In [1]:
import sys
sys.path.insert(0, '..')

import torch.nn as nn
from torch.utils.data import DataLoader
from utils.eye_dataset import *
from eye_classifier import *
import torchvision.transforms as transforms
import torchvision.models as models

In [2]:
base_dir = "../../data"
image_dir = f"{base_dir}/preprocessed_images"

image_dir_training = f"{base_dir}/ODIR-5K/training"
image_dir_testing = f"{base_dir}/ODIR-5K/testing"
csv_file = f'{base_dir}/ODIR-5K/data.csv'

print ('reading input dataset')
input_size = 224

apply_transforms = transforms.Compose([
    transforms.Resize(size=input_size),
    transforms.CenterCrop(size=input_size),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

ds = EyeImageDataset(root=image_dir_training, data_info_csv_file=csv_file, transform=apply_transforms)


reading input dataset


Building the model

In [3]:
class ResnetEyeClassifier(EyeClassifier):
    def __init__(self, num_classes: int) -> None:
        super(ResnetEyeClassifier, self).__init__(model=[

            (models.resnet34(pretrained=False), TransferFunction.NotApplicable),

            (nn.Linear(in_features=1000, out_features=256),
             TransferFunction.LeakyRelu),

            (nn.Linear(in_features=256, out_features=64),
             TransferFunction.LeakyRelu),

            (nn.Linear(in_features=64, out_features=16),
             TransferFunction.Relu),

            (nn.Linear(in_features=16, out_features=num_classes),
             TransferFunction.NotApplicable),
        ])


nn = ResnetEyeClassifier(num_classes=len(ds.classes))
print(nn)


ResnetEyeClassifier(
  (layer 1): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=Tru

Training the model

In [5]:
nn.load_layer_weights(0, "resnet34.dat")
nn.freeze_layer(0)

#nn.load_weights('eye_resnet18.dat')
#nn.freeze_layer(0)
nn.train_model(ds, batch_size=16, num_epochs=30)


training (1%) epoch 1/30, loss = 2.1823
training (2%) epoch 1/30, loss = 1.4892
training (3%) epoch 1/30, loss = 1.4491
training (4%) epoch 1/30, loss = 1.3497
training (5%) epoch 2/30, loss = 1.2107
training (6%) epoch 2/30, loss = 1.6123
training (7%) epoch 2/30, loss = 0.9128
training (8%) epoch 3/30, loss = 1.2033
training (9%) epoch 3/30, loss = 1.3054
training (10%) epoch 3/30, loss = 1.4078
training (11%) epoch 4/30, loss = 1.5028
training (12%) epoch 4/30, loss = 1.2998
training (13%) epoch 4/30, loss = 0.9520
training (14%) epoch 4/30, loss = 1.4890
training (15%) epoch 5/30, loss = 1.4733
training (16%) epoch 5/30, loss = 1.3858
training (17%) epoch 5/30, loss = 1.5783
training (18%) epoch 6/30, loss = 0.9785
training (19%) epoch 6/30, loss = 0.8635
training (20%) epoch 6/30, loss = 0.9822
training (21%) epoch 7/30, loss = 0.8750
training (22%) epoch 7/30, loss = 0.7725
training (23%) epoch 7/30, loss = 1.1118
training (24%) epoch 7/30, loss = 1.2641
training (25%) epoch 8/30

KeyboardInterrupt: 

Testing the model with the training set

In [6]:

nn.test_model(ds)

testing 1% [4 / 6564 files]
testing 2% [72 / 6564 files]
testing 3% [140 / 6564 files]
testing 4% [208 / 6564 files]
testing 5% [276 / 6564 files]
testing 6% [344 / 6564 files]


KeyboardInterrupt: 

Testing the network with the test set

In [None]:
#torch.save(nn.state_dict(), "eye_resnet18.dat")

In [None]:
input_size = 224

apply_transforms = transforms.Compose([
    transforms.Resize(size=input_size),
    transforms.CenterCrop(size=input_size),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

ds = EyeImageDataset(root=image_dir_testing, data_info_csv_file=csv_file, transform=apply_transforms)


In [None]:
nn.test_model(ds)