In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from tqdm import tqdm

import images and make dataloader

In [2]:
train_dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_dataset = datasets.MNIST(root="dataset/", transform=transforms.ToTensor(), download=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=16, shuffle=True)

In [3]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=1, embed_dim=768, patch_size=4):
        super(PatchEmbedding, self).__init__()
        self.embed_dim=embed_dim
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
    
    def forward(self, x):
        x = self.proj(x)  # (B, embed_dim, H/patch_size, W/patch_size)
        x = x.permute(0, 2, 3, 1)  # (B, H/patch_size, W/patch_size, embed_dim)
        B = x.shape[0]
        x = x.reshape([B, -1, self.embed_dim])
        return x
    

class PatchMaker(nn.Module):
    def __init__(self, patch_size=4):   # d_model= 768
        super().__init__()
        self.x_axis=patch_size[0]
        self.y_axis=patch_size[1]
        self.embeddings=PatchEmbedding()
        self.linear=nn.Linear(self.x_axis*self.y_axis, 768, bias=False)     #add channels

    def forward(self, images):
        # [batch_size, channels, x_axis, y_axis] -> [batch_size, patches, channels, patchsize_x, patchsize_y]
        assert(images.shape[-1]%self.x_axis==0 and images.shape[-2]%self.y_axis==0),"image dimentions aren't divisible by patch size"
        channels=images.shape[1]
        batch_size=images.shape[0]
        patches = images.view([batch_size, channels, -1, self.x_axis, self.y_axis]).transpose(-3,-4)
        #patches = images.unfold(1, self.y_axis, self.y_axis).unfold(2, self.x_axis, self.x_axis)
        #patches = patches.contiguous().view(batch_size, channels, -1, self.x_axis, self.y_axis)  #[batch_Size,channels,patchs,x,y]
        num_patches=patches.shape[-4]
        #patches = patches.permute(0, 2, 1, 3, 4)   #[batch_size,patches,channels,x,y]
        flattened_patches = patches.view([batch_size, num_patches, channels, -1])  # flattened
        output=self.linear(flattened_patches)
        
        return output

In [4]:
#20, 20, 3 -> 4x 10,10,3 ->4X 1,100*3 nn.Linear()->

In [5]:
class ViT(nn.Module):
    def __init__(self, embed_dim=768, img_size=224, patch_size=16):
        super().__init__()
        self.num_patches=(img_size//patch_size)**2
        self.embed_dim=embed_dim
        self.patchemb=PatchEmbedding(embed_dim=embed_dim)
        #self.patchmaker=PatchMaker()
        self.pos_emb = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))  #1 for class token
        self.class_token=nn.Parameter(torch.zeros(1,1,embed_dim))   #emb_size=768
        pass

    def forward(self, x):
        batch_size = x.shape[0]
        x = self.patchemb(x)

        #patches = self.patchmaker(x)

        class_token = self.class_token.expand([batch_size, 1, self.embed_dim])
        x = torch.cat((class_token, x), dim=1)

        x = x+self.pos_emb
        
        return (x.shape, class_token.shape)
        

In [6]:
class VisionTransformer(nn.Module):
    def __init__(self, img_size=28, patch_size=4, in_channels=1, num_classes=10, embed_dim=768, num_heads=12, num_layers=12, dropout=0.1):
        super(VisionTransformer, self).__init__()
        
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        
        self.patch_embeddings = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        
        self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))  # +1 for class token
        self.class_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        
        self.dropout = nn.Dropout(dropout)
        
        self.transformer_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(embed_dim, num_heads, dropout=dropout)
            for _ in range(num_layers)
        ])
        
        self.fc = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        batch_size = x.size(0)

        # Create patches and embed
        x = self.patch_embeddings(x)  # Shape: (batch_size, embed_dim, num_patches_height, num_patches_width)
        x = x.flatten(2).transpose(1, 2)  # Shape: (batch_size, num_patches, embed_dim)

        # Add class token
        class_token = self.class_token.expand(batch_size, -1, -1)  # Shape: (batch_size, 1, embed_dim)
        x = torch.cat((class_token, x), dim=1)  # Shape: (batch_size, num_patches + 1, embed_dim)

        # Add position embeddings
        x += self.position_embeddings
        x = self.dropout(x)

        # Pass through transformer layers
        for layer in self.transformer_layers:
            x = layer(x)

        # Use the class token's output for classification
        x = x[:, 0]  # Shape: (batch_size, embed_dim)
        x = self.fc(x)
        return x


In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [8]:
model = VisionTransformer()
model.to(device)


VisionTransformer(
  (patch_embeddings): Conv2d(1, 768, kernel_size=(4, 4), stride=(4, 4))
  (dropout): Dropout(p=0.1, inplace=False)
  (transformer_layers): ModuleList(
    (0-11): 12 x TransformerEncoderLayer(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
      )
      (linear1): Linear(in_features=768, out_features=2048, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (linear2): Linear(in_features=2048, out_features=768, bias=True)
      (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout1): Dropout(p=0.1, inplace=False)
      (dropout2): Dropout(p=0.1, inplace=False)
    )
  )
  (fc): Linear(in_features=768, out_features=10, bias=True)
)

In [9]:
criterion=nn.CrossEntropyLoss()
optimizer=optim.Adam(model.parameters(), lr=0.001)

In [10]:
num_epochs = 10

for epoch in range(num_epochs):
  model.train()
  # Training loop
  for images, labels in tqdm(train_loader):

    images, labels = images.to(device), labels.to(device)

    # Forward pass
    outputs = model(images)
    loss = criterion(outputs, labels)

    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  # Calculate accuracy
  model.eval()
  correct = 0
  total = 0
  with torch.no_grad():
    for images, labels in train_loader:
      images, labels = images.to(device), labels.to(device)
      outputs = model(images)
      _, predicted = torch.max(outputs.data, 1)
      total += labels.size(0)
      correct += (predicted == labels).sum().item()

  accuracy = 100 * correct / total
  print(f'Epoch: {epoch+1}/{num_epochs}, Accuracy: {accuracy:.2f}%')

  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
  0%|          | 3/938 [00:04<21:16,  1.37s/it]