# load required packages

In [2]:
## running with python 3.8.5
# installed packages are listed in train_requirements.txt
import pandas as pd
import time, os, sys
from urllib.parse import urlparse
import matplotlib.pyplot as plt
import matplotlib as mpl
import glob
import re
from tifffile import TiffFile
import tifffile as TIFFfile
import skimage
import random
import copy
import scipy
import scipy.ndimage
import math
import numpy as np
from sklearn.model_selection import train_test_split
import cellpose
from cellpose import models, utils, io
import cv2


# Parameters

In [5]:
# typical size of cells/nuclei in µm
diam_um = 3.5 

# train from this model:
pretrainedModel='./pretrainedModels/cyto2torch_0' ## downloaded from https://www.cellpose.org/models/cyto2torch_0

# Parameters for training:
# number of augmentations (we used 0 or 10 (recommended) in the manuscript)
n_augmentations= 10 
# csv files listing all ground truth data and corresponding pixel sizes
TrainDataFiles = "./trainingData/TrainDataFiles.csv"
# retrained model files will be saved in this folder:
OutFolder="./retrained_models"

# jaccard index is calculated for test data without or with blurring by factor:
test_blur=2.5

# Prepare training data

In [6]:
#### Load and rescale training data
list_traindata= pd.read_csv(TrainDataFiles)

data=[]
labels=[]

list_traindata
for i,f in enumerate(list_traindata['PathToFile']):
    #print(i)
    img=io.imread(f)
    ll=io.imread(list_traindata['PathToMask'][i])
    xscale=list_traindata['xScale'][i] * 30/diam_um
    yscale=list_traindata['yScale'][i]  * 30/diam_um
    imgshape=list_traindata['type'][i]
    if xscale > 1:
        #print('wrong')
        img_scaled = scipy.ndimage.zoom(img,(xscale,yscale), order=3)
    else: 
        img_scaled = scipy.ndimage.zoom(img,(xscale,yscale), order=2)
        
    l_scaled=scipy.ndimage.zoom(ll,(xscale,yscale), order=0)
    
                
    data.append(np.uint16(img_scaled))
    labels.append(np.uint16(l_scaled))
        


### Split data into test and training data
train_data, test_data, train_label, test_label = train_test_split(data, labels,random_state=666,test_size=0.1)
print('length of training data: ' + str(len(train_data)))
print('length of test data: ' + str(len(test_data)))

train_data_augmented = copy.deepcopy(train_data)
train_label_augmented = copy.deepcopy(train_label)


### Blur testdata by blurfactor_testb
test_data_blurred = copy.deepcopy(test_data)
for i in range(len(test_data)):
    test_data_blurred[i] = cv2.GaussianBlur(test_data_blurred[i],(0,0),test_blur,cv2.BORDER_DEFAULT)

random.seed(24579)
if n_augmentations > 0:
    for ii in range(n_augmentations):
        #print(ii)
        for i in range(len(train_data)):
            img=np.copy(train_data[i])
            l= np.copy(train_label[i])

            # get angle for rotation (see below)
            angle=(random.random())*45

            # random flipping of images
            if(random.random()>0.5):
                img=np.transpose(img,(1,0))
                l=np.transpose(l,(1,0))
            if(random.random()>=0.5):
                img=np.fliplr(img)
                l=np.fliplr(l)
            if(random.random()>=0.5):
                img=np.flipud(img)
                l=np.flipud(l)

            # rotate image by angle (from above)
            minI=np.amin(img)
            img=minI+scipy.ndimage.rotate(img-minI, angle, axes=(1, 0), reshape=True, order=3, mode='constant', cval=0.0, prefilter=True)
            l_out=scipy.ndimage.rotate(l, angle, axes=(1, 0), reshape=True, order=0, mode='constant', cval=0.0, prefilter=True)


            # add noise
            img_noise=skimage.util.random_noise(img/np.amax(img), 'gaussian',clip=True)

            noiseratio=random.random()
            img_out=img*(1-noiseratio) + noiseratio*img_noise*np.amax(img)

            # add blurring (up to factor 3)
            augBlur=random.random()*3
            img_blur=cv2.GaussianBlur(img_out,(0,0),augBlur,cv2.BORDER_DEFAULT)

            # increase/decrease intensity by random factoe
            intFactor=random.random()+0.5
            train_data_augmented.append(np.uint16(img_blur*intFactor))
            train_label_augmented.append(np.uint16(l_out))

print('augmentation done')
print('saving to ' + 'cellpose_meioticNuclei_'+str(n_augmentations)+'Augm')
print('number of images:')
print(len(train_data_augmented)) 
#print(len(train_label_augmented))



# Training

In [7]:
model3D = models.CellposeModel(gpu=True,diam_mean=30,pretrained_model=pretrainedModel)
chan = [0,0] # only single color images here

## 500 Epochs 
modelout = model3D.train(
              train_data_augmented, train_label_augmented,
              test_data=test_data,test_labels=test_label, 
              channels=chan, 
              normalize=True, 
              save_path=OutFolder, model_name='cellpose500epochs_meioticNuclei_'+str(n_augmentations)+'Augm',
              save_every=20, learning_rate=0.05, min_train_masks=2,
              n_epochs=500, momentum=0.9, weight_decay=1e-05, batch_size=8, rescale=True)



print(modelout)


# Evaluate model

In [8]:
masks_pred, flows, styles = model3D.eval(test_data, diameter=30, 
                                                      channels=chan, 
                                                      batch_size=3,
                                                      normalize=True,rescale=True)
masks_pred_blur, flows_blur, styles_blur = model3D.eval(test_data_blurred, diameter=30, 
                                                      channels=chan, 
                                                      batch_size=3,
                                                      normalize=True,rescale=True)

metrics= cellpose.metrics.aggregated_jaccard_index(test_label, masks_pred)
metrics_blurred= cellpose.metrics.aggregated_jaccard_index(test_label, masks_pred_blur)

print('processed:  cellpose500epochs_meioticNuclei_'+str(n_augmentations)+'Augm')
print('Summary of jaccard index (no blurring): ')
print('min: ' + str(np.amin(metrics)))
print('mean: ' + str(np.mean(metrics)))
print('max: ' + str(np.amax(metrics)))

print('Summary of jaccard index (with blurring by ' + str(test_blur) + '): ')
print('min: ' + str(np.amin(metrics_blurred)))
print('mean: ' + str(np.mean(metrics_blurred)))
print('max: ' + str(np.amax(metrics_blurred)))

