misc imports

In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
import os
from functools import partial

import numpy as np
import pandas as pd
from tqdm import tqdm

import matplotlib.pyplot as plt
%matplotlib inline

# Config

In [3]:
import torch
import torch.nn as nn

from saltsegm.utils import ratio2groups
from saltsegm.torch_models.unet import get_UNet
from saltsegm.metrics import dice_score
from saltsegm.dataset import Dataset
from saltsegm.cross_validation import get_cv_111

In [4]:
DATA_PATH = '/cobrain/groups/ml_group/data/dustpelt/salt_prep/train/'
EXP_PATH = '/cobrain/groups/ml_group/experiments/dustpelt/salt_segm/unet_baseline/'

modalities = ['image-128']
target = 'target-128'
n_channels = len(modalities)
n_classes = 1

ds = Dataset(data_path = DATA_PATH, modalities=modalities, target=target)

# ================================================================================

n_splits = 5
val_size = 100

batch_size = 32
epochs = 50
steps_per_epoch = 100

lr_init = 1e-2
patience = 5
lr_factor = 0.3

optim = partial(torch.optim.Adam, lr=lr_init)

lr_scheduler = partial(torch.optim.lr_scheduler.ReduceLROnPlateau,
                       factor=lr_factor, patience=patience, verbose=True)

# ================================================================================

random_state = 42

ratio = ds.metadata['target_ratio'].values
groups = ratio2groups(ratio)

cv_splits = get_cv_111(ds.ids, n_splits=n_splits, val_size=val_size,
                       groups=groups, random_state=random_state)

# ================================================================================


load_x = ds.load_x
load_y = ds.load_y

metric_fn = dice_score
metrics_dict = {'dice_score': dice_score}

loss_fn = nn.BCEWithLogitsLoss()

## Build experiment

In [5]:
from saltsegm.experiment import generate_experiment

# generate_experiment(exp_path=EXP_PATH, cv_splits=cv_splits, dataset=ds)

# Carry on experiment

In [6]:
from saltsegm.experiment import load_experiment_data, make_predictions, calculate_metrics
from saltsegm.batch_iter import BatchIter
from saltsegm.torch_models.model import TorchModel

### val 0

In [7]:
n_val = 0
VAL_PATH = os.path.join(EXP_PATH, f'experiment_{n_val}')

In [8]:
x_train, y_train, x_val, y_val, x_test, y_test = load_experiment_data(exp_path=EXP_PATH, n_val=n_val)
val_data = (x_val, y_val)

train_ids = cv_splits[n_val]['train_ids']

batch_iter = BatchIter(train_ids=train_ids, load_x=load_x, load_y=load_y,
                       batch_size=batch_size)

In [9]:
model = get_UNet(n_channels=n_channels, n_classes=n_classes)
model_torch = TorchModel(model=model, loss_fn=loss_fn, metric_fn=metric_fn,
                         optim=optim, lr_scheduler=lr_scheduler)

history = model_torch.fit_generator(
    batch_iter.flow(), epochs=epochs, val_data=val_data,
    steps_per_epoch=steps_per_epoch, verbose=True
)

model_filename = os.path.join(VAL_PATH, 'model.pt')
torch.save(model_torch.model, model_filename)

make_predictions(exp_path=EXP_PATH, n_val=n_val, model=model_torch)

calculate_metrics(exp_path=EXP_PATH, n_val=n_val, metrics_dict=metrics_dict)

epoch 1/50: 101it [00:36,  2.66it/s, val_loss=0.598502, val_metric=0.353]                         
epoch 2/50: 101it [00:35,  2.62it/s, val_loss=0.57563883, val_metric=0.37]                         
epoch 3/50: 101it [00:35,  2.61it/s, val_loss=0.3677874, val_metric=0.591]                         
epoch 4/50: 101it [00:35,  2.65it/s, val_loss=0.461218, val_metric=0.473]                         
epoch 5/50: 101it [00:35,  2.64it/s, val_loss=0.4069781, val_metric=0.501]                         
epoch 6/50: 101it [00:35,  2.60it/s, val_loss=0.41882268, val_metric=0.498]                         
epoch 7/50: 101it [00:35,  2.59it/s, val_loss=0.3614318, val_metric=0.601]                         
epoch 8/50: 101it [00:35,  2.61it/s, val_loss=0.4754977, val_metric=0.584]                         
epoch 9/50: 101it [00:35,  2.62it/s, val_loss=0.384659, val_metric=0.552]                         
epoch 10/50: 101it [00:35,  2.66it/s, val_loss=0.25505042, val_metric=0.658]                         


