In [1]:
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import keras
import tensorflow as tf

import sys
import json
import importlib
from pathlib import Path

sys.path.append("..")

from losses import DiceCELoss
from dataloader import get_dataset
from schedulers import WarmUpCosine
from models.vanilla import UNetEncoder, UNet, get_skip_names_from_encoder


def import_method(method_path: str):
    """Import a method from a module path."""
    method_shards = method_path.split(".")
    method_shards[0] = {
        "np": "numpy",
        "tf": "tensorflow",
        "tfa": "tensorflow_addons",
    }.get(method_shards[0], method_shards[0])

    module_path = ".".join(method_shards[:-1])
    method_name = method_shards[-1]

    module = importlib.import_module(module_path)
    return getattr(module, method_name)


# Experiment configuration

In [2]:
config = {
    "experiment_dir": "/home/sangohe/projects/pathology/results/TCGA_patches_224_56-fold0-DiceCE",
    "train_tfrecord_path": "/home/sangohe/projects/pathology/data/TCGA_patches_224_56/fold0_train.tfrecord",
    # "train_tfrecord_path": "/data/histopathology/TCGA/tfrecords/TCGA_patches_224_56/fold0_train.tfrecord",
    "val_tfrecord_path": "/home/sangohe/projects/pathology/data/TCGA_patches_224_56/fold0_val.tfrecord",
    # "val_tfrecord_path": "/data/histopathology/TCGA/tfrecords/TCGA_patches_224_56/fold0_val.tfrecord",
    "epochs": 100,
    "warmup_epoch_percentage": 0.1,
    "num_train_samples": 120_000,
    "dataloader": {
        "augmentations": True,
        "filter_non_zero_prob": 0.8,
        "batch_size": 32,
        "cache": False,
        "prefetch": True,
        "shuffle_size": 100_000,
    },
    "model": {
        "input_shape": (224, 224, 3),
        "filters_per_level": [128, 256, 512, 1024, 1024],
        "activation": "relu",
        "kernel_size": 3,
        "strides": 1,
        "dilation_rate": 1,
        "padding": "same",
        "norm_layer": "keras.layers.BatchNormalization",
        "pooling_layer": "keras.layers.MaxPooling2D",
        "blocks_depth": [2, 2, 2, 2, 2],
        "dropout_rate": 0.3,
        # UNet specific.
        "num_classes": 2,
        "out_activation": "linear",  # Because DiceCE loss requires logits.
        "upsample_layer": "keras.layers.UpSampling2D",
        "attention_layer": "models.conv_layers.AdditiveCrossAttention",
    },
    "optimizer": {
        "learning_rate": 0.001,
        "weight_decay": 0.004,
    },
}

# write config as a json file.
experiment_dir = Path(config["experiment_dir"])
experiment_dir.mkdir(parents=True, exist_ok=True)

# Create the paths for the config, model weights and training history.
config_path = experiment_dir / "config.json"
logs_path = experiment_dir / "logs"
weights_path = experiment_dir / "weights.h5"
json.dump(config, open(str(config_path), "w"), indent=4)

# After this, convert the strings of layers to the actual layers.
for attr in ["norm_layer", "pooling_layer", "upsample_layer", "attention_layer"]:
    config["model"][attr] = import_method(config["model"][attr])


# Dataset

In [3]:
train_dset = get_dataset(config["train_tfrecord_path"], **config["dataloader"])
val_dset = get_dataset(
    config["val_tfrecord_path"], batch_size=config["dataloader"]["batch_size"]
)




# Create the model

In [4]:
encoder = UNetEncoder(**config["model"])
skip_names = get_skip_names_from_encoder(encoder)
model = UNet(encoder, skip_names=skip_names, **config["model"])

total_steps = (
    int(config["num_train_samples"] / config["dataloader"]["batch_size"])
    * config["epochs"]
)
warmup_steps = int(total_steps * config["warmup_epoch_percentage"])
scheduled_lrs = WarmUpCosine(
    learning_rate_base=config["optimizer"]["learning_rate"],
    total_steps=total_steps,
    warmup_learning_rate=config["optimizer"]["learning_rate"],
    warmup_steps=warmup_steps,
)
opt = tf.keras.optimizers.AdamW(scheduled_lrs, config["optimizer"]["weight_decay"])

model.compile(
    opt, loss=DiceCELoss(y_one_hot=True, reduce_batch=True, include_background=False)
)
model.fit(
    train_dset.take(1),
    validation_data=val_dset.take(1),
    epochs=1,
    callbacks=[
        keras.callbacks.TensorBoard(logs_path),
        keras.callbacks.ModelCheckpoint(
            weights_path, save_best_only=True, save_weights_only=True
        ),
    ],
    verbose=1,
)


- [ ] add model and tensorboard callbacks
- [ ] add dice score metric
- [ ] add lr schedule