In [38]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import random

n_embed = 48
head_size = 12
n_heads = 4
n_layers = 7
dropout = 0.4
mlp_ratio = 2
device = 'cuda'
block_size = 64

class Patching(nn.Module):
    def __init__(self,
                 in_channels=3,
                 patch_size=4,
                 embedding_dim=n_embed,
                ):
        super().__init__()
        self.patch = nn.Sequential(nn.Conv2d(in_channels, embedding_dim,
                                             kernel_size=(patch_size, patch_size),
                                             stride=(patch_size, patch_size),
                                            ),
                                   nn.Flatten(2, 3),
                )

    def forward(self, x):
        return self.patch(x).transpose(-2, -1)

class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.value = nn.Linear(n_embed, head_size, bias=False)

    def forward(self, x, attention_mask=None):
        # 32,87,40
        B,T,C = x.shape
        k = self.key(x)
        # k -> 32,87,10
        q = self.query(x)
        v = self.value(x)
        w = torch.bmm(k,q.transpose(-2, -1)) * (n_embed ** -0.5)
        if attention_mask is not None:
            attention_mask = attention_mask.unsqueeze(-1).float()
            w = w * attention_mask
        w = F.softmax(w, dim=-1)
        out = torch.bmm(w,v)
        return out

class MultiHead(nn.Module):
    def __init__(self,head_size,n_heads):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(n_heads)])
        self.proj = nn.Linear(n_embed,n_embed)
    def forward(self,x,attention_mask):
        out = torch.cat([head(x,attention_mask) for head in self.heads],-1)
        out = self.proj(out)
        return out

class FeedForward(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed,n_embed * mlp_ratio),
            nn.ReLU(),
            nn.Linear(n_embed * mlp_ratio,n_embed),
            nn.Dropout(dropout),
        )
    def forward(self,x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self):
        super().__init__()
        self.multihead = MultiHead(head_size,n_heads)
        self.ffwd = FeedForward()
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)

    def forward(self,x,attention_mask):
        x = self.ln1(x)
        x = x + self.multihead(x,attention_mask)
        x = self.ln2(x)
        x = x + self.ffwd(x)
        return x


class ViT(nn.Module):
    def __init__(self):
        super().__init__()
        self.patch_embedding = Patching()
        self.positional_embedding = nn.Embedding(block_size,n_embed)
        self.blocks = nn.ModuleList([Block() for _ in range(n_layers)])
        self.ln = nn.LayerNorm(n_embed)
        self.cl_head = nn.Sequential(
            nn.Linear(n_embed,n_embed * mlp_ratio),
            nn.ReLU(),
            nn.Linear(n_embed * mlp_ratio,10)
        )
        self.sequence_pooling = nn.Linear(n_embed,1)
    def forward(self,x,attention_mask=None,targets=None):
        ini_emb = self.patch_embedding(x)
        # ini_emb = torch.cat([ini_emb,self.class_embedding.expand(x.shape[0],-1,-1)],dim=1)
        # B,1025,768
        # b,t b=batch, t = tokens
        B,N,S =  ini_emb.shape
        pos_emb = self.positional_embedding(torch.arange(N,device=device))
        x = ini_emb + pos_emb
        # b,t,c = 1,1024,768
        for block in self.blocks:
            x = block(x,attention_mask)
        x = self.ln(x) # B,N,D
        seq_pool = self.sequence_pooling(x).transpose(-2,-1) # B,1,N
        seq_pool = torch.nn.functional.softmax(out[1],dim=2) # B,1,N
        x = torch.bmm(seq_pool,x).squeeze(1) # B,D
        x = self.cl_head(x)
        return x

In [39]:
device = 'cpu'
model = ViT()
out = model(torch.randn(32,3,32,32))

In [41]:
out.shape

torch.Size([32, 10])