misc imports

In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
import os
from functools import partial

import numpy as np
import pandas as pd
from tqdm import tqdm

import matplotlib.pyplot as plt
%matplotlib inline

# Config

In [3]:
import torch
import torch.nn as nn

from saltsegm.utils import ratio2groups, dump_json
from saltsegm.torch_models.unet import get_UNet
from saltsegm.metrics import dice_score, main_metric
from saltsegm.dataset import Dataset
from saltsegm.cross_validation import get_cv_111

In [4]:
DATA_PATH = '/cobrain/groups/ml_group/data/dustpelt/salt_prep/train/'
EXP_PATH = '/cobrain/groups/ml_group/experiments/dustpelt/salt_segm/unet_baseline/'

modalities = ['image-128']
target = 'target-128'
n_channels = len(modalities)
n_classes = 1

ds = Dataset(data_path = DATA_PATH, modalities=modalities, target=target)

# ================================================================================

n_splits = 5
val_size = 100

batch_size = 32
epochs = 50
steps_per_epoch = 100

lr_init = 1e-2
patience = 5
lr_factor = 0.3

optim = partial(torch.optim.Adam, lr=lr_init)

lr_scheduler = partial(torch.optim.lr_scheduler.ReduceLROnPlateau,
                       factor=lr_factor, patience=patience, verbose=True)

# ================================================================================

random_state = 42

ratio = ds.metadata['target_ratio'].values
groups = ratio2groups(ratio)

cv_splits = get_cv_111(ds.ids, n_splits=n_splits, val_size=val_size,
                       groups=groups, random_state=random_state)

# ================================================================================


load_x = ds.load_x
load_y = ds.load_y

metric_fn = dice_score
metrics_dict = {'dice_score': dice_score,
                'main_metric': main_metric}

loss_fn = nn.BCEWithLogitsLoss()

## Build experiment

In [5]:
from saltsegm.experiment import generate_experiment

generate_experiment(exp_path=EXP_PATH, cv_splits=cv_splits, dataset=ds)

# Carry on experiment

In [6]:
from saltsegm.experiment import load_val_data, make_predictions, calculate_metrics
from saltsegm.batch_iter import BatchIter
from saltsegm.torch_models.model import TorchModel

### val 0

In [7]:
n_val = 0
VAL_PATH = os.path.join(EXP_PATH, f'experiment_{n_val}')

In [None]:
x_val, y_val = load_val_data(exp_path=EXP_PATH, n_val=n_val)
val_data = (x_val, y_val)

train_ids = cv_splits[n_val]['train_ids']

batch_iter = BatchIter(train_ids=train_ids, load_x=load_x, load_y=load_y,
                       batch_size=batch_size)

In [None]:
model = get_UNet(n_channels=n_channels, n_classes=n_classes)
model_torch = TorchModel(model=model, loss_fn=loss_fn, metric_fn=metric_fn,
                         optim=optim, lr_scheduler=lr_scheduler)

history = model_torch.fit_generator(
    batch_iter.flow(), epochs=epochs, val_data=val_data,
    steps_per_epoch=steps_per_epoch, verbose=True
)

dump_json(history, os.path.join(VAL_PATH, 'history.json'))

model_filename = os.path.join(VAL_PATH, 'model.pt')
torch.save(model_torch.model, model_filename)

make_predictions(exp_path=EXP_PATH, n_val=n_val, model=model_torch)

calculate_metrics(exp_path=EXP_PATH, n_val=n_val, metrics_dict=metrics_dict)

### val 1

In [None]:
n_val = 1
VAL_PATH = os.path.join(EXP_PATH, f'experiment_{n_val}')

In [None]:
x_val, y_val = load_val_data(exp_path=EXP_PATH, n_val=n_val)
val_data = (x_val, y_val)

train_ids = cv_splits[n_val]['train_ids']

batch_iter = BatchIter(train_ids=train_ids, load_x=load_x, load_y=load_y,
                       batch_size=batch_size)

In [None]:
model = get_UNet(n_channels=n_channels, n_classes=n_classes)
model_torch = TorchModel(model=model, loss_fn=loss_fn, metric_fn=metric_fn,
                         optim=optim, lr_scheduler=lr_scheduler)

history = model_torch.fit_generator(
    batch_iter.flow(), epochs=epochs, val_data=val_data,
    steps_per_epoch=steps_per_epoch, verbose=True
)

dump_json(history, os.path.join(VAL_PATH, 'history.json'))

