In [40]:
import numpy as np 
import torch 
from tqdm import tqdm,trange
from torch.nn import CrossEntropyLoss 
from torch.optim import Adam 
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor 
from torchvision.datasets.mnist import MNIST 
from torchvision import transforms
from torch import nn

In [41]:
torch.manual_seed(0)

<torch._C.Generator at 0x2d719057230>

In [55]:
class VIT(nn.Module):
    def __init__(self,chw=(1,28,28),n_patches=7,n_heads=2,
                 n_blocks=2,hidden_d=2,out_d=10):
        super(VIT,self).__init__()

        self.chw = chw 
        self.n_patches = n_patches
        self.hidden_d = hidden_d
        self.n_blocks = n_blocks
        self.n_heads = n_heads 
        self.out_d = out_d


        #28//7 = 4
        assert chw[1]% n_patches == 0, 'Input shape should be divisble by n_patches'
        assert chw[2]% n_patches == 0, 'Input shape should be divisble by n_patches'

        self.patch_size = (chw[1]/ n_patches, chw[2] / n_patches) #4,4

        #linear mapper 
        self.input_d = int(self.chw[0]*self.patch_size[0]*self.patch_size[1]) #1* 4*4 = 16 
        self.linear_mapper = nn.Linear(self.input_d,self.hidden_d)

        #learnable classification token
        """nn. Parameter is used to explicitly specify which tensors should be treated as the model's learnable parameters. 
        So that those tensors are learned (updated) during the training process to minimize the loss function."""
        self.class_token = nn.Parameter(torch.randn(1,self.hidden_d))

        #positional embedding 
        self.pos_embed = nn.Parameter(torch.randn(self.n_patches**2 + 1,self.hidden_d))
        self.pos_embed.requires_grad = False 

        #Encoder block 
        self.encoder_blocks = nn.ModuleList(
            [EncoderVIT(self.hidden_d,n_heads) for _ in range(n_blocks)]
            )

        #classification MLP 
        self.mlp = nn.Sequential(nn.Linear(self.hidden_d,out_d),
                                 nn.Softmax(dim=-1)
                                 )

    def forward(self,images):

        n,c,h,w = images.shape #n,c,h,w ->n,1,28,28
        patches = self.patch_embedding(images,self.n_patches) #n,49,16
        #print(patches.shape)
        tokens = self.linear_mapper(patches.to(images.device)) #n,49,8
        tokens = torch.stack([torch.vstack((self.class_token,tokens[i])) for i in range(len(tokens))]) #n,50,8
        
        pos_embed = self.pos_embed.repeat(n,1,1).to(images.device) #n,50,8
        out = tokens + pos_embed

        #transformer block 
        for block in self.encoder_blocks:
            out = block(out)
            print(out.shape)

        #classification token 
        out = out[:,0] #1,8
        
        
        return self.mlp(out)
    """The @staticmethod decorator in Python is used to define a method that belongs to a class but does not access any properties or methods of the class.
      Here’s a detailed explanation of the role and usage of @staticmethod"""

    @staticmethod
    def patch_embedding(images,n_patches): # 7*7 -> n_patches
        n,c,h,w = images #n,c,h,w ->n,1,28,28

        assert h==w ,'Patch embedding required the dimensions of the height and width to be the same'

        patches =torch.zeros(n,n_patches**2,h*w*c//n_patches**2,device = images.device) # h//n_patches,w//n_patches ->28/7 = 4->N,49,16
        patch_size = h//n_patches


        for idx,image in enumerate(images):
            for i in range(n_patches):
                for j in range(n_patches):
                    patch = image[:,i*patch_size:(i+1)*patch_size,j*patch_size:(j+1)*patch_size]
                    #image 2D patch 0--->4 4-----12 
                    patches[idx,i*n_patches+j] = patch.flatten()

        return patches, #n,49,16 

    @staticmethod
    def positional_embedding(sequence_length,d): #n,49,8
        """ 
        p(i,j) ={sin(i/10000^(j/d(emd_dim))) if j is even ,j represent the position of the dimension
        cos(i/10000^(j/d(emd_dim))) if j is odd}
        """
        result = torch.ones(sequence_length,d)
        for i in range(sequence_length):
            for j in range(d): #j->dimension 
                result[i][j] = np.sin(i /(10000**(j/d))) if j%2 == 0 else np.cos(i/(10000**((j-1) /d)))
        
        return result 
    




                

In [56]:
class MultiHeadAttention(nn.Module):
    def __init__(self,d,n_heads):
        super(MultiHeadAttention,self).__init__()

        self.d = d 
        self.n_heads = n_heads

        assert d% n_heads == 0 ,f'Dimension{d} is not divisble by head : {n_heads}'

        #patches --> q,k,v -->n_heads 

        d_head = int(d/n_heads) #8/2 = 4 
        self.q = nn.ModuleList([nn.Linear(d_head,d_head) for _ in range(self.n_heads)]) #(4,4) (4,4)
        self.k = nn.ModuleList([nn.Linear(d_head,d_head) for _ in range(self.n_heads)]) #(4,4) (4,4)
        self.v = nn.ModuleList([nn.Linear(d_head,d_head) for _ in range(self.n_heads)]) #(4,4) (4,4)
        self.d_head = d_head 
        self.softmax = nn.Softmax(dim=-1)

    def forward (self,sequences):
        #N,sequence_length,token_dim #n,50,8/d heads 
        #patch 8 / head 
        result = []
        for sequence in sequences:
            seq_res = [ ]
            for head in range(self.n_heads): #0,1 [2 times]
                q_mapping = self.q[head]
                k_mapping = self.k[head]
                v_mapping = self.v[head]
                seq = sequence[: head * self.d_head : (head+1)* self.d_head]
                q,k,v = q_mapping(seq),k_mapping(seq),v_mapping(seq)

                attention = self.softmax(q@k.T /(self.d_head ** 0.5))
                seq_res.append(attention@v)
            result.append(torch.hstack(seq_res))
        return torch.cat([torch.unsqueeze(r,dim=0) for r in result])



             


In [57]:
class EncoderVIT(nn.Module):
    def __init__(self,hidden_d,n_heads,mlp_ratio=4):
        super(EncoderVIT,self).__init__()
        self.hidden_d = hidden_d
        self.n_heads = n_heads
        self.mlp_ratio = mlp_ratio
        
        self.norm1 = nn.LayerNorm(hidden_d)
        self.mhsa = MultiHeadAttention(hidden_d,n_heads)
        self.norm2 = nn.LayerNorm(hidden_d)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_d,mlp_ratio*hidden_d),
            nn.GELU(),
            nn.Linear(mlp_ratio * hidden_d,hidden_d)
        )

        def forward(self,x):
            out = x + self.mhsa(self.norm1(x))
            out = out + self.mlp(self.norm2(out))
            return out
        

