In [1]:
import sys
sys.path.insert(0, '..')

import torch.nn as nn
from torch.utils.data import DataLoader
from utils.eye_dataset import *
from eye_classifier import *
import torchvision.transforms as transforms
import torchvision.models as models

In [2]:
base_dir = "../../data"
image_dir = f"{base_dir}/preprocessed_images"

image_dir_training = f"{base_dir}/ODIR-5K/training"
image_dir_testing = f"{base_dir}/ODIR-5K/testing"
csv_file = f'{base_dir}/ODIR-5K/data.csv'

print ('reading input dataset')
input_size = 224

apply_transforms = transforms.Compose([
    transforms.Resize(size=input_size),
    transforms.CenterCrop(size=input_size),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

ds = EyeImageDataset(root=image_dir_training, data_info_csv_file=csv_file, transform=apply_transforms)


reading input dataset


Building the model

In [3]:
class ResnetEyeClassifier(EyeClassifier):
    def __init__(self, num_classes: int) -> None:
        super(ResnetEyeClassifier, self).__init__(model=[

            (models.resnet18(pretrained=False), TransferFunction.NotApplicable),

            (nn.Linear(in_features=1000, out_features=256),
             TransferFunction.LeakyRelu),

            (nn.Linear(in_features=256, out_features=64),
             TransferFunction.LeakyRelu),

            (nn.Linear(in_features=64, out_features=16),
             TransferFunction.Relu),

            (nn.Linear(in_features=16, out_features=num_classes),
             TransferFunction.NotApplicable),
        ])


nn = ResnetEyeClassifier(num_classes=len(ds.classes))
print(nn)


ResnetEyeClassifier(
  (layer 1): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=Tru

Training the model

In [4]:
#nn.load_weights('eye_resnet18.dat')
#nn.freeze_layer(0)
nn.train(ds, batch_size=16, num_epochs=10)


training (1%) epoch 1/100, loss = 2.0973
training (2%) epoch 2/100, loss = 1.2665
training (3%) epoch 3/100, loss = 1.0868
training (4%) epoch 4/100, loss = 1.1189
training (5%) epoch 5/100, loss = 0.7957
training (6%) epoch 6/100, loss = 0.7116
training (7%) epoch 7/100, loss = 0.4425
training (8%) epoch 8/100, loss = 0.5986
training (9%) epoch 9/100, loss = 0.3236
training (10%) epoch 10/100, loss = 0.6395
training (11%) epoch 11/100, loss = 0.1378
training (12%) epoch 12/100, loss = 0.2401
training (13%) epoch 13/100, loss = 0.2263
training (14%) epoch 14/100, loss = 0.0229
training (15%) epoch 15/100, loss = 0.3152
training (16%) epoch 16/100, loss = 0.4140
training (17%) epoch 17/100, loss = 0.1724
training (18%) epoch 18/100, loss = 0.1346
training (19%) epoch 19/100, loss = 0.3900
training (20%) epoch 20/100, loss = 0.1868
training (21%) epoch 21/100, loss = 0.1086
training (22%) epoch 22/100, loss = 0.1772
training (23%) epoch 23/100, loss = 0.2375
training (24%) epoch 24/100, 

Testing the model with the training set

In [None]:

nn.test(ds)

Testing the network with the test set

In [None]:
#torch.save(nn.state_dict(), "eye_resnet18.dat")

In [None]:
input_size = 224

apply_transforms = transforms.Compose([
    transforms.Resize(size=input_size),
    transforms.CenterCrop(size=input_size),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

ds = EyeImageDataset(root=image_dir_testing, data_info_csv_file=csv_file, transform=apply_transforms)


In [None]:
nn.test(ds)