<a href="https://colab.research.google.com/github/Chaudhari-Amar/econ8310-assignment-baseball-amar/blob/main/Inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import torch
from torchvision.transforms import functional as TF
import torchvision

from train_model import BoxRegressor


def load_model(weights_path: str = "model_weights.pth", device: str = None):
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    model = BoxRegressor().to(device)
    sd = torch.load(weights_path, map_location=device)
    model.load_state_dict(sd)
    model.eval()
    return model, device


def predict_video_boxes(video_path: str, model: BoxRegressor, device: str):
    video, _, info = torchvision.io.read_video(video_path, pts_unit="sec")  # (T,H,W,C)
    T, H, W, C = video.shape
    preds = []
    with torch.no_grad():
        for t in range(T):
            frame = video[t].permute(2, 0, 1).float() / 255.0  # (3,H,W)
            frame = frame.unsqueeze(0).to(device)
            out = model(frame)[0].cpu().tolist()  # [cx, cy, w, h] in [0,1]
            # Convert to absolute pixel corners for readability
            cx, cy, bw, bh = out
            x1 = max(0, int((cx - bw / 2) * W))
            y1 = max(0, int((cy - bh / 2) * H))
            x2 = min(W - 1, int((cx + bw / 2) * W))
            y2 = min(H - 1, int((cy + bh / 2) * H))
            preds.append((t, x1, y1, x2, y2))
    return preds


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("video", type=str, help="Path to a .mov/.mp4 video")
    parser.add_argument("--weights", type=str, default="model_weights.pth")
    args = parser.parse_args()

    model, device = load_model(args.weights)
    results = predict_video_boxes(args.video, model, device)
    for t, x1, y1, x2, y2 in results:
        print(f"frame={t}\tbox=({x1},{y1},{x2},{y2})")
