In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '1,2'

import cv2
import glob
import pandas as pd
import numpy as np
import torch
from torch import nn
from torch import optim
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as dset
import torch.nn.functional as F
import torchvision.utils as vutils
import pickle
from PIL import ImageFile
from tqdm import tqdm
from torch.optim.lr_scheduler import StepLR

from unet import UNet
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

from sklearn.metrics import recall_score
from sklearn.metrics import confusion_matrix

from albumentations import (
    HorizontalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine,
    IAASharpen, IAAEmboss, RandomContrast, RandomBrightness, Flip, OneOf, Compose,
    RandomCrop, Normalize, Resize
)

ImageFile.LOAD_TRUNCATED_IMAGES = True

import matplotlib.pyplot as plt
def show(img):
    npimg = img.detach().numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')
    plt.show()

In [2]:
niter=10
outf='./saved_model/'

IS_TRAINING = True
n_channel = 3
n_disc = 16
n_gen = 64
n_encode = 64
n_l = 10
n_z = 50
img_size = 128
IMG_HEIGHT, IMG_WIDTH = 128, 128
batchSize = 32
use_cuda = torch.cuda.is_available()
n_age = int(n_z/n_l)
n_gender = int(n_z/2)

n_class_age = 6
n_repeat = 2

PATH_DATA_TRAIN = '/home/quang/working/fundus_segmentation/data/segmentation_doctor/patch_data_train_DeepLab/'
PATH_DATA_TEST = '/home/quang/working/fundus_segmentation/data/segmentation_doctor/patch_data_test_DeepLab/'

PATH_DATA_MASKPRED_TRAIN = '/home/quang/working/fundus_segmentation/data/segmentation_doctor/patch_data_mask_pred_train_DeepLab/'
PATH_DATA_MASKPRED_TEST = '/home/quang/working/fundus_segmentation/data/segmentation_doctor/patch_data_mask_pred_test_DeepLab/'

def show(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')
    plt.show()
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print (device)

cuda


In [3]:
list_paths_img_train = glob.glob(PATH_DATA_TRAIN + '*.jpg')
list_paths_mask_train = [x.split('.')[0] + '.labels.tif' for x in list_paths_img_train]
   
print (len(list_paths_img_train))

list_paths_img_test = glob.glob(PATH_DATA_TEST + '*.jpg')
list_paths_mask_test = [x.split('.')[0] + '.labels.tif' for x in list_paths_img_test]
   
print (len(list_paths_img_train), len(list_paths_img_test))

10913
10913 1087


In [4]:
class ImageAug:
    def __init__(self, aug):
        self.aug=aug

    def __call__(self, img):
        img = self.aug(image=img)['image']
        return img

class FundusDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, PATH_IMG, PATH_MASK, PATH_DATA_MASKPRED, transform=None, transform_torch=None, toTensor=None, train_transform_torch_mask_pred=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
#         self.root_dir = root_dir
           
#         self.df_paired = df_paired.loc[df_paired['train'] == train]
        self.PATH_IMG = PATH_IMG
        self.PATH_MASK = PATH_MASK
        self.PATH_DATA_MASKPRED = PATH_DATA_MASKPRED
        self.ToTensor = toTensor
        self.transform = transform
        self.transform_torch = transform_torch
        self.train_transform_torch_mask_pred = train_transform_torch_mask_pred

    def __len__(self):
#         return 100
        return len(self.PATH_IMG)

    def __getitem__(self, idx):
        file_name_temp = self.PATH_IMG[idx].split('/')[-1]
        image = cv2.imread(self.PATH_IMG[idx])
        mask = cv2.imread(self.PATH_MASK[idx],0)
        mask = np.where(mask > 0, 1., 0.)
        mask = np.expand_dims(mask, axis=-1)

        mask_pred = cv2.imread(self.PATH_DATA_MASKPRED + file_name_temp,0)
        mask_pred = np.expand_dims(mask_pred, axis=-1)

        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)   
        
        if self.transform:
            img_auged = self.transform(image=image, mask=mask, mask_pred=mask_pred)
            image = img_auged['image']
            mask = img_auged['mask']
            mask_pred = img_auged['mask_pred']
            image = self.transform_torch(image)

            mask = self.ToTensor(mask).float()
            mask_pred = self.train_transform_torch_mask_pred(mask_pred)
        sample = {'image': image, 'mask': mask, 'mask_pred': mask_pred, 'id': file_name_temp}

        return sample
    
