# Patch Based Segmentation of Fundus Imagery

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

# Load the image names
import os
import os.path
import cv2

from sklearn.datasets import load_sample_image
from sklearn.feature_extraction import image as imgutil

import time
import torch.utils.data as utils

mpl.rcParams['figure.dpi'] = 300

In [2]:
torch.cuda.get_device_name(0)
torch.cuda.empty_cache()

In [3]:
class FocalLoss(nn.Module):

    def __init__(self, gamma=0, alpha=None, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        if isinstance(alpha, (float, int)): self.alpha = torch.Tensor([alpha, 1 - alpha])
        if isinstance(alpha, list): self.alpha = torch.Tensor(alpha)
        self.size_average = size_average

    def forward(self, input, target):
        if input.dim()>2:
            input = input.view(input.size(0), input.size(1), -1)  # N,C,H,W => N,C,H*W
            input = input.transpose(1, 2)                         # N,C,H*W => N,H*W,C
            input = input.contiguous().view(-1, input.size(2))    # N,H*W,C => N*H*W,C
        target = target.view(-1, 1)

        logpt = F.log_softmax(input, dim=1)
        logpt = logpt.gather(1,target)
        logpt = logpt.view(-1)
        pt = logpt.exp()

        if self.alpha is not None:
            if self.alpha.type() != input.data.type():
                self.alpha = self.alpha.type_as(input.data)
            at = self.alpha.gather(0, target.data.view(-1))
            logpt = logpt * at

        loss = -1 * (1 - pt)**self.gamma * logpt
        if self.size_average: return loss.mean()
        else: return loss.sum()

In [4]:
class MultiClassClassifier(nn.Module):
    def __init__(self):
        super(MultiClassClassifier, self).__init__()
        self.conv = nn.Sequential()
        self.conv.add_module("Pad1", nn.ConstantPad2d((0,1,0,1), 0))
        self.conv.add_module("Conv1", nn.Conv2d(3, 32, kernel_size=2))
        self.conv.add_module("BN1", nn.BatchNorm2d(32))
        self.conv.add_module("Relu1", nn.ReLU())
        
        self.conv.add_module("Pad2", nn.ConstantPad2d((0,1,0,1), 0))
        self.conv.add_module("Conv2", nn.Conv2d(32, 32, kernel_size=2))
        self.conv.add_module("BN2", nn.BatchNorm2d(32))
        self.conv.add_module("Relu2", nn.ReLU())
        self.conv.add_module("Layer2MaxPool", nn.MaxPool2d(2))
        
        self.conv.add_module("Pad3", nn.ConstantPad2d((0,1,0,1), 0))
        self.conv.add_module("Conv3", nn.Conv2d(32, 64, kernel_size=2))
        self.conv.add_module("BN3", nn.BatchNorm2d(64))
        self.conv.add_module("Relu3", nn.ReLU())
        
        self.conv.add_module("Pad4", nn.ConstantPad2d((0,1,0,1), 0))
        self.conv.add_module("Conv4", nn.Conv2d(64, 64, kernel_size=2))
        self.conv.add_module("BN4", nn.BatchNorm2d(64))
        self.conv.add_module("Relu4", nn.ReLU())
        self.conv.add_module("Layer4MaxPool", nn.MaxPool2d(2))
                             
        self.conv.add_module("Pad5", nn.ConstantPad2d((0,1,0,1), 0))
        self.conv.add_module("Conv5", nn.Conv2d(64, 128, kernel_size=2))
        self.conv.add_module("BN5", nn.BatchNorm2d(128))
        self.conv.add_module("Relu5", nn.ReLU())
        
        self.conv.add_module("Pad6", nn.ConstantPad2d((0,1,0,1), 0))
        self.conv.add_module("Conv6", nn.Conv2d(128, 128, kernel_size=2))
        self.conv.add_module("BN6", nn.BatchNorm2d(128))
        self.conv.add_module("Relu6", nn.ReLU())
        self.conv.add_module("Layer6MaxPool", nn.MaxPool2d(2))
        
        self.conv.add_module("Pad7", nn.ConstantPad2d((0,1,0,1), 0))
        self.conv.add_module("Conv7", nn.Conv2d(128, 256, kernel_size=2))
        self.conv.add_module("BN7", nn.BatchNorm2d(256))
        self.conv.add_module("Relu7", nn.ReLU())
        
        self.conv.add_module("Pad8", nn.ConstantPad2d((0,1,0,1), 0))
        self.conv.add_module("Conv8", nn.Conv2d(256, 256, kernel_size=2))
        self.conv.add_module("BN8", nn.BatchNorm2d(256))
        self.conv.add_module("Relu8", nn.ReLU())
        
        self.fc = nn.Sequential()
        self.fc.add_module("FC1", nn.Linear(4096, 1000))
        self.fc.add_module("Relu9", nn.ReLU())
        self.fc.add_module("Dropout1", nn.Dropout(0.5))
        self.fc.add_module("FC2", nn.Linear(1000, 100))
        self.fc.add_module("Relu10", nn.ReLU())
        self.fc.add_module("Dropout1", nn.Dropout(0.5))
        self.fc.add_module("FC3",nn.Linear(100, 5)) 
    
    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        return x

In [5]:
class BinaryClassifier(nn.Module):
    def __init__(self):
        super(BinaryClassifier, self).__init__()
        self.conv = nn.Sequential()
        self.conv.add_module("Pad1", nn.ConstantPad2d((0,1,0,1), 0))
        self.conv.add_module("Conv1", nn.Conv2d(3, 32, kernel_size=2))
        self.conv.add_module("BN1", nn.BatchNorm2d(32))
        self.conv.add_module("Relu1", nn.ReLU())
        
        self.conv.add_module("Pad2", nn.ConstantPad2d((0,1,0,1), 0))
        self.conv.add_module("Conv2", nn.Conv2d(32, 32, kernel_size=2))
        self.conv.add_module("BN2", nn.BatchNorm2d(32))
        self.conv.add_module("Relu2", nn.ReLU())
        self.conv.add_module("Layer2MaxPool", nn.MaxPool2d(2))
        
        self.conv.add_module("Pad3", nn.ConstantPad2d((0,1,0,1), 0))
        self.conv.add_module("Conv3", nn.Conv2d(32, 64, kernel_size=2))
        self.conv.add_module("BN3", nn.BatchNorm2d(64))
        self.conv.add_module("Relu3", nn.ReLU())
        
        self.conv.add_module("Pad4", nn.ConstantPad2d((0,1,0,1), 0))
        self.conv.add_module("Conv4", nn.Conv2d(64, 64, kernel_size=2))
        self.conv.add_module("BN4", nn.BatchNorm2d(64))
        self.conv.add_module("Relu4", nn.ReLU())
        self.conv.add_module("Layer4MaxPool", nn.MaxPool2d(2))
                             
        self.conv.add_module("Pad5", nn.ConstantPad2d((0,1,0,1), 0))
        self.conv.add_module("Conv5", nn.Conv2d(64, 128, kernel_size=2))
        self.conv.add_module("BN5", nn.BatchNorm2d(128))
        self.conv.add_module("Relu5", nn.ReLU())
        
        self.conv.add_module("Pad6", nn.ConstantPad2d((0,1,0,1), 0))
        self.conv.add_module("Conv6", nn.Conv2d(128, 128, kernel_size=2))
        self.conv.add_module("BN6", nn.BatchNorm2d(128))
        self.conv.add_module("Relu6", nn.ReLU())
        self.conv.add_module("Layer6MaxPool", nn.MaxPool2d(2))
        
        self.conv.add_module("Pad7", nn.ConstantPad2d((0,1,0,1), 0))
        self.conv.add_module("Conv7", nn.Conv2d(128, 256, kernel_size=2))
        self.conv.add_module("BN7", nn.BatchNorm2d(256))
        self.conv.add_module("Relu7", nn.ReLU())
        
        self.conv.add_module("Pad8", nn.ConstantPad2d((0,1,0,1), 0))
        self.conv.add_module("Conv8", nn.Conv2d(256, 256, kernel_size=2))
        self.conv.add_module("BN8", nn.BatchNorm2d(256))
        self.conv.add_module("Relu8", nn.ReLU())
        
        self.fc = nn.Sequential()
        self.fc.add_module("FC1", nn.Linear(4096, 1000))
        self.fc.add_module("Relu9", nn.ReLU())
        self.fc.add_module("Dropout1", nn.Dropout(0.5))
        self.fc.add_module("FC2", nn.Linear(1000, 100))
        self.fc.add_module("Relu10", nn.ReLU())
        self.fc.add_module("Dropout1", nn.Dropout(0.5))
        self.fc.add_module("FC3",nn.Linear(100, 2)) 
    
    def forward(self, x):
        x = self.conv(x)

        x = x.view(x.shape[0], -1)
        
        x = self.fc(x)
        return x

In [6]:
model = MultiClassClassifier()
# model = VesselClassifier()
model.cuda()
learning_rate = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [7]:
criterion = F.cross_entropy

In [8]:
#==================================
# We may wish to load a previous model
load_saved = False
if load_saved:
    saved_models = list(map(lambda x: os.path.splitext(x)[0], filter(lambda x: os.path.splitext(x)[1] == '.sav', os.listdir('.'))))
    print(saved_models)
    attempts = 0
    while attempts < 3:
        attempts += 1
        model_name = input("Choose a saved model: ")
        if model_name not in saved_models:
            print("Not found. Try again")
            continue
        else:
            model.load_state_dict(torch.load(model_name+".sav"))
            break

In [9]:
def get_labelled_patches(img, truth):
    img = np.dstack((img, truth))
    patches = imgutil.extract_patches_2d(img, (32,32))
    np.random.shuffle(patches)
    positive_patches = []
    negative_patches = []
    positive_labels = []
    for patch in patches:
              
        truth = patch[16,16,3]
        patch = patch[:,:,:3]  
        patch = np.rollaxis(patch,2,0)

        if truth != 0:
            positive_patches.append(torch.from_numpy(patch))
            positive_labels.append(truth)
        else:
            negative_patches.append(torch.from_numpy(patch))
            
    return torch.stack(positive_patches), torch.stack(negative_patches), positive_labels

In [10]:
def create_tensors(images, features):
    
    patches = []
    labels = []
    
    for image, feature in zip(images, features):
        pos, neg, pos_labels = get_labelled_patches(image, feature)
        neg = neg[:len(pos)*5,:,:]
        patches.extend(pos)
        patches.extend(neg)
            
        pos_l = torch.Tensor(pos_labels)
        neg_l = torch.zeros(len(neg))
        labels.extend(pos_l)
        labels.extend(neg_l)
    
    image_tensor = torch.stack(patches).float()
    label_tensor = torch.stack(labels)
    return image_tensor, label_tensor

In [11]:
def load_numpy_data_task1(source):

    # train_path_truths = os.path.join("Data_Group_Component_Task_1", "Train", "masks_Hard_Exudates")
    train_path_images = os.path.join("Data_Group_Component_Task_1", source, "original_retinal_images")

    train_path_exudates = os.path.join("Data_Group_Component_Task_1", source, "masks_Hard_Exudates")
    train_path_soft_exudates = os.path.join("Data_Group_Component_Task_1", source, "masks_Soft_Exudates")
    train_path_haemorrhages = os.path.join("Data_Group_Component_Task_1", source, "masks_Haemorrhages")
    train_path_microaneurysms = os.path.join("Data_Group_Component_Task_1", source, "masks_Microaneurysms")

    train_image_names = os.listdir(train_path_images)

    train_exudate_names = list(map(lambda x: os.path.join(train_path_exudates, x.split('.')[0] + '_EX.tif'), train_image_names))
    train_haem_names = list(map(lambda x: os.path.join(train_path_haemorrhages, x.split('.')[0] + '_HE.tif'), train_image_names))
    train_sfex_names = list(map(lambda x: os.path.join(train_path_soft_exudates, x.split('.')[0] + '_SE.tif'), train_image_names))
    train_ma_names = list(map(lambda x: os.path.join(train_path_microaneurysms, x.split('.')[0] + '_MA.tif'), train_image_names))

    images = list(map(
        lambda x: cv2.resize(
            cv2.imread(
                os.path.join(train_path_images, x)
            ), (256, 256)
        ), train_image_names))

    
    features = []
    for names in zip(train_exudate_names, train_haem_names, train_sfex_names, train_ma_names):
        he, ha, se, ma =  names

        he = cv2.imread(he)
        if he is not None:
            he = cv2.resize(he, (256,256))[:,:,2]
        else:
            he = np.zeros((256,256), dtype=np.uint8)

        ha = cv2.imread(ha)
        if ha is not None:
            ha = cv2.resize(ha, (256,256))[:,:,2]
        else:
            ha = np.zeros((256,256), dtype=np.uint8)

        se = cv2.imread(se)
        if se is not None:
            se = cv2.resize(se, (256,256))[:,:,2]
        else:
            se = np.zeros((256,256), dtype=np.uint8)

        ma = cv2.imread(ma)
        if ma is not None:
            ma = cv2.resize(ma, (256,256))[:,:,2]
        else:
            ma = np.zeros((256,256), dtype=np.uint8)    

        feature_map = (he != 0).astype(np.uint8)
        feature_map[np.where(ha != 0)] = 2
        feature_map[np.where(se != 0)] = 3
        feature_map[np.where(ma != 0)] = 4

        features.append(feature_map)
    
    return images, features

In [12]:
def load_numpy_data_task2(source):
    
    def read_video(video_path):
        video = cv2.VideoCapture(str(video_path))
        while video.isOpened():
            ok, frame = video.read()

            if not ok:
                break

            yield frame
        video.release()

    if source == "Train":
        source = "Training"
    # train_path_truths = os.path.join("Data_Group_Component_Task_1", "Train", "masks_Hard_Exudates")
    train_path_images = os.path.join("Data_Group_Component_Task_2", source, "original_retinal_images")

    train_path_vessels = os.path.join("Data_Group_Component_Task_2", source, "blood_vessel_segmentation_masks")
   
    train_image_names = os.listdir(train_path_images)

    train_vessel_names = list(map(lambda x: os.path.join(train_path_vessels, x.split('_')[0] + '_manual1.gif'), train_image_names))

    images = list(map(
        lambda x: cv2.resize(
            cv2.imread(
                os.path.join(train_path_images, x)
            ), (256, 256)
        ), train_image_names))

    images = list(map(normalise_image, images))
    
    features = []
    for name in train_vessel_names:
        print(name)
        
        vessels = list(read_video(name))[0]
        
        if vessels is not None:
            vessels = cv2.resize(vessels, (256,256))[:,:,2]
        else:
            vessels = np.zeros((256,256), dtype=np.uint8)

        feature_map = (vessels != 0).astype(np.uint8)

        features.append(feature_map)
    
    return images, features

In [13]:
load_numpy_data = load_numpy_data_task1

Load the data

In [14]:
image_list, label_list = load_numpy_data("Train")

In [15]:
image_tensor, label_tensor = create_tensors(image_list, label_list)

mean = torch.mean(image_tensor, axis=tuple(range(image_tensor.ndim-1)))
std = torch.std(image_tensor, axis=tuple(range(image_tensor.ndim-1)))

Put the data into pytorch DataLoader for batching

In [None]:
epochs = 100
batch_size = 50

losses_per_dataset = []
for dataset in [(0,100000),(100000,200000),(200000,300000),(300000,424976)]:
    train_dataset = utils.TensorDataset(image_tensor[dataset[0]:dataset[1]], label_tensor[dataset[0]:dataset[1]])
    train_dataloader = utils.DataLoader(train_dataset, batch_size = batch_size, shuffle=True)
    losses_per_epoch = []
    model.train()

    for epoch in range(0, epochs):
        t0 = time.time()
        losses_per_batch = []

        for i, data in enumerate(train_dataloader):
            patches, labels = data

            # clear gradients
            optimizer.zero_grad()

            # forward pass
            output = model(patches.cuda().float())

            # calculate batch loss
            loss = criterion(output, labels.cuda().long())
            # compute 
            loss.backward()
            optimizer.step()

            losses_per_batch.append(loss.item())

        t1 = time.time()
        losses_per_epoch.append(sum(losses_per_batch))
        torch.cuda.empty_cache()
        print("Epoch: {0} Loss:{1} Trained in {2} seconds".format(epoch+1, sum(losses_per_batch), t1-t0))
    
    losses_per_dataset.append(losses_per_epoch)

Epoch: 1 Loss:713.6724153272808 Trained in 57.71755814552307 seconds
Epoch: 2 Loss:506.0491146296263 Trained in 57.03878211975098 seconds
Epoch: 3 Loss:424.66163576580584 Trained in 61.10589146614075 seconds
Epoch: 4 Loss:366.312396902591 Trained in 57.663920164108276 seconds
Epoch: 5 Loss:326.84037521854043 Trained in 56.72777462005615 seconds
Epoch: 6 Loss:287.71891102474183 Trained in 57.76373767852783 seconds
Epoch: 7 Loss:253.99396940972656 Trained in 61.29642081260681 seconds
Epoch: 8 Loss:224.2651601145044 Trained in 56.523176431655884 seconds
Epoch: 9 Loss:196.0254759497475 Trained in 56.214316606521606 seconds
Epoch: 10 Loss:170.1361532052979 Trained in 56.51357674598694 seconds
Epoch: 11 Loss:143.75554694479797 Trained in 59.27753257751465 seconds
Epoch: 12 Loss:128.34042917389888 Trained in 65.13123846054077 seconds
Epoch: 13 Loss:117.96772166597657 Trained in 59.18282151222229 seconds
Epoch: 14 Loss:99.64340712223202 Trained in 57.57897639274597 seconds
Epoch: 15 Loss:96.58

Epoch: 18 Loss:40.81916815043405 Trained in 69.70354461669922 seconds
Epoch: 19 Loss:37.585002107343826 Trained in 69.71152377128601 seconds
Epoch: 20 Loss:39.121547532418845 Trained in 67.84750938415527 seconds
Epoch: 21 Loss:30.735269776054793 Trained in 64.03371143341064 seconds
Epoch: 22 Loss:32.92275110141418 Trained in 69.32455945014954 seconds
Epoch: 23 Loss:33.428624439865985 Trained in 66.21798014640808 seconds
Epoch: 24 Loss:44.636184004233655 Trained in 64.51941156387329 seconds
Epoch: 25 Loss:28.24809390806513 Trained in 65.06595015525818 seconds
Epoch: 26 Loss:25.186901116719696 Trained in 65.1058440208435 seconds
Epoch: 27 Loss:27.959434051499557 Trained in 66.98880648612976 seconds
Epoch: 28 Loss:27.83032178459598 Trained in 64.1683509349823 seconds
Epoch: 29 Loss:24.016852821137945 Trained in 65.02505993843079 seconds
Epoch: 30 Loss:28.916332027465614 Trained in 65.38409876823425 seconds
Epoch: 31 Loss:27.526303284802452 Trained in 69.79130983352661 seconds
Epoch: 32 Lo

## Training

Train the data for **epochs** on mini-batches of **batch_size**

In [None]:
torch.save(model.state_dict(), "5class_normalised_1streak" + ".sav")

In [None]:
model_name = input("Name the model ")
torch.save(model.state_dict(), model_name + ".sav")
print(f"Saved as {model_name}.sav")

### Evaluate Visually

In [None]:
def recolour(image):
    colours = np.array([
        [0,0,0],
        [0,255,0],
        [0,0,255],
        [255,0,0],
        [255,0,255],
    ])
    
    colour_vector = np.take(colours, image.flatten(), axis=0)
    colour_vector = np.reshape(colour_vector, (image.shape[0], image.shape[1], 3))
    
    return colour_vector

In [None]:
image_test, feature_test = load_numpy_data("Test")

# image_test = np.array(image_test)

# image_test_mean = np.mean(image_test, axis=tuple(range(image_test.ndim-1)))
# image_test_std = np.std(image_test, axis=tuple(range(image_test.ndim-1)))

# image_test = image_test - image_test_mean
# image_test = image_test / image_test_std


In [None]:
import pickle as pkl



def test(num, epochs, task, name):
    model.eval()
    image, truth = list(zip(image_test, feature_test))[num]
    patches = imgutil.extract_patches_2d(image, (32,32))
    rolled_patches = [torch.Tensor(np.rollaxis(patch,2,0)) for patch in patches]
    rolled_patches_tensor = torch.stack(rolled_patches)
    image_patches_dataset = utils.TensorDataset(rolled_patches_tensor)
    image_loader = utils.DataLoader(image_patches_dataset, batch_size=225)

    generated_mask = []

    for i, image_patch_ in enumerate(image_loader):
        img_patch = image_patch_[0]
        test_output = model(img_patch.cuda())
        labels = torch.argmax(test_output,1)# convert one hot to index/pixel form
        generated_mask.append(labels.cpu().data.numpy())

    generated_mask = np.array(generated_mask)
    
    coloured = recolour(generated_mask)
    

    
    coloured_truth = recolour(truth[16:16+225,16:16+225])

    side_by_side = np.hstack((coloured_truth, coloured, image[16:16+225,16:16+225]))
    cv2.imwrite(f"outputs/{task}/{name}/test{num}_{name}_{epochs}.jpg", side_by_side)
    
    with open(f'outputs/{task}/{name}/test{num}_{name}_{epochs}.sav', 'wb') as file:
        pkl.dump({
            'pred': generated_mask,
            'truth': truth,
            'sideBySide': side_by_side,
            'trainingLoss': losses_per_dataset,
        }, file)

In [None]:
for i in range(len(image_test)):
    test(i, "1streak10epochs", "task1", "focalLoss")

In [None]:
learning = []

for dataset in losses_per_dataset:
    for epoch in dataset:
        print(epoch)
        learning.append(epoch)

In [None]:
len(learning)

In [None]:
plt.plot(learning)
plt.title('Learning Curve')
plt.xlabel('Epoch')
plt.ylabel('Cross Entropy Loss')
plt.savefig('125epochs_1streak.png')

##### 