In [None]:
from tqdm import tqdm
import os
import copy

import numpy as np
import pandas as pd
import seaborn as sns
import PIL.Image as Image

import nibabel as nib

import matplotlib.pyplot as plt
from sklearn.model_selection import StratifiedKFold, train_test_split

import segmentation_models_pytorch as smp

from util.util import *
from util.visualize import *
from data.dataset_2d import *

common_dir = '/home/ncp/workspace/202002n050/050.신경계 질환 관련 임상 및 진료 데이터'

In [None]:
!nvidia-smi

In [None]:
import os
import numpy as np
import pandas as pd
import glob

FILE_EXTENSION = ['.png', '.PNG', '.jpg', '.JPG', '.dcm', '.DCM', '.raw', '.RAW', '.svs', '.SVS']
IMG_EXTENSION = ['.png', '.PNG', '.jpg', '.JPG', '.jpeg', '.JPEG']
DCM_EXTENSION = ['.dcm', '.DCM']
RAW_EXTENSION = ['.raw', '.RAW']
NIFTI_EXTENSION = ['.nii']
NP_EXTENSION = ['.npy']

mask_common_dir = '/home/ncp/workspace/202002n050/050.신경계 질환 관련 임상 및 진료 데이터'


def check_extension(filename, extension_ls=FILE_EXTENSION):
    return any(filename.endswith(extension) for extension in extension_ls)


def load_file_path(folder_path, extension_ls=FILE_EXTENSION, all_sub_folders=False):
    """find 'IMG_EXTENSION' file paths in folder.
    
    Parameters:
        folder_path (str) -- folder directory
        extension_ls (list) -- list of extensions
    
    Return:
        file_paths (list) -- list of 'extension_ls' file paths
    """
    
    file_paths = []
    assert os.path.isdir(folder_path), f'{folder_path} is not a valid directory'

    for root, _, fnames in sorted(os.walk(folder_path)):
        for fname in fnames:
            if check_extension(fname, extension_ls):
                path = os.path.join(root, fname)
                file_paths.append(path)
        if not all_sub_folders:
            break

    return file_paths[:]


def gen_new_dir(new_dir):
    try: 
        if not os.path.exists(new_dir): 
            os.makedirs(new_dir) 
            #print(f"New directory!: {new_dir}")
    except OSError: 
        print("Error: Failed to create the directory.")


def find_dwi_adc_dir(img_folder_dir, fname):
    dwi_folder_dir = os.path.join(img_folder_dir, fname, 'dwi')
    adc_folder_dir = os.path.join(img_folder_dir, fname, 'adc')
    if (os.path.isdir(dwi_folder_dir)) & (os.path.isdir(adc_folder_dir)):
        return dwi_folder_dir, adc_folder_dir
    else:
        return None


def find_mask_dir(mask_folder_dir, fname):
    mask_folder_dir = os.path.join(mask_folder_dir, fname)
    if (os.path.isdir(mask_folder_dir)):
        return mask_folder_dir
    else:
        return None


def pair_aihub_dwi_adc_img_mask_path(img_folder_dir, mask_folder_dir):
    img_mask_path_dict = {}
    for fname in sorted(os.listdir(img_folder_dir)):
        dwi_adc_dir = find_dwi_adc_dir(img_folder_dir, fname)
        mask_dir = find_mask_dir(mask_folder_dir, fname)
        if dwi_adc_dir:
            if mask_dir:
                dwi_folder_dir, adc_folder_dir = dwi_adc_dir
                dwi_path_ls = sorted(load_file_path(dwi_folder_dir, IMG_EXTENSION))
                adc_path_ls = sorted(load_file_path(adc_folder_dir, IMG_EXTENSION))
                img_path_ls = list(zip(dwi_path_ls,adc_path_ls))
                mask_path_ls = sorted(load_file_path(mask_dir, IMG_EXTENSION))
                img_mask_path_dict[fname] = [img_path_ls, mask_path_ls]
    return img_mask_path_dict


