# Fine tuning test
Test the possibility to manually annotated a few frames of the experiment, then fine tune the network on them to predict the rest of the frames.  
This is kind of overfitting part of the test set to perform well on the rest, or domain adaptation.

In [1]:
%matplotlib inline

import os, sys, time, shutil, copy
import random
import ipywidgets as widgets
from ipywidgets import interact

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from skimage import io
from scipy import ndimage as ndi
import cv2
import imgaug.augmenters as iaa

import torch

from utils_common.image import imread_to_float, to_npint, overlay_preds_targets
from utils_common.metrics import dice_coef
from utils_data import normalize_range, get_all_dataloaders, pad_transform, pad_transform_stack, compute_weights
from utils_loss import get_BCEWithLogits_loss
from utils_metric import get_dice_metric
from utils_model import CustomUNet, load_model
from utils_test import predict, predict_stack, evaluate, evaluate_stack

seed = 1
random.seed(seed)
np.random.seed(seed*10 + 1234)
torch.manual_seed(seed*100 + 4321)

# Use GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device:", device)

%load_ext autoreload
%autoreload 2

Device: cuda:0


### Parameters

In [2]:
batch_size = 16
learning_rate = 0.0005

# Choose wether or not use synth, aug, and weights for fine tuning
synth_data = False
synth_ratio = None # ratio of synthetic data vs. real data
only_synth = False # If True, will use only the synthetic data (and all of it, at the opposite of ratio=1)
data_aug = False # If True, will use data augmentation (see below for augmentation sequence)
use_weights = True # if False use class weights, if True use pixelwise weights (if existing)

input_channels = "RG" # Channel to use as input
u_depth = 4
out1_channels = 16

out_model_name = "models/test_ft"
model_name = "models/unet4-16_RG_cv-annotated/"
data_dir = "/data/talabot/pdm/dataset_cv-annotated/"

### Prepare training
Make dataloaders and so on to prepare fine tuning training

In [3]:
# Create random augment sequence for data augmentation if applicable
if data_aug:
    seq = iaa.GammaContrast((0.7, 1.3)) # Gamma correction
    aug_fn = seq.augment_image
else:
    aug_fn = lambda x: x # identity function

# Create dataloaders
dataloaders = get_all_dataloaders(
    data_dir,
    batch_size, 
    input_channels = input_channels,
    test_dataloader = True,
    use_weights = use_weights,
    synthetic_data = synth_data, synthetic_ratio = synth_ratio, synthetic_only = only_synth,
    train_transform = lambda img: normalize_range(pad_transform(aug_fn(img), u_depth)),
    train_target_transform = lambda img: pad_transform(img, u_depth),
    eval_transform = lambda img: normalize_range(pad_transform(img, u_depth)), 
    eval_target_transform = lambda img: pad_transform(img, u_depth)
)
# "Deactivate" the collate_fn of the train dataloader
collate_fn = dataloaders["train"].collate_fn
dataloaders["train"].collate_fn = lambda batch: batch

# Compute class weights (as pixel imbalance)
pos_count = 0
neg_count = 0
for filename in dataloaders["train"].dataset.y_filenames:
    y = io.imread(filename)
    pos_count += (y == 255).sum()
    neg_count += (y == 0).sum()
pos_weight = torch.tensor((neg_count + pos_count) / (2 * pos_count)).to(device)
neg_weight = torch.tensor((neg_count + pos_count) / (2 * neg_count)).to(device)

### Load model

In [4]:
model = load_model(model_name, input_channels=input_channels, u_depth=u_depth, 
                   out1_channels=out1_channels, device=device)

loss_fn = get_BCEWithLogits_loss(pos_weight=pos_weight, neg_weight=neg_weight)
metrics = {"dice": get_dice_metric()}

# Save future model
os.makedirs(out_model_name, exist_ok=True)
shutil.copy("utils_model.py", os.path.join(out_model_name, "utils_model_save.py"))

'models/test_ft/utils_model_save.py'

## Load experiments and annotated frames
Load an experiment, predict once the detections, and create annotations for a few frames.