train_album = Compose([
    Resize(int(IMG_HEIGHT*1.), int(IMG_WIDTH*1.)), 
    RandomCrop(IMG_HEIGHT, IMG_HEIGHT),
    HorizontalFlip(),
    ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.05, rotate_limit=10, p=.75)
], additional_targets = {'image0': 'image', 'mask_pred': 'mask'})

train_transform_torch = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

train_transform_torch_mask_pred = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

test_album = Compose([
    Resize(IMG_HEIGHT, IMG_WIDTH)
], additional_targets = {'image0': 'image', 'mask_pred': 'mask'})

test_transform_torch = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])


train_dataset = FundusDataset(list_paths_img_train, list_paths_mask_train, PATH_DATA_MASKPRED_TRAIN, transform=train_album, transform_torch=train_transform_torch, 
                              toTensor=transforms.ToTensor(), train_transform_torch_mask_pred=train_transform_torch_mask_pred)

dataloader = torch.utils.data.DataLoader(train_dataset,
                                             batch_size=batchSize, shuffle=True,
                                             num_workers=2, drop_last = True)

valid_dataset = FundusDataset(list_paths_img_test, list_paths_mask_test, PATH_DATA_MASKPRED_TEST, transform=test_album, transform_torch=test_transform_torch, 
                              toTensor=transforms.ToTensor(), train_transform_torch_mask_pred=train_transform_torch_mask_pred)

dataloader_valid = torch.utils.data.DataLoader(valid_dataset,
                                             batch_size=32, shuffle=False,
                                             num_workers=2)

In [6]:
def show(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')
    plt.show()

In [9]:
def requires_grad(model, flag=True):
    for p in model.parameters():
        p.requires_grad = flag

In [10]:
if use_cuda:
#     net = UNet(n_channels=3, n_classes=1).cuda()
    net = nn.DataParallel(UNet(n_channels=4, n_classes=1)).cuda()
else:
    net = UNet(n_channels=3, n_classes=1)

In [12]:
optimizerE = optim.Adam(net.parameters(),lr=0.01, betas=(0.9, 0.995))
scheduler = StepLR(optimizerE, step_size=17, gamma=0.1)
# optimizerE = optim.SGD(net.parameters(),lr=0.01, momentum=0.9, weight_decay=0.0005)

In [14]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.logits = logits
        self.reduce = reduce

    def forward(self, inputs, targets):
        if self.logits:
            BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False)
        else:
            BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False)
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

        if self.reduce:
            return torch.mean(F_loss)
        else:
            return F_loss

In [15]:
# Focal Tversky loss, brought to you by:  https://github.com/nabsabraham/focal-tversky-unet
def tversky(y_true, y_pred, smooth=1e-6):
    y_true_pos = y_true.view(-1)
    y_pred_pos = y_pred.view(-1)
    true_pos = (y_true_pos * y_pred_pos).sum()
    false_neg = (y_true_pos * (1-y_pred_pos)).sum()
    false_pos = ((1-y_true_pos)*y_pred_pos).sum()
    alpha = 0.75
    return (true_pos + smooth)/(true_pos + alpha*false_neg + (1-alpha)*false_pos + smooth)

def tversky_loss(y_true, y_pred):
    return 1 - tversky(y_true,y_pred)

def focal_tversky_loss(y_true,y_pred):
    pt_1 = tversky(y_true, y_pred)
    gamma = 0.6
#     return tf.keras.backend.pow((1-pt_1), gamma)
    return (1-pt_1).pow(gamma)

In [16]:
def dice_loss(input, target):
    smooth = 1.

    iflat = input.view(-1)
    tflat = target.view(-1)
    intersection = (iflat * tflat).sum()
    
    return 1 - ((2. * intersection + smooth) /
              (iflat.sum() + tflat.sum() + smooth))