def select_train_val_test(img_mask_path_dict, fname_list):
    tmp_dict = {}
    for fname in fname_list:
        if img_mask_path_dict.get(fname):
            tmp_dict[fname] = img_mask_path_dict.get(fname)
            
    return tmp_dict


def find_aihub_img_mask_paths(img_folder_dir, mask_folder_dir, fname_list):
    img_mask_path_dict = pair_aihub_dwi_adc_img_mask_path(img_folder_dir, mask_folder_dir)
    
    img_mask_path_dict_sel = select_train_val_test(img_mask_path_dict, fname_list)
    
    img_path_arr = np.concatenate([[*img_path_ls] for img_path_ls, _ in img_mask_path_dict_sel.values()])
    mask_path_arr = np.concatenate([mask_path_ls for _, mask_path_ls in img_mask_path_dict_sel.values()])
    return img_path_arr, mask_path_arr

In [None]:
def dwi_adc_loader(dwi_adc_path):
    dwi_path, adc_path = dwi_adc_path
    dwi_img = np.array(Image.open(dwi_path))
    adc_img = np.array(Image.open(adc_path))
    return np.stack([dwi_img,adc_img], axis=-1)
def mask_loader(mask_path):
    return np.expand_dims(np.where(np.array(Image.open(mask_path)),1,0), axis=-1).astype(np.uint8)

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets, models, transforms
import os

import numpy as np
import pandas as pd
import albumentations as A
from albumentations.pytorch import ToTensorV2


def get_training_augmentation(params=None):
    transform_list = []
    
    #transform_list.append(A.HorizontalFlip(p=.5))
    #transform_list.append(A.VerticalFlip(p=.5))
    #transform_list.append(A.ShiftScaleRotate(scale_limit=0.1, rotate_limit=5, shift_limit=0.2, border_mode=0, p=.5))
    #transform_list.append(A.ShiftScaleRotate(scale_limit=0.01, rotate_limit=5, shift_limit=0., border_mode=0, p=.5))
    
    return A.Compose(transform_list)


def get_preprocessing(params=None,resize=(256,256),convert=True):
    transform_list = []
    transform_list.append(A.Resize(*resize))
    if convert:
        transform_list.append(A.Normalize(mean=(0.5,0.5),  std=(0.5,0.5)))
        #transform_list.append(A.Normalize(mean=(0.485, 0.456, 0.406),  std=(0.229, 0.224, 0.225)))
        transform_list.append(ToTensorV2(transpose_mask=True))
    return A.Compose(transform_list)


class AIHUB_DWI_ADC_LesionSegDataset(torch.utils.data.Dataset):
    def __init__(self, 
                 img_folder_dir, 
                 mask_folder_dir, 
                 data_df_path,
                 img_loader=dwi_adc_loader, 
                 mask_loader=mask_loader,
                 augmentation=None, 
                 preprocessing=None,
                 mode='train'
    ):
        self.data_df = pd.read_csv(data_df_path)
        self.img_loader = img_loader
        self.mask_loader = mask_loader
        self.augmentation = augmentation
        self.preprocessing = preprocessing
        self.mode = mode
        self.fname_list = self.data_df[self.data_df["split_811"] == self.mode].name.values
        
        self.img_path_arr, self.mask_path_arr = find_aihub_img_mask_paths(img_folder_dir, mask_folder_dir, self.fname_list)
        if self.mode != 'train':
            self.augmentation = None
        
    def __getitem__(self, index):
        image = self.img_loader(self.img_path_arr[index])
        mask = self.mask_loader(self.mask_path_arr[index])
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        return image, mask
    
    def __len__(self):
        return len(self.img_path_arr)


In [None]:
import PIL.Image as Image
import matplotlib.pyplot as plt

In [None]:
import segmentation_models_pytorch as smp

