### Imports

In [1]:
import numpy as np
import os
os.environ["OMP_NUM_THREADS"]="1"
import random
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import toml
from time import time

from src.data_utils import color_augmentations
from src.embedding_loss import SpatialEmbLoss
from src.unet import UNet
from src.spatial_augmenter import SpatialAugmenter
from src.train_utils import instance_seg_train_step, instance_seg_validation, save_snapshot, save_model
from src.data_utils import H5pyDataset

### Helper functions

In [2]:
def validate(model, validation_dataloader, inst_loss_fn, device, step, writer, validation_loss, loss, inst_model):
    print('Validate')
    val_new = instance_seg_validation(model,
                        validation_dataloader,
                        inst_loss_fn,
                        device,
                        step,
                        writer,
                        inst_model)
    validation_loss.append(val_new)
    if val_new <= np.min(validation_loss):
        print('Save best model')
        save_model(step, model, optimizer, loss, os.path.join(log_dir,"best_model"))

def create_snapshot(img_caug, gt_inst, params, pred_inst, inst_loss_fn, step):
    print('Save snapshot')
    tmp_dic = {
            'img_caug': img_caug[0].squeeze(0).cpu().detach().numpy(),
            'gt_inst': gt_inst.squeeze(0).cpu().detach().numpy()[:1]
    }
    if params['instance_seg'] == 'embedding':
        _,_,h,w = pred_inst.shape
        xym_s = inst_loss_fn.xym[:, 0:h, 0:w].contiguous()
        spatial_emb = pred_inst[0, 0:2] + xym_s  # 2 x h x w
        sigma = pred_inst[0, 2:2+inst_loss_fn.n_sigma]  # n_sigma x h x w
        seed_map = torch.sigmoid(
            pred_inst[0, 2+inst_loss_fn.n_sigma:2+inst_loss_fn.n_sigma + 1])  # 1 x h x w
        tmp_dic['embedding'] = spatial_emb.cpu().detach().numpy()
        tmp_dic['sigma'] = sigma.cpu().detach().numpy()
        tmp_dic['seed_map'] = seed_map.cpu().detach().numpy()
    elif params['instance_seg'] == 'cpv_3c':
        tmp_dic['pred_cpv'] = pred_inst[0,:2].cpu().detach().numpy()
        tmp_dic['pred_3c'] = pred_inst[0,2:].softmax(0).cpu().detach().numpy()
    save_snapshot(snap_dir, tmp_dic, step)


def supervised_training(params, model, labeled_dataloader, validation_dataloader, fast_aug, color_aug_fn, inst_loss_fn, writer, device):
    # step = -1, step = 21000
    step = -1
    validation_loss = []

    # training loop
    while step < params['training_steps']:

        # get sample for training
        tmp_loader = iter(labeled_dataloader)
        for raw, gt in tmp_loader:
            step += 1
            optimizer.zero_grad()

            # training step
            print("Started training step %d" % step)
            loss, pred_inst, img_caug, gt_inst = instance_seg_train_step(model,
                                                                         raw,
                                                                         gt,
                                                                         fast_aug,
                                                                         color_aug_fn,
                                                                         inst_loss_fn,
                                                                         writer,
                                                                         device,
                                                                         step)
            if torch.isnan(loss) or not torch.isfinite(loss):
                continue
            loss.backward()
            print("backward loss done")
            optimizer.step()
            print("Finished training step %d" % step)

            # validation
            if step % params['validation_step'] == 0:
                validate(model,
                         validation_dataloader,
                         inst_loss_fn,
                         device,
                         step,
                         writer,
                         validation_loss,
                         loss,
                         inst_model=params['instance_seg']
                         )

            # Create snapshot
            if step % params['snapshot_step'] == 0:
                create_snapshot(img_caug, gt_inst, params, pred_inst, inst_loss_fn, step)

            # Create checkpoint
            if step % params['checkpoint_step'] == 0:
                save_model(step, model, optimizer, loss, os.path.join(log_dir,"checkpoint_step_"+str(step)))

In [3]:
x = torch.zeros(2, 1, 2, 1, 2)
x.size()
y = torch.squeeze(x, 0)
y.size()

torch.Size([2, 1, 2, 1, 2])

### Set parameters

In [3]:
torch.backends.cudnn.benchmark = True

torch.manual_seed(42)

params = {
    'data_path': '/home/julia/AG Kainmüller/Alzheimer_Segmentation_Annotation/tau_nerve_segmentation/data/data/cells',
    # directory name of report
    'experiment' : 'instance_seg_cell_bodies',
    'batch_size': 1,
    'training_steps': 10, #:400000,
    'in_channels': 3,
    'num_fmaps': 32,
    'fmap_inc_factors': 2,
    'downsample_factors': [ [ 2, 2,], [ 2, 2,], [ 2, 2,], [ 2, 2,],],
    'num_fmaps_out': 5,
    'constant_upsample': False,
    'padding': 'same',
    'activation': 'ReLU',
    'weight_decay': 1e-5,
    'learning_rate': 3e-4,
    'seed': 42,
    'num_validation': 15,
    'checkpoint_path': None, # 'exp_0_dsb/best_model',
    'pretrained_model': False,
    'multi_head': False,
    'uniform_class_sampling': False,
    'optimizer': 'AdamW', # one of SGD AdamW AdaBound , Adahessian breaks memory and is not supported
    'validation_step' : 10, #500,
    'snapshot_step' : 5, #5000,
    'checkpoint_step': 5, #20000,
    'instance_seg': 'embedding', # 'embedding' or 'cpv_3c'
    'attention': True,
    'color_augmentation_s': 0.4,
    'to_center': True
    }

