# Imports

In [1]:
import tensorflow as tf
from tensorflow.keras import layers, models, Input, Model
from tqdm import tqdm
import pandas as pd
import random
import time
import multiprocessing
import imageio
from collections import Counter
import numpy as np
import os
import matplotlib.pyplot as plt

# Settings

## Global parameters

In [2]:
# Current model to fit
IDX_MODEL = -1

In [3]:
plt.rcParams["figure.figsize"] = (16,9)

In [4]:
SIZE = 410
BATCH_SIZE = 32
EPOCHS = 20

# (TRAIN VALIDATION TEST)
SPLITS = (0.7, 0.2, 0.1)

MAX_CROPS_PREFETCHED = 2048
# After many tries, I found i could prefetch less than 20'500 crops in memory, at best.
# 12288 Doesn't seem to work with test prefetched as well
# 8192 # Success with only half the cache used
#16384 # Crashes 1687/6207 (27%)

BALANCED_DATASET = True # If you want the class-wise balanced and filtered dataset
SPARSE = False

## GPU

In [5]:
tf.config.list_physical_devices('GPU')

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [6]:
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
    # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
    # Memory growth must be set before GPUs have been initialized
        print(e)


1 Physical GPUs, 1 Logical GPUs


## Folders

In [7]:
metadata_folder = "data"
crop_folder = "C:/cell_crops"
result_folder = "results"

In [8]:
whole_classes = ["Nucleoplasm",
"Nuclear membrane",
"Nucleoli",
"Nucleoli fibrillar center",
"Nuclear speckles",
"Nuclear bodies",
"Endoplasmic reticulum",
"Golgi apparatus",
"Intermediate filaments",
"Actin filaments",
"Microtubules",
"Mitotic spindle",
"Centrosome",
"Plasma membrane",
"Mitochondria",
"Aggresome",
"Cytosol",
"Vesicles and punctate cytosolic patterns",
"Negative"]

# Alarm (to launch another fit and/or wake up at night)

In [9]:
import winsound
def beep_loud():
    winsound.Beep(500, 500)
    winsound.Beep(800, 300)
    winsound.Beep(500, 250)
    winsound.Beep(800, 150)
    winsound.Beep(500, 500)
def beep_quick(n=5):
    for _ in range(n-1):
        winsound.Beep(800, 600)
        time.sleep(0.2)
    winsound.Beep(800, 600)

# Label data

## Whole Dataframe

In [10]:
df_whole = pd.read_csv(os.path.join(metadata_folder, "train_bboxes.csv"), index_col="ID").rename(columns={"Unnamed: 0": "new_index"})

In [11]:
# crop_names = 
whole_crops_path = []
# labels = []
for idx, row in tqdm(df_whole.iterrows(), total=len(df_whole), postfix = "Loading images path to in memory list"):
    n_box = len(eval(row["boxes"]))
    whole_crops_path.extend([os.path.join(crop_folder, f"{idx}_{i}.png") for i in range(n_box)])


100%|██████████| 10412/10412 [00:02<00:00, 4486.00it/s, Loading images path to in memory list]


## Other dataframe : balanced classes

In [12]:
# Keeping only classes with more than 200 occurrences
uniques_count = df_whole.Label.value_counts()
filtered_labels = uniques_count[uniques_count>200]
display(filtered_labels)
# Filtering in those labels in the original dataset
f_df = df_whole.loc[df_whole.Label.isin(filtered_labels.index), :].reset_index()
# Getting one row per crop
f_df["boxes"] = f_df.boxes.apply(eval)
f_df = f_df.explode("boxes", ignore_index=True)[["ID", "Label", "old_index", "boxes"]]
f_df

0     1551
14    1054
16     906
13     818
4      788
7      719
5      636
12     561
3      549
8      529
2      515
6      476
10     404
9      294
17     274
1      221
Name: Label, dtype: int64