In [None]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            #self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            #self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [None]:
val_dataset = AIHUB_DWI_ADC_LesionSegDataset(
    img_folder_dir = '/home/ncp/workspace/blocks1/dicom_to_png_2d/', 
    mask_folder_dir = '/home/ncp/workspace/blocks/refined_mask', 
    data_df_path = '/home/ncp/workspace/blocks1/aihub_df.csv',  
    augmentation=None, 
    preprocessing=get_preprocessing(resize=(256,256)),
    mode='test'
    )

In [None]:
vis_val_dataset = AIHUB_DWI_ADC_LesionSegDataset(
    img_folder_dir = '/home/ncp/workspace/blocks1/dicom_to_png_2d/', 
    mask_folder_dir = '/home/ncp/workspace/blocks/refined_mask', 
    data_df_path = '/home/ncp/workspace/blocks1/aihub_df.csv',  
    augmentation=None, 
    preprocessing=get_preprocessing(resize=(256,256),convert=False),
    mode='test'
    )

In [None]:
# load best saved checkpoint
save_path = "./DWI_ADC_ckpt/2d_ckpt/Unet_resnet152"
best_model = torch.load(os.path.join(save_path, 'best_model01.pth'))

In [None]:
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=8, shuffle=False)

In [None]:
# evaluate model on test set
test_epoch = smp.utils.train.ValidEpoch(
    model=best_model,
    loss=loss,
    metrics=metrics,
    device=DEVICE,
)

logs = test_epoch.run(val_loader)

In [None]:
train_history = pd.read_csv(os.path.join(save_path,'results01.csv'))
fig,ax = plt.subplots(1,2)

ax[0].set_title('loss')
ax[0].plot(np.array(train_history['train_loss']), 'b')
ax[0].plot(np.array(train_history['valid_loss']), 'r')

ax[1].set_title('acc')
ax[1].plot(np.array(train_history['train_score']), 'b')
ax[1].plot(np.array(train_history['valid_score']), 'r')

In [None]:
import cv2

In [None]:
def save_arr_to_png(im_2d, save_point, filename):
    Image.fromarray(im_2d).save(os.path.join(save_point, filename+'.png'))

In [None]:
# load best saved checkpoint
save_path = "./DWI_ADC_ckpt/2d_ckpt/Unet_resnet152"
best_model = torch.load(os.path.join(save_path, 'best_model01.pth'))

In [None]:
val_dataset = AIHUB_DWI_ADC_LesionSegDataset(
    img_folder_dir = '/home/ncp/workspace/blocks1/dicom_to_png_2d/', 
    mask_folder_dir = '/home/ncp/workspace/blocks/refined_mask', 
    data_df_path = '/home/ncp/workspace/blocks1/aihub_df.csv',  
    augmentation=None, 
    preprocessing=get_preprocessing(resize=(256,256)),
    mode='test',
    )

In [None]:
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=8, shuffle=False)

In [None]:
predict_masks = []

for data in tqdm(val_loader):
    images, labels = data
    images = images.to(DEVICE)
    masks = labels.to(DEVICE)
    pr_mask = best_model.predict(images)
    predict_masks.append(pr_mask.cpu().numpy())

In [None]:
predict_masks = np.squeeze(np.vstack(predict_masks))

In [None]:
predict_masks.shape

In [None]:
predict_masks_norm = (predict_masks*255).astype(np.uint8)

In [None]:
save_dir = '/home/ncp/workspace/blocks1/dicom_to_png_2d/'
len_cnt = 0
tot_lesion_exist_dict = {}
for fname in tqdm(val_dataset.fname_list):
    z = len([p for p, _ in val_dataset.img_path_arr if fname in p])
    tmp = predict_masks_norm[len_cnt:len_cnt+z]
    save_point = os.path.join(save_dir, fname, 'pred_masks')
    gen_new_dir(save_point)
    for i, slice_img in enumerate(tmp):
        save_name = str(i).zfill(3)
        save_arr_to_png(slice_img, save_point, save_name)
    len_cnt += z

In [None]:
#val

