In [1]:
import sys
sys.path.insert(0, '..')

import torch.nn as nn
from torch.utils.data import DataLoader
from utils.eye_dataset import *
from eye_classifier import *
import torchvision.transforms as transforms

In [2]:
base_dir = "../../data"
image_dir_training = f"{base_dir}/ODIR-5K/training"
image_dir_testing = f"{base_dir}/ODIR-5K/testing"
csv_file = f'{base_dir}/ODIR-5K/data.csv'

input_size = 512

print ('reading input dataset')
apply_transforms = transforms.Compose([
    transforms.Resize(size=input_size),
    transforms.CenterCrop(size=input_size),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

ds = EyeImageDataset(root=image_dir_training, data_info_csv_file=csv_file, transform=apply_transforms)
train_loader = DataLoader(ds, batch_size=4, shuffle=True)


reading input dataset


Building the model

In [3]:
class CustomEyeClassifier(EyeClassifier):
    def __init__(self, num_classes: int) -> None:
        super(CustomEyeClassifier, self).__init__(model=[

            (nn.Conv2d(in_channels=3, out_channels=6,
                       kernel_size=(5, 5), stride=(2, 2), padding=(0, 0), dilation=(1, 1)),
             TransferFunction.NotApplicable),

            (nn.MaxPool2d(
                kernel_size=(5, 5), stride=(2, 2), padding=(0, 0), dilation=(1, 1)),
             TransferFunction.NotApplicable),

            (nn.Dropout(),
             TransferFunction.NotApplicable),

            (nn.Conv2d(in_channels=6, out_channels=16,
                       kernel_size=(5, 5), stride=(2, 2), padding=(0, 0), dilation=(1, 1)),
             TransferFunction.NotApplicable),

            (nn.MaxPool2d(
                kernel_size=(5, 5), stride=(2, 2), padding=(0, 0), dilation=(1, 1)),
             TransferFunction.NotApplicable),

            (nn.Dropout(),
             TransferFunction.NotApplicable),

            (nn.Linear(in_features=13456, out_features=84),
             TransferFunction.Relu),

            (nn.Linear(in_features=84, out_features=42),
             TransferFunction.Relu),

            (nn.Linear(in_features=42, out_features=num_classes),
             TransferFunction.NotApplicable),
        ])

nn = CustomEyeClassifier(num_classes=len(ds.classes))
print(nn)

CustomEyeClassifier(
  (layer 1): Conv2d(3, 6, kernel_size=(5, 5), stride=(2, 2))
  (layer 2): MaxPool2d(kernel_size=(5, 5), stride=(2, 2), padding=(0, 0), dilation=(1, 1), ceil_mode=False)
  (layer 3): Dropout(p=0.5, inplace=False)
  (layer 4): Conv2d(6, 16, kernel_size=(5, 5), stride=(2, 2))
  (layer 5): MaxPool2d(kernel_size=(5, 5), stride=(2, 2), padding=(0, 0), dilation=(1, 1), ceil_mode=False)
  (layer 6): Dropout(p=0.5, inplace=False)
  (layer 7): Linear(in_features=13456, out_features=84, bias=True)
  (layer 8): Linear(in_features=84, out_features=42, bias=True)
  (layer 9): Linear(in_features=42, out_features=8, bias=True)
)


Training the model

In [4]:
nn.train(ds, gpu=False)


training (1%) epoch 1/100, loss = 2.6104
training (2%) epoch 2/100, loss = 1.7469
training (3%) epoch 3/100, loss = 1.7995
training (4%) epoch 4/100, loss = 2.1684
training (5%) epoch 5/100, loss = 1.0854
training (6%) epoch 6/100, loss = 1.1470
training (7%) epoch 7/100, loss = 0.8077
training (8%) epoch 8/100, loss = 2.1153
training (9%) epoch 9/100, loss = 1.5670
training (10%) epoch 10/100, loss = 0.8468
training (11%) epoch 11/100, loss = 1.0097
training (12%) epoch 12/100, loss = 3.6972
training (13%) epoch 13/100, loss = 1.8472
training (14%) epoch 14/100, loss = 1.6781
training (15%) epoch 15/100, loss = 1.3144
training (16%) epoch 16/100, loss = 1.1918
training (17%) epoch 17/100, loss = 0.7174
training (18%) epoch 18/100, loss = 1.7985
training (19%) epoch 19/100, loss = 1.5440
training (20%) epoch 20/100, loss = 2.2547
training (21%) epoch 21/100, loss = 1.2940
training (22%) epoch 22/100, loss = 2.7675
training (23%) epoch 23/100, loss = 0.6979
training (24%) epoch 24/100, 

KeyboardInterrupt: 

Testing the model

In [7]:

nn.test(ds, gpu = False)

testing 1% [1641 files]
accuracy: 66.16% [34742/52512]
	 - Normal: 47.75% [3134/6564]
	 - Diabetes: 30.83% [2024/6564]
	 - Glaucoma: 72.81% [4779/6564]
	 - Cataract: 93.22% [6119/6564]
	 - AgeRelatedMacularDegeneration: 89.21% [5856/6564]
	 - Hypertension: 79.83% [5240/6564]
	 - PathologicalMyopia: 96.74% [6350/6564]
	 - Other: 18.89% [1240/6564]