Unnamed: 0,ID,Label,old_index,boxes
0,5e22a522-bb99-11e8-b2b9-ac1f6b6435d0,0,5,"[0, 217, 1798, 2381]"
1,5e22a522-bb99-11e8-b2b9-ac1f6b6435d0,0,5,"[114, 493, 0, 273]"
2,5e22a522-bb99-11e8-b2b9-ac1f6b6435d0,0,5,"[646, 853, 1954, 2321]"
3,5e22a522-bb99-11e8-b2b9-ac1f6b6435d0,0,5,"[1382, 1733, 194, 525]"
4,5e22a522-bb99-11e8-b2b9-ac1f6b6435d0,0,5,"[1778, 2085, 1986, 2125]"
...,...,...,...,...
245030,df573730-bbca-11e8-b2bc-ac1f6b6435d0,14,21804,"[2534, 3071, 0, 953]"
245031,df573730-bbca-11e8-b2bc-ac1f6b6435d0,14,21804,"[2720, 3071, 2674, 3071]"
245032,df573730-bbca-11e8-b2bc-ac1f6b6435d0,14,21804,"[2778, 3049, 1878, 2169]"
245033,df573730-bbca-11e8-b2bc-ac1f6b6435d0,14,21804,"[2882, 3071, 1390, 1785]"


In [13]:
# Suffixing the duplicate IDs. Will correspond to the crops filenames
suffixes = f_df.groupby("ID").cumcount().values.astype(str)
f_df.ID = f_df.ID+"_"+suffixes
f_df

Unnamed: 0,ID,Label,old_index,boxes
0,5e22a522-bb99-11e8-b2b9-ac1f6b6435d0_0,0,5,"[0, 217, 1798, 2381]"
1,5e22a522-bb99-11e8-b2b9-ac1f6b6435d0_1,0,5,"[114, 493, 0, 273]"
2,5e22a522-bb99-11e8-b2b9-ac1f6b6435d0_2,0,5,"[646, 853, 1954, 2321]"
3,5e22a522-bb99-11e8-b2b9-ac1f6b6435d0_3,0,5,"[1382, 1733, 194, 525]"
4,5e22a522-bb99-11e8-b2b9-ac1f6b6435d0_4,0,5,"[1778, 2085, 1986, 2125]"
...,...,...,...,...
245030,df573730-bbca-11e8-b2bc-ac1f6b6435d0_31,14,21804,"[2534, 3071, 0, 953]"
245031,df573730-bbca-11e8-b2bc-ac1f6b6435d0_32,14,21804,"[2720, 3071, 2674, 3071]"
245032,df573730-bbca-11e8-b2bc-ac1f6b6435d0_33,14,21804,"[2778, 3049, 1878, 2169]"
245033,df573730-bbca-11e8-b2bc-ac1f6b6435d0_34,14,21804,"[2882, 3071, 1390, 1785]"


In [14]:
# Sampling the same (min) amount by class
f_unique_counts = f_df.Label.value_counts()
display(f_unique_counts)
min_amount = f_unique_counts.min()
print(f"Min count is {min_amount}")
df_sampled = f_df.groupby("Label").apply(lambda x: x.sample(frac=min_amount/len(x))).reset_index(drop=True)

0     37472
14    27495
16    22738
13    21168
7     18825
4     17527
5     15337
12    13952
3     12882
2     12672
8     11194
6     10198
10     7789
17     5619
9      5322
1      4845
Name: Label, dtype: int64

Min count is 4845


In [15]:
df_sampled

Unnamed: 0,ID,Label,old_index,boxes
0,b838966e-bbaf-11e8-b2ba-ac1f6b6435d0_28,0,9683,"[646, 1035, 2739, 3071]"
1,d5e64e8a-bbc9-11e8-b2bc-ac1f6b6435d0_27,0,21311,"[1466, 2365, 1298, 2181]"
2,bad6e3ce-bbac-11e8-b2ba-ac1f6b6435d0_17,0,8352,"[907, 1345, 469, 764]"
3,c10bd260-bbb9-11e8-b2ba-ac1f6b6435d0_6,0,14181,"[538, 913, 904, 1224]"
4,df28267e-bbbc-11e8-b2ba-ac1f6b6435d0_8,0,15549,"[278, 513, 1450, 1725]"
...,...,...,...,...
77515,e7d0b8ac-bbaa-11e8-b2ba-ac1f6b6435d0_7,17,7527,"[294, 665, 0, 511]"
77516,d688cf90-bb9f-11e8-b2b9-ac1f6b6435d0_14,17,2715,"[1226, 1525, 1518, 1749]"
77517,55569706-bb9d-11e8-b2b9-ac1f6b6435d0_3,17,1747,"[394, 605, 1082, 1314]"
77518,2c1886b0-bbc5-11e8-b2bc-ac1f6b6435d0_2,17,19148,"[0, 276, 1450, 2047]"


