# Imports

In [4]:
import os

import timm

from PIL import Image

import torch
from torch.utils.data import random_split, DataLoader
from torch.optim import Adam

import torch.nn as nn
from torch.nn import CrossEntropyLoss

from torchvision import datasets, transforms

# Function Definitions

In [19]:
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss, total_correct = 0, 0
    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_correct += (outputs.argmax(1) == labels).type(torch.float).sum().item()

    return total_loss / len(dataloader), total_correct / len(dataloader.dataset)

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss, total_correct = 0, 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.itm()
            total_correct += (outputs.argmax(1) == labels).type(torch.float).sum().item()

    return total_loss / len(dataloader)

def train_model(start, epochs, model, criterion, optimizer, train_loader, valid_loader, device, save_dir):
    hist_train_loss = []
    hist_valid_loss = []
    hist_train_accs = []
    hist_valid_accs = []

    for epoch in range(start, start + epochs):
        train_corr = 0
        valid_corr = 0
        batch_corr = 0

        for X_train, y_train in train_loader:
            X_train, y_train = X_train.to(device), y_train.to(device)

            train_pred = model(X_train)
            train_loss = criterion(train_pred, y_train)

            train_predicted = torch.max(train_pred.data, 1)[1]
            batch_corr = (train_predicted == y_train).sum()
            train_corr += batch_corr

            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()

        train_accuracy = train_corr.item() / len(train_loader.dataset)

        hist_train_loss.append(train_loss)
        hist_train_accs.append(train_accuracy)

        with torch.no_grad():
            for X_valid, y_valid in valid_loader:
                X_valid, y_valid = X_valid.to(device), y_valid.to(device)

                valid_pred = model(X_valid)

                valid_predicted = torch.max(valid_pred.data, 1)[1]
                valid_corr += (valid_predicted == y_valid).sum()

        valid_accuracy = valid_corr.item() / len(valid_loader.dataset)
        valid_loss = criterion(valid_pred, y_valid)

        hist_valid_loss.append(valid_loss)
        hist_valid_accs.append(valid_accuracy)

        print(
            f'[epoch: {epoch}]\n', 
            f'- train loss: {train_loss.item():}, train accuracy: {train_accuracy}\n',
            f'- valid loss: {valid_loss.item()}, valid accuracy: {valid_accuracy}'
        )

        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'hist_train_loss': hist_train_loss, 
            'hist_train_accs': hist_train_accs,
            'hist_valid_loss': hist_valid_loss, 
            'hist_valid_accs': hist_valid_accs
        }, os.path.join(save_dir, f'epoch_{epoch}_checkpoint.pth'))

# Main Excecution

In [25]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.manual_seed(42) if device.type == 'cuda' else torch.manual_seed(42)
print(f'Using device: {device}')

Using device: cpu


In [20]:
# PREPARING THE DATASET

data_path = 'data/Modern images for Abhinav'

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

dataset = datasets.ImageFolder(root=data_path, transform=transform)

train_size = int(0.7 * len(dataset))
valid_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - valid_size
train_dataset, valid_dataset, test_dataset = random_split(dataset, [train_size, valid_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=256, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

In [21]:
# CREATING AND TRAINING THE ViT MODEL

model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=len(dataset.classes))
model = model.to(device)

start = 0
epochs = 2

optimizer = Adam(model.parameters(), lr=1e-4)
criterion = CrossEntropyLoss()

save_dir = 'models'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

train_model(start, epochs, model, criterion, optimizer, train_loader, valid_loader, device, save_dir)

[epoch: 0]
 - train loss: 1.9107078313827515, train accuracy: 0.17801047120418848
 - valid loss: 3.81400990486145, valid accuracy: 0.525
[epoch: 1]
 - train loss: 4.377239227294922, train accuracy: 0.4293193717277487
 - valid loss: 4.2843732833862305, valid accuracy: 0.125
