<a href="https://colab.research.google.com/github/JanRibasGarriga/WhoIsWho/blob/main/ColabNotebooks/WhoIsWho_Coculture_TrainingAndValidation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

---
# **WHO IS WHO - [CO-CULTURE / TRAINING & VALIDATION]**

<font size = 4> WhoIsWho is Google Colab-based tool that aims to classify cells based on features related to their nuclei and their neighbours.

<font size = 4> This notebook is meant to be used for the training and validation of the pipeline when using co-culture training datasets.

---



<font size = 4>Some blocks of the notebook are based on the paper...

*   <font size = 4> Stringer, Carsen, Tim Wang, Michalis Michaelos, and Marius Pachitariu. 2021. **“Cellpose: A Generalist Algorithm for Cellular Segmentation.”** Nature Methods 18 (1): 100–106. https://doi.org/10.1038/s41592-020-01018-x

  *   <font size = 4> The Original code is freely available on GitHub: https://github.com/MouseLand/cellpose

<font size = 4>and the webstie...

*   <font size = 4>**Transfer learning and fine-tuning**, 2017 François Chollet. https://www.tensorflow.org/tutorials/images/transfer_learning    

<font size = 4>Please also cite the original sources when developing this notebook.

In [None]:
#@markdown ##Double click to see the license information

#--- License for "Transfer learning and fine-tuning, 2017 François Chollet" ---
# MIT License
#
# Copyright (c) 2017 François Chollet                                                                                                                    # IGNORE_COPYRIGHT: cleared by OSS licensing
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

# **1. Initialise the Colab session**

---



In [None]:
#@markdown ##Run this cell to check if you have GPU access.

import tensorflow as tf
if tf.test.gpu_device_name()=='':
  gpu_available = False
  print("You are not connected to a GPU.") 
  print("Please, check the selected runtime type.") 
  print("Expect slow performance, especially if classification training needs to be performed.")

else:
  gpu_available = True
  print("GPU device available")

if gpu_available:
  gpu_info = !nvidia-smi
  gpu_info = '\n'.join(gpu_info)
  if gpu_info.find('failed') >= 0:
    print("To enable a GPU accelerator, check the selected runtime type.")
  else:
    print(gpu_info)

  from psutil import virtual_memory
  ram_gb = virtual_memory().total / 1e9
  print("Your runtime has {:.1f} gigabytes of available RAM\n".format(ram_gb))
  if ram_gb < 20:
    print("To enable a high-RAM runtime, check the selected runtime type.")
  else:
    print('You are using a high-RAM runtime!')

In [None]:
#@markdown ##Run this cell to mount your Google Drive in your virtual machine of this exectuable Colab document.

#@markdown * Click on the URL. 

#@markdown * Sign in your Google Account. 

#@markdown * Copy the authorization code. 

#@markdown * Paste the authorization code. 

from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

# **2. Install dependencies**

---

In [None]:
#@markdown ##Install general dependencies.

import os
import shutil
from PIL import Image
import skimage
from skimage.util.shape import view_as_windows
from skimage import io
from skimage import util
from skimage import filters
from skimage import img_as_ubyte, img_as_uint
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import random
import matplotlib.pyplot as plt
import seaborn as sns
from google.colab import files

#Colors for the warning messages
class bcolors:
  WARNING = '\033[31m'

W  = '\033[0m'  # white (normal)
R  = '\033[31m' # red

#Disable some of the warnings
import warnings
warnings.filterwarnings("ignore")



In [None]:
#@markdown ##Install Segmentation Block's dependencies.

!pip install tifffile
from tifffile import imread, imsave
from zipfile import ZIP_DEFLATED

!pip install cellpose 
from cellpose import models
from cellpose import plot 


In [None]:
#@markdown ##Install Localization Block's dependencies.

from skimage import morphology
from skimage import measure
import pandas as pd 

from PIL import Image
import skimage
from skimage import img_as_bool, img_as_ubyte
import numpy as np
from skimage import segmentation

import ipywidgets as widgets
import matplotlib.patches as patches

In [None]:
#@markdown ##Install Classification Block's dependencies.

import tensorflow
from tensorflow import keras

from keras.preprocessing import image_dataset_from_directory
from tensorflow.data.experimental import cardinality
from keras.layers.experimental import preprocessing
from keras.applications import xception
from keras.layers import GlobalAveragePooling2D, Dense, Dropout
from keras.optimizers import Nadam
from keras.losses import SparseCategoricalCrossentropy
from keras.callbacks import ModelCheckpoint, EarlyStopping, CSVLogger
from keras.models import load_model
from sklearn.metrics import confusion_matrix, classification_report

# **4. Data preparation**

---






## **4.1. Define path of the raw data**
---

<font size = 4> Before you run the following cell, please ensure that the training dataset has been properly stored on the respective Google Drive, following the following structure of paired images:

*   Data/ 
  *   TL/
      * image_1.tif 
      * image_2.tif 
      * image_3.tif
      * ...
  
  *   DAPI/
      * image_1.tif 
      * image_2.tif 
      * image_3.tif
      * ...

  *   EGFP/
      * image_1.tif 
      * image_2.tif 
      * image_3.tif
      * ...

  *   mCherry/
      * image_1.tif 
      * image_2.tif 
      * image_3.tif
      * ...

In [None]:
#@markdown ##Please specify the path to where the data is located in your Google Drive.

#@markdown <font size = 4> **`data_folder_path`**: is the path of the "Data/" folder illustrated above.

data_folder_path = '' #@param {type:"string"}
data_folder_path = data_folder_path + "/"

if os.path.isdir(data_folder_path):
  tl_folder_path = os.path.join(data_folder_path, "TL/")
  dapi_folder_path = os.path.join(data_folder_path, "DAPI/")
  EGFP_folder_path = os.path.join(data_folder_path, "EGFP/")
  mCherry_folder_path = os.path.join(data_folder_path, "mCherry/")

else:
  print(R+'!! WARNING: The path selected does not exist !!'+W)

In [None]:
#@markdown ##Run this cell to create the directory structure in /content. 

whoiswho_folder_path = "/content/WhoIsWho"

if os.path.isdir(whoiswho_folder_path):
  shutil.rmtree(whoiswho_folder_path)

