In [1]:
import os
import sys
sys.path.append("..")

import tensorflow as tf

import utils as utils
import CST as CST
from metrics import recall_m, precision_m, f1_m, auc_m

In [2]:
# TODO: see if results improve on cifar corrupted vs cifar normal


#### Load previous model and instanciate cst

In [3]:
model_path = "../models/CST4_alpha1_DC4.h5"
tile_size = 128
alpha = 1
dist_params = {
    "contrast": {"lower": 0.6, "upper": 1.6},
    "color": {"factor": [20,0,20]},
    "blur": {"kernel_size": 2, "sigma": 5.},  # kernel size is 'kernel_size * 2 + 1'
    "brightness": {"max_delta":0.3}
}

model = tf.keras.models.load_model(model_path)

cst = CST.ContrastiveStabilityTraining(
    model=model,
    tile_size=tile_size,
    dist_params=dist_params,
    alpha=alpha
)
cst.cst_model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 128, 128, 3)       0         
_________________________________________________________________
sequential (Sequential)      (None, 1)                 22859809  
Total params: 22,859,809
Trainable params: 22,825,377
Non-trainable params: 34,432
_________________________________________________________________


#### Compile for cst

In [4]:
opt = tf.keras.optimizers.Adam(lr=1e-4, amsgrad=True)
metrics = [recall_m, precision_m, f1_m, auc_m]

cst.compile_cst(optimizer=opt, metrics=metrics)

#### create data generator and train network

In [5]:
# generator parameters
data_path = "../data/histo"
batch_size = 64

# train parameters
save_all_epochs = True
model_save_path = "../models"
model_name = "cst_model_name"
save_metrics = True
epochs = 10


gen = tf.keras.preprocessing.image.ImageDataGenerator(
    validation_split=0.2,
    preprocessing_function=utils.normalize_image
)

t_flow = gen.flow_from_directory(
    directory=data_path,
    target_size=(tile_size,tile_size),
    color_mode='rgb',  # rgb for color
    batch_size=batch_size,
    class_mode='binary',  # 'sparse' for multiclass, 'binary' for binary 
    subset='training'
)

v_flow = gen.flow_from_directory(
    directory=data_path,
    target_size=(tile_size,tile_size),
    color_mode="rgb",  # rgb for color
    batch_size=batch_size,
    shuffle=False,
    class_mode='binary',  # 'sparse' for multiclass, 'binary' for binary
    subset='validation'
)

class_weight = utils.get_class_weights(t_flow.classes)

Found 41770 images belonging to 2 classes.
Found 10442 images belonging to 2 classes.


In [7]:
cst.train_cst(
    x=t_flow,
    validation_data=v_flow,
    save_all_epochs=save_all_epochs,
    model_save_path=model_save_path,
    model_name=model_name,
    save_metrics=save_metrics,
    class_weight=class_weight,
    epochs=epochs
)

Epoch 1/10

Process ForkPoolWorker-8:
Process ForkPoolWorker-6:
Process ForkPoolWorker-4:
Process ForkPoolWorker-7:
Process ForkPoolWorker-2:
Traceback (most recent call last):
Process ForkPoolWorker-5:
Process ForkPoolWorker-3:
Process ForkPoolWorker-1:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/lib/python3.6/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/usr/lib/python3.6/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/usr/lib/python3.6/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python3.6/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
T

KeyboardInterrupt: 