In [None]:
"""
Also checkout
Pytorch paper replication  : https://www.learnpytorch.io/08_pytorch_paper_replicating/"""
from google.colab import drive
drive.mount("/content/drive")

In [None]:
"""CONFIG"""
import torch
from pathlib import Path

DATA_ROOT = Path("/content/drive/MyDrive/research")
CHECKPOINT_PATH = DATA_ROOT/"vision_transformer/model_checkpoints"
RESUME=False

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_EPOCHS = 300
BATCH_SIZE = 32

In [None]:
"""DATASET"""

"""
USEFULL STUFF
Pathlib - https://pythonandvba.com/wp-content/uploads/2021/11/Pathlib_Cheat_Sheet.pdf
Caltech256 Dataset : https://pytorch.org/vision/stable/generated/torchvision.datasets.Caltech256.html#torchvision.datasets.Caltech256

"""

import torch
from torchvision import datasets
from torchvision.transforms import transforms
from sklearn.model_selection import train_test_split

torch.manual_seed(42)

def to_rgb(image):
  """Converts a grayscale image to RGB format."""
  if len(image.getbands()) == 1:
    # Add two dummy channels to make it RGB
    return image.convert('RGB')
  else:
    return image

transform = transforms.Compose([
  transforms.Lambda(to_rgb),
    transforms.Resize((224, 224)),  # Resize images to a fixed size
    transforms.ToTensor(),           # Convert images to tensors
])

# dataset = datasets.Caltech256(DATA_ROOT, transform=transform, download=True)
dataset = datasets.Caltech101(DATA_ROOT, transform=transform, download=True)
print("Dataset size : ",len(dataset))
indices = list(range(len(dataset)))

split = int(0.8 * len(dataset))
train_indices, test_indices = indices[:split], indices[split:]

# Create training and test subsets using Subset
train_dataset = torch.utils.data.Subset(dataset, train_indices)
test_dataset = torch.utils.data.Subset(dataset, test_indices)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)


In [None]:
"""MODEL"""

import torch
import torch.nn as nn

