In [24]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt 
from tqdm import tqdm
import time 

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")



Using device: cpu


In [25]:
print("Loading dataset...")
data = np.load('../data/synthetic_tracking_dataset.npz', allow_pickle=True)
videos = data['videos']
positions = data['positions']

print(f"✓ Loaded {len(videos)} videos")
print(f"✓ Video shape: {videos.shape}")
print(f"✓ Each video has {videos.shape[1]} frames")

Loading dataset...
✓ Loaded 1000 videos
✓ Video shape: (1000, 50, 32, 32)
✓ Each video has 50 frames


In [26]:
class TrackingDataset(Dataset):
    """
    Dataset for tracking task.
    Given two consecutive frames, predict the position in the second frame.
    """
    def __init__(self, videos, positions):
        self.videos = videos
        self.positions = positions

    def __len__(self):
        return len(self.videos) * (self.videos.shape[1] - 1)

    def __getitem__(self, idx):

        video_idx = idx // (self.videos.shape[1] - 1)
        frame_idx = idx % (self.videos.shape[1] - 1)

        frame1 = self.videos[video_idx, frame_idx]
        frame2 = self.videos[video_idx, frame_idx + 1]

        input_frames = np.stack([frame1, frame2], axis=0)

        target_pos = np.array(self.positions[video_idx][frame_idx + 1])

        input_frames = torch.FloatTensor(input_frames)
        target_pos = torch.FloatTensor(target_pos)

        return input_frames, target_pos

print("Creating dataset...")
full_dataset = TrackingDataset(videos, positions)

train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(
    full_dataset, [train_size, val_size]
)

print(f"✓ Training samples: {len(train_dataset)}")
print(f"✓ Validation samples: {len(val_dataset)}")

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

test_input, test_target = next(iter(train_loader))
print(f"n✓ Batch input shape: {test_input.shape}")
print(f"✓ Batch target shape: {test_target.shape}")
        

    

Creating dataset...
✓ Training samples: 39200
✓ Validation samples: 9800
n✓ Batch input shape: torch.Size([32, 2, 32, 32])
✓ Batch target shape: torch.Size([32, 2])


In [27]:
class PatchEmbedding(nn.Module):
    """Convert image into patches and embed them"""
    def __init__(self, img_size=32, patch_size=4, in_channels=2, embed_dim=64):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2
        self.patch_size = patch_size

        self.proj = nn.Conv2d( #convulational layer that extractes and embeds patches
            in_channels, embed_dim,
            kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2)
        x = x.transpose(1,2)
        return x

class StandardAttention(nn.Module):
    """Standard multi-head self-attention"""
    def __init__(self, dim, num_heads=4, dropout=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, N, C = x.shape #unpack the input shape

        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) #creates qkv
        qkv = qkv.permute(2, 0, 3, 1, 4) #organize for multi-head attention
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale #compare patches - where you get the attentions scores
        attn = attn.softmax(dim=-1) #convert to probabilities
        attn = self.dropout(attn) #prevents overfitting

        x = (attn @ v).transpose(1, 2).reshape(B, N, C) #apply attention weights to values
        x = self.proj(x) # refine output 

        return x

class TransformerBlock(nn.Module):
    """One transformer block: Attention + MLP"""
    def __init__(self, dim, num_heads=4, mlp_ratio=4, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = StandardAttention(dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(dim)

        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            nn.GELU(), #activation function
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))

        x = x + self.mlp(self.norm2(x))
        return x

print("✓ Model components defined!")

✓ Model components defined!


In [28]:
class VisionTransformerTracker(nn.Module):
    """
    Vision Transformer for tracking.
    Takes two frames, outputs predicted (x, y) position.
    """

    def __init__(
        self,
        img_size=32,
        patch_size=4,
        in_channels=2,
        embed_dim=64,
        depth=4, #transformer blocks
        num_heads=4, 
        mlp_ratio=4,
        dropout=0.1
    ):
        super().__init__()

        self.patch_embed = PatchEmbedding(
            img_size, patch_size, in_channels, embed_dim
        )
        num_patches = self.patch_embed.num_patches

        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) #tensors become paramters
        self.dropout = nn.Dropout(dropout)

        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])

        self.norm = nn.LayerNorm(embed_dim) # stabilises numbers for final prediction head

        self.head = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 2), # compresses dims, 64 to 32
            nn.GELU(), #adds non-linearity
            nn.Linear(embed_dim // 2, 2) #returns two coordinates for prediction from 32 dims
        )

    def forward(self, x):

        x = self.patch_embed(x)

        x = x + self.pos_embed
        x = self.dropout(x)

        for block in self.blocks:
            x = block(x)

        x = self.norm(x)

        x = x.mean(dim=1)

        pos = self.head(x)

        return pos

print("Creating model...")
model = VisionTransformerTracker(
    img_size=32,
    patch_size=4,
    embed_dim=64,
    depth=4,
    num_heads=4
).to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"✓ Model created with {total_params:,} parameters")

test_input = torch.randn(4, 2, 32, 32).to(device)
test_output = model(test_input)
print(f"✓ Test output shape: {test_output.shape}")


Creating model...
✓ Model created with 208,418 parameters
✓ Test output shape: torch.Size([4, 2])