# augmentation parameters
aug_params_fast = {}
"""

    'mirror': {'prob_x': 0.5, 'prob_y': 0.5, 'prob': 0.5},
    'translate': {'max_percent':0.05, 'prob': 0.2},
    'scale': {'min': 0.8, 'max':1.2, 'prob': 0.2},
    'zoom': {'min': 0.8, 'max':1.2, 'prob': 0.2},
    'rotate': {'max_degree': 179, 'prob': 0.75},
    'shear': {'max_percent': 0.1, 'prob': 0.2},
    'elastic': {'alpha': [120,120], 'sigma': 8, 'prob': 0.5}
}"""
print("parameters set")

# set directories for report
log_dir = os.path.join(params['experiment'],'train')
snap_dir = os.path.join(log_dir,'snaps')
os.makedirs(snap_dir, exist_ok=True)
writer_dir = os.path.join(log_dir,'summary', str(time()))
os.makedirs(writer_dir, exist_ok=True)
writer = SummaryWriter(writer_dir)
print("directories set")

# set device
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(device)
params['device'] = device
params['aug_params_fast'] = aug_params_fast
with open(os.path.join(params['experiment'], 'params.toml'), 'w') as f:
    toml.dump(params, f)

parameters set


2023-01-19 16:28:47.304692: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-01-19 16:28:49.203192: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /home/julia/.local/share/virtualenvs/tau_nerve_segmentation-s0uDq-Et/lib/python3.8/site-packages/cv2/../../lib64:
2023-01-19 16:28:49.203281: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /home/julia/.loc

directories set
cpu


### Load data

In [4]:
file_paths = [os.path.join(params['data_path'], file) for file in os.listdir(params['data_path'])]
random.Random(params['seed']).shuffle(file_paths)
val_paths = [file_paths.pop(0) for _ in range(params['num_validation'])]
labeled_paths = file_paths
labeled_dataset = H5pyDataset(labeled_paths, raw_keys=['raw'], label_keys=['gt_instances'], crop_size=(1200, 1312))
validation_dataset = H5pyDataset(val_paths, raw_keys=['raw'], label_keys=['gt_instances'], crop_size=(1200, 1312))

labeled_dataloader = DataLoader(labeled_dataset,
                        batch_size=params['batch_size'],
                        shuffle=True,
                        #prefetch_factor=4,
                        num_workers=0)

validation_dataloader = DataLoader(validation_dataset,
                    batch_size=1,
                    shuffle=True,
                   # prefetch_factor=4,
                    num_workers=0)

### Select model

In [5]:
model = UNet(in_channels = params['in_channels'],
                num_fmaps = params['num_fmaps'],
                fmap_inc_factor = params['fmap_inc_factors'],
                downsample_factors = params['downsample_factors'],
                activation = params['activation'],
                padding = params['padding'],
                num_fmaps_out = params['num_fmaps_out'],
                constant_upsample = params['constant_upsample'],
            ).to(params['device'])

if 'checkpoint_path' in params.keys() and params['checkpoint_path']:
    model.load_state_dict(torch.load(params['checkpoint_path'])['model_state_dict'])

model = model.train()

### Initialize optimizer and augmentations

In [6]:
# Optimizer
if params['optimizer'] == 'SGD':
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=params['learning_rate'],
                                momentum=0.9,
                                weight_decay=params['weight_decay'],
                                nesterov=True)
elif params['optimizer'] == 'AdamW':
    optimizer = torch.optim.AdamW(model.parameters(),
                                  lr=params['learning_rate'],
                                  weight_decay=params['weight_decay'])

lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=params['training_steps'], eta_min=5e-6)

# Augmentation
fast_aug = SpatialAugmenter(aug_params_fast)#, padding_mode='reflection')
color_aug_fn = color_augmentations(100, s=params['color_augmentation_s'])

# set loss function
inst_loss_fn = SpatialEmbLoss(n_sigma=1, to_center=params['to_center'], foreground_weight=10, H=1200, W=1312).to(device)

Created spatial emb loss function with: to_center: True, n_sigma: 1, foreground_weight: 10


### Train

In [None]:
supervised_training(params, model, labeled_dataloader, validation_dataloader, fast_aug, color_aug_fn, inst_loss_fn, writer, device)

Started training step 0
step started
finished aug
loss 2.4153785705566406
backward loss done
Finished training step 0
Validate
Validation loss:  2.3075337409973145
Save best model
Save model
Save snapshot
Save training snapshot
Save model
Started training step 1
step started
finished aug
loss 2.3503928184509277
backward loss done
Finished training step 1
Started training step 2
step started
finished aug
loss 2.431063175201416
backward loss done
Finished training step 2
Started training step 3
step started
finished aug
