In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '1,2'

import cv2
import glob
import pandas as pd
import numpy as np
import torch
from torch import nn
from torch import optim
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as dset
import torch.nn.functional as F
import torchvision.utils as vutils
import pickle
from PIL import ImageFile
from tqdm import tqdm

from unet import UNet
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

from sklearn.metrics import recall_score
from sklearn.metrics import confusion_matrix

from albumentations import (
    HorizontalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine,
    IAASharpen, IAAEmboss, RandomContrast, RandomBrightness, Flip, OneOf, Compose,
    RandomCrop, Normalize, Resize
)

ImageFile.LOAD_TRUNCATED_IMAGES = True

import matplotlib.pyplot as plt
def show(img):
    npimg = img.detach().numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')
    plt.show()

In [2]:
""" DeepLabv3 Model download and change the head for your prediction"""
from torchvision import models
from torchvision.models.segmentation.deeplabv3 import DeepLabHead

def createDeepLabv3(outputchannels=1):
    model = models.segmentation.deeplabv3_resnet101(
        pretrained=True, progress=True)
    # Added a Sigmoid activation after the last convolution layer
    model.classifier = DeepLabHead(2048, outputchannels)
    # Set the model in training mode
    model.train()
    return model

### Note: you should change the path

In [3]:
niter = 20
outf = './results/prediction_whole_image/'

PATH_DATA = '/home/quang/working/fundus_segmentation/data/segmentation_doctor/data_full/'
PATH_MASK = '/home/quang/working/fundus_segmentation/data/segmentation_doctor/data_full/'

img_size = 512
IMG_HEIGHT, IMG_WIDTH = 512, 512
batchSize = 4
use_cuda = torch.cuda.is_available()
# IS_TRAINING = True
# n_channel = 3
# n_disc = 16
# n_gen = 64
# n_encode = 64
# n_l = 10
# n_z = 50
# n_age = int(n_z/n_l)
# n_gender = int(n_z/2)
# n_class_age = 6
# n_repeat = 2
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print (device)

cuda


In [4]:
list_paths_img = glob.glob(PATH_DATA + '*.jpg')
list_paths_mask = [PATH_MASK + x.split('/')[-1].split('.')[0] + '.labels.tif' for x in list_paths_img]

list_paths_img_train, list_paths_img_test, list_paths_mask_train, list_paths_mask_test = train_test_split(list_paths_img, list_paths_mask, 
                                                                                                          test_size=0.1, random_state=12)
print (len(list_paths_img), len(list_paths_img_train))

775 697


In [5]:
class ImageAug:
    def __init__(self, aug):
        self.aug=aug

    def __call__(self, img):
        img = self.aug(image=img)['image']
        return img

class FundusDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, PATH_IMG, PATH_MASK, transform=None, transform_torch=None, toTensor=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.PATH_IMG = PATH_IMG
        self.PATH_MASK = PATH_MASK
        self.ToTensor = toTensor
        self.transform = transform
        self.transform_torch = transform_torch

    def __len__(self):
#         return 100
        return len(self.PATH_IMG)

    def __getitem__(self, idx):
        file_name_temp = self.PATH_IMG[idx].split('/')[-1]
        image = cv2.imread(self.PATH_IMG[idx])
        mask = cv2.imread(self.PATH_MASK[idx],0)
        mask = np.where(mask > 0, 1., 0.)
        mask = np.expand_dims(mask, axis=-1)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.transform:
            img_auged = self.transform(image=image, mask=mask)
            image = img_auged['image']
            mask = img_auged['mask']
            image = self.transform_torch(image)
            mask = self.ToTensor(mask)
        sample = {'image': image, 'mask': mask.float(), 'id': file_name_temp}

        return sample

    
train_album = Compose([
    Resize(int(IMG_HEIGHT*1.), int(IMG_WIDTH*1.)), 
    RandomCrop(IMG_HEIGHT, IMG_HEIGHT),
    HorizontalFlip(),
    ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.05, rotate_limit=10, p=.75)
], additional_targets = {'image0': 'image'})

train_transform_torch = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

test_album = Compose([
    Resize(IMG_HEIGHT, IMG_WIDTH)
])

test_transform_torch = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])


train_dataset = FundusDataset(list_paths_img_train, list_paths_mask_train, transform=train_album, transform_torch=train_transform_torch, 
                              toTensor=transforms.ToTensor())

dataloader = torch.utils.data.DataLoader(train_dataset,
                                             batch_size=batchSize, shuffle=True,
                                             num_workers=2, drop_last = True)

valid_dataset = FundusDataset(list_paths_img_test, list_paths_mask_test, transform=test_album, transform_torch=test_transform_torch, 
                              toTensor=transforms.ToTensor())

dataloader_valid = torch.utils.data.DataLoader(valid_dataset,
                                             batch_size=1, shuffle=False,
                                             num_workers=2)

In [7]:
def show(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')
    plt.show()

In [10]:
def requires_grad(model, flag=True):
    for p in model.parameters():
        p.requires_grad = flag

In [14]:
if use_cuda:
    BCE = nn.BCELoss().cuda()
    L1  = nn.L1Loss().cuda()
    CE = nn.CrossEntropyLoss().cuda()
    MSE = nn.MSELoss().cuda()
else:
    BCE = nn.BCELoss()
    L1  = nn.L1Loss()
    CE = nn.CrossEntropyLoss()
    MSE = nn.MSELoss()

In [15]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.logits = logits
        self.reduce = reduce

    def forward(self, inputs, targets):
        if self.logits:
            BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False)
        else:
            BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False)
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

        if self.reduce:
            return torch.mean(F_loss)
        else:
            return F_loss

