In [None]:
!pip install segmentation-models-3D

In [None]:
import math, re, os, gc
import tensorflow as tf
import tensorflow.keras.backend as K
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from kaggle_datasets import KaggleDatasets
from tensorflow import keras
from functools import partial
from sklearn.model_selection import train_test_split
import tensorflow_addons as tfa
import matplotlib.ticker as mtick
import tensorflow.experimental.numpy as tnp
import segmentation_models_3D as sm

In [None]:
import tensorflow as tf
try:
    # TPU detection. No parameters necessary if TPU_NAME environment variable is
    # set: this is always the case on Kaggle.
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.TPUStrategy(tpu)
else:
    # Default distribution strategy in Tensorflow. Works on CPU and single GPU.
    strategy = tf.distribute.get_strategy()

AUTOTUNE = tf.data.experimental.AUTOTUNE
print("REPLICAS: ", strategy.num_replicas_in_sync)

In [None]:
AUTOTUNE = tf.data.AUTOTUNE

def seed_all(s):
    random.seed(s)
    np.random.seed(s)
    tf.random.set_seed(s)
    os.environ['TF_CUDNN_DETERMINISTIC'] = '1'
    os.environ['PYTHONHASHSEED'] = str(s)
    
seed = 42

In [None]:
BATCH_SIZE = 1 * strategy.num_replicas_in_sync # Number of the batch size

In [None]:
PIXEL_MEAN = 0.25
MIN_BOUND = -1000.0
MAX_BOUND = 400.0

def normalize(image):
    image = (image - MIN_BOUND) / (MAX_BOUND - MIN_BOUND)
#     image[image>1] = 1.
#     image[image<0] = 0.
    image = tf.multiply(image, tf.cast(tf.math.greater(image, tf.constant([0], dtype=tf.float32)), tf.float32))
    image = tf.multiply(image, tf.cast(tf.math.less(image, tf.constant([1], dtype=tf.float32)), tf.float32))
    return image

def zero_center(image):
    image = image - PIXEL_MEAN
    return image

In [None]:
IMG_DEPTH = 128
IMG_HEIGHT = 512
IMG_WIDTH = 512
IMG_CHANNELS = 3

feature_description = {
    'scan_id': tf.io.FixedLenFeature([], tf.string),
    'height': tf.io.FixedLenFeature([], tf.int64),
    'width': tf.io.FixedLenFeature([], tf.int64),
    'num_channels': tf.io.FixedLenFeature([], tf.int64),
    'image': tf.io.FixedLenFeature([], tf.string),
    'mask': tf.io.FixedLenFeature([], tf.string),
}

feature_list_description = {
    'annotation_indices': tf.io.FixedLenSequenceFeature([], tf.int64),
    'random_sample_indices': tf.io.FixedLenSequenceFeature([], tf.int64),
    'annotated_indices_corrected': tf.io.FixedLenSequenceFeature([], tf.int64)
}

def _parse_image_function(example_proto):
  # Parse the input tf.Example proto using the dictionary above.
    feature, features_list = tf.io.parse_single_sequence_example(
                            example_proto, 
                            sequence_features=feature_list_description,                                                   
                            context_features=feature_description
                        )
#     single_example = tf.io.parse_single_example(example_proto, image_feature_description)
    scan_id = feature['scan_id']
    img_height = feature['height']
    img_width = feature['width']
    num_images = feature['num_channels']
        
    img_bytes =  tf.io.parse_tensor(feature['image'], out_type=tf.float32) #tf.io.decode_raw(feature['image'],out_type='double')
    img_array = tf.reshape(img_bytes, shape=[num_images, img_height, img_width])
   
    mask_bytes =  tf.io.parse_tensor(feature['mask'], out_type=tf.int8) #tf.io.decode_raw(feature['mask'],out_type='bool')
    mask = tf.reshape(mask_bytes, shape=[num_images, img_height, img_width])
    
    img_array = tf.cast(img_array, tf.float32)
    mask = tf.cast(mask, tf.float32)
    
