In [252]:
import os
import math
import random
import argparse
import numpy as np
import tensorflow as tf
from tensorflow import keras
from pptx import Presentation
from pptx.util import Inches
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import segmentation_models as sm
from time import gmtime, strftime
from PIL import Image
from matplotlib.backends.backend_agg import FigureCanvasAgg

In [253]:
def split_train_test(tfrecords_pattern, rate=0.8, buffer_size=10000):
    filenames = tf.io.gfile.glob(tfrecords_pattern)
    random.shuffle(filenames)
    split_idx = int(len(filenames) * rate)

    return filenames[:split_idx], filenames[split_idx:]

In [254]:
def decode_image(image):
    image = tf.io.decode_png(image, 1, dtype=tf.dtypes.uint16)
    return image

In [255]:
def read_tfrecord(example):
    features_description = {
        "filename": tf.io.FixedLenFeature([], tf.string),
        "number": tf.io.FixedLenFeature([], tf.int64),
        "sample": tf.io.FixedLenFeature([], tf.int64),
        "image_raw": tf.io.FixedLenFeature([], tf.string),
        "mask_raw": tf.io.FixedLenFeature([], tf.string)
    }

    example = tf.io.parse_single_example(example, features_description, name="nii")
    image = decode_image(example["image_raw"])
    mask = tf.cast(decode_image(example["mask_raw"]), dtype=tf.float32)
    sample = tf.cast(example["sample"], tf.bool)
    
    return (image, mask), sample


In [256]:
def load_dataset(filenames):
    dataset = tf.data.TFRecordDataset(
        filenames,
        compression_type="GZIP"
    )
    dataset = dataset.map(read_tfrecord, num_parallel_calls=tf.data.AUTOTUNE)
    
    for n_num, _ in enumerate(dataset):
        pass
    
    dataset = dataset.shuffle(2048, reshuffle_each_iteration=False)

    return dataset, n_num

In [257]:
def get_dataset(filenames, batch=4, repeat=False):
    dataset, n_num = load_dataset(filenames)
    # dataset = dataset.filter(lambda x, s: s)
    dataset = dataset.map(lambda x, s: x)
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch)
    if repeat:
        dataset = dataset.repeat()

    return dataset, n_num

In [258]:
def show_batch(image_batch):
    plt.figure(figsize=(10, 10))
    for n in range(len(image_batch)):
        ax = plt.subplot(5, 5, n + 1)
        plt.imshow(image_batch[n])
        plt.axis("off")

In [259]:
def compute_iou(gt_mask, pr_mask):
    inter = np.sum(gt_mask * pr_mask)
    union = np.sum(gt_mask) + np.sum(pr_mask) - inter
    return (inter / union), f"({np.sum(gt_mask)} x {np.sum(pr_mask)}) / ({np.sum(gt_mask)} + {np.sum(pr_mask)} - {np.sum(inter)})"
    # gt_mask = gt_tensor.numpy().T[0]
    # print(np.max(gt_mask), np.max(pr_mask.astype(np.bool_)))

