# Patch Based Segmentation of Fundus Imagery

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

# Load the image names
import os
import os.path
import cv2

from sklearn.datasets import load_sample_image
from sklearn.feature_extraction import image as imgutil

import time
import torch.utils.data as utils

mpl.rcParams['figure.dpi'] = 300

In [2]:
torch.cuda.get_device_name(0)
torch.cuda.empty_cache()

In [3]:
class FocalLoss(nn.Module):

    def __init__(self, gamma=0, alpha=None, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        if isinstance(alpha, (float, int)): self.alpha = torch.Tensor([alpha, 1 - alpha])
        if isinstance(alpha, list): self.alpha = torch.Tensor(alpha)
        self.size_average = size_average

    def forward(self, input, target):
        if input.dim()>2:
            input = input.view(input.size(0), input.size(1), -1)  # N,C,H,W => N,C,H*W
            input = input.transpose(1, 2)                         # N,C,H*W => N,H*W,C
            input = input.contiguous().view(-1, input.size(2))    # N,H*W,C => N*H*W,C
        target = target.view(-1, 1)

        logpt = F.log_softmax(input, dim=1)
        logpt = logpt.gather(1,target)
        logpt = logpt.view(-1)
        pt = logpt.exp()

        if self.alpha is not None:
            if self.alpha.type() != input.data.type():
                self.alpha = self.alpha.type_as(input.data)
            at = self.alpha.gather(0, target.data.view(-1))
            logpt = logpt * at

        loss = -1 * (1 - pt)**self.gamma * logpt
        if self.size_average: return loss.mean()
        else: return loss.sum()

In [4]:
class MultiClassClassifier(nn.Module):
    def __init__(self):
        super(MultiClassClassifier, self).__init__()
        self.conv = nn.Sequential()
        self.conv.add_module("Pad1", nn.ConstantPad2d((0,1,0,1), 0))
        self.conv.add_module("Conv1", nn.Conv2d(3, 32, kernel_size=2))
        self.conv.add_module("BN1", nn.BatchNorm2d(32))
        self.conv.add_module("Relu1", nn.ReLU())
        
        self.conv.add_module("Pad2", nn.ConstantPad2d((0,1,0,1), 0))
        self.conv.add_module("Conv2", nn.Conv2d(32, 32, kernel_size=2))
        self.conv.add_module("BN2", nn.BatchNorm2d(32))
        self.conv.add_module("Relu2", nn.ReLU())
        self.conv.add_module("Layer2MaxPool", nn.MaxPool2d(2))
        
        self.conv.add_module("Pad3", nn.ConstantPad2d((0,1,0,1), 0))
        self.conv.add_module("Conv3", nn.Conv2d(32, 64, kernel_size=2))
        self.conv.add_module("BN3", nn.BatchNorm2d(64))
        self.conv.add_module("Relu3", nn.ReLU())
        
        self.conv.add_module("Pad4", nn.ConstantPad2d((0,1,0,1), 0))
        self.conv.add_module("Conv4", nn.Conv2d(64, 64, kernel_size=2))
        self.conv.add_module("BN4", nn.BatchNorm2d(64))
        self.conv.add_module("Relu4", nn.ReLU())
        self.conv.add_module("Layer4MaxPool", nn.MaxPool2d(2))
                             
        self.conv.add_module("Pad5", nn.ConstantPad2d((0,1,0,1), 0))
        self.conv.add_module("Conv5", nn.Conv2d(64, 128, kernel_size=2))
        self.conv.add_module("BN5", nn.BatchNorm2d(128))
        self.conv.add_module("Relu5", nn.ReLU())
        
        self.conv.add_module("Pad6", nn.ConstantPad2d((0,1,0,1), 0))
        self.conv.add_module("Conv6", nn.Conv2d(128, 128, kernel_size=2))
        self.conv.add_module("BN6", nn.BatchNorm2d(128))
        self.conv.add_module("Relu6", nn.ReLU())
        self.conv.add_module("Layer6MaxPool", nn.MaxPool2d(2))
        
        self.conv.add_module("Pad7", nn.ConstantPad2d((0,1,0,1), 0))
        self.conv.add_module("Conv7", nn.Conv2d(128, 256, kernel_size=2))
        self.conv.add_module("BN7", nn.BatchNorm2d(256))
        self.conv.add_module("Relu7", nn.ReLU())
        
        self.conv.add_module("Pad8", nn.ConstantPad2d((0,1,0,1), 0))
        self.conv.add_module("Conv8", nn.Conv2d(256, 256, kernel_size=2))
        self.conv.add_module("BN8", nn.BatchNorm2d(256))
        self.conv.add_module("Relu8", nn.ReLU())
        
        self.fc = nn.Sequential()
        self.fc.add_module("FC1", nn.Linear(4096, 1000))
        self.fc.add_module("Relu9", nn.ReLU())
        self.fc.add_module("Dropout1", nn.Dropout(0.5))
        self.fc.add_module("FC2", nn.Linear(1000, 100))
        self.fc.add_module("Relu10", nn.ReLU())
        self.fc.add_module("Dropout1", nn.Dropout(0.5))
        self.fc.add_module("FC3",nn.Linear(100, 5)) 
    
    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        return x

In [5]:
class BinaryClassifier(nn.Module):
    def __init__(self):
        super(BinaryClassifier, self).__init__()
        self.conv = nn.Sequential()
        self.conv.add_module("Pad1", nn.ConstantPad2d((0,1,0,1), 0))
        self.conv.add_module("Conv1", nn.Conv2d(3, 32, kernel_size=2))
        self.conv.add_module("BN1", nn.BatchNorm2d(32))
        self.conv.add_module("Relu1", nn.ReLU())
        
        self.conv.add_module("Pad2", nn.ConstantPad2d((0,1,0,1), 0))
        self.conv.add_module("Conv2", nn.Conv2d(32, 32, kernel_size=2))
        self.conv.add_module("BN2", nn.BatchNorm2d(32))
        self.conv.add_module("Relu2", nn.ReLU())
        self.conv.add_module("Layer2MaxPool", nn.MaxPool2d(2))
        
        self.conv.add_module("Pad3", nn.ConstantPad2d((0,1,0,1), 0))
        self.conv.add_module("Conv3", nn.Conv2d(32, 64, kernel_size=2))
        self.conv.add_module("BN3", nn.BatchNorm2d(64))
        self.conv.add_module("Relu3", nn.ReLU())
        
        self.conv.add_module("Pad4", nn.ConstantPad2d((0,1,0,1), 0))
        self.conv.add_module("Conv4", nn.Conv2d(64, 64, kernel_size=2))
        self.conv.add_module("BN4", nn.BatchNorm2d(64))
        self.conv.add_module("Relu4", nn.ReLU())
        self.conv.add_module("Layer4MaxPool", nn.MaxPool2d(2))
                             
        self.conv.add_module("Pad5", nn.ConstantPad2d((0,1,0,1), 0))
        self.conv.add_module("Conv5", nn.Conv2d(64, 128, kernel_size=2))
        self.conv.add_module("BN5", nn.BatchNorm2d(128))
        self.conv.add_module("Relu5", nn.ReLU())
        
        self.conv.add_module("Pad6", nn.ConstantPad2d((0,1,0,1), 0))
        self.conv.add_module("Conv6", nn.Conv2d(128, 128, kernel_size=2))
        self.conv.add_module("BN6", nn.BatchNorm2d(128))
        self.conv.add_module("Relu6", nn.ReLU())
        self.conv.add_module("Layer6MaxPool", nn.MaxPool2d(2))
        
        self.conv.add_module("Pad7", nn.ConstantPad2d((0,1,0,1), 0))
        self.conv.add_module("Conv7", nn.Conv2d(128, 256, kernel_size=2))
        self.conv.add_module("BN7", nn.BatchNorm2d(256))
        self.conv.add_module("Relu7", nn.ReLU())
        
        self.conv.add_module("Pad8", nn.ConstantPad2d((0,1,0,1), 0))
        self.conv.add_module("Conv8", nn.Conv2d(256, 256, kernel_size=2))
        self.conv.add_module("BN8", nn.BatchNorm2d(256))
        self.conv.add_module("Relu8", nn.ReLU())
        
        self.fc = nn.Sequential()
        self.fc.add_module("FC1", nn.Linear(4096, 1000))
        self.fc.add_module("Relu9", nn.ReLU())
        self.fc.add_module("Dropout1", nn.Dropout(0.5))
        self.fc.add_module("FC2", nn.Linear(1000, 100))
        self.fc.add_module("Relu10", nn.ReLU())
        self.fc.add_module("Dropout1", nn.Dropout(0.5))
        self.fc.add_module("FC3",nn.Linear(100, 2)) 
    
    def forward(self, x):
        x = self.conv(x)

        x = x.view(x.shape[0], -1)
        
        x = self.fc(x)
        return x

In [6]:
# model = MultiClassClassifier()
model = BinaryClassifier()
model.cuda()
learning_rate = 0.0001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [7]:
criterion = F.cross_entropy

In [8]:
#==================================
# We may wish to load a previous model
load_saved = False
if load_saved:
    saved_models = list(map(lambda x: os.path.splitext(x)[0], filter(lambda x: os.path.splitext(x)[1] == '.sav', os.listdir('.'))))
    print(saved_models)
    attempts = 0
    while attempts < 3:
        attempts += 1
        model_name = input("Choose a saved model: ")
        if model_name not in saved_models:
            print("Not found. Try again")
            continue
        else:
            model.load_state_dict(torch.load(model_name+".sav"))
            break

In [100]:
def get_labelled_patches(img, truth):
    img = np.dstack((img, truth))
    patches = imgutil.extract_patches_2d(img, (32,32))
    np.random.shuffle(patches)
    positive_patches = []
    negative_patches = []
    positive_labels = []
    for patch in patches:
              
        truth = patch[16,16,3]
        patch = patch[:,:,:3]  
        patch = np.rollaxis(patch,2,0)

        if truth != 0:
            positive_patches.append(torch.from_numpy(patch))
            positive_labels.append(truth)
        else:
            negative_patches.append(torch.from_numpy(patch))
    
    if len(positive_patches) == 0:
        positive_tensor = None
    else:
        positive_tensor = torch.stack(positive_patches)
    
    return positive_tensor, torch.stack(negative_patches), positive_labels

In [98]:
def create_tensors(images, features):
    
    patches = []
    labels = []
    
    for image, feature in zip(images, features):
        pos, neg, pos_labels = get_labelled_patches(image, feature)
        if(pos is None):
            continue
        neg = neg[:len(pos),:,:]
        patches.extend(pos)
        patches.extend(neg)
            
        pos_l = torch.Tensor(pos_labels)
        neg_l = torch.zeros(len(neg))
        labels.extend(pos_l)
        labels.extend(neg_l)
    
    image_tensor = torch.stack(patches).float()
    label_tensor = torch.stack(labels)
    return image_tensor, label_tensor

In [11]:
def load_numpy_data_task1(source):

    # train_path_truths = os.path.join("Data_Group_Component_Task_1", "Train", "masks_Hard_Exudates")
    train_path_images = os.path.join("Data_Group_Component_Task_1", source, "original_retinal_images")

    train_path_exudates = os.path.join("Data_Group_Component_Task_1", source, "masks_Hard_Exudates")
    train_path_soft_exudates = os.path.join("Data_Group_Component_Task_1", source, "masks_Soft_Exudates")
    train_path_haemorrhages = os.path.join("Data_Group_Component_Task_1", source, "masks_Haemorrhages")
    train_path_microaneurysms = os.path.join("Data_Group_Component_Task_1", source, "masks_Microaneurysms")

    train_image_names = os.listdir(train_path_images)

    train_exudate_names = list(map(lambda x: os.path.join(train_path_exudates, x.split('.')[0] + '_EX.tif'), train_image_names))
    train_haem_names = list(map(lambda x: os.path.join(train_path_haemorrhages, x.split('.')[0] + '_HE.tif'), train_image_names))
    train_sfex_names = list(map(lambda x: os.path.join(train_path_soft_exudates, x.split('.')[0] + '_SE.tif'), train_image_names))
    train_ma_names = list(map(lambda x: os.path.join(train_path_microaneurysms, x.split('.')[0] + '_MA.tif'), train_image_names))

    images = list(map(
        lambda x: cv2.resize(
            cv2.imread(
                os.path.join(train_path_images, x)
            ), (256, 256)
        ), train_image_names))

    
    features = []
    for names in zip(train_exudate_names, train_haem_names, train_sfex_names, train_ma_names):
        he, ha, se, ma =  names

        he = cv2.imread(he)
        if he is not None:
            he = cv2.resize(he, (256,256))[:,:,2]
        else:
            he = np.zeros((256,256), dtype=np.uint8)

        ha = cv2.imread(ha)
        if ha is not None:
            ha = cv2.resize(ha, (256,256))[:,:,2]
        else:
            ha = np.zeros((256,256), dtype=np.uint8)

        se = cv2.imread(se)
        if se is not None:
            se = cv2.resize(se, (256,256))[:,:,2]
        else:
            se = np.zeros((256,256), dtype=np.uint8)

        ma = cv2.imread(ma)
        if ma is not None:
            ma = cv2.resize(ma, (256,256))[:,:,2]
        else:
            ma = np.zeros((256,256), dtype=np.uint8)    

        feature_map = (he != 0).astype(np.uint8)
        feature_map[np.where(ha != 0)] = 2
        feature_map[np.where(se != 0)] = 3
        feature_map[np.where(ma != 0)] = 4

        features.append(feature_map)
    
    return images, features

In [91]:
def load_numpy_data_task2(source):
    
    def read_video(video_path):
        video = cv2.VideoCapture(str(video_path))
        while video.isOpened():
            ok, frame = video.read()

            if not ok:
                break

            yield frame
        video.release()

    if source == "Train":
        source = "Training"
    # train_path_truths = os.path.join("Data_Group_Component_Task_1", "Train", "masks_Hard_Exudates")
    train_path_images = os.path.join("Data_Group_Component_Task_2", source, "original_retinal_images")

    train_path_vessels = os.path.join("Data_Group_Component_Task_2", source, "blood_vessel_segmentation_masks")
   
    train_image_names = os.listdir(train_path_images)

    train_vessel_names = list(map(lambda x: os.path.join(train_path_vessels, x.split('_')[0] + '_manual1.gif'), train_image_names))

    images = list(map(
        lambda x: cv2.resize(
            cv2.imread(
                os.path.join(train_path_images, x)
            ), (256, 256)
        ), train_image_names))
    
    features = []
    for name in train_vessel_names:
        
        vessels = list(read_video(name))[0]
        
        if vessels is not None:
            vessels = cv2.resize(vessels, (256,256))[:,:,2]
        else:
            vessels = np.zeros((256,256), dtype=np.uint8)

        feature_map = (vessels != 0).astype(np.uint8)

        features.append(feature_map)
    
    return images, features

In [92]:
def load_numpy_data_task1_specificfeature(source, feature, suffix):
    
    # train_path_truths = os.path.join("Data_Group_Component_Task_1", "Train", "masks_Hard_Exudates")
    train_path_images = os.path.join("Data_Group_Component_Task_1", source, "original_retinal_images")

    train_path_vessels = os.path.join("Data_Group_Component_Task_1", source, feature)
   
    train_image_names = os.listdir(train_path_images)

    train_feature_names = list(map(lambda x: os.path.join(train_path_vessels, x.split('.')[0] + f'_{suffix}.tif'), train_image_names))

    images = list(map(
        lambda x: cv2.resize(
            cv2.imread(
                os.path.join(train_path_images, x)
            ), (256, 256)
        ), train_image_names))
    
    features = []
    for name in train_feature_names:
    
        vessels = cv2.imread(name)

        
        if vessels is not None:
            vessels = cv2.resize(vessels, (256,256))[:,:,2]
        else:
            vessels = np.zeros((256,256), dtype=np.uint8)

        feature_map = (vessels != 0).astype(np.uint8)

        features.append(feature_map)
    
    return images, features

## Load the correct data

In [93]:
image_list, label_list = load_numpy_data_task1_specificfeature("Train", "masks_Soft_Exudates", "SE")
# image_list, label_list = load_numpy_data_task2("Train")

In [101]:
image_tensor, label_tensor = create_tensors(image_list, label_list)

Put the data into pytorch DataLoader for batching

In [102]:
print(len(image_tensor))

11566


In [103]:
def create_partitions(n, parts):
    
    a = np.linspace(0, n, parts+1).astype(np.uint64)

    prev = a[0]
    partitions=[]
    for index in a[1:]:
        partitions.append((prev, index))
        prev=index
    
    return partitions

In [104]:
partitions = create_partitions(len(image_tensor), 5)
print(partitions)

[(0, 2313), (2313, 4626), (4626, 6939), (6939, 9252), (9252, 11566)]


In [121]:
epochs = 100
batch_size = 50

losses_per_dataset = []
partition = 0
for dataset in partitions:
    partition += 1
    print(f"================================  Partition:{partition}  ======================================")
    train_dataset = utils.TensorDataset(image_tensor[dataset[0]:dataset[1]], label_tensor[dataset[0]:dataset[1]])
    train_dataloader = utils.DataLoader(train_dataset, batch_size = batch_size, shuffle=True)
    losses_per_epoch = []
    model.train()

    for epoch in range(0, epochs):
        t0 = time.time()
        losses_per_batch = []

        for i, data in enumerate(train_dataloader):
            patches, labels = data

            # clear gradients
            optimizer.zero_grad()

            # forward pass
            output = model(patches.cuda().float())

            # calculate batch loss
            loss = criterion(output, labels.cuda().long())
            # compute 
            loss.backward()
            optimizer.step()

            losses_per_batch.append(loss.item())

        t1 = time.time()
        losses_per_epoch.append(sum(losses_per_batch))
        torch.cuda.empty_cache()
        print("Epoch: {0} Loss:{1} Trained in {2} seconds".format(epoch+1, sum(losses_per_batch), t1-t0))
    
    losses_per_dataset.append(losses_per_epoch)

Epoch: 1 Loss:7.7166856955736876 Trained in 1.4014768600463867 seconds
Epoch: 2 Loss:1.1241815702960594 Trained in 1.4092507362365723 seconds
Epoch: 3 Loss:0.10534545283007901 Trained in 1.389035940170288 seconds
Epoch: 4 Loss:0.01636983340995357 Trained in 1.4002623558044434 seconds
Epoch: 5 Loss:0.007772507643238669 Trained in 1.4910125732421875 seconds
Epoch: 6 Loss:0.00923547380313039 Trained in 1.5627751350402832 seconds
Epoch: 7 Loss:0.005448279186399674 Trained in 1.5048158168792725 seconds
Epoch: 8 Loss:0.11509303577413021 Trained in 1.4581074714660645 seconds
Epoch: 9 Loss:0.226950641502647 Trained in 1.5351450443267822 seconds
Epoch: 10 Loss:0.9191907774568904 Trained in 1.765279769897461 seconds
Epoch: 11 Loss:0.02978351098136045 Trained in 1.6191380023956299 seconds
Epoch: 12 Loss:0.01484833801532659 Trained in 1.5114150047302246 seconds
Epoch: 13 Loss:0.00535729956789055 Trained in 1.701448917388916 seconds
Epoch: 14 Loss:0.0082472923036363 Trained in 1.4451539516448975 se

Epoch: 10 Loss:0.8959252558231583 Trained in 1.3384056091308594 seconds
Epoch: 11 Loss:0.08052442825794515 Trained in 1.349423885345459 seconds
Epoch: 12 Loss:0.043655478921664326 Trained in 1.3446543216705322 seconds
Epoch: 13 Loss:0.01632654967875169 Trained in 1.3538165092468262 seconds
Epoch: 14 Loss:0.060084884878051525 Trained in 1.3487300872802734 seconds
Epoch: 15 Loss:0.00734053549568614 Trained in 1.3563926219940186 seconds
Epoch: 16 Loss:0.006741037180944431 Trained in 1.3494219779968262 seconds
Epoch: 17 Loss:0.002902398294015285 Trained in 1.336864709854126 seconds
Epoch: 18 Loss:0.00369037296855268 Trained in 1.3430864810943604 seconds
Epoch: 19 Loss:0.0032257813765710353 Trained in 1.3493919372558594 seconds
Epoch: 20 Loss:0.001947049552533997 Trained in 1.359678030014038 seconds
Epoch: 21 Loss:0.0010511178159013923 Trained in 1.3593904972076416 seconds
Epoch: 22 Loss:0.0035668626154858885 Trained in 1.341423749923706 seconds
Epoch: 23 Loss:0.0013697589681136435 Trained 

Epoch: 19 Loss:0.0037093211834449136 Trained in 1.3643217086791992 seconds
Epoch: 20 Loss:0.0007990308871530516 Trained in 1.341423749923706 seconds
Epoch: 21 Loss:0.0001877927773712429 Trained in 1.3423168659210205 seconds
Epoch: 22 Loss:0.003916958555455707 Trained in 1.3493905067443848 seconds
Epoch: 23 Loss:0.0019196721587455556 Trained in 1.352710485458374 seconds
Epoch: 24 Loss:0.00013282299240913176 Trained in 1.3282749652862549 seconds
Epoch: 25 Loss:0.0012621283381282211 Trained in 1.3369956016540527 seconds
Epoch: 26 Loss:0.00015473585905745324 Trained in 1.342428207397461 seconds
Epoch: 27 Loss:0.0006210326979925185 Trained in 1.3697104454040527 seconds
Epoch: 28 Loss:8.072376194689923e-05 Trained in 1.3543980121612549 seconds
Epoch: 29 Loss:0.0003444136121473207 Trained in 1.338829755783081 seconds
Epoch: 30 Loss:3.2444000330755784e-05 Trained in 1.338538408279419 seconds
Epoch: 31 Loss:0.000250706668571965 Trained in 1.3476207256317139 seconds
Epoch: 32 Loss:0.000128444892

Epoch: 28 Loss:0.002015552520873598 Trained in 1.323469638824463 seconds
Epoch: 29 Loss:0.00013881683010552592 Trained in 1.387373924255371 seconds
Epoch: 30 Loss:0.00016221046035802544 Trained in 1.3334298133850098 seconds
Epoch: 31 Loss:0.0008724689447952017 Trained in 1.3344640731811523 seconds
Epoch: 32 Loss:5.7258607030874487e-05 Trained in 1.322188377380371 seconds
Epoch: 33 Loss:0.00026570027241668015 Trained in 1.3531088829040527 seconds
Epoch: 34 Loss:1.184536869920482 Trained in 1.3350272178649902 seconds
Epoch: 35 Loss:0.24735834460443584 Trained in 1.3553738594055176 seconds
Epoch: 36 Loss:0.11125826980651254 Trained in 1.3255865573883057 seconds
Epoch: 37 Loss:0.07342778420968443 Trained in 1.3367033004760742 seconds
Epoch: 38 Loss:0.01600542012403139 Trained in 1.3291068077087402 seconds
Epoch: 39 Loss:0.0033457321118817163 Trained in 1.3460056781768799 seconds
Epoch: 40 Loss:0.003113849493317389 Trained in 1.3250155448913574 seconds
Epoch: 41 Loss:0.03136140243017138 Tra

Epoch: 38 Loss:0.0031320929386087215 Trained in 1.344407081604004 seconds
Epoch: 39 Loss:0.0003174448048532952 Trained in 1.3673427104949951 seconds
Epoch: 40 Loss:0.02398973986530173 Trained in 1.3414113521575928 seconds
Epoch: 41 Loss:0.0003732252102359723 Trained in 1.3463845252990723 seconds
Epoch: 42 Loss:0.00031782150566783685 Trained in 1.3414289951324463 seconds
Epoch: 43 Loss:0.0004067230122135612 Trained in 1.332566261291504 seconds
Epoch: 44 Loss:0.0006164169289082366 Trained in 1.3633527755737305 seconds
Epoch: 45 Loss:0.001144552221152395 Trained in 1.3383445739746094 seconds
Epoch: 46 Loss:0.00038678170977313187 Trained in 1.3292758464813232 seconds
Epoch: 47 Loss:0.05567924975606431 Trained in 1.3449993133544922 seconds
Epoch: 48 Loss:0.34999020473690834 Trained in 1.3454177379608154 seconds
Epoch: 49 Loss:1.4491265268658395 Trained in 1.3374226093292236 seconds
Epoch: 50 Loss:0.5390910961359623 Trained in 1.3274757862091064 seconds
Epoch: 51 Loss:0.11590845302713149 Tra

## Training

Train the data for **epochs** on mini-batches of **batch_size**

In [None]:
torch.save(model.state_dict(), "5class_normalised_1streak" + ".sav")

In [52]:
model_name = input("Name the model ")
torch.save(model.state_dict(), model_name + ".sav")
print(f"Saved as {model_name}.sav")

Name the model vessels_crossentropyloss_1streak10epochs
Saved as vessels_crossentropyloss_1streak10epochs.sav


### Evaluate Visually

In [26]:
def recolour(image):
    colours = np.array([
        [0,0,0],
        [0,255,0],
        [0,0,255],
        [255,0,0],
        [255,0,255],
    ])
    
    colour_vector = np.take(colours, image.flatten(), axis=0)
    colour_vector = np.reshape(colour_vector, (image.shape[0], image.shape[1], 3))
    
    return colour_vector

In [111]:
#image_test, feature_test = load_numpy_data("Test")
image_test, feature_test = load_numpy_data_task1_specificfeature("Test", "masks_Soft_Exudates", "SE")
# image_test = np.array(image_test)

# image_test_mean = np.mean(image_test, axis=tuple(range(image_test.ndim-1)))
# image_test_std = np.std(image_test, axis=tuple(range(image_test.ndim-1)))

# image_test = image_test - image_test_mean
# image_test = image_test / image_test_std


In [112]:
import pickle as pkl



def test(num, epochs, task, name):
    model.eval()
    image, truth = list(zip(image_test, feature_test))[num]
    patches = imgutil.extract_patches_2d(image, (32,32))
    rolled_patches = [torch.Tensor(np.rollaxis(patch,2,0)) for patch in patches]
    rolled_patches_tensor = torch.stack(rolled_patches)
    image_patches_dataset = utils.TensorDataset(rolled_patches_tensor)
    image_loader = utils.DataLoader(image_patches_dataset, batch_size=225)

    generated_mask = []

    for i, image_patch_ in enumerate(image_loader):
        img_patch = image_patch_[0]
        test_output = model(img_patch.cuda())
        labels = torch.argmax(test_output,1)# convert one hot to index/pixel form
        generated_mask.append(labels.cpu().data.numpy())

    generated_mask = np.array(generated_mask)
    
    coloured = recolour(generated_mask)
    

    
    coloured_truth = recolour(truth[16:16+225,16:16+225])

    side_by_side = np.hstack((coloured_truth, coloured, image[16:16+225,16:16+225]))
    cv2.imwrite(f"outputs/{task}/{name}/test{num}_{name}_{epochs}.jpg", side_by_side)
    
    with open(f'outputs/{task}/{name}/test{num}_{name}_{epochs}.sav', 'wb') as file:
        pkl.dump({
            'pred': generated_mask,
            'truth': truth,
            'sideBySide': side_by_side,
            'trainingLoss': losses_per_dataset,
        }, file)

In [120]:
for i in range(len(image_test)):
    test(i, "2streak100epochs", "task1", "softexudates")

In [114]:
learning = []

for dataset in losses_per_dataset:
    for epoch in dataset:
        learning.append(epoch)

62.46484309434891
26.250777304172516
20.4270920753479
13.255280308425426
7.77558471262455
4.586254129186273
2.3648105040192604
1.5346070677042007
0.9536549397744238
0.5772387200267985
13.45454003661871
4.412263004109263
2.074827627511695
1.1346344805788249
0.6379011685494334
0.3920715559506789
0.2104671751440037
0.22020873207657132
0.1363369211758254
0.12140374958107714
16.681800670921803
4.65898478589952
2.5014692340046167
1.4243572717532516
1.056384357623756
0.538886453199666
0.393902646668721
0.2203140258206986
0.11009480235225055
0.13226510648382828
9.487128367647529
3.457006884738803
1.8514068583026528
1.1478114165365696
0.7540193410823122
0.9356588879600167
0.8494800234329887
0.3681835462921299
0.26040642510633916
0.19696869706967846
9.565196793526411
3.100829637609422
1.839425751240924
1.0322476993314922
0.6479089525528252
0.4886001986451447
0.22339871019357815
0.21318444173084572
0.1461899876157986
0.09051430795079796


In [None]:
len(learning)

In [None]:
plt.plot(learning)
plt.title('Learning Curve')
plt.xlabel('Epoch')
plt.ylabel('Cross Entropy Loss')
plt.savefig('125epochs_1streak.png')

##### 