class PatchEmbed(nn.Module):
  def __init__(self, img_size, patch_size, in_channel=3, embed_dim=768):
    super().__init__()
    self.img_size = img_size
    self.patch_size = patch_size
    self.n_patches = (img_size//patch_size)**2

    self.proj = nn.Conv2d(
      in_channel,
      embed_dim,
      kernel_size=patch_size,
      stride=patch_size
    )
  def forward(self, x):
    x = self.proj(x)
    x = x.flatten(2)
    x = x.transpose(1,2)

    return x

class Attention(nn.Module):
  def __init__(self, dim, n_heads=12, qkv_bias=True, attn_p=0., proj_p=0.):
    super().__init__()
    self.n_heads=n_heads
    self.dim=dim
    self.head_dim= dim//n_heads
    self.scale=self.head_dim** -0.5

    self.qkv = nn.Linear(dim, dim*3, bias=qkv_bias)
    self.attn_drop = nn.Dropout(attn_p)
    self.proj = nn.Linear(dim, dim)
    self.proj_drop = nn.Dropout(proj_p)

  def forward(self, x):
    batch, n_tokens, dim = x.shape
    # assert dim != self.dim , "chud gaya"
    if dim != self.dim:
      raise ValueError


    qkv = self.qkv(x)
    qkv = qkv.reshape(batch, n_tokens, 3, self.n_heads, self.head_dim)
    qkv = qkv.permute(2,0,3,1,4)
    q,k,v = qkv[0], qkv[1], qkv[2]
    k_t = k.transpose(-2,-1)
    dp = (
      q@k_t
    )*self.scale
    attn = dp.softmax(dim=-1)
    attn = self.attn_drop(attn)
    weighted_avg = attn@v
    weighted_avg = weighted_avg.transpose(1,2)
    weighted_avg = weighted_avg.flatten(2)

    x = self.proj(weighted_avg)
    x = self.proj_drop(x)

    return x

class MLP(nn.Module):
  def __init__(self, in_features, hidden_features, out_features, p=0.):
    super().__init__()
    self.fc1 = nn.Linear(in_features, hidden_features)
    self.act = nn.GELU()
    self.fc2 = nn.Linear(hidden_features, out_features)
    self.drop = nn.Dropout(p)


  def forward(self, x):
    x = self.fc1(x)
    x= self.act(x)
    x = self.drop(x)
    x = self.fc2(x)
    x= self.drop(x)

    return x

class TransformerBlock(nn.Module):
  def __init__(self, dim, n_heads, mlp_ratio=4.0, qkv_bias=True, p=0, attn_p=0.):
    super().__init__()
    self.norm1 = nn.LayerNorm(dim, eps=1e-6)
    self.attn = Attention(
      dim,
      n_heads=n_heads,
      qkv_bias=qkv_bias,
      attn_p=attn_p,
      proj_p=p
    )
    self.norm2 = nn.LayerNorm(dim, eps=1e-6)
    hidden_features = int(dim*mlp_ratio)
    self.mlp = MLP(
      in_features=dim,
      hidden_features=hidden_features,
      out_features=dim
    )

  def forward(self,x):
    x = x + self.attn(self.norm1(x))
    x = x + self.mlp(self.norm2(x))

    return x

class VisionTransformer(nn.Module):
  def __init__(
    self,
    img_size=224,
    patch_size=16,
    in_channels=3,
    n_classes=1000,
    embed_dim=768,
    depth=12,
    n_heads=12,
    mlp_ratio=4,
    qkv_bias=True,
    p=0.,
    attn_p=0.,
  ):
    super().__init__()

    self.patch_embed = PatchEmbed(
      img_size=img_size,
      patch_size=patch_size,
      in_channel=in_channels,
      embed_dim=embed_dim
    )
    self.cls_token = nn.Parameter(torch.zeros(1,1, embed_dim))
    self.pos_embed = nn.Parameter(
      torch.zeros(1,1+self.patch_embed.n_patches, embed_dim)
    )
    self.pos_drop = nn.Dropout(p=p)

    self.blocks = nn.ModuleList(
      [
        TransformerBlock(dim=embed_dim,
             n_heads=n_heads,
             mlp_ratio=mlp_ratio,
             qkv_bias=qkv_bias,
             p=p,
             attn_p=attn_p
             )
        for _ in range(depth)
      ]
    )

    self.norm = nn.LayerNorm(embed_dim, eps=1e-6)
    self.head = nn.Linear(embed_dim, n_classes)

  def forward(self, x):
    batch_size = x.shape[0]
    x = self.patch_embed(x)
    cls_token = self.cls_token.expand(batch_size, -1,-1) # (batch, 1, embed_dim)
    x = torch.cat((cls_token, x), dim=1)
    x = x + self.pos_embed
    x = self.pos_drop(x)


    for block in self.blocks:
      x= block(x)


    x = self.norm(x)
    cls_token_final = x[:, 0]# just cls token
    x = self.head(cls_token_final)

    return x


"""
Download ImageNet and train
ImageNet Dataset Download : https://image-net.org/challenges/LSVRC/2012/2012-downloads.php
ImageNet Dataset Download : https://image-net.org/download-images
PyTorch ImageNet usage : https://pytorch.org/vision/main/generated/torchvision.datasets.ImageNet.html
create config.py"""


In [None]:
import torch
from pathlib import Path

def save_checkpoint(state_dict, epoch, loss, path):
    p = Path(path)
    if not p.exists():
        print("Creating folder")
        p.mkdir(parents=True, exist_ok=True)

    model_details = {
        "epoch":epoch,
        "state_dict": state_dict,
        "loss" : loss,
    }
    torch.save(model_details, f"{p}/vit{epoch}.pth")
    print(f"model saved at path : {p}/vit{epoch}.pth")


def load_pretrained(model, path, epoch):
    model.load_state_dict(torch.load(f"{path}/vit{epoch}.pth")["state_dict"])
    return model



In [None]:

import torch.nn as nn
import torch.optim as optim
import time

if __name__ == "__main__":

  model = VisionTransformer()
  if CHECKPOINT_PATH.exists() and RESUME:
    models = []
    for file in CHECKPOINT_PATH.iterdir():
      models.append(file.stem)
    models.sort()
    ckpt = models[-1]
    epoch = ckpt[-1]
    print(epoch)

    model = load_pretrained(model, CHECKPOINT_PATH/f"{ckpt}{epoch}.pth")
  model = model.to(DEVICE)
  print("Num parameters of model : ", sum(p.numel() for p in model.parameters()))

  criterion = nn.CrossEntropyLoss()
  optimizer = optim.Adam(model.parameters(), lr=0.001)
  print("_______________________\nStarting Training\n")

  for epoch in range(NUM_EPOCHS):
    print("Training for epoch : ", epoch+1)
    start = time.time()
    for i, (imgs, labels) in enumerate(train_loader):
      
      imgs = imgs.to(DEVICE)
      labels = labels.to(DEVICE)
      
      prediction = model(imgs)
      loss = criterion(prediction, labels)
      print(loss)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      if (i%50==0):
        print("Saving checkpoint : ", i)
        save_checkpoint(model.state_dict(), epoch+1, loss, CHECKPOINT_PATH)

    end = time.time()
    print(f"Epoch {epoch+1} Training time : {end-start}s")

