In [55]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim

In [56]:
batch_size = 64
img_size = 32
patch_size = 8
num_channels = 3
num_patches = (img_size // patch_size) ** 2
num_heads = 24
embed_dim = 768
mlp_dim = 16
transformer_units = 6

In [57]:
transform = transforms.Compose(
    [transforms.ToTensor()])

In [58]:
#load dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
valset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transform)

In [59]:
#create train and val batches
train_data = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True)
val_data = torch.utils.data.DataLoader(valset, batch_size=batch_size,
                                          shuffle=False)

In [60]:
class PatchEmbedding(nn.Module):
  def __init__(self):
    super().__init__()
    self.proj=nn.Conv2d(num_channels,embed_dim,kernel_size=patch_size,stride=patch_size)
  def forward(self,x):
    x=self.proj(x)
    x=x.flatten(2).transpose(1,2)
    return x

In [61]:
class Block(nn.Module):
  def __init__(self):
    super().__init__()
    self.ln1=nn.LayerNorm(embed_dim)
    self.attn=nn.MultiheadAttention(embed_dim,num_heads,batch_first=True)
    self.ln2=nn.LayerNorm(embed_dim)
    self.ff=nn.Sequential(nn.Linear(embed_dim,4*embed_dim),nn.GELU(),nn.Linear(4*embed_dim,embed_dim))
  def forward(self,x):
    x=x+self.attn(self.ln1(x),self.ln1(x),self.ln1(x))[0]
    x=x+self.ff(self.ln2(x))
    return x

In [62]:
class DeiT(nn.Module):
  def __init__(self):
    super().__init__()
    self.patch_embed=PatchEmbedding()
    self.cls_token=nn.Parameter(torch.randn((1,1,embed_dim)))
    self.dist_token=nn.Parameter(torch.randn((1,1,embed_dim)))
    self.pos_embed=nn.Parameter(torch.randn((1,num_patches+2,embed_dim)))
    self.blocks=nn.ModuleList([Block() for _ in range(transformer_units)])
    self.norm=nn.LayerNorm(embed_dim)
    self.cls_head=nn.Linear(embed_dim,10)
    self.dist_head=nn.Linear(embed_dim,10)
  def forward(self,x):
    B=x.size(0)
    x=self.patch_embed(x)
    cls_token=self.cls_token.expand(B,-1,-1)
    dist_token=self.dist_token.expand(B,-1,-1)
    x=torch.cat((cls_token,dist_token,x),dim=1)
    x=x+self.pos_embed
    for block in self.blocks:
      x=block(x)
    x=self.norm(x)
    cls_token=self.cls_head(x[:,0])
    dist_token=self.dist_head(x[:,1])
    return cls_token,dist_token



In [63]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DeiT().to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

In [64]:
for epoch in range(5):
    model.train()
    total_loss = 0
    correct_epoch = 0
    total_epoch = 0
    print(f"\nEpoch {epoch+1}")

    for batch_idx, (images, labels) in enumerate(train_data):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        cls_out, dist_out = model(images)
        loss_cls = criterion(cls_out, labels)
        loss_dist = criterion(dist_out, labels)
        loss = 0.5 * loss_cls + 0.5 * loss_dist
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        preds = (cls_out+dist_out).argmax(dim=1)
        correct = (preds == labels).sum().item()
        correct_epoch += correct
        total_epoch += labels.size(0)

        if batch_idx % 100 == 0:
            acc = 100.0 * correct / labels.size(0)
            print(f"  Batch {batch_idx+1:3d}: Loss={loss.item():.4f}, Accuracy={acc:.2f}%")

    epoch_acc = 100.0 * correct_epoch / total_epoch
    print(f"==> Epoch {epoch+1} Summary: Total Loss={total_loss:.4f}, Accuracy={epoch_acc:.2f}%")



Epoch 1
  Batch   1: Loss=2.6069, Accuracy=7.81%
  Batch 101: Loss=2.0498, Accuracy=21.88%
  Batch 201: Loss=1.8871, Accuracy=37.50%
  Batch 301: Loss=2.0224, Accuracy=25.00%
  Batch 401: Loss=1.9653, Accuracy=40.62%
  Batch 501: Loss=1.8015, Accuracy=28.12%
  Batch 601: Loss=1.4910, Accuracy=45.31%
  Batch 701: Loss=1.5908, Accuracy=45.31%
==> Epoch 1 Summary: Total Loss=1449.4591, Accuracy=32.38%

Epoch 2
  Batch   1: Loss=1.6271, Accuracy=42.19%
  Batch 101: Loss=1.5591, Accuracy=40.62%
  Batch 201: Loss=1.2982, Accuracy=50.00%
  Batch 301: Loss=1.3152, Accuracy=51.56%
  Batch 401: Loss=1.5943, Accuracy=39.06%
  Batch 501: Loss=1.5530, Accuracy=40.62%
  Batch 601: Loss=1.4504, Accuracy=42.19%
  Batch 701: Loss=1.6334, Accuracy=39.06%
==> Epoch 2 Summary: Total Loss=1161.5433, Accuracy=46.85%

Epoch 3
  Batch   1: Loss=1.3347, Accuracy=50.00%
  Batch 101: Loss=1.3858, Accuracy=39.06%
  Batch 201: Loss=1.3436, Accuracy=50.00%
  Batch 301: Loss=1.6385, Accuracy=39.06%
  Batch 401: Los

In [65]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in val_data:
        images, labels = images.to(device), labels.to(device)
        cls_out, dist_out = model(images)
        preds = (cls_out+dist_out).argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

test_acc = 100.0 * correct / total
print(f"\n==> Val Accuracy: {test_acc:.2f}%")


==> Val Accuracy: 54.72%
