In [2]:
%load_ext autoreload
%autoreload 2

import copy
from pathlib import Path

import cv2
import numpy as np
import tqdm

import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.utils.data

from detectron2.config import get_cfg
from detectron2.engine import DefaultTrainer
from overcooked_ai.dataset_types import DetectionDataset
from detectron2.data import MetadataCatalog, DatasetCatalog

SOURCE_DIR = Path("/home/mimic/Overcooked2_1-1_jpeg")
MODELS_DIR = Path("/home/mimic/objdet/overcooked_models")

In [3]:
import torch

TORCH_VERSION = ".".join(torch.__version__.split(".")[:2])
CUDA_VERSION = torch.__version__.split("+")[-1]
print("torch: ", TORCH_VERSION, "; cuda: ", CUDA_VERSION)

gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

device = "cuda:0" if torch.cuda.is_available() else "cpu"

torch:  2.6 ; cuda:  cu124
Mon Jun 16 09:59:05 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.133.20             Driver Version: 570.133.20     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 3080        Off |   00000000:01:00.0  On |                  N/A |
| 30%   32C    P5             29W /  320W |    2732MiB /  10240MiB |      3%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                     

In [5]:
## Load dataset and existing detectron2 object detection model

rng_seed = 1337
training_subset_ratio = 0.75
dataset_json_path = SOURCE_DIR / "detection_dataset.mar2025.json"

dataset = DetectionDataset.load_from_json(dataset_json_path)

# Update file_name to absolute paths
for entry in dataset.entries:
    entry.file_name = str(SOURCE_DIR / entry.file_name)

# Configure dataset into catalog, merely for loading existing trainer
dataset_dict = dataset.to_dict()["dataset_dict"]
thing_classes = dataset.thing_classes
thing_colors = [(255, 255, 255) for _ in range(len(thing_classes))]

# Split into train and val
np.random.seed(rng_seed)
np.random.shuffle(dataset_dict)
num_training_entries = int(len(dataset_dict) * training_subset_ratio)
train_dataset_dict = dataset_dict[:num_training_entries]
val_dataset_dict = dataset_dict[num_training_entries:]

for d in ["train", "val"]:
    dataset_name = "overcooked_" + d
    if dataset_name in DatasetCatalog:
        DatasetCatalog.remove(dataset_name)
        MetadataCatalog.remove(dataset_name)
    DatasetCatalog.register(dataset_name, lambda d=d: train_dataset_dict if d == "train" else val_dataset_dict)
    MetadataCatalog.get(dataset_name).set(thing_classes=thing_classes, thing_colors=thing_colors)

# Load existing model
cfg = get_cfg()
cfg.merge_from_file(MODELS_DIR / "20250323_225544" / "model.yaml")
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=True)

