# notebook for overfitting AMM loss on one image from VisDrone

In [2]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
import sys
import os
import pickle

# add parent directory, it should add parent of parent
sys.path.append("..")

import torch
from torch import nn, optim

from torch.utils.tensorboard import SummaryWriter

from torchvision import transforms
from tqdm import tqdm

from models import Res18FPNCEASC  # Adjust as needed
from utils.visdrone_dataloader import get_dataset
from utils.losses import Lnorm, Lamm  # Adjust as needed

In [4]:
def safe_shape(x):
    if isinstance(x, torch.Tensor):
        return x.shape
    elif isinstance(x, (list, tuple)):
        return [safe_shape(e) for e in x]
    return type(x)

In [5]:
# get the setup 
mode = "train"  # Change to "eval" or "test" as needed

config = {
    "root_dir": "/home/soroush1/scratch/eecs_project",
    "batch_size": 1,
    "num_workers": 4,
    "num_epochs": 1,
    "lr": 1e-2,
    "config_path": "../configs/resnet18_fpn_feature_extractor.py",
}

In [7]:
if __name__ == "__main__":

    # Get the dictionary for feature visualization
    vis_dict = {}

    # Unpack config
    root_dir = config["root_dir"]
    batch_size = config["batch_size"]
    num_workers = config["num_workers"]
    num_epochs = config["num_epochs"]
    learning_rate = config["lr"]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Dataset and loader
    dataloader = get_dataset(
        root_dir=root_dir,
        split="train",
        transform=None,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
    )

    # Model
    model = Res18FPNCEASC(config_path=config["config_path"], num_classes=10)
    model.to(device)
    model.train()

    # Optimizer
    optimizer = optim.SGD(model.parameters(), lr=learning_rate)  
    
    # Losses
    l_amm = Lamm()

    batch = next(iter(dataloader))

    images = batch["image"].to(device)
    targets = {
        "boxes": batch["boxes"],
        "labels": batch["labels"],
        "image_id": batch["image_id"],
        "orig_size": batch["orig_size"],
    }
    # print("\n🔍 Inspecting `targets` structure:")
    # for i in range(len(targets["boxes"])):
    #     print(f"--- Sample {i} ---")
    #     print(f"Image ID:         {targets['image_id'][i]}")
    #     print(f"Original Size:    {targets['orig_size'][i]}")
    #     print(f"Boxes shape:      {targets['boxes'][i].shape}")  # [N_i, 4]
    #     print(f"Labels shape:     {targets['labels'][i].shape}")  # [N_i]
    #     print(f"Boxes:            {targets['boxes'][i]}")
    #     print(f"Labels:           {targets['labels'][i]}")

    vis_dict["image_id"] = targets["image_id"]

    n_iters = 1000

    writer = SummaryWriter()
    
    for n in range(n_iters):

        optimizer.zero_grad()
    
        # Forward pass
        outputs = model(images, stage="train")
        (
            cls_outs,
            reg_outs,
            soft_mask_outs,
            sparse_cls_feats_outs,
            sparse_reg_feats_outs,
            dense_cls_feats_outs,
            dense_reg_feats_outs,
            feats,
            anchors,
        ) = outputs

        if n == 24 or n == 49 or n ==99:
            vis_dict[f"{n}_mask"] = [s.clone().detach() for s in soft_mask_outs]

        # print("\n🔍 Output shapes from model:")
        # for i in range(len(cls_outs)):
        #     print(f"--- FPN Level {i} ---")
        #     print(f"cls_outs[{i}]:              {safe_shape(cls_outs[i])}")
        #     print(f"reg_outs[{i}]:              {safe_shape(reg_outs[i])}")
        #     print(
        #         f"soft_mask_outs[{i}]:    {safe_shape(soft_mask_outs[i])}"
        #     )
        #     print(
        #         f"sparse_cls_feats[{i}]:      {safe_shape(sparse_cls_feats_outs[i])}"
        #     )
        #     print(
        #         f"sparse_reg_feats[{i}]:      {safe_shape(sparse_reg_feats_outs[i])}"
        #     )
        #     print(
        #         f"dense_cls_feats[{i}]:       {safe_shape(dense_cls_feats_outs[i])}"
        #     )
        #     print(
        #         f"dense_reg_feats[{i}]:       {safe_shape(dense_reg_feats_outs[i])}"
        #     )
        #     print(f"feats[{i}]:                 {safe_shape(feats[i])}")
        
        # for i, anchor in enumerate(anchors):
        #     print(f"P{i+3} Anchors shape: {anchor.shape}")

        loss_amm = l_amm(
            soft_mask_outs, targets["boxes"], im_dimx=1333, im_dimy=800
        )  
    
        print(f"Loss AMM, iter {n}: {loss_amm.item()}")

        writer.add_scalar('AMM Loss/overfit',loss_amm.item(),n)

        loss_amm.backward()
    
        optimizer.step()

    writer.close()
    print('Overfit complete')

    with open("visdrone_masks.pickle","wb") as f:
        pickle.dump(vis_dict,f)

Lamm init called
Loss AMM, iter 0: 0.2353719025850296
Loss AMM, iter 1: 0.23240871727466583
Loss AMM, iter 2: 0.22862663865089417
Loss AMM, iter 3: 0.23308204114437103
Loss AMM, iter 4: 0.23013949394226074
Loss AMM, iter 5: 0.222884401679039
Loss AMM, iter 6: 0.21268267929553986
Loss AMM, iter 7: 0.21511700749397278
Loss AMM, iter 8: 0.20486898720264435
Loss AMM, iter 9: 0.20724628865718842
Loss AMM, iter 10: 0.1933915913105011
Loss AMM, iter 11: 0.19243361055850983
Loss AMM, iter 12: 0.1890181005001068
Loss AMM, iter 13: 0.1800224781036377
Loss AMM, iter 14: 0.18674185872077942
Loss AMM, iter 15: 0.18616944551467896
Loss AMM, iter 16: 0.1758638173341751
Loss AMM, iter 17: 0.17470967769622803
Loss AMM, iter 18: 0.1687769889831543
Loss AMM, iter 19: 0.16449229419231415
Loss AMM, iter 20: 0.164544939994812
Loss AMM, iter 21: 0.16062717139720917
Loss AMM, iter 22: 0.14974288642406464
Loss AMM, iter 23: 0.15504346787929535
Loss AMM, iter 24: 0.14641724526882172
Loss AMM, iter 25: 0.1478549