In [1]:
from google.colab import drive
drive.mount('/content/drive')

from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math, os, zipfile, shutil, random
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import CosineAnnealingLR


zip_path = '/content/drive/MyDrive/archive.zip'
extract_dir = '/content/asl_alphabet'

if not os.path.exists(extract_dir):
    print("Extracting archive.zip...")
    shutil.copy(zip_path, '/content/archive.zip')
    try:
        with zipfile.ZipFile('/content/archive.zip', 'r') as zip_ref:
            zip_ref.extractall(extract_dir)
        print("Extraction successful.")
    except zipfile.BadZipFile:
        print("Error: Not a valid zip file.")
else:
    print("Dataset already extracted.")

train_dir = os.path.join(extract_dir, 'asl_alphabet_train', 'asl_alphabet_train')
image_size = (100, 100)


transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

@dataclass
class Config:
  n_embd: int = 768
  n_head: int = 12
  n_layer: int = 12
  n_class: int = 29
  patch_size: int = 16
  max_len: int = (100 // patch_size) * (100 // patch_size)  # 100x100 image




class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        norm = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
        return self.scale * (x / (norm + self.eps))

class RotaryEmbedding(nn.Module):
    def __init__(self, dim, base=10000):
        super().__init__()
        self.dim = dim
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, seq_len, device=None):
        # shape: (seq_len)
        positions = torch.arange(seq_len, device=device).type_as(self.inv_freq)
        # outer product → (seq_len, dim/2)
        angles = torch.einsum("i,j->ij", positions, self.inv_freq)
        # (seq_len, dim/2, 2)
        emb = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
        return emb  # shape: (seq_len, dim/2, 2)

def apply_rotary_emb(x, rope):
    # x: (batch, heads, seq, dim)
    # rope: (seq, dim/2, 2)
    x1 = x[..., ::2]  # even dims
    x2 = x[..., 1::2] # odd dims

    cos = rope[..., 0].unsqueeze(0).unsqueeze(0)  # (1, 1, seq, dim/2)
    sin = rope[..., 1].unsqueeze(0).unsqueeze(0)

    x_rotated_even = x1 * cos - x2 * sin
    x_rotated_odd = x1 * sin + x2 * cos

    return torch.stack([x_rotated_even, x_rotated_odd], dim=-1).flatten(-2)


class SelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.rope = RotaryEmbedding(config.n_embd // config.n_head)


        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias = False)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias = False)


        self.n_head = config.n_head
        self.n_embd = config.n_embd

    def forward(self, x):
        B, T, C = x.size()  # Batch, Time, Channels
        qkv = self.c_attn(x)  # shape: (B, T, 3*C)

        q, k, v = qkv.split(self.n_embd, dim=2)

        # Reshape: (B, heads, T, head_dim)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        # Apply  (RoPE)
        rope = self.rope(T, device=x.device)  # (T, head_dim/2, 2)
        q = apply_rotary_emb(q, rope)  # shape unchanged
        k = apply_rotary_emb(k, rope)

        # Attention score calculation
        att = (q @ k.transpose(-2, -1)) * (1.0 / (C // self.n_head) ** 0.5)
        att = F.softmax(att, dim=-1)

        # Apply attention to values
        y = att @ v  # (B, heads, T, head_dim)
        y = y.transpose(1, 2).contiguous().view(B, T, C)  # reshape to (B, T, C)

        # Final projection
        y = self.c_proj(y)
        return y


class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.gelu = nn.GELU(approximate = 'tanh')
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x

class PatchEmbedding(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.proj = nn.Conv2d(
            in_channels =  3,
            out_channels = config.n_embd,
            kernel_size = config.patch_size,
            stride = config.patch_size
        )
    def forward(self,x):
      x = self.proj(x)
      x = x.flatten(2)
      x = x.transpose(1,2)
      return x

class VisionTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.patch_embedding = PatchEmbedding(config)

        self.cls_token = nn.Parameter(torch.zeros(1, 1, config.n_embd))
        # self.pos_embedding = nn.Parameter(torch.zeros(1, 1 + config.max_len, config.n_embd))

        self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layer)])

        self.head = nn.Linear(config.n_embd, config.n_class)

        self.ln_f = RMSNorm(config.n_embd)

    def forward(self, x):
        x = self.patch_embedding(x)
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        # x = x + self.pos_embedding
        for block in self.blocks:
          x = block(x)
        x = self.ln_f(x)
        cls_output = x[:, 0]  # Shape: (B, C)
        return self.head(cls_output)