We went from 245035 to 77520 samples, but at least their labels are balanced.

### Also re-labelling (so we don't have useless columns in the one hot encoding)

In [16]:
# Resetting index will help use get rid of "holes"
sampled_uniques = pd.DataFrame(df_sampled.Label.unique()).rename(columns={0: "old_labelindex"}).rename_axis("new_labelindex").reset_index()
display(sampled_uniques)
new_label_dict = sampled_uniques.set_index("old_labelindex").to_dict()["new_labelindex"]
display(new_label_dict)
# Applying to the class list:
sampled_classes = [whole_classes[i] for i in sampled_uniques.old_labelindex]
display(sampled_classes)
# Also applying to the sampled dataframe
df_sampled = df_sampled.rename(columns={"Label": "Old_label"})
df_sampled["Label"] = df_sampled.Old_label.map(new_label_dict)
df_sampled = df_sampled.set_index("ID")
df_sampled

Unnamed: 0,new_labelindex,old_labelindex
0,0,0
1,1,1
2,2,2
3,3,3
4,4,4
5,5,5
6,6,6
7,7,7
8,8,8
9,9,9


{0: 0,
 1: 1,
 2: 2,
 3: 3,
 4: 4,
 5: 5,
 6: 6,
 7: 7,
 8: 8,
 9: 9,
 10: 10,
 12: 11,
 13: 12,
 14: 13,
 16: 14,
 17: 15}

['Nucleoplasm',
 'Nuclear membrane',
 'Nucleoli',
 'Nucleoli fibrillar center',
 'Nuclear speckles',
 'Nuclear bodies',
 'Endoplasmic reticulum',
 'Golgi apparatus',
 'Intermediate filaments',
 'Actin filaments',
 'Microtubules',
 'Centrosome',
 'Plasma membrane',
 'Mitochondria',
 'Cytosol',
 'Vesicles and punctate cytosolic patterns']

Unnamed: 0_level_0,Old_label,old_index,boxes,Label
ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
b838966e-bbaf-11e8-b2ba-ac1f6b6435d0_28,0,9683,"[646, 1035, 2739, 3071]",0
d5e64e8a-bbc9-11e8-b2bc-ac1f6b6435d0_27,0,21311,"[1466, 2365, 1298, 2181]",0
bad6e3ce-bbac-11e8-b2ba-ac1f6b6435d0_17,0,8352,"[907, 1345, 469, 764]",0
c10bd260-bbb9-11e8-b2ba-ac1f6b6435d0_6,0,14181,"[538, 913, 904, 1224]",0
df28267e-bbbc-11e8-b2ba-ac1f6b6435d0_8,0,15549,"[278, 513, 1450, 1725]",0
...,...,...,...,...
e7d0b8ac-bbaa-11e8-b2ba-ac1f6b6435d0_7,17,7527,"[294, 665, 0, 511]",15
d688cf90-bb9f-11e8-b2b9-ac1f6b6435d0_14,17,2715,"[1226, 1525, 1518, 1749]",15
55569706-bb9d-11e8-b2b9-ac1f6b6435d0_3,17,1747,"[394, 605, 1082, 1314]",15
2c1886b0-bbc5-11e8-b2bc-ac1f6b6435d0_2,17,19148,"[0, 276, 1450, 2047]",15


In [17]:
# Already done
df_sampled.to_csv("train_bboxes_sampled.csv")