In [16]:
def dice_loss(input, target):
    smooth = 1.

    iflat = input.view(-1)
    tflat = target.view(-1)
    intersection = (iflat * tflat).sum()
    
    return 1 - ((2. * intersection + smooth) /
              (iflat.sum() + tflat.sum() + smooth))

def dice_loss_np(input, target):
    smooth = 1.

    iflat = input.reshape(-1)
    tflat = target.reshape(-1)
    intersection = (iflat * tflat).sum()
    
    return 1 - ((2. * intersection + smooth) /
              (iflat.sum() + tflat.sum() + smooth))

# Focal Tversky loss, brought to you by:  https://github.com/nabsabraham/focal-tversky-unet
def tversky(y_true, y_pred, smooth=1e-6):
    y_true_pos = y_true.view(-1)
    y_pred_pos = y_pred.view(-1)
    true_pos = (y_true_pos * y_pred_pos).sum()
    false_neg = (y_true_pos * (1-y_pred_pos)).sum()
    false_pos = ((1-y_true_pos)*y_pred_pos).sum()
    alpha = 0.75
    return (true_pos + smooth)/(true_pos + alpha*false_neg + (1-alpha)*false_pos + smooth)

def tversky_loss(y_true, y_pred):
    return 1 - tversky(y_true,y_pred)

def focal_tversky_loss(y_true,y_pred):
    pt_1 = tversky(y_true, y_pred)
    gamma = 0.75
#     return tf.keras.backend.pow((1-pt_1), gamma)
    return (1-pt_1).pow(gamma)

focal_loss = FocalLoss().to(device)

In [17]:
dataloaders = {'Train': dataloader, 'Test': dataloader_valid}

# model = createDeepLabv3()
model = nn.DataParallel(createDeepLabv3()).cuda()
model.train()

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

## Traning model

### You can weight of BCEWithLogitsLoss with your dataset or change loss function

In [None]:
best_loss = 1e10
batchsummary = {}
niter = 20

## Adjust the weight with your dataset
weight = torch.tensor([15.])
if use_cuda:
    weight = weight.cuda()

criterion = nn.BCEWithLogitsLoss(pos_weight=weight)

LIST_DICE_VAL = []
LIST_BCE_TRAIN = []
LIST_DICE_LOSS_TRAIN = []

for epoch in range(niter):
    print('Epoch {}/{}'.format(epoch, niter))
    print('-' * 10)
    epoch_loss = 0
    epoch_loss_weight = 0
    epoch_dice = 0
    epoch_bounary = 0
    
    TN = 0
    FP = 0
    TP = 0
    FN = 0

    for phase in ['Train', 'Test']:
        if phase == 'Train':
            model.train()  # Set model to training mode
        else:
            model.eval()   # Set model to evaluate mode

        # Iterate over data.
        dice_list = []
        sensivity_list = []
        specitivity_list = []
        for i,sample in enumerate(tqdm(dataloaders[phase])):
            inputs = sample['image'].to(device)
            masks = sample['mask'].to(device)
            # zero the parameter gradients
            optimizer.zero_grad()

            # track history if only in train
            with torch.set_grad_enabled(phase == 'Train'):
                outputs = model(inputs)
                loss_BCE = criterion(outputs['out'], masks)
                
                y_pred = torch.sigmoid(outputs['out'])
                
                masks_probs_flat_tensor = y_pred.view(-1)
                true_masks_flat_tensor = masks.view(-1)
#                 dice_loss_ = dice_loss(masks_probs_flat_tensor, true_masks_flat_tensor)
#                 focal_loss_ = focal_loss(masks_probs_flat_tensor, true_masks_flat_tensor)
#                 ft_loss = focal_tversky_loss(true_masks_flat_tensor, masks_probs_flat_tensor)
#                 loss = 0.*loss_BCE + dice_loss_  + 0.5*focal_loss_
                loss = loss_BCE
                           
                epoch_loss += loss.mean().item()
                y_pred = y_pred.data.cpu().numpy().ravel()
                y_true = masks.data.cpu().numpy().ravel()
                
                y_pred = np.where(y_pred > 0.5, 1, 0)
                y_true = np.where(y_true > 0.5, 1, 0)
                
                if phase is not 'Train':
                    dice_ = dice_loss_np(y_pred.astype(int), y_true.astype(int))
                    dice_list.append(dice_)
                else:
                    dice_ = 0
                    dice_list.append(dice_)
                tn_, fp_, fn_, tp_ = confusion_matrix(y_true.reshape(-1).astype(int), y_pred.reshape(-1).astype(int)).ravel()
                specificity = tn_ / (tn_+fp_)
                sensitivity = tp_ / (tp_+fn_)

                sensivity_list.append(sensitivity)
                specitivity_list.append(specificity)

                if phase == 'Train':
                    loss.backward()
                    optimizer.step()


        ## checkpoint
        if epoch%10==0 or epoch == (niter-1):
            torch.save(model.state_dict(),"%s/net_%03d.pth"%(outf,epoch+1))

        print('Epoch {} finished ! Loss: {:04.3f} | Loss_weight: {:04.3f} | sensitivity: {:04.3f} | specificity: {}| Dice score: {:04.3f}'.format(epoch,epoch_loss / i, 
                                                0, np.array(sensivity_list).mean(), np.array(specitivity_list).mean(), 1 - np.array(dice_list).mean())) 

        print("-"*80)
        if  phase is not 'Train':
            LIST_DICE_VAL.append(1 - np.array(dice_list).mean())
            LIST_BCE_TRAIN.append(epoch_loss / i)
            LIST_DICE_LOSS_TRAIN.append(epoch_dice / i)

  0%|          | 0/174 [00:00<?, ?it/s]

Epoch 0/20
----------


  9%|▉         | 16/174 [00:37<05:47,  2.20s/it]