# Segmentation 

In [None]:
import numpy as np
import importlib
import os
import tensorflow as tf
from tensorflow.keras.layers import *
from models import Unet2D, Unet3D, StandardUnet3D
import dataio
import utils
import models
import ants
from dataio import *
from utils import *
from typing import Any
print("TensorFlow version: {}".format(tf.__version__))
print("Eager execution: {}".format(tf.executing_eagerly()))
K.clear_session()
os.environ["CUDA_VISIBLE_DEVICES"]="0"
importlib.reload(dataio)
importlib.reload(utils)
importlib.reload(models)
import warnings
warnings.filterwarnings('ignore')
from sklearn.model_selection import train_test_split
# from keras_radam.optimizer_v2 import RAdam

## Setup

In [None]:
prefix = 'SegmentGSR'
# The csv_file variable contains name of CSV file used to load the images
csv_file = 'GSR_fucts.csv'
# Reduce images to this size
new_shape = (256,256,32)
# if first coordinate of patch_shape is 0, the whole image is used instead of patches
patch_shape = (0,256,32)
crop = 0
# When using patches, sample n_exmaple patches
n_examples = 6 
alpha = [7e2,7e2,0]
sigma = [6,6,1]
augment_funcs = [identity, flip_image_horizontally, flip_image_vertically]

if patch_shape[0]>0:
    type_str = '_patches_'
else:
    type_str = '_imgs_'
    
if augment_funcs:
    type_str = type_str + 'augm_'

### Create Datasets

In [None]:
# csv_file = 'iStroke_148_FUCTS_CTAS.csv'
img_type = 'FUCT'
subject_df = pd.read_csv(csv_file,comment='#').dropna()
subject_df.columns.str.match("Unnamed")
subject_df.loc[:,~subject_df.columns.str.match("Unnamed")]
subject_df = subject_df.query("vol > 50000")
strat_col = 'vol'
train_df, test_df = split_dataset(subject_df,strat_col,0.1,0)
train_df, val_df = split_dataset(train_df,strat_col,0.1,0)

file_writer = tf.summary.create_file_writer(prefix+'/logs')

train_TFRfile = prefix + '/train' + type_str + '.TFRecords'
val_TFRfile = prefix + '/validation' + type_str + '.TFRecords'
test_TFRfile = prefix + '/test' + type_str + '.TFRecords'

if patch_shape[0]==0:
    example_shape = (new_shape[0],new_shape[1],patch_shape[2])
else:
    example_shape = patch_shape

print(subject_df.shape) 

In [None]:
# from icecream import ic
plt.hist(train_df['vol'],3)
plt.show()
# ic(train_df.shape[0])

plt.hist(val_df['vol'],3)
plt.show()
# ic(val_df.shape[0])

plt.hist(test_df['vol'],3)
plt.show()
# ic(test_df.shape[0])



In [None]:
print('training data...')
train_TFRfile = subjects_to_TFRecords(train_df, img_type, train_TFRfile, patch_shape, new_shape, crop, n_examples, augment_funcs)
train_ds = tf.data.TFRecordDataset(train_TFRfile).map(read_tfrecord(example_shape))
print(f'Number of subjects in training dataset is {train_df.shape[0]}')
print(f'Number of samples in training dataset is {get_dataset_size(train_ds)}\n')

In [None]:
print('validation data...')
val_TFRfile = subjects_to_TFRecords(val_df,img_type, val_TFRfile, patch_shape, new_shape, crop, n_examples, None)
val_ds = tf.data.TFRecordDataset(val_TFRfile).map(read_tfrecord(example_shape))
print(f'Number of subjects in validation dataset is {val_df.shape[0]}')
print(f'Number of samples in validation dataset is {get_dataset_size(val_ds)}\n')

In [None]:
print('testing data...')
test_TFRfile = subjects_to_TFRecords(test_df,img_type, test_TFRfile, patch_shape, new_shape, crop, n_examples, None)
test_ds = tf.data.TFRecordDataset(test_TFRfile).map(read_tfrecord(example_shape))
print(f'Number of subjects in testing dataset is {test_df.shape[0]}')
print(f'Number of samples in testing dataset is {get_dataset_size(test_ds)}\n')

### Clip data

In [None]:
train_ds = tf.data.TFRecordDataset(train_TFRfile).map(read_tfrecord(example_shape)).map(ds_clip(0,150))
val_ds = tf.data.TFRecordDataset(val_TFRfile).map(read_tfrecord(example_shape)).map(ds_clip(0,150))
test_ds = tf.data.TFRecordDataset(test_TFRfile).map(read_tfrecord(example_shape)).map(ds_clip(0,150))

### Inspect Data

