# Create a Cellpose model

This is a program used to generate a new cellpose model based on a series of representative images and their corresponding manually labelled. 

In [None]:
from cellpose import core, models, io, metrics
import os

import matplotlib.pyplot as plt
import numpy as np 
import pandas as pd
import tqdm 
import tifffile as tf

import tqdm
import tkinter as tk
from tkinter import filedialog

import datetime


from PIL import Image

### set the Folder path for the test images

In [None]:
root = tk.Tk()
root.attributes("-topmost", True)
root.withdraw() # Stops a second window opening
image_folder = filedialog.askdirectory(title = 'Select image Folder')

### Set the folder path for the user defined masks

In [None]:
root = tk.Tk()
root.attributes("-topmost", True)
root.withdraw() # Stops a second window opening
mask_folder = filedialog.askdirectory(title = 'Select Masks Folder')

#### Create a method to extract all the filenames from a folder. 

In [None]:
def get_files_from_folder(folder_path): 
    '''A method to extract all files from the image.'''

    file_list = os.listdir(folder_path)
    image_files = []
    
    for i in range( len(file_list) ): 
        if file_list[i][-4:] == '.tif' or file_list[i][-4:] == '.png':
            image_files.append(file_list[i])
        
    return(image_files)

#### Create a method to download in the image data from the image file name. 

In [None]:
def get_image_data(image_file):
    '''Get the image data from the file using Pillow.
    Convert the PILLOW image to a numpy array'''

    image_data = tf.imread(image_file)
    
    # print(image_data.getexif())
    
    np_image_data = np.array(image_data)

    return(np_image_data)

In [None]:
image_file_list = get_files_from_folder(image_folder)
mask_file_list = get_files_from_folder(mask_folder)

print( len(image_file_list) )
print( len(mask_file_list) )

In [None]:
print(image_folder)
print(mask_folder)

#### extract datasets for testing/training

In [None]:
training_image_index = np.random.choice(len(image_file_list)-1, size = int(0.8*(len(image_file_list)-1)), replace = False)

In [None]:
print(np.sort(training_image_index))
print(len(training_image_index))

In [None]:
test_image_index = []

for i in range(len(image_file_list)):
    if len(np.where(training_image_index == i)[0]) == 0:
        test_image_index.append(i)

print(test_image_index)
print(len(test_image_index))

### Get test images and user_masks into a format for cellpose model Training. 

In [None]:
ground_truth_training = []
training_images = []

for i in range(len(training_image_index)):
    # get the image data
    image_file_name = image_file_list[training_image_index[i]]
    individual_image = get_image_data(image_folder + '/'+ image_file_name)
    training_images.append(individual_image)
    # get the corresponding user_defined_mask
    mask_file_name = image_file_name[0:-4] + '.tif'
    # print(image_file_name)
    # print(mask_file_name)
    user_mask = get_image_data(mask_folder + '/'+ mask_file_name)
    ground_truth_training.append(user_mask)

    
# ground_truth = np.array(ground_truth)
print(training_images[0].shape)
print(ground_truth_training[0].shape)
print(len(training_images))

### Get test images and user_masks into a format for cellpose model evaluation. 

In [None]:
ground_truth_test = []
test_images = []

for i in range(len(test_image_index)):
    # get the image data
    image_file_name = image_file_list[test_image_index[i]]
    individual_image = get_image_data(image_folder + '/'+ image_file_name)
    test_images.append(individual_image)
    # get the corresponding user_defined_mask
    mask_file_name = image_file_name[0:-4] + '.tif'
    # print(image_file_name)
    # print(mask_file_name)
    user_mask = get_image_data(mask_folder + '/'+ mask_file_name)
    ground_truth_test.append(user_mask)

    
# ground_truth = np.array(ground_truth)
print(test_images[i].shape)
print(ground_truth_test[i].shape)
print(len(test_images))

---

## Train a model using the training data provided to the program. 