def dice_loss_np(input, target):
    smooth = 1.

    iflat = input.reshape(-1)
    tflat = target.reshape(-1)
    intersection = (iflat * tflat).sum()
    
    return 1 - ((2. * intersection + smooth) /
              (iflat.sum() + tflat.sum() + smooth))

# dice_loss = DiceLoss().to(device)
focal_loss = FocalLoss().to(device)
# bounary_loss = BoundaryLoss(1).to(device)

In [17]:
criterion = nn.BCELoss()

niter = 25

requires_grad(net, True)
net.train()

LIST_DICE_VAL = []
LIST_BCE_TRAIN = []
LIST_DICE_LOSS_TRAIN = []

if IS_TRAINING:
    for epoch in range(0,niter):
        epoch_loss = 0
        epoch_dice = 0
        epoch_bounary = 0
        epoch_loss_weight = 0
        requires_grad(net, True)
        net.train()
        for i,data in enumerate(tqdm(dataloader)):  
            img_data_v = Variable(data['image'])
            mask_data_v = Variable(data['mask'])
            mask_pred_data_v = Variable(data['mask_pred'])

            if use_cuda:
                img_data_v = img_data_v.cuda()
                mask_data_v = mask_data_v.cuda()
                mask_pred_data_v = mask_pred_data_v.cuda()

            batchSize = img_data_v.size(0)

            net.zero_grad()
            data_catted = torch.cat([img_data_v, mask_pred_data_v], 1)
            masks_pred = net(data_catted)

            masks_probs_flat = masks_pred.view(-1)
            true_masks_flat = mask_data_v.view(-1)

#             focal_loss_ = focal_loss(masks_probs_flat, true_masks_flat)
            ft_loss = focal_tversky_loss(true_masks_flat, masks_probs_flat)
            dice_loss_ = dice_loss(masks_probs_flat, true_masks_flat)
            loss = ft_loss
            epoch_loss_weight += loss.item()
            epoch_loss += loss.item()
            epoch_dice += dice_loss_.item()
            epoch_bounary += 0
            loss.backward()

            optimizerE.step()        

        scheduler.step()
        requires_grad(net, False)
        net.eval()
        dice_list = []
        sensivity_list = []
        specitivity_list = []
        for j_,data in enumerate(tqdm(dataloader_valid)):      
            img_data_v = Variable(data['image'])
            mask_data_v = Variable(data['mask'])
            mask_pred_data_v = Variable(data['mask_pred'])
            if use_cuda:
                img_data_v = img_data_v.cuda()
                mask_data_v = mask_data_v.cuda()
                mask_pred_data_v = mask_pred_data_v.cuda()

            batchSize = img_data_v.size(0)
            data_catted = torch.cat([img_data_v, mask_pred_data_v], 1)
            masks_pred = net(data_catted)
            masks_probs_flat = masks_pred.view(-1)
            true_masks_flat = mask_data_v.view(-1)
            masks_probs_flat_np = masks_probs_flat.cpu().numpy()
            masks_probs_flat_np = np.where(masks_probs_flat_np > 0.5, 1, 0)
            true_masks_flat_np = true_masks_flat.cpu().numpy()
            true_masks_flat_np = np.where(true_masks_flat_np > 0.5, 1, 0)

            dice_ = dice_loss_np(masks_probs_flat_np.reshape(-1).astype(int), true_masks_flat_np.reshape(-1).astype(int))
            dice_list.append(dice_)
        
            tn_, fp_, fn_, tp_ = confusion_matrix(true_masks_flat_np.reshape(-1).astype(int), masks_probs_flat_np.reshape(-1).astype(int)).ravel()
            specificity = tn_ / (tn_+fp_)
            sensitivity = tp_ / (tp_+fn_)

            sensivity_list.append(sensitivity)
            specitivity_list.append(specificity)
    

        ## checkpoint
        if epoch%10==0 or epoch == (niter-1):
            torch.save(net.state_dict(),"%s/net_%03d.pth"%(outf,epoch+1))

        print()
        print('Epoch {} finished ! Loss: {:04.3f} | Loss_weight: {:04.3f} | sensivity: {:04.3f} | specitivity: {:04.3f}| Dice score: {:04.3f}'.format(epoch,epoch_loss / i, 
                                                epoch_loss_weight/i, np.array(sensivity_list).mean(),
                                                np.array(specitivity_list).mean(), 1 - np.array(dice_list).mean()))
        for param_group in optimizerE.param_groups:
            print("Current learning rate is: {}".format(param_group['lr']))
        print("-"*80)

