In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import numpy as np
import pandas as pd
import tensorflow as tf
from PIL import Image
from sklearn.model_selection import train_test_split

from model import build_model
from load_data import get_patches_path, data_generate

print(tf.__version__)
print("Num GPUs Available: ", len(tf.config.list_physical_devices("GPU")))

INPUT_PATCH_FOLDER = "data/train_imgs_patch"
PATCH_WIDTH = 128
PATCH_SHAPE = (PATCH_WIDTH, PATCH_WIDTH, 3)
BATCH_SIZE = 1

In [None]:
clinical_data = pd.read_csv("data/train.csv")
bag_names = list(clinical_data["ID"])
labels = list(clinical_data["N_category"])
patch_bags = get_patches_path(INPUT_PATCH_FOLDER)

In [None]:
(
    train_bag_names,
    val_bag_names,
    train_y,
    val_y,
    train_bags,
    val_bags,
) = train_test_split(
    bag_names[: len(patch_bags)], labels[: len(patch_bags)], patch_bags
)

In [None]:
train_dataset = tf.data.Dataset.from_generator(
        generator=data_generate,
        output_types=(tf.float32, tf.float32),
        output_shapes=(
            tf.TensorShape([None, PATCH_WIDTH, PATCH_WIDTH, 3]),
            tf.TensorShape([1, 1]),
        ),
        args=(train_bag_names, train_y, train_bags),
    )

val_dataset = tf.data.Dataset.from_generator(
    generator=data_generate,
    output_types=(tf.float32, tf.float32),
    output_shapes=(
        tf.TensorShape([None, PATCH_WIDTH, PATCH_WIDTH, 3]),
        tf.TensorShape([1, 1]),
    ),
    args=(val_bag_names, val_y, val_bags),
)

In [None]:
model = build_model(PATCH_SHAPE)
model.summary()

In [None]:
os.makedirs("check_points", exist_ok=True)
model_name = (
        "check_points/"
        + "acc({accuracy:.4f})"
        + "epoch({epoch})"
        + "val_loss({val_loss:.4f}).hd5"
    )

check_point = tf.keras.callbacks.ModelCheckpoint(
    model_name,
    monitor="val_loss",
    verbose=1,
    save_best_only=True,
    save_weights_only=True,
    mode="min",
)
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor="val_loss", patience=5
)
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
    monitor="val_loss",
    factor=0.2,  # (=decay)
    verbose=True,
)
callbacks = [check_point, early_stopping, reduce_lr]

model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=False),
    metrics=["accuracy"],
)

model.fit(
    train_dataset.repeat(),
    validation_data=val_dataset.repeat(),
    callbacks=callbacks,
    epochs=100,
    steps_per_epoch=int(len(train_bag_names) / BATCH_SIZE),
    validation_steps=int(len(val_bag_names) / BATCH_SIZE),
)