In [15]:
OUTPUT_FOLDER = "/scratch/aakash_ks.iitr"
DATA_FOLDER = "/scratch/aakash_ks.iitr/data/diabetic-retinopathy/"
TRAIN_DATA_FOLDER = DATA_FOLDER + 'resized_train/'

In [16]:
# PyTorch modules
import torch
from torch import nn 
from torch.utils import data 
import torch.nn.functional as F 
from torchvision import transforms, models

# other modules
import os
import pandas as pd
import numpy as np
from PIL import Image
from glob import glob
import random
import itertools

In [17]:
# annotations
training_table = pd.read_csv(os.path.join(DATA_FOLDER, 'trainLabels.csv'))

In [52]:
class ROP_dataset_v5(data.Dataset):
    """ 
    Create dataset representation of ROP data 
    - This class returns image pairs with a change label (i.e. change vs no change in a categorical disease severity label) and other metadata
    - Image pairs are sampled so that there are an equal number of change vs no change labels
    - Epoch size can be set for empirical testing
  
    Concepts adapted from: 
    - https://hackernoon.com/facial-similarity-with-siamese-networks-in-pytorch-9642aa9db2f7
    - https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
    """  
    def __init__(self, patient_table, image_dir, epoch_size, transform=None):
        """
        Args:
            patient_table (pd.dataframe): dataframe table containing image names, disease severity category label, and other metadata
            image_dir (string): directory containing all of the image files
            transform (callable, optional): optional transform to be applied on a sample
        """
        self.patient_table = patient_table
        self.image_dir = image_dir 
        self.transform = transform
        self.epoch_size = epoch_size 
        if self.transform is None:
            self.transform = transforms.Compose([transforms.ToTensor()])
 
    def __len__(self):
        return self.epoch_size
 
    def __getitem__(self, idx):
        
        name_list = list(self.patient_table['image'])

        # goal is 50:50 distribution of change vs no change
        change_binary = random.randint(0,1) 

        # keep on looping until no change pair created
        while change_binary == 0:

            # pick random image from folder
            # check to see if the image exists and can be loaded, if not move to another random image
            while True:
                random_image = random.choice(name_list) + '.jpeg' # note that processed images are all .png type, while patient_table has different types
                if random_image in os.listdir(self.image_dir):
                    break
                else:
                    print('attempted to get following image, but missing: ' + random_image)

            while True:
                paired_image = random.choice(name_list) + '.jpeg' # note that processed images are all .png type, while patient_table has different types
                if paired_image in os.listdir(self.image_dir):
                    break
                else:
                    print('attempted to get following image, but missing: ' + random_image)

            # get labels and meta data
            plus_disease_0 = self.patient_table.loc[self.patient_table['image'] == random_image[:-5], 'level'].values[0]
            plus_disease_1 = self.patient_table.loc[self.patient_table['image'] == paired_image[:-5], 'level'].values[0]
 
            if plus_disease_0 == plus_disease_1:
                plus_disease_binary_change = 0 # 0 for no change
            else:
                plus_disease_binary_change = 1 # 1 for change

            if plus_disease_binary_change == change_binary:
                break
  
        # keep on looping until change pair created
        while change_binary == 1:
        
            # pick random image from folder
            # check to see if the image exists and can be loaded, if not move to another random image
            while True:
                random_image = random.choice(name_list) + '.jpeg' # note that processed images are all .png type, while patient_table has different types
                if random_image in os.listdir(self.image_dir):
                    break
                else:
                    print('attempted to get following image, but missing: ' + random_image)

            while True:
                paired_image = random.choice(name_list) + '.jpeg' # note that processed images are all .png type, while patient_table has different types
                if paired_image in os.listdir(self.image_dir):
                    break
                else:
                    print('attempted to get following image, but missing: ' + random_image)

            # get levels and meta data
            plus_disease_0 = self.patient_table.loc[self.patient_table['image'] == random_image[:-5], 'level'].values[0]
            plus_disease_1 = self.patient_table.loc[self.patient_table['image'] == paired_image[:-5], 'level'].values[0]
 
            if plus_disease_0 == plus_disease_1:
                plus_disease_binary_change = 0 # 0 for no change
            else:
                plus_disease_binary_change = 1 # 1 for change

            if plus_disease_binary_change == change_binary:
                break

        # convert disease severity class labels to numeric form    
 
        # if plus_disease_0 == 'NRG': pd0 = 0
        # if plus_disease_0 == 'RG': pd0 = 1

        # if plus_disease_1 == 'NRG': pd1 = 0
        # if plus_disease_1 == 'RG': pd1 = 1

        plus_disease_change = plus_disease_1 - plus_disease_0 # should range from -1 to +1

        if plus_disease_change == 0:
            plus_disease_binary_change = 0
        else:
            plus_disease_binary_change = 1

        # should be same patient ID and eye for both time points
        subject_id_0 = random_image.split('_')[0]
        subject_id_1 = paired_image.split('_')[0]
        eye_0 = random_image.split('_')[1]
        eye_1 = paired_image.split('_')[1]
 
        label = plus_disease_binary_change # 0 for no change, 1 for change
 
        meta = {"subject_id_0": subject_id_0,
                "subject_id_1": subject_id_1,
                "eye_0": eye_0, 
                "eye_1": eye_1, 
                "plus_disease_0": plus_disease_0,
                "plus_disease_1": plus_disease_1,
                "plus_disease_change": plus_disease_change,
               }
     
        # open images and convert to single channel (greyscale from RGB) - note already 8bit images
        img0 = Image.open(self.image_dir + random_image)
        img1 = Image.open(self.image_dir + paired_image)
        img0 = img0.convert("L")
        img1 = img1.convert("L")

        if self.transform is not None:
            img0 = self.transform(img0)
            img1 = self.transform(img1) 

        return img0, img1, label, meta
              

In [53]:
# processed image directory
image_dir = TRAIN_DATA_FOLDER

# TRAINING DATA - INTER-patient image pairs BUT not limited to a single time point

training_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5), # aim to train to be invariant to laterality of eye
    transforms.RandomRotation(10), # rotate +/- 5 degrees around center
    transforms.RandomCrop(224), # pixel crop 
    transforms.ColorJitter(brightness = 0.03, contrast = 0.03), # brightness and color variation of +/- 5%
    transforms.ToTensor()
])

# note dataset v5 for inter-patient 50:50 comparisons
training_siamese_dataset = ROP_dataset_v5(patient_table = training_table, 
                                        image_dir = image_dir, 
                                        epoch_size = 3200,
                                        transform = training_transforms)

training_dataloader = torch.utils.data.DataLoader(training_siamese_dataset, 
                                                  batch_size=16, 
                                                  shuffle=False, 
                                                  num_workers=0)
