# Imports and parameters

In [1]:
import tensorflow as tf
from tensorflow.keras import layers, models
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import imageio
from tqdm import tqdm
import random

import PIL

In [2]:
tf.config.list_physical_devices('GPU')

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [3]:
metadata_folder = "data/"
crop_folder_ssd = "C:/cell_crops/"

In [4]:
classes = ["Nucleoplasm",
"Nuclear membrane",
"Nucleoli",
"Nucleoli fibrillar center",
"Nuclear speckles",
"Nuclear bodies",
"Endoplasmic reticulum",
"Golgi apparatus",
"Intermediate filaments",
"Actin filaments",
"Microtubules",
"Mitotic spindle",
"Centrosome",
"Plasma membrane",
"Mitochondria",
"Aggresome",
"Cytosol",
"Vesicles and punctate cytosolic patterns",
"Negative"]

In [25]:
SIZE = 410
BATCH_SIZE = 32
EPOCHS = 10
# After tries, I found i could prefetch less than 20'500 crops in memory.
MAX_CROPS_PREFETCHED = 4096
# 12288 Doesn't seem to work with test prefetched as well
# 8192 # Success with only half the cache used
#16384 # Crashes 1687/6207 (27%)


In [6]:
tf.keras.models.load_model("models/default")

<tensorflow.python.keras.engine.sequential.Sequential at 0x1b0423daee0>

# Designing the model

# Label data

In [10]:
df = pd.read_csv(metadata_folder+"train_bboxes.csv", index_col="ID").rename(columns={"Unnamed: 0": "new_index"})

In [11]:
N_LABELS = len(set(df.Label))

In [12]:
(df.image_height*df.image_width).sum()*4

221140484096

In [13]:
df

Unnamed: 0_level_0,new_index,old_index,Label,image_height,image_width,boxes_height,boxes_width,boxes
ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
5e22a522-bb99-11e8-b2b9-ac1f6b6435d0,0,5,0,3072,3072,"[217, 379, 207, 351, 307, 743, 399, 231, 255, ...","[583, 273, 367, 331, 139, 283, 383, 219, 499, ...","[[0, 217, 1798, 2381], [114, 493, 0, 273], [64..."
5f79a114-bb99-11e8-b2b9-ac1f6b6435d0,1,6,14,2048,2048,"[433, 599, 165, 558, 735, 771, 669, 583, 767, ...","[605, 606, 699, 609, 617, 447, 452, 469, 610, ...","[[0, 433, 0, 605], [0, 599, 568, 1174], [0, 16..."
5c801c04-bb99-11e8-b2b9-ac1f6b6435d0,2,9,14,3072,3072,"[249, 285, 315, 335, 306, 399, 415, 289, 277, ...","[331, 543, 537, 339, 275, 195, 319, 137, 276, ...","[[0, 249, 42, 373], [0, 285, 1894, 2437], [226..."
5e9afd56-bb99-11e8-b2b9-ac1f6b6435d0,3,10,0,2048,2048,"[146, 261, 288, 368, 269, 211, 156, 208, 195, ...","[205, 209, 283, 129, 231, 179, 210, 162, 126, ...","[[0, 146, 455, 660], [0, 261, 635, 844], [5, 2..."
5f1af6b4-bb99-11e8-b2b9-ac1f6b6435d0,4,11,3,2048,2048,"[365, 525, 162, 273, 279, 323, 469, 310, 316, ...","[884, 415, 582, 708, 485, 429, 557, 286, 475, ...","[[0, 365, 0, 884], [0, 525, 766, 1181], [0, 16..."
...,...,...,...,...,...,...,...,...
d9d99186-bbca-11e8-b2bc-ac1f6b6435d0,10407,21798,3,2048,2048,"[845, 173, 299, 455, 587, 1127, 659, 811, 395,...","[1093, 607, 437, 655, 741, 387, 483, 265, 422,...","[[0, 845, 0, 1093], [0, 173, 1122, 1729], [70,..."
daa22470-bbca-11e8-b2bc-ac1f6b6435d0,10408,21799,0,2048,2048,"[1149, 909, 439, 623, 940, 721, 625]","[463, 815, 735, 1057, 536, 223, 451]","[[0, 1149, 370, 833], [0, 909, 666, 1481], [75..."
dc261180-bbca-11e8-b2bc-ac1f6b6435d0,10409,21800,6,2048,2048,"[165, 288, 329, 275, 260, 259, 358, 431, 387, ...","[378, 182, 291, 411, 266, 157, 259, 291, 365, ...","[[0, 165, 78, 456], [0, 288, 459, 641], [0, 32..."
dd0989c4-bbca-11e8-b2bc-ac1f6b6435d0,10410,21801,14,2048,2048,"[357, 325, 425, 345, 515, 527, 251, 424, 343, ...","[399, 512, 591, 695, 461, 437, 333, 507, 615, ...","[[0, 357, 0, 399], [0, 325, 249, 761], [0, 425..."


In [14]:
# crop_names = 
crops_path_hdd = []
crops_path_ssd = []
# labels = []
for idx, row in tqdm(df.iterrows(), total=len(df), postfix = "Loading images path to in memory list"):
    n_box = len(eval(row["boxes"]))
    crops_path_ssd.extend([f"{crop_folder_ssd}{idx}_{i}.png" for i in range(n_box)])
