In [61]:
import torch
from torch import nn
import lightning
from einops import rearrange
from einops import repeat
from torch.nn import functional as F

In [44]:
class ImagePathces(nn.Module):
    def __init__(self, img_size, in_ch:int = 3,patch_size: int = 16):
        super().__init__()
        assert img_size % patch_size == 0

        emb_dim = patch_size * patch_size * in_ch
        self.cls_token = nn.Parameter(torch.randn(1,1,emb_dim))
        self.pos_emb = nn.Parameter(torch.randn(1,1+(img_size//patch_size)**2,emb_dim))
        self.patches = nn.Conv2d(in_ch,emb_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.patches(x)

        # x = torch.flatten(x,2)
        # b, e, hw = x.shape
        # x = x.reshape((b,hw,e))
        x = rearrange(x, "b e h w -> b (h w) e")
        
        b, _, _ = x.shape

        #cls_token = self.cls_token.repeat((b,1,1))
        cls_token = repeat(self.cls_token,"() s e -> b s e",b=b)

        x = torch.concat((x, cls_token),dim=1)

        return self.pos_emb + x

In [81]:
patch = ImagePathces(224,3)
x = torch.randn(2, 3, 224, 224)
patches = patch(x)
patches.shape

torch.Size([2, 197, 768])

In [50]:
class MLP(nn.Module):
    def __init__(self,input_size: int, hidden_size: int = None, output_size: int = None, droput = 0.,layer_act: torch.nn = nn.GELU) -> None:
        super().__init__()
        if hidden_size is None:
            hidden_size = input_size
        if output_size is None:
            output_size = input_size

        self.fc1 = nn.Linear(input_size,hidden_size)
        self.fc2 = nn.Linear(hidden_size,output_size)
        self.actv = layer_act()
        self.drop = nn.Dropout(droput)

    def forward(self,x):
        return self.drop(self.fc2(self.actv(self.fc1(x))))

In [55]:
mlp = MLP(234,345,10,droput=0.2,layer_act=nn.ReLU)

x = torch.randn((2,234))

mlp(x).shape

torch.Size([2, 10])

In [189]:
class Attention(nn.Module):
    def __init__(self,dim: int,qkv_bias: bool = False, num_heads: int =8, attn_drop=0., out_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim,dim*3,bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.out_drop = nn.Dropout(out_drop)
        self.out_lin = nn.Linear(dim,dim)

    def forward(self,x):
        x = self.attn_drop(self.qkv(x))
        b, e, _ = x.shape
        # b e (q k v)
        x = x.reshape((b,3,self.num_heads,e,-1))
        # b qkv heads e s 
        
        att = F.log_softmax((x[:,0] @ x[:,1].transpose(2,3))*self.scale,dim=3)

        att = att @ x[:,2]
        
        #out = out.reshape(b,e,-1)
        att = rearrange(att, "b h n e -> b n (h e)")

        out = self.out_drop(self.out_lin(att))

        return out

In [188]:
att = Attention(768)
att(patches).shape

torch.Size([2, 197, 768])

In [248]:
class Block(nn.Module):
    def __init__(self, dim, qkv_bias: bool = False, num_heads:int = 8, mlp_ratio:int = 4, drop_rate: float = 0.) -> None:
        super().__init__()
        self.att = Attention(dim,num_heads=num_heads,qkv_bias=qkv_bias)

        self.att_norm = nn.LayerNorm(dim)

        self.mlp = MLP(dim,dim*mlp_ratio,dim)
        
        self.mlp_norm = nn.LayerNorm(dim)
        

    def forward(self, x):
        att = self.att(x)
        att_out = self.att_norm(x + att)

        mlp = self.mlp(att_out)
        mlp_out = self.mlp_norm(x + mlp)
        return mlp_out

In [212]:
bl = Block(768)

bl(patches).shape

torch.Size([2, 197, 768])

In [213]:
class Encoder(nn.Module):
    def __init__(self, dim, qkv_bias: bool = False, N:int = 6, num_heads: int = 8, mlp_ratio: int = 4, drop_rate: float = 0.) -> None:
        super().__init__()
        self.blocks =nn.ModuleList([
            Block(dim,qkv_bias,num_heads,mlp_ratio,drop_rate)
        for _ in range(N)
        ])

    def forward(self,x):
        for block in self.blocks:
            x = block(x)
        return x

In [215]:
enc = Encoder(768)

enc(patches).shape

torch.Size([2, 197, 768])

In [251]:
class ViT(nn.Module):
    def __init__(self, img_size: int, patch_size: int, num_classes: int, 
                 in_ch: int = 3, depth: int=6, num_heads: int = 8, 
                 mlp_ratio: int = 4, qkv_bias=False, drop_rate: float = 0.) -> None:
        super().__init__()

        emb_dim = in_ch*patch_size**2
        
        self.patch = ImagePathces(img_size = img_size,
                                  in_ch = in_ch,
                                  patch_size = patch_size)

        self.encoder = Encoder(dim = emb_dim,
                               qkv_bias = qkv_bias,
                               N = depth,
                               num_heads = num_heads,
                               mlp_ratio = mlp_ratio,
                               drop_rate = drop_rate)
        
        self.mlp_head = nn.Linear(emb_dim,num_classes)
    
    def forward(self,x):
        x = self.patch(x)

        x = self.encoder(x)

        x = self.mlp_head(x)
        return F.softmax(x,dim=2)

In [252]:
x = torch.randn(2, 3, 224, 224)

In [253]:
x.shape

torch.Size([2, 3, 224, 224])

In [260]:
vit = ViT(img_size=224,
          patch_size=16,
          num_classes=10)

### Training

In [3]:
from typing import List, Union
from cifar_cnn.data.data_loader import CIFAR10DataModule
#from cifar_cnn.models.cnn import CIFAR10Model
from cifar_cnn.models.vit import ViT
import torch
import lightning as L
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import LearningRateMonitor
import wandb

ACCELERATOR = "gpu" if torch.cuda.is_available() else "cpu"
CONF = {
    "architecture": "ViT", 
    "batch_size": 20
    }


class ImageCallback(L.Callback):
    def __init__(self) -> None:
        super().__init__()
        self.outputs = None
        self.x = None
        self.y = None


    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if batch_idx == trainer.num_training_batches-1:
            self.x, self.y = batch
            self.outputs = torch.argmax(outputs["prediction"],dim=1)


    def on_train_epoch_end(self, trainer, pl_module):
        n = 10
        x, y = self.x, self.y

        images = [img for img in x[:n]]
        captions = [f'Target: {y_i} - Prediction: {y_pred}' 
            for y_i, y_pred in zip(y[:n], self.outputs[:n])]

        trainer.logger.log_image(
                key='sample_images', 
                images=images, 
                caption=captions)


callbacks =[
    LearningRateMonitor(logging_interval='step'),
    ImageCallback()  
]

   
def train(epoch: int = 10,
          device: str = "auto",
          lr: float = 2e-3,
          path: str  = "/CIFAR10/datasets/raw") -> None:

    data_module = CIFAR10DataModule(path)
    model_module = ViT(
        img_size=32,
        patch_size=8,
        num_classes=10,
        in_ch=3,
        num_heads= 2
    )
    
    LOGGER = WandbLogger(log_model=True, name=f"{CONF['architecture']}-lr({lr})-epoch({epoch})")
    try:
        trainer = L.Trainer(
            accelerator=ACCELERATOR,
            devices=device,
            max_epochs=epoch,
            callbacks= callbacks
        )
        wandb.run

        trainer.fit(model_module, datamodule = data_module)           
        trainer.test(model_module,datamodule = data_module)

        wandb.finish(0)
    except RuntimeError:
        wandb.finish(1)
        

In [4]:
train(epoch=8,lr=1e-5)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/root/miniconda3/envs/dl_env/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


Files already downloaded and verified
Files already downloaded and verified
TrainerFn.FITTING


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type         | Params
------------------------------------------
0 | patch    | ImagePathces | 40.5 K
1 | encoder  | Encoder      | 2.7 M 
2 | mlp_head | Linear       | 1.9 K 
------------------------------------------
2.7 M     Trainable params
0         Non-trainable params
2.7 M     Total params
10.833    Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]