# Imports

In [2]:
import timm

from PIL import Image

import torch
from torch.utils.data import random_split, DataLoader
from torch.optim import Adam

import torch.nn as nn
from torch.nn import CrossEntropyLoss

from torchvision import datasets, transforms

# Function Definitions

In [4]:
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss, total_correct = 0, 0
    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_correct += (outputs.argmax(1) == labels).type(torch.float).sum().item()

    return total_loss / len(dataloader), total_correct / len(dataloader.dataset)

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss, total_correct = 0, 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.itm()
            total_correct += (outputs.argmax(1) == labels).type(torch.float).sum().item()

    return total_loss / len(dataloader)

def train_model(start, epochs, model, criterion, optimizer, train_loader, valid_loader, device, PATH='~'):
    hist_train_loss = []
    hist_valid_loss = []
    hist_train_accs = []
    hist_valid_accs = []

    for i in range(start, start + epochs):
        train_corr = 0
        valid_corr = 0
        batch_corr = 0

        for X_train, y_train in train_loader:
            X_train, y_train = X_train.to(device), y_train.to(device)

            train_pred = model(X_train)
            train_loss = criterion(train_pred, y_train)

            train_predicted = torch.max(train_pred.data, 1)[1]
            batch_corr = (train_predicted == y_train).sum()
            train_corr += batch_corr

            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()

        train_accuracy = train_corr.item() / len(train_loader.dataset)

        hist_train_loss.append(train_loss)
        hist_train_accs.append(train_accuracy)

        with torch.no_grad():
            for X_valid, y_valid in valid_loader:
                X_valid, y_valid = X_valid.to(device), y_valid.to(device)

                valid_pred = model(X_valid)

                valid_predicted = torch.max(valid_pred.data, 1)[1]
                valid_corr += (valid_predicted == y_valid).sum()

        valid_accuracy = valid_corr.item() / len(valid_loader.dataset)
        valid_loss = criterion(valid_pred, y_valid)

        hist_valid_loss.append(valid_loss)
        hist_valid_accs.append(valid_accuracy)

        print(
                f'epoch: {i} training loss: {train_loss.item()} training accuracy: {train_accuracy} '
                f'val loss: {valid_loss.item()} val accuracy: {valid_accuracy}'
        )

        # torch.save(
        #         {
        #             'epoch': i,
        #             'model_state_dict': optimizer.state_dict(),
        #             'optimizer_state_dict': optimizer.state_dict(),
        #             'loss': valid_loss
        #         }, 
        #         PATH
        # )


# Main Excecution

In [5]:
# PREPARING THE DATASET

data_path = 'data/Modern images for Abhinav'

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

dataset = datasets.ImageFolder(root=data_path, transform=transform)

train_size = int(0.7 * len(dataset))
valid_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - valid_size
train_dataset, valid_dataset, test_dataset = random_split(dataset, [train_size, valid_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=256, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)


# CREATING AND TRAINING THE ViT MODEL

model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=len(dataset.classes))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

optimizer = Adam(model.parameters(), lr=1e-4)
criterion = CrossEntropyLoss()

start = 0
epochs = 2

train_model(start, epochs, model, criterion, optimizer, train_loader, valid_loader, device)

Preparing the dataset...
[x] Dataset loaded =)
[x] DataLoaders defined =)
Training the model...
[x] Model instantiated and moved to cpu
epoch: 0 training loss: 1.3623961210250854 training accuracy: 0.29842931937172773 val loss: 5.290616035461426 val accuracy: 0.375
epoch: 1 training loss: 4.693085193634033 training accuracy: 0.45549738219895286 val loss: 2.8824639320373535 val accuracy: 0.225
epoch: 2 training loss: 2.8795652389526367 training accuracy: 0.17277486910994763 val loss: 1.8244686126708984 val accuracy: 0.125
epoch: 3 training loss: 1.644390344619751 training accuracy: 0.18848167539267016 val loss: 1.70224928855896 val accuracy: 0.375
epoch: 4 training loss: 1.5857023000717163 training accuracy: 0.45549738219895286 val loss: 1.8623991012573242 val accuracy: 0.375
epoch: 5 training loss: 1.7326589822769165 training accuracy: 0.45549738219895286 val loss: 1.8335832357406616 val accuracy: 0.375
epoch: 6 training loss: 1.689814567565918 training accuracy: 0.45549738219895286 va