# Vision Transformers for Object Detection in PyTorch

This Jupyter notebook is based on the blog post [Vision Transformers for Object Detection](https://www.labellerr.com/blog/vision-transformers-for-object-detection/), adapted to use PyTorch instead of TensorFlow/Keras. The original post demonstrates a simple Vision Transformer (ViT) model for bounding box regression on a single object per image, using subsets of the Caltech-101 dataset (cars and faces). We've translated the architecture, data preparation, training, and evaluation to PyTorch.

The model treats images as sequences of patches, encodes them with positional embeddings, applies transformer layers, and outputs a 4-dimensional vector for the bounding box (normalized top-left x, top-left y, bottom-right x, bottom-right y).

**Note:** This is a simplified object localizer, not a full multi-object detector like DETR. Run this in a Jupyter environment (e.g., Google Colab) with PyTorch installed. We've assumed a GPU is available for faster training.

In [None]:
import os
import shutil
import numpy as np
import matplotlib.pyplot as plt
import cv2
from scipy.io import loadmat
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import StepLR  # Optional for learning rate scheduling
import requests  # For downloading the dataset
from io import BytesIO
from PIL import Image
import tarfile

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Download and Prepare the Dataset

The dataset is Caltech-101, focusing on "cars" and "faces" categories with bounding box annotations.

In [None]:
# Download Caltech-101 dataset if not present
if not os.path.exists("101_ObjectCategories.tar.gz"):
    url = "https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view?usp=sharing"
    print("Note: The direct download link may require manual intervention. Alternatively, download from http://www.vision.caltech.edu/datasets/ and place '101_ObjectCategories.tar.gz' in the current directory.")
    # For automation, you can use gdown if installed: !gdown 137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
else:
    print("Dataset already downloaded.")

# Extract the tar.gz
if not os.path.exists("101_ObjectCategories"):
    with tarfile.open("101_ObjectCategories.tar.gz", "r:gz") as tar:
        tar.extractall()

# Prepare directories
data_dir = "data"
os.makedirs(data_dir, exist_ok=True)
categories = ["Faces_easy", "Motorbikes"]  # Using Faces_easy and Motorbikes as per blog (cars are Motorbikes in Caltech-101?)

# Copy images
for category in categories:
    category_dir = os.path.join("101_ObjectCategories", category)
    dest_dir = os.path.join(data_dir, category)
    shutil.copytree(category_dir, dest_dir)

# Load annotations (assuming annotation dir exists; Caltech-101 annotations are in separate .mat files)
annotation_dir = "Annotations"  # Download from http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar if needed
if not os.path.exists(annotation_dir):
    # Placeholder: Download and extract Annotations.tar
    pass

# Function to load data
def load_data(category):
    images = []
    bboxes = []
    img_dir = os.path.join("101_ObjectCategories", category)
    ann_dir = os.path.join(annotation_dir, category)
    for filename in os.listdir(img_dir):
        if filename.endswith(".jpg"):
            img_path = os.path.join(img_dir, filename)
            ann_path = os.path.join(ann_dir, filename.replace(".jpg", ".mat"))
            if os.path.exists(ann_path):
                ann = loadmat(ann_path)
                bbox = ann["box_coord"][0]  # [x1, y1, x2, y2]
                img = cv2.imread(img_path)
                img = cv2.resize(img, (224, 224))
                # Normalize bbox to [0,1]
                h, w = img.shape[:2]
                bbox = bbox.astype(np.float32) / np.array([w, h, w, h])
                images.append(img)
                bboxes.append(bbox)
    return np.array(images), np.array(bboxes)

# Load categories
images_cars, bboxes_cars = load_data("Motorbikes")
images_faces, bboxes_faces = load_data("Faces_easy")

# Combine
x = np.concatenate([images_cars, images_faces], axis=0)
y = np.concatenate([bboxes_cars, bboxes_faces], axis=0)

# Shuffle and split train/test (90/10)
num_samples = len(x)
indices = np.random.permutation(num_samples)
split = int(0.9 * num_samples)
x_train, y_train = x[indices[:split]], y[indices[:split]]
x_test, y_test = x[indices[split:]], y[indices[split:]]
print(f"Train samples: {len(x_train)}, Test samples: {len(x_test)}")

**Note:** The blog uses "car_side" and "Faces", but Caltech-101 has "Motorbikes" for cars and "Faces_easy". Adjust categories if needed. Annotations must be downloaded separately if not included.

## Visualize Patches

In [None]:
# Parameters
image_size = 224
patch_size = 32

# Function to extract patches (for visualization)
def extract_patches(image, patch_size):
    image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)  # [1, C, H, W]
    patches = image.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
    patches = patches.permute(0, 2, 3, 1, 4, 5).reshape(1, -1, 3 * patch_size * patch_size)
    return patches.squeeze(0).numpy()

