# SUR image model training

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torchvision.datasets
from torchvision.transforms import v2
import torchvision.transforms.functional
import numpy as np

Convolutional neural network for binary classification

In [None]:
class Eye(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.pool1 = nn.MaxPool2d(2)
        #40x40
        self.conv2a = nn.Conv2d(16, 16, 3, padding=1)
        self.conv2b = nn.Conv2d(16, 32, 3, padding=1)
        self.batchNorm2 = nn.BatchNorm2d(32)
        #20x20
        self.conv3a = nn.Conv2d(32, 32, 3, padding=1)
        self.conv3b = nn.Conv2d(32, 64, 3, padding=1)
        self.batchNorm3 = nn.BatchNorm2d(64)
        #10x10

        self.conv4a = nn.Conv2d(64, 64, 3, padding=1)
        self.conv4b = nn.Conv2d(64, 64, 3, padding=1)
        self.batchNorm4 = nn.BatchNorm2d(64)
        # 5x5

        self.lin1 = nn.Linear(5*5*64, 64)
        self.lin2 = nn.Linear(64, 1)
    
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.pool1(x)
        #40x40
        x = self.conv2a(x)
        x = F.relu(x)
        x = self.conv2b(x)
        x = F.relu(x)
        x = self.batchNorm2(x)
        x = self.pool1(x)
        #20x20
        x = self.conv3a(x)
        x = F.relu(x)
        x = self.conv3b(x)
        x = F.relu(x)
        x = self.batchNorm3(x)
        x = self.pool1(x)
        #10x10
        x = self.conv4a(x)
        x = F.relu(x)
        x = self.conv4b(x)
        x = F.relu(x)
        x = self.batchNorm4(x)
        x = self.pool1(x)
        #5x5x128

        x = torch.flatten(x, 1)

        x = self.lin1(x)
        x = F.relu(x)
        x = self.lin2(x)
        x = torch.sigmoid(x)
        return x

Helper transform for data loading.

In [None]:
# Data
to_torch = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True)
])

Used augmentations.

In [None]:
class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        return torch.clamp(tensor + torch.randn(tensor.size(), dtype=torch.float) * self.std + self.mean, 0, 1)
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

In [None]:
#Augmentation
transforms = v2.Compose([
    v2.RandomHorizontalFlip(),
    v2.RandomAffine(15, [0.05,0.05], [0.9,1.1]),
    v2.GaussianBlur(3),
    v2.RandomApply([AddGaussianNoise(0, 0.15)], p=0.3),
    v2.RandomPhotometricDistort()
])

Data loading.

In [None]:
target_train = torchvision.datasets.ImageFolder("data/training", transform=to_torch)
target_validation = torchvision.datasets.ImageFolder("data/validation", transform=to_torch)

inverse_counts = 1 / np.bincount([x[1] for x in target_train])
class_weights = [inverse_counts[x[1]] for x in target_train]

# Target class oversampling
sampler = torch.utils.data.WeightedRandomSampler(class_weights, len(target_train), replacement=True)

loader = torch.utils.data.DataLoader(target_train, batch_size=30, sampler=sampler)
validation_loader = torch.utils.data.DataLoader(target_validation, batch_size=len(target_validation))

Augmentation example.

In [None]:
x = target_train[0][0]
display(torchvision.transforms.functional.to_pil_image(transforms(x)))

### Training

In [None]:
model = Eye()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
lfunc = torch.nn.BCELoss()

In [None]:
n_epochs = 100

for i in range(n_epochs):
    for batch in iter(loader):
        inputs, labels = batch
        inputs = transforms(inputs)
        optimizer.zero_grad()
        result = model(inputs)
        loss:torch.Tensor = lfunc(result, labels.reshape([len(labels),1]).float())
        loss.backward()
        optimizer.step()
    
    with torch.no_grad():
        for batch in validation_loader:
            inputs, labels = batch
            result = model(inputs)
            loss:torch.Tensor = lfunc(result, labels.reshape([len(labels),1]).float())
            print(i, loss.item())


Model export.

In [None]:
torch.save(model.state_dict(), "image_model.pth")