#     labels.extend([row.Label for _ in range(n_box)])

100%|██████████| 10412/10412 [00:01<00:00, 6197.56it/s, Loading images path to in memory list]


In [15]:
filelist_ds_hdd = tf.data.Dataset.from_tensor_slices(crops_path_hdd)
filelist_ds_ssd = tf.data.Dataset.from_tensor_slices(crops_path_ssd)

# Pre-process

In [16]:
tt = crops_path_ssd[1568]
tt = tt.split("/")[-1].split("_")[0]
display(tt)
display(tf.one_hot(df.loc[tt, "Label"], N_LABELS))
del tt

'9df04740-bb99-11e8-b2b9-ac1f6b6435d0'

<tf.Tensor: shape=(19,), dtype=float32, numpy=
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
       0., 0.], dtype=float32)>

In [17]:
def get_label(filepath):
#     return tf.one_hot(df.loc[tf.strings.split(tf.strings.split(filepath, "/")[-1], "_")[0], "Label"], N_LABELS)
    return tf.one_hot(df.loc[filepath.split("/")[-1].split("_")[0], "Label"], N_LABELS)

def get_crop(filepath):
    f_str = filepath.numpy().decode('UTF-8')
    return tf.convert_to_tensor(imageio.imread(f_str), dtype=tf.float32), get_label(f_str)
#     return tf.convert_to_tensor(imageio.imread(filepath.numpy().decode('UTF-8')), dtype=tf.float32), get_label(filepath)

In [18]:
def prepare_dataset(crops_path, train_ratio: float = 0.8, n_prefetch: int = None, shuffle: bool = True):
    if shuffle:
        crops_path = random.sample(crops_path, len(crops_path))
        
    if n_prefetch is None:
        n_prefetch = MAX_CROPS_PREFETCHED // BATCH_SIZE
        print(f"\nPREFETCH = {n_prefetch}\n")
    filelist_ds = tf.data.Dataset.from_tensor_slices(crops_path)
    
    train_length = int(len(crops_path)*train_ratio)
    print(str(train_length))
    ret_train = filelist_ds.take(train_length).map(lambda x: tf.py_function(func=get_crop,
                        inp=[x],
                        Tout=(tf.float32, tf.float32)), 
                        num_parallel_calls = tf.data.AUTOTUNE)
    ret_test = filelist_ds.skip(train_length).map(lambda x: tf.py_function(func=get_crop,
                        inp=[x],
                        Tout=(tf.float32, tf.float32)), 
                        num_parallel_calls = tf.data.AUTOTUNE)
    print(len(ret_train), len(ret_test))
    ### Get delete the "drop_reminder" arguemnt ! ###
    return ret_train.batch(BATCH_SIZE).prefetch(n_prefetch), ret_test.batch(BATCH_SIZE).prefetch(n_prefetch)
# .prefetch(n_prefetch).cache()

# .prefetch(n_prefetch).cache()

# Train

In [19]:
ds_train, ds_test = prepare_dataset(crops_path_ssd)


PREFETCH = 128

198616
198616 49655


In [20]:
checked_metrics = 'val_categorical_accuracy'

early_stop = tf.keras.callbacks.EarlyStopping(
                 monitor=checked_metrics, min_delta=0, patience=5, verbose=0,mode='max', baseline=None,
                 restore_best_weights=True)

checkpoint = tf.keras.callbacks.ModelCheckpoint(
    filepath="\\tmp\\checkpoint",
    save_weights_only=True,
    monitor= checked_metrics,
    mode='max',
    save_best_only=True)

callbacks = [early_stop, checkpoint]

In [None]:
model = tf.models.load_model(f"models/conv_smaller")

In [26]:
history2 = model.fit(ds_train, validation_data=ds_test, validation_steps = 8, epochs = EPOCHS, callbacks = callbacks)

Epoch 1/10
 535/6207 [=>............................] - ETA: 40:46 - loss: 0.1356 - categorical_accuracy: 0.4386

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "N:\Conda folders\envs\env_tf\lib\site-packages\IPython\core\interactiveshell.py", line 3437, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-26-f12e410695c0>", line 1, in <module>
    history2 = model.fit(ds_train, validation_data=ds_test, validation_steps = 8, epochs = EPOCHS, callbacks = callbacks)
  File "N:\Conda folders\envs\env_tf\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1100, in fit
    tmp_logs = self.train_function(iterator)
  File "N:\Conda folders\envs\env_tf\lib\site-packages\tensorflow\python\eager\def_function.py", line 828, in __call__
    result = self._call(*args, **kwds)
  File "N:\Conda folders\envs\env_tf\lib\site-packages\tensorflow\python\eager\def_function.py", line 855, in _call
    return self._stateless_fn(*args, **kwds)  # pylint: disable=not-callable
  File "N:\Conda folders\envs\env_tf\lib\site-packages\tensorflow\python\eager\function.py", line 29

TypeError: object of type 'NoneType' has no len()

# Debug and local performance test

## Cropping

# "Equal" timing

## Resizing

# API