In [1]:
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt
from PIL import Image
import torch
from torchvision import datasets, models, transforms
import torch.nn as nn
from torch.nn import functional as F
import torch.optim as optim
import torchvision

In [2]:
input_path = "C:/Users/AS-GP/Desktop/Resnet50/lisat_gaze_data_v1/"

In [3]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

data_transforms = {
    'train':
    transforms.Compose([
        transforms.Resize((224,224)),
        transforms.RandomAffine(0, shear=10, scale=(0.8,1.2)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ]),
    'validation':
    transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(), #for parallel computations
        normalize
    ]),
}

#train_subset_size = int(0.4 * len(datasets.ImageFolder(input_path + 'lisat_gaze_data_v1/train', transform=data_transforms['train']))) #testing hashelha b3den
#val_subset_size = int(0.4 * len(datasets.ImageFolder(input_path + 'lisat_gaze_data_v1/val', transform=data_transforms['validation'])))
image_datasets = {
    'train':
    datasets.ImageFolder(input_path + 'lisat_gaze_data_v1/train', data_transforms['train']),
    'validation':
    datasets.ImageFolder(input_path + 'lisat_gaze_data_v1/val', data_transforms['validation'])
}

dataloaders = {
    'train':
    torch.utils.data.DataLoader(image_datasets['train'],
                                batch_size=32,
                                shuffle=True,
                                num_workers=2),  # for now 2
    'validation':
    torch.utils.data.DataLoader(image_datasets['validation'],
                                batch_size=32,
                                shuffle=False,
                                num_workers=2)  # for Kaggle
}

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [5]:
model =torchvision.models.mobilenet_v2(pretrained=True).to(device)
model



MobileNetV2(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=

In [6]:
model2 =torchvision.models.mobilenet_v2(pretrained=True).to(device)

for param in model.parameters():
    param.requires_grad = False

model2.classifier = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(in_features=1280, out_features=8, bias=True),
            nn.Softmax(dim=1))

model2

MobileNetV2(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=

In [7]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model2.classifier.parameters())

In [15]:
class EarlyStopper:
    def __init__(self, patience=1, min_delta=3):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None

    def early_stop(self, current_loss):
        if self.best_loss is None:
            self.best_loss = current_loss
            print('best_loss is None')
        elif self.best_loss - current_loss > self.min_delta:
            self.best_loss = current_loss
            print('self.best_loss - current_loss')
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

In [16]:
weights_path='model_weights.h5'

In [19]:
early_stopper = EarlyStopper(3,3)
def train_model(model2, criterion, optimizer, num_epochs=5):
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
        print('-' * 10)

        for phase in ['train', 'validation']:
            if phase == 'train':
                model2.train().to(device) #added to device
            else:
                model2.eval().to(device)

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                outputs = model2(inputs)
                loss = criterion(outputs, labels)

                if phase == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                _, preds = torch.max(outputs, 1)
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(image_datasets[phase])
            epoch_acc = running_corrects.double() / len(image_datasets[phase])

            if phase == 'validation':
                print('hey i am in val now')
                if early_stopper.early_stop(epoch_loss):
                    print("Early stopping triggered")
                    return model2
                print(early_stopper.best_loss)

                if epoch_loss < early_stopper.best_loss:
                    early_stopper.best_loss = epoch_loss
                    torch.save(model2.state_dict(), weights_path)
                    print("Model saved as it achieved the best validation loss so far.")

            print('{} loss: {:.4f}, acc: {:.4f}'.format(phase,
                                                        epoch_loss,
                                                        epoch_acc))

    return model2

In [20]:
model_trained = train_model(model2, criterion, optimizer, num_epochs=5) #1 epoch for testing gpu

Epoch 1/5
----------
train loss: 1.5874, acc: 0.6987
hey i am in val now
best_loss is None
1.6775708421045001
validation loss: 1.6776, acc: 0.5975
Epoch 2/5
----------
train loss: 1.5892, acc: 0.6934
hey i am in val now
1.6775708421045001
Model saved as it achieved the best validation loss so far.
validation loss: 1.6742, acc: 0.5983
Epoch 3/5
----------
train loss: 1.5836, acc: 0.7004
hey i am in val now
1.674218123357655
Model saved as it achieved the best validation loss so far.
validation loss: 1.6613, acc: 0.6130
Epoch 4/5
----------
train loss: 1.5772, acc: 0.7078
hey i am in val now
Early stopping triggered


In [11]:
weights_path = 'C:/Users/AS-GP/Desktop/MobileNet/model_weights.h5'

# model weights
torch.save(model_trained.state_dict(), weights_path)

In [None]:
#Later to restore:
model.load_state_dict(torch.load(weights_path))
model.eval()