<a href="https://colab.research.google.com/github/G0nkly/pytorch_sandbox/blob/main/vits/classification/ViT_classification(bounding_boxes).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torchinfo

In [None]:
###############################
# DATASET WITH BOUNDING BOXES #
###############################

In [None]:
!wget https://www.cis.upenn.edu/~jshi/ped_html/PennFudanPed.zip

In [None]:
!unzip PennFudanPed.zip

In [None]:
import torch
import random
import timeit
import math
import numpy
import numpy as np
from torch import optim
from torch import nn
from torchinfo import summary
import torch.nn.functional as F
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Dataset
from dataclasses import dataclass
from torch.optim.lr_scheduler import CosineAnnealingLR
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
import os

%matplotlib inline

In [None]:
class PennFudanSinglePed(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform

        self.img_dir = os.path.join(root, "PNGImages")
        self.mask_dir = os.path.join(root, "PedMasks")

        self.images = sorted(os.listdir(self.img_dir))
        self.masks = sorted(os.listdir(self.mask_dir))

        self.valid_indices = []
        for idx in range(len(self.images)):
            mask = np.array(Image.open(os.path.join(self.mask_dir, self.masks[idx])))
            obj_ids = np.unique(mask)
            obj_ids = obj_ids[obj_ids != 0]
            if len(obj_ids) == 1:
                self.valid_indices.append(idx)

    def __len__(self):
      return len(self.valid_indices)

    def __getitem__(self, idx):
        real_idx = self.valid_indices[idx]

        img_path  = os.path.join(self.img_dir,  self.images[real_idx])
        mask_path = os.path.join(self.mask_dir, self.masks[real_idx])

        img = Image.open(img_path).convert("RGB")
        mask = np.array(Image.open(mask_path))

        orig_w, orig_h = img.size

        ys, xs = np.where(mask == 1)
        xmin, xmax = xs.min(), xs.max()
        ymin, ymax = ys.min(), ys.max()

        # Normalize BEFORE any transform
        bbox = torch.tensor([
            xmin / orig_w,
            ymin / orig_h,
            xmax / orig_w,
            ymax / orig_h
        ], dtype=torch.float32)

        label = torch.tensor(1)

        if self.transform:
            img = self.transform(img)

        return {"image": img, "label": label, "bbox": bbox}


In [None]:
from torchvision import transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

dataset = PennFudanSinglePed(
    root="PennFudanPed",
    transform=transform
)

loader = DataLoader(dataset, batch_size=8, shuffle=True)

for dictionary in loader:
    print(dictionary["image"].shape)
    print(dictionary["label"])
    print(dictionary["bbox"])
    break


In [None]:
#################
# CONFIGURATION #
#################

In [None]:
@dataclass
class vit_config:
    num_channels: int = 3
    batch_size:int = 16
    image_size: int = 224
    patch_size: int = 16
    num_heads:int = 8
    dropout: float = 0.0
    layer_norm_eps: float = 1e-6
    num_encoder_layers: int = 12
    random_seed: int = 42
    epochs: int = 30
    num_classes: int = 10
    learning_rate: float = 1e-5
    adam_weight_decay: int = 0
    adam_betas: tuple = (0.9, 0.999)
    embd_dim: int = (patch_size ** 2) * num_channels           # 768
    num_patches: int = (image_size // patch_size) ** 2         # 196
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
config = vit_config

random.seed(config.random_seed)
numpy.random.seed(config.random_seed)
torch.manual_seed(config.random_seed)
torch.cuda.manual_seed(config.random_seed)
torch.cuda.manual_seed_all(config.random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

In [None]:
##################
# MODEL BUILDING #
##################

In [None]:
class VisionEmbedding(nn.Module):
    def __init__(self, config: vit_config):
        super().__init__()

        self.config  = config
        self.patch_embedding = nn.Sequential(
            nn.Conv2d(
                in_channels=config.num_channels,
                out_channels=config.embd_dim,
                kernel_size=config.patch_size,
                stride=config.patch_size,
                padding="valid"
            ),
            nn.Flatten(start_dim=2)
        )

        self.cls_token = nn.Parameter(torch.randn(size=(1, 1, config.embd_dim)), requires_grad=True)
        self.pos_embeddings = nn.Parameter(torch.randn(size=(1, config.num_patches + 1, config.embd_dim)), requires_grad=True)
        self.dropout = nn.Dropout(p=config.dropout)

    def forward(self, x : torch.Tensor) -> torch.Tensor:
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)

        patch_embd = self.patch_embedding(x).transpose(2,1)
        patch_embd = torch.cat([cls_token, patch_embd], dim=1)
        embd = self.pos_embeddings + patch_embd
        embd = self.dropout(embd)
        return embd

In [None]:
class VisionAttention(nn.Module):
    def __init__(self, config:vit_config):
        super().__init__()

        self.embd_dim = config.embd_dim
        self.num_heads = config.num_heads
        self.dropout = config.dropout

        self.q_proj = nn.Linear(self.embd_dim, self.embd_dim)
        self.k_proj = nn.Linear(self.embd_dim, self.embd_dim)
        self.v_proj = nn.Linear(self.embd_dim, self.embd_dim)
        self.out_proj = nn.Linear(self.embd_dim, self.embd_dim)

    def forward(self, x : torch.Tensor) -> torch.Tensor:
        B, T, C = x.shape

        query = self.q_proj(x)
        key = self.k_proj(x)
        value = self.v_proj(x)

        query = query.view(B, T, self.num_heads, C//self.num_heads).transpose(1,2)
        key = key.view(B, T, self.num_heads, C//self.num_heads).transpose(1,2)
        value = value.view(B, T, self.num_heads, C//self.num_heads).transpose(1,2)

        attn_score = (query @ key.transpose(-2, -1)) * (1.0 / math.sqrt(key.size(-1)))
        attn_score = F.softmax(attn_score, dim=-1).to(query.dtype)

        attn_out = (attn_score @ value).transpose(1,2)
        attn_out = attn_out.reshape(B, T, C).contiguous()
        attn_out = self.out_proj(attn_out)
        attn_out = F.dropout(attn_out, p=self.dropout, training=self.training)

        return attn_out

In [None]:
class VisionMLP(nn.Module):
    def __init__(self, config:vit_config):
        super().__init__()

        self.layer1 = nn.Linear(config.embd_dim, 3 * config.embd_dim)
        self.layer2 = nn.Linear(3 * config.embd_dim, config.embd_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.layer1(x)
        x = nn.functional.gelu(x, approximate="tanh")
        x = self.layer2(x)
        return x

In [None]:
class VisionEncoderLayer(nn.Module):
    def __init__(self, config: vit_config):
        super().__init__()

        self.embd_dim = config.embd_dim
        self.attn = VisionAttention(config)
        self.layer_norm1 = nn.LayerNorm(self.embd_dim, eps=config.layer_norm_eps)
        self.mlp = VisionMLP(config)
        self.layer_norm2 = nn.LayerNorm(self.embd_dim, eps=config.layer_norm_eps)

    def forward(self, x:torch.Tensor) -> torch.Tensor:
        x = x + self.attn(self.layer_norm1(x))
        x = x + self.mlp(self.layer_norm2(x))
        return x

In [None]:
class VisionEncoder(nn.Module):
    def __init__(self, config: vit_config):
        super().__init__()
        self.layers = nn.ModuleList([VisionEncoderLayer(config) for _ in range(config.num_encoder_layers)])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for layer in self.layers:
            x = layer(x)
        return x

In [None]:
class VisionTransformer(nn.Module):
    def __init__(self, config=vit_config):
        super().__init__()

        self.embedding = VisionEmbedding(config)
        self.encoder = VisionEncoder(config)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(config.embd_dim, eps=config.layer_norm_eps),
            nn.Linear(config.embd_dim, config.num_classes)
        )

        self.bbox_head = nn.Sequential(
            nn.LayerNorm(config.embd_dim),
            nn.ReLU(),
            nn.Linear(config.embd_dim, 4)  # [x_min, y_min, x_max, y_max]
        )

    def forward(self, x:torch.Tensor ) -> torch.Tensor:
        x = self.embedding(x)
        x = self.encoder(x)
        logits = self.mlp_head(x[:, 0, :])
        bbox = torch.sigmoid(self.bbox_head(x[:, 0, :]))

        return logits, bbox

In [None]:
vit = VisionTransformer(vit_config)
summary(model=vit,
        input_size=(16, 3, 224, 224),
        col_names= ["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings= ["var_names"]
    )

In [None]:
config = vit_config
model = VisionTransformer(config)
model.to(config.device)

In [None]:
##################
# MODEL TRAINING #
##################

In [None]:
criterion_cls = nn.CrossEntropyLoss()
criterion_bbox = nn.SmoothL1Loss()
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate, betas=config.adam_betas, weight_decay=config.adam_weight_decay)
scheduler = CosineAnnealingLR(optimizer, T_max=config.epochs)

for epoch in range(config.epochs):
    model.train()
    train_loss_running = 0

    for batch in loader:
        img = batch["image"].float().to(config.device)
        label = batch["label"].long().to(config.device)
        bbox  = batch["bbox"].float().to(config.device)

        optimizer.zero_grad()
        logits, bbox_pred = model(img)

        loss_cls = criterion_cls(logits, label)
        loss_bbox = criterion_bbox(bbox_pred, bbox)
        loss = loss_cls + loss_bbox  # you can weight: loss = loss_cls + 5*loss_bbox

        loss.backward()
        optimizer.step()

        train_loss_running += loss.item()

    print(f"Epoch {epoch+1}, Train Loss: {train_loss_running / len(loader):.4f}")


In [None]:
####################
# MODEL EVALUATION #
####################

In [None]:
def show_prediction(img_tensor, bbox_pred):
    img_np = img_tensor.permute(1,2,0).cpu().numpy()
    h, w = img_np.shape[:2]

    xmin, ymin, xmax, ymax = bbox_pred
    xmin *= w
    xmax *= w
    ymin *= h
    ymax *= h

    import matplotlib.pyplot as plt
    import matplotlib.patches as patches

    fig, ax = plt.subplots()
    ax.imshow(img_np)
    rect = patches.Rectangle(
        (xmin, ymin),
        xmax - xmin,
        ymax - ymin,
        fill=False,
        edgecolor='red',
        linewidth=2
    )
    ax.add_patch(rect)
    plt.show()

In [None]:
model.eval()
batch = next(iter(loader))
img = batch["image"][0].unsqueeze(0).to(config.device)
_, bbox = model(img)
show_prediction(batch["image"][0], bbox[0].detach().cpu())