In [None]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from torch import optim
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, Normalize, RandomHorizontalFlip, Resize, ToTensor
from transformers import get_cosine_schedule_with_warmup
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore")

### Patch Embedding

In [None]:
class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=786):
        super(PatchEmbed, self).__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.n_patches = (img_size // patch_size) ** 2

        self.proj = nn.Conv2d(in_channels=in_chans,
                              out_channels=embed_dim,
                              kernel_size=patch_size,
                              stride=patch_size)

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2) 
        x = x.transpose(1,2) 
        return x

### Attention

In [None]:

class Head(nn.Module):
    def __init__(self, embed_dim=768, head_dim=768, attn_p=0):
        super(Head, self).__init__()
        self.query = nn.Linear(embed_dim, head_dim)
        self.key = nn.Linear(embed_dim, head_dim)
        self.value = nn.Linear(embed_dim, head_dim)
        self.attn_dropout = nn.Dropout(attn_p)

    def forward(self, x):
        batch_size, n_patch, embed_dim = x.shape
        q = self.query(x) 
        k = self.key(x)
        v = self.value(x) 

        sam = (q @ k.transpose(-2,-1)) * embed_dim**-0.5 
        attn = sam.softmax(dim=-1)
        attn = self.attn_dropout(attn)
        weighted_average = attn @ v 
        return weighted_average
    
class MultiHeadedAttention(nn.Module):
    def __init__(self, embed_dim=768, num_heads=12, attn_p=0, proj_p=0):
        super(MultiHeadedAttention, self).__init__()
        self.head_size = embed_dim // num_heads
        self.heads = nn.ModuleList([Head(embed_dim=embed_dim, head_dim=self.head_size, attn_p=attn_p) for _ in range(num_heads)])
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.proj_drop = nn.Dropout(proj_p)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1) 
        out = self.proj_drop(self.proj(out)) 
        return out

### Efficient Attention

In [None]:
class EfficientAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, attn_p, proj_p):
        super(EfficientAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_size = int(self.embed_dim / num_heads)

        self.qkv = nn.Linear(embed_dim, embed_dim*3)
        self.attn_dropout = nn.Dropout(attn_p)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.proj_drop = nn.Dropout(proj_p)

    def forward(self, x):
        batch, patches, embed_dim = x.shape 
        qkv = self.qkv(x)
        qkv = qkv.reshape(batch, patches, 3, self.num_heads, self.head_size) 
        qkv = qkv.permute(2, 0, 3, 1, 4)  
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        sam = (q @ k.transpose(-2,-1)) * self.head_size**-0.5 
        attn = sam.softmax(dim=-1)
        attn = self.attn_dropout(attn)
        weighted_average = attn @ v 
        weighted_average = weighted_average.transpose(1,2) 
        weighted_average = weighted_average.flatten(2) 
        out = self.proj_drop(self.proj(weighted_average))
        return out

In [None]:
class MLP(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, mlp_p=0):
        
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(mlp_p)

    def forward(self, x):
        x = self.act(self.fc1(x)) 
        x = self.drop(x)
        x = self.fc2(x) 
        x = self.drop(x)
        return x