In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data import random_split
from torch.utils.data import DataLoader

import importlib
from models.default_cnn import default_cnn
from training.trainer_base import Trainer
from data.transforms import ResizeTransform


torch.manual_seed(42)


in_size = 128
out_size = 2

In [6]:
data_type = 'heatmap'
model_type = 'def_cnn'

transform = ResizeTransform(in_size, in_size)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = default_cnn(in_size, out_size)
print('using device: ' + device.__str__())
# Create a dataset from image folder
dataset = ImageFolder(root='/home/gatemrou/uds/Eye_Tracking/data/cropped/{}s_sorted'.format(data_type), 
                      transform=transform)

# Define the proportions for train, val, and test sets
train_ratio = 0.7
val_ratio = 0.15
test_ratio = 0.15

# Calculate the sizes of each subset based on the ratios
train_size = int(train_ratio * len(dataset))
val_size = int(val_ratio * len(dataset))
test_size = len(dataset) - train_size - val_size

# Randomly split the dataset into train, val, and test subsets
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])


# Create data loaders for each subset
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

learning_rate = 0.02

# Define the loss function and optimizer
loss = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

# Define trainer 
save_path = '/home/gatemrou/uds/Eye_Tracking/saved_models/{}s_{}'.format(data_type, model_type)

trainer = Trainer(model, train_loader, val_loader, save_path, loss, optimizer)

using device: cuda


In [7]:
# training
total_epochs = 100
for _ in range(total_epochs):
    trainer.run_epoch()
trainer.run_epoch()

05.03 20:32:46 --- Epoch [1], Loss: 0.5684
New best validation accuracy: 68.6275
05.03 20:32:47 --- Epoch [2], Loss: 0.5102
05.03 20:32:48 --- Epoch [3], Loss: 0.5040
05.03 20:32:49 --- Epoch [4], Loss: 0.4972
05.03 20:32:50 --- Epoch [5], Loss: 0.4671


KeyboardInterrupt: 

In [3]:
trainer.run_epoch()

05.03 20:32:09 --- Epoch [1], Loss: 0.5684
New best validation accuracy: 68.6275


In [8]:
trainer.print_best_validation_acc()

Best validation accuracy: 68.62745098039215.4f
