In [None]:
import torch
import torch.nn as nn
import torch.nn.utils as utils
import torch.nn.utils.prune as prune
import torch.optim as optim
import torch.utils.data as data_utils
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torchvision.models as models
import torch.quantization
from torchvision.models import vision_transformer as ViT


class TeacherViT(nn.Module):
    def __init__(self):
        super(TeacherViT, self).__init__()
        self.vit = models.vit_b_16(weights=ViT.ViT_B_16_Weights.IMAGENET1K_V1)
        self.vit.heads = nn.Linear(self.vit.hidden_dim, 10)

    def forward(self, x):
        return self.vit(x)

class StudentViT(nn.Module):
    def __init__(self):
        super(StudentViT, self).__init__()
        self.vit = models.vit_b_16()
        self.vit.heads = nn.Linear(self.vit.hidden_dim, 10)
        # decrease num of transformer blocks
        self.vit.encoder.layers = self.vit.encoder.layers[:6]

    def forward(self, x):
        return self.vit(x)

class DistillationLoss(nn.Module):
    def __init__(self, alpha=0.5, temperature=2.0):
        super(DistillationLoss, self).__init__()
        self.alpha = alpha
        self.temperature = temperature
        self.kl_div = nn.KLDivLoss(reduction="batchmean")
        self.ce_loss = nn.CrossEntropyLoss()

    def forward(self, student_logits, teacher_logits, labels):
        kd_loss = self.kl_div(
            torch.log_softmax(student_logits / self.temperature, dim=1),
            torch.softmax(teacher_logits / self.temperature, dim=1)
        ) * (self.temperature ** 2)
        ce_loss = self.ce_loss(student_logits, labels)
        return self.alpha * kd_loss + (1 - self.alpha) * ce_loss

def train(teacher_model, student_model, train_loader, optimizer, distillation_loss, device):
    teacher_model.eval()
    student_model.train()
    print("Length: ", len(train_loader))
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        with torch.no_grad():
            teacher_output = teacher_model(data)
        student_output = student_model(data)
        loss = distillation_loss(student_output, teacher_output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 10 == 0:
          print(f"Batch {batch_idx+1}, Loss: {loss.item()}")

def evaluate(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
    return 100. * correct / total

def prune_model(model, amount=0.3):
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            prune.l1_unstructured(module, name='weight', amount=amount)
    return model

def qat_model(model):
    model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
    # Only prepare the linear modules for quantization
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            torch.quantization.prepare_qat(module, inplace=True)
    return model



In [None]:
# Initialization
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_subset = data_utils.Subset(trainset, range(8000))
test_subset = data_utils.Subset(trainset, range(2000))
train_loader = DataLoader(train_subset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_subset, batch_size=64, shuffle=False)
print("Data loaded successfully!")

teacher_model = TeacherViT().to(device)
student_model = StudentViT().to(device)
print("Models initialized successfully!")


Using cuda device
Files already downloaded and verified
Files already downloaded and verified
Data loaded successfully!
Models initialized successfully!


In [None]:
# Knowledge distillation
optimizer = optim.Adam(student_model.parameters(), lr=0.001)
distillation_loss = DistillationLoss()

for epoch in range(10):
    train(teacher_model, student_model, train_loader, optimizer, distillation_loss, device)
    accuracy = evaluate(student_model, test_loader, device)
    print(f"Epoch {epoch+1}, Accuracy: {accuracy}%")


Length:  125
Batch 1, Loss: 1.3523688316345215
Batch 11, Loss: 1.4267396926879883
Batch 21, Loss: 1.241178274154663
Batch 31, Loss: 1.2082445621490479
Batch 41, Loss: 1.1923753023147583
Batch 51, Loss: 1.1969530582427979
Batch 61, Loss: 1.1723254919052124
Batch 71, Loss: 1.2015101909637451
Batch 81, Loss: 1.1453453302383423
Batch 91, Loss: 1.1224390268325806
Batch 101, Loss: 1.1280912160873413
Batch 111, Loss: 1.1324542760849
Batch 121, Loss: 1.1429380178451538
Epoch 1, Accuracy: 19.15%
Length:  125
Batch 1, Loss: 1.1268446445465088
Batch 11, Loss: 1.162122130393982
Batch 21, Loss: 1.1351104974746704
Batch 31, Loss: 1.1078426837921143
Batch 41, Loss: 1.1341239213943481
Batch 51, Loss: 1.1478160619735718
Batch 61, Loss: 1.128321647644043
Batch 71, Loss: 1.1012498140335083
Batch 81, Loss: 1.1305257081985474
Batch 91, Loss: 1.1339735984802246
Batch 101, Loss: 1.0814881324768066
Batch 111, Loss: 1.2038604021072388
Batch 121, Loss: 1.1270062923431396
Epoch 2, Accuracy: 23.6%
Length:  125
Ba

In [None]:
pruned_model = prune_model(student_model)
accuracy = evaluate(pruned_model, test_loader, device)
print(f"Pruned model accuracy: {accuracy}%")


Pruned model accuracy: 26.3%


In [None]:
# QAT
pruned_model = StudentViT().to(device)
qat_model = qat_model(pruned_model.train())
optimizer = optim.Adam(qat_model.parameters(), lr=0.0001)

for epoch in range(5):
    train(teacher_model, qat_model, train_loader, optimizer, distillation_loss, device)
    accuracy = evaluate(qat_model, test_loader, device)
    print(f"QAT Epoch {epoch+1}, Accuracy: {accuracy}%")

quantized_model = torch.quantization.convert(qat_model.eval(), inplace=False)

# Save model
torch.jit.save(torch.jit.script(quantized_model), "quantized_vit.pt")



Length:  125
Batch 1, Loss: 1.316054344177246
Batch 11, Loss: 1.2011091709136963
Batch 21, Loss: 1.1418251991271973
Batch 31, Loss: 1.1504271030426025
Batch 41, Loss: 1.1098215579986572
Batch 51, Loss: 1.1115397214889526
Batch 61, Loss: 1.1399153470993042
Batch 71, Loss: 1.0721808671951294
Batch 81, Loss: 1.1016970872879028
Batch 91, Loss: 1.1157045364379883
Batch 101, Loss: 1.1227526664733887
Batch 111, Loss: 1.031515121459961
Batch 121, Loss: 1.1034893989562988
QAT Epoch 1, Accuracy: 28.15%
Length:  125
Batch 1, Loss: 1.1279518604278564
Batch 11, Loss: 1.1133867502212524
Batch 21, Loss: 1.0771567821502686
Batch 31, Loss: 1.024730920791626
Batch 41, Loss: 1.0906916856765747
Batch 51, Loss: 1.0369982719421387
Batch 61, Loss: 1.0119776725769043
Batch 71, Loss: 1.0649549961090088
Batch 81, Loss: 1.0364967584609985
Batch 91, Loss: 1.0247160196304321
Batch 101, Loss: 1.014143466949463
Batch 111, Loss: 1.048240065574646
Batch 121, Loss: 1.0948542356491089
QAT Epoch 2, Accuracy: 38.6%
Length