<a href="https://colab.research.google.com/github/NidaNabi/simple-vision-transformer/blob/main/visiontrans.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
from torchvision.datasets import FashionMNIST
import torch.utils.data as Data
import torchvision.transforms as transforms
train_data = FashionMNIST(root="./",train=True,transform=transforms.ToTensor(),download=True)
train_loader = Data.DataLoader(dataset=train_data,batch_size=64,shuffle=True,num_workers=0)  #torch.Size([64, 1, 28, 28])

class Attention(nn.Module):
    def __init__(self,dim=768,head_num=8,drop1=0.,drop2=0.):
        super(Attention, self).__init__()
        self.linear = nn.Linear(dim,dim*3)
        self.W0 = nn.Linear(dim,dim)
        self.drop1 = nn.Dropout(drop1)
        self.drop2 = nn.Dropout(drop2)
        self.d = (dim/head_num)**-0.5
    def forward(self,x):
        batch,N,C = x.shape
        qkv = self.linear(x) 
        qkv = self.drop1(qkv)
        QKV = qkv.view(batch,N,3,8,-1)
        QKV = QKV.permute(2,0,3,1,4)
        q,k,v = QKV[0],QKV[1],QKV[2]
        attention = nn.functional.softmax((q@k.transpose(-1,-2))/self.d,dim=-1)
        attention = attention @ v
        attention = attention.transpose(1,2) 
        attention = attention.reshape(batch,N,C)
        attention = self.W0(attention)
        attention = self.drop2(attention)
        return attention

class Encoder_block(nn.Module):
    def __init__(self,drop_attention,drop_mlp,dim):
        super(Encoder_block, self).__init__()
        self.layer_norm = nn.LayerNorm(dim)
        self.attention = Attention()
        self.drop_attention = nn.Dropout(drop_attention)
        self.mlp = MLP()
        self.drop_mlp = nn.Dropout(drop_mlp)
    def forward(self,x):
        y = self.layer_norm(x)
        y = self.attention(y)
        y = self.drop_attention(y)
        z = y + x
        k = self.layer_norm(z)
        k = self.mlp(k)
        k = self.drop_mlp(k)
        return z+k

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.l = nn.Sequential(nn.Linear(768,1500),
                               nn.GELU(),
                               nn.Dropout(0.2),
                               nn.Linear(1500,768),
                               nn.Dropout(0.))
    def forward(self,x):
        return self.l(x)

class VIT(nn.Module):
    def __init__(self,batchsize=64,dim=768,drop_pos=0.,drop_attention=0.,drop_mlp=0.,classes=10):
        super(VIT, self).__init__()
        self.cls_token = nn.Parameter(torch.zeros(batchsize,1,dim))
        self.embeding = nn.Sequential(nn.Conv2d(1,dim,2,2)) #torch.Size([64, 768, 14, 14])
        self.pos = nn.Parameter(torch.zeros(batchsize,197,dim))
        self.pos_drop = nn.Dropout(drop_pos)
        self.encoder_block = Encoder_block(drop_attention,drop_mlp,dim)
        self.layer_norm = nn.LayerNorm(dim)
        self.mlphead = nn.Sequential(nn.Linear(dim,2000),
                                     nn.Tanh(),
                                     nn.Linear(2000,classes))
    def forward(self,x):
        x=(self.embeding(x)).flatten(2) #torch.Size([64, 768, 196])
        x = x.transpose(1,2) #torch.Size([64, 196, 768])
        x = torch.cat([self.cls_token,x],dim=1)  # torch.Size([64, 197, 768])
        x = self.pos + x # torch.Size([64, 197, 768])
        x = self.pos_drop(x) # torch.Size([64, 197, 768])
        for i in range(12):
            x = self.encoder_block(x)
        x = self.layer_norm(x)
        x = ((x.transpose(0,1))[0])
        return self.mlphead(x)
vit = VIT()
opt = torch.optim.Adam(vit.parameters(),lr=0.001)
loss = nn.CrossEntropyLoss()
def train(epoch,model,loader,optim,loss):
    model.train()
    for i in range(epoch):
        for j,(x,y) in enumerate(loader):
            y_p = model(x)
            l = loss(y_p,y)
            optim.zero_grad()
            l.backward()
            opt.step()
            print(l.item())


train(10,vit,train_loader,opt,loss)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./FashionMNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/26421880 [00:00<?, ?it/s]

Extracting ./FashionMNIST/raw/train-images-idx3-ubyte.gz to ./FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./FashionMNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/29515 [00:00<?, ?it/s]

Extracting ./FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./FashionMNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/4422102 [00:00<?, ?it/s]

Extracting ./FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/5148 [00:00<?, ?it/s]

Extracting ./FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./FashionMNIST/raw

2.3075625896453857
4.685646057128906
4.172881603240967
3.2076404094696045
6.592207431793213
4.815701484680176
4.107974529266357
4.20037841796875
3.437725305557251
3.6584601402282715
3.4304473400115967
2.9258852005004883
2.862039089202881
2.349529266357422
2.5239665508270264
3.1689720153808594
2.8336949348449707
2.3147199153900146
2.404148578643799
2.5409023761749268
2.464268207550049
2.3830718994140625
2.414585590362549
2.373171329498291
2.51781964302063
2.322843551635742
2.539923906326294
2.4608092308044434
2.554372787475586
2.4675300121307373