In [260]:
# helper function for data visualization
def visualize(et="", fn=f"tmp-{''.join(random.sample('zyxwvutsrqponmlkjihgfedcba0123456789',6))}.png", **images):
    """PLot images in one row."""
    num_of_images = len(images)
    
    iou_score, message = compute_iou(images["gt_mask"], images["pr_mask"])     
    
    fig = plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        fig.suptitle(f"IoU: {iou_score}, {message}")
        plt.subplot(1, num_of_images, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        if name == "image":
            plt.imshow(image, cmap=plt.cm.gray)
        else:
            plt.imshow(image, cmap=plt.cm.binary_r)
        # plt.show()
        if i == 2:
            plt.savefig(os.path.join("/workspaces/Intracranial-Hemorrhage/ICH-Segmentation/output", f"{name}-{i}-{et}-{fn}"))
            plt.close()

In [261]:
# helper function for data visualization
def get_plot_image(**images):
    """PLot images in one row."""
    num_of_images = len(images)
    
    iou_score, message = compute_iou(images["gt_mask"], images["pr_mask"])     
    
    fig = plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        fig.suptitle(f"IoU: {iou_score}, {message}")
        plt.subplot(1, num_of_images, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        if name == "image":
            plt.imshow(image, cmap=plt.cm.gray)
        else:
            plt.imshow(image, cmap=plt.cm.binary_r)
        # # plt.show()
        # if i == 2:
        #     plt.savefig(os.path.join("/workspaces/Intracranial-Hemorrhage/ICH-Segmentation/output", f"{name}-{i}-{et}-{fn}"))
        #     plt.close()
    figure = plt.gcf().canvas
    ag = figure.switch_backends(FigureCanvasAgg)
    ag.draw()
    plt.close()
    return np.asarray(ag.buffer_rgba())

In [262]:
def main(args):
    # Load Dataset
    x_list, _ = split_train_test(
        os.path.join(args.dataset, "*.tfrecord"),
        rate=args.train_rate
    )

    x_dataset, x_num = get_dataset(x_list, batch=args.batch)    
    # y_dataset, _ = get_dataset(y_list, batch=args.batch)
    
    # Build Model
    model = sm.Unet(args.backbone, encoder_weights=None, input_shape=(None, None, 1))
    
    # dice_loss = sm.losses.DiceLoss()
    # focal_loss = sm.losses.BinaryFocalLoss()
    # totol_loss = dice_loss + (1 * focal_loss)
    
    model.load_weights("/workspaces/Intracranial-Hemorrhage/ICH-Segmentation/logs/ICH420-20221107165659/ICH-ICH420-50.h5")
    # model.compile(
    #     keras.optimizers.Adam(args.lr),
    #     loss=totol_loss,
    #     metrics=[sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)],
    # )
    
    prs = Presentation()
    title_slide_layout = prs.slide_layouts[0]
    slide = prs.slides.add_slide(title_slide_layout)
    title = slide.shapes.title
    title.text = "ICH 420 Paired"
    
    for idx, (image, mask) in tqdm(enumerate(x_dataset), total=x_num):
        if idx > 200:
            break
    
        slide = prs.slides.add_slide(prs.slide_layouts[1])
        shapes = slide.shapes
        title_shape = shapes.title
        title_shape.text = f"ICH420 _ {str(idx).zfill(6)}.png"
        
        result = model.predict(image)
        plot_image = get_plot_image(
            image=image[0],
            gt_mask=mask[..., 0].numpy().squeeze() > 0,
            pr_mask=result[..., 0].squeeze() > 0.5
        )
        plot_image = Image.fromarray(plot_image)
        plot_image.save("/tmp/image.png")
        shapes.add_picture("/tmp/image.png", Inches(0), Inches(2.5), Inches(10))
        
    prs.save("pythonppt.pptx")
            
        
    # current_time = strftime('%Y%m%d%H%M%S', gmtime())
    # logdir = os.path.join(args.logs, f"{args.name}-{current_time}")
    # callbacks = [
    #     keras.callbacks.TensorBoard(log_dir=logdir),
    #     keras.callbacks.ModelCheckpoint(os.path.join(logdir, f"ICH-{args.name}-"+"{epoch}.h5"), save_weights_only=True, save_best_only=False, mode='min'),
    #     keras.callbacks.ReduceLROnPlateau(),
    # ]

    # Training
    # history = model.fit(
    #     x_dataset,
    #     epochs=args.epoch,
    #     steps_per_epoch=int(math.ceil(1. * x_num) / args.batch),
    #     callbacks=callbacks,
    #     validation_data=y_dataset
    # )

    # print("fitted")
    
    # # Plot training & validation iou_score values
    # plt.figure(figsize=(30, 5))
    # plt.subplot(121)
    # plt.plot(history.history['iou_score'])
    # plt.plot(history.history['val_iou_score'])
    # plt.title('Model iou_score')
    # plt.ylabel('iou_score')
    # plt.xlabel('Epoch')
    # plt.legend(['Train', 'Test'], loc='upper left')

    # # Plot training & validation loss values
    # plt.subplot(122)
    # plt.plot(history.history['loss'])
    # plt.plot(history.history['val_loss'])
    # plt.title('Model loss')
    # plt.ylabel('Loss')
    # plt.xlabel('Epoch')
    # plt.legend(['Train', 'Test'], loc='upper left')
    # plt.savefig("training.png")

In [263]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--name",
        default="ICH420",
        help="Training Name"
    )
    parser.add_argument(
        "--backbone",
        default="resnet101",
        help="Model Backbone"
    )
    parser.add_argument(
        "--batch",
        default=16,
        help="Batch Size",
        type=int
    )
    parser.add_argument(
        "--epoch",
        default=10,
        help="Training Epoch",
        type=int
    )
    parser.add_argument(
        "--lr",
        default=0.001,
        help="Steps per Epoch",
        type=float
    )
    parser.add_argument(
        "--dataset",
        default=os.path.join(os.getcwd(), "datasets/ICH_420/TFRecords/val"),
        help="/path/to/dataset"
    )
    parser.add_argument(
        "--train_rate",
        default=0.8,
        help="Use to split dataset to 'train' and 'valid'",
        type=float
    )
    parser.add_argument(
        "--logs",
        default=os.path.join(os.getcwd(), "logs"),
        help="/path/to/logs"
    )
    main(parser.parse_args([
        "--name", "ICH420",
        "--backbone", "resnet101",
        "--batch", "1",
        "--epoch", "100",
        "--lr", "0.001",
        "--dataset", os.path.join(os.getcwd(), "datasets/ICH_420/TFRecords/val"),
        "--train_rate", "1",
        "--logs", os.path.join(os.getcwd(), "logs")
    ]))

  0%|          | 0/2485 [00:00<?, ?it/s]



  return (inter / union), f"({np.sum(gt_mask)} x {np.sum(pr_mask)}) / ({np.sum(gt_mask)} + {np.sum(pr_mask)} - {np.sum(inter)})"