# Display original image and patches
plt.figure(figsize=(4, 4))
plt.imshow(x_train[0])
plt.axis("off")
plt.show()

patches = extract_patches(x_train[0], patch_size)
print(f"Image size: {image_size} X {image_size}")
print(f"Patch size: {patch_size} X {patch_size}")
print(f"Patches per image: {patches.shape[0]}")
print(f"Elements per patch: {patches.shape[1]}")

n = int(np.sqrt(patches.shape[0]))
plt.figure(figsize=(4, 4))
for i, patch in enumerate(patches):
    ax = plt.subplot(n, n, i + 1)
    patch_img = np.reshape(patch, (patch_size, patch_size, 3))
    plt.imshow(patch_img)
    plt.axis("off")
plt.show()

## Define Custom Layers and Model

In [None]:
class MLP(nn.Module):
    def __init__(self, hidden_units, dropout_rate):
        super(MLP, self).__init__()
        self.layers = nn.ModuleList()
        for units in hidden_units:
            self.layers.append(nn.Linear(units, units))
            self.layers.append(nn.GELU())
            self.layers.append(nn.Dropout(dropout_rate))

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

# Note: The blog's MLP is sequential with decreasing units, but for transformer it's [proj*2, proj]

class PatchEncoder(nn.Module):
    def __init__(self, num_patches, projection_dim):
        super(PatchEncoder, self).__init__()
        self.projection = nn.Linear(3 * patch_size * patch_size, projection_dim)
        self.position_embedding = nn.Embedding(num_patches, projection_dim)

    def forward(self, patches):
        positions = torch.arange(0, num_patches, device=patches.device)
        projected = self.projection(patches)
        encoded = projected + self.position_embedding(positions)
        return encoded

class TransformerBlock(nn.Module):
    def __init__(self, projection_dim, num_heads, transformer_units, dropout_rate):
        super(TransformerBlock, self).__init__()
        self.ln1 = nn.LayerNorm(projection_dim)
        self.attention = nn.MultiheadAttention(projection_dim, num_heads, dropout=dropout_rate)
        self.ln2 = nn.LayerNorm(projection_dim)
        self.mlp = nn.Sequential(
            nn.Linear(projection_dim, transformer_units[0]),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(transformer_units[0], transformer_units[1]),
            nn.GELU(),
            nn.Dropout(dropout_rate)
        )

    def forward(self, x):
        # x shape: (batch, num_patches, proj_dim) -> transpose to (num_patches, batch, proj_dim) for attention
        x1 = self.ln1(x).transpose(0, 1)
        attn_output, _ = self.attention(x1, x1, x1)
        attn_output = attn_output.transpose(0, 1)
        x2 = x + attn_output
        x3 = self.ln2(x2)
        x3 = self.mlp(x3)
        return x2 + x3

class ViTObjectDetector(nn.Module):
    def __init__(self, input_shape, patch_size, num_patches, projection_dim, num_heads, transformer_units, transformer_layers, mlp_head_units, dropout_rate=0.3):
        super(ViTObjectDetector, self).__init__()
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.projection_dim = projection_dim
        self.patch_encoder = PatchEncoder(num_patches, projection_dim)
        self.transformer_blocks = nn.ModuleList([TransformerBlock(projection_dim, num_heads, transformer_units, dropout_rate=0.1) for _ in range(transformer_layers)])
        self.ln = nn.LayerNorm(projection_dim)
        self.dropout = nn.Dropout(dropout_rate)
        mlp_layers = []
        in_features = num_patches * projection_dim
        for units in mlp_head_units:
            mlp_layers.extend([nn.Linear(in_features, units), nn.GELU(), nn.Dropout(dropout_rate)])
            in_features = units
        self.mlp_head = nn.Sequential(*mlp_layers)
        self.output = nn.Linear(in_features, 4)

    def extract_patches(self, images):
        batch_size = images.shape[0]
        patches = images.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        patches = patches.permute(0, 2, 3, 1, 4, 5).reshape(batch_size, self.num_patches, -1)
        return patches

    def forward(self, x):
        patches = self.extract_patches(x)
        encoded = self.patch_encoder(patches)
        for block in self.transformer_blocks:
            encoded = block(encoded)
        representation = self.ln(encoded)
        representation = representation.view(representation.size(0), -1)
        representation = self.dropout(representation)
        features = self.mlp_head(representation)
        bbox = self.output(features)
        return bbox