In [None]:
# from icecream import ic
def plot_patch_label_pair(patch: np.array, label: np.array) -> Any:
    slvol = np.sum(label,axis=(0,1))
    idx = np.where(slvol==np.max(slvol))[0][0]
    fig = plt.figure(figsize=(12,12))
    ax1 = fig.add_subplot(1, 2, 1)
    im = ax1.imshow(np.squeeze(patch[:,:,idx]))
    ax1.grid(False)
    ax2 = fig.add_subplot(1, 2, 2)
    ax2.imshow(np.squeeze(label[:,:,idx]))
    ax2.grid(False)
    plt.show()

ds = train_ds
for p,l in ds.take(6):
    print(p.shape)
    plot_patch_label_pair(p,l)

### Setup LR Schedule

In [None]:
from tensorflow.keras.callbacks import LearningRateScheduler

class LearningRateDecay:
    def plot(self, epochs, title="Learning Rate Schedule"):
        # compute the set of learning rates for each corresponding
        # epoch
        lrs = [self(i) for i in epochs]
        # the learning rate schedule
        plt.style.use("ggplot")
        plt.figure()
        plt.plot(epochs, lrs)
        plt.title(title)
        plt.xlabel("Epoch #")
        plt.ylabel("Learning Rate")
        
class StepDecay(LearningRateDecay):
    def __init__(self, initAlpha=0.01, factor=0.1, dropEvery=10):
        # store the base initial learning rate, drop factor, and
        # epochs to drop every
        self.initAlpha = initAlpha
        self.factor = factor
        self.dropEvery = dropEvery
    def __call__(self, epoch):
        # compute the learning rate for the current epoch
        exp = np.floor((1 + epoch) / self.dropEvery)
        alpha = self.initAlpha / (1+ self.factor * epoch)
        # return the learning rate
        return float(alpha)
    
step_schedule = StepDecay(initAlpha=1e-3,factor=0.01, dropEvery = 100)
step_schedule.plot(range(0,400))
lr_callback = LearningRateScheduler(step_schedule)

### Setup callbacks

In [None]:
import datetime
file_writer = tf.summary.create_file_writer(prefix+'/logs')
dtnow=datetime.datetime.now().strftime("%d%m%Y-%H%M%S")
log_dir=prefix+"/logs/logs_" + dtnow
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=0,profile_batch=0,embeddings_freq=0)

# Set up checkpoints 
checkpoint_path = prefix+ "/checkpoints/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                             save_weights_only=True,save_best_only = True,
                                             verbose=1,monitor='val_dice_coef',mode='max')

### Setup Model

In [None]:
from losses import dice_coef_loss, dice_coef
from tensorflow.keras.optimizers import Adam
conf = dict(model = 'UNET',
            filters = [16,32,64],
            bnorm=True,
            SN = False,
            se = False,
            SA = False,
            r = 8,
            alpha = 0.0,
            dropout = 0.0,
            lr = 1e-3,
            wd=0.0e-4,
            spatial_dropout = 0.20,
            ratio=2,
            bs=1,
            epochs=300,
            noise_sigma=0,
            init = 'lecun_uniform'
            )
print(conf)

if patch_shape[-1] > 1:
    input_img = Input(shape=(None, None,None, 1))
    model = Unet3D(input_img,conf)
else:
    input_img = Input(shape=(None, None, 1))
    model = Unet2D(input_img,conf)
    
Optimizer = Adam(conf['lr'])
loss = dice_coef_loss
miou = tf.keras.metrics.MeanIoU(num_classes=2)
model.compile(optimizer=Optimizer,loss=loss, metrics=['accuracy', miou,dice_coef])
print(model.summary())

## Train

In [None]:
print(f'Fitting model over {conf["epochs"]} epochs with batch size of {conf["bs"]} ....')
tf.keras.backend.clear_session()
history = model.fit(train_ds.batch(conf["bs"],drop_remainder=True),validation_data=val_ds.batch(1),\
                            epochs=conf['epochs'],callbacks=[tensorboard_callback,cp_callback],verbose=0)#,callbacks=[cp_callback])# callbacks=[tensorboard_callback]) 

## Test

In [None]:
model.load_weights(checkpoint_path)
model_path = prefix + '/GSR_CTAPRED_256x256x32_0.h5'
model.save(model_path)

In [None]:
model.evaluate(test_ds.batch(1),batch_size=1)

In [None]:
print(img_type)
results = predict_dataframe(model,test_df, img_type, example_shape, new_shape, crop,(0,150))

In [None]:
df = results.query("vol_gt>2500")
print(df)
print(df.mean())

In [None]:
res= predict_subject([model_path],test_df.iloc[0],img_type,new_shape,example_shape, crop,True,False,(0,150),16,True,0.5)