In [139]:
import torch

In [140]:
device = "cuda"

In [141]:
BATCH_SIZE = 12

In [142]:
IMG_SIZE = 28

In [143]:
PATCH_SIZE = 7

In [144]:
CLASSES = 10

In [145]:
EPOCHS_STUDENT = 64

In [146]:
LR_STUDENT = 3e-5

In [147]:
CHANNELS = 3

In [148]:
ATTENTION_HEADS = 4
TRANSFORMER_LAYERS = 4
TEMPERATURE = 2
ALPHA = 0.5

In [149]:
import torchvision
import torchvision.datasets as dataset

#@title import data

data_transformations = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Lambda(lambda x: x.repeat(3, 1, 1))
])

train_ds = dataset.MNIST("./data", train=True, download=True, transform=data_transformations)
test_ds = dataset.MNIST("./data", train=False, download=True, transform=data_transformations)

In [150]:
from torch.utils.data import DataLoader, Subset

train_ds = Subset(train_ds, range(1000))
test_ds = Subset(test_ds, range(1000))

train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)

In [151]:
len(train_ds)

1000

In [152]:
import torchvision.models as models
teacher = models.resnet50(weights = models.ResNet50_Weights.IMAGENET1K_V2)

In [153]:
import torch.nn as nn
teacher.fc = nn.Linear(teacher.fc.in_features, CLASSES)
teacher.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [154]:
# Patch_EMB

EMBED_DIM = 32
img_size = IMG_SIZE

class PatchEmbed(nn.Module):
    def __init__(self,
                 channels = CHANNELS,
                 attention_heads = ATTENTION_HEADS,
                 classes = CLASSES,
                 embed_dim = EMBED_DIM,
                 patch_size = PATCH_SIZE):
        super(PatchEmbed, self).__init__()

        self.conv = nn.Conv2d(
            channels,
            embed_dim,
            patch_size,
            patch_size)
        self.n = (img_size // patch_size) ** 2
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = self.conv(x)
        x = x.flatten(2)
        x = x.transpose(1, 2)
        x = self.proj(x)
        x = self.norm(x)
        return x


# ViT class

In [155]:
class ViT(nn.Module):
    def __init__(self,
                 embed_dim=EMBED_DIM,
                 attention_heads=ATTENTION_HEADS,
):
        super().__init__()

        self.patch_embed = PatchEmbed(
            channels=CHANNELS,
            attention_heads=ATTENTION_HEADS,
            classes=CLASSES,
            embed_dim=EMBED_DIM,
            patch_size=PATCH_SIZE)

        self.cls = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.distillation = nn.Parameter(torch.zeros(1, 1, embed_dim))

        self.n = self.patch_embed.n
        self.pos_embed = nn.Parameter(
            torch.zeros(1, self.n + 2, embed_dim))

        qkv_dim = embed_dim // attention_heads
        self.transformer_blocks = nn.Sequential(
            *[nn.TransformerEncoderLayer(
              d_model=embed_dim,
              nhead=attention_heads,
              batch_first=True)
              for _ in range(TRANSFORMER_LAYERS)])
        self.layernorm = nn.LayerNorm(embed_dim)

        self.head_cls = nn.Linear(embed_dim, CLASSES)
        self.head_distil = nn.Linear(embed_dim, CLASSES)

    def forward(self, x):
      B = x.shape[0]
      x = self.patch_embed(x)
      cls = self.cls.expand(B, -1, -1)
      dist = self.distillation.expand(B, -1, -1) # Corrected from self.cls

      x = torch.cat([cls, x, dist], dim=1) + self.pos_embed

      x = self.transformer_blocks(x)
      x = self.layernorm(x)

      cls_token = x[:, 0]
      distil_token = x[:, -1]

      cls_logits = self.head_cls(cls_token)
      distil_logits = self.head_distil(distil_token)
      return cls_logits, distil_logits

In [156]:
student = ViT().to(device)
opt_s = torch.optim.Adam(student.parameters(), lr=LR_STUDENT)

In [157]:
student.parameters()

<generator object Module.parameters at 0x7da935445700>

In [158]:
import torch.nn.functional as F

def kd_loss(
    s_logits,
    t_logits,
    y,
    alpha = ALPHA,
    temperature = TEMPERATURE
):
  ce = F.cross_entropy(s_logits, y)
  kd = F.kl_div(F.softmax(s_logits / temperature, dim=1).log(),
                F.softmax(t_logits / temperature, dim=1),
                reduction='batchmean') * (temperature ** 2)
  return alpha * ce + (1 - alpha) * kd

In [159]:
for e in range(EPOCHS_STUDENT):
  student.train()
  for x, y in train_dl:

    x = x.to(device)
    y = y.to(device)
    with torch.no_grad():
      t_logits = teacher(x)

    s_logits, _ = student(x)
    loss = kd_loss(s_logits, t_logits, y)
    opt_s.zero_grad()
    loss.backward()
    opt_s.step()
    print(f"Epoch: {e+1}, Loss: {loss.item():.4f}")

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Epoch: 5, Loss: 1.0571
Epoch: 5, Loss: 1.0158
Epoch: 5, Loss: 1.0675
Epoch: 5, Loss: 1.0131
Epoch: 5, Loss: 1.0594
Epoch: 5, Loss: 1.0058
Epoch: 5, Loss: 1.0404
Epoch: 5, Loss: 1.1426
Epoch: 5, Loss: 1.0577
Epoch: 5, Loss: 1.1091
Epoch: 5, Loss: 0.9865
Epoch: 5, Loss: 1.0479
Epoch: 5, Loss: 1.1217
Epoch: 5, Loss: 1.0684
Epoch: 5, Loss: 1.0474
Epoch: 5, Loss: 1.0148
Epoch: 5, Loss: 1.0181
Epoch: 5, Loss: 1.0723
Epoch: 5, Loss: 1.1081
Epoch: 5, Loss: 1.1005
Epoch: 5, Loss: 1.0827
Epoch: 5, Loss: 1.0911
Epoch: 5, Loss: 1.0588
Epoch: 5, Loss: 1.0188
Epoch: 5, Loss: 0.9926
Epoch: 5, Loss: 1.0071
Epoch: 5, Loss: 1.0421
Epoch: 5, Loss: 1.0441
Epoch: 5, Loss: 0.9951
Epoch: 5, Loss: 1.0429
Epoch: 5, Loss: 0.9773
Epoch: 5, Loss: 1.1128
Epoch: 5, Loss: 1.1301
Epoch: 5, Loss: 1.0752
Epoch: 5, Loss: 1.0570
Epoch: 5, Loss: 0.9739
Epoch: 5, Loss: 1.0501
Epoch: 5, Loss: 1.0237
Epoch: 5, Loss: 0.9262
Epoch: 5, Loss: 0.9825
Epoch: 5, Loss:

In [160]:
# evaluate now
print("evaluate now")


evaluate now


In [161]:
# find accuracy of trained model
student.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in test_dl:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs, _ = student(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

Accuracy of the network on the 10000 test images: 66 %
