In [35]:
import numpy as np

from tqdm import tqdm, trange

import torch
import torch.nn as nn
from torch.optim import Adam, SGD
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader

from torchvision import transforms
from torchvision.datasets.cifar import CIFAR10

np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7df71b67a870>

### Classes

In [2]:
class Patcher(nn.Module):
  def __init__(self, patch_size):
    super(Patcher, self).__init__()
    self.patch_size=patch_size
    self.unfold = torch.nn.Unfold(kernel_size=patch_size, stride=patch_size)

  def forward(self, images):
    batch_size, channels, height, width = images.shape
    patch_height, patch_width = [self.patch_size, self.patch_size]
    assert height % patch_height == 0 and width % patch_width == 0, "Height and width must be divisible by the patch size."

    patches = self.unfold(images) #bs (cxpxp) N
    patches = patches.view(batch_size, channels, patch_height, patch_width, -1).permute(0, 4, 1, 2, 3) # bs N C P P

    return patches

In [3]:
class TransformerBlock(nn.Module):
    def __init__(self, model_dim, num_heads, mlp_ratio=4.0, dropout=0.1):
        super(TransformerBlock, self).__init__()
        self.norm1 = nn.LayerNorm(model_dim)
        self.attn = nn.MultiheadAttention(model_dim, num_heads, dropout=dropout)
        self.norm2 = nn.LayerNorm(model_dim)

        # Feedforward network
        self.mlp = nn.Sequential(
            nn.Linear(model_dim, int(model_dim * mlp_ratio)),
            nn.GELU(),
            nn.Linear(int(model_dim * mlp_ratio), model_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        # Self-attention
        x = self.norm1(x)
        attn_out, _ = self.attn(x, x, x)
        x = x + attn_out

        # Feedforward network
        x = self.norm2(x)
        mlp_out = self.mlp(x)
        x = x + mlp_out

        return x

In [83]:
class ViT_RGB(nn.Module):
  def __init__(self, img_size, patch_size, model_dim= 30, num_heads=3, num_layers=2, n_classes=10):
    super().__init__()
    self.img_size = img_size
    self.patch_size = patch_size
    self.n_patches = (self.img_size // self.patch_size) ** 2
    self.model_dim = model_dim
    self.num_layers = num_layers
    self.num_heads= num_heads
    self.n_classes = n_classes

    # 1) Patching
    self.patcher = Patcher(patch_size=self.patch_size)

    # 2) Linear Prjection
    self.linear_projector = nn.Linear( 3 * self.patch_size ** 2, self.model_dim)

    # 3) Class Token
    self.class_token = nn.Parameter(torch.rand(1, 1, self.model_dim)) # This common for all images! TODO

    # 4) Positional Embedding
    self.positional_embedding = nn.Parameter(torch.rand(1,(img_size // patch_size) ** 2 + 1, model_dim))

    # 5) Transformer blocks
    self.blocks = nn.ModuleList([
        TransformerBlock( self.model_dim,  self.num_heads) for _ in range(num_layers)
    ])

    # 6) Classification MLPk
    self.mlp = nn.Sequential(
            nn.Linear(self.model_dim, self.n_classes),
            nn.Softmax(dim=-1)
        )

  def forward(self, x):

    x = self.patcher(x)

    x = x.flatten(start_dim=2)
    x = self.linear_projector(x)

    batch_size = x.shape[0]
    class_token = self.class_token.expand(batch_size, -1, -1)
    x = torch.cat((class_token, x), dim=1)

    x = x + self.positional_embedding

    for block in self.blocks:
      x = block(x)

    latent = x[:, 0]
    # latent = x.mean(dim=1)
    logits = self.mlp(latent)

    return logits, latent

### Test Classes

In [7]:
%pdb off
class ViTTester():
  def __init__(self):
    self.vit = ViT_RGB(img_size=32, patch_size=4, model_dim=20, num_heads=4, num_layers=3, n_classes=10)

  def test(self):
    images = torch.randn(7, 3, 32, 32)
    tokens = self.vit(images)
    print(tokens[0].shape)
    return tokens

tokens = ViTTester().test()

Automatic pdb calling has been turned OFF
torch.Size([7, 10])


In [None]:
!wget https://picsum.photos/32 -O image.jpg

In [72]:
from PIL import Image
import matplotlib.pyplot as plt
import cv2
from mpl_toolkits.axes_grid1 import ImageGrid
import numpy as np

class PatcherTester():
  def __init__(self, image_path, patch_size):
    self.patcher = Patcher(patch_size)
    self.image_path = image_path

  def plot_original(self):
    img_src = self.image_path
    image = Image.open(img_src)
    return image

  def plot_patches(self):
    img_src = self.image_path
    image = Image.open(img_src)
    image = np.array(image)
    image = image.astype('float32') / 255.0  # Normalize to [0, 1]
    image = torch.from_numpy(image)
    image = image.permute(2,0,1)
    image = image.unsqueeze(0) #to add the batch dimension
    p = self.patcher(image)
    p = p.squeeze()
    fig = plt.figure(figsize=(8, 8))
    grid = ImageGrid(fig, 111, nrows_ncols=(4, 4), axes_pad=0.1)
    for i, ax in enumerate(grid):
        patch = p[i].permute(1, 2, 0).numpy()
        ax.imshow(patch)
        ax.axis('off')
    plt.show()

In [None]:
patcher = PatcherTester("./image.jpg", 75)
patcher.plot_original()

In [None]:
patcher = PatcherTester("./image.jpg", 4)
patcher.plot_patches()

### Train

In [84]:
del model
torch.cuda.empty_cache()

In [85]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ViT_RGB(img_size=32, patch_size=4, model_dim=100, num_heads=4, num_layers=3, n_classes=10).to(device)

In [86]:
transform_train = transforms.Compose([
    # transforms.RandomCrop(32, padding=4),
    transforms.Resize(32),
    # transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

train_set = CIFAR10(root='./datasets', train=True, download=True, transform=transform_train)
test_set = CIFAR10(root='./datasets', train=False, download=True, transform=transform_test)

train_loader = DataLoader(train_set, shuffle=True, batch_size=64)
test_loader = DataLoader(test_set, shuffle=False, batch_size=64)


Files already downloaded and verified
Files already downloaded and verified


In [None]:
n_epochs = 200
lr = 0.0001

optimizer = Adam(model.parameters(), lr=lr)
criterion = CrossEntropyLoss()

for epoch in range(n_epochs):
    train_loss = 0.0
    for i,batch in enumerate(train_loader):
        x, y = batch
        x, y = x.to(device), y.to(device)
        y_hat, latent = model(x)
        loss = criterion(y_hat, y)

        batch_loss = loss.detach().cpu().item()
        train_loss += batch_loss / len(train_loader)

        optimizer.zero_grad()
        cls0 = model.class_token.clone()
        loss.backward()
        optimizer.step()
        cls1 = model.class_token.clone()

        # if((cls0 == cls1).sum()):
        #   print("cls==")
        #   break

        if i%100==0:
          print(f"Batch {i}/{len(train_loader)} loss: {batch_loss:.03f}")

    print(f"Epoch {epoch + 1}/{n_epochs} loss: {train_loss:.03f}")

In [77]:
# Test loop
with torch.no_grad():
    correct, total = 0, 0
    test_loss = 0.0
    for batch in tqdm(test_loader, desc="Testing"):
        x, y = batch
        x, y = x.to(device), y.to(device)
        y_hat, latent = model(x)
        loss = criterion(y_hat, y)
        test_loss += loss.detach().cpu().item() / len(test_loader)

        correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item()
        total += len(x)
    print(f"Test loss: {test_loss:.2f}")
    print(f"Test accuracy: {correct / total * 100:.2f}%")

Testing: 100%|██████████| 157/157 [00:04<00:00, 34.99it/s]

Test loss: 1.78
Test accuracy: 68.15%



