# Retinal Disease Detection and Classification

## By: Jalen Wu, Yechan Na, Jonathan Zhang

__Project Description:__

The goal of this project is to develop a machine learning model capable of detecting retinal diseases by analyzing fundus images of the eye. Using computer vision and deep learning techniques, the model assists in early detection and diagnosis of retinal disease(s).

__Applications and Impact:__

This project could be used for clinical screening to help ophthalmologists identify diseases and improve efficiency on identifying these diseases. This automated detection system has the potential to make medical imaging diagnostics more accessible.

__What we hope to achieve:__

We hope to build a model that takes in images of the eye and accurately predicts whether an individual’s eyes are healthy or showing signs of disease. To quantify the effectiveness of our model, we will be measuring metrics such as F1 score, precision, recall, loss, and accuracy of our models and graphing them as a function of how many epochs we run on our training data (all of these scores should increase in subsequent epochs).

- __Dataset__: https://www.kaggle.com/datasets/andrewmvd/retinal-disease-classification/data
- __References__: 
    - https://www.mdpi.com/2306-5729/6/2/14
    - https://jamanetwork.com/journals/jama/fullarticle/2588763
    

In [1]:
# imports

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from torchvision import transforms
import torch.nn.functional as F
import numpy as np
from matplotlib import pyplot as plt
import os
from torch.utils.data import Dataset
import pandas as pd
from skimage import io, transform
from torch import tensor
from PIL import Image
from torch import flatten
from nltk.metrics.scores import (precision, recall, f_measure, accuracy)
from tqdm import tqdm

In [3]:
# Import Dataset
import kagglehub

# Download latest version
path = kagglehub.dataset_download("andrewmvd/retinal-disease-classification")

print("Path to dataset files:", path)

Path to dataset files: C:\Users\wendy\.cache\kagglehub\datasets\andrewmvd\retinal-disease-classification\versions\1


In [4]:
# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cpu device


In [5]:
# X = 1424 x 2144 x 3 : h x w x colors
# y = label
transform = transforms.Compose([
    transforms.Resize((int(1424 / 8), int(2144 / 8))), # Standardize image dimensions to 1444 x 2144
    transforms.ToTensor(),          # Convert images to PyTorch tensors
    # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize to [-1, 1]
])

base_directory = path  
print(base_directory)

if os.path.exists(base_directory):
    print("Directory exists and its contents are:")
    print(os.listdir(base_directory))
else:
    print("Directory does not exist.")

C:\Users\wendy\.cache\kagglehub\datasets\andrewmvd\retinal-disease-classification\versions\1
Directory exists and its contents are:
['Evaluation_Set', 'Test_Set', 'Training_Set']


In [6]:
# Constants (paths that will be accessed later in the project)
BASE_DIRECTORY = path
TRAINING_DIRECTORY = os.path.join(BASE_DIRECTORY, 'Training_Set', 'Training_Set', 'Training')
TRAINING_LABELS = os.path.join(BASE_DIRECTORY, 'Training_Set', 'Training_Set', 'RFMiD_Training_Labels.csv')
TESTING_DIRECTORY = os.path.join(BASE_DIRECTORY, 'Test_Set', 'Test_Set', 'Test')
TESTING_LABELS = os.path.join(BASE_DIRECTORY, 'Test_Set', 'Test_Set', 'RFMiD_Testing_Labels.csv')

In [7]:
class MultiClassDataset(Dataset):
    """
    Supports structure of given dataset (images in child folder and labels in csv format).
    """
    def __init__(self, label_csv_file, image_directory, transform=None):
        """
        Args:
            label_csv_file (str): Path to the CSV label file.
            image_directory (str): Directory with eye images.
            transform (callable, optional): transform function to be applied to each image.
        """
        self.label_csv_file = pd.read_csv(label_csv_file)
        self.image_directory = image_directory
        self.transform = transform
    
    def __len__(self):
        return len(self.label_csv_file)    

    def __getitem__(self, index):
        """
        Args:
            index: the index of the image/label pair we want to retrieve
            
        Returns:
            image_and_label (dict): A dictionary containing the image and its corresponding label at the requested index.
        """

        # The images are PNG and one-indexed (1.png, 2.png, 3.png, ...)
        image_path = os.path.join(self.image_directory, str(index + 1) +'.png')
        
        image_label = self.label_csv_file.loc[index]
        image_label = image_label.to_numpy()
        image_label = image_label[1:]
        image_label = tensor(image_label)
        
        # Loads image at the given path.
        current_image = io.imread(image_path)
        
        if current_image is None or current_image.size == 0:
            raise ValueError(f"Failed to load image: {image_path}")
       
        # Convert to PIL Image
        current_image = Image.fromarray(current_image)
        
        # If transform function passed in, apply transform to image
        if self.transform:
            transformed_image = self.transform(current_image)
            return {'image': transformed_image, 'label': image_label}
        
        return {'image': current_image,'label': image_label}

