<a href="https://colab.research.google.com/github/Chaudhari-Amar/econ8310-assignment-baseball-amar/blob/main/train_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

import os
from typing import Tuple

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.models import resnet18

from dataloader import create_dataloader, DATA_PATH


class BoxRegressor(nn.Module):
    """ResNet18 backbone with 4-dim regression head (cx, cy, w, h)."""

    def __init__(self):
        super().__init__()
        self.backbone = resnet18(weights=None)  # allowed by assignment (only PyTorch)
        in_feats = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        self.head = nn.Sequential(
            nn.Linear(in_feats, 256),
            nn.ReLU(),
            nn.Linear(256, 4),
            nn.Sigmoid(),  # keep outputs in [0,1]
        )

    def forward(self, x):  # x: (B, 3, H, W)
        feats = self.backbone(x)
        return self.head(feats)  # (B, 4)


def train(
    epochs: int = 3,
    batch_size: int = 16,
    lr: float = 1e-3,
    max_frames_per_video: int = 64,
    frame_subsample: int = 1,
    device: str = None,
    out_weights: str = "model_weights.pth",
):
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Data
    dl = create_dataloader(
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
        videos_dir=os.path.join(DATA_PATH, "Raw Videos"),
        ann_dir=os.path.join(DATA_PATH, "Annotations"),
        max_frames_per_video=max_frames_per_video,
        frame_subsample=frame_subsample,
    )

    model = BoxRegressor().to(device)
    criterion = nn.SmoothL1Loss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    model.train()
    for epoch in range(epochs):
        running = 0.0
        for batch in dl:
            imgs = batch["image"].to(device)
            targets = batch["target"].to(device)

            preds = model(imgs)
            loss = criterion(preds, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running += loss.item() * imgs.size(0)
        epoch_loss = running / len(dl.dataset)
        print(f"Epoch {epoch+1}/{epochs} - Loss: {epoch_loss:.4f}")

    torch.save(model.state_dict(), out_weights)
    print(f"Saved weights to: {out_weights}")


if __name__ == "__main__":
    # Example minimal training run
    train()