if not os.path.isdir(whoiswho_folder_path):

  os.mkdir(whoiswho_folder_path)

  # Defining PATCHES paths and creating folders and subfolders
  patches_folder_path = os.path.join(whoiswho_folder_path, "Patches")
  os.mkdir(patches_folder_path)

  tlpatches_folder_path = os.path.join(patches_folder_path, "TL/")
  os.mkdir(tlpatches_folder_path)
  
  dapipatches_folder_path = os.path.join(patches_folder_path, "DAPI/")
  os.mkdir(dapipatches_folder_path)

  cellposelabelpatches_folder_path = os.path.join(patches_folder_path, "CellposeLabel/")
  os.mkdir(cellposelabelpatches_folder_path)

  EGFPpatches_folder_path = os.path.join(patches_folder_path, "EGFP/")
  os.mkdir(EGFPpatches_folder_path)
  
  mCherrypatches_folder_path = os.path.join(patches_folder_path, "mCherry/")
  os.mkdir(mCherrypatches_folder_path)
 
  regionpropsdf_folder_path = os.path.join(patches_folder_path, "Regionpropsdf/")
  os.mkdir(regionpropsdf_folder_path)
  
  # Defining SNIPPETS paths and creating folders and subfolders
  snippets_folder_path = os.path.join(whoiswho_folder_path, "Snippets")
  os.mkdir(snippets_folder_path)

  tlsnippets_folder_path = os.path.join(snippets_folder_path, "TL/")
  os.mkdir(tlsnippets_folder_path)
  
  # Defining FIGURES paths and creating folders and subfolders
  figures_folder_path = os.path.join(whoiswho_folder_path, "Figures")
  os.mkdir(figures_folder_path)

  generalfigures_folder_path = os.path.join(figures_folder_path, "GeneralFigures")
  os.mkdir(generalfigures_folder_path)

  cellposefigures_folder_path = os.path.join(figures_folder_path, "CellposeFigures")
  os.mkdir(cellposefigures_folder_path)

  # Defining MODEL paths and creating folders and subfolders
  model_folder_path = os.path.join(whoiswho_folder_path, "Model")
  os.mkdir(model_folder_path)

print("/content/WhoIsWho directory just created")
print("Files will be stored under this directory")
print("Remember to run the last section of the notebook to select and download the results to your computer")


In [None]:
#@markdown ##Run to display and save an example of an images.

random_choice = random.choice(os.listdir(tl_folder_path))

tl_sample = io.imread(os.path.join(tl_folder_path, random_choice), as_gray = True)
dapi_sample = io.imread(os.path.join(dapi_folder_path, random_choice), as_gray = True)
EGFP_sample = io.imread(os.path.join(EGFP_folder_path, random_choice), as_gray = True)
mCherry_sample = io.imread(os.path.join(mCherry_folder_path, random_choice), as_gray = True)

print(random_choice)

f = plt.figure(figsize = (20,10))

plt.subplot(1,4,1)
plt.imshow(tl_sample, interpolation = 'nearest', cmap = 'gray')
plt.title('TL image')
plt.axis('off');

plt.subplot(1,4,2)
plt.imshow(dapi_sample, interpolation = 'nearest', cmap = 'gray')
plt.title('DAPI image')
plt.axis('off');

plt.subplot(1,4,3)
plt.imshow(EGFP_sample, interpolation = 'nearest', cmap = 'gray')
plt.title('EGFP image')
plt.axis('off');

plt.subplot(1,4,4)
plt.imshow(mCherry_sample, interpolation = 'nearest', cmap = 'gray')
plt.title('mCherry image')
plt.axis('off');

base_name_image = random_choice.replace(".tif", "")
figure_name = (f"Image_{base_name_image}")

plt.savefig(os.path.join(generalfigures_folder_path, figure_name + ".png"),
            dpi = 300, format = "png", bbox_inches = None)

## **4.2. Creation of patches**

---



In [None]:
def fittingPowerOfTwo(number):
  n = 0
  while 2**n <= number:
    n += 1 
  return 2**(n-1)

def estimatePatchSize(data_path, max_width = 512, max_height = 512):

  files = os.listdir(data_path)
  
  # Get the size of the first image found in the folder and initialise the variables to that
  n = 0 
  while os.path.isdir(os.path.join(data_path, files[n])):
    n += 1
  (height_min, width_min) = Image.open(os.path.join(data_path, files[n])).size

  # Screen the size of all dataset to find the minimum image size
  for file_ in files:
    if not os.path.isdir(os.path.join(data_path, file_)):
      (height, width) = Image.open(os.path.join(data_path, file_)).size
      if width < width_min:
        width_min = width
      if height < height_min:
        height_min = height
    break
  
  # Find the power of patches that will fit within the smallest dataset
  width_min, height_min = (fittingPowerOfTwo(width_min), fittingPowerOfTwo(height_min))

  # Clip values at maximum permissible values
  if width_min > max_width:
    width_min = max_width

  if height_min > max_height:
    height_min = max_height
  
  return (width_min, height_min)
  
#@markdown ##Run to define the dimensions of the patches. 

#@markdown <font size = 4> The notebook crops the data in patches of fixed size. The largest 2^n x 2^n patch that fits in the data being used so as to ensure no overlapping between patches from the same image. Yet, the largest patch available is 512 x 512 in order to preserve the stability of the entire pipeline.

patch_width, patch_height = estimatePatchSize(
    tl_folder_path
)

print(f"Patch dimensions (patch_width, patch_height): ({patch_width}, {patch_height}).")

In [None]:
#@markdown ## Run to crop and filter patches of each image.

#@markdown ### Please select whether you want to use the default patches parameters.

Use_default_patches_parameters = True #@param {type:"boolean"}

#@markdown ### Otherwise, please define the parameter:

#@markdown <font size = 4>**`ratio_cells_vs_background`**: threshold of the fraction (%) of foreground pixels against the fraction of background pixels for each patch so as to be considered for the following steps of the pipeline. When `Use_default_patches_parameters` is enabled, the parameter is set to 5%.

ratio_cells_vs_background = 5 #@param {type:"slider", min:1, max:100, step:1}

if Use_default_patches_parameters:
  ratio_cells_vs_background = 5

def convert2mask(image, threshold = None):
  if threshold == None:
    threshold = filters.threshold_otsu(image)
  image[image > threshold] = 255
  image[image <= threshold] = 0
  return image

def normalizeMinMax(x, dtype = np.float32):
  x = x.astype(dtype,copy=False)
  x = (x - np.amin(x)) / (np.amax(x) - np.amin(x))
  return x

