# Imports

In [1]:
import os
import time

from PIL import Image

import torch
from torch.utils.data import random_split, Subset, DataLoader
from torch.optim import Adam

import torch.nn as nn
from torch.nn import CrossEntropyLoss

from torchvision import datasets, transforms

import timm

# Function Definitions

In [1]:
def split_n_save_indices(data_path, save_path, transform, splits=[.70, .15]):
    """
    Function to split a torchvision dataset and save the indices of the splits
    so that they can be reused.

    ensures recreation of the splits
    """
    dataset = datasets.ImageFolder(root=data_path, transform=transform)
    
    train_size = int(splits[0] * len(dataset))
    valid_size = int(splits[1] * len(dataset))
    test_size = len(dataset) - train_size - valid_size
    
    train_dataset, valid_dataset, test_dataset = random_split(dataset, [train_size, valid_size, test_size])

    torch.save({
        'train_indices': train_dataset.indices, 
        'valid_indices': valid_dataset.indices, 
        'test_indices': test_dataset.indices
    }, save_path)

In [2]:
def train_model(epochs, model, criterion, optimizer, train_loader, valid_loader, device, save_dir, save_every, print_every):
    hist_train_loss = []
    hist_valid_loss = []
    hist_train_accs = []
    hist_valid_accs = []

    best_accuracy = 0.0

    for epoch in range(epochs):
        train_corr = 0
        valid_corr = 0
        batch_corr = 0

        for i, (X_train, y_train) in enumerate(train_loader):
            X_train, y_train = X_train.to(device), y_train.to(device)

            train_pred = model(X_train)
            train_loss = criterion(train_pred, y_train)

            train_predicted = torch.max(train_pred.data, 1)[1]
            batch_corr = (train_predicted == y_train).sum()
            train_corr += batch_corr

            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()

            # print progress
            if (i + 1) % print_every == 0:
                print(f'Iteration {i + 1}/{len(train_loader)}; Loss: {train_loss.item()}')

            # checkpoint saving based on iterations
            if (i + 1) % save_every == 0:
                checkpoint_path = os.path.join(save_dir, f'epoch-{epoch}_iter-{i + 1}_checkpoint.pth')
                torch.save({
                    'epoch': epoch, 
                    'iteration': i + 1, 
                    'model_state_dict': model.state_dict(), 
                    'optimizer_state_dict': optimizer.state_dict(), 
                    'loss': train_loss.item()
                }, checkpoint_path)
                print(f'Checkpoint saved at {checkpoint_path} after {i + 1} iterations of epoch {epoch}')

        train_accuracy = train_corr.item() / len(train_loader.dataset)

        hist_train_loss.append(train_loss.item())
        hist_train_accs.append(train_accuracy)

        with torch.no_grad():
            for X_valid, y_valid in valid_loader:
                X_valid, y_valid = X_valid.to(device), y_valid.to(device)

                valid_pred = model(X_valid)

                valid_predicted = torch.max(valid_pred.data, 1)[1]
                valid_corr += (valid_predicted == y_valid).sum()

        valid_accuracy = valid_corr.item() / len(valid_loader.dataset)
        valid_loss = criterion(valid_pred, y_valid)

        hist_valid_loss.append(valid_loss.item())
        hist_valid_accs.append(valid_accuracy)

        print(
            f'[epoch: {epoch}]\n', 
            f'- train loss: {train_loss.item():}, train accuracy: {train_accuracy}\n',
            f'- valid loss: {valid_loss.item()}, valid accuracy: {valid_accuracy}'
        )

        if valid_accuracy > best_accuracy:
            best_accuracy = valid_accuracy
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'hist_train_loss': hist_train_loss, 
                'hist_train_accs': hist_train_accs,
                'hist_valid_loss': hist_valid_loss, 
                'hist_valid_accs': hist_valid_accs
            }, os.path.join(save_dir, f'best_model.pth'))
            print(f'- New best model saves with accuracy {valid_accuracy:.4f}%')

# Main Excecution

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.manual_seed(42) if device.type == 'cuda' else torch.manual_seed(42)
print(f'Using device: {device}')

Using device: cuda


In [None]:
# PREPARING THE DATASET

data_path = '../data/pretraining-data'
save_path = '../data/pretraining-dataset-indices.pth'

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

dataset = datasets.ImageFolder(root=data_path, transform=transform)

In [None]:
split_n_save_indices(data_path, save_path, transform)

In [None]:
# Load indices
indices = torch.load(save_path)

train_indices = indices['train_indices']
valid_indices = indices['valid_indices']
test_indices = indices['test_indices']

subset_percent = 0.2
train_len = int(0.2 * len(train_indices))
valid_len = int(0.2 * len(valid_indices))
test_len = int(0.2 * len(test_indices))

train_dataset = Subset(dataset, train_indices[:train_len])
valid_dataset = Subset(dataset, valid_indices[:valid_len])
test_dataset = Subset(dataset, test_indices[:test_len])

bs = 128

train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=bs, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=bs, shuffle=False)

In [15]:
# CREATING AND TRAINING THE ViT MODEL

model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=len(dataset.classes))
model = model.to(device)

epochs = 20

optimizer = Adam(model.parameters(), lr=1e-4)
criterion = CrossEntropyLoss()

save_dir = 'models-test'
os.makedirs(save_dir, exist_ok=True)
save_every = 60
print_every = 20

initial = time.time()

train_model(epochs, model, criterion, optimizer, train_loader, valid_loader, device, save_dir, save_every, print_every)

final = time.time()

print(f'Total time for {epochs} epochs of {len(train_loader)} iterations: {final - initial:.4f} seconds')

  return F.conv2d(input, weight, bias, self.stride,


Iteration 20/540; Loss: 1.610893964767456
Iteration 40/540; Loss: 1.4994583129882812
Iteration 60/540; Loss: 1.2412899732589722
Checkpoint saved at models-test/epoch-0_iter-60_checkpoint.pth after 60 iterations of epoch 0
Iteration 80/540; Loss: 0.9037249684333801
Iteration 100/540; Loss: 0.8036942481994629
Iteration 120/540; Loss: 0.8705576062202454
Checkpoint saved at models-test/epoch-0_iter-120_checkpoint.pth after 120 iterations of epoch 0
Iteration 140/540; Loss: 0.7765486240386963
Iteration 160/540; Loss: 0.7585942149162292
Iteration 180/540; Loss: 0.8005860447883606
Checkpoint saved at models-test/epoch-0_iter-180_checkpoint.pth after 180 iterations of epoch 0
Iteration 200/540; Loss: 0.8790202140808105
Iteration 220/540; Loss: 0.7823211550712585
Iteration 240/540; Loss: 0.7159841060638428
Checkpoint saved at models-test/epoch-0_iter-240_checkpoint.pth after 240 iterations of epoch 0
Iteration 260/540; Loss: 0.7567704916000366
Iteration 280/540; Loss: 0.6931222677230835
Iterati