100%|██████████| 341/341 [00:40<00:00,  9.08it/s]
100%|██████████| 34/34 [00:15<00:00,  2.38it/s]
  0%|          | 0/341 [00:00<?, ?it/s]


Epoch 0 finished ! Loss: 0.449 | Loss_weight: 0.449 | sensivity: 0.661 | specitivity: 0.964| Dice score: 0.639
Current learning rate is: 0.01
--------------------------------------------------------------------------------


100%|██████████| 341/341 [00:38<00:00,  8.85it/s]
100%|██████████| 34/34 [00:15<00:00,  2.33it/s]
  0%|          | 0/341 [00:00<?, ?it/s]


Epoch 1 finished ! Loss: 0.425 | Loss_weight: 0.425 | sensivity: 0.731 | specitivity: 0.944| Dice score: 0.618
Current learning rate is: 0.01
--------------------------------------------------------------------------------


100%|██████████| 341/341 [00:38<00:00,  8.53it/s]
100%|██████████| 34/34 [00:15<00:00,  2.22it/s]
  0%|          | 0/341 [00:00<?, ?it/s]


Epoch 2 finished ! Loss: 0.407 | Loss_weight: 0.407 | sensivity: 0.697 | specitivity: 0.963| Dice score: 0.658
Current learning rate is: 0.01
--------------------------------------------------------------------------------


100%|██████████| 341/341 [00:38<00:00,  8.90it/s]
100%|██████████| 34/34 [00:15<00:00,  2.41it/s]
  0%|          | 0/341 [00:00<?, ?it/s]


Epoch 3 finished ! Loss: 0.390 | Loss_weight: 0.390 | sensivity: 0.663 | specitivity: 0.970| Dice score: 0.660
Current learning rate is: 0.01
--------------------------------------------------------------------------------


100%|██████████| 341/341 [00:38<00:00,  8.92it/s]
100%|██████████| 34/34 [00:15<00:00,  2.31it/s]
  0%|          | 0/341 [00:00<?, ?it/s]


Epoch 4 finished ! Loss: 0.381 | Loss_weight: 0.381 | sensivity: 0.755 | specitivity: 0.953| Dice score: 0.660
Current learning rate is: 0.01
--------------------------------------------------------------------------------


100%|██████████| 341/341 [00:39<00:00,  8.83it/s]
100%|██████████| 34/34 [00:15<00:00,  2.19it/s]
  0%|          | 0/341 [00:00<?, ?it/s]


Epoch 5 finished ! Loss: 0.378 | Loss_weight: 0.378 | sensivity: 0.723 | specitivity: 0.962| Dice score: 0.672
Current learning rate is: 0.01
--------------------------------------------------------------------------------


100%|██████████| 341/341 [00:38<00:00,  8.15it/s]
100%|██████████| 34/34 [00:15<00:00,  2.07it/s]
  0%|          | 0/341 [00:00<?, ?it/s]


Epoch 6 finished ! Loss: 0.374 | Loss_weight: 0.374 | sensivity: 0.648 | specitivity: 0.976| Dice score: 0.673
Current learning rate is: 0.01
--------------------------------------------------------------------------------


100%|██████████| 341/341 [00:39<00:00,  8.34it/s]
100%|██████████| 34/34 [00:15<00:00,  2.29it/s]
  0%|          | 0/341 [00:00<?, ?it/s]


Epoch 7 finished ! Loss: 0.374 | Loss_weight: 0.374 | sensivity: 0.735 | specitivity: 0.959| Dice score: 0.669
Current learning rate is: 0.01
--------------------------------------------------------------------------------