def create_patches(tl_folder_path, dapi_folder_path, EGFP_folder_path, mCherry_folder_path, patch_width, patch_height):
  patch_num = 1

  for filename in tqdm(sorted(os.listdir(tl_folder_path))):

    basename_filename = os.path.splitext(filename)[0]

    tl = io.imread(os.path.join(tl_folder_path, filename))
    dapi = io.imread(os.path.join(dapi_folder_path, filename))
    mask = img_as_ubyte(convert2mask(io.imread(os.path.join(dapi_folder_path, filename))))
    EGFP = io.imread(os.path.join(EGFP_folder_path, filename))
    mCherry = io.imread(os.path.join(mCherry_folder_path, filename))

    # Using view_as_windows with step size equal to the patch size to ensure there is no overlap
    patches_tl = view_as_windows(tl, (patch_width, patch_height), (patch_width, patch_height))
    patches_dapi = view_as_windows(dapi, (patch_width, patch_height), (patch_width, patch_height))
    patches_mask = view_as_windows(mask, (patch_width, patch_height), (patch_width, patch_height))
    patches_EGFP = view_as_windows(EGFP, (patch_width, patch_height), (patch_width, patch_height))
    patches_mCherry = view_as_windows(mCherry, (patch_width, patch_height), (patch_width, patch_height))

    patches_tl = patches_tl.reshape(patches_tl.shape[0]*patches_tl.shape[1], patch_width, patch_height)
    patches_dapi = patches_dapi.reshape(patches_dapi.shape[0]*patches_dapi.shape[1], patch_width, patch_height)
    patches_mask = patches_mask.reshape(patches_mask.shape[0]*patches_mask.shape[1], patch_width, patch_height)
    patches_EGFP = patches_EGFP.reshape(patches_EGFP.shape[0]*patches_EGFP.shape[1], patch_width, patch_height)
    patches_mCherry = patches_mCherry.reshape(patches_mCherry.shape[0]*patches_mCherry.shape[1], patch_width, patch_height)

    for i in range(patches_tl.shape[0]):
      tl_save_path = os.path.join(tlpatches_folder_path, basename_filename+'--PATCH'+str(patch_num)+'.tif')
      dapi_save_path = os.path.join(dapipatches_folder_path, basename_filename+'--PATCH'+str(patch_num)+'.tif')
      EGFP_save_path = os.path.join(EGFPpatches_folder_path, basename_filename+'--PATCH'+str(patch_num)+'.tif')
      mCherry_save_path = os.path.join(mCherrypatches_folder_path, basename_filename+'--PATCH'+str(patch_num)+'.tif')

      patch_num += 1

      counts = np.count_nonzero(patches_mask[i] == 255)
      ratio = (counts / (patches_mask[i].shape[0] * patches_mask[i].shape[1])) * 100

      if ratio >= ratio_cells_vs_background:
        io.imsave(tl_save_path, img_as_ubyte(normalizeMinMax(patches_tl[i])))
        io.imsave(dapi_save_path, img_as_ubyte(normalizeMinMax(patches_dapi[i])))
        io.imsave(EGFP_save_path, img_as_ubyte(patches_EGFP[i]))
        io.imsave(mCherry_save_path, img_as_ubyte(patches_mCherry[i]))

create_patches(tl_folder_path, 
               dapi_folder_path,
               EGFP_folder_path, 
               mCherry_folder_path, 
               patch_width, patch_height,
)

print("Patches just created")

In [None]:
#@markdown ##Run to display and save an example of a patch.

random_choice = random.choice(os.listdir(tlpatches_folder_path))

tl_sample = io.imread(os.path.join(tlpatches_folder_path, random_choice), as_gray = True)
dapi_sample = io.imread(os.path.join(dapipatches_folder_path, random_choice), as_gray = True)
EGFP_sample = io.imread(os.path.join(EGFPpatches_folder_path, random_choice), as_gray = True)
mCherry_sample = io.imread(os.path.join(mCherrypatches_folder_path, random_choice), as_gray = True)

print(random_choice)

f = plt.figure(figsize = (20,10))

plt.subplot(1,4,1)
plt.imshow(tl_sample, interpolation = 'nearest', cmap = 'gray')
plt.title('TL image')
plt.axis('off');

plt.subplot(1,4,2)
plt.imshow(dapi_sample, interpolation = 'nearest', cmap = 'gray')
plt.title('DAPI image')
plt.axis('off');

plt.subplot(1,4,3)
plt.imshow(EGFP_sample, interpolation = 'nearest', cmap = 'gray')
plt.title('EGFP image')
plt.axis('off');

plt.subplot(1,4,4)
plt.imshow(mCherry_sample, interpolation = 'nearest', cmap = 'gray')
plt.title('mCherry image')
plt.axis('off');

base_name_patch = random_choice.replace(".tif", "")
figure_name = (f"Patch_{base_name_patch}")

plt.savefig(os.path.join(generalfigures_folder_path, figure_name + ".png"),
            dpi = 300, format = "png", bbox_inches = None)


# **5. Segmentation Block**
---









In [None]:
#@markdown ## Run to perform image segmentation using Cellpose

#@markdown ### Please select whether you want to use the default segmentation parameters.

Use_default_segmentation_parameters = True #@param {type:"boolean"}

#@markdown ### Otherwise, please define the parameters:

#@markdown <font size = 4>**`object_diameter`**: approximate value of the diameter of the nuclei. If `Use_default_segmentation_parameters` is enabled, the parameter is set to 25.

#@markdown <font size = 4>**`Flow_threshold`**: flow error threshold that dictates which ROIs are kept or removed. If `Use_default_segmentation_parameters` is enabled, the parameter is set to 0.4.

Object_diameter = 24.5 #@param {type:"slider", min:7, max:30, step:1}

Flow_threshold = 0.4 #@param {type:"slider", min:0.1, max:1.1, step:0.1}

#@markdown ### Please, specify if you want to save the Cellpose's segmentation plot.

#@markdown <font size = 4>**`save_cellpose`**: if enabled, Cellpose's segmentation plot will be displayed and stored.

save_cellpose = False #@param {type:"boolean"}

Channel_to_segment = [0,0]
if Use_default_segmentation_parameters:
  print("Default segmentation parameters enabled")
  Object_diameter = 25
  Flow_threshold = 0.4

model = models.Cellpose(gpu = True, 
                        model_type = "nuclei",
                        net_avg = True,
                        device = None,
                        torch = False
)

for filename in sorted(os.listdir(dapipatches_folder_path)):

  print(f"Cellpose is computing labels from {filename}")

  image = io.imread(os.path.join(dapipatches_folder_path, filename))
  masks, flows, styles, diams = model.eval(image, 
                                          batch_size = 8,
                                          channels = Channel_to_segment,
                                          invert = False,
                                          normalize = False,
                                          diameter = Object_diameter,
                                          net_avg = True,
                                          augment = False,
                                          tile = True,
                                          resample = False,
                                          interp = True,
                                          flow_threshold = Flow_threshold,
                                          cellprob_threshold = 0
  )
  
  if save_cellpose:
    basename_filename = os.path.splitext(filename)[0]
    flowi = flows[0]
    fig = plt.figure(figsize=(12,5))
    plot.show_segmentation(fig, image, masks, flowi, channels=[0,0])
    plt.tight_layout()
    plt.savefig(os.path.join(cellposefigures_folder_path, str(basename_filename) + ".png"), 
                dpi = 100, format = "png", bbox_inches = None)

  os.chdir(cellposelabelpatches_folder_path)
  imsave(str(filename), masks, compress = ZIP_DEFLATED)