[32m[06/16 09:59:47 d2.engine.defaults]: [0mModel:
GeneralizedRCNN(
  (backbone): FPN(
    (fpn_lateral2): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral3): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral4): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral5): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (top_block): LastLevelMaxPool()
    (bottom_up): ResNet(
      (stem): BasicStem(
        (conv1): Conv2d(
          3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
          (norm): FrozenBatchNorm2d(num_features=64, eps=1e-05)
        )
      )
 

In [21]:
## Build embedding dataset

rcnn = trainer.model

embedding_dataset = []
for entry_dict in tqdm.tqdm(dataset_dict, desc="frame"):
    entry_dict = copy.deepcopy(entry_dict)
    image = cv2.imread(entry_dict["file_name"], cv2.IMREAD_COLOR)  # detectron2 defaults to BGR
    image_tensor = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
    entry_dict["image"] = image_tensor

    H_grid_frame = np.array(entry_dict["H_grid_img"]).reshape(3, 3)
    H_grid_frame /= H_grid_frame[2, 2]
    H_grid_frame = torch.tensor(H_grid_frame, dtype=torch.float64)

    with torch.no_grad():
        images = rcnn.preprocess_image([entry_dict])
        features = rcnn.backbone(images.tensor)

        fpn_flattened_vectors = []
        for key in features:
            x = F.adaptive_avg_pool2d(features[key], (1, 1))
            x = x.view(x.shape[0], -1)
            fpn_flattened_vectors.append(x)

        feature_embedding = torch.cat(fpn_flattened_vectors, dim=1).squeeze(0)
    
    embedding_dataset.append({
        "file_name": entry_dict["file_name"],
        "feature_embedding": feature_embedding.to("cpu"),
        "H_grid_frame": H_grid_frame.to("cpu"),
    })

embedding_dataset_path = MODELS_DIR / "embedding_dataset.pth"
torch.save(embedding_dataset, embedding_dataset_path)
print(f"Saved embedding dataset to {embedding_dataset_path}")

frame: 100%|██████████| 860/860 [00:48<00:00, 17.90it/s]

Saved embedding dataset to /home/mimic/objdet/overcooked_models/embedding_dataset.pth





In [24]:
## Build nn.Module to regress homography matrix from feature embedding

from overcooked_ai.game_maps import world_1_1_tile_short_labels
from overcooked_ai.type_conversions import get_world_1_1_corner_grid_xys


class DifferentiableDLT(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, target_xys: torch.Tensor, source_xys: torch.Tensor) -> torch.Tensor:
        """
        @param target_xys: (B x N x 2) target points
        @param source_xys: (B x N x 2) source points

        @return H_target_source: (B x 3 x 3) homography matrix
        """
        B = source_xys.shape[0]
        assert source_xys.shape == target_xys.shape == (B, 4, 2)

        # Extract individual coordinates, of dim [B, 4]
        source_xs, source_ys = source_xys[:, :, 0], source_xys[:, :, 1]
        target_xs, target_ys = target_xys[:, :, 0], target_xys[:, :, 1]
        zeros = torch.zeros_like(target_xs)
        ones = torch.ones_like(target_xs)

        # Build the A matrix for each correspondence
        A1 = torch.stack([zeros, zeros, zeros, -source_xs, -source_ys, -ones, target_ys * source_xs, target_ys * source_ys, target_ys], dim=2)
        A2 = torch.stack([source_xs, source_ys, ones, zeros, zeros, zeros, -target_xs * source_xs, -target_xs * source_ys, -target_xs], dim=2)
        A = torch.stack([A1, A2], dim=2).reshape(B, 8, 9)  # [B, 8, 9]

        # Compute homography via SVD
        _, _, V = torch.linalg.svd(A)  # [B, 9, 9]
        H_target_source = V[:, -1, :].reshape(B, 3, 3)  # Take last row of V and reshape

        # Normalize so that H[2,2] = 1
        H_target_source = H_target_source / H_target_source[:, 2:3, 2:3]

        return H_target_source


class HomographyNet(nn.Module):
    def __init__(self):
        super().__init__()

        # [1, 4, 2]
        self.corner_grid_xys = torch.tensor(get_world_1_1_corner_grid_xys(), dtype=torch.float64).unsqueeze(0)

        self.corner_frame_xys_regressor = nn.Sequential(
            nn.Linear(1280, 320),
            nn.ReLU(),
            nn.Linear(320, 80),
            nn.ReLU(),
            nn.Linear(80, 8),
        )
        self.dlt = DifferentiableDLT()

    def forward(self, feature_embedding: torch.Tensor) -> dict[str, torch.Tensor]:
        """
        @param feature_embedding: (B, 1280)

        @return predicted_image_xys: (B, 8)
        """
        corners_frame_xys_vec = self.corner_frame_xys_regressor(feature_embedding)  # [B, 8]
        corners_frame_xys = corners_frame_xys_vec.view(-1, 4, 2).to(torch.float64)  # [B, 4, 2]

        corner_grid_xys = self.corner_grid_xys.to(feature_embedding.device)  # [1, 4, 2]
        corner_grid_xys = corner_grid_xys.repeat(corners_frame_xys.shape[0], 1, 1)  # [B, 4, 2]
        
        # NOTE: predict H_grid_frame, to easily apply to frame xys (as predicted by this net and by objdet)
        H_grid_frame = self.dlt(target_xys=corner_grid_xys, source_xys=corners_frame_xys)  # [B, 3, 3]

        return {
            "corners_frame_xys": corners_frame_xys,
            "H_grid_frame": H_grid_frame,
        }


def batched_apply_homography(H_target_source: torch.Tensor, xys_source: torch.Tensor) -> torch.Tensor:
    """Apply homography to source points.

    @param H_target_source: (batch_size x 3 x 3) homography matrix
    @param xys_source: (batch_size x N x 2) source points

    @return xys_target: (batch_size x N x 2) target points
    """

    hxys_source = torch.cat([xys_source, torch.ones_like(xys_source[:, :, :1])], dim=2)  # [batch_size, num_pts, 3]
    hxys_target = H_target_source @ hxys_source.transpose(1, 2)  # [batch_size, 3, num_pts]
    xys_target = hxys_target.transpose(1, 2)[:, :, :2] / hxys_target.transpose(1, 2)[:, :, 2:3]  # [batch_size, num_pts, 2]

    return xys_target

# Test if module works
h_net = HomographyNet().to(device)

feature_embedding = embedding_dataset[0]["feature_embedding"].to(device)

res = h_net(feature_embedding)
corners_frame_xys = res["corners_frame_xys"]
H_grid_frame = res["H_grid_frame"]
print(H_grid_frame)

tensor([[[-2.4178e+02, -1.7869e+02,  7.5200e+01],
         [ 4.1452e+02,  4.7522e+01,  4.2581e+02],
         [ 6.3797e-01,  3.1237e-01,  1.0000e+00]]], device='cuda:0',
       dtype=torch.float64, grad_fn=<DivBackward0>)


In [26]:
## Test homography fitting logic against existing ground truth

from overcooked_ai.grid_homography import apply_homography
from overcooked_ai.type_conversions import get_world_1_1_corner_grid_xys

entry_dict = dataset_dict[0]

H_grid_frame = np.array(entry_dict["H_grid_img"]).reshape(3, 3)
H_grid_frame /= H_grid_frame[2, 2]

print("GT H_grid_frame:\n", H_grid_frame)

# Get some ground truth homography, and map it into frame xys
corner_grid_xys = get_world_1_1_corner_grid_xys()
corner_grid_hxys = np.concatenate([corner_grid_xys, np.ones((4, 1))], axis=1)

corner_frame_hxys = apply_homography(np.linalg.inv(H_grid_frame), corner_grid_hxys)
corner_frame_xys = corner_frame_hxys[:, :2]

dlt = DifferentiableDLT()

T_grid_xys = torch.tensor(corner_grid_xys, dtype=torch.float64).unsqueeze(0)
T_frame_xys = torch.tensor(corner_frame_xys, dtype=torch.float64).unsqueeze(0)

T_H_grid_frame = dlt(target_xys=T_grid_xys, source_xys=T_frame_xys)
print("T_H_grid_frame\n", T_H_grid_frame[0])

print("diff w/ gt:")
print(np.abs(T_H_grid_frame.cpu().detach().numpy() - H_grid_frame))

T_pred_frame_xys = batched_apply_homography(torch.linalg.inv(T_H_grid_frame), T_grid_xys)
print("T_pred_frame_xys\n", T_pred_frame_xys)
print("diff w/ gt:\n", np.abs(T_pred_frame_xys.cpu().detach().numpy() - corner_frame_xys))

GT H_grid_frame:
 [[ 1.29720744e+00  3.49389082e-01 -6.41973130e+02]
 [-7.60501013e-06  1.79799546e+00 -2.65532313e+02]
 [-7.40649707e-06  5.60571219e-04  1.00000000e+00]]
T_H_grid_frame
 tensor([[ 1.2972e+00,  3.4939e-01, -6.4197e+02],
        [-7.6050e-06,  1.7980e+00, -2.6553e+02],
        [-7.4065e-06,  5.6057e-04,  1.0000e+00]], dtype=torch.float64)
diff w/ gt:
[[[9.77662395e-13 2.69784195e-14 4.08704182e-10]
  [4.98995376e-13 1.21680443e-13 2.89901436e-10]
  [4.54703312e-14 1.28442287e-13 0.00000000e+00]]]
T_pred_frame_xys
 tensor([[[ 489.1548,  178.1699],
         [ 351.0563,  735.7790],
         [1644.6209,  730.5700],
         [1496.0181,  177.9636]]], dtype=torch.float64)
diff w/ gt:
 [[[7.34985406e-11 4.51905180e-12]
  [7.63287744e-09 4.27741043e-08]
  [2.20211405e-08 1.07354481e-08]
  [4.26791757e-08 9.96948302e-10]]]


In [31]:
## Build embeddings+homography dataset and dataloader

batch_size = 256

class EmbeddingDataset(torch.utils.data.Dataset):
    def __init__(self, embedding_dataset_path: Path):
        self.embedding_dataset = torch.load(embedding_dataset_path)

    def __len__(self):
        return len(self.embedding_dataset)
    
    def __getitem__(self, idx: int) -> dict:
        return self.embedding_dataset[idx]

dataset = EmbeddingDataset(embedding_dataset_path)

train_dataset, val_dataset = torch.utils.data.random_split(dataset, [training_subset_ratio, 1 - training_subset_ratio])
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

next(iter(train_dataloader))["feature_embedding"].shape

torch.Size([256, 1280])

In [41]:
import tqdm.notebook as tqdm

def eval_homography_net(
        homography_net: HomographyNet,
        dataloader: torch.utils.data.DataLoader) -> tuple[float, float]:
    corner_grid_xys = homography_net.corner_grid_xys.to(device)

    total_num_samples = 0
    total_homography_l2_loss = 0.0
    total_reprojection_loss = 0.0

    homography_net.eval()
    with torch.no_grad():
        for batch in dataloader:
            feature_embedding = batch["feature_embedding"].to(device)
            H_grid_frame_gt = batch["H_grid_frame"].to(device)

            result = homography_net(feature_embedding)
            corners_frame_xys_pred = result["corners_frame_xys"]
            H_grid_frame_pred = result["H_grid_frame"]

            homography_l2_loss = F.mse_loss(H_grid_frame_pred, H_grid_frame_gt)
            total_homography_l2_loss += homography_l2_loss.item() * len(batch)

            corners_grid_xys_pred = batched_apply_homography(H_grid_frame_pred, corners_frame_xys_pred)
            reprojection_loss = F.mse_loss(corners_grid_xys_pred, corner_grid_xys)
            total_reprojection_loss += reprojection_loss.item() * len(batch)

            total_num_samples += len(batch)

        total_homography_l2_loss /= total_num_samples
        total_reprojection_loss /= total_num_samples

    homography_net.train()

    return total_homography_l2_loss, total_reprojection_loss


def train_homography_net(
        homography_net: HomographyNet,
        optimizer: torch.optim.Optimizer,
        train_dataloader: torch.utils.data.DataLoader,
        val_dataloader: torch.utils.data.DataLoader,
        num_epochs: int) -> None:
    corner_grid_xys = homography_net.corner_grid_xys.to(device)

    homography_net.train()
    optimizer.zero_grad()

    pbar = tqdm.tqdm(range(num_epochs), desc="epoch")
    for epoch_idx in pbar:
        for batch in train_dataloader:
            feature_embedding = batch["feature_embedding"].to(device)
            H_grid_frame_gt = batch["H_grid_frame"].to(device)

            result = homography_net(feature_embedding)
            corners_frame_xys_pred = result["corners_frame_xys"]
            H_grid_frame_pred = result["H_grid_frame"]

            homography_l2_loss = F.mse_loss(H_grid_frame_pred, H_grid_frame_gt)

            corners_grid_xys_pred = batched_apply_homography(H_grid_frame_pred, corners_frame_xys_pred)
            reprojection_loss = F.mse_loss(corners_grid_xys_pred, corner_grid_xys)

            # TODO: tune hparam weights
            loss = homography_l2_loss + reprojection_loss

            loss.backward()
            optimizer.step()
        
        val_homography_l2_loss, val_reprojection_loss = eval_homography_net(homography_net, val_dataloader)
        pbar.set_description(f"val_H_l2: {val_homography_l2_loss:.4f}, val_reproj: {val_reprojection_loss:.4f}")

homography_net = HomographyNet().to(device)
# TODO: tune lr
optimizer = torch.optim.AdamW(homography_net.parameters(), lr=3e-4)
train_homography_net(homography_net, optimizer, train_dataloader, val_dataloader, num_epochs=1000)
# eval_homography_net(homography_net, val_dataloader)

# TODO: why is H l2 loss so high and increasing during training? bad initialization?

epoch:   0%|          | 0/1000 [00:00<?, ?it/s]

  reprojection_loss = F.mse_loss(corners_grid_xys_pred, corner_grid_xys)
  reprojection_loss = F.mse_loss(corners_grid_xys_pred, corner_grid_xys)
  reprojection_loss = F.mse_loss(corners_grid_xys_pred, corner_grid_xys)