Epoch    34: reducing learning rate of group 0 to 3.0000e-03.


epoch 36/50: 101it [00:35,  2.62it/s, val_loss=0.18856698, val_metric=0.771]                         
epoch 37/50: 101it [00:35,  2.59it/s, val_loss=0.1457588, val_metric=0.773]                         
epoch 38/50: 101it [00:35,  2.66it/s, val_loss=0.15731198, val_metric=0.781]                         
epoch 39/50: 101it [00:35,  2.62it/s, val_loss=0.16010068, val_metric=0.735]                         
epoch 40/50: 101it [00:35,  2.63it/s, val_loss=0.20277686, val_metric=0.775]                         
epoch 41/50: 101it [00:35,  2.63it/s, val_loss=0.15700138, val_metric=0.785]                         
epoch 42/50: 101it [00:35,  2.62it/s, val_loss=0.16711503, val_metric=0.78]                         
epoch 43/50: 101it [00:35,  2.67it/s, val_loss=0.16938014, val_metric=0.783]                         
epoch 44/50:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch    42: reducing learning rate of group 0 to 9.0000e-04.


epoch 44/50: 101it [00:35,  2.61it/s, val_loss=0.15901853, val_metric=0.758]                         
epoch 45/50: 101it [00:35,  2.59it/s, val_loss=0.17106569, val_metric=0.766]                         
epoch 46/50: 101it [00:35,  2.64it/s, val_loss=0.16817418, val_metric=0.764]                         
epoch 47/50: 101it [00:35,  2.64it/s, val_loss=0.15852167, val_metric=0.772]                         
epoch 48/50: 101it [00:35,  2.65it/s, val_loss=0.16429971, val_metric=0.758]                         
epoch 49/50: 101it [00:35,  2.61it/s, val_loss=0.16388805, val_metric=0.775]                         
epoch 50/50:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch    48: reducing learning rate of group 0 to 2.7000e-04.


epoch 50/50: 101it [00:35,  2.63it/s, val_loss=0.16600184, val_metric=0.774]                         


### val 1

In [10]:
n_val = 1
VAL_PATH = os.path.join(EXP_PATH, f'experiment_{n_val}')

In [None]:
x_train, y_train, x_val, y_val, x_test, y_test = load_experiment_data(exp_path=EXP_PATH, n_val=n_val)
val_data = (x_val, y_val)

train_ids = cv_splits[n_val]['train_ids']

batch_iter = BatchIter(train_ids=train_ids, load_x=load_x, load_y=load_y,
                       batch_size=batch_size)

In [None]:
model = get_UNet(n_channels=n_channels, n_classes=n_classes)
model_torch = TorchModel(model=model, loss_fn=loss_fn, metric_fn=metric_fn,
                         optim=optim, lr_scheduler=lr_scheduler)

history = model_torch.fit_generator(
    batch_iter.flow(), epochs=epochs, val_data=val_data,
    steps_per_epoch=steps_per_epoch, verbose=True
)

model_filename = os.path.join(VAL_PATH, 'model.pt')
torch.save(model_torch.model, model_filename)

make_predictions(exp_path=EXP_PATH, n_val=n_val, model=model_torch)

calculate_metrics(exp_path=EXP_PATH, n_val=n_val, metrics_dict=metrics_dict)

### val 2

In [None]:
n_val = 2
VAL_PATH = os.path.join(EXP_PATH, f'experiment_{n_val}')

In [None]:
x_train, y_train, x_val, y_val, x_test, y_test = load_experiment_data(exp_path=EXP_PATH, n_val=n_val)
val_data = (x_val, y_val)