# **6. Localization Block**

---



In [None]:
def measure_regionprops(labelpatch_folder_path, filename):
  '''
  This function computes image properties from a labelled image regions
  The information extracted will be used to crop each cell as snippet and to measure the mean intensity on the fluorescence channels

  INPUT:
  labelatch_folder_path: folder path of the label patches
  filename: unique name of the patch

  OUTPUT:
  regionprops_df: two-dimensional, size-mutable, potentially heterogeneous tabular data
  '''
  # Read label patches
  label = io.imread(os.path.join(labelpatch_folder_path, filename))

  # Compute image properties and return them as a pandas-compatible table
  regionprops_dict = measure.regionprops_table(label, 
                                               properties = ["label", 
                                                             "centroid", 
                                                             "area", 
                                                             "equivalent_diameter"]
  )

  # Convert the image properties table to a dataframe
  regionprops_df = pd.DataFrame()
  regionprops_df = pd.DataFrame(regionprops_dict)

  # Rename a few columns to be aligned to the structure (axis) of the data
  regionprops_df.rename(columns={"label": "cell_label",
                                 "centroid-0": "cx", 
                                 "centroid-1": "cy", 
                                 "equivalent_diameter": "diameter"}, 
                        inplace = True
  )

  # Remove rows in which nuclei's areas are under an specific value
  regionprops_df = regionprops_df[regionprops_df["area"] > 250]

  # Calculate bounding box (bbox) coordinates and store them directly into the dataframe
  regionprops_df["xmin"] = regionprops_df["cx"] - regionprops_df["diameter"]
  regionprops_df["xmax"] = regionprops_df["cx"] + regionprops_df["diameter"]
  regionprops_df["ymin"] = regionprops_df["cy"] - regionprops_df["diameter"]
  regionprops_df["ymax"] = regionprops_df["cy"] + regionprops_df["diameter"]

  # Remove rows in which bboxes' dimenions are out of patch's original dimensions
  regionprops_df = regionprops_df[(regionprops_df["xmin"] >= 0) &  
                                  (regionprops_df["xmax"] <= label.shape[1]) & 
                                  (regionprops_df["ymin"] >= 0) & 
                                  (regionprops_df["ymax"] <= label.shape[0])
  ]

  # Reset indexes of the dataframe and drop the previous index column
  regionprops_df.reset_index(inplace = True)
  regionprops_df.drop(labels = ["index"], axis = 1, inplace = True)

  # Add the name of the snippet in the dataframe as id value
  for idx in regionprops_df.index: 
    regionprops_df.loc[[idx], ["snippet_id"]] = (filename[:-4] + "--SNIPPET" + str(regionprops_df["cell_label"][idx]) + ".png")

  return regionprops_df

def compute_meanintensity(filename, labelpatch_folder_path, EGFPpatches_folder_path, mCherrypatches_folder_path, regionprops_df):
  '''
  This function computes the mean pixel intensity of both fluorescence channels of the nucleus in each bounding box
  The information extracted is added to the respective row of the regionprops_df

  INPUT:
  filename: unique name of the patch
  labelpatch_folder_path: folder path of the label patches
  EGFPpatches_folder_path: folder path of the EGFP patchs
  mCherrypatches_folder_path: folder path of the mCherry patches
  regionprops_df: two-dimensional, size-mutable, potentially heterogeneous tabular data of the mask properties
  '''
  # Read patches as PIL images
  EGFP = Image.open(os.path.join(EGFPpatches_folder_path, filename))
  mCherry = Image.open(os.path.join(mCherrypatches_folder_path, filename))

  # Read label patches
  label = util.img_as_ubyte(io.imread(os.path.join(labelpatch_folder_path, filename)))

  for idx in regionprops_df.index:
    # --- MASK ---
    # Remove neighbour cells which are on the snippet
    region_of_interest = np.ones((label.shape[0], label.shape[1]), dtype = "uint8") * regionprops_df["cell_label"][idx]
    mask = (label == region_of_interest)

    # Convert boolean mask patch to PIL image
    mask = Image.fromarray(mask, mode=None)

    # Crop snippet from mask patch 
    mask_snippet = mask.crop((regionprops_df["ymin"][idx],
                              regionprops_df["xmin"][idx],
                              regionprops_df["ymax"][idx],
                              regionprops_df["xmax"][idx])
    )
    # Convert to bool numpy array 
    mask_snippet = img_as_bool(mask_snippet)

    # --- EGFP ---
    # Crop snippet from EGFP patch 
    EGFP_snippet = EGFP.crop((regionprops_df["ymin"][idx],
                              regionprops_df["xmin"][idx],
                              regionprops_df["ymax"][idx],
                              regionprops_df["xmax"][idx])
    )
    # Convert to uint8 numpy array
    EGFP_snippet = img_as_ubyte(np.array(EGFP_snippet))

    # --- mCherry ---
    # Crop snippet from mCherry patch 
    mCherry_snippet = mCherry.crop((regionprops_df["ymin"][idx],
                              regionprops_df["xmin"][idx],
                              regionprops_df["ymax"][idx],
                              regionprops_df["xmax"][idx])
    )
    # Convert to uint8 numpy array
    mCherry_snippet = img_as_ubyte(np.array(mCherry_snippet))

    # --- MULTIPLICATION AND STORAGE ---
    # Multiply mask and both fluorescence snippets
    EGFPbymask_snippet = mask_snippet * EGFP_snippet
    mCherrybymask_snippet = mask_snippet * mCherry_snippet
    # Store the mean of both multiplication results into the dataframe
    regionprops_df.loc[[idx], ["meanintensityEGFP"]] = EGFPbymask_snippet.mean()
    regionprops_df.loc[[idx], ["meanintensitymCherry"]] = mCherrybymask_snippet.mean()

def save_regionpropsdataframe(save_regionprops_folder_path, regionprops_df, filename):
  '''
  This function saves the dataframe of the patch being analyzed as .csv

  INPUT:
  save_folder_path: folder path in which dataframe will be saved
  regionprops_df: two-dimensional, size-mutable, potentially heterogeneous tabular data of the mask properties
  filename: unique name of the patch
  '''
  # Get the basename of the file
  basename = filename.rstrip(".tif")

  # Get the complete path (folder path and filename) where the dataframe will be saved
  complete_save_path = os.path.join(save_regionprops_folder_path, 
                                    (basename + ".csv")
  )

  # Save the dataframe as .xlsx 
  regionprops_df.to_csv(complete_save_path, index = False)

