In [1]:
import os
import math
import random
import argparse
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
import segmentation_models as sm
from time import gmtime, strftime

2022-11-08 03:44:21.726814: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-11-08 03:44:21.817376: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2022-11-08 03:44:21.836540: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Segmentation Models: using `keras` framework.


In [2]:
def split_train_test(tfrecords_pattern, rate=0.8, buffer_size=10000):
    filenames = tf.io.gfile.glob(tfrecords_pattern)
    random.shuffle(filenames)
    split_idx = int(len(filenames) * rate)

    return filenames[:split_idx], filenames[split_idx:]

In [3]:
def decode_image(image):
    image = tf.io.decode_png(image, 1, dtype=tf.dtypes.uint16)
    return image

In [4]:
def read_tfrecord(example):
    features_description = {
        "filename": tf.io.FixedLenFeature([], tf.string),
        "number": tf.io.FixedLenFeature([], tf.int64),
        "sample": tf.io.FixedLenFeature([], tf.int64),
        "image_raw": tf.io.FixedLenFeature([], tf.string),
        "mask_raw": tf.io.FixedLenFeature([], tf.string)
    }

    example = tf.io.parse_single_example(example, features_description, name="nii")
    image = decode_image(example["image_raw"])
    mask = tf.cast(decode_image(example["mask_raw"]), dtype=tf.float32)
    sample = tf.cast(example["sample"], tf.bool)
    
    return (image, mask), sample


In [5]:
def load_dataset(filenames):
    dataset = tf.data.TFRecordDataset(
        filenames,
        compression_type="GZIP"
    )
    dataset = dataset.map(read_tfrecord, num_parallel_calls=tf.data.AUTOTUNE)
    
    for n_num, _ in enumerate(dataset):
        pass
    
    dataset = dataset.shuffle(2048, reshuffle_each_iteration=False)

    return dataset, n_num

In [6]:
def get_dataset(filenames, batch=4, repeat=False):
    dataset, n_num = load_dataset(filenames)
    dataset = dataset.filter(lambda x, s: s)
    dataset = dataset.map(lambda x, s: x)
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch)
    if repeat:
        dataset = dataset.repeat()

    return dataset, n_num

In [7]:
def show_batch(image_batch):
    plt.figure(figsize=(10, 10))
    for n in range(len(image_batch)):
        ax = plt.subplot(5, 5, n + 1)
        plt.imshow(image_batch[n])
        plt.axis("off")

In [8]:
def main(args):
    # Load Dataset
    x_list, y_list = split_train_test(
        os.path.join(args.dataset, "*.tfrecord"),
        rate=args.train_rate
    )

    x_dataset, x_num = get_dataset(x_list, batch=args.batch, repeat=True)    
    y_dataset, _ = get_dataset(y_list, batch=args.batch)
    
    # Build Model
    model = sm.Unet(args.backbone, encoder_weights=None, input_shape=(None, None, 1))
    
    dice_loss = sm.losses.DiceLoss()
    focal_loss = sm.losses.BinaryFocalLoss()
    totol_loss = dice_loss + (1 * focal_loss)
    
    model.compile(
        keras.optimizers.Adam(args.lr),
        loss=totol_loss,
        metrics=[sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)],
    )
    
    # current_time = strftime('%Y%m%d%H%M%S', gmtime())
    current_time = "20221107165659"
    logdir = os.path.join(args.logs, f"{args.name}-{current_time}")
    callbacks = [
        keras.callbacks.TensorBoard(log_dir=logdir),
        keras.callbacks.ModelCheckpoint(os.path.join(logdir, f"ICH-{args.name}-"+"{epoch}.h5"), save_weights_only=True, save_best_only=False, mode='min'),
        keras.callbacks.ReduceLROnPlateau(),
    ]

    # Training
    history = model.fit(
        x_dataset,
        epochs=args.epoch,
        steps_per_epoch=int(math.ceil(1. * x_num) / args.batch),
        callbacks=callbacks,
        validation_data=y_dataset,
        initial_epoch=31
    )

    print("fitted")
    
    # Plot training & validation iou_score values
    plt.figure(figsize=(30, 5))
    plt.subplot(121)
    plt.plot(history.history['iou_score'])
    plt.plot(history.history['val_iou_score'])
    plt.title('Model iou_score')
    plt.ylabel('iou_score')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Test'], loc='upper left')

    # Plot training & validation loss values
    plt.subplot(122)
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('Model loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Test'], loc='upper left')
    plt.savefig("training.png")

In [9]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--name",
        default="ICH420",
        help="Training Name"
    )
    parser.add_argument(
        "--backbone",
        default="resnet101",
        help="Model Backbone"
    )
    parser.add_argument(
        "--batch",
        default=16,
        help="Batch Size",
        type=int
    )
    parser.add_argument(
        "--epoch",
        default=10,
        help="Training Epoch",
        type=int
    )
    parser.add_argument(
        "--lr",
        default=0.001,
        help="Steps per Epoch",
        type=float
    )
    parser.add_argument(
        "--dataset",
        default=os.path.join(os.getcwd(), "datasets/ICH_420/TFRecords/train"),
        help="/path/to/dataset"
    )
    parser.add_argument(
        "--train_rate",
        default=0.8,
        help="Use to split dataset to 'train' and 'valid'",
        type=float
    )
    parser.add_argument(
        "--logs",
        default=os.path.join(os.getcwd(), "logs"),
        help="/path/to/logs"
    )
    main(parser.parse_args([
        "--name", "ICH420",
        "--backbone", "resnet101",
        "--batch", "16",
        "--epoch", "100",
        "--lr", "0.001",
        "--dataset", os.path.join(os.getcwd(), "datasets/ICH_420/TFRecords/train"),
        "--train_rate", "0.8",
        "--logs", os.path.join(os.getcwd(), "logs")
    ]))

2022-11-08 03:44:23.555084: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:980] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-11-08 03:44:23.556761: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:980] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-11-08 03:44:23.556843: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:980] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-11-08 03:44:23.557125: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compi

Epoch 32/100


2022-11-08 03:44:38.034834: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8100
2022-11-08 03:44:38.901720: I tensorflow/stream_executor/cuda/cuda_blas.cc:1614] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.


Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78/100
Epoch 79/100
Epoch 80/100
Epoch 81/100
Epoch 82/100
Epoch 83/100
Epoch 84/100
Epoch 85/100
Epoch 86/100
Epoch 87/100
Epoch 88/100
Epoch 89/100
Epoch 90/100
Epoch 91/100
Epoch 92/100
Epoch 93/100
Epoch 94/100
Epoch 95/100
Epoch 96/100
Epoch 97/100
Epoch 98/100
Epoch 99/100
  81/1056 [=>............................] - ETA: 9:55 - loss: 0.0090 - iou_score: 0.9843 - f1-score: 0.9921

KeyboardInterrupt: 