#     img_array = zero_center(normalize(img_array))
    
    annotation_indices = tf.cast(features_list['annotation_indices'], tf.int32)
    random_sample_indices = tf.cast(features_list['random_sample_indices'], tf.int32)
    annotated_indices_corrected = tf.cast(features_list['annotated_indices_corrected'], tf.int32)
    
    struct = {
        'scan_id': scan_id,
        'height': img_height,
        'width': img_width,
        'num_images': num_images,
        'img': img_array,
        'mask': mask,
        'annotation_indices': annotation_indices,
        'random_sample_indices': random_sample_indices,
        'annotated_indices_corrected': annotated_indices_corrected
    }
    
#     index_tensor = tf.range(128)
    
#     if tf.shape(annotation_indices)[0] > 0:
#         new_indices_mask = tf.reduce_any(random_sample_indices[None, :] == annotation_indices[:, None], axis=0)    
#         annotated_tensor = tf.boolean_mask(index_tensor, new_indices_mask)
#     else:
#         annotated_tensor = tf.boolean_mask(index_tensor, tf.broadcast_to([False], tf.shape(index_tensor)))
    
#     num_indices_needed = 32 - tf.shape(annotated_indices_corrected)[0]
        
#     if num_indices_needed > 0 and tf.shape(annotated_indices_corrected)[0] > 0:
#         new_indices_mask_empty = tf.reduce_any(random_sample_indices[None, :] != annotation_indices[:, None], axis=0)    
#         unannotated_tensor = tf.boolean_mask(index_tensor, new_indices_mask_empty)[:num_indices_needed]
# #         annotated_tensor = tf.sort(tf.concat([annotated_tensor, unannotated_tensor], axis=0))
#     else:
#         unannotated_tensor = tf.boolean_mask(index_tensor, tf.broadcast_to([True], tf.shape(index_tensor)))
#         annotated_tensor = tf.sort(tf.concat([annotated_tensor, unannotated_tensor], axis=0))
    
#     new_indices = tf.sort(tf.concat([annotated_indices_corrected, unannotated_tensor], axis=0))[:32]
    
#     img_array = tf.gather(img_array, tf.sort(annotated_indices_corrected))
#     mask = tf.gather(mask, tf.sort(annotated_indices_corrected))
    
#     pad_images = 128 - tf.shape(img_array)[0]
    
#     if pad_images > 0:
#         img_padding = tf.constant(-1, dtype=tf.float32, shape=(pad_images, IMG_WIDTH, IMG_HEIGHT, IMG_CHANNELS))
#         mask_padding = tf.constant(0, dtype=tf.float32, shape=(pad_images, IMG_WIDTH, IMG_HEIGHT, 1))

#         img_array = tf.concat([img_array, img_padding], axis=0)
#         mask = tf.concat([img_array, mask], axis=0)
    
    img_array = tf.stack([img_array, img_array, img_array], axis=-1)
    
#     img_array = tf.reshape(img_array, [IMG_DEPTH, IMG_WIDTH, IMG_HEIGHT, IMG_CHANNELS])
    mask = tf.reshape(mask, [IMG_DEPTH, IMG_WIDTH, IMG_HEIGHT, 1])
    
    img_array = tf.image.resize(img_array, tf.constant([128, 128]))
    mask = tf.image.resize(mask, tf.constant([128, 128]))
    
    img_array = tf.reshape(img_array, [128, 128, 128, 3])
    mask = tf.reshape(mask, [128, 128, 128, 1])
    
    mask = tf.cast(mask > 0, dtype=tf.float32)
                         
    return img_array, mask

def read_tf_dataset(storage_file_path):
    encoded_image_dataset = tf.data.TFRecordDataset(storage_file_path, compression_type="GZIP")
    parsed_image_dataset = encoded_image_dataset.map(_parse_image_function)
    return parsed_image_dataset

In [None]:
# Augmentations - Source: https://www.kaggle.com/code/sreevishnudamodaran/rsna-3d-clahe-voxels-tpu-3d-augmentations/notebook

from tensorflow_addons.image import utils as img_utils