def crop_tlsnippet(tlpatches_folder_path, regionprops_df, save_tlsnippet_folder_path, filename):
  '''
  This function saves all the snippets from the transmitted-light patch according to the bounding boxes dimenions

  INPUT:
  tlpatches_folder_path: folder path of the transmitted-light patches
  regionprops_df: two-dimensional, size-mutable, potentially heterogeneous tabular data of the mask properties
  save_tlsnippet_folder_path: folder path in which transmitted-light snippets will be saved
  filename: unique name of the patch
  '''
  # Read patch as PIL images
  tl_image = Image.open(os.path.join(tlpatches_folder_path, filename))

  for idx in regionprops_df.index:
    # Get the complete path (folder path and filename) where the snippet will be saved
    complete_save_path = os.path.join(save_tlsnippet_folder_path, 
                                      (str(regionprops_df["snippet_id"][idx]))
    )

    # Crop snippet from transmitted light patch 
    tl_snippet = tl_image.crop((regionprops_df["ymin"][idx], 
                                regionprops_df["xmin"][idx], 
                                regionprops_df["ymax"][idx], 
                                regionprops_df["xmax"][idx])
    )

    # Convert to uint8 numpy array
    tl_snippet = img_as_ubyte(np.array(tl_snippet))

    # Save the snippet into the folder
    io.imsave(complete_save_path, tl_snippet)

#@markdown ##Run to compute TL snippets and regionprops dataframe.

print("Snippets being created...")

for filename in sorted(os.listdir(cellposelabelpatches_folder_path)):

  check_label = io.imread(os.path.join(cellposelabelpatches_folder_path, filename))
  if len(np.unique(check_label)) > 1:
    regionprops_df = measure_regionprops(cellposelabelpatches_folder_path, filename)
    compute_meanintensity(filename, cellposelabelpatches_folder_path, EGFPpatches_folder_path, mCherrypatches_folder_path, regionprops_df)
    save_regionpropsdataframe(regionpropsdf_folder_path, regionprops_df, filename)
    crop_tlsnippet(tlpatches_folder_path, regionprops_df, tlsnippets_folder_path, filename)

  else:
    os.remove(os.path.join(tlpatches_folder_path, filename))
    os.remove(os.path.join(dapipatches_folder_path, filename))
    os.remove(os.path.join(cellposelabelpatches_folder_path, filename))
    if save_cellpose:
      figure_filename = filename.replace(".tif", ".png")
      os.remove(os.path.join(cellposefigures_folder_path, figure_filename))

print("\n") 
print("Snippets just created")

print("\n") 
number_of_snippets = len([filename for filename in os.listdir(tlsnippets_folder_path)])
print(f"TOTAL: {(number_of_snippets)} snippets")

# The tabular data computed for each patch is stored to a common tabular data.
general_regionprops_df = pd.DataFrame(data = None, index = None, columns = None)
for filename in sorted(os.listdir(regionpropsdf_folder_path)):
  new_patch_df = pd.read_csv(os.path.join(regionpropsdf_folder_path, filename), index_col = False)
  general_regionprops_df = general_regionprops_df.append(new_patch_df, ignore_index = True)

# The outlier instances of the tabular data are removed
q_low_EGFP = general_regionprops_df["meanintensityEGFP"].quantile(0.01)
q_hi_EGFP  = general_regionprops_df["meanintensityEGFP"].quantile(0.99)
q_low_mCherry = general_regionprops_df["meanintensitymCherry"].quantile(0.01)
q_hi_mCherry  = general_regionprops_df["meanintensitymCherry"].quantile(0.99)

filtered_general_regionprops_df = general_regionprops_df[
                                                  (general_regionprops_df["meanintensityEGFP"] < q_hi_EGFP) & 
                                                  (general_regionprops_df["meanintensityEGFP"] > q_low_EGFP) & 
                                                  (general_regionprops_df["meanintensitymCherry"] < q_hi_mCherry) & 
                                                  (general_regionprops_df["meanintensitymCherry"] > q_low_mCherry)
]

# Compute the log value of each meanintensityEGFP and meanintensityEGFP attributes.
filtered_general_regionprops_df['meanintensityEGFP_log'] = np.log2(filtered_general_regionprops_df['meanintensityEGFP'])
filtered_general_regionprops_df['meanintensitymCherry_log'] = np.log2(filtered_general_regionprops_df['meanintensitymCherry'])




---


---



<font size = 4> The two following steps of the Localization Block are meant to be used to manually cluster each group of cells according to the cell line they belong to. Thus, you should repetitively run both cells one after the other until each population of cells has been labelled.

<font size = 4> On the first cell of code, you should use the sliders displayed in order to adjust the position and shape of the rectangle to the proper location. Then, on the following cell of code, you shall specify the name of the cell line. 

In [None]:
#@markdown ##Run to plot the distribution of the TL snippets according to the mean intensity of both fluorescence channels.

xmin = filtered_general_regionprops_df["meanintensityEGFP_log"].min() - 0.25
xmax = filtered_general_regionprops_df["meanintensityEGFP_log"].max() + 0.25
xmean = filtered_general_regionprops_df["meanintensityEGFP_log"].mean()
ymin = filtered_general_regionprops_df["meanintensitymCherry_log"].min() - 0.25 
ymax = filtered_general_regionprops_df["meanintensitymCherry_log"].max() + 0.25
ymean = filtered_general_regionprops_df["meanintensitymCherry_log"].mean()

filtered_general_regionprops_df["cell_line"] = "undefined"
cell_lines_list = []

def update_plot(x, y, width, height):
  fig, ax = plt.subplots(figsize=(16,6))
  ax.scatter(
    x = filtered_general_regionprops_df["meanintensityEGFP_log"],
    y = filtered_general_regionprops_df["meanintensitymCherry_log"],
    alpha = 0.2
  )

  plt.xlim(xmin, xmax)
  plt.ylim(ymin, ymax)
  plt.xlabel("Mean nuclear pixel intensity EGFP (AU)", fontsize = 12) 
  plt.ylabel("Mean nuclear pixel intensity mCherry (AU)", fontsize = 12) 
  ax.add_patch(
      patches.Rectangle((x,y), width, height, angle = 0, color ='forestgreen', alpha = 0.5
  ))

x_slider = widgets.FloatSlider(min = xmin, max = xmax, step = 0.1, value = xmean)
y_slider = widgets.FloatSlider(min = ymin, max = ymax, step = 0.1, value = ymean)
width_slider = widgets.FloatSlider(min = 0.25, max = 2, step = 0.1, value = 1)
height_slider = widgets.FloatSlider(min = 0.25, max = 2, step = 0.1, value = 1)