In [8]:
# Initializing dataset and loader.
train_dataset = MultiClassDataset(label_csv_file=TRAINING_LABELS, image_directory=TRAINING_DIRECTORY, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64)

# Visualize loader
res = next(iter(train_loader))
print(res['image'].shape)
print(res['label'].shape)

print(res['label'][0])
print(res['image'][0][0])

torch.Size([64, 3, 178, 268])
torch.Size([64, 46])
tensor([1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])


In [9]:
# image tensor format: [batch_size, channels, height, width]
print(res['image'][0].shape)

torch.Size([3, 178, 268])


In [10]:
import torch.nn as nn
import torch.nn.functional as F


class Conv_NN_v2(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.conv3 = nn.Conv2d(16, 16, 5)
        self.fc1 = nn.Linear(8640, 1200)
        self.fc2 = nn.Linear(1200, 840)
        self.fc3 = nn.Linear(840, 230)
        self.fc4 = nn.Linear(230, 46)
        #self.sig = nn.Sigmoid()

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.sigmoid(self.fc4(x))
        return x


conv_model_v2 = Conv_NN_v2()

In [11]:

class Conv_NN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.conv3 = nn.Conv2d(16, 16, 5)
        self.fc1 = nn.Linear(1456, 1200)
        self.fc2 = nn.Linear(1200, 840)
        self.fc3 = nn.Linear(840, 230)
        self.fc4 = nn.Linear(230, 46)
        #self.sig = nn.Sigmoid()

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.sigmoid(self.fc4(x))
        return x


conv_model = Conv_NN()

In [15]:
def train(dataloader, model, loss_fn, optimizer):
    model.train()
    size = len(dataloader.dataset)
    total_loss = 0.0
    for i, data in tqdm(enumerate(dataloader)):
        # get the inputs; data is a dict of [inputs, labels]
        inputs = data['image']
        labels = data['label']
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = loss_fn(outputs, labels.type(torch.FloatTensor))
        loss.backward()
        optimizer.step()

        # print statistics
        total_loss += loss.item()
    
    return total_loss/size

In [16]:
test_dataset = MultiClassDataset(label_csv_file=TESTING_LABELS, image_directory=TESTING_DIRECTORY, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64)

In [17]:
def PRFA(predictions, answers):
    pred_indices = [x for x in range(len(predictions)) if predictions[x] == 1]
    label_indices = [y for y in range(len(answers)) if answers[y] == 1]

    temp_precision = precision(set(pred_indices), set(label_indices)) # actual labels vs. predicted labels
    temp_recall = recall(set(pred_indices), set(label_indices))
    temp_f1 = f_measure(set(pred_indices), set(label_indices))
    temp_accuracy = accuracy(answers, predictions)
    return (temp_precision, temp_recall, temp_f1, temp_accuracy)


In [21]:
def test(dataloader, model):
    """
    Takes a dataloader for the test data and model
    Returns the Precision, Recall, F1 Score, and Accuracy of the model as a tuple
    """
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    correct = 0
    correct_each = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for i, data in enumerate(dataloader):
            # get the inputs; data is a list of [inputs, labels]
            inputs = data['image']
            labels = data['label']
            
            preds = model(inputs)
            
            for p_i in range(len(preds)):
                for l_i in range(len(preds[p_i])):
                    if preds[p_i][l_i] > 0.9:
                        preds[p_i][l_i] = 1.0
                    else:
                        preds[p_i][l_i] = 0.0
            
            all_preds.append(preds)
            all_labels.append(labels)
    
    return PRFA(torch.flatten(torch.cat(all_preds)), torch.flatten(torch.cat(all_labels)))

In [35]:
import torch.optim as optim

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(conv_model.parameters(), lr=0.001, momentum=0.9)

In [37]:
epochs = 40
conv_model_loss = []
p_scores = []
r_scores = []
f_scores = []
a_scores = []

#conv_model = Conv_NN()
#conv_model.load_state_dict(torch.load("5th_conv_model.pth", weights_only=True))

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    conv_loss = train(train_loader, conv_model, loss_fn, optimizer)
    print("conv_loss: ", conv_loss)
    
    if t > 3 and conv_loss > conv_model_loss[-1]:
        break
    
    p, r, f, a = test(test_loader, conv_model)
    print("prfa: ")
    print(p, r, f, a)
    p_scores.append(p)
    r_scores.append(r)
    f_scores.append(f)
    a_scores.append(a)
    
    conv_model_loss.append(conv_loss)
    torch.save(conv_model.state_dict(), "5th_conv_model.pth")


    file = open("5th_model_scores.txt", "w", encoding="utf8")
    file.write(f"losses: {conv_model_loss} \n")
    file.write(f"precisions: {p_scores} \n")
    file.write(f"recalls: {r_scores} \n")
    file.write(f"f1 scores: {f_scores} \n")
    file.write(f"accuracies: {a_scores} \n")

Epoch 1
-------------------------------


30it [14:19, 28.64s/it]


conv_loss:  0.09462826177477837
prfa: 
0.7863829787234042 0.24283837056504598 0.3710843373493976 tensor(0.8936)
Epoch 2
-------------------------------


30it [17:50, 35.69s/it]


conv_loss:  0.0946250177298983
prfa: 
0.7821276595744681 0.2479093606690046 0.37648504711183944 tensor(0.8966)
Epoch 3
-------------------------------


30it [16:24, 32.82s/it]


conv_loss:  0.09462742786854506
prfa: 
0.7821276595744681 0.24797625472207233 0.3765621798811719 tensor(0.8966)
Epoch 4
-------------------------------


30it [16:27, 32.92s/it]


conv_loss:  0.09462281508992115
prfa: 
0.7812765957446809 0.24938875305623473 0.3780889621087315 tensor(0.8974)
Epoch 5
-------------------------------


30it [16:48, 33.61s/it]


conv_loss:  0.09462200912336509
prfa: 
0.7821276595744681 0.24764214497440043 0.37617683176422434 tensor(0.8965)
Epoch 6
-------------------------------


30it [16:35, 33.18s/it]


conv_loss:  0.09461806248873472
prfa: 
0.7812765957446809 0.24945652173913044 0.37816683831101955 tensor(0.8975)
Epoch 7
-------------------------------


30it [14:48, 29.60s/it]


conv_loss:  0.09461749792098999
prfa: 
0.7821276595744681 0.24757543103448276 0.3760998567628402 tensor(0.8964)
Epoch 8
-------------------------------


30it [21:53, 43.78s/it]


conv_loss:  0.09461309934655825
prfa: 
0.7812765957446809 0.24952432726284315 0.3782447466007416 tensor(0.8975)
Epoch 9
-------------------------------


30it [15:41, 31.39s/it]


conv_loss:  0.09461275357753038
prfa: 
0.7821276595744681 0.24737550471063258 0.37586912065439676 tensor(0.8963)
Epoch 10
-------------------------------


30it [16:39, 33.31s/it]


conv_loss:  0.09460789542645216
prfa: 
0.7812765957446809 0.24966004895295077 0.37840065952184665 tensor(0.8976)
Epoch 11
-------------------------------


30it [25:06, 50.21s/it] 

conv_loss:  0.09460810863723358





# Analysis
- Plot loss/accuracy (y) with number of epochs ran (x)
- Recall/Precision/F1 score 

In [32]:
conv_model = Conv_NN()
conv_model.load_state_dict(torch.load("5th_conv_model.pth", weights_only=True))

test(test_loader, conv_model)

(0.7804255319148936, 0.24993186154265468, 0.37861271676300584, tensor(0.8978))

In [None]:
def graph_loss(x, y, metric):

    plt.plot(x, y)
    plt.xlabel("epoch #")
    plt.ylabel(f"{metric} score per epoch")
    plt.title(f"{metric} Scores")
    plt.savefig(f"{metric}_plot.png")
    plt.clf()

graph_loss(range(len(conv_model_loss)), conv_model_loss, "Cross Entropy Loss")
graph_loss(range(len(p_scores)), p_scores, "Precision")
graph_loss(range(len(r_scores)), r_scores, "Recall")
graph_loss(range(len(f_scores)), f_scores, "F1")
graph_loss(range(len(a_scores)), a_scores, "Accuracy")

In [48]:
print("First vs Final Scores:")
print("Cross Entropy Loss:", conv_model_loss[0], conv_model_loss[-1])
print("Precision:", p_scores[0], p_scores[-1])
print("Recall:", r_scores[0], r_scores[-1])
print("F1:", f_scores[0], f_scores[-1])
print("Accuracy:", float(a_scores[0]), float(a_scores[-1]))

First vs Final Scores:
Cross Entropy Loss: 0.09462826177477837 0.09460789542645216
Precision: 0.7863829787234042 0.7812765957446809
Recall: 0.24283837056504598 0.24966004895295077
F1: 0.3710843373493976 0.37840065952184665
Accuracy: 0.8936141133308411 0.897554337978363


In [None]:
# functions to show an image


def imshow(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


# get some random training images
dataiter = iter(train_loader)
images, labels = next(dataiter)

# show images
imshow(images[0])
# print labels
print(classes[labels[0]], images[0].shape)