In [18]:
# We need a specifid "crops path" for our sampled dataset.
sampled_crops_path = df_sampled.reset_index().ID.apply(lambda x: os.path.join(crop_folder, f"{x}.png")).tolist()
sampled_crops_path

['C:/cell_crops\\b838966e-bbaf-11e8-b2ba-ac1f6b6435d0_28.png',
 'C:/cell_crops\\d5e64e8a-bbc9-11e8-b2bc-ac1f6b6435d0_27.png',
 'C:/cell_crops\\bad6e3ce-bbac-11e8-b2ba-ac1f6b6435d0_17.png',
 'C:/cell_crops\\c10bd260-bbb9-11e8-b2ba-ac1f6b6435d0_6.png',
 'C:/cell_crops\\df28267e-bbbc-11e8-b2ba-ac1f6b6435d0_8.png',
 'C:/cell_crops\\c7257380-bbb0-11e8-b2ba-ac1f6b6435d0_21.png',
 'C:/cell_crops\\c243eec0-bbb3-11e8-b2ba-ac1f6b6435d0_11.png',
 'C:/cell_crops\\9bf68f5e-bbca-11e8-b2bc-ac1f6b6435d0_14.png',
 'C:/cell_crops\\0412456c-bba6-11e8-b2ba-ac1f6b6435d0_0.png',
 'C:/cell_crops\\d36f519c-bba6-11e8-b2ba-ac1f6b6435d0_3.png',
 'C:/cell_crops\\ee7befac-bbbc-11e8-b2ba-ac1f6b6435d0_14.png',
 'C:/cell_crops\\4ab4f494-bba4-11e8-b2b9-ac1f6b6435d0_14.png',
 'C:/cell_crops\\c9dd317a-bbac-11e8-b2ba-ac1f6b6435d0_20.png',
 'C:/cell_crops\\cab08d6c-bb99-11e8-b2b9-ac1f6b6435d0_8.png',
 'C:/cell_crops\\67fee340-bbb1-11e8-b2ba-ac1f6b6435d0_7.png',
 'C:/cell_crops\\8dbd4e38-bb9c-11e8-b2b9-ac1f6b6435d0_7.png',

## Train-test split

In [19]:
def get_filename_dataset():
    # Seleting the crops path, depending on the gloabl parameter BALANCED_DATASET
    source_path = sampled_crops_path if BALANCED_DATASET else whole_crops_path
    # Shuffling
    source_path = random.sample(source_path, len(source_path))
    test_size = int(len(source_path) * SPLITS[2])
    
    return tf.data.Dataset.from_tensor_slices(source_path[:-test_size]), tf.data.Dataset.from_tensor_slices(source_path[-test_size:])

In [20]:
trainval_filelist_ds, test_filelist_ds = get_filename_dataset()

# Applying global parameters

In [21]:
# PARAMETER : whole or sampled dataframe
df, crops_path, classes = (df_sampled, sampled_crops_path, sampled_classes) if BALANCED_DATASET else (df_whole, whole_crops_path, whole_classes)
N_LABELS = len(set(df.Label))

# Pre-process

In [22]:
def extract_crop_id(crop_path: str):
    crop_id = crop_path.split("\\")[-1]
    crop_id = crop_id.split(".") if BALANCED_DATASET else crop_id.split("_")
    return crop_id[0]


def get_onehot_label(filepath):
    return tf.one_hot(df.loc[extract_crop_id(filepath), "Label"], N_LABELS)

def get_label(filepath):
    return df.loc[extract_crop_id(filepath), "Label"]

def get_crop(filepath):
    f_str = filepath.numpy().decode('UTF-8')
    return tf.convert_to_tensor(imageio.imread(f_str), dtype=tf.float32), get_label(f_str) if SPARSE else get_onehot_label(f_str)

def tensor_extract_crop_id(tensor: tf.Tensor):
    """
    Not useful anymore.
    """
    return bytes.decode(tf.strings.split(tf.strings.split(tensor, "/")[-1], "_")[0].numpy())
    

def get_classweights(filenames: tf.data.Dataset):
    """
    Class weights can't work with Onehot output encoding or more than 2 classes.
    We have to do sample weights
    """
    print("\nPREPARING CLASS WEIGHTS\n")
    crop_ids = list(map(extract_crop_id, map(bytes.decode, [c for c in filenames.as_numpy_iterator()])))
    labels = df.loc[crop_ids, "Label"]
    counts = Counter(labels)
    size = sum(counts.values())
    # Taking the medium value as weight "1.".
    med = np.median(tuple(counts.values()))
    weights = {k: med / v for k, v in counts.items()}
    return weights

def get_sampleweights(filenames: tf.data.Dataset):
    print("\nPREPARING CLASS WEIGHTS\n")
    crop_ids = list(map(extract_crop_id, map(bytes.decode, [c for c in filenames.as_numpy_iterator()])))
    labels = df.loc[crop_ids, "Label"]
    counts = Counter(labels)
    size = sum(counts.values())
    # Taking the medium value as weight "1.".
    med = np.median(tuple(counts.values()))
    weights = {k: med / v for k, v in counts.items()}
    return np.array(map(lambda x: weights.get(x), weights))

In [23]:
def prepare_dataset(filename_ds: tf.data.Dataset):
    return filename_ds.map(lambda x: tf.py_function(func=get_crop,
                    inp=[x],
                    Tout=(tf.float32, tf.float32)), 
                    num_parallel_calls = tf.data.AUTOTUNE)

def prepare_training_split(crops_path, train_ratio: float = 0.8, n_prefetch: int = None, shuffle: bool = True):
    train_ratio = SPLITS[0] / sum(SPLITS[:2])
    # Shuffling
    if shuffle:
        crops_path = random.sample(crops_path, len(crops_path))

    # Calcultating the number of prefetch that the memory can hold.
    if n_prefetch is None:
        n_prefetch = MAX_CROPS_PREFETCHED // BATCH_SIZE
        print(f"\nPREFETCH = {n_prefetch}\n")

    filelist_ds = tf.data.Dataset.from_tensor_slices(crops_path)

    train_length = int(len(crops_path)*train_ratio)
    train_filename_ds = filelist_ds.take(train_length)
#     weights = get_sampleweights(train_filename_ds)

    # Train set
    ret_train = prepare_dataset(train_filename_ds)

    # Test set
    ret_val = prepare_dataset(filelist_ds.skip(train_length))

    # Not returning the weights anymore
    return ret_train.batch(BATCH_SIZE).prefetch(n_prefetch), ret_val.batch(BATCH_SIZE).prefetch(n_prefetch).repeat()

# Models

In [24]:
model_list = []

In [25]:
# Better. Best so far
# 14
model = models.Sequential()
model.add(layers.Conv2D(128, (15, 15), input_shape=(SIZE, SIZE, 3)))
model.add(layers.ReLU())
model.add(layers.BatchNormalization())
model.add(layers.MaxPooling2D((2, 2)))

model.add(layers.Conv2D(128, (11, 11)))
model.add(layers.ReLU())
model.add(layers.BatchNormalization())
model.add(layers.MaxPooling2D((2, 2)))

model.add(layers.Conv2D(64, (7, 7)))
model.add(layers.ReLU())
model.add(layers.BatchNormalization())
model.add(layers.MaxPooling2D((2, 2)))

model.add(layers.Conv2D(64, (5, 5)))
model.add(layers.ReLU())
model.add(layers.BatchNormalization())
model.add(layers.MaxPooling2D((2, 2)))

model.add(layers.Conv2D(32, (3, 3)))
model.add(layers.ReLU())
model.add(layers.BatchNormalization())
model.add(layers.MaxPooling2D((2, 2)))

model.add(layers.Conv2D(32, (1, 1)))
model.add(layers.ReLU())
model.add(layers.BatchNormalization())
model.add(layers.MaxPooling2D((2, 2)))

model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(32, activation='relu'))
model.add(layers.Dense(len(classes), activation="softmax"))

