# CNN Training for Cell Cycle State Classification

### Welcome!

This notebook allows you to train a convolutional neural network (CNN) using your annotated single-cell image patches & make new predictions about the labels of previously unseen images. Follow the step-wise instructions to proceed with the network training and testing of the accuracy. 

This is a preview of the CNN training process with image annotations: <br/>
![image](../assets/cnn_training_process.png)


### Important Notes:

1. You are using the virtual environment of the [Google Colab](https://colab.research.google.com/notebooks/intro.ipynb "Google Colaboratory"). To be able to train the neural network on your annotated data, you must first **import your data** into the folders to source from. Please follow the instructions after executing the first cell of this notebook.

2. If using Google Colab: This session will 'timeout' if you do not interact with it. It's 90 minutes if you close the browser or 12 hours if you keep the browser open. Additionally, if you close your browser with a code cell is running, if that same cell has not finished, when you reopen the browser it will still be running (the current executing cell keeps running even after browser is closed). Please visit this [StackOverflow](https://stackoverflow.com/questions/54057011/google-colab-session-timeout "Google Colab Session Timeout") discussion for more details.


### Running Instructions:

1. Prior to running the whole notebook in one go, make sure to execute the first cell containing code. This allows to install the CellX library & create local directories in the environment of the virtual machine. 

2. The executed first cell will print ```Building wheel for cellx (setup.py) ... done```. Click on the ``` 📁``` folder icon located on the left-side dashboard of the Colab notebook. You should now see 4 subfolders in this directory: "sample_data" (default), "logs" "train" and "test" folder, which should all be empty.

3. At this point, you should **manually move your 'annotation_XXX.zip' files into the "train" and "test" folders**. Doing so will allow the image patch data to be processed, divided into categories and used for model training & predictions.

> When training your network, it is important you allocate some of your annotated data so that the images and their labels are not seen by the network in the training step. Doing so will enable you to test how well the model is able to predict new labels in new, previously unseen images. We recommend you distribute your annotation files so that there is **approx. 80% of the labels in the train group and the remainder 20% is used for testing stage**. If you fail to allocate enough examples for the testing dataset, you will be unable to validate the performance of the network. 

4. You can now now run the entire notebook by clicking on ```Runtime``` > ```Run``` in the upper main dashboard. Re-running the initial cell will fail to create the "logs" "train" and "test" folders as those are already in the directory. 

5. Prior to training of the model, this notebook will distribute the image patch data into the training & testing sets and introduce data augmentations. The notebook will ultimately train the neural network based on the hyperparameters you've set up.

6. During training, you can actively visualise what the network is doing via [TensorBoard](https://www.tensorflow.org/tensorboard/get_started "TensorFlow || Tensorboard"), a tool for providing the measurements and visualizations needed during the machine learning workflow. It enables tracking experiment metrics like loss and accuracy, visualizing the model graph, projecting embeddings to a lower dimensional space, and much more.

7. **Do not terminate this notebook before saving out the model.** To export and download the saved model to your local machine, press the '...' button and select 'Download'. Failure to save and download the model will result in losing all of the training progress you've achieved so far. You'll be instructed to import the downloaded model into the Colab environment again in the next steps of this protocol - make sure you have a working model to show for it. 

---

**Happy training!**

*Your [CellX](http://lowe.cs.ucl.ac.uk/cellx.html "Lowe Lab @ UCL") team*


### Install the CellX library & create subdirectories in the virtual machine:

In [None]:
# if using colab, install cellx library and make log and data folders

if 'google.colab' in str(get_ipython()):
    !pip install -q git+git://github.com/quantumjot/cellx.git
    !mkdir logs
    !mkdir train
    !mkdir test

### Import libraries and CellX toolkit:

In [None]:
import os
import zipfile
import numpy as np
import matplotlib.pyplot as plt

from datetime import datetime
from skimage.transform import resize

In [None]:
import tensorflow.keras as K
import tensorflow as tf

In [None]:
from cellx.layers import Encoder2D
from cellx.tools.dataset import build_dataset
from cellx.tools.dataset import write_dataset
from cellx.augmentation.utils import append_conditional_augmentation, augmentation_label_handler
from cellx.callbacks import tensorboard_confusion_matrix_callback

### Define paths & class labels:

In [None]:
TRAIN_PATH = "./train"
TEST_PATH = "./test"
TRAIN_FILE = os.path.join(TRAIN_PATH, 'CNN_train.tfrecord')
TEST_FILE = os.path.join(TEST_PATH, 'CNN_test.tfrecord')
LABELS = ["Interphase", "Prometaphase", "Metaphase", "Anaphase", "Apoptosis"]

### Set-up CNN training hyperparameters: 

In [None]:
BATCH_SIZE = 64
BUFFER_SIZE = 20_000
TRAINING_EPOCHS = 100

### Load the Tensorboard extension for real-time visualisation of CNN training:

In [None]:
%load_ext tensorboard
LOG_ROOT = './logs'
LOG_DIR = os.path.join(LOG_ROOT, datetime.now().strftime("%Y%m%d-%H%M%S"))

### Generate TensorFlow Record (TFRecord) files:

In [None]:
def create_tf_record(
    root, 
    filename,
    labels=LABELS
):
    
    _images = []
    _labels = []
    
    # Find the zip files:
    zipfiles = [os.path.join(root, f) for f in os.listdir(root) if f.endswith(".zip") and f.startswith("annotation_")]
    
    if len(zipfiles) == 0:
        raise Exception("No 'annotation' zip files found in the directory. Please provide your annotated data into the Colab environment.")
    
    label_counter = dict({"Interphase" : 0, "Prometaphase" : 0, "Metaphase" : 0, "Anaphase" : 0, "Apoptosis" : 0, "Flagged" : 0})
    
    for zfn in zipfiles:
        print(f"Loading file: {zfn}")
        with zipfile.ZipFile(zfn, 'r') as zip_data:
            files = zip_data.namelist()

            for numeric_label, label in enumerate(labels):
                patch_files = [f for f in files if f.endswith(".tif") and f.startswith(label.capitalize())]
                
                # Count the label instances: 
                for f in patch_files:
                    if f.endswith("_flagged.tif"):
                        label_counter["Flagged"] += 1
                    else:
                        if f.startswith(label.capitalize()):
                            label_counter[label.capitalize()] += 1
                
                # Open image patched & read the pixel data:
                images = [plt.imread(zip_data.open(f)) for f in patch_files if not "_flagged" in f]
                images_resized = [resize(img, (64, 64), preserve_range=True) for img in images]

                _images += images_resized
                _labels += [numeric_label] * len(images_resized)

                
    images_arr = np.stack(_images, axis=0)[..., np.newaxis]
    labels_arr = np.stack(_labels, axis=0)
    
    # Print out the statistics:
    print(f"Total images: {images_arr.shape[0]}")
    print(label_counter)
    
    # Visualise the class distribution:
    plt.bar(x=label_counter.keys(), height=label_counter.values(), color="grey")
    plt.title(f"Label Count per Class: {root} set of {images_arr.shape[0]} images")
    plt.xticks(ticks=label_counter.keys(), labels=label_counter.keys(), rotation=45)
    plt.show()
    plt.close()
    
    write_dataset(filename, images_arr.astype(np.uint8), labels=labels_arr.astype(np.int64))

## IMPORTANT: 

**Prior to calling the function to create the TFRecods files:**

You need to manually drag the `annotation_XXX.zip` files into the newly created folders. If you are working in the Google Colab environment, click on the folder icon at the left-side dashboard, which should now contain the 'logs', 'train' and 'test' directories. They should be empty until you drag your annotation files into them.

When training your network, it is important you reserve some of your annotated data so that it's not seen by the network in the training step - this is the most ideal way to test how well the model is able to predicts the new labels in completely new, previously unseen images. We recommend you distribute your annotation files so that there is **approx. 80% of the labels in the train group and the remainder 20% is used for testing stage**. If you fail to allocate enough examples for the testing dataset, you will be unable to validate the performance of the network. 

Once the files have been imported, run the following cell:

In [None]:
create_tf_record(TRAIN_PATH, TRAIN_FILE)
create_tf_record(TEST_PATH, TEST_FILE)

### Create a simple CNN for classification:

In [None]:
img = K.layers.Input(shape=(64, 64, 1))
x = Encoder2D(layers=[8, 16, 32, 64, 128])(img)
x = K.layers.Flatten()(x)
x = K.layers.Dense(256, activation="relu")(x)
x = K.layers.Dropout(0.2)(x)
logits = K.layers.Dense(5, activation="linear")(x)

In [None]:
model = K.Model(inputs=img, outputs=logits)

In [None]:
model.summary()

### Set-up some augmentations to be used while training

In [None]:
@augmentation_label_handler
def normalize(img):
    img = tf.image.per_image_standardization(img)
    # clip to 4 standard deviations
    img = tf.clip_by_value(img, -4., 4.)
    tf.debugging.check_numerics(img, "Image contains NaN")
    return img

In [None]:
@augmentation_label_handler
def augment(img):
    boundary_augmentation=True
    if boundary_augmentation:
        # this will randomly simulate the cropping that occurs at the edge of
        # an image volume

        vignette = np.ones((64, 64, 1), dtype=np.float32)
        width = np.random.randint(0,30)
        vignette[:,:width,...] = 0

        img = tf.cond(pred=tf.random.uniform(shape=())<0.05,
                true_fn=lambda: tf.multiply(img, vignette),
                false_fn=lambda: img)

    # do some data augmentation
    k = tf.random.uniform(maxval=3, shape=(), dtype=tf.int32)
    img = tf.image.rot90(img, k=k)

    img = tf.image.random_flip_left_right(img)
    img = tf.image.random_flip_up_down(img)
    return img

In [None]:
@augmentation_label_handler
def random_contrast(x):
    return tf.image.random_contrast(x, 0.3, 1.0)

@augmentation_label_handler
def random_brightness(x):
    return tf.image.random_brightness(x, 0.3, 1.0)

### Build the training dataset, with random augmentations

In [None]:
dataset = build_dataset(TRAIN_FILE, read_label=True)

In [None]:
dataset = dataset.map(augment)
dataset = append_conditional_augmentation(dataset, [random_contrast, random_brightness])
dataset = dataset.map(normalize)
dataset = dataset.shuffle(buffer_size=BUFFER_SIZE, reshuffle_each_iteration=True)
dataset = dataset.repeat()
dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
dataset = dataset.prefetch(1)

### Build the test dataset, without augmentations

In [None]:
test_dataset = build_dataset(TEST_FILE, read_label=True)
test_dataset = test_dataset.map(normalize)
test_dataset = test_dataset.take(-1).as_numpy_iterator()

test_images, test_labels = zip(*list(test_dataset))

### Set up tensorboard callbacks to monitor training

In [None]:
tensorboard_callback = K.callbacks.TensorBoard(log_dir=LOG_DIR)
confusion_matrix_callback = tensorboard_confusion_matrix_callback(
    model, 
    np.asarray(test_images), 
    test_labels,
    LOG_DIR,
    class_names=LABELS,
    is_binary=False
)

### Set up the loss function

In [None]:
loss = K.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer="adam", loss=loss, metrics=['accuracy'])

## Finally, train the model and evaluate performance using TensorBoard

In [None]:
%tensorboard --logdir $LOG_ROOT --host localhost

In [None]:
model.fit(
    dataset, 
    steps_per_epoch=BUFFER_SIZE//BATCH_SIZE, 
    epochs=TRAINING_EPOCHS, 
    callbacks=[tensorboard_callback, confusion_matrix_callback],
)

## Saving the Model

In [None]:
model_name = 'model'
model.save('{}.h5'.format(model_name))

## IMPORTANT:

Do not terminate this notebook before saving the model. To export and download the saved model to your local machine, press the '...' button and select 'Download'. Failure to save and download the model will result in losing all of the training progress you've achieved so far. You'll be instructed to import the downloaded model into the Colab environment again in the next steps of this protocol - make sure you have a working model to show for it. 