<a href="https://colab.research.google.com/github/Fr2zyRoom/ISLES2017_LesionSegmentation_Tutorial/blob/main/ISLES2017_2dslice_lesion_segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **ISLES2017 Lesion Segmentation : 2D slice models**

In [None]:
!pip install gdown

In [None]:
# Download ADC dataset(2d slices version) ~ .npz
!gdown "https://drive.google.com/uc?id=1rk8_WePcRn8sGxOwyJjwJKoBossyxl1A"

In [None]:
# Download packages
!gdown "https://drive.google.com/uc?id=1q8JDkCKe5Iv-Zzd_Rv17I1-alz1lM5qF"

In [None]:
!gdown "https://drive.google.com/uc?id=1TO6t8lS1Ie4rMoZuW12XQk1Tw7S_eEm6"

In [None]:
!unzip ISLES2017.zip -d ./ISLES2017

In [None]:
!unzip util.zip -d ./util

In [None]:
!unzip data.zip -d ./data

In [None]:
!apt-get install tree

In [None]:
!pip install --force-reinstall albumentations==1.0.3

In [None]:
!pip install segmentation-models-pytorch

In [None]:
!pip uninstall matplotlib
!pip install matplotlib==3.1.3

In [None]:
from tqdm import tqdm
import os
import copy

import numpy as np
import pandas as pd
import seaborn as sns
import PIL.Image as Image

import nibabel as nib

import matplotlib.pyplot as plt
from util.util import *
from util.visualize import *
from data.dataset_2d import *
from sklearn.model_selection import StratifiedKFold, train_test_split

import segmentation_models_pytorch as smp

In [None]:
def split_train_val_test(file_id, val_case):
    if file_id in val_case:
        return 'val'
    else:
        return 'train'

In [None]:
# split dataset
FOLD = 5
random_seed = 50

train_df = pd.read_csv("./ISLES2017/ISLES2017_Training_clr.csv")
case_name = train_df["Case SMIR ID 1"].values
mrss = train_df["MRSScore"].values

skf = StratifiedKFold(n_splits=FOLD, random_state=random_seed, shuffle=True)

skf.get_n_splits(case_name, mrss)

train_df_split = copy.deepcopy(train_df)

num=1
for train_index, test_index in skf.split(case_name, mrss):
    fold_num = 'fold'+str(num)
    X_train, X_val = case_name[train_index], case_name[test_index]
    #y_train, y_test = mrss[train_index], mrss[test_index]
    #X_train, X_val = train_test_split(X_train, test_size=0.2, random_state=random_seed, shuffle=True, stratify=y_train)
    train_df_split[fold_num] = train_df_split["Case SMIR ID 1"].map(lambda x: split_train_val_test(x, val_case=X_val))
    num += 1

In [None]:
kfold_df_path = "./ISLES2017/ISLES2017_Training_clr_" + str(FOLD) + "fold.csv"
train_df_split.to_csv(kfold_df_path, index=False)

In [None]:
train_dataset = ISLES_ADCLesionSegDataset(
    dataset_dir="./ISLES2017/ISLES2017_Training_2d_ADC", 
    df_path="./ISLES2017/ISLES2017_Training_clr_5fold.csv",
    img_loader=img_loader, 
    mask_loader=mask_loader,
    augmentation=get_training_augmentation(), 
    preprocessing=get_preprocessing(resize=(256,256)),
    kfold=1,
    mode='train'
    )

In [None]:
val_dataset = ISLES_ADCLesionSegDataset(
    dataset_dir="./ISLES2017/ISLES2017_Training_2d_ADC", 
    df_path="./ISLES2017/ISLES2017_Training_clr_5fold.csv",
    img_loader=img_loader, 
    mask_loader=mask_loader,
    augmentation=get_training_augmentation(), 
    preprocessing=get_preprocessing(resize=(256,256)),
    kfold=1,
    mode='val'
    )