In [None]:
val_dataset = AIHUB_DWI_ADC_LesionSegDataset(
    img_folder_dir = '/home/ncp/workspace/blocks1/dicom_to_png_2d/', 
    mask_folder_dir = '/home/ncp/workspace/blocks/refined_mask', 
    data_df_path = '/home/ncp/workspace/blocks1/aihub_df.csv',  
    augmentation=None, 
    preprocessing=get_preprocessing(resize=(256,256)),
    mode='val'
    )

In [None]:
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=8, shuffle=False)

In [None]:
predict_masks = []

for data in tqdm(val_loader):
    images, labels = data
    images = images.to(DEVICE)
    masks = labels.to(DEVICE)
    pr_mask = best_model.predict(images)
    predict_masks.append(pr_mask.cpu().numpy())

In [None]:
predict_masks = np.squeeze(np.vstack(predict_masks))

In [None]:
predict_masks.shape

In [None]:
predict_masks_norm = (predict_masks*255).astype(np.uint8)

In [None]:
save_dir = '/home/ncp/workspace/blocks1/dicom_to_png_2d/'
len_cnt = 0
tot_lesion_exist_dict = {}
for fname in tqdm(val_dataset.fname_list):
    z = len([p for p, _ in val_dataset.img_path_arr if fname in p])
    tmp = predict_masks_norm[len_cnt:len_cnt+z]
    save_point = os.path.join(save_dir, fname, 'pred_masks')
    gen_new_dir(save_point)
    for i, slice_img in enumerate(tmp):
        save_name = str(i).zfill(3)
        save_arr_to_png(slice_img, save_point, save_name)
    len_cnt += z

In [None]:
#train

In [None]:
val_dataset = AIHUB_DWI_ADC_LesionSegDataset(
    img_folder_dir = '/home/ncp/workspace/blocks1/dicom_to_png_2d/', 
    mask_folder_dir = '/home/ncp/workspace/blocks/refined_mask', 
    data_df_path = '/home/ncp/workspace/blocks1/aihub_df.csv',  
    augmentation=None, 
    preprocessing=get_preprocessing(resize=(256,256)),
    mode='train'
    )

In [None]:
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=8, shuffle=False)

In [None]:
predict_masks = []

for data in tqdm(val_loader):
    images, labels = data
    images = images.to(DEVICE)
    masks = labels.to(DEVICE)
    pr_mask = best_model.predict(images)
    predict_masks.append(pr_mask.cpu().numpy())

In [None]:
predict_masks = np.squeeze(np.vstack(predict_masks))

In [None]:
predict_masks.shape

In [None]:
predict_masks_norm = (predict_masks*255).astype(np.uint8)

In [None]:
save_dir = '/home/ncp/workspace/blocks1/dicom_to_png_2d/'
len_cnt = 0
tot_lesion_exist_dict = {}
for fname in tqdm(val_dataset.fname_list):
    z = len([p for p, _ in val_dataset.img_path_arr if fname in p])
    tmp = predict_masks_norm[len_cnt:len_cnt+z]
    save_point = os.path.join(save_dir, fname, 'pred_masks')
    gen_new_dir(save_point)
    for i, slice_img in enumerate(tmp):
        save_name = str(i).zfill(3)
        save_arr_to_png(slice_img, save_point, save_name)
    len_cnt += z

In [None]:
################### 3d pred mask #####################

In [None]:
DEVICE='cuda'

In [None]:
val_dataset = AIHUB_DWI_ADC_LesionSegDataset(
    img_folder_dir = '/home/ncp/workspace/blocks2/dicom_to_png_2d_resample', 
    mask_folder_dir = '/home/ncp/workspace/blocks2/refined_mask_resample_2d', 
    data_df_path = '/home/ncp/workspace/blocks1/aihub_df.csv',  
    augmentation=None, 
    preprocessing=get_preprocessing(resize=(256,256)),
    mode='test'
    )

In [None]:
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=8, shuffle=False)