In [5]:
experiment = "/data/talabot/experiments/annotated/R70H06_20181202-tdTomGC6fopt-fl2/R70H06-tdTomGC6fopt-fly2-001/"

# Load experiment and segmentation & weights if available
rgb_stack = imread_to_float(os.path.join(experiment, "RGB.tif"))
if os.path.isfile(os.path.join(experiment, "seg_ROI.tif")):
    seg_stack = imread_to_float(os.path.join(experiment, "seg_ROI.tif"))
else:
    seg_stack = None
if use_weights and os.path.isfile(os.path.join(experiment, "weights.tif")):
    weights_stack = imread_to_float(os.path.join(experiment, "weights.tif"))
else:
    weights_stack = None

# Predict using loaded model
start = time.time()
predictions = predict_stack(model, rgb_stack, batch_size, input_channels=input_channels,
                            transform=lambda stack: normalize_range(pad_transform_stack(stack, u_depth)))
predictions = torch.sigmoid(predictions)
print("Predicted experiment in %.1f s." % (time.time() - start))

if seg_stack is not None:
    print("Dice =", dice_coef((predictions > 0.5).numpy(), seg_stack))

@interact(image=(0, len(rgb_stack) - 1))
def plot_experiment(image=0):
    plt.figure(figsize=(12, 8))
    plt.subplot(231)
    plt.title("Raw input")
    plt.imshow(rgb_stack[image])
    if seg_stack is not None:
        plt.subplot(232)
        plt.title("Binary detection")
        plt.imshow(seg_stack[image], cmap="gray")
    if weights_stack is not None:
        plt.subplot(233)
        plt.title("Pixel weighting")
        plt.imshow(weights_stack[image], cmap="gray")
    plt.subplot(235)
    plt.title("Prediction")
    plt.imshow(predictions[image], cmap="gray")
    if seg_stack is not None:
        plt.subplot(236)
        plt.title("Overlay with ground truth")
        plt.imshow(overlay_preds_targets((predictions[image] > 0.5).numpy(), seg_stack[image]))
    plt.tight_layout()
    plt.show()

Predicted experiment in 2.3 s.
Dice = 0.19969367975092164