FLIP = 0.5 # @params: probability
CONTRAST = (0.7, 1.3, 0.5) # @params: (minval, maxval, probability)
BRIGHTNESS = (0.2, 0.4) # @params: (delta, probability)
GAMMA = (0.8, 1.2, 0.5) # @params: (minval, maxval, probability)
ROTATE = (20, 0.5) # @params: (maxangle, probability)
RANDOM_CROP = (64, 64, 0.4) # @params: (min_width, min_height, probability)
CUTOUT = ((8, 8), 8, 0.4) # @params: ((mask_dim0, mask_dim1), max_num_holes, probability)
BLUR = ([5, 5], 10, 0.4) # @params: ((filter_dim0, filter_dim1), sigma, probability)

def random_rotate3D(voxel, limit=90, p=0.5):
    if tf.random.uniform(()) < p:
        angle = tf.random.uniform((), minval=-limit, maxval=limit,
                                  dtype=tf.int32)
        voxel = tfa.image.rotate(voxel, tf.cast(angle,
                                                tf.float32)*(math.pi/180),
                                 interpolation='nearest',
                                 fill_mode='constant',
                                 fill_value=0.0)
    return voxel

def random_resized_crop3D(voxel, min_width, min_height, p=0.5):
    if tf.random.uniform(()) < p:
        voxel_shape = voxel.shape
        assert voxel_shape[1] >= min_height
        assert voxel_shape[2] >= min_width
        
        width = tf.random.uniform((), minval=min_width,
                                  maxval=voxel_shape[2],
                                  dtype=tf.int32)
        height = tf.random.uniform((), minval=min_height,
                                   maxval=voxel_shape[1],
                                   dtype=tf.int32)
        x = tf.random.uniform((), minval=0,
                              maxval=voxel_shape[2] - width,
                              dtype=tf.int32)
        y = tf.random.uniform((), minval=0,
                              maxval=voxel_shape[1] - height,
                              dtype=tf.int32)
        voxel = voxel[:, y:y+height, x:x+width, :]
        voxel = tf.image.resize(voxel,
                                voxel_shape[1:3],
                                method='lanczos5')
    return voxel

def random_cutout3D(voxel, mask_shape=(10, 10), num_holes=20, p=0.5):
    if tf.random.uniform(()) < p:
        voxel_shape = voxel.shape
        assert voxel_shape[1] >= mask_shape[0]
        assert voxel_shape[2] >= mask_shape[1]

        holes = tf.random.uniform((), minval=1, maxval=num_holes,
                                  dtype=tf.int32)
        mask_size = tf.constant([mask_shape[0], mask_shape[1]])
        mask = tf.Variable((lambda : tf.ones(voxel_shape)),
                           trainable=False)
        
        for i in tf.range(holes):
            x = tf.random.uniform((), minval=0,
                                  maxval=voxel_shape[2],
                                  dtype=tf.int32)
            y = tf.random.uniform((), minval=0,
                                  maxval=voxel_shape[1],
                                  dtype=tf.int32)
            mask_endx = tf.add(x, mask_size[1])
            mask_endy = tf.add(y, mask_size[0])
            mask[:, x:mask_endx,
                 y:mask_endy, :].assign(tf.zeros_like(mask[:, x:mask_endx,
                                                        y:mask_endy, :]))
        voxel = tf.multiply(voxel, mask)
        mask.assign(tf.ones(voxel_shape))
    return voxel