100%|██████████| 341/341 [00:39<00:00,  8.64it/s]
100%|██████████| 34/34 [00:15<00:00,  2.27it/s]
  0%|          | 0/341 [00:00<?, ?it/s]


Epoch 8 finished ! Loss: 0.372 | Loss_weight: 0.372 | sensivity: 0.693 | specitivity: 0.971| Dice score: 0.683
Current learning rate is: 0.01
--------------------------------------------------------------------------------


100%|██████████| 341/341 [00:38<00:00,  8.69it/s]
100%|██████████| 34/34 [00:15<00:00,  2.32it/s]
  0%|          | 0/341 [00:00<?, ?it/s]


Epoch 9 finished ! Loss: 0.372 | Loss_weight: 0.372 | sensivity: 0.687 | specitivity: 0.971| Dice score: 0.682
Current learning rate is: 0.01
--------------------------------------------------------------------------------


100%|██████████| 341/341 [00:39<00:00,  8.89it/s]
100%|██████████| 34/34 [00:15<00:00,  2.07it/s]
  0%|          | 0/341 [00:00<?, ?it/s]


Epoch 10 finished ! Loss: 0.370 | Loss_weight: 0.370 | sensivity: 0.729 | specitivity: 0.960| Dice score: 0.667
Current learning rate is: 0.01
--------------------------------------------------------------------------------


100%|██████████| 341/341 [00:38<00:00,  8.65it/s]
100%|██████████| 34/34 [00:15<00:00,  2.37it/s]
  0%|          | 0/341 [00:00<?, ?it/s]


Epoch 11 finished ! Loss: 0.369 | Loss_weight: 0.369 | sensivity: 0.658 | specitivity: 0.976| Dice score: 0.683
Current learning rate is: 0.01
--------------------------------------------------------------------------------


100%|██████████| 341/341 [00:39<00:00,  8.64it/s]
100%|██████████| 34/34 [00:15<00:00,  2.33it/s]
  0%|          | 0/341 [00:00<?, ?it/s]


Epoch 12 finished ! Loss: 0.368 | Loss_weight: 0.368 | sensivity: 0.727 | specitivity: 0.962| Dice score: 0.674
Current learning rate is: 0.01
--------------------------------------------------------------------------------


100%|██████████| 341/341 [00:38<00:00,  8.86it/s]
100%|██████████| 34/34 [00:16<00:00,  2.20it/s]
  0%|          | 0/341 [00:00<?, ?it/s]


Epoch 13 finished ! Loss: 0.367 | Loss_weight: 0.367 | sensivity: 0.796 | specitivity: 0.938| Dice score: 0.638
Current learning rate is: 0.01
--------------------------------------------------------------------------------


100%|██████████| 341/341 [00:39<00:00,  8.76it/s]
100%|██████████| 34/34 [00:16<00:00,  2.08it/s]
  0%|          | 0/341 [00:00<?, ?it/s]


Epoch 14 finished ! Loss: 0.367 | Loss_weight: 0.367 | sensivity: 0.673 | specitivity: 0.974| Dice score: 0.684
Current learning rate is: 0.01
--------------------------------------------------------------------------------


100%|██████████| 341/341 [00:38<00:00,  8.45it/s]
100%|██████████| 34/34 [00:16<00:00,  2.21it/s]
  0%|          | 0/341 [00:00<?, ?it/s]


Epoch 15 finished ! Loss: 0.366 | Loss_weight: 0.366 | sensivity: 0.689 | specitivity: 0.972| Dice score: 0.687
Current learning rate is: 0.01
--------------------------------------------------------------------------------


100%|██████████| 341/341 [00:38<00:00,  8.84it/s]
100%|██████████| 34/34 [00:15<00:00,  2.32it/s]
  0%|          | 0/341 [00:00<?, ?it/s]


Epoch 16 finished ! Loss: 0.365 | Loss_weight: 0.365 | sensivity: 0.735 | specitivity: 0.961| Dice score: 0.676
Current learning rate is: 0.001
--------------------------------------------------------------------------------


