misc imports

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
import os
from functools import partial

import numpy as np
import pandas as pd
from tqdm import tqdm

import matplotlib.pyplot as plt
%matplotlib inline

# Config

In [None]:
import torch
import torch.nn as nn

from saltsegm.utils import ratio2groups, dump_json
from saltsegm.torch_models.unet import get_UNetRes
from saltsegm.metrics import dice_score, main_metric
from saltsegm.dataset import Dataset
from saltsegm.cross_validation import get_cv_111

In [None]:
DATA_PATH = '/cobrain/groups/ml_group/data/dustpelt/salt_prep/train/'
EXP_PATH = '/cobrain/groups/ml_group/experiments/dustpelt/salt_segm/unet_baseline/'

modalities = ['image-128']
features = None
target = 'target-128'

ds = Dataset(data_path = DATA_PATH, modalities=modalities, features=features, target=target)

n_channels = ds.n_channels
n_classes = 1

# ================================================================================

n_splits = 5
val_size = 100

batch_size = 32
epochs = 200
steps_per_epoch = 100

lr_init = 1e-2
patience = 5
lr_factor = 0.5

dropout_rate = 0.10
noise_augm_ratio = 0.05

optim = partial(torch.optim.Adam, lr=lr_init)

lr_scheduler = partial(torch.optim.lr_scheduler.ReduceLROnPlateau,
                       factor=lr_factor, patience=patience, verbose=True)

# ================================================================================

random_state = 42

ratio = ds.metadata['target_ratio'].values
groups = ratio2groups(ratio)

cv_splits = get_cv_111(ds.ids, n_splits=n_splits, val_size=val_size,
                       groups=groups, random_state=random_state)

# ================================================================================


load_x = ds.load_x
load_y = ds.load_y

metric_fn = main_metric
metrics_dict = {'dice_score': dice_score,
                'main_metric': main_metric}

loss_fn = nn.BCEWithLogitsLoss()

## Build experiment

In [None]:
from saltsegm.experiment import generate_experiment

generate_experiment(exp_path=EXP_PATH, cv_splits=cv_splits, dataset=ds)

# Carry on experiment

In [None]:
from saltsegm.experiment import load_val_data, make_predictions, calculate_metrics
from saltsegm.batch_iter import BatchIter
from saltsegm.torch_models.model import TorchModel

In [None]:
for n_val in range(n_splits):
    VAL_PATH = os.path.join(EXP_PATH, f'experiment_{n_val}')
    
    
    # BATCH ITERATOR:
    x_val, y_val = load_val_data(exp_path=EXP_PATH, n_val=n_val)
    val_data = (x_val, y_val)

    train_ids = cv_splits[n_val]['train_ids']

    batch_iter = BatchIter(train_ids=train_ids, load_x=load_x, load_y=load_y,
                           batch_size=batch_size, verbose_loading=False)
    
    
    # MODEL CREATION:
    model = get_UNetRes(n_channels=n_channels, n_classes=n_classes,
                        dropout_rate=dropout_rate)
    model_torch = TorchModel(model=model, loss_fn=loss_fn, metric_fn=metric_fn,
                             optim=optim, lr_scheduler=lr_scheduler)
    del model
    
    
    # LEARNING PROCESS:
    history = model_torch.fit_generator(
        batch_iter.flow(), epochs=epochs, val_data=val_data,
        steps_per_epoch=steps_per_epoch, verbose=True
    )
    

    # SAVING EXPERIMENT STUFF:
    dump_json(history, os.path.join(VAL_PATH, 'log.json'))

    model_filename = os.path.join(VAL_PATH, 'model.pt')
    torch.save(model_torch.model, model_filename)

    
    # MAKING PREDICTIONS AND CALCULATING METRICS:
    make_predictions(exp_path=EXP_PATH, n_val=n_val)

    calculate_metrics(exp_path=EXP_PATH, n_val=n_val, metrics_dict=metrics_dict)

## Results

In [None]:
from saltsegm.experiment import get_experiment_result

In [None]:
get_experiment_result(exp_path=EXP_PATH, n_splits=n_splits, metric_name='dice_score')

In [None]:
get_experiment_result(exp_path=EXP_PATH, n_splits=n_splits, metric_name='main_metric')

<hr/>