In [None]:
aug_dataset = ISLES_ADCLesionSegDataset(
    dataset_dir="./ISLES2017/ISLES2017_Training_2d_ADC", 
    df_path="./ISLES2017/ISLES2017_Training_clr_5fold.csv",
    img_loader=img_loader, 
    mask_loader=mask_loader,
    augmentation=get_training_augmentation(), 
    preprocessing=get_preprocessing(resize=(256,256),convert=False),
    kfold=1,
    mode='train'
    )

In [None]:
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.axis("off")
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

In [None]:
# check augmentation 
for i in range(10,40):
    image, mask = aug_dataset[i] 
    visualize(image=visualize_grayscale(np.squeeze(image)), mask=visualize_grayscale(np.squeeze(mask)))

In [None]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            #self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            #self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [None]:
save_path = "./ADC_ckpt/2d_ckpt/UNet_resnet152"
gen_new_dir(save_path)
###############################
trial = 1
n_epoches = 10000
LR = 0.0001
LR_DECREASE = 1e-5
lr_decrease_epoch = 70
BATCH_SIZE = 16
patience= 15
###############################
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, 
                                           shuffle=True, drop_last=True)

val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=8, 
                                         shuffle=False)

ENCODER = 'resnet152'
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multicalss segmentation
DEVICE = 'cuda'

# create segmentation model with pretrained encoder
model = smp.Unet(
    encoder_name=ENCODER, 
    encoder_weights=None, 
    in_channels=1,
    classes=1, 
    activation=ACTIVATION,
)

loss = smp.utils.losses.DiceLoss()

metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
]

optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=LR),
])

In [None]:
# create epoch runners 
# it is a simple loop of iterating over dataloader`s samples
train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

In [None]:
# train model for 40 epochs

max_score = 0
with open(os.path.join(save_path, f'results{str(trial).zfill(2)}.csv'), 'w') as f:
    f.write('epoch,train_loss,train_score,valid_loss,valid_score\n')

early_stopping = EarlyStopping(patience=patience, verbose=True)

for epoch in range(0, n_epoches):
    
    print(f'\nEpoch: {epoch}')
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(val_loader)
    
    with open(os.path.join(save_path, f'results{str(trial).zfill(2)}.csv'), 'a') as f:
            f.write('%03d,%0.6f,%0.6f,%0.6f,%0.6f\n' % (
                (epoch + 1),
                train_logs['dice_loss'],
                train_logs['iou_score'],
                valid_logs['dice_loss'],
                valid_logs['iou_score'],
            ))
    
    # do something (save model, change lr, etc.)
    if max_score < valid_logs['iou_score']:
        max_score = valid_logs['iou_score']
        torch.save(model, os.path.join(save_path, f'best_model{str(trial).zfill(2)}.pth'))
        print('New Record!')
        
    torch.save(model, os.path.join(save_path, f'final_model{str(trial).zfill(2)}.pth'))
    
    early_stopping(valid_logs['dice_loss'], model)
    if early_stopping.early_stop:
        print("Early stopping")
        break
    
    if epoch == lr_decrease_epoch:
        optimizer.param_groups[0]['lr'] = LR_DECREASE
        print(f'Decrease decoder learning rate to {LR_DECREASE}!')

In [None]:
val_dataset = ISLES_ADCLesionSegDataset(
    dataset_dir="./ISLES2017/ISLES2017_Training_2d_ADC", 
    df_path="./ISLES2017/ISLES2017_Training_clr_5fold.csv",
    img_loader=img_loader, 
    mask_loader=mask_loader,
    augmentation=get_training_augmentation(), 
    preprocessing=get_preprocessing(resize=(256,256)),
    kfold=1,
    mode='val'
    )