100%|██████████| 341/341 [00:38<00:00,  8.87it/s]
100%|██████████| 34/34 [00:15<00:00,  2.16it/s]
  0%|          | 0/341 [00:00<?, ?it/s]


Epoch 17 finished ! Loss: 0.360 | Loss_weight: 0.360 | sensivity: 0.719 | specitivity: 0.966| Dice score: 0.684
Current learning rate is: 0.001
--------------------------------------------------------------------------------


100%|██████████| 341/341 [00:38<00:00,  8.77it/s]
100%|██████████| 34/34 [00:15<00:00,  2.35it/s]
  0%|          | 0/341 [00:00<?, ?it/s]


Epoch 18 finished ! Loss: 0.360 | Loss_weight: 0.360 | sensivity: 0.708 | specitivity: 0.969| Dice score: 0.687
Current learning rate is: 0.001
--------------------------------------------------------------------------------


100%|██████████| 341/341 [00:39<00:00,  8.89it/s]
100%|██████████| 34/34 [00:16<00:00,  2.09it/s]
  0%|          | 0/341 [00:00<?, ?it/s]


Epoch 19 finished ! Loss: 0.358 | Loss_weight: 0.358 | sensivity: 0.705 | specitivity: 0.969| Dice score: 0.687
Current learning rate is: 0.001
--------------------------------------------------------------------------------


100%|██████████| 341/341 [00:38<00:00,  8.80it/s]
100%|██████████| 34/34 [00:15<00:00,  2.13it/s]
  0%|          | 0/341 [00:00<?, ?it/s]


Epoch 20 finished ! Loss: 0.359 | Loss_weight: 0.359 | sensivity: 0.712 | specitivity: 0.968| Dice score: 0.686
Current learning rate is: 0.001
--------------------------------------------------------------------------------


100%|██████████| 341/341 [00:39<00:00,  8.82it/s]
100%|██████████| 34/34 [00:15<00:00,  2.23it/s]
  0%|          | 0/341 [00:00<?, ?it/s]


Epoch 21 finished ! Loss: 0.358 | Loss_weight: 0.358 | sensivity: 0.732 | specitivity: 0.963| Dice score: 0.682
Current learning rate is: 0.001
--------------------------------------------------------------------------------


100%|██████████| 341/341 [00:38<00:00,  8.81it/s]
100%|██████████| 34/34 [00:15<00:00,  2.21it/s]
  0%|          | 0/341 [00:00<?, ?it/s]


Epoch 22 finished ! Loss: 0.358 | Loss_weight: 0.358 | sensivity: 0.690 | specitivity: 0.973| Dice score: 0.689
Current learning rate is: 0.001
--------------------------------------------------------------------------------


100%|██████████| 341/341 [00:38<00:00,  8.60it/s]
100%|██████████| 34/34 [00:15<00:00,  2.10it/s]
  0%|          | 0/341 [00:00<?, ?it/s]


Epoch 23 finished ! Loss: 0.360 | Loss_weight: 0.360 | sensivity: 0.730 | specitivity: 0.964| Dice score: 0.682
Current learning rate is: 0.001
--------------------------------------------------------------------------------


100%|██████████| 341/341 [00:39<00:00,  8.83it/s]
100%|██████████| 34/34 [00:15<00:00,  2.43it/s]


Epoch 24 finished ! Loss: 0.358 | Loss_weight: 0.358 | sensivity: 0.712 | specitivity: 0.968| Dice score: 0.687
Current learning rate is: 0.001
--------------------------------------------------------------------------------





## Test performance

In [18]:
# net.load_state_dict(torch.load('./saved_model/net_040.pth'))

## Predict whole image

In [19]:
IMG_SIZE = 1200
PATCH_SIZE = 128
STEP = 64

INPUT_DIR_FULL_IMAGE = '/home/quang/working/fundus_segmentation/data/segmentation_doctor/data_full/'
INPUT_DIR_MASK_GT_FULL_IMAGE = '/home/quang/working/fundus_segmentation/data/segmentation_doctor/data_full/'
INPUT_DIR_MASK_PRED_FULL_IMAGE = '/home/quang/working/fundus_segmentation/data/segmentation_doctor/data_full_mask_pred/'
OUTPUT_DIR_MASK_PRED_FULL_IMAGE = './predictions/'

