# napari-mAIcrobe cell cycle model training
---
napari-mAIcrobe, is a napari based re-implementation of eHooke. eHooke was developed as an image analysis framework developed specifically for automated analysis of microscopy images of spherical bacterial cells. The original eHooke contained a trained artificial neural network to automatically classify the cell cycle phase of individual *S. aureus* cells.

In this notebook we enable training of the original eHooke neural network for an arbitrary number of channels and classes.

---

## Before getting started

-   Make sure you have the training and test data already loaded into your Google Drive.

-   Preprocess your data into the correct format - use the generate_pickles widget in napari-mAIcrobe to do this.

-   Currently the input to the model is given by two folders that can contain multiple pickle (.p) files which form a pair - the training source and the training target.
-   The training source is a pickled python list of numpy arrays where each array is a 100x(#channels*100)

-   The training target is a list of integer labels that correspond to the cell cycle phase.
    

In [None]:
#@title 1. Install all depencencies
!pip uninstall -qy fastai albucore spacy albumentations grpcio-status jax flax optax chex dopamine-rl tensorflow-text tensorflow-decision-forests keras-hub orbax-checkpoint tensorstore opencv-python-headless opencv-python opencv-contrib-python ydf thinc
!pip install -q tensorflow==2.15.0 keras==2.15.0 numpy==1.26.4 tf-keras==2.15 grpcio==1.71.0
exit()

In [None]:
#@title 2. Import necessary libraries and connect to Google Drive

import sys
import os

import math

import numpy as np

import tensorflow
import tensorflow as tf
from tensorflow.keras import Input
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation, Flatten, RandomFlip, RandomRotation
from tensorflow.keras.layers import Conv2D, MaxPooling2D
from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint, EarlyStopping

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

from glob import glob
from tifffile import imread

from matplotlib import pyplot as plt

import pickle

class CellCycleSequence(tf.keras.utils.Sequence):

    def __init__(self, x_set, y_set, batch_size, n_channels):
        super().__init__()
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size
        self.n_channels = n_channels

    def __len__(self):
        return math.ceil(len(self.y) / self.batch_size)

    def __getitem__(self, idx):
        batch_x = self.x[idx*self.batch_size: (idx + 1)*self.batch_size]
        batch_y = self.y[idx*self.batch_size: (idx + 1)*self.batch_size]

        return np.array(batch_x).reshape(-1,100,self.n_channels*100,1), np.array(batch_y)


class ModelTrainer:

    def __init__(self):

        self.X_trn = None
        self.y_trn = None

        self.X_val = None
        self.y_val = None

        self.trn_sequence = None
        self.val_sequence = None

    def load_data(self, X_path, y_path, val_split, nchannels):

        # get all pickles in the X_path folder
        X_files = glob(os.path.join(X_path, "*.p"))
        y_files = glob(os.path.join(y_path, "*.p"))
        assert len(X_files) == len(y_files), "Number of X and y pickle files must be the same"
        print(f"Found {len(X_files)} pickle files in {X_path} and {y_path}")
        # load all pickles and concatenate them
        X = []
        y = []
        for xf, yf in zip(sorted(X_files), sorted(y_files)):
            print(f"Loading {xf} and {yf}")
            with open(xf, 'rb') as xfile:
                X_part = pickle.load(xfile)
                X.extend(X_part)
            with open(yf, 'rb') as yfile:
                y_part = pickle.load(yfile)
                y.extend(y_part)
        X = np.array(X)
        y = np.array(y)
        print(f"Loaded {len(X)} samples")
        y = [int(i)-1 for i in y]

        rng = np.random.RandomState()
        ind = rng.permutation(len(y))
        n_val = max(1, int(round(val_split * len(ind))))
        ind_train, ind_val = ind[:-n_val], ind[-n_val:]
        self.X_val, self.Y_val = [X[i] for i in ind_val]  , [y[i] for i in ind_val]
        self.X_trn, self.Y_trn = [X[i] for i in ind_train], [y[i] for i in ind_train]
        print('number of images: %3d' % len(X))
        print('- training:       %3d' % len(self.X_trn))
        print('- validation:     %3d' % len(self.X_val))

        sample_inds = rng.randint(0,len(y),size=3)
        fig = plt.figure(figsize=(5,9))
        subfigs = fig.subfigures(nrows=3,ncols=1)
        for row,subfig in enumerate(subfigs):
            sind = sample_inds[row]
            subfig.suptitle(f"Phase # {y[sind]+1}")

            axs = subfig.subplots(nrows=1,ncols=nchannels)
            for nc in range(nchannels):
                if nchannels==1:
                    axs.imshow(X[sind][:,0:100],cmap='gray',clim=(0,1))
                    axs.set_title(f"Channel # {nc}")
                else:
                    axs[nc].imshow(X[sind][:,0+nc*100:100+nc*100],cmap='gray',clim=(0,1))
                    axs[nc].set_title(f"Channel # {nc}")

    def create_model(self, depth, cellcyclephases, n_channels, augbool = False, hflip=False, vflip=False, rotfactor=0):
        self.model = Sequential()

        self.model.add(Input(shape=(100,n_channels*100,1)))

        if augbool:
            if hflip and vflip:
                self.model.add(RandomFlip(mode='horizontal_and_vertical'))
            elif hflip:
                self.model.add(RandomFlip(mode='horizontal'))
            elif vflip:
                self.model.add(RandomFlip(mode='vertical'))

            if rotfactor>0:
                self.model.add(RandomRotation(rotfactor))

        self.model.add(Conv2D(16, (3, 3), padding='same'))
        self.model.add(Activation('relu'))

        if depth > 1:
            self.model.add(MaxPooling2D(pool_size=(2, 2)))

            if depth > 2:
                self.model.add(Conv2D(16, (3, 3), padding='same'))
                self.model.add(Activation('relu'))

                if depth > 3:
                    self.model.add(MaxPooling2D(pool_size=(2, 2)))

                    if depth > 4:
                        self.model.add(Conv2D(16, (3, 3), padding='same'))
                        self.model.add(Activation('relu'))

                        if depth > 5:
                            self.model.add(MaxPooling2D(pool_size=(2, 2)))

                            if depth > 6:
                                self.model.add(Conv2D(32, (3, 3), padding='same'))
                                self.model.add(Activation('relu'))

                                if depth > 7:
                                    self.model.add(Conv2D(32, (3, 3), padding='same'))
                                    self.model.add(Activation('relu'))

                                    if depth > 8:
                                        self.model.add(Conv2D(32, (3, 3), padding='same'))
                                        self.model.add(Activation('relu'))

                                        if depth > 9:
                                            self.model.add(Conv2D(32, (3, 3), padding='same'))
                                            self.model.add(Activation('relu'))

                                            if depth > 10:
                                                self.model.add(Flatten())
                                                self.model.add(Dense(100))
                                                self.model.add(Activation('relu'))
        if depth <= 10:
            self.model.add(Flatten())
            pass

        self.model.add(Dense(cellcyclephases))
        self.model.add(Activation('softmax'))
        #self.model.summary()


    def compile_model(self,learningrate):
        self.model.compile(loss='sparse_categorical_crossentropy',
                           optimizer=tf.keras.optimizers.Adam(learningrate),
                           metrics=['accuracy'])

    def train_model(self, model_path, n_epochs, n_batch_size, n_channels):
        #tbCallBack = TensorBoard(log_dir="Graph", histogram_freq=0, write_graph=True, write_images=True)
        checkpoint = ModelCheckpoint(os.path.join(model_path,"weights_{epoch}.keras"), verbose=1, monitor="val_loss", save_best_only=True)
        #earlystopper = EarlyStopping(patience=50, monitor="val_loss", mode="auto", verbose=1)

        history = self.model.fit(x=CellCycleSequence(self.X_trn, self.Y_trn, n_batch_size, n_channels),
                                 validation_data = CellCycleSequence(self.X_val, self.Y_val, n_batch_size, n_channels),
                                 epochs=n_epochs, verbose=1,
                                 callbacks=[checkpoint])

        return history

    def save_model(self,model_path):
        self.model.save(os.path.join(model_path,"weights_last.keras"))

model_trainer = ModelTrainer()

IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
    from google.colab import output
    output.enable_custom_widget_manager()
    from google.colab import drive
    drive.mount('/content/drive')



In [None]:
#@title 3. Load data and select training parameters

#@markdown ###Load data:
Training_source = '' #@param {type:"string"}
Training_target = '' #@param {type:"string"}

#@markdown ###Model parameters:
n_cell_cycle_stages = 3#@param {type:"integer"}
n_channels = 2#@param {type:"integer"}
depth = 11#@param {type:"integer"}

#@markdown ###Training parameters:
number_of_epochs =  200#@param {type:"integer"}
batch_size =  32#@param {type:"integer"}
percentage_validation =  0.2#@param{type:"number"}
initial_learning_rate = 0.001 #@param {type:"number"}

#@markdown #### Data augmentation

dataaugmentation = True #@param {type:"boolean"}

#@markdown Rotation
factor_rotations =  0.5 #@param {type:"slider", min:0, max:1, step:0.05}

#@markdown Flips
horizontal_flip = True #@param {type:"boolean"}
vertical_flip = True #@param {type:"boolean"}


#@markdown ###Save model:
model_path = '' #@param {type:"string"}

if not os.path.exists(model_path):
    os.mkdir(model_path)


model_trainer.load_data(X_path=Training_source,y_path=Training_target,val_split=percentage_validation,nchannels=n_channels)

model_trainer.create_model(depth,n_cell_cycle_stages,n_channels,augbool = dataaugmentation, hflip=horizontal_flip, vflip=vertical_flip, rotfactor=factor_rotations)

model_trainer.compile_model(initial_learning_rate)



In [None]:
#@title 4. Start model training (Warning: this might take a long time)


model_history = model_trainer.train_model(model_path,number_of_epochs,batch_size,n_channels)
model_trainer.save_model(model_path)


In [None]:
#@title 5. Assess Quality Control

#@markdown ###Load data:
QC_source = '' #@param {type:"string"}
QC_target = '' #@param {type:"string"}

tX = np.array(pickle.load(open(QC_source, "rb")))
tX = np.array(tX).reshape(-1,100,n_channels*100,1)
ty = np.array(pickle.load(open(QC_target, "rb")))

predicted = np.argmax(model_trainer.model.predict(tX),axis=1)+1

model_metrics = model_history.history
fig,axs = plt.subplots(1,2)

axs[0].plot(model_metrics['loss'],label='loss')
axs[0].plot(model_metrics['val_loss'],label='val_loss')
axs[0].legend()
axs[0].set_xlabel('Epoch')
axs[0].set_ylabel('loss')

axs[1].plot(model_metrics['accuracy'],label='accuracy')
axs[1].plot(model_metrics['val_accuracy'],label='val_accuracy')
axs[1].legend()
axs[1].set_xlabel('Epoch')
axs[1].set_ylabel('accuracy')

plt.tight_layout()
plt.show()


conf = confusion_matrix(ty,predicted)
confaxs = ConfusionMatrixDisplay(confusion_matrix=conf, display_labels=(1,2,3))
confaxs.plot()
plt.title("Confusion matrix")
plt.show()