class Block(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.ln1 = RMSNorm(config.n_embd)
        self.attention = SelfAttention(config)
        self.ln2 = RMSNorm(config.n_embd)
        self.mlp = MLP(config)
    def forward(self, x):
        x = x + self.attention(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x



full_dataset = datasets.ImageFolder(train_dir, transform=transform)
val_percent = 0.1
val_size = int(val_percent * len(full_dataset))
train_size = len(full_dataset) - val_size

train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42))
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True, num_workers=os.cpu_count(), pin_memory=True, prefetch_factor=32)
val_loader = DataLoader(val_dataset, batch_size=512, shuffle=False, num_workers=os.cpu_count(), pin_memory=True)

best_val_acc = 0.0
scaler = torch.cuda.amp.GradScaler()


config = Config()
model = VisionTransformer(config).to(device)
model = torch.compile(model)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6)
early_stopping_patience = 5
early_stopping_counter = 0


num_epochs = 10


for epoch in range(num_epochs):
    model.train()
    train_loss, train_correct, total = 0.0, 0, 0

    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        images, labels = images.to(device,memory_format = torch.channels_last, non_blocking = True), labels.to(device, non_blocking = True)
        optimizer.zero_grad()

        with torch.cuda.amp.autocast():
            outputs = model(images)
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item()
        train_correct += (outputs.argmax(1) == labels).sum().item()
        total += labels.size(0)

    train_acc = 100 * train_correct / total


    model.eval()
    val_correct, val_total = 0, 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            with torch.cuda.amp.autocast():
                outputs = model(images)
            val_correct += (outputs.argmax(1) == labels).sum().item()
            val_total += labels.size(0)

    val_acc = 100 * val_correct / val_total

    print(f"\nEpoch {epoch+1}: Train Loss = {train_loss/len(train_loader):.4f} | Train Acc = {train_acc:.2f}% | Val Acc = {val_acc:.2f}%")

    scheduler.step()


    if val_acc > best_val_acc:
        best_val_acc = val_acc
        early_stopping_counter = 0
        torch.save(model.state_dict(), "/content/best_vit_model.pth")
    else:
        early_stopping_counter+= 1
        if early_stopping_counter >= early_stopping_patience:
            print(f"Early stopping triggered after {epoch+1} epochs.")
            break




Mounted at /content/drive
Extracting archive.zip...
Extraction successful.


  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():
W0716 12:31:05.469000 309 torch/_inductor/utils.py:1137] [0/0] Not enough SMs to use max_autotune_gemm mode
Epoch 1/10: 100%|██████████| 153/153 [03:21<00:00,  1.32s/it]



Epoch 1: Train Loss = 2.0615 | Train Acc = 35.79% | Val Acc = 75.17%


Epoch 2/10: 100%|██████████| 153/153 [00:59<00:00,  2.56it/s]



Epoch 2: Train Loss = 0.4066 | Train Acc = 85.81% | Val Acc = 89.85%


Epoch 3/10: 100%|██████████| 153/153 [01:00<00:00,  2.53it/s]



Epoch 3: Train Loss = 0.1639 | Train Acc = 94.39% | Val Acc = 95.60%


Epoch 4/10: 100%|██████████| 153/153 [01:01<00:00,  2.49it/s]



Epoch 4: Train Loss = 0.0721 | Train Acc = 97.56% | Val Acc = 97.72%


Epoch 5/10: 100%|██████████| 153/153 [01:01<00:00,  2.49it/s]



Epoch 5: Train Loss = 0.0275 | Train Acc = 99.16% | Val Acc = 98.98%


Epoch 6/10: 100%|██████████| 153/153 [01:01<00:00,  2.49it/s]



Epoch 6: Train Loss = 0.0091 | Train Acc = 99.75% | Val Acc = 99.53%


Epoch 7/10: 100%|██████████| 153/153 [01:01<00:00,  2.49it/s]



Epoch 7: Train Loss = 0.0030 | Train Acc = 99.93% | Val Acc = 99.82%


Epoch 8/10: 100%|██████████| 153/153 [01:01<00:00,  2.49it/s]



Epoch 8: Train Loss = 0.0008 | Train Acc = 99.99% | Val Acc = 99.85%


Epoch 9/10: 100%|██████████| 153/153 [01:01<00:00,  2.50it/s]



Epoch 9: Train Loss = 0.0005 | Train Acc = 100.00% | Val Acc = 99.83%


Epoch 10/10: 100%|██████████| 153/153 [01:01<00:00,  2.48it/s]



Epoch 10: Train Loss = 0.0004 | Train Acc = 100.00% | Val Acc = 99.83%


In [5]:
torch.save(model.state_dict(), "/content/drive/MyDrive/asl_model_saves/Vit_asl.pth")
