<a href="https://colab.research.google.com/github/MittalNeha/ERA1/blob/main/S17/train_vit.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# !git clone https://github.com/MittalNeha/ERA1.git

In [2]:
%cd ERA1/S17

/content/ERA1/S17


In [3]:
# from google.colab import drive
# drive.mount('/content/drive')

## Contains the train dataset and some helper code for traning

In [4]:
# !cp -r /content/drive/MyDrive/ERA_V1/S17/. .

In [5]:
import matplotlib.pyplot as plt
import torch
import torchvision
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from super_repo import data_setup, engine, utils
from ViT import build_ViT

In [6]:
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
NUM_WORKERS = os.cpu_count()
# Set the batch size
BATCH_SIZE = 32  # this is lower than the ViT paper but it's because we're starting small

image_path = "./pizza_steak_sushi"
train_dir = image_path + "/train"
test_dir = image_path + "/test"

IMG_SIZE = 224

# Create transform pipeline manually
manual_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
])
print(f"Manually created transforms: {manual_transforms}")

device = "cuda" if torch.cuda.is_available() else "cpu"
print("training device: ", device)

Manually created transforms: Compose(
    Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=warn)
    ToTensor()
)
training device:  cuda


In [7]:
def create_dataloaders(
        train_dir: str,
        test_dir: str,
        transform: transforms.Compose,
        batch_size: int,
        num_workers: int = NUM_WORKERS
):
    # Use ImageFolder to create dataset(s)
    train_data = datasets.ImageFolder(train_dir, transform=transform)
    test_data = datasets.ImageFolder(test_dir, transform=transform)

    # Get class names
    class_names = train_data.classes

    # Turn images into data loaders
    train_dataloader = DataLoader(
        train_data,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
    )
    test_dataloader = DataLoader(
        test_data,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
    )

    return train_dataloader, test_dataloader, class_names

In [8]:
train_dataloader, test_dataloader, class_names = create_dataloaders(
    train_dir=train_dir,
    test_dir=test_dir,
    transform=manual_transforms,  # use manually created transforms
    batch_size=BATCH_SIZE
)

In [9]:
vit = build_ViT(num_classes=len(class_names))

In [13]:
optimizer = torch.optim.Adam(params=vit.parameters(),
                             lr=3e-3, # Base LR from Table 3 for ViT-* ImageNet-1k
                             betas=(0.9, 0.999), # default values but also mentioned in ViT paper section 4.1 (Training & Fine-tuning)
                             weight_decay=0.3) # from the ViT paper section 4.1 (Training & Fine-tuning) and Table 3 for ViT-* ImageNet-1k

# Setup the loss function for multi-class classification
loss_fn = torch.nn.CrossEntropyLoss()

# Train the model and save the training results to a dictionary
results = engine.train(model=vit,
                       train_dataloader=train_dataloader,
                       test_dataloader=test_dataloader,
                       optimizer=optimizer,
                       loss_fn=loss_fn,
                       epochs=10,
                       device=torch.device(device))

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 1 | train_loss: 5.1080 | train_acc: 0.2969 | test_loss: 7.3566 | test_acc: 0.2604
Epoch: 2 | train_loss: 3.8140 | train_acc: 0.3320 | test_loss: 5.1516 | test_acc: 0.1979
Epoch: 3 | train_loss: 2.5621 | train_acc: 0.4336 | test_loss: 2.5879 | test_acc: 0.5417
Epoch: 4 | train_loss: 2.7546 | train_acc: 0.2773 | test_loss: 2.8308 | test_acc: 0.2604
Epoch: 5 | train_loss: 1.7354 | train_acc: 0.4102 | test_loss: 2.7835 | test_acc: 0.5417
Epoch: 6 | train_loss: 1.5746 | train_acc: 0.4141 | test_loss: 2.8134 | test_acc: 0.5417
Epoch: 7 | train_loss: 1.5231 | train_acc: 0.4453 | test_loss: 1.7999 | test_acc: 0.2604
Epoch: 8 | train_loss: 1.5297 | train_acc: 0.3242 | test_loss: 1.6965 | test_acc: 0.5417
Epoch: 9 | train_loss: 1.8020 | train_acc: 0.3867 | test_loss: 2.5582 | test_acc: 0.1979
Epoch: 10 | train_loss: 1.4944 | train_acc: 0.2500 | test_loss: 2.4501 | test_acc: 0.1979


In [11]:
!pip install -q torchinfo

In [12]:
from torchinfo import summary
summary(model=vit,
        input_size=(32, 3, 224, 224), # (batch_size, color_channels, height, width)
        # col_names=["input_size"], # uncomment for smaller output
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

Layer (type (var_name))                                                Input Shape          Output Shape         Param #              Trainable
ViT (ViT)                                                              [32, 3, 224, 224]    [32, 3]              152,064              True
├─PatchEmbedding (src_embed)                                           [32, 3, 224, 224]    [32, 196, 768]       --                   True
│    └─Conv2d (patcher)                                                [32, 3, 224, 224]    [32, 768, 14, 14]    590,592              True
│    └─Flatten (flatten)                                               [32, 768, 14, 14]    [32, 768, 196]       --                   --
├─Encoder (encoder)                                                    [32, 197, 768]       [32, 197, 768]       2                    True
│    └─ModuleList (layers)                                             --                   --                   --                   True
│    │    └─EncoderBlock