train_ids = cv_splits[n_val]['train_ids']

batch_iter = BatchIter(train_ids=train_ids, load_x=load_x, load_y=load_y,
                       batch_size=batch_size)

In [None]:
model = get_UNet(n_channels=n_channels, n_classes=n_classes)
model_torch = TorchModel(model=model, loss_fn=loss_fn, metric_fn=metric_fn,
                         optim=optim, lr_scheduler=lr_scheduler)

history = model_torch.fit_generator(
    batch_iter.flow(), epochs=epochs, val_data=val_data,
    steps_per_epoch=steps_per_epoch, verbose=True
)

model_filename = os.path.join(VAL_PATH, 'model.pt')
torch.save(model_torch.model, model_filename)

make_predictions(exp_path=EXP_PATH, n_val=n_val, model=model_torch)

calculate_metrics(exp_path=EXP_PATH, n_val=n_val, metrics_dict=metrics_dict)

### val 3

In [None]:
n_val = 3
VAL_PATH = os.path.join(EXP_PATH, f'experiment_{n_val}')

In [None]:
x_train, y_train, x_val, y_val, x_test, y_test = load_experiment_data(exp_path=EXP_PATH, n_val=n_val)
val_data = (x_val, y_val)

train_ids = cv_splits[n_val]['train_ids']

batch_iter = BatchIter(train_ids=train_ids, load_x=load_x, load_y=load_y,
                       batch_size=batch_size)

In [None]:
model = get_UNet(n_channels=n_channels, n_classes=n_classes)
model_torch = TorchModel(model=model, loss_fn=loss_fn, metric_fn=metric_fn,
                         optim=optim, lr_scheduler=lr_scheduler)

history = model_torch.fit_generator(
    batch_iter.flow(), epochs=epochs, val_data=val_data,
    steps_per_epoch=steps_per_epoch, verbose=True
)

model_filename = os.path.join(VAL_PATH, 'model.pt')
torch.save(model_torch.model, model_filename)

make_predictions(exp_path=EXP_PATH, n_val=n_val, model=model_torch)

calculate_metrics(exp_path=EXP_PATH, n_val=n_val, metrics_dict=metrics_dict)

### val 4

In [None]:
n_val = 4
VAL_PATH = os.path.join(EXP_PATH, f'experiment_{n_val}')

In [None]:
x_train, y_train, x_val, y_val, x_test, y_test = load_experiment_data(exp_path=EXP_PATH, n_val=n_val)
val_data = (x_val, y_val)

train_ids = cv_splits[n_val]['train_ids']

batch_iter = BatchIter(train_ids=train_ids, load_x=load_x, load_y=load_y,
                       batch_size=batch_size)

In [None]:
model = get_UNet(n_channels=n_channels, n_classes=n_classes)
model_torch = TorchModel(model=model, loss_fn=loss_fn, metric_fn=metric_fn,
                         optim=optim, lr_scheduler=lr_scheduler)

history = model_torch.fit_generator(
    batch_iter.flow(), epochs=epochs, val_data=val_data,
    steps_per_epoch=steps_per_epoch, verbose=True
)

model_filename = os.path.join(VAL_PATH, 'model.pt')
torch.save(model_torch.model, model_filename)

make_predictions(exp_path=EXP_PATH, n_val=n_val, model=model_torch)

calculate_metrics(exp_path=EXP_PATH, n_val=n_val, metrics_dict=metrics_dict)

temporary results

In [24]:
from saltsegm.utils import load_json

def get_experiment_result(exp_path, n_splits, metric_name):
    val_results = []
    for i in range(n_splits):
        metric_path = os.path.join(EXP_PATH, f'experiment_{i}/test_metrics/{metric_name}.json')
        results_dict = load_json(metric_path)
        
        val_mean = np.mean(list(results_dict.values()))
        val_results.append(val_mean)
        
    return np.mean(val_results)

In [25]:
get_experiment_result(exp_path=EXP_PATH, n_splits=n_splits, metric_name='dice_score')

0.7828686297599752