model_list.append({"name":"invert_conv", "model": model})

# Train

In [29]:
ds_train, ds_val = prepare_training_split(crops_path) # not getting weight anymore


PREFETCH = 64



In [30]:
checked_metrics = 'val_categorical_accuracy'

early_stop = tf.keras.callbacks.EarlyStopping(
                 monitor=checked_metrics, min_delta=0, patience=5, verbose=0,mode='max', baseline=None,
                 restore_best_weights=True)

checkpoint = tf.keras.callbacks.ModelCheckpoint(
    filepath="\\tmp\\checkpoint",
    save_weights_only=True,
    monitor= checked_metrics,
    mode='max',
    save_best_only=True)

callbacks = [early_stop, checkpoint]

In [31]:
def train(name, model=None, epochs=EPOCHS):
    
    if isinstance(name, dict):
        model = name["model"]
        name = name["name"]
        
    if model is None:
        raise Exception("Please give a model in argument or pass a dict{'name':name, 'model': model}.")
        
    model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy() if SPARSE else tf.keras.losses.CategoricalCrossentropy(),
              metrics=[tf.keras.metrics.CategoricalAccuracy()]
                 )
    print(f"\nTRAINING {name.upper()}\n")
    print(model.summary())
    try:
        history = model.fit(ds_train, 
                            validation_data=ds_val,
    #                         sample_weight=weights,
                            validation_steps = 128,
                            epochs = epochs,
                            callbacks = callbacks
                           )
        model.save(f"models\\{name}_{'sampled' if BALANCED_DATASET else 'whole'}")
        res_df = pd.DataFrame(history.history)
        res_df.to_csv(os.path.join(result_folder, f"history_{name}_{'sampled' if BALANCED_DATASET else 'whole'}.csv"))
        plot_history(name, res_df)
        print(f"Model, results and plots saved as {name}.")
    except Exception as e:
        beep_quick(8)
        print(f"ERROR with model {name}")
        print(e)
    beep_loud()
    beep_loud()
    return history