interactive_plot = widgets.interactive(
    update_plot, 
    x = x_slider,
    y = y_slider,
    width = width_slider,
    height = height_slider
)
output = interactive_plot.children[-1]
display(interactive_plot)

In [None]:
#@markdown ## Run to label the cells just selected.

#@markdown <font size = 4>**`cell_line_name`**: name of the cell line selected by the green rectangle in the previous step.

cell_line_name = "" #@param {type:"string"}
cell_lines_list.append(cell_line_name)

for idx in filtered_general_regionprops_df.index:
  if (((filtered_general_regionprops_df["meanintensityEGFP_log"][idx]) >= x_slider.value) & ((filtered_general_regionprops_df["meanintensityEGFP_log"][idx]) <= (x_slider.value + width_slider.value)) & ((filtered_general_regionprops_df["meanintensitymCherry_log"][idx]) >= (y_slider.value)) & ((filtered_general_regionprops_df["meanintensitymCherry_log"][idx]) <= (y_slider.value + height_slider.value))):
    filtered_general_regionprops_df["cell_line"][idx] = cell_line_name

plt.figure(figsize=(16, 6))
ax = sns.scatterplot(
    x = filtered_general_regionprops_df["meanintensityEGFP_log"],
    y = filtered_general_regionprops_df["meanintensitymCherry_log"],
    alpha = 0.2,
    hue = filtered_general_regionprops_df["cell_line"]
)
ax.set_xlabel("Mean nuclear pixel intensity EGFP (AU)", fontsize = 12) 
ax.set_ylabel("Mean nuclear pixel intensity mCherry (AU)", fontsize = 12)

plt.savefig(os.path.join(generalfigures_folder_path, "CocultureScatterPlot" + ".png"),
                dpi = 300, format = "png", bbox_inches = None)



---



---




In [None]:
#@markdown ##Run to organize the snippets into the respective folders.

number_of_snippets = []

for cell_line in sorted(cell_lines_list):
  tlsnippets_cellline_folder_path = os.path.join(tlsnippets_folder_path, cell_line)
  os.mkdir(tlsnippets_cellline_folder_path)
  for idx in filtered_general_regionprops_df.index:
    if filtered_general_regionprops_df["cell_line"][idx] == cell_line:
      shutil.move(os.path.join(tlsnippets_folder_path, filtered_general_regionprops_df["snippet_id"][idx]),
                  os.path.join(tlsnippets_cellline_folder_path, filtered_general_regionprops_df["snippet_id"][idx]))
  number_of_files = len([filename for filename in os.listdir(tlsnippets_cellline_folder_path) if os.path.isfile(os.path.join(tlsnippets_cellline_folder_path, filename))])
  number_of_snippets.append(len([filename for filename in os.listdir(tlsnippets_cellline_folder_path) if os.path.isfile(os.path.join(tlsnippets_cellline_folder_path, filename))]))
  print(f"{cell_line} snippets have just been stored in the respective folder.")
  print(f"{number_of_files} {cell_line} snippets.")
print("\n") 
print(f"TOTAL: {sum(number_of_snippets)} snippets of {len(cell_lines_list)} cell lines.")

In [None]:
#@markdown ##Run to remove the "undefined" snippets.

filelist = [f for f in os.listdir(tlsnippets_folder_path) if f.endswith(".png") ]
for filename in filelist:
  os.remove(os.path.join(tlsnippets_folder_path, filename))

complete_save_path = os.path.join(regionpropsdf_folder_path, "_filtered_general_regionprops.csv")
filtered_general_regionprops_df.to_csv(complete_save_path, index = False)

In [None]:
#@markdown ##Run to display and save an example of each snippet.

i = 0

for cell_line_name in sorted(cell_lines_list):
  tlsnippets_cellline_folder_path = os.path.join(tlsnippets_folder_path, cell_line_name)

  for filename in os.listdir(tlsnippets_cellline_folder_path):

    tl_sample = io.imread(os.path.join(tlsnippets_cellline_folder_path, filename), as_gray = True)

    print(filename)

    f = plt.figure(figsize = (5,5))

    plt.subplot(1,1,1)
    plt.imshow(tl_sample, interpolation = 'nearest', cmap = 'gray')
    plt.title(f'{cell_line_name} Bright-field snippet ')
    plt.axis('off');

    base_name_snippet = filename.replace(".png", "")
    figure_name = (f"Snippet{cell_line_name}_{base_name_patch}")

    plt.savefig(os.path.join(generalfigures_folder_path, figure_name + ".png"),
                dpi = 300, format = "png", bbox_inches = None)
    
    break

  i += 1


# **7. Classification Block**

---



## **7.1. Define main parameters**

---



In [None]:
#@markdown ##Please specify the path where the model is going to be saved.

#@markdown <font size = 4> **`drivemodel_folder_path`**: is the Google Drive's folder path where the trained classification neural network is going to be stored.

drivemodel_folder_path = '' #@param {type:"string"}
drivemodel_folder_path = drivemodel_folder_path + "/" 

In [None]:
#@markdown #Please specify the main training parameters.

#@markdown ### Please select whether you want to use the default classification parameters.

Use_default_classification_parameters = True #@param {type:"boolean"}

#@markdown ### Otherwise, please define the parameters:

#@markdown <font size = 4>**`initial_epochs`**: number of epochs for the first training stage. If `Use_default_classification_parameters` is enabled, the parameter is set to 10.

initial_epochs = 10 #@param {type:"number"}

#@markdown <font size = 4>**`fine_tune_epochs`**: number of epochs for the second training stage. If `Use_default_classification_parameters` is enabled, the parameter is set to 10.

fine_tune_epochs = 10 #@param {type:"number"}

#@markdown <font size = 4>**`batch_size`**: This parameter describes the amount of images that are loaded into the network per step. Smaller batchsizes may improve training performance slightly but may increase training time. If the notebook crashes while loading the dataset this can be due to a too large batch size. Decrease the number in this case. If `Use_default_classification_parameters` is enabled, the parameter is set to 4.

batch_size =  2 #@param [2, 4, 8, 16, 32, 64, 128, 256] {type:"raw"}

#@markdown <font size = 4>**`percentage_validation`**: percentage of your training dataset you want to use to validate the network during training. If `Use_default_classification_parameters` is enabled, the parameter is set to 10.

percentage_validation = 10  #@param {type:"slider", min:1, max:100, step:1}

#@markdown <font size = 4>**`EarlyStopping_patience`**: number of epochs with no improvement after which training will be stopped when using the EarlyStopping method. If `Use_default_classification_parameters` is enabled, the parameter is set to 3.

