In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '1,2'

import cv2
import glob
import pandas as pd
import numpy as np
import torch
from torch import nn
from torch import optim
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as dset
import torch.nn.functional as F
import torchvision.utils as vutils
import pickle
from PIL import ImageFile
from tqdm import tqdm

from unet import UNet
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

from sklearn.metrics import recall_score
from sklearn.metrics import confusion_matrix

from albumentations import (
    HorizontalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine,
    IAASharpen, IAAEmboss, RandomContrast, RandomBrightness, Flip, OneOf, Compose,
    RandomCrop, Normalize, Resize
)

ImageFile.LOAD_TRUNCATED_IMAGES = True

import matplotlib.pyplot as plt
def show(img):
    npimg = img.detach().numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')
    plt.show()

In [2]:
""" DeepLabv3 Model download and change the head for your prediction"""
from torchvision import models
from torchvision.models.segmentation.deeplabv3 import DeepLabHead

def createDeepLabv3(outputchannels=1):
    model = models.segmentation.deeplabv3_resnet101(
        pretrained=True, progress=True)
    # Added a Sigmoid activation after the last convolution layer
    model.classifier = DeepLabHead(2048, outputchannels)
    # Set the model in training mode
    model.train()
    return model

In [3]:
IMG_HEIGHT, IMG_WIDTH = 512, 512
batchSize = 4
use_cuda = torch.cuda.is_available()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print (device)

cuda


In [4]:
class ImageAug:
    def __init__(self, aug):
        self.aug=aug

    def __call__(self, img):
        img = self.aug(image=img)['image']
        return img

class FundusDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, PATH_IMG, transform=None, transform_torch=None, toTensor=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.PATH_IMG = PATH_IMG
        self.ToTensor = toTensor
        self.transform = transform
        self.transform_torch = transform_torch

    def __len__(self):
        return len(self.PATH_IMG)

    def __getitem__(self, idx):
        file_name_temp = self.PATH_IMG[idx].split('/')[-1]
        image = cv2.imread(self.PATH_IMG[idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.transform:
            img_auged = self.transform(image=image)
            image = img_auged['image']
            image = self.transform_torch(image)
        sample = {'image': image, 'id': file_name_temp}

        return sample

    
train_transform_torch = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

test_album = Compose([
    Resize(IMG_HEIGHT, IMG_WIDTH)
])

test_transform_torch = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

In [5]:
# model = createDeepLabv3()
model = nn.DataParallel(createDeepLabv3()).cuda()

## Load model
### Note: you should change the path

In [6]:
# model.load_state_dict(torch.load('./saved_model/image_DeepLab/net_040.pth'))
model.load_state_dict(torch.load('/home/quang/working/fundus_segmentation/results/image_DeepLab/net_040.pth'))

<All keys matched successfully>

## Generate mask for all data
### Note: you should change the path

In [7]:
PATH_DATA = '/home/quang/working/Face-Aging-CAAE/data/data_amd_resize_1200_8196/'
PATH_OUTPUT = '/home/quang/working/fundus_segmentation/data/segmentation_doctor/prediction_data_amd_512_8196/'

list_paths_img = glob.glob(PATH_DATA + '*.jpg')

In [8]:
valid_dataset_all = FundusDataset(list_paths_img, transform=test_album, transform_torch=test_transform_torch, 
                              toTensor=transforms.ToTensor())

dataloader_valid_all = torch.utils.data.DataLoader(valid_dataset_all,
                                             batch_size=32, shuffle=False,
                                             num_workers=2)

# requires_grad(net, False)
model.eval()
dice_list = []
TN = 0
FP = 0
TP = 0
FN = 0
sensivity_list = []
specitivity_list = []
for j_,sample in enumerate(tqdm(dataloader_valid_all)):

    inputs = sample['image'].to(device)
    ids_list = sample['id']
    # zero the parameter gradients

    # track history if only in train
    with torch.set_grad_enabled(False):
        outputs = model(inputs)
        y_pred = torch.sigmoid(outputs['out'])
        y_pred = y_pred.data.cpu().numpy()

        for idx_temp, f_name_temp in enumerate(ids_list):
            cv2.imwrite(PATH_OUTPUT + f_name_temp.split('/')[-1] ,(y_pred[idx_temp][0]*255).astype(np.uint8))

100%|██████████| 257/257 [06:44<00:00,  1.22s/it]