In [None]:
vis_val_dataset = ISLES_ADCLesionSegDataset(
    dataset_dir="./ISLES2017/ISLES2017_Training_2d_ADC", 
    df_path="./ISLES2017/ISLES2017_Training_clr_5fold.csv",
    img_loader=img_loader, 
    mask_loader=mask_loader,
    augmentation=get_training_augmentation(), 
    preprocessing=get_preprocessing(resize=(256,256), convert=False),
    kfold=1,
    mode='val'
    )

In [None]:
# load best saved checkpoint
save_path = "./ADC_ckpt/2d_ckpt/UNet_resnet152"
best_model = torch.load(os.path.join(save_path, 'best_model01.pth'))

In [None]:
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=8, 
                                          shuffle=False)

In [None]:
# evaluate model on test set
test_epoch = smp.utils.train.ValidEpoch(
    model=best_model,
    loss=loss,
    metrics=metrics,
    device=DEVICE,
)

logs = test_epoch.run(val_loader)

In [None]:
train_history = pd.read_csv(os.path.join(save_path,'results01.csv'))
fig,ax = plt.subplots(1,2)

ax[0].set_title('loss')
ax[0].plot(np.array(train_history['train_loss']), 'b')
ax[0].plot(np.array(train_history['valid_loss']), 'r')

ax[1].set_title('acc')
ax[1].plot(np.array(train_history['train_score']), 'b')
ax[1].plot(np.array(train_history['valid_score']), 'r')

In [None]:
import cv2

In [None]:
predict_masks = []

for data in val_loader:
    images, labels = data
    images = images.to(DEVICE)
    masks = labels.to(DEVICE)
    pr_mask = best_model.predict(images)
    predict_masks.append(pr_mask.cpu().numpy().round())

In [None]:
predict_masks = np.squeeze(np.vstack(predict_masks))

In [None]:
for i in range(0,63):
    image, mask = vis_val_dataset[i] 
    predict= predict_masks[i]
    image_rgb = visualize_grayscale(np.squeeze(image))
    predict= predict.astype(np.uint8)
    predict= predict[:,:,np.newaxis]
    intersect_mask = mask*predict
    only_mask = np.where((mask-intersect_mask)==1, 1, 0)
    only_pred = np.where((predict-intersect_mask)==1, 1, 0)
    tp_np_mask = np.concatenate([only_pred,intersect_mask,only_mask], axis=-1)*255
    vis = image_rgb/2 + tp_np_mask/2
    vis = vis.astype(np.uint8)
    visualize(image=image_rgb, result=tp_np_mask, visualize= vis)

In [None]:
from sklearn.metrics import confusion_matrix

In [None]:
# code reference: https://gist.github.com/gergf/acd8e3fd23347cb9e6dc572f00c63d79
def dice(true_mask, pred_mask, non_seg_score=0.0):
    """
        Computes the Dice coefficient.
        Args:
            true_mask : Array of arbitrary shape.
            pred_mask : Array with the same shape than true_mask.  
        
        Returns:
            A scalar representing the Dice coefficient between the two segmentations. 
        
    """
    assert true_mask.shape == pred_mask.shape

    true_mask = np.asarray(true_mask).astype(np.bool_)
    pred_mask = np.asarray(pred_mask).astype(np.bool_)

    # If both segmentations are all zero, the dice will be 1. (Developer decision)
    im_sum = true_mask.sum() + pred_mask.sum()
    if im_sum == 0:
        return non_seg_score

    # Compute Dice coefficient
    intersection = np.logical_and(true_mask, pred_mask)
    return 2. * intersection.sum() / im_sum

In [None]:
dice_avg = 0
cnt = 0
for i in range(len(vis_val_dataset)):
    image, mask = vis_val_dataset[i] 
    if (predict_masks[i].max() != 0.) & (mask.max() != 0.):
        dice_avg += dice(np.squeeze(mask.astype(np.uint8)), predict_masks[i].astype(np.uint8))
        cnt += 1
    else:
        pass
dice_avg /= cnt

In [None]:
dice_avg