In [None]:
# notebook for overfitting AMM loss on one image from UAVDT 

In [9]:
import sys
import os

# add parent directory, it should add parent of parent
sys.path.append("..")

import torch
from torch import nn, optim

from torch.utils.tensorboard import SummaryWriter

from torchvision import transforms
from tqdm import tqdm

from models import Res18FPNCEASC  # Adjust as needed
from utils.uavdt_dataloader import get_dataset
from utils.losses import Lnorm, Lamm  # Adjust as needed

In [10]:
def safe_shape(x):
    if isinstance(x, torch.Tensor):
        return x.shape
    elif isinstance(x, (list, tuple)):
        return [safe_shape(e) for e in x]
    return type(x)

In [11]:
!pwd

/lustre06/project/6067616/eyakub/CEASC_Replicate/notebooks


In [12]:
# get the setup 
mode = "train"  # Change to "eval" or "test" as needed

config = {
    "root_dir": "/home/eyakub/scratch/CEASC_replicate",
    "batch_size": 1,
    "num_workers": 4,
    "num_epochs": 1,
    "lr": 1e-2,
    "config_path": "../configs/resnet18_fpn_feature_extractor.py",
}

In [16]:
if __name__ == "__main__":

    # Unpack config
    root_dir = config["root_dir"]
    batch_size = config["batch_size"]
    num_workers = config["num_workers"]
    num_epochs = config["num_epochs"]
    learning_rate = config["lr"]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Dataset and loader
    dataloader = get_dataset(
        root_dir=root_dir,
        split="train",
        transform=None,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
    )

    # Model
    model = Res18FPNCEASC(config_path=config["config_path"], num_classes=3)
    model.to(device)
    model.train()

    # Optimizer
    optimizer = optim.SGD(model.parameters(), lr=learning_rate)  
    
    # Losses
    l_amm = Lamm()

    batch = next(iter(dataloader))

    images = batch["image"].to(device)
    targets = {
        "boxes": batch["boxes"],
        "labels": batch["labels"],
        "image_id": batch["image_id"],
        "orig_size": batch["orig_size"],
    }
    print("\n🔍 Inspecting `targets` structure:")
    for i in range(len(targets["boxes"])):
        print(f"--- Sample {i} ---")
        print(f"Image ID:         {targets['image_id'][i]}")
        print(f"Original Size:    {targets['orig_size'][i]}")
        print(f"Boxes shape:      {targets['boxes'][i].shape}")  # [N_i, 4]
        print(f"Labels shape:     {targets['labels'][i].shape}")  # [N_i]
        print(f"Boxes:            {targets['boxes'][i]}")
        print(f"Labels:           {targets['labels'][i]}")

    n_iters = 100

    writer = SummaryWriter()
    
    for n in range(n_iters):

        optimizer.zero_grad()
    
        # Forward pass
        outputs = model(images, stage="train")
        (
            cls_outs,
            reg_outs,
            soft_mask_outs,
            sparse_cls_feats_outs,
            sparse_reg_feats_outs,
            dense_cls_feats_outs,
            dense_reg_feats_outs,
            feats,
            anchors,
        ) = outputs

        # print("\n🔍 Output shapes from model:")
        # for i in range(len(cls_outs)):
        #     print(f"--- FPN Level {i} ---")
        #     print(f"cls_outs[{i}]:              {safe_shape(cls_outs[i])}")
        #     print(f"reg_outs[{i}]:              {safe_shape(reg_outs[i])}")
        #     print(
        #         f"soft_mask_outs[{i}]:    {safe_shape(soft_mask_outs[i])}"
        #     )
        #     print(
        #         f"sparse_cls_feats[{i}]:      {safe_shape(sparse_cls_feats_outs[i])}"
        #     )
        #     print(
        #         f"sparse_reg_feats[{i}]:      {safe_shape(sparse_reg_feats_outs[i])}"
        #     )
        #     print(
        #         f"dense_cls_feats[{i}]:       {safe_shape(dense_cls_feats_outs[i])}"
        #     )
        #     print(
        #         f"dense_reg_feats[{i}]:       {safe_shape(dense_reg_feats_outs[i])}"
        #     )
        #     print(f"feats[{i}]:                 {safe_shape(feats[i])}")
        
        # for i, anchor in enumerate(anchors):
        #     print(f"P{i+3} Anchors shape: {anchor.shape}")

        loss_amm = l_amm(
            soft_mask_outs, targets["boxes"], im_dimx=1024, im_dimy=540
        )  
    
        print(f"Loss AMM, iter {n}: {loss_amm.item()}")

        writer.add_scalar('AMM Loss/overfit',loss_amm.item(),n)

        loss_amm.backward()
    
        optimizer.step()

    writer.close()
    print('Overfit complete')

Lamm init called

🔍 Inspecting `targets` structure:
--- Sample 0 ---
Image ID:         tensor([19962])
Original Size:    tensor([ 540, 1024])
Boxes shape:      torch.Size([21, 4])
Labels shape:     torch.Size([21])
Boxes:            tensor([[ 16., 297.,  61., 328.],
        [ 91., 306., 143., 337.],
        [459., 389., 506., 437.],
        [462., 320., 505., 356.],
        [248., 273., 285., 299.],
        [223., 315., 270., 345.],
        [104., 348., 155., 379.],
        [  2., 444.,  58., 477.],
        [359., 266., 417., 303.],
        [255., 362., 304., 393.],
        [548., 391., 598., 441.],
        [354., 348., 406., 373.],
        [674., 484., 719., 532.],
        [648., 500., 693., 538.],
        [629., 129., 665., 152.],
        [648., 134., 695., 169.],
        [696., 125., 725., 148.],
        [744.,  96., 774., 122.],
        [808.,  46., 837.,  68.],
        [787.,  72., 820.,  97.],
        [532., 132., 558., 165.]])
Labels:           tensor([0, 0, 0, 0, 0, 0, 0, 0, 1,