In [None]:
!pip install tf-nightly

In [None]:
import os
import sys
import json

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
os.environ.pop('TF_CONFIG', None)
if '.' not in sys.path:
  sys.path.insert(0, '.')

os.environ['TF_CONFIG'] = json.dumps({
    'cluster': {
        'worker': ['192.168.0.101:20000', '192.168.0.105:20000']
    },
    'task': {'type': 'worker', 'index': 0}
})

In [30]:
import math
import random
import argparse
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
import segmentation_models as sm
from time import gmtime, strftime

In [31]:
def split_train_test(tfrecords_pattern, rate=0.8, buffer_size=10000):
    filenames = tf.io.gfile.glob(tfrecords_pattern)
    random.shuffle(filenames)
    split_idx = int(len(filenames) * rate)

    return filenames[:split_idx], filenames[split_idx:]

In [32]:
def decode_image(image):
    image = tf.io.decode_png(image, 1, dtype=tf.dtypes.uint16)
    return image

In [33]:
def read_tfrecord(example):
    features_description = {
        "filename": tf.io.FixedLenFeature([], tf.string),
        "number": tf.io.FixedLenFeature([], tf.int64),
        "sample": tf.io.FixedLenFeature([], tf.int64),
        "image_raw": tf.io.FixedLenFeature([], tf.string),
        "mask_raw": tf.io.FixedLenFeature([], tf.string)
    }

    example = tf.io.parse_single_example(example, features_description, name="nii")
    image = decode_image(example["image_raw"])
    mask = tf.cast(decode_image(example["mask_raw"]), dtype=tf.float32)
    sample = tf.cast(example["sample"], tf.bool)
    
    return (image, mask), sample


In [34]:
def load_dataset(filenames):
    dataset = tf.data.TFRecordDataset(
        filenames,
        compression_type="GZIP"
    )
    dataset = dataset.map(read_tfrecord, num_parallel_calls=tf.data.AUTOTUNE)
    
    for n_num, _ in enumerate(dataset):
        pass
    
    dataset = dataset.shuffle(2048, reshuffle_each_iteration=False)

    return dataset, n_num

In [35]:
def get_dataset(filenames, batch=4, repeat=False):
    dataset, n_num = load_dataset(filenames)
    dataset = dataset.filter(lambda x, s: s)
    dataset = dataset.map(lambda x, s: x)
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch)
    if repeat:
        dataset = dataset.repeat()

    return dataset, n_num

In [36]:
def show_batch(image_batch):
    plt.figure(figsize=(10, 10))
    for n in range(len(image_batch)):
        ax = plt.subplot(5, 5, n + 1)
        plt.imshow(image_batch[n])
        plt.axis("off")

In [37]:
def main(args):
    # Load Dataset
    x_list, y_list = split_train_test(
        os.path.join(args.dataset, "*.tfrecord"),
        rate=args.train_rate
    )

    x_dataset, x_num = get_dataset(x_list, batch=args.batch, repeat=True)    
    y_dataset, _ = get_dataset(y_list, batch=args.batch)
    
    # Multi Node
    strategy = tf.distribute.MultiWorkerMirroredStrategy()
    
    with strategy.scope():
        # Build Model
        model = sm.Unet(args.backbone, encoder_weights=None, input_shape=(None, None, 1))
        
        dice_loss = sm.losses.DiceLoss()
        focal_loss = sm.losses.BinaryFocalLoss()
        totol_loss = dice_loss + (1 * focal_loss)
        
        model.compile(
            keras.optimizers.Adam(args.lr),
            loss=totol_loss,
            metrics=[sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)],
        )
            
        current_time = strftime('%Y%m%d%H%M%S', gmtime())
        logdir = os.path.join(args.logs, f"{args.name}-{current_time}")
        callbacks = [
            keras.callbacks.TensorBoard(log_dir=logdir),
            keras.callbacks.ModelCheckpoint(os.path.join(logdir, f"ICH-{args.name}-"+"{epoch}.h5"), save_weights_only=True, save_best_only=False, mode='min'),
            keras.callbacks.ReduceLROnPlateau(),
        ]

        # Training
        history = model.fit(
            x_dataset,
            epochs=args.epoch,
            steps_per_epoch=int(math.ceil(1. * x_num) / args.batch),
            callbacks=callbacks,
            validation_data=y_dataset
        )

    print("fitted")
    
    # Plot training & validation iou_score values
    plt.figure(figsize=(30, 5))
    plt.subplot(121)
    plt.plot(history.history['iou_score'])
    plt.plot(history.history['val_iou_score'])
    plt.title('Model iou_score')
    plt.ylabel('iou_score')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Test'], loc='upper left')

    # Plot training & validation loss values
    plt.subplot(122)
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('Model loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Test'], loc='upper left')
    plt.savefig("training.png")

In [38]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--name",
        default="ICH420",
        help="Training Name"
    )
    parser.add_argument(
        "--backbone",
        default="resnet101",
        help="Model Backbone"
    )
    parser.add_argument(
        "--batch",
        default=16,
        help="Batch Size",
        type=int
    )
    parser.add_argument(
        "--epoch",
        default=10,
        help="Training Epoch",
        type=int
    )
    parser.add_argument(
        "--lr",
        default=0.001,
        help="Steps per Epoch",
        type=float
    )
    parser.add_argument(
        "--dataset",
        default=os.path.join(os.getcwd(), "datasets/ICH_420/TFRecords/train"),
        help="/path/to/dataset"
    )
    parser.add_argument(
        "--train_rate",
        default=0.8,
        help="Use to split dataset to 'train' and 'valid'",
        type=float
    )
    parser.add_argument(
        "--logs",
        default=os.path.join(os.getcwd(), "logs"),
        help="/path/to/logs"
    )
    main(parser.parse_args([
        "--name", "ICH420",
        "--backbone", "resnet101",
        "--batch", "16",
        "--epoch", "100",
        "--lr", "0.001",
        "--dataset", os.path.join(os.getcwd(), "datasets/ICH_420/TFRecords/train"),
        "--train_rate", "0.8",
        "--logs", os.path.join(os.getcwd(), "logs")
    ]))

Epoch 1/100


2022-11-07 16:57:05.744944: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8100
2022-11-07 16:57:06.593364: I tensorflow/stream_executor/cuda/cuda_blas.cc:1614] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.


Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100