interactive(children=(IntSlider(value=0, description='image', max=567), Output()), _dom_classes=('widget-inter…

Use the ground truth as annotations to test how many frames are needed, and how to fine tune.

In [6]:
# Select and annotated frames
n_annotations = 4 # number of frames to annotated
n_valid = 1 # number of annotated frames to use for validation

# Randomly choose frames
indices_annotated = np.random.choice(np.arange(len(rgb_stack)), size=n_annotations, replace=False)
print("Indices of annotated frames:", indices_annotated, sep="\n")
rgb_annotated = np.stack([rgb_stack[idx] for idx in indices_annotated])

# Take ground truths as annotation
seg_annotated = np.stack([seg_stack[idx] for idx in indices_annotated])

# Create weights
if use_weights:
    weights_annotated = np.stack([weights_stack[idx] for idx in indices_annotated])
#     weights_annotated = compute_weights(seg_annotated)
else:
    weights_annotated = None

# Compute class weights (as pixel imbalance)
pos_count = 0
neg_count = 0
# for filename in dataloaders["train"].dataset.y_filenames:
#     y = io.imread(filename)
#     pos_count += (y == 255).sum()
#     neg_count += (y == 0).sum()
pos_count += (seg_annotated[:n_annotations - n_valid] == 1).sum() #* (len(dataloaders["train"].dataset) // n_annotations)
neg_count += (seg_annotated[:n_annotations - n_valid] == 0).sum() #* (len(dataloaders["train"].dataset) // n_annotations)
pos_weight = torch.tensor((neg_count + pos_count) / (2 * pos_count)).to(device)
neg_weight = torch.tensor((neg_count + pos_count) / (2 * neg_count)).to(device)
print("{:.6f} positive and {:.6f} negative weighting.".format(pos_weight.item(), neg_weight.item()))

Indices of annotated frames:
[436 289 480 135]
96.120148 positive and 0.502614 negative weighting.


In [7]:
# %matplotlib inline
import matplotlib.pyplot as plt

annotated_per_batch = min(n_annotations - n_valid, batch_size) # number of annotated frames in each batch
n_iter = 200

# Fine tune the model
model_ft = copy.deepcopy(model)
loss_fn = get_BCEWithLogits_loss(pos_weight=pos_weight, neg_weight=neg_weight)
optimizer = torch.optim.Adam(model_ft.parameters(), lr=learning_rate)

# Set model to training mode
model_ft.train()

# Iterate over the data
print("Iteration (over %d):" % n_iter)
# dataloader_iter = iter(dataloaders["train"])
best_iter, best_dice = -1, 0
for i in range(n_iter):
    # Get next batch, and re-initialize dataloader if needed
#     try:
#         batch = next(dataloader_iter)
#     except StopIteration:
#         dataloader_iter = iter(dataloaders["train"])
#         batch = next(dataloader_iter)
    
    ## Replace first elements of the batch by annotations
    # Randomly select elements
    rand_idx = np.random.choice(np.arange(n_annotations - n_valid), size=annotated_per_batch, replace=False)
    # Keep only relevant input channels
    channel_imgs = {"R": rgb_annotated[rand_idx,:,:,0],
                    "G": rgb_annotated[rand_idx,:,:,1],
                    "B": rgb_annotated[rand_idx,:,:,2]}
    images = np.stack([channel_imgs[channel] for channel in input_channels], axis=1)
    # Apply train transforms
    images = [normalize_range(pad_transform(aug_fn(image), u_depth)) for image in images]
    targets = pad_transform_stack(seg_annotated[rand_idx], u_depth)
    if use_weights:
        weights = pad_transform_stack(weights_annotated[rand_idx], u_depth)
        items_annotated = [(i, t, w) for i, t, w in zip(images, targets, weights)]
    else:
        items_annotated = [(i, t) for i, t in zip(images, targets)]
    
    # Extract items from batch and send to model device
#     batch[:annotated_per_batch] = items_annotated
#     batch = collate_fn(batch[:annotated_per_batch * 2])
    batch = collate_fn(items_annotated)
    
    batch_x = batch[0].to(model.device)
    batch_y = batch[1].to(model.device)
    if use_weights: # pixel-wise weights
        batch_w = batch[2]
        batch_w = batch_w.to(model.device)
    else:
        batch_w = None
    
    # Forward pass
    y_pred = model_ft(batch_x)

    # Loss
    loss = loss_fn(y_pred, batch_y, batch_w)

    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if n_valid > 0:
        valid_dice = evaluate_stack(model_ft, rgb_annotated[n_annotations - n_valid:],
                                    seg_annotated[n_annotations - n_valid:], batch_size, metrics=metrics,
                                    input_channels=input_channels,
                                    transform=lambda stack: normalize_range(
                                        pad_transform_stack(stack, u_depth)))["dice"]
        if best_dice < valid_dice:
            best_iter = i
            best_dice = valid_dice
            torch.save(model_ft.state_dict(), os.path.join(out_model_name, "model_best.pth"))
    else:
        valid_dice = 0.0
    
    if n_iter >= 10 and (i + 1) % int(n_iter / 10) == 0:
        print("{}: dice_annotated = {:.6f} - dice_valid = {:.6f} - dice_full = {:.6f}".format(i + 1,
            evaluate_stack(model_ft, rgb_annotated[:n_annotations - n_valid], 
                           seg_annotated[:n_annotations - n_valid], batch_size, metrics=metrics,
                           input_channels=input_channels,
                           transform=lambda stack: normalize_range(pad_transform_stack(stack, u_depth)))["dice"],
            valid_dice,
            evaluate_stack(model_ft, rgb_stack, seg_stack, batch_size, metrics=metrics,
                           input_channels=input_channels,
                           transform=lambda stack: normalize_range(pad_transform_stack(stack, u_depth)))["dice"]))
# Load best model found
if n_valid > 0:
    print("Best model fine tuned in iteration %d." % best_iter)
    model_ft.load_state_dict(torch.load(os.path.join(out_model_name, "model_best.pth")))

Iteration (over 200):
20: dice_annotated = 0.846047 - dice_valid = 0.846229 - dice_full = 0.851170
40: dice_annotated = 0.871750 - dice_valid = 0.851064 - dice_full = 0.863276
60: dice_annotated = 0.895420 - dice_valid = 0.882834 - dice_full = 0.877214
80: dice_annotated = 0.910595 - dice_valid = 0.889088 - dice_full = 0.887204
100: dice_annotated = 0.924317 - dice_valid = 0.897297 - dice_full = 0.893143
120: dice_annotated = 0.941968 - dice_valid = 0.901357 - dice_full = 0.897816
140: dice_annotated = 0.953773 - dice_valid = 0.904632 - dice_full = 0.900697
160: dice_annotated = 0.964700 - dice_valid = 0.906278 - dice_full = 0.904454
180: dice_annotated = 0.975496 - dice_valid = 0.916895 - dice_full = 0.909051
200: dice_annotated = 0.980084 - dice_valid = 0.922374 - dice_full = 0.912047
Best model fine tuned in iteration 193.


In [8]:
@interact(i=(0, len(batch_x) - 1))
def plot_batch(i=0):
    input = batch_x[i].cpu().numpy()
    input = (np.stack([input[0], input[1], input[1]], axis=-1) + 1) / 2
    
    plt.figure(figsize=(12,4))
    plt.subplot(131)
    plt.imshow(input)
    plt.subplot(132)
    plt.imshow(batch_y[i].cpu().numpy())
    if use_weights:
        plt.subplot(133)
        plt.imshow(batch_w[i].cpu().numpy())
    plt.tight_layout()
    plt.show

interactive(children=(IntSlider(value=0, description='i', max=2), Output()), _dom_classes=('widget-interact',)…

In [9]:
# Predict again, and compare results
start = time.time()
predictions_ft = predict_stack(model_ft, rgb_stack, batch_size, input_channels=input_channels,
                               transform=lambda stack: normalize_range(pad_transform_stack(stack, u_depth)))
predictions_ft = torch.sigmoid(predictions_ft)
print("Predicted experiment in %.1f s." % (time.time() - start))

if seg_stack is not None:
    print("Dice    =", dice_coef((predictions > 0.5).numpy(), seg_stack))
    print("Dice_ft =", dice_coef((predictions_ft > 0.5).numpy(), seg_stack))

@interact(image=(0, len(rgb_stack) - 1))
def plot_experiment(image=0):
    plt.figure(figsize=(12, 8))
    plt.subplot(231)
    plt.title("Raw input")
    plt.imshow(rgb_stack[image])
    if seg_stack is not None:     
        plt.subplot(234)
        plt.title("Binary detection")
        plt.imshow(seg_stack[image], cmap="gray")
    plt.subplot(232)
    plt.title("Prediction")
    plt.imshow(predictions[image], cmap="gray")
    if seg_stack is not None:
        plt.subplot(233)
        plt.title("Overlay with ground truth")
        plt.imshow(overlay_preds_targets((predictions[image] > 0.5).numpy(), seg_stack[image]))
    plt.subplot(235)
    plt.title("Fine tuned prediction")
    plt.imshow(predictions_ft[image], cmap="gray")
    if seg_stack is not None:
        plt.subplot(236)
        plt.title("Overlay with ground truth")
        plt.imshow(overlay_preds_targets((predictions_ft[image] > 0.5).numpy(), seg_stack[image]))
    plt.tight_layout()
    plt.show()

Predicted experiment in 2.2 s.
Dice    = 0.19969367975092164
Dice_ft = 0.9111517111723493


interactive(children=(IntSlider(value=0, description='image', max=567), Output()), _dom_classes=('widget-inter…

## Test drawing for manual annotations
Test drawing by mouse.