In [None]:
# notebook for overfitting AMM loss on one image from VisDrone

In [7]:
import sys
import os
import pickle

# add parent directory, it should add parent of parent
sys.path.append("..")

import torch
from torch import nn, optim

from torch.utils.tensorboard import SummaryWriter

from torchvision import transforms
from tqdm import tqdm

from models import Res18FPNCEASC  # Adjust as needed
from utils.dataset import get_dataset
from utils.losses import Lnorm, Lamm  # Adjust as needed

In [8]:
def safe_shape(x):
    if isinstance(x, torch.Tensor):
        return x.shape
    elif isinstance(x, (list, tuple)):
        return [safe_shape(e) for e in x]
    return type(x)

In [9]:
# get the setup 
mode = "train"  # Change to "eval" or "test" as needed

config = {
    "root_dir": "/home/eyakub/scratch/CEASC_replicate",
    "batch_size": 1,
    "num_workers": 4,
    "num_epochs": 1,
    "lr": 1e-2,
    "config_path": "../configs/resnet18_fpn_feature_extractor.py",
}

In [10]:
if __name__ == "__main__":

    # Get the dictionary for feature visualization
    vis_dict = {}

    # Unpack config
    root_dir = config["root_dir"]
    batch_size = config["batch_size"]
    num_workers = config["num_workers"]
    num_epochs = config["num_epochs"]
    learning_rate = config["lr"]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Dataset and loader
    dataloader = get_dataset(
        root_dir=root_dir,
        split="train",
        transform=None,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
    )

    # Model
    model = Res18FPNCEASC(config_path=config["config_path"], num_classes=10)
    model.to(device)
    model.train()

    # Optimizer
    optimizer = optim.SGD(model.parameters(), lr=learning_rate)  
    
    # Losses
    l_amm = Lamm()

    batch = next(iter(dataloader))

    images = batch["image"].to(device)
    targets = {
        "boxes": batch["boxes"],
        "labels": batch["labels"],
        "image_id": batch["image_id"],
        "orig_size": batch["orig_size"],
    }
    print("\n🔍 Inspecting `targets` structure:")
    for i in range(len(targets["boxes"])):
        print(f"--- Sample {i} ---")
        print(f"Image ID:         {targets['image_id'][i]}")
        print(f"Original Size:    {targets['orig_size'][i]}")
        print(f"Boxes shape:      {targets['boxes'][i].shape}")  # [N_i, 4]
        print(f"Labels shape:     {targets['labels'][i].shape}")  # [N_i]
        print(f"Boxes:            {targets['boxes'][i]}")
        print(f"Labels:           {targets['labels'][i]}")

    vis_dict["image_id"] = targets["image_id"]

    n_iters = 100

    writer = SummaryWriter()
    
    for n in range(n_iters):

        optimizer.zero_grad()
    
        # Forward pass
        outputs = model(images, stage="train")
        (
            cls_outs,
            reg_outs,
            soft_mask_outs,
            sparse_cls_feats_outs,
            sparse_reg_feats_outs,
            dense_cls_feats_outs,
            dense_reg_feats_outs,
            feats,
            anchors,
        ) = outputs

        if n == 24 or n == 49 or n ==99:
            vis_dict[f"{n}_mask"] = [s.clone().detach() for s in soft_mask_outs]

        # print("\n🔍 Output shapes from model:")
        # for i in range(len(cls_outs)):
        #     print(f"--- FPN Level {i} ---")
        #     print(f"cls_outs[{i}]:              {safe_shape(cls_outs[i])}")
        #     print(f"reg_outs[{i}]:              {safe_shape(reg_outs[i])}")
        #     print(
        #         f"soft_mask_outs[{i}]:    {safe_shape(soft_mask_outs[i])}"
        #     )
        #     print(
        #         f"sparse_cls_feats[{i}]:      {safe_shape(sparse_cls_feats_outs[i])}"
        #     )
        #     print(
        #         f"sparse_reg_feats[{i}]:      {safe_shape(sparse_reg_feats_outs[i])}"
        #     )
        #     print(
        #         f"dense_cls_feats[{i}]:       {safe_shape(dense_cls_feats_outs[i])}"
        #     )
        #     print(
        #         f"dense_reg_feats[{i}]:       {safe_shape(dense_reg_feats_outs[i])}"
        #     )
        #     print(f"feats[{i}]:                 {safe_shape(feats[i])}")
        
        # for i, anchor in enumerate(anchors):
        #     print(f"P{i+3} Anchors shape: {anchor.shape}")

        loss_amm = l_amm(
            soft_mask_outs, targets["boxes"], im_dimx=1333, im_dimy=800
        )  
    
        print(f"Loss AMM, iter {n}: {loss_amm.item()}")

        writer.add_scalar('AMM Loss/overfit',loss_amm.item(),n)

        loss_amm.backward()
    
        optimizer.step()

    writer.close()
    print('Overfit complete')

    with open("visdrone_masks.pickle","wb") as f:
        pickle.dump(vis_dict,f)

Lamm init called

🔍 Inspecting `targets` structure:
--- Sample 0 ---
Image ID:         tensor([5578])
Original Size:    tensor([1050, 1400])
Boxes shape:      torch.Size([103, 4])
Labels shape:     torch.Size([103])
Boxes:            tensor([[ 479.,  922.,  497.,  959.],
        [1372.,  949., 1384.,  985.],
        [1103.,  697., 1108.,  708.],
        [1099.,  690., 1105.,  701.],
        [1071.,  681., 1076.,  692.],
        [1021.,  634., 1024.,  641.],
        [1020.,  636., 1023.,  644.],
        [ 635.,  678.,  639.,  687.],
        [ 623.,  663.,  627.,  674.],
        [ 766.,  640.,  769.,  650.],
        [ 795.,  638.,  799.,  646.],
        [ 799.,  620.,  802.,  628.],
        [ 907.,  551.,  909.,  554.],
        [ 911.,  552.,  913.,  554.],
        [ 914.,  552.,  917.,  554.],
        [ 918.,  552.,  922.,  555.],
        [ 913.,  559.,  916.,  561.],
        [ 923.,  555.,  926.,  557.],
        [ 918.,  559.,  921.,  562.],
        [ 922.,  560.,  925.,  563.],
      