# CNN Training for Cell Cycle State Classification

### Welcome!

This notebook allows you to train a convolutional neural network (CNN) using your annotated single-cell image patches & make new predictions about the labels of previously unseen images. Follow the step-wise instructions to proceed with the network training and testing of the accuracy. 


### Important Notes:

1. You are using the virtual environment of the [Google Colab](https://colab.research.google.com/notebooks/intro.ipynb "Google Colaboratory"). To be able to train the neural network on your annotated data, you must first **import your data** into the folders to source from. Please follow the instructions after executing the first cell of this notebook.

2. If using Google Colab: This session will 'timeout' if you do not interact with it. It's 90 minutes if you close the browser or 12 hours if you keep the browser open. Additionally, if you close your browser with a code cell is running, if that same cell has not finished, when you reopen the browser it will still be running (the current executing cell keeps running even after browser is closed). Please visit this [StackOverflow](https://stackoverflow.com/questions/54057011/google-colab-session-timeout "Google Colab Session Timeout") discussion for more details.


### Running Instructions:

1. Prior to running the whole notebook in one go, make sure to execute the first cell containing code. This allows to install the CellX library & create local directories in the environment of the virtual machine. 

2. The executed first cell will print ```Building wheel for cellx (setup.py) ... done```. Click on the ``` 📁``` folder icon located on the left-side dashboard of the Colab notebook. You should now see 4 subfolders in this directory: "sample_data" (default), "logs" "train" and "test" folder, which should all be empty.

3. At this point, you should **manually move your 'annotation_XXX.zip' files into the "train" and "test" folders**. Doing so will allow the image patch data to be processed, divided into categories and used for model training & predictions.

4. You can now now run the entire notebook by clicking on ```Runtime``` > ```Run``` in the upper main dashboard. Re-running the initial cell will fail to create the "logs" "train" and "test" folders as those are already in the directory. 

5. Prior to training of the model, this notebook will distribute the image patch data into the training & testing sets and introduce data augmentations. The notebook will ultimately train the neural network based on the hyperparameters you've set up.

6. During training, you can actively visualise what the network is doing via [TensorBoard](https://www.tensorflow.org/tensorboard/get_started "TensorFlow || Tensorboard"), a tool for providing the measurements and visualizations needed during the machine learning workflow. It enables tracking experiment metrics like loss and accuracy, visualizing the model graph, projecting embeddings to a lower dimensional space, and much more.

---

**Happy training!**

*Your [CellX](http://lowe.cs.ucl.ac.uk/cellx.html "Lowe Lab @ UCL") team*


### Install the CellX library & create subdirectories in the virtual machine:

In [1]:
# if using colab, install cellx library and make log and data folders

if 'google.colab' in str(get_ipython()):
    !pip install -q git+git://github.com/quantumjot/cellx.git
    !mkdir logs
    !mkdir train
    !mkdir test

### Import libraries and CellX toolkit:

In [2]:
import os
import zipfile
import numpy as np
import matplotlib.pyplot as plt

from datetime import datetime
from skimage.transform import resize
from scipy.special import softmax

In [3]:
import tensorflow.keras as K
import tensorflow as tf

In [8]:
from cellx.layers import Encoder2D
from cellx.tools.dataset import build_dataset
from cellx.tools.dataset import write_dataset
from cellx.augmentation.utils import append_conditional_augmentation, augmentation_label_handler
from cellx.callbacks import tensorboard_confusion_matrix_callback
from cellx.core import load_model
from cellx.tools.confusion import plot_confusion_matrix

ModuleNotFoundError: No module named 'cellx'

In [11]:
from sklearn.metrics import confusion_matrix,precision_recall_fscore_support

### Define paths & class labels:

In [5]:
TRAIN_PATH = "./train"
TEST_PATH = "./test"
TRAIN_FILE = os.path.join(TRAIN_PATH, 'CNN_train.tfrecord')
TEST_FILE = os.path.join(TEST_PATH, 'CNN_test.tfrecord')
LABELS = ["Interphase", "Prometaphase", "Metaphase", "Anaphase", "Apoptosis"]

### Generate TensorFlow Record (TFRecord) files:

In [6]:
def create_tf_record(
    root, 
    filename,
    labels=LABELS
):
    
    _images = []
    _labels = []
    
    # find the zip files:
    zipfiles = [os.path.join(root, f) for f in os.listdir(root) if f.endswith(".zip") and f.startswith("annotation_")]
    
    if len(zipfiles) == 0:
        raise Exception("Warning, no 'annotation' zip files found in the directory.")
    
    for zfn in zipfiles:
        print(f"Loading file: {zfn}")
        with zipfile.ZipFile(zfn, 'r') as zip_data:
            files = zip_data.namelist()

            for numeric_label, label in enumerate(labels):

                patch_files = [f for f in files if f.endswith(".tif") and f.startswith(label.capitalize())]
                images = [plt.imread(zip_data.open(f)) for f in patch_files]
                images_resized = [resize(img, (64, 64), preserve_range=True) for img in images]

                _images += images_resized
                _labels += [numeric_label] * len(images_resized)

                
    images_arr = np.stack(_images, axis=0)[..., np.newaxis]
    labels_arr = np.stack(_labels, axis=0)
    
    print(f"Total images: {images_arr.shape[0]}")
    write_dataset(filename, images_arr.astype(np.uint8), labels=labels_arr.astype(np.int64))

## IMPORTANT: 

**Prior to calling the function to create the TFRecods files:**

You need to manually drag the annotation_XXX.zip files into the newly created folders. If you are working in the Google Colab environment, click on the folder icon at the left-side dashboard, which should now contain the 'logs', 'train' and 'test' directories. They should be empty until you drag your annotation files into them.

Once the files have been imported, run the following cell:

In [7]:
create_tf_record(TRAIN_PATH, TRAIN_FILE)
create_tf_record(TEST_PATH, TEST_FILE)

FileNotFoundError: [Errno 2] No such file or directory: './train'

### Load the Model

By using the "load_model" function from the CellX library, we can import models without needing to specify the CellX custom layers that had been used to build them.

In [9]:
model_name = 'model'
model = load_model('{}.h5'.format(model_name))

NameError: name 'load_model' is not defined

In [None]:
model.summary()

### Build the test dataset, without augmentations

In [None]:
@augmentation_label_handler
def normalize(img):
    img = tf.image.per_image_standardization(img)
    # clip to 4 standard deviations
    img = tf.clip_by_value(img, -4., 4.)
    tf.debugging.check_numerics(img, "Image contains NaN")
    return img

In [None]:
test_dataset = build_dataset(TEST_FILE, read_label=True)
test_dataset = test_dataset.map(normalize)
test_dataset = test_dataset.take(-1).as_numpy_iterator()

test_images, test_labels = zip(*list(test_dataset))
test_images = np.array(test_images)
test_labels = np.array(test_labels)

## Run the Model on the Test_Images

The shape of test_images should be (N,64,64,1) where N is the number of individual images in the test set.

In [10]:
test_predictions = model.predict(test_images) # shape = (N,5)

NameError: name 'model' is not defined

The 'softmax' function transforms test_predictions into an array of scores for each class for each instance in the testing set. Across classes, the scores sum to one. The class associated with the highest score is the model's 'prediction'.

In [12]:
test_predictions = softmax(test_predictions,axis=1)

### Predictions on Testing Images

Sample N images out of the testing set to check the model's predictions on them.

In [None]:
def Show_Testing_Predictions(
    image_num_list, # indices of the examples in the testing set to be shown
    test_images
):
    plt.figure(figsize=(10,3*(int(len(image_num_list)/5)+1)))
    plt.suptitle('Predictions',fontsize=25,x=0.5,y=0.95)
    for image_num in image_num_list:
        plt.subplot(int(len(image_num_list)/5)+1,5,np.where(image_num_list==image_num)[0]+1)
        plt.imshow(test_images[image_num,:,:,0])
        plt.title('Image {}'.format(image_num))
        plt.yticks([])
        plt.xticks([])
        plt.xlabel(LABELS[np.argmax(test_predictions[image_num])])
    plt.show()

In [None]:
Show_Testing_Predictions(np.arange(21),test_images)

### Evaluation Metrics

In [None]:
## what is a confusion matrix? (explain)

In [None]:
loss,accuracy = model.evaluate(test_images, test_labels)

test_confusion_matrix = confusion_matrix(test_labels,test_predictions)
test_confusion_matrix_plot = plot_confusion_matrix(test_confusion_matrix,LABELS)
test_confusion_matrix_plot.show()

print('Testing Accuracy = ',accuracy)
print('Testing Loss = ',loss)

precision,recall,fscore,support = precision_recall_fscore_support(test_labels,test_predictions)
print('Testing Precision = ',precision)
print('Testing Recall = ',recall)