In [1]:
import sys
sys.path.insert(0, '..')

import torch.nn as nn
from torch.utils.data import DataLoader
from utils.eye_dataset import *
from eye_classifier import *
import torchvision.transforms as transforms
import torchvision.models as models

In [2]:
base_dir = "../../data"

image_dir_training = f"{base_dir}/ODIR-5K/training"
#image_dir_training = f"{base_dir}/preprocessed_images"
image_dir_testing = f"{base_dir}/ODIR-5K/testing"
csv_file = f'{base_dir}/ODIR-5K/data.csv'

print ('reading input dataset')
input_size = 224

apply_transforms = transforms.Compose([
    transforms.Resize(size=input_size),
    transforms.CenterCrop(size=input_size),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# apply_transforms = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# ])

ds = EyeImageDataset(root=image_dir_training, data_info_csv_file=csv_file, transform=apply_transforms)


reading input dataset


Building the model

In [3]:

class ResnetEyeClassifier(EyeClassifier):
    def __init__(self, num_classes: int) -> None:
        super(ResnetEyeClassifier, self).__init__(model=[

            (models.resnet18(pretrained=True), TransferFunction.NotApplicable),

            (nn.Linear(in_features=1000, out_features=84),
             TransferFunction.Relu),

            (nn.Linear(in_features=84, out_features=42),
             TransferFunction.Relu),

            (nn.Linear(in_features=42, out_features=num_classes),
             TransferFunction.NotApplicable),
        ])


nn = ResnetEyeClassifier(num_classes=len(ds.classes))
print(nn)

ResnetEyeClassifier(
  (layer 1): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=Tru

Training the model

In [4]:
#ds.set_buffer_size(1024)
nn.train_model(ds, shuffle=True)


training (1%) epoch 1/100, loss = 2.0092
training (2%) epoch 2/100, loss = 1.2397
training (3%) epoch 3/100, loss = 0.8970
training (4%) epoch 4/100, loss = 1.9033
training (5%) epoch 5/100, loss = 0.6805
training (6%) epoch 6/100, loss = 0.7784
training (7%) epoch 7/100, loss = 0.7446
training (8%) epoch 8/100, loss = 0.8137
training (9%) epoch 9/100, loss = 1.1901
training (10%) epoch 10/100, loss = 0.7632
training (11%) epoch 11/100, loss = 0.4135
training (12%) epoch 12/100, loss = 0.5246
training (13%) epoch 13/100, loss = 0.7102
training (14%) epoch 14/100, loss = 0.2448
training (15%) epoch 15/100, loss = 0.8478
training (16%) epoch 16/100, loss = 1.0685
training (17%) epoch 17/100, loss = 0.0406
training (18%) epoch 18/100, loss = 0.3717
training (19%) epoch 19/100, loss = 0.0949
training (20%) epoch 20/100, loss = 0.0216
training (21%) epoch 21/100, loss = 0.4898
training (22%) epoch 22/100, loss = 0.0799
training (23%) epoch 23/100, loss = 0.0597
training (24%) epoch 24/100, 

Testing the model with the training set

In [5]:
nn.test_model(ds)

testing 1% [4 / 6564 files]
testing 2% [72 / 6564 files]
testing 3% [140 / 6564 files]
testing 4% [208 / 6564 files]
testing 5% [276 / 6564 files]
testing 6% [344 / 6564 files]
testing 7% [412 / 6564 files]
testing 8% [480 / 6564 files]
testing 9% [548 / 6564 files]
testing 10% [616 / 6564 files]
testing 11% [684 / 6564 files]
testing 12% [752 / 6564 files]
testing 13% [820 / 6564 files]
testing 14% [888 / 6564 files]
testing 15% [956 / 6564 files]
testing 16% [1024 / 6564 files]
testing 17% [1092 / 6564 files]
testing 18% [1160 / 6564 files]
testing 19% [1228 / 6564 files]
testing 20% [1296 / 6564 files]
testing 21% [1364 / 6564 files]
testing 22% [1432 / 6564 files]
testing 23% [1500 / 6564 files]
testing 24% [1568 / 6564 files]
testing 25% [1636 / 6564 files]
testing 26% [1704 / 6564 files]
testing 27% [1772 / 6564 files]
testing 28% [1840 / 6564 files]
testing 29% [1908 / 6564 files]
testing 30% [1976 / 6564 files]
testing 31% [2044 / 6564 files]
testing 32% [2112 / 6564 files]
tes

Testing the network with the test set

In [6]:

nn.save_layer_weights(0, "eye_classification_net_full_resnet18.w")
nn.save_weights("eye_classification_net_full.w")


In [7]:
input_size = 224

apply_transforms = transforms.Compose([
    transforms.Resize(size=input_size),
    transforms.CenterCrop(size=input_size),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

ds = EyeImageDataset(root=image_dir_testing, data_info_csv_file=csv_file, transform=apply_transforms)


In [8]:
ds.set_buffer_size(16)
nn.test_model(ds)

testing 1% [4 / 6564 files]
testing 2% [72 / 6564 files]
testing 3% [140 / 6564 files]
testing 4% [208 / 6564 files]
testing 5% [276 / 6564 files]
testing 6% [344 / 6564 files]
testing 7% [412 / 6564 files]
testing 8% [480 / 6564 files]
testing 9% [548 / 6564 files]
testing 10% [616 / 6564 files]
testing 11% [684 / 6564 files]
testing 12% [752 / 6564 files]
testing 13% [820 / 6564 files]
testing 14% [888 / 6564 files]
testing 15% [956 / 6564 files]
testing 16% [1024 / 6564 files]
testing 17% [1092 / 6564 files]
testing 18% [1160 / 6564 files]
testing 19% [1228 / 6564 files]
testing 20% [1296 / 6564 files]
testing 21% [1364 / 6564 files]
testing 22% [1432 / 6564 files]
testing 23% [1500 / 6564 files]
testing 24% [1568 / 6564 files]
testing 25% [1636 / 6564 files]
testing 26% [1704 / 6564 files]
testing 27% [1772 / 6564 files]
testing 28% [1840 / 6564 files]
testing 29% [1908 / 6564 files]
testing 30% [1976 / 6564 files]
testing 31% [2044 / 6564 files]
testing 32% [2112 / 6564 files]
tes