def _get_gaussian_kernel(sigma, filter_shape):
    x = tf.range(-filter_shape // 2 + 1, filter_shape // 2 + 1)
    x = tf.cast(x ** 2, sigma.dtype)
    x = tf.nn.softmax(-x / (2.0 * (sigma ** 2)))
    return x

def random_gaussian_blur3D(voxel, filter_shape=[5, 5], max_sigma=3, p=0.5):
    if tf.random.uniform(()) < p:
        sigma = tf.random.uniform((), minval=3, maxval=max_sigma,
                          dtype=tf.int32)
        filter_shape = tf.constant(filter_shape)
        channels = voxel.shape[-1]
        sigma = tf.cast(sigma, voxel.dtype)
        gaussian_kernel_x = _get_gaussian_kernel(sigma,
                                                 filter_shape[1])
        gaussian_kernel_x = gaussian_kernel_x[tf.newaxis, :]
        gaussian_kernel_y = _get_gaussian_kernel(sigma,
                                                 filter_shape[0])        
        gaussian_kernel_y = gaussian_kernel_y[:, tf.newaxis]
        gaussian_kernel_2d = tf.matmul(gaussian_kernel_y,
                                       gaussian_kernel_x)
        gaussian_kernel_2d = gaussian_kernel_2d[:, :, tf.newaxis,
                                                tf.newaxis]
        gaussian_kernel_2d = tf.tile(gaussian_kernel_2d,
                                     tf.constant([1, 1, channels, 1]))
        voxel = tf.nn.depthwise_conv2d(input=voxel,
                                       filter=gaussian_kernel_2d,
                                       strides=(1, 1, 1, 1),
                                       padding="SAME",
                                       )
        voxel = tf.cast(voxel, voxel.dtype)
    return voxel

def build_augmenter(with_labels=True):
    '''
    Performing tranformations with the same seed
    to ensure the same tranformation is applied to every voxel slice.
    ''' 
    def augment(voxel):
        aug_seed = tf.random.uniform((2,), minval=1, maxval=9999, dtype=tf.int32)
        if tf.random.uniform(()) < FLIP:
            if tf.random.uniform(()) < 0.5:
                voxel = tf.image.flip_up_down(voxel)
            else:
                voxel = tf.image.flip_left_right(voxel)
                
        if tf.random.uniform(()) < BRIGHTNESS[1]:
            voxel = tf.image.adjust_brightness(
                voxel, tf.random.uniform((), minval=0.0,
                                         maxval=BRIGHTNESS[0],
                                         seed=seed))
        if tf.random.uniform(()) < CONTRAST[2]:
            voxel = tf.image.adjust_contrast(
                voxel, tf.random.uniform((), minval=CONTRAST[0],
                                         maxval=CONTRAST[1],
                                         seed=seed))
        if tf.random.uniform(()) < GAMMA[2]:
            voxel = tf.image.adjust_gamma(
                voxel, tf.random.uniform((), minval=GAMMA[0],
                                         maxval=GAMMA[1],
                                         seed=seed))
        voxel = random_rotate3D(voxel, limit=ROTATE[0],
                                p=ROTATE[1])
        voxel = random_resized_crop3D(voxel, RANDOM_CROP[0],
                                      RANDOM_CROP[1],
                                      p=RANDOM_CROP[2])
        voxel = random_cutout3D(voxel, mask_shape=CUTOUT[0],
                                num_holes=CUTOUT[1],
                                p=CUTOUT[2])
        voxel = random_gaussian_blur3D(voxel,
                                       filter_shape=BLUR[0],
                                       max_sigma=BLUR[1],
                                       p=BLUR[2])

        # Remove nans in place of black pixels in some imgs.
        # Anyone knows the reason for the nans?
        voxel = tf.where(tf.math.is_nan(voxel),
                         tf.zeros_like(voxel), voxel)
        voxel = tnp.maximum(tnp.array([0.]), voxel)
        voxel = tnp.minimum(tnp.array([1.]), voxel)
        voxel = tf.cast(voxel, tf.float32)
        return voxel
    
    def augment_with_labels(voxel, label):
        return augment(voxel), label
    
    return augment_with_labels if with_labels else augment

In [None]:
def load_dataset(filenames, ordered=False):
    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE, compression_type="GZIP")
    dataset = dataset.with_options(ignore_order)
    dataset = dataset.map(_parse_image_function, num_parallel_calls=AUTOTUNE)
    return dataset

def get_training_dataset(dataset, do_aug=True, shuffleBuffer=5):

    AUTOTUNE = tf.data.experimental.AUTOTUNE
    
    if do_aug:
        augment_fn = build_augmenter(with_labels=True)
        dataset = dataset.map(augment_fn, num_parallel_calls=AUTOTUNE)
        
    dataset = dataset.repeat()
#     dataset = dataset.shuffle(shuffleBuffer)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTOTUNE)
    
    return dataset

def get_val_dataset(dataset, do_aug=False):
    if do_aug:
        pass
    dataset = dataset.batch(BATCH_SIZE)