In [None]:
predict_masks = []

for data in tqdm(val_loader):
    images, labels = data
    images = images.to(DEVICE)
    masks = labels.to(DEVICE)
    pr_mask = best_model.predict(images)
    predict_masks.append(pr_mask.cpu().numpy())

In [None]:
predict_masks = np.squeeze(np.vstack(predict_masks))

In [None]:
predict_masks.shape

In [None]:
predict_masks.dtype

save_dir = '/home/ncp/workspace/blocks1/pred_mask_resample'
gen_new_dir(save_dir)
len_cnt = 0
tot_lesion_exist_dict = {}
for fname in tqdm(val_dataset.fname_list):
    z = len([p for p, _ in val_dataset.img_path_arr if fname in p])
    tmp = predict_masks[len_cnt:len_cnt+z]
    np.save(os.path.join(save_dir, fname+'.npy'), tmp)
    len_cnt += z

In [None]:
val_dataset = AIHUB_DWI_ADC_LesionSegDataset(
    img_folder_dir = '/home/ncp/workspace/blocks2/dicom_to_png_2d_resample', 
    mask_folder_dir = '/home/ncp/workspace/blocks2/refined_mask_resample_2d', 
    data_df_path = '/home/ncp/workspace/blocks1/aihub_df.csv',  
    augmentation=None, 
    preprocessing=get_preprocessing(resize=(256,256)),
    mode='val'
    )

In [None]:
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=8, shuffle=False)

In [None]:
predict_masks = []

for data in tqdm(val_loader):
    images, labels = data
    images = images.to(DEVICE)
    masks = labels.to(DEVICE)
    pr_mask = best_model.predict(images)
    predict_masks.append(pr_mask.cpu().numpy())

In [None]:
predict_masks = np.squeeze(np.vstack(predict_masks))

In [None]:
predict_masks.shape

In [None]:
predict_masks.dtype

In [None]:
save_dir = '/home/ncp/workspace/blocks1/pred_mask_resample'
gen_new_dir(save_dir)
len_cnt = 0
tot_lesion_exist_dict = {}
for fname in tqdm(val_dataset.fname_list):
    z = len([p for p, _ in val_dataset.img_path_arr if fname in p])
    tmp = predict_masks[len_cnt:len_cnt+z]
    np.save(os.path.join(save_dir, fname+'.npy'), tmp)
    len_cnt += z

In [None]:
val_dataset = AIHUB_DWI_ADC_LesionSegDataset(
    img_folder_dir = '/home/ncp/workspace/blocks2/dicom_to_png_2d_resample', 
    mask_folder_dir = '/home/ncp/workspace/blocks2/refined_mask_resample_2d', 
    data_df_path = '/home/ncp/workspace/blocks1/aihub_df.csv',  
    augmentation=None, 
    preprocessing=get_preprocessing(resize=(256,256)),
    mode='train'
    )

In [None]:
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=8, shuffle=False)

In [None]:
predict_masks = []

for data in tqdm(val_loader):
    images, labels = data
    images = images.to(DEVICE)
    masks = labels.to(DEVICE)
    pr_mask = best_model.predict(images)
    predict_masks.append(pr_mask.cpu().numpy())

In [None]:
predict_masks = np.squeeze(np.vstack(predict_masks))

In [None]:
predict_masks.shape

In [None]:
predict_masks.dtype

In [None]:
save_dir = '/home/ncp/workspace/blocks1/pred_mask_resample'
gen_new_dir(save_dir)
len_cnt = 0
tot_lesion_exist_dict = {}
for fname in tqdm(val_dataset.fname_list):
    z = len([p for p, _ in val_dataset.img_path_arr if fname in p])
    tmp = predict_masks[len_cnt:len_cnt+z]
    np.save(os.path.join(save_dir, fname+'.npy'), tmp)
    len_cnt += z