patience = 3 #@param {type:"slider", min:1, max:10, step:1}

#@markdown <font size = 4>**`monitor_callback`**: metric to be monitored on the EarlyStopping and ModelCheckpoint methods. If `Use_default_classification_parameters` is enabled, the parameter is set to "val_loss".

monitor_callback = "val_loss" #@param ["accuracy", "loss", "val_accuracy", "val_loss"]

#@markdown ###Are you debugging?

#@markdown <font size = 4>**`DEBUG`**: enable the parameter if you want to know some extra information for each step of the classification block

DEBUG = True #@param {type:"boolean"}

if Use_default_classification_parameters: 
  print("Default classification parameters enabled")
  initial_epochs = 10 
  fine_tune_epochs = 10
  batch_size = 4
  percentage_validation = 10
  patience = 3
  monitor_callback = "val_loss"

img_height = 299
img_width = 299
img_size = (img_height, img_width)
total_epochs =  initial_epochs + fine_tune_epochs

class_weights = {}
for idx, item in enumerate(sorted(cell_lines_list)):
  class_weights[idx] = max(number_of_snippets) / number_of_snippets[idx]

## **7.2. Creation of train and validation dataset**

---



In [None]:
#@markdown ##Run to create the training  and validation datasets.

print("TRAINING DATASET") 
train_dataset = image_dataset_from_directory(
    directory = tlsnippets_folder_path,
    labels = "inferred",
    label_mode = "int",
    class_names = None, 
    color_mode = "rgb", 
    batch_size = batch_size, 
    image_size = (img_height, img_width),
    shuffle = True, 
    seed = 32, 
    validation_split = (percentage_validation / 100), 
    subset = "training",
    follow_links = False
)
print(f"Number of training batches: {tf.data.experimental.cardinality(train_dataset)}")

print("\n") 
print("VALIDATION DATASET") 
validation_dataset = image_dataset_from_directory(
    directory = tlsnippets_folder_path,
    labels = "inferred",
    label_mode = "int",
    class_names = None, 
    color_mode = "rgb", 
    batch_size = batch_size, 
    image_size = (img_height, img_width),
    shuffle = True, 
    seed = 32, 
    validation_split = (percentage_validation / 100), 
    subset = "validation", 
    follow_links = False
)

print(f"Number of validation batches: {cardinality(validation_dataset)}.")

if DEBUG: 
  print("\n") 
  class_names = train_dataset.class_names
  print(f"Name of the classes found: {class_names}")

AUTOTUNE = tf.data.AUTOTUNE
train_dataset = train_dataset.cache().prefetch(buffer_size=AUTOTUNE)
validation_dataset = validation_dataset.cache().prefetch(buffer_size=AUTOTUNE)
if DEBUG: 
  print("\n") 
  print("Enabled buffered prefetching so you can yield data from disk without having I/O become blocking.")

## **7.3. Set the data augmentation**

---



In [None]:
#@markdown ##Run to configure the data augmentation layer, as well as display and save an example.

data_augmentation = tf.keras.Sequential([
  preprocessing.RandomFlip(mode = "horizontal_and_vertical"),
  preprocessing.RandomRotation(0.2, fill_mode = 'reflect', interpolation = 'bilinear'),
  preprocessing.RandomZoom(0.2, fill_mode = 'reflect', interpolation = 'bilinear')
])

for image, _ in train_dataset.take(1):
  plt.figure(figsize=(10, 10))
  first_image = image[0]
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    augmented_image = data_augmentation(tf.expand_dims(first_image, 0))
    plt.imshow(augmented_image[0] / 255)
    plt.axis('off')

    plt.savefig(os.path.join(model_folder_path, "DataAugmentationExample.png"), 
                dpi = 300, format = "png", bbox_inches = None)

## **7.4. Build the model**

---



In [None]:
#@markdown ##Run to build the base model and stack the classification layers on top.

# Preprocessing layer according to the model chosen
preprocess_input = xception.preprocess_input

# LOAD IN THE PRETRAINED BASE MODEL (AND PRETRAINED WEIGHTS)
print("Loading in the pretrained base model (and pretrained weights)...")
print("\n")
# Creation of the base model
img_shape = img_size + (3,)
base_model = xception.Xception(
    include_top = False, 
    weights = 'imagenet',  
    input_shape = img_shape
)
if DEBUG:
  image_batch, label_batch = next(iter(train_dataset))
  feature_batch = base_model(image_batch)
  print("The base model act as a feature extractor.")
  print(f"Snippets are reshaped from 299x299x3 into a {feature_batch.shape[1:]} block of features.")
  print("\n") 

# Freeze the convolutional base
base_model.trainable = False
if DEBUG:
  print("Freezed the convolutional base.")
  print("It will prevent the weights of the base model from being updated during training.")
  base_model.summary()
  print("\n") 

# STACK THE CLASSIFICATION LAYERS ON TOP
print("Stacking the classification layers on top...")
print("\n")

# Addition of a classification head
global_average_layer = GlobalAveragePooling2D()
if DEBUG:
  feature_batch_average = global_average_layer(feature_batch)
  print("Applied a GlobalAveragePooling2D() layer.")
  print(f"To generate predictions from the block of features, average over the spatial {feature_batch.shape[1:3]} spatial locations.")
  print(f"In order to convert the features to a single {feature_batch_average.shape[1]}-element vector per image.")
  print("\n") 

num_classes = len(class_names)
prediction_layer = Dense(num_classes)
if DEBUG:
  prediction_batch = prediction_layer(feature_batch_average)
  print("Applied a Dense() layer.")
  print("Convert features into predictions according to the number of classes.")
  print(f"There is no need to apply activation function because prediction will be treated as a logit of {prediction_batch.shape[1]} outputs.")
  print("\n") 
  
# Build the model
inputs = keras.Input(shape = img_shape)
x = data_augmentation(inputs)
x = preprocess_input(x)
x = base_model(x, training = False)
x = global_average_layer(x)
x = Dropout(0.2)(x)
outputs = prediction_layer(x)
model = keras.Model(inputs, outputs)

print("The model has been composed correctly!")

## **7.5. Train the model**

---



### **7.5.1. Train the classification layers**

In [None]:
#@markdown ##Run to compile the model.

base_learning_rate = 0.001

model.compile(
    optimizer = Nadam(learning_rate = base_learning_rate, beta_1 = 0.9, beta_2 = 0.999, epsilon = 1e-07, name='Nadam'),
    loss = SparseCategoricalCrossentropy(from_logits = True), 
    metrics = ["accuracy"]
)

if DEBUG:
  model.summary()
  print("\n") 
  print("Base model being used as a fixed feature structure.")
  print("Just the classification head layers being trained and updated.")

