***
# Retraining: transfert learning

Part to retrain CellPose, copy-pasted from a previous course: https://github.com/gletort/NeubiasPasteur2023_AdvancedCellPose/blob/main/CellPose.ipynb

When small training dataset is available, *pretraining* (on somewhat similar images and a close task) help to initialize the network.
Fine-tuned a network in a new dataset close to the main dataset by *retraining* a network.

Should all the layers be retrained ?

For classification networks (https://www.sciencedirect.com/science/article/abs/pii/S0010482520304467): 

- Lot of images: retrain all
- Few images, close to the training ones: high risk of overfitting. Instead retrain only on the last layer: use the features learned by the original training, and change only the last layer.
- Few images, different: intermediate (not all the layers). Idea is that very first layers of cnn are very general (to images and network task). Risk of overfitting if retrain too many layers.


- From scratch and few images: you can pretrain on a different task on the same images: e.g. training for segmentation: pretrain an auto-encoder (reconstruct the image), and then retrain on the segmentation.


But with CellPose models, deeper layers are connected to shallow layers. This paper: https://arxiv.org/abs/2110.02196 suggests that freezing only the "bottleneck block" (the base of the U) gives better retraining. 

## Test of retraining CellPose model on our test dataset

Test with oocytes segmentation: test without retraining, retraining from scratch, retraining from a pretrained network, pretrained+freezing some layers

In [None]:
## Imports
import os, shutil, random, time
import numpy as np
import tifffile
import matplotlib.pyplot as plt
from glob import glob
import tempfile
import napari

## The "classic" CellPose version (version 3)
import sys
sys.path.insert(1,os.path.abspath('src/cellpose3/'))
import src.cellpose3.cellpose as cellpose3
import src.cellpose3.cellpose.models as cp3_models
import src.cellpose3.cellpose.dynamics as cp3_dyn

In [None]:
## Load the data for retraining

ootrainimg_files = glob(datadir+"/dataOocytes/clin2/input/*.png")
ootrainmask_files = [ datadir+"/dataOocytes/clin2/mask/"+os.path.basename(filepath) for filepath in ootrainimg_files]
ootrain_img = [cv2.imread(fimg, cv2.IMREAD_GRAYSCALE) for fimg in ootrainimg_files]
ootrain_mask = [np.uint8(io.imread(fimg)/255) for fimg in ootrainmask_files]

ootestimg_files = glob(datadir+"/dataOocytes/clin2_test/input/*.png")
ootestmask_files = [ datadir+"/dataOocytes/clin2_test/mask/"+os.path.basename(filepath) for filepath in ootestimg_files]
ootest_img = [cv2.imread(fimg, cv2.IMREAD_GRAYSCALE) for fimg in ootestimg_files]
ootest_mask = [np.uint8(io.imread(fimg)/255) for fimg in ootestmask_files]

In [None]:
model = cp3_models.CellposeModel(gpu=use_GPU, model_type='CP', nchan=2, pretrained_model=os.fspath(models.MODEL_DIR.joinpath("CP")) )

In [None]:
## Show the CellPose model architecture (Pytorch)
print(model.net)

## If you want to freeze some layers of the network, uncomment these lines and choose which layers to freeze

#for param in model.net.downsample.parameters():
#    param.requires_grad = False
#for param in model.net.downsample.maxpool.parameters():
#    param.requires_grad = False   
#for param in model.net.downsample.down.res_down_3.proj.parameters():
#    param.requires_grad = False 

In [None]:
new_model_path = model.train(ootrain_img, ootrain_mask,
                              min_train_masks = 1,
                              test_data=ootest_img,
                              test_labels=ootest_mask,
                              channels=[0,0], 
                              save_path='./retrain', 
                              n_epochs=20,
                              learning_rate=0.001, 
                              weight_decay=0.0001, 
                              model_name='test')


start_time = time.time()
mem_usage = memory_usage(run_retrain)
print('Maximum memory usage: %s' % max(mem_usage))
end_time = time.time()
print('Execution time:', end_time-start_time, 'seconds')

In [None]:
## Test performance of retrained model
masks, flows, styles  = model.eval(ootest_img, channels=[0,0])
ap = metrics.average_precision(ootest_mask, masks)[0]
print(f'>>> average precision at iou threshold 0.5 = {ap[:,0].mean():.3f}')
print(f'>>> average precision at iou threshold 0.9 = {ap[:,2].mean():.3f}')

cpmasks, flows, styles  = model.eval(test_imgs, channels=[0,0])
ap = metrics.average_precision(test_masks, cpmasks)[0]
print(f'>>> average precision at iou threshold 0.5 = {ap[:,0].mean():.3f}')
print(f'>>> average precision at iou threshold 0.9 = {ap[:,2].mean():.3f}')

In [None]:
## Look at results
nimg = len(ootest_img)
plt.figure(figsize=(nimg,2), dpi=200)
for ind, img in enumerate(ootest_img):
    plt.subplot(2,nimg,2*ind+1)
    plt.imshow(img[0])
    plt.axis('off')
    plt.subplot(2,nimg,2*ind+2)
    mask = masks[ind]
    plt.imshow(mask)
    plt.axis('off')