In [None]:
#red:FalsePositive / green:TruePositive / blue:FalseNegative
for i in range(1950, 1960):
    image, mask = vis_val_dataset[i] 
    predict= predict_masks[i]
    image_rgb = visualize_grayscale(np.squeeze(image[:,:,0]))
    predict= predict.astype(np.uint8)
    predict= predict[:,:,np.newaxis]
    intersect_mask = mask*predict
    only_mask = np.where((mask-intersect_mask)==1, 1, 0)
    only_pred = np.where((predict-intersect_mask)==1, 1, 0)
    tp_np_mask = np.concatenate([only_pred,intersect_mask,only_mask], axis=-1)*255
    vis = image_rgb/2 + tp_np_mask/2
    vis = vis.astype(np.uint8)
    visualize(image=image_rgb, result=tp_np_mask, visualize=vis)

In [None]:
from sklearn.metrics import confusion_matrix

In [None]:
# code reference: https://gist.github.com/gergf/acd8e3fd23347cb9e6dc572f00c63d79
def dice(true_mask, pred_mask, non_seg_score=0.0):
    """
        Computes the Dice coefficient.
        Args:
            true_mask : Array of arbitrary shape.
            pred_mask : Array with the same shape than true_mask.  
        
        Returns:
            A scalar representing the Dice coefficient between the two segmentations. 
        
    """
    assert true_mask.shape == pred_mask.shape

    true_mask = np.asarray(true_mask).astype(np.bool_)
    pred_mask = np.asarray(pred_mask).astype(np.bool_)

    # If both segmentations are all zero, the dice will be 1. (Developer decision)
    im_sum = true_mask.sum() + pred_mask.sum()
    if im_sum == 0:
        return non_seg_score

    # Compute Dice coefficient
    intersection = np.logical_and(true_mask, pred_mask)
    return 2. * intersection.sum() / im_sum

In [None]:
dice_avg = 0
cnt = 0
for i in range(len(vis_val_dataset)):
    image, mask = vis_val_dataset[i] 
    if (predict_masks[i].max() != 0.) & (mask.max() != 0.):
        dice_avg += dice(np.squeeze(mask.astype(np.uint8)), predict_masks[i].astype(np.uint8))
        cnt += 1
    else:
        pass
dice_avg /= cnt

In [None]:
dice_avg

In [None]:
dice_score_list = []
cnt = 0
for i in range(len(vis_val_dataset)):
    image, mask = vis_val_dataset[i] 
    if (predict_masks[i].max() != 0.) & (mask.max() != 0.):
        dice_score_list.append([i,dice(np.squeeze(mask.astype(np.uint8)), predict_masks[i].astype(np.uint8))])
        cnt += 1
    else:
        pass

In [None]:
dice_score_arr = np.array(dice_score_list)

In [None]:
val_dataset_idx = dice_score_arr[:,0]
val_dataset_dice = dice_score_arr[:,1]

In [None]:
min_idx = int(val_dataset_idx[np.argmin(val_dataset_dice)])

In [None]:
val_dataset_dice_w_idx = [[dice, i] for i, dice in enumerate(val_dataset_dice)]

In [None]:
val_dataset_dice_w_idx = sorted(val_dataset_dice_w_idx, key=lambda x: x[0])

In [None]:
val_dataset_dice_w_idx[1100][1]

In [None]:
val_dataset_dice_w_idx

In [None]:
print(np.min(val_dataset_dice))
#red:FalsePositive / green:TruePositive / blue:FalseNegative
for N in range(600,1200):
    i = val_dataset_dice_w_idx[N][1]
    print(i)
    image, mask = vis_val_dataset[i] 
    predict= predict_masks[i]
    image_rgb = visualize_grayscale(np.squeeze(image[:,:,0]))
    predict= predict.astype(np.uint8)
    predict= predict[:,:,np.newaxis]
    intersect_mask = mask*predict
    only_mask = np.where((mask-intersect_mask)==1, 1, 0)
    only_pred = np.where((predict-intersect_mask)==1, 1, 0)
    tp_np_mask = np.concatenate([only_pred,intersect_mask,only_mask], axis=-1)*255
    vis = image_rgb/2 + tp_np_mask/2
    vis = vis.astype(np.uint8)
    visualize(image=image_rgb, result=tp_np_mask, visualize= vis)

