In [1]:
import torchvision.models
import torch

import numpy as np

from tqdm import tqdm, trange

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader

from torchvision.transforms import ToTensor
from torchvision.datasets.mnist import MNIST

np.random.seed(0)
torch.manual_seed(0)

device = torch.device("cuda")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
vit16 = torchvision.models.vit_b_16()
transform = ToTensor()

train_set = MNIST(root='./../datasets', train=True, download=True, transform=transform)
test_set = MNIST(root='./../datasets', train=False, download=True, transform=transform)

train_loader = DataLoader(train_set, shuffle=True, batch_size=128)
test_loader = DataLoader(test_set, shuffle=False, batch_size=128)

In [3]:
import torch.nn as nn

class Attention(nn.Module):
  def __init__(self,embed_size,dout,heads=4):
    super().__init__()
    self.Q2d = nn.Linear(embed_size,embed_size)
    self.K2d = nn.Linear(embed_size,embed_size)
    self.V2d = nn.Linear(embed_size,embed_size)
    self.d2O = nn.Linear(embed_size,embed_size)
    
    assert embed_size % heads == 0
    self.heads = heads
    self.h_dim = embed_size // heads
    self.drop_out = nn.Dropout(dout)
  def forward(self,Q,K,V,mask=None):
    B , T , D = Q.shape
    Q = self.Q2d(Q)
    K = self.K2d(K)
    V = self.V2d(V)
    Q = Q.reshape(B,T,self.heads,self.h_dim)
    K = K.reshape(B,T,self.heads,self.h_dim)
    V = V.reshape(B,T,self.heads,self.h_dim)
    attention = torch.einsum("bqhd,bkhd->bhqk",[Q,K])
    attention = attention / (D**(1/2))
    if mask is not None:
      attention = attention.masked_fill(mask,-float("inf"))
    attention = torch.softmax(attention,dim=-1)
    out = torch.einsum("bhqk,bkhd->bqhd",[attention,V])
    out = self.drop_out(out)
    out = out.reshape(B,T,D)
    out = self.d2O(out)
    return out

class ResidualNetwork(nn.Module):
  def __init__(self,embed_size,dout):
    super().__init__()
    self.ln = nn.LayerNorm(embed_size)
    self.drop_out = nn.Dropout(dout)
  def forward(self,x,sublayer):    
    res = self.ln(x)
    res = sublayer(res)
    res = self.drop_out(res)
    return x + res

class EncoderLayer(nn.Module):
  def __init__(self,embed_size,dout):
    super().__init__()
    self.residual = ResidualNetwork(embed_size,dout)
    self.attention = Attention(embed_size,dout)
    self.feedforward = nn.Sequential(
        nn.Linear(embed_size,32),
        nn.ReLU(True),
        nn.Dropout(dout),
        nn.Linear(32,embed_size),
    )
  def forward(self,x,mask=None):
    x = self.residual(x,lambda x:self.attention(x,x,x,mask))
    x = self.residual(x,self.feedforward)
    return x


def get_positional_embedding(seq,embed_size):
  mat = torch.ones(seq,embed_size)
  for k in range(seq):
    for i in range(embed_size // 2):
      denominator = 10000 ** (2*i / embed_size)
      mat[k,2*i] = np.sin(k/denominator)
      mat[k,2*i + 1] = np.cos(k/denominator)
  return mat

class vit(nn.Module):
  def __init__(self,embed_size,dout=.1,out_size=10):
    super().__init__()
    self.conv2d = nn.Conv2d(1,16,kernel_size=(4,4),stride=4,bias=False)
    self.embed_size = embed_size
    self.mapping = nn.Linear(16,embed_size)
    self.cls_token = nn.Parameter(torch.rand(1,embed_size).to(device))
    self.EncodingLayer = nn.ModuleList([EncoderLayer(embed_size,dout) for _ in range(3)])
    self.mlp = nn.Sequential(
        nn.Linear(embed_size,out_size),
        nn.Softmax(dim=-1)
    )
  def forward(self,x,mask=None):
    B ,_, _ , _ = x.shape
    x = self.conv2d(x)
    x = x.reshape(x.size(0),x.size(1),-1).permute(0,2,1)
    x = self.mapping(x)
    x = torch.cat((self.cls_token.expand(B,1,-1),x),dim=1)

    
    x += get_positional_embedding(x.size(1),self.embed_size).to(device)
    for encodingLayer in self.EncodingLayer:
        x = encodingLayer(x)

    x = x[:,0]
    x = self.mlp(x)
    return x

In [None]:
model = vit(8).to(device)
optimizer = torch.optim.Adam(model.parameters(),lr=0.005) #,betas=(0.9,0.999),weight_decay=.001)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=5,eta_min=0,last_epoch=-1)
criterion = nn.CrossEntropyLoss()

for epoch in range(20):
    total_loss = 0
    for batch in tqdm(train_loader):
        optimizer.zero_grad()
        x,y = batch
        x , y = x.to(device),y.to(device)
        pred = model(x)
        loss = criterion(pred,y)
        total_loss += loss.detach().item() / len(train_loader)
        loss.backward()
        optimizer.step()
    print(f"epoch {epoch}: loss {loss}")
    # scheduler.step()


 40%|█████████████████████████████████████████████████████████████████████▊                                                                                                         | 187/469 [00:07<00:10, 27.17it/s]

In [None]:
# Test loop
model.eval()
with torch.no_grad():
    correct, total = 0, 0
    test_loss = 0.0
    for batch in tqdm(test_loader, desc="Testing"):
        x, y = batch
        x, y = x.to(device), y.to(device)
        y_hat = model(x)
        loss = criterion(y_hat, y)
        test_loss += loss.detach().cpu().item() / len(test_loader)

        correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item()
        total += len(x)
    print(f"Test loss: {test_loss:.2f}")
    print(f"Test accuracy: {correct / total * 100:.2f}%")