In [None]:
import numpy as np
import os
import warnings
from datetime import datetime
import mlflow
from dotenv import load_dotenv
import tensorflow as tf
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report
from tensorflow import keras 

warnings.filterwarnings('ignore')

load_dotenv('../.env')

RSEED = 123
MODELS_DIR=os.path.join('..','models')
MODEL_CHECKPOINTS_DIR=os.path.join('..','model_checkpoints')

start_time = datetime.now().strftime('-%Y-%m-%d-%T')

In [None]:
for dir in MODELS_DIR, MODEL_CHECKPOINTS_DIR:
    if not os.path.exists(dir):
        os.mkdir(dir)

In [None]:
# Set parameters
batch_size=64
patience=10 
min_delta=0.001
dropout_rate=0.25
initial_learning_rate=0.0005


run_name_params = (
    f'bs{batch_size}'
    f'_pat{patience}'
    f'_del{min_delta}'
    f'_dr{dropout_rate}'
    f'_lr{initial_learning_rate}'
)

parent_run_name = f'mobilenetv2_{run_name_params}_save'

In [None]:
# Set up gcloud TPUs
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='local')
tf.config.experimental_connect_to_cluster(cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
strategy = tf.distribute.TPUStrategy(cluster_resolver)


In [None]:
# Set information for mlflow
run_description = """
Fully trained models 
    - classify each slice by tumour/tissue regions in the segmentation
    - Uses MobileNetV2
    - Saves model at end of run
"""
dataset = 'full_data_stratified'
mlflow_tracking_uri = os.getenv('MLFLOW_URI')
if mlflow_tracking_uri:
    mlflow.set_tracking_uri(mlflow_tracking_uri)
mlflow_expt = os.getenv('CLASSIFICATION_EXPT')
if mlflow_expt:
    mlflow.set_experiment(mlflow_expt)    


print(f'Logging to \n URI:{mlflow_tracking_uri}\n Expt:{mlflow_expt}')


In [None]:


with mlflow.start_run(
    run_name=parent_run_name,
    tags={
        'dataset': dataset,
    },
    description=run_description,
):

    img_height = 240
    img_width = 240
    data_dir = os.path.join('..','data','UPENN-GBM','slice_classification_common_stratify','train')

    train_ds = tf.keras.utils.image_dataset_from_directory(
        data_dir,
        validation_split=0.2,
        subset="training",
        color_mode="rgba",
        seed=RSEED,
        image_size=(img_height, img_width),
        batch_size=batch_size)

    val_ds = tf.keras.utils.image_dataset_from_directory(
        data_dir,
        validation_split=0.2,
        subset="validation",
        color_mode="rgba",
        seed=RSEED,
        image_size=(img_height, img_width),
        batch_size=batch_size)

    
    class_names = train_ds.class_names

    # Calculate class weights for weighting accuracy
    ds_classes = []
    for _, batch_classes in train_ds:
        ds_classes.append(batch_classes.numpy())

    ds_classes = np.concatenate(ds_classes)

    class_weight = compute_class_weight(
        class_weight = 'balanced',
        classes = np.unique(ds_classes),
        y=ds_classes
    )

    class_weight = dict(zip(np.unique(ds_classes), class_weight))

    AUTOTUNE = tf.data.AUTOTUNE

    train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
    val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

    num_classes = len(class_names)
    
    margin = 8
    scaled_height = img_height - 2*margin
    scaled_width = img_width - 2*margin

    # Build layers for model with fixed base
    with strategy.scope():
        crop_layer = tf.keras.layers.Cropping2D(margin)
        rescale_initial = tf.keras.layers.Rescaling(1./127.5, offset=-1)
        conv_4to3_channel = tf.keras.layers.Conv2D(3,1,padding='same',activation='tanh')
        base_model = tf.keras.applications.MobileNetV2(
            input_shape=(scaled_width,scaled_height,3),
            include_top=False,
            weights='imagenet'
        )
        global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
        prediction_layer = tf.keras.layers.Dense(num_classes)

        base_model.trainable = False
        
        inputs = tf.keras.Input(shape=(img_width, img_height, 4))
        x = crop_layer(inputs)
        x = rescale_initial(x)
        x = conv_4to3_channel(x)
        x = base_model(x, training=False)
        x = global_average_layer(x)
        x = tf.keras.layers.Dropout(dropout_rate)(x)
        outputs = prediction_layer(x)
    
        earlystopping = tf.keras.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=patience,
            min_delta=min_delta,
            )
        
        model = tf.keras.Model(inputs, outputs)
        model.compile(
            optimizer=tf.keras.optimizers.Adam(learning_rate=initial_learning_rate,),
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
            metrics=['accuracy'],
        )

        
    # Initial fit of classification and 4 to 3 channel layers
    with mlflow.start_run(
        run_name=f'fixed_{run_name_params}',
        tags={'dataset': dataset},
        nested=True
    ):
        mlflow.tensorflow.autolog()
        mlflow.log_param('ds_batch_size', batch_size)
        mlflow.log_param('ds_validation_batch_size', batch_size)


        fixed_base_epochs=80
        history_fixed_base = model.fit(
            train_ds,
            validation_data=val_ds,
            epochs=fixed_base_epochs,
            class_weight=class_weight,
            callbacks=[earlystopping],
        )

    # Relax top layers of base model
    base_model.trainable = True
    fix_below_layer = 100
    for layer in base_model.layers[:fix_below_layer]:
        layer.trainable = False
    with strategy.scope():
        model.compile(
            optimizer=tf.keras.optimizers.Adam(learning_rate=initial_learning_rate/10.0),
                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'],
        )

    with mlflow.start_run(
        run_name=f'partial_{run_name_params}',
        tags={'dataset': dataset},
        nested=True
    ):
        mlflow.tensorflow.autolog()
        mlflow.log_param('ds_batch_size', batch_size)
        mlflow.log_param('ds_validation_batch_size', batch_size)

        partial_relax_epochs=history_fixed_base.epoch[-1] + 100 
        history_partial_relax = model.fit(
            train_ds,
            validation_data=val_ds,
            epochs=partial_relax_epochs,
            initial_epoch=history_fixed_base.epoch[-1]+1,
            class_weight=class_weight,
            callbacks=[earlystopping],
        )

    # Fully relax model
    model.trainable = True

    with strategy.scope():
        model.compile(
            optimizer=tf.keras.optimizers.Adam(learning_rate=initial_learning_rate/10.0),
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
            metrics=['accuracy'],
        )

    with mlflow.start_run(
        run_name=f'relax_{run_name_params}',
        tags={'dataset': dataset},
        nested=True
    ):
        mlflow.tensorflow.autolog()
        mlflow.log_param('ds_batch_size', batch_size)
        mlflow.log_param('ds_validation_batch_size', batch_size)

        # create checkpoint
        checkpoint_path = os.path.join(
            MODEL_CHECKPOINTS_DIR,
            parent_run_name + start_time + "-{epoch:03d}-{val_loss:.4f}.ckpt"
        )
        ckpt_callback = tf.keras.callbacks.ModelCheckpoint(
            filepath=checkpoint_path, 
            verbose=1, 
            save_weights_only=False,
            save_freq='epoch',
            monitor='val_loss',
            mode='min',
            save_best_only=True,
        ) 

        full_relax_epochs=history_partial_relax.epoch[-1] + 100
        history_full_relax = model.fit(
            train_ds,
            validation_data=val_ds,
            epochs=full_relax_epochs,
            initial_epoch=history_partial_relax.epoch[-1]+1,
            class_weight=class_weight,
            callbacks=[earlystopping, ckpt_callback],
        )


In [None]:
test_data_dir = os.path.join('..','data','UPENN-GBM','slice_classification_common_stratify','test')

test_ds = tf.keras.utils.image_dataset_from_directory(
    test_data_dir,
    color_mode="rgba",
    seed=RSEED,
    shuffle=False,
    image_size=(img_height, img_width),
    batch_size=batch_size,
)

In [None]:
val_pred = model.predict(val_ds)
val_prob = tf.nn.softmax(val_pred)
val_class_pred = [np.argmax(x) for x in val_prob]
val_base = [ 0 for x in val_class_pred ]

val_true_class = []
for _, classes in val_ds:
    val_true_class += list(classes)




In [None]:
test_pred = model.predict(test_ds)
test_prob = tf.nn.softmax(test_pred)
test_class_pred = [np.argmax(x) for x in test_prob]
test_base = [ 0 for x in test_class_pred ]

test_true_class = []
for _, classes in test_ds:
    test_true_class += list(classes)


In [None]:
print(classification_report(val_true_class, val_class_pred))

In [None]:
print(classification_report(test_true_class, test_class_pred))

In [None]:
model_file_name = os.path.join(MODELS_DIR, parent_run_name + start_time)
model.save(model_file_name)