In [61]:

def main():
    transform = ToTensor()
    train_set = MNIST(root='./datasets', train=True, download=True, transform=transform)
    test_set = MNIST(root='./datasets', train=False, download=True, transform=transform)

    train_loader = DataLoader(train_set, shuffle=True, batch_size=128)
    test_loader = DataLoader(test_set, shuffle=True, batch_size=128)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = VIT(chw=(1,28,28), n_patches=7, n_heads=2, n_blocks=2, hidden_d=2, out_d=10).to(device)
    
    n_epochs = 5 
    lr = 0.001

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    for epoch in trange(n_epochs, desc='Training'):
        train_loss = 0.0 
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1} in training"):
            x, y = batch
            x = x.to(device)
            y = y.to(device)
            y_hat = model(x)
            loss = criterion(y_hat, y)

            train_loss += loss.detach().cpu().item() / len(train_loader)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch+1} loss : {train_loss}")
        
    with torch.no_grad():
        correct, total = 0, 0
        test_loss = 0.0
        for batch in tqdm(test_loader, desc="testing"):
            x, y = batch
            x = x.to(device)
            y = y.to(device)
            y_hat = model(x)
            loss = criterion(y_hat, y)
            test_loss += loss.detach().cpu().item() / len(test_loader)
            pred = torch.argmax(y_hat, dim=1)
            correct += torch.sum(pred == y).detach().cpu().item()
            total += len(y)
        print(f"Test loss : {test_loss}")
        print(f"Accuracy : {correct / total}")
    
    torch.save(model.state_dict(), 'model.pth')

main()


Epoch 1 in training: 100%|██████████| 469/469 [07:08<00:00,  1.09it/s]
Training:  20%|██        | 1/5 [07:08<28:33, 428.32s/it]

Epoch 1 loss : 2.3017357397181133


Epoch 2 in training: 100%|██████████| 469/469 [06:53<00:00,  1.14it/s]
Training:  40%|████      | 2/5 [14:01<20:58, 419.36s/it]

Epoch 2 loss : 2.2900404513263495


Epoch 3 in training: 100%|██████████| 469/469 [20:11<00:00,  2.58s/it]
Training:  60%|██████    | 3/5 [34:13<26:02, 781.26s/it]

Epoch 3 loss : 2.2615771618987455


Epoch 4 in training: 100%|██████████| 469/469 [07:24<00:00,  1.06it/s]
Training:  80%|████████  | 4/5 [41:37<10:48, 648.21s/it]

Epoch 4 loss : 2.2502054364950688


Epoch 5 in training: 100%|██████████| 469/469 [07:21<00:00,  1.06it/s]
Training: 100%|██████████| 5/5 [48:59<00:00, 587.91s/it]


Epoch 5 loss : 2.2429188766967507


testing: 100%|██████████| 79/79 [00:44<00:00,  1.78it/s]

Test loss : 2.24631729307054
Accuracy : 0.2153