list_paths_img = glob.glob(INPUT_DIR_FULL_IMAGE + '*.jpg')
list_paths_mask = [x.split('.')[0] + '.labels.tif' for x in list_paths_img]

list_paths_img_train, list_paths_img_test, list_paths_mask_train, list_paths_mask_test = train_test_split(list_paths_img, list_paths_mask, 
                                                                                                          test_size=0.1, random_state=12)

list_name_f = [x.split('/')[-1].split('.')[0] for x in list_paths_img_test]

print (len(list_paths_img), len(list_paths_img_train))

775 697


In [20]:
def dice_with_res(pred_np, true_np, SIZE = 512):
    pred_rescaled = cv2.resize((pred_np*255).astype(np.uint8), (SIZE, SIZE), interpolation = cv2.INTER_CUBIC)
    pred_rescaled = np.where(pred_rescaled > 128, 1, 0).reshape(-1).astype(int)
    
    true_rescaled = cv2.resize((true_np*255).astype(np.uint8), (SIZE, SIZE), interpolation = cv2.INTER_CUBIC)
    true_rescaled = np.where(true_rescaled > 128, 1, 0).reshape(-1).astype(int)

    return dice_loss_np(pred_rescaled, true_rescaled.reshape(-1).astype(int))

In [21]:
test_album_FULL_IMAGE = Compose([
    Resize(int(PATCH_SIZE), int(PATCH_SIZE)), 
], additional_targets = {'image0': 'image', 'mask_pred': 'mask'})

test_transform_torch_FULL_IMAGE = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

test_transform_torch_mask_pred_FULL_IMAGE = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

toTensor = transforms.ToTensor()

In [22]:
from sklearn.metrics import jaccard_score

In [23]:
net.eval()

dice_list_whole_image = []
sensivity_list_whole_image = []
specitivity_list_whole_image = []
acc_list = []
jacc_list = []

for temp_name in tqdm(list_name_f):
    img = cv2.imread(INPUT_DIR_FULL_IMAGE + temp_name + '.jpg')
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) #??????
    img_rescaled = cv2.resize(img, (IMG_SIZE, IMG_SIZE), interpolation = cv2.INTER_CUBIC)

    mask_GT = cv2.imread(INPUT_DIR_MASK_GT_FULL_IMAGE + temp_name + '.labels.tif', 0)
    mask_GT = np.where(mask_GT > 0, 255, 0).astype(np.uint8)
    mask_GT_rescaled = cv2.resize(mask_GT, (IMG_SIZE, IMG_SIZE), interpolation = cv2.INTER_CUBIC)
    mask_GT_rescaled = np.where(mask_GT_rescaled > 128, 1, 0).astype(np.uint8)
    mask_GT_rescaled = np.expand_dims(mask_GT_rescaled, axis=-1)

    mask_pred = cv2.imread(INPUT_DIR_MASK_PRED_FULL_IMAGE + temp_name + '.jpg', 0)
    mask_pred_rescaled = cv2.resize(mask_pred, (IMG_SIZE, IMG_SIZE), interpolation = cv2.INTER_CUBIC)
    mask_pred_rescaled = np.expand_dims(mask_pred_rescaled, axis=-1)

    img_gray = cv2.cvtColor(img_rescaled, cv2.COLOR_BGR2GRAY)
    img_gray_mask = np.where(img_gray > 10, 1, 0)
    circle = np.zeros((IMG_SIZE, IMG_SIZE), dtype=np.uint8)
    cv2.circle(circle, (int(IMG_SIZE/2), int(IMG_SIZE/2)), int(IMG_SIZE/2), 1, thickness=-1)

    img_gray_mask = img_gray_mask*circle

    mask_predict = np.zeros((IMG_SIZE,IMG_SIZE))
    mask_predict_count = np.zeros((IMG_SIZE,IMG_SIZE))

    for i in range(0, IMG_SIZE-PATCH_SIZE, STEP):
        for j in range(0, IMG_SIZE-PATCH_SIZE, STEP):
    #         if img_gray_mask[i:i+PATCH_SIZE, j:j+PATCH_SIZE].sum() > PATCH_SIZE*PATCH_SIZE/2 and mask_cropped[i:i+PATCH_SIZE, j:j+PATCH_SIZE].sum() > 0:
            if img_gray_mask[i:i+PATCH_SIZE, j:j+PATCH_SIZE].sum() > PATCH_SIZE*PATCH_SIZE/2 or True:
                img_patch = img_rescaled[i:i+PATCH_SIZE, j:j+PATCH_SIZE,:]
                mask_GT_patch = mask_GT_rescaled[i:i+PATCH_SIZE, j:j+PATCH_SIZE]
                mask_pred_patch = mask_pred_rescaled[i:i+PATCH_SIZE, j:j+PATCH_SIZE]

                img_auged = test_album_FULL_IMAGE(image=img_patch, mask=mask_GT_patch, mask_pred=mask_pred_patch)

                image_tensor = img_auged['image']
                mask_tensor = img_auged['mask']
                mask_pred_aug = img_auged['mask_pred']
                image_tensor = test_transform_torch_FULL_IMAGE(image_tensor)