In [None]:
#@markdown ##Run to train the top classification layers of the model.

checkpoint_cb = ModelCheckpoint(
    os.path.join(drivemodel_folder_path, "weights_best.hdf5"), 
    monitor = monitor_callback, 
    verbose = 1, 
    save_best_only = True,
    mode = "auto",
    save_weights_only = False,
    save_freq = 'epoch',
)

csvlogger_cb = CSVLogger(
    os.path.join(model_folder_path, "model_log.csv"),
    separator = ',', 
    append = False
)

if DEBUG:
  print("Set ModelCheckpoint() as a callback for the model.")
  print(f"ModelCheckpoint() is monitoring {monitor_callback} during training of the model.")
  print("\n")
  print("Set CSVLogger() as a callback for the model.") 
  print("\n")

history = model.fit(
    train_dataset,
    epochs = initial_epochs,
    callbacks = [checkpoint_cb, csvlogger_cb],
    validation_data = validation_dataset,
    class_weight = class_weights,
    verbose = 1
)

In [None]:
#@markdown ##Run to display and save the learning curves of the training of the top classification layers of the model.

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

initial_epoch = []
for e in history.epoch:
  initial_epoch.append(e+1)

plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(initial_epoch, acc, label='Training Accuracy')
plt.plot(initial_epoch, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.ylabel('Accuracy')
plt.autoscale(enable=True, axis='both')
plt.xticks(ticks = initial_epoch)

plt.subplot(2, 1, 2)
plt.plot(initial_epoch, loss, label='Training Loss')
plt.plot(initial_epoch, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.ylabel('Sparse Categorical Crossentropy')
plt.xlabel('epoch')
plt.autoscale(enable=True, axis='both')
plt.xticks(ticks = initial_epoch)

plt.savefig(os.path.join(model_folder_path, "LearningCurvesTopClassificationLayers.png"),
            dpi = 300, format = "png", bbox_inches = None)

### **7.5.2. Round of fine-tuning**

In [None]:
#@markdown ## Run to compile the model.

# Unfreeze the convolutional base and set the bottom layers to be un-trainable
base_model.trainable = True
fine_tune_at = round(len(base_model.layers) * (2/3))
for layer in base_model.layers[:fine_tune_at]:
  layer.trainable =  False
if DEBUG:
  print(f"There are {len(base_model.layers)} layers on the base model")
  print(f"Fine-tuning from {fine_tune_at} layer onwards.")
  print("\n") 

model.compile(
    optimizer = Nadam(learning_rate = base_learning_rate / 10, beta_1 = 0.9, beta_2 = 0.999, epsilon = 1e-07, name = 'Nadam'),
    loss = SparseCategoricalCrossentropy(from_logits = True),
    metrics = ['accuracy']
)
if DEBUG:
  print("Model compiled!")
  model.summary()


In [None]:
#@markdown ##Run to train the entire model.

# Set EarlyStopping callback
earlystopping_cb = EarlyStopping(
    monitor = monitor_callback, 
    min_delta = 0, 
    patience = patience, 
    verbose = 1,
    mode = 'auto', 
)

csvlogger_cb.append = True

if DEBUG:
  print("Set ModelCheckpoint() as a callback for the model.")
  print(f"ModelCheckpoint() is monitoring {monitor_callback} during training of the model.")
  print("\n")
if DEBUG:
  print("Set EarlyStopping() as a callback for the model.")
  print(f"EarlyStopping()'s patience = {patience}.")
  print("\n") 
 
history_fine = model.fit(
    train_dataset,
    epochs = total_epochs,
    initial_epoch = history.epoch[-1]+1,
    callbacks = [checkpoint_cb, earlystopping_cb, csvlogger_cb],
    validation_data = validation_dataset,
    class_weight = class_weights,
    verbose = 1
)

model.save(os.path.join(drivemodel_folder_path, 'weights_last.hdf5'))

In [None]:
#@markdown ##Run to display and save the learning curves of the training of the entire model.

acc += history_fine.history['accuracy']
val_acc += history_fine.history['val_accuracy']
loss += history_fine.history['loss']
val_loss += history_fine.history['val_loss']

epochs = []
for e in history_fine.epoch:
  epochs.append(e+1)

fine_tune_epochs = initial_epoch + epochs

plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(fine_tune_epochs, acc, label='Training Accuracy')
plt.plot(fine_tune_epochs, val_acc, label='Validation Accuracy')
plt.plot([initial_epoch[-1], initial_epoch[-1]],
          plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.ylabel('Accuracy')
plt.autoscale(enable=True, axis='both')
plt.xticks(ticks = fine_tune_epochs)

plt.subplot(2, 1, 2)
plt.plot(fine_tune_epochs, loss, label='Training Loss')
plt.plot(fine_tune_epochs, val_loss, label='Validation Loss')
plt.plot([initial_epoch[-1], initial_epoch[-1]],
         plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.ylabel('Sparse Categorical Crossentropy')
plt.xlabel('epoch')
plt.autoscale(enable=True, axis='both')
plt.xticks(ticks = fine_tune_epochs)

plt.savefig(os.path.join(model_folder_path, "LearningCurvesFineTuning.png"),
            dpi = 300, format = "png", bbox_inches = None)

# **8. Download files**

In [None]:
#@markdown ##Please select the folders to be downloaded.

figures_folder = False #@param {type:"boolean"}
model_folder = False #@param {type:"boolean"}
regionprops_folder = False #@param {type:"boolean"}
patches_folder = False #@param {type:"boolean"}
snippets_folder = False #@param {type:"boolean"}

if figures_folder:
  !zip -r /content/WhoIsWho/Figures.zip /content/WhoIsWho/Figures

if model_folder:
  !zip -r /content/WhoIsWho/Model.zip /content/WhoIsWho/Model

if regionprops_folder:
  !zip -r /content/WhoIsWho/Regionpropsdf.zip /content/WhoIsWho/Patches/Regionpropsdf

if patches_folder:
  !zip -r /content/WhoIsWho/Patches.zip /content/WhoIsWho/Patches

if snippets_folder:
  !zip -r /content/WhoIsWho/Snippets.zip /content/WhoIsWho/Snippets



In [None]:
#@markdown ##Run to download as a zip the folders selected.

if figures_folder:
  files.download("/content/WhoIsWho/Figures.zip")

if model_folder:
  files.download("/content/WhoIsWho/Model.zip")

if regionprops_folder:
  files.download("/content/WhoIsWho/Regionpropsdf.zip") 

if patches_folder:
  files.download("/content/WhoIsWho/Patches.zip")

if snippets_folder:
  files.download("/content/WhoIsWho/Snippets.zip")