def plot_history(name, history = None, show=True, save=True, balanced=BALANCED_DATASET):
    
    # Transforming to DataFrame, if not already
    if history is None:
        history = pd.read_csv(os.path.join(result_folder, f"history_{name}_{'sampled' if balanced else 'whole'}.csv"))
    elif not isinstance(history, pd.DataFrame):
        history = pd.DataFrame(history.history)
    
    # Plotting
    fig, (ax1, ax2) = plt.subplots(2,1)
    
    ax1.plot(history['categorical_accuracy'])
    ax1.plot(history['val_categorical_accuracy'])
    ax1.set_title(f'Model categorical accuracy - {name}')
    ax1.set_ylabel('Accuracy')
    ax1.set_xlabel('Epoch')
    ax1.legend(['Train', 'Validation'], loc='upper left')
    ax1.grid()
    
    ax2.plot(history['loss'])
    ax2.plot(history['val_loss'])
    ax2.set_title('Model loss')
    ax2.set_ylabel('Loss')
    ax2.set_xlabel('Epoch')
    ax2.legend(['Train', 'Validation'], loc='upper left')
    ax2.grid()
    
    if save:
        dest_path = os.path.join(result_folder, f"plot_{name}_{'sampled' if balanced else 'whole'}.png")
        plt.savefig(dest_path)
        print(f"History plots saved at {dest_path}")
    if show:
        plt.show()

# Lets'go

## Training

In [32]:
model_dict = model_list[IDX_MODEL]

In [33]:
model_dict

{'name': 'invert_conv',
 'model': <tensorflow.python.keras.engine.sequential.Sequential at 0x1c03e359430>}

In [31]:
# Had to interrupt this model fitting for the final presentation !
hist = train(model_dict)


TRAINING INVERT_CONV_DROPOUT

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 396, 396, 128)     86528     
_________________________________________________________________
re_lu (ReLU)                 (None, 396, 396, 128)     0         
_________________________________________________________________
batch_normalization (BatchNo (None, 396, 396, 128)     512       
_________________________________________________________________
dropout (Dropout)            (None, 396, 396, 128)     0         
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 198, 198, 128)     0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 188, 188, 128)     1982592   
_________________________________________________________________
re_lu_1 (ReLU)           

KeyboardInterrupt: 

Testing : soon...