#                 mask_tensor = toTensor(mask_tensor).float()
                mask_tensor = test_transform_torch_mask_pred_FULL_IMAGE(mask_tensor)
                mask_pred_tensor = test_transform_torch_mask_pred_FULL_IMAGE(mask_pred_aug).float() #???? mask_pred_tensor = toTensor(mask_pred_aug).float()

                image_tensor = image_tensor.cuda()
                mask_tensor = mask_tensor.cuda()
                mask_pred_tensor = mask_pred_tensor.cuda()

                data_catted = torch.cat([image_tensor, mask_pred_tensor], 0)

                masks_pred = net(torch.unsqueeze(data_catted, 0))
                mask_predict[i:i+PATCH_SIZE, j:j+PATCH_SIZE] += masks_pred.cpu().numpy()[0,0]
                mask_predict_count[i:i+PATCH_SIZE, j:j+PATCH_SIZE] += 1

    mask_predict_count_temp = np.where(mask_predict_count == 0, 1, mask_predict_count)
    mask_predict_avg = mask_predict/mask_predict_count_temp
    mask_predict_avg_binary = np.where(mask_predict_avg > 0.5, 1, 0)

    dice_list_whole_image.append(dice_with_res(mask_predict_avg_binary, mask_GT_rescaled))

    tn_, fp_, fn_, tp_ = confusion_matrix(mask_GT_rescaled.reshape(-1).astype(int), mask_predict_avg_binary.reshape(-1).astype(int)).ravel()
    specificity = tn_ / (tn_+fp_)
    sensitivity = tp_ / (tp_+fn_)
    
    jacc = jaccard_score(mask_GT_rescaled.reshape(-1).astype(int), mask_predict_avg_binary.reshape(-1).astype(int))
    jacc_list.append(jacc)
    sensivity_list_whole_image.append(sensitivity)
    specitivity_list_whole_image.append(specificity)
    acc_list.append((tp_ + tn_)/ (tp_+tn_+fp_+fn_))
    
    cv2.imwrite(OUTPUT_DIR_MASK_PRED_FULL_IMAGE + temp_name + '.png', mask_predict_avg_binary*255)
    
print ('Dice score: ', 1 - np.array(dice_list_whole_image).mean())
print ('Sensivity: ', np.array(sensivity_list_whole_image).mean())
print ('Specitivity: ', np.array(specitivity_list_whole_image).mean())
print ('Accuracy: ', np.array(acc_list).mean())
print ('Jaccard: ', np.array(jacc_list).mean())

100%|██████████| 78/78 [05:05<00:00,  4.22s/it]

Dice score:  0.6250615175252656
Sensivity:  0.6615360988014666
Specitivity:  0.9970243503297536
Accuracy:  0.9948013977920226
Jaccard:  0.4719299921313919



