<a href="https://colab.research.google.com/github/alecbidaran/Pytorch_excersies/blob/main/visual_trasformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch-snippets

Collecting torch-snippets
  Downloading torch_snippets-0.426-py3-none-any.whl (25 kB)
Collecting loguru
  Downloading loguru-0.5.3-py3-none-any.whl (57 kB)
[K     |████████████████████████████████| 57 kB 3.6 MB/s 
Collecting rich
  Downloading rich-10.6.0-py3-none-any.whl (208 kB)
[K     |████████████████████████████████| 208 kB 9.3 MB/s 
Collecting fastcore
  Downloading fastcore-1.3.20-py3-none-any.whl (53 kB)
[K     |████████████████████████████████| 53 kB 2.4 MB/s 
Collecting colorama<0.5.0,>=0.4.0
  Downloading colorama-0.4.4-py2.py3-none-any.whl (16 kB)
Collecting commonmark<0.10.0,>=0.9.0
  Downloading commonmark-0.9.1-py2.py3-none-any.whl (51 kB)
[K     |████████████████████████████████| 51 kB 6.6 MB/s 
[?25hInstalling collected packages: commonmark, colorama, rich, loguru, fastcore, torch-snippets
Successfully installed colorama-0.4.4 commonmark-0.9.1 fastcore-1.3.20 loguru-0.5.3 rich-10.6.0 torch-snippets-0.426


In [None]:
import torch 
from torchvision import transforms,datasets
from torch_snippets import *
import numpy as np 
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import torchvision.models as models
from collections import OrderedDict

In [None]:
train_transform=transforms.Compose([transforms.Resize((32,32)),
                                  transforms.RandomRotation(0.2),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
valid_transform=transforms.Compose([
                                    transforms.Resize((32,32)),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])
test_transform=transforms.Compose([transforms.Resize((32,32)),
                                  transforms.RandomRotation(0.4),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])

In [None]:
train_dataset=datasets.CIFAR100(root="./data",download=True,train=True,transform=train_transform)
valid_dataset=datasets.CIFAR100(root="./data",download=True,train=False,transform=train_transform)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
resnet=models.resnet18(pretrained=True)
for params in resenet.parameters():
  params.requires_grad=False
print(resnet)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [None]:
class VIT(torch.nn.Module):
  def __init__(self,num_classes=100):
    super(VIT,self).__init__()
    layers=OrderedDict()
    for name,modules in resnet.named_children():
      if name in ['conv1','bn1','relu','maxpool','layer1',]:
        layers[name]=modules
    self.backbone=torch.nn.Sequential(layers)
    #self.patch=torch.nn.Conv2d(512,256,1)
    self.flatten=torch.nn.Flatten(2,3)
    self.patches=torch.nn.Unfold(kernel_size=16,stride=16,padding=1)
    self.embedding=torch.nn.Embedding((32//16)**2,768)
    self.transformer_encoder=torch.nn.Sequential(*[torch.nn.TransformerEncoderLayer(768,12,dim_feedforward=3072,dropout=0.1) for _ in range(12)])
    self.projection=torch.nn.Linear(768,768)
    self.mlp=torch.nn.Sequential(torch.nn.InstanceNorm1d(768),
                                 torch.nn.Linear(768,768),
                                 torch.nn.GELU(),
                                 torch.nn.Dropout(0.2),
                                 torch.nn.Linear(768,100))
  def forward(self,x):
    #x=self.backbone(x)
    x=self.patches(x).transpose(-2,-1)
    x=self.projection(x)
    embedding=torch.nn.Embedding(16,768,device=device)
    pos=torch.arange(0,(32//16)**2).to(device).long()
    pos_enc=embedding(pos.unsqueeze(0))
    for i in range(768,2):
      wk=0.0001**(2*i*768/16)
      x[:][i]+=torch.sin(pos_enc.squeeze(0)*wk)
      x[:][i+1]+=torch.cos(pos_enc.squeeze(0)*wk)
    x1=self.transformer_encoder(x)
    x1=self.mlp(x1)
    return x1

In [None]:
from torchsummary import summary

In [None]:
model=VIT(num_classes=100).to(device)
summary(model,(3,32,32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Unfold-1               [-1, 768, 4]               0
            Linear-2               [-1, 4, 768]         590,592
MultiheadAttention-3  [[-1, 4, 768], [-1, 2, 2]]               0
           Dropout-4               [-1, 4, 768]               0
         LayerNorm-5               [-1, 4, 768]           1,536
            Linear-6              [-1, 4, 3072]       2,362,368
           Dropout-7              [-1, 4, 3072]               0
            Linear-8               [-1, 4, 768]       2,360,064
           Dropout-9               [-1, 4, 768]               0
        LayerNorm-10               [-1, 4, 768]           1,536
TransformerEncoderLayer-11               [-1, 4, 768]               0
MultiheadAttention-12  [[-1, 4, 768], [-1, 2, 2]]               0
          Dropout-13               [-1, 4, 768]               0
        LayerNorm-14          

In [None]:
loss_fn=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=0.001)

In [None]:

def train_batch(batch,model,loss_fn,optimizer):
  image,label=batch
  image,label=image.to(device),torch.nn.functional.one_hot(label,num_classes=100).to(device)
  optimizer.zero_grad()
  logit=model(image)
  loss=loss_fn(logit,label)
  loss.backward()
  optimizer.step()
  acc=(torch.max(logit,1)[1]==label).float().mean()
  return loss.item(),acc.item()
@torch.no_grad()
def valid_batch(batch,model,loss_fn):
  model.eval()
  image,label=batch
  image,label=image.to(device),torch.nn.functional.one_hot(label,num_classes=100).to(device)
  logit=model(image)
  loss=loss_fn(logit,label)
  acc=(torch.max(logit,1)[1]==label).float().mean()
  return loss.item(),acc.item()

In [None]:
trn_dl=torch.utils.data.DataLoader(train_dataset,batch_size=64,shuffle=True)
val_dl=torch.utils.data.DataLoader(valid_dataset,batch_size=64,shuffle=False)

In [None]:
n_epoch=100
log=Report(n_epoch)
for epochs in range(n_epoch):
  N = len(trn_dl)
  for i,data in enumerate(trn_dl):
    loss,acc=train_batch(data,model,loss_fn,optimizer)
    log.record(epochs+(i+1)/N,trn_loss=loss,trn_acc=acc,end='\r')
  N = len(val_dl)
  for b,data in enumerate(val_dl):
    loss,acc=valid_batch(data,model,loss_fn)
    log.record(epochs+(b+1)/N,val_loss=loss,val_acc=acc,end='\r')
  if epochs>=10:
    optimizer=torch.optim.Adam(model.parameters(),lr=0.0001) 
  log.report_avgs(epochs+1)

EPOCH: 1.000	trn_loss: 1.388	trn_acc: 0.250	val_loss: 1.386	val_acc: 0.361	(161.19s - 15957.80s remaining)
EPOCH: 2.000	trn_loss: 1.386	trn_acc: 0.459	val_loss: 1.386	val_acc: 0.544	(320.03s - 15681.63s remaining)
EPOCH: 3.000	trn_loss: 1.386	trn_acc: 0.545	val_loss: 1.386	val_acc: 0.672	(478.91s - 15484.63s remaining)
EPOCH: 4.000	trn_loss: 1.386	trn_acc: 0.605	val_loss: 1.386	val_acc: 0.560	(637.90s - 15309.71s remaining)
EPOCH: 5.000	trn_loss: 1.386	trn_acc: 0.647	val_loss: 1.386	val_acc: 0.672	(796.71s - 15137.42s remaining)
EPOCH: 6.000	trn_loss: 1.386	trn_acc: 0.684	val_loss: 1.386	val_acc: 0.684	(955.46s - 14968.86s remaining)
EPOCH: 7.000	trn_loss: 1.386	trn_acc: 0.700	val_loss: 1.386	val_acc: 0.677	(1114.27s - 14803.83s remaining)
EPOCH: 8.000	trn_loss: 1.386	trn_acc: 0.719	val_loss: 1.386	val_acc: 0.759	(1272.97s - 14639.20s remaining)
EPOCH: 9.000	trn_loss: 1.386	trn_acc: 0.754	val_loss: 1.386	val_acc: 0.821	(1431.69s - 14475.93s remaining)
EPOCH: 10.000	trn_loss: 1.386	trn_

KeyboardInterrupt: ignored

In [None]:
log.plot_epochs(log=True)

In [None]:
img,label=next(iter(val_dl))
logits=model(img.to(device))
_,mask=torch.max(logits.data,1)
acc=(mask==torch.nn.functional.one_hot(label,num_classes=100).to(device)).float().mean()
acc

tensor(0.9900, device='cuda:0')

# BERT