#     dataset = dataset.cache()
    dataset = dataset.prefetch(AUTOTUNE)
    return dataset

In [None]:
from kaggle_datasets import KaggleDatasets
GCS_DS_PATH = KaggleDatasets().get_gcs_path('tfrecords4')
GCS_DS_PATH

In [None]:
train_tf_gcs = GCS_DS_PATH + '/train*.tfrec'
val_tf_gcs = GCS_DS_PATH +'/val*.tfrec'
train_tf_files = tf.io.gfile.glob(train_tf_gcs)
val_tf_files = tf.io.gfile.glob(val_tf_gcs)
# print(val_tf_files[:3])
print("Train TFrecord Files:", len(train_tf_files))
print("Val TFrecord Files:", len(val_tf_files))

In [None]:
train_dataset = load_dataset(train_tf_files)
val_dataset = load_dataset(val_tf_files)

In [None]:
NUM_FILES = sum([int(file.split('/')[-1][-8:-6]) for file in train_tf_files])

In [None]:
# BATCH_SIZE = 1
BACKBONE = 'vgg16'
EPOCHS = 10
STEPS_PER_EPOCH = NUM_FILES // BATCH_SIZE
n_classes = 1
activation = 'sigmoid'
with strategy.scope():
    model = sm.Unet(BACKBONE, input_shape=(128, 128, 128, 3), encoder_weights='imagenet')
    optimizer = tf.keras.optimizers.Adam(1e-3)
    total_loss = sm.losses.binary_focal_dice_loss
    metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]
    model.compile(optimizer, total_loss, metrics)

In [None]:
START_LR = 1e-6
MAX_LR = 1e-3
MIN_LR = 1e-8
LR_RAMP = 3
LR_SUSTAIN = 3
LR_DECAY = 0.90

def lrfn(epoch):
    if LR_RAMP > 0 and epoch < LR_RAMP:
        lr = (MAX_LR-START_LR)/(LR_RAMP*1.0)*epoch+START_LR
    elif epoch < LR_RAMP+LR_SUSTAIN:
        lr = MAX_LR
    else: # exponential decay from MAX_LR to MIN_LR
        lr = (MAX_LR-MIN_LR)*LR_DECAY**(epoch-LR_RAMP-LR_SUSTAIN)+MIN_LR
    return lr
    
@tf.function
def lrfn_tffun(epoch):
    return lrfn(epoch)

fig = plt.figure(figsize=(14, 5))
ax = fig.add_subplot(111)
rng = [i for i in range(EPOCHS)]
plt.plot(rng, [lrfn(x) for x in rng],
         marker='o')
plt.ticklabel_format(axis="y", style="plain")
ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e'))
plt.show()

lr_callback = tf.keras.callbacks.LearningRateScheduler(lambda epoch: lrfn_tffun(epoch), verbose=1)

model_save = tf.keras.callbacks.ModelCheckpoint('./model.h5',
                             save_best_only = True, 
                             monitor = 'val_iou_score', 
                             mode = 'max', verbose = 1)

early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_iou_score', patience=2, mode='max', verbose=1)

In [None]:
model.summary()

In [None]:
history = model.fit(
    get_training_dataset(train_dataset, do_aug=True),
    steps_per_epoch = STEPS_PER_EPOCH,
    epochs = EPOCHS,
    callbacks = [model_save,
                 lr_callback,
                 early_stop
                ],
    validation_data = get_val_dataset(val_dataset),
    verbose=1
)

In [None]:
with strategy.scope():
    model = tf.keras.models.load_model('./model.h5', custom_objects={'binary_focal_loss_plus_dice_loss': sm.losses.binary_focal_dice_loss,
                                    'dice_loss': sm.losses.dice_loss,
                                   'iou_score': sm.metrics.IOUScore(threshold=0.5), 
                                  'f1-score': sm.metrics.FScore(threshold=0.5)})

In [None]:
# history.history
val_dataset = load_dataset(val_tf_files, ordered=True)
valds = get_val_dataset(val_dataset)
pred = model.predict(valds, verbose=1)

In [None]:
plt.imshow(pred[0][50] > 0.9)