<a href="https://colab.research.google.com/github/ArjunBalaji79/KD_CompVision/blob/main/VIT_distilled_4_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import timm

In [None]:
! pip install timm

Collecting timm
  Downloading timm-0.9.12-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m13.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: timm
Successfully installed timm-0.9.12


In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# Training set
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)

# Validation set
valset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
valloader = torch.utils.data.DataLoader(valset, batch_size=32, shuffle=False)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:05<00:00, 28907300.61it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [None]:
# Teacher Model (ViT-H) with pretrained weights

num_classes=10
teacher_model = timm.create_model('vit_large_patch16_224', pretrained=True)
teacher_model.head = nn.Linear(teacher_model.head.in_features, num_classes)
teacher_model.cuda()

# Student Model (EfficientNet-B3) without pretrained weights, so we're training from scratch

student_model = timm.create_model('efficientnet_b3', pretrained=False)
student_model.classifier = nn.Linear(student_model.classifier.in_features, num_classes)
student_model.cuda()


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/1.22G [00:00<?, ?B/s]

EfficientNet(
  (conv_stem): Conv2d(3, 40, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn1): BatchNormAct2d(
    40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
    (drop): Identity()
    (act): SiLU(inplace=True)
  )
  (blocks): Sequential(
    (0): Sequential(
      (0): DepthwiseSeparableConv(
        (conv_dw): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=40, bias=False)
        (bn1): BatchNormAct2d(
          40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(40, 10, kernel_size=(1, 1), stride=(1, 1))
          (act1): SiLU(inplace=True)
          (conv_expand): Conv2d(10, 40, kernel_size=(1, 1), stride=(1, 1))
          (gate): Sigmoid()
        )
        (conv_pw): Conv2d(40, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn2): BatchNormAct2d(
    

In [None]:
#KL Divergence
#It can be calculated as the negative sum of probability of each event in P multiplied by the log of the probability of the event in Q over the probability of the event in P.

def distillation_loss(student_output, labels, teacher_output, T=3):
    soft_labels = nn.functional.softmax(teacher_output / T, dim=1)
    return nn.functional.cross_entropy(student_output, labels) + nn.functional.kl_div(
        nn.functional.log_softmax(student_output / T, dim=1), soft_labels, reduction='batchmean') * (T * T)

optimizer = optim.Adam(student_model.parameters(), lr=0.001)


In [None]:
#Model Validation
def compute_accuracy(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in dataloader:
            images, labels = data
            images, labels = images.cuda(), labels.cuda()
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

num_epochs = 50

# Training Loop
for epoch in range(num_epochs):
    student_model.train()
    total_loss = 0.0

    for data in trainloader:
        inputs, labels = data
        inputs, labels = inputs.cuda(), labels.cuda()

        optimizer.zero_grad()

        with torch.no_grad():
            teacher_outputs = teacher_model(inputs)

        student_outputs = student_model(inputs)
        loss = distillation_loss(student_outputs, labels, teacher_outputs)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # Training loss and validation accuracy
    train_loss = total_loss / len(trainloader)
    val_accuracy = compute_accuracy(student_model, valloader)
    print(f'Epoch {epoch + 1}, Loss: {train_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')



In [None]:
torch.save(student_model.state_dict(), 'student_model.pth')