In [None]:
max_idx = int(val_dataset_idx[np.argmax(val_dataset_dice)])

In [None]:
print(np.max(val_dataset_dice))
#red:FalsePositive / green:TruePositive / blue:FalseNegative
i = max_idx
image, mask = vis_val_dataset[i] 
predict= predict_masks[i]
image_rgb = visualize_grayscale(np.squeeze(image[:,:,0]))
predict= predict.astype(np.uint8)
predict= predict[:,:,np.newaxis]
intersect_mask = mask*predict
only_mask = np.where((mask-intersect_mask)==1, 1, 0)
only_pred = np.where((predict-intersect_mask)==1, 1, 0)
tp_np_mask = np.concatenate([only_pred,intersect_mask,only_mask], axis=-1)*255
vis = image_rgb/2 + tp_np_mask/2
vis = vis.astype(np.uint8)
visualize(image=image_rgb, result=tp_np_mask, visualize= vis)

In [None]:
save_path = "./ADC_ckpt/2d_ckpt/Deeplabv3plus_b0"
gen_new_dir(save_path)
###############################
trial = 2
n_epoches = 10000
LR = 0.0001
LR_DECREASE = 1e-5
lr_decrease_epoch = 70
BATCH_SIZE = 64
patience= 15
###############################
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, 
                                           shuffle=True, drop_last=True)

val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=8, 
                                         shuffle=False)
#timm-resnest101e
#timm-efficientnet-b0
#timm-efficientnet-b2
#timm-efficientnet-b5
#se_resnext50_32x4d
#se_resnext101_32x4d
#resnet152
#densenet121
#densenet169

ENCODER = 'timm-efficientnet-b0'
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multicalss segmentation
DEVICE = 'cuda'

# create segmentation model with pretrained encoder
save_path = "./ADC_ckpt/2d_ckpt/Deeplabv3plus_b0"
model = torch.load(os.path.join(save_path, 'best_model01.pth'))

loss = smp.utils.losses.DiceLoss()

metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
]

optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=LR),
])

In [None]:
# create epoch runners 
# it is a simple loop of iterating over dataloader`s samples
train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

In [None]:
# train model for 40 epochs

max_score = 0
with open(os.path.join(save_path, f'results{str(trial).zfill(2)}.csv'), 'w') as f:
    f.write('epoch,train_loss,train_score,valid_loss,valid_score\n')

early_stopping = EarlyStopping(patience=patience, verbose=True)

for epoch in range(0, n_epoches):
    
    print(f'\nEpoch: {epoch}')
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(val_loader)
    
    with open(os.path.join(save_path, f'results{str(trial).zfill(2)}.csv'), 'a') as f:
            f.write('%03d,%0.6f,%0.6f,%0.6f,%0.6f\n' % (
                (epoch + 1),
                train_logs['dice_loss'],
                train_logs['iou_score'],
                valid_logs['dice_loss'],
                valid_logs['iou_score'],
            ))
    
    # do something (save model, change lr, etc.)
    if max_score < valid_logs['iou_score']:
        max_score = valid_logs['iou_score']
        torch.save(model, os.path.join(save_path, f'best_model{str(trial).zfill(2)}.pth'))
        print('New Record!')
        
    torch.save(model, os.path.join(save_path, f'final_model{str(trial).zfill(2)}.pth'))
    
    early_stopping(valid_logs['dice_loss'], model)
    if early_stopping.early_stop:
        print("Early stopping")
        break
    
    if epoch == lr_decrease_epoch:
        optimizer.param_groups[0]['lr'] = LR_DECREASE
        print(f'Decrease decoder learning rate to {LR_DECREASE}!')