In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '1,2'

import cv2
import glob
import pandas as pd
import numpy as np
import torch
from torch import nn
from torch import optim
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as dset
import torch.nn.functional as F
import torchvision.utils as vutils
import pickle
from PIL import ImageFile
from PIL import Image
from tqdm import tqdm
from torch.optim.lr_scheduler import StepLR

from sklearn.metrics import jaccard_score
from unet import UNet
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

from sklearn.metrics import recall_score
from sklearn.metrics import confusion_matrix

from albumentations import (
    HorizontalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine,
    IAASharpen, IAAEmboss, RandomContrast, RandomBrightness, Flip, OneOf, Compose,
    RandomCrop, Normalize, Resize
)

ImageFile.LOAD_TRUNCATED_IMAGES = True

import matplotlib.pyplot as plt
def show(img):
    npimg = img.detach().numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')
    plt.show()

In [2]:
img_size = 128
IMG_HEIGHT, IMG_WIDTH = 128, 128
batchSize = 32
use_cuda = torch.cuda.is_available()
  
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print (device)

cuda


In [3]:
if use_cuda:
#     net = UNet(n_channels=3, n_classes=1).cuda()
    net = nn.DataParallel(UNet(n_channels=4, n_classes=1)).cuda()
else:
    net = UNet(n_channels=3, n_classes=1)

## Load model
### Note: you should change the path

In [4]:
# net.load_state_dict(torch.load('./results/patch_withPred_ver2/net_025.pth'))
net.load_state_dict(torch.load('/home/quang/working/fundus_segmentation/model_paper/results/patch_withPred_ver2/net_025.pth'))

<All keys matched successfully>

## Predict whole image
### Note: you should change the path

In [5]:
IMG_SIZE = 1200
PATCH_SIZE = 128
STEP = 64

# INPUT_DIR_FULL_IMAGE = '/home/quang/working/fundus_segmentation/data/segmentation_doctor/data_full/'
# INPUT_DIR_MASK_PRED_FULL_IMAGE = '/home/quang/working/fundus_segmentation/data/segmentation_doctor/data_full_mask_pred/'

INPUT_DIR_FULL_IMAGE = '/home/quang/working/Face-Aging-CAAE/data/data_amd_resize_1200_8196/'
INPUT_DIR_MASK_PRED_FULL_IMAGE = '/home/quang/working/fundus_segmentation/data/segmentation_doctor/prediction_data_amd_512_8196/'

# OUTPUT_DIR_MASK_PRED_FULL_IMAGE = './predictions/main_model/'
OUTPUT_DIR_MASK_PRED_FULL_IMAGE = '/home/quang/working/fundus_segmentation/data/segmentation_doctor/prediction_data_amd_1200_8196/'

list_paths_img = glob.glob(INPUT_DIR_FULL_IMAGE + '*.jpg')

list_name_f = [x.split('/')[-1].split('.')[0] for x in list_paths_img]
print (len(list_name_f))

8196


In [6]:
test_album_FULL_IMAGE = Compose([
    Resize(int(PATCH_SIZE), int(PATCH_SIZE)), 
], additional_targets = {'image0': 'image', 'mask_pred': 'mask'})

test_transform_torch_FULL_IMAGE = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

test_transform_torch_mask_pred_FULL_IMAGE = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

toTensor = transforms.ToTensor()

In [7]:
net.eval()

with torch.no_grad():
    for temp_name in tqdm(list_name_f):
        img = cv2.imread(INPUT_DIR_FULL_IMAGE + temp_name + '.jpg')
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) #??????
        img_rescaled = cv2.resize(img, (IMG_SIZE, IMG_SIZE), interpolation = cv2.INTER_CUBIC)

        mask_pred = cv2.imread(INPUT_DIR_MASK_PRED_FULL_IMAGE + temp_name + '.jpg', 0)
        mask_pred_rescaled = cv2.resize(mask_pred, (IMG_SIZE, IMG_SIZE), interpolation = cv2.INTER_CUBIC)
        mask_pred_rescaled = np.expand_dims(mask_pred_rescaled, axis=-1)

        img_gray = cv2.cvtColor(img_rescaled, cv2.COLOR_BGR2GRAY)
        img_gray_mask = np.where(img_gray > 10, 1, 0)
        circle = np.zeros((IMG_SIZE, IMG_SIZE), dtype=np.uint8)
        cv2.circle(circle, (int(IMG_SIZE/2), int(IMG_SIZE/2)), int(IMG_SIZE/2), 1, thickness=-1)

        img_gray_mask = img_gray_mask*circle

        mask_predict = np.zeros((IMG_SIZE,IMG_SIZE))
        mask_predict_count = np.zeros((IMG_SIZE,IMG_SIZE))

        for i in range(0, IMG_SIZE-PATCH_SIZE, STEP):
            for j in range(0, IMG_SIZE-PATCH_SIZE, STEP):
                if img_gray_mask[i:i+PATCH_SIZE, j:j+PATCH_SIZE].sum() > PATCH_SIZE*PATCH_SIZE/2 or True:
                    img_patch = img_rescaled[i:i+PATCH_SIZE, j:j+PATCH_SIZE,:]
                    mask_pred_patch = mask_pred_rescaled[i:i+PATCH_SIZE, j:j+PATCH_SIZE]

                    img_auged = test_album_FULL_IMAGE(image=img_patch, mask_pred=mask_pred_patch)

                    image_tensor = img_auged['image']
                    mask_pred_aug = img_auged['mask_pred']
                    image_tensor = test_transform_torch_FULL_IMAGE(image_tensor)
                    mask_pred_tensor = test_transform_torch_mask_pred_FULL_IMAGE(mask_pred_aug).float()
                    image_tensor = image_tensor.cuda()
                    mask_pred_tensor = mask_pred_tensor.cuda()
                    data_catted = torch.cat([image_tensor, mask_pred_tensor], 0)
                    masks_pred = net(torch.unsqueeze(data_catted, 0))
                    mask_predict[i:i+PATCH_SIZE, j:j+PATCH_SIZE] += masks_pred.cpu().numpy()[0,0]
                    mask_predict_count[i:i+PATCH_SIZE, j:j+PATCH_SIZE] += 1

        mask_predict_count_temp = np.where(mask_predict_count == 0, 1, mask_predict_count)
        mask_predict_avg = mask_predict/mask_predict_count_temp
        mask_predict_avg_binary = np.where(mask_predict_avg > 0.5, 1, 0)

        cv2.imwrite(OUTPUT_DIR_MASK_PRED_FULL_IMAGE + temp_name + '.png', mask_predict_avg_binary*255)

100%|██████████| 8196/8196 [4:22:36<00:00,  1.91s/it]  