## Prepare Dataset and DataLoaders

In [None]:
class ObjectDetectionDataset(Dataset):
    def __init__(self, images, boxes):
        self.images = images
        self.boxes = boxes
        self.transform = transforms.Compose([
            transforms.ToTensor(),  # Converts to [0,1]
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.transform(self.images[idx])
        box = torch.from_numpy(self.boxes[idx]).float()
        return img, box

# Create datasets
train_dataset = ObjectDetectionDataset(x_train, y_train)
test_dataset = ObjectDetectionDataset(x_test, y_test)

# DataLoaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

## Instantiate Model and Train

In [None]:
# Hyperparameters from blog
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 4
transformer_units = [projection_dim * 2, projection_dim]
transformer_layers = 4
mlp_head_units = [2048, 1024, 512, 64, 32]
learning_rate = 0.001
weight_decay = 0.0001
num_epochs = 100

# Model
model = ViTObjectDetector(
    input_shape=(3, image_size, image_size),  # C, H, W
    patch_size=patch_size,
    num_patches=num_patches,
    projection_dim=projection_dim,
    num_heads=num_heads,
    transformer_units=transformer_units,
    transformer_layers=transformer_layers,
    mlp_head_units=mlp_head_units
).to(device)

# Optimizer and loss
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
criterion = nn.MSELoss()

# Training loop
history = {'loss': [], 'val_loss': []}
best_val_loss = float('inf')
patience = 10
counter = 0

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for imgs, boxes in train_loader:
        imgs, boxes = imgs.to(device), boxes.to(device)
        optimizer.zero_grad()
        preds = model(imgs)
        loss = criterion(preds, boxes)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * imgs.size(0)
    train_loss /= len(train_loader.dataset)
    history['loss'].append(train_loss)
    
    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for imgs, boxes in test_loader:
            imgs, boxes = imgs.to(device), boxes.to(device)
            preds = model(imgs)
            loss = criterion(preds, boxes)
            val_loss += loss.item() * imgs.size(0)
    val_loss /= len(test_loader.dataset)
    history['val_loss'].append(val_loss)
    
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    
    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "vit_object_detector.pth")
        counter = 0
    else:
        counter += 1
        if counter >= patience:
            print("Early stopping")
            break

# Plot history
plt.plot(history['loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Val Loss')
plt.legend()
plt.show()

## Evaluation: IoU Calculation and Visualization

In [None]:
def bounding_box_iou(box_pred, box_true):
    # Boxes: [x1, y1, x2, y2] normalized
    x1 = max(box_pred[0], box_true[0])
    y1 = max(box_pred[1], box_true[1])
    x2 = min(box_pred[2], box_true[2])
    y2 = min(box_pred[3], box_true[3])
    
    intersection = max(0, x2 - x1) * max(0, y2 - y1)
    area_pred = (box_pred[2] - box_pred[0]) * (box_pred[3] - box_pred[1])
    area_true = (box_true[2] - box_true[0]) * (box_true[3] - box_true[1])
    union = area_pred + area_true - intersection
    
    return intersection / union if union > 0 else 0

# Load best model
model.load_state_dict(torch.load("vit_object_detector.pth"))
model.eval()

# Test on a sample
with torch.no_grad():
    sample_img, sample_box = test_dataset[0]
    sample_img = sample_img.unsqueeze(0).to(device)
    pred_box = model(sample_img).cpu().squeeze(0).numpy()

# Calculate IoU
iou = bounding_box_iou(pred_box, sample_box.numpy())
print(f"IoU: {iou:.4f}")

# Visualize
def draw_bbox(img, bbox, color=(0, 255, 0)):
    h, w = img.shape[:2]
    x1, y1, x2, y2 = (bbox * np.array([w, h, w, h])).astype(int)
    cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)

img = x_test[0].copy()
draw_bbox(img, sample_box.numpy(), (0, 255, 0))  # Green for ground truth
draw_bbox(img, pred_box, (255, 0, 0))  # Red for prediction

plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
plt.title(f"Ground Truth (Green) vs Prediction (Red), IoU: {iou:.4f}")
plt.axis("off")
plt.show()

This notebook replicates the blog's functionality in PyTorch. For full multi-object detection, consider extending to models like DETR. Adjust hyperparameters or add data augmentation for better performance.