model_filename = os.path.join(VAL_PATH, 'model.pt')
torch.save(model_torch.model, model_filename)

make_predictions(exp_path=EXP_PATH, n_val=n_val, model=model_torch)

calculate_metrics(exp_path=EXP_PATH, n_val=n_val, metrics_dict=metrics_dict)

### val 2

In [None]:
n_val = 2
VAL_PATH = os.path.join(EXP_PATH, f'experiment_{n_val}')

In [None]:
x_val, y_val = load_val_data(exp_path=EXP_PATH, n_val=n_val)
val_data = (x_val, y_val)

train_ids = cv_splits[n_val]['train_ids']

batch_iter = BatchIter(train_ids=train_ids, load_x=load_x, load_y=load_y,
                       batch_size=batch_size)

In [None]:
model = get_UNet(n_channels=n_channels, n_classes=n_classes)
model_torch = TorchModel(model=model, loss_fn=loss_fn, metric_fn=metric_fn,
                         optim=optim, lr_scheduler=lr_scheduler)

history = model_torch.fit_generator(
    batch_iter.flow(), epochs=epochs, val_data=val_data,
    steps_per_epoch=steps_per_epoch, verbose=True
)

dump_json(history, os.path.join(VAL_PATH, 'history.json'))

model_filename = os.path.join(VAL_PATH, 'model.pt')
torch.save(model_torch.model, model_filename)

make_predictions(exp_path=EXP_PATH, n_val=n_val, model=model_torch)

calculate_metrics(exp_path=EXP_PATH, n_val=n_val, metrics_dict=metrics_dict)

### val 3

In [None]:
n_val = 3
VAL_PATH = os.path.join(EXP_PATH, f'experiment_{n_val}')

In [None]:
x_val, y_val = load_val_data(exp_path=EXP_PATH, n_val=n_val)
val_data = (x_val, y_val)

train_ids = cv_splits[n_val]['train_ids']

batch_iter = BatchIter(train_ids=train_ids, load_x=load_x, load_y=load_y,
                       batch_size=batch_size)

In [None]:
model = get_UNet(n_channels=n_channels, n_classes=n_classes)
model_torch = TorchModel(model=model, loss_fn=loss_fn, metric_fn=metric_fn,
                         optim=optim, lr_scheduler=lr_scheduler)

history = model_torch.fit_generator(
    batch_iter.flow(), epochs=epochs, val_data=val_data,
    steps_per_epoch=steps_per_epoch, verbose=True
)

dump_json(history, os.path.join(VAL_PATH, 'history.json'))

model_filename = os.path.join(VAL_PATH, 'model.pt')
torch.save(model_torch.model, model_filename)

make_predictions(exp_path=EXP_PATH, n_val=n_val, model=model_torch)

calculate_metrics(exp_path=EXP_PATH, n_val=n_val, metrics_dict=metrics_dict)

### val 4

In [None]:
n_val = 4
VAL_PATH = os.path.join(EXP_PATH, f'experiment_{n_val}')

In [None]:
x_val, y_val = load_val_data(exp_path=EXP_PATH, n_val=n_val)
val_data = (x_val, y_val)

train_ids = cv_splits[n_val]['train_ids']

batch_iter = BatchIter(train_ids=train_ids, load_x=load_x, load_y=load_y,
                       batch_size=batch_size)

In [None]:
model = get_UNet(n_channels=n_channels, n_classes=n_classes)
model_torch = TorchModel(model=model, loss_fn=loss_fn, metric_fn=metric_fn,
                         optim=optim, lr_scheduler=lr_scheduler)

history = model_torch.fit_generator(
    batch_iter.flow(), epochs=epochs, val_data=val_data,
    steps_per_epoch=steps_per_epoch, verbose=True
)

dump_json(history, os.path.join(VAL_PATH, 'history.json'))

model_filename = os.path.join(VAL_PATH, 'model.pt')
torch.save(model_torch.model, model_filename)

make_predictions(exp_path=EXP_PATH, n_val=n_val, model=model_torch)

calculate_metrics(exp_path=EXP_PATH, n_val=n_val, metrics_dict=metrics_dict)

temporary results

In [5]:
from saltsegm.experiment import get_experiment_result

In [6]:
get_experiment_result(exp_path=EXP_PATH, n_splits=n_splits, metric_name='dice_score')

0.7980003825141514

In [7]:
get_experiment_result(exp_path=EXP_PATH, n_splits=n_splits, metric_name='main_metric')

0.7275803757190253