In [None]:
def get_date_and_time():
    '''Get the time and date at this moment in time.'''

    # Get the current time and date.
    date_time_now = str(datetime.datetime.now())

    # Extract the date. 
    date = date_time_now[0:10]
    date = date.replace('-', '_')
    
    # Extract the time
    find_colon = date_time_now.find(':')
    find_dp = date_time_now.find('.')
    time = date_time_now[find_colon-2:find_dp]
    
    time = time.replace(':', '_')

    return(date, time)

----
The next cell performs the training of cellpose models. The code is taken directly from a google colaboratry notebook, produced by the research group behind cellpose, avaliable at the following location: https://colab.research.google.com/github/MouseLand/cellpose/blob/main/notebooks/run_cellpose_2.ipynb

For more information about github, please see:     
Paper: https://www.nature.com/articles/s41592-022-01663-4   
Online documentation: https://cellpose.readthedocs.io/en/latest/   
Github Repository: https://github.com/MouseLand/cellpose/tree/main   

Reference:   
Pachitariu, M., Stringer, C. Cellpose 2.0: how to train your own model. Nat Methods 19, 1634–1641 (2022). https://doi.org/10.1038/s41592-022-01663-4

-----

In [None]:
# This cell is taken directly from the Colab notebook and then 
# modified for my needs. 

# Get the date and time. 
date, time = get_date_and_time()

# start logger (to see training across epochs)
logger = io.logger_setup()

# DEFINE CELLPOSE MODEL (without size model)
model = models.CellposeModel(gpu=True, model_type='cyto')

# set channels
channels = [0, 0]
# Set Epoch to train over
n_epochs = 100
# Set learning rate. 
learning_rate = 0.1
# Set the weight decay
weight_decay = 0.0001


# # get files
# output = io.load_train_test_data(train_dir, test_dir, mask_filter='_seg.npy')
# # train_data, train_labels, _, test_data, test_labels, _ = output

new_model_path = model.train(training_images[0 : int(0.8*len(training_images)) ], ground_truth_training[0 : int( 0.8*len(training_images)) ], 
                              test_data=training_images[int(0.8*len(training_images)):],
                              test_labels=ground_truth_training[int(0.8*len(training_images)):],
                              channels=channels, 
                              save_path=os.path.dirname(image_folder), 
                              n_epochs=n_epochs,
                              learning_rate=learning_rate, 
                              weight_decay=weight_decay, 
                              nimg_per_epoch=8,
                              model_name = date + '_' + time +'_' + 'CP_Models')

# diameter of labels in training images
diam_labels = model.diam_labels.copy()

print(test_images[0].shape)

---------
## Test Accuracy of newly generated models

In [None]:
# Initalise
retrained_masks = []

#########
## NOTE: Need to manually add new cellpose model to path, see following link: 
## https://cellpose.readthedocs.io/en/latest/models.html#user-trained-models
#########

# Add path to cellpose model/cellpose model name. 
model_path = date + '_' + time +'_' + 'CP_Models'

# initalise model
model = models.CellposeModel(gpu=True, model_type=model_path)

# For all of the testing images.  
for i in tqdm.tqdm(range(len(test_images) )):
    print(test_images[i].shape)
    masks = model.eval(test_images[i], channels = [2, 0], diameter = None)[0]
    retrained_masks.append(masks)

# Check the performance of the model using IoU metric. 
ap = metrics.average_precision(ground_truth_test, retrained_masks)[0]
# Print the model accuracy. 
print(ap[:,0].mean()) 

----------
### Plot the cellpose masks and the ground truth to compare results. 

In [None]:
im = 1
fig1, ax1 = plt.subplots()
ax1.imshow(test_images[im][ :, :], vmin = 50)
ax1.imshow(retrained_masks[im], alpha = 0.25, cmap = 'inferno_r', vmax = 1)
ax1.set_title('Cellpose Masks')

fig3, ax3 = plt.subplots()
ax3.imshow(test_images[im][ :, :], vmin = 50)
ax3.imshow(ground_truth_test[im], alpha = 0.25, cmap = 'inferno_r', vmax = 1)
ax3.set_title('Ground truth')