In [1]:
import sys
import os

# # add parent directory, it should add parent of parent
sys.path.append("..")

import torch
from torch import nn, optim

from torchvision import transforms
from tqdm import tqdm

from models import Res18FPNCEASC  # Adjust as needed
from utils.uavdt_dataloader import get_dataset
from utils.losses import Lnorm, Lamm  # Adjust as needed

In [2]:
def safe_shape(x):
    if isinstance(x, torch.Tensor):
        return x.shape
    elif isinstance(x, (list, tuple)):
        return [safe_shape(e) for e in x]
    return type(x)

In [3]:
# get the setup 
mode = "train"  # Change to "eval" or "test" as needed

config = {
    "root_dir": "/home/eyakub/scratch/CEASC_replicate",
    "batch_size": 1,
    "num_workers": 4,
    "num_epochs": 1,
    "lr": 1e-3,
    "config_path": "configs/resnet18_fpn_feature_extractor.py",
}

In [4]:
!pwd

/home/eyakub/projects/def-kohitij/eyakub/CEASC_Replicate


In [5]:
if __name__ == "__main__":

    # Unpack config
    root_dir = config["root_dir"]
    batch_size = config["batch_size"]
    num_workers = config["num_workers"]
    num_epochs = config["num_epochs"]
    learning_rate = config["lr"]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Dataset and loader
    dataloader = get_dataset(
        root_dir=root_dir,
        split="train",
        transform=None,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
    )

    # Model
    model = Res18FPNCEASC(config_path=config["config_path"], num_classes=3)
    model.to(device)
    model.train()

    # Optimizer
    optimizer = optim.SGD(model.parameters(), lr=learning_rate)  # NOTE: Adam was not used in the paper
    
    # Losses
    l_amm = Lamm()

    batch = next(iter(dataloader))

    images = batch["image"].to(device)
    targets = {
        "boxes": batch["boxes"],
        "labels": batch["labels"],
        "image_id": batch["image_id"],
        "orig_size": batch["orig_size"],
    }
    print("\n🔍 Inspecting `targets` structure:")
    for i in range(len(targets["boxes"])):
        print(f"--- Sample {i} ---")
        print(f"Image ID:         {targets['image_id'][i]}")
        print(f"Original Size:    {targets['orig_size'][i]}")
        print(f"Boxes shape:      {targets['boxes'][i].shape}")  # [N_i, 4]
        print(f"Labels shape:     {targets['labels'][i].shape}")  # [N_i]
        print(f"Boxes:            {targets['boxes'][i]}")
        print(f"Labels:           {targets['labels'][i]}")

    # Forward pass
    outputs = model(images, stage="train")
    (
        cls_outs,
        reg_outs,
        soft_mask_outs,
        sparse_cls_feats_outs,
        sparse_reg_feats_outs,
        dense_cls_feats_outs,
        dense_reg_feats_outs,
        feats,
        anchors,
    ) = outputs

    print("\n🔍 Output shapes from model:")
    for i in range(len(cls_outs)):
        print(f"--- FPN Level {i} ---")
        print(f"cls_outs[{i}]:              {safe_shape(cls_outs[i])}")
        print(f"reg_outs[{i}]:              {safe_shape(reg_outs[i])}")
        print(
            f"soft_mask_outs[{i}]:    {safe_shape(soft_mask_outs[i])}"
        )
        print(
            f"sparse_cls_feats[{i}]:      {safe_shape(sparse_cls_feats_outs[i])}"
        )
        print(
            f"sparse_reg_feats[{i}]:      {safe_shape(sparse_reg_feats_outs[i])}"
        )
        print(
            f"dense_cls_feats[{i}]:       {safe_shape(dense_cls_feats_outs[i])}"
        )
        print(
            f"dense_reg_feats[{i}]:       {safe_shape(dense_reg_feats_outs[i])}"
        )
        print(f"feats[{i}]:                 {safe_shape(feats[i])}")
    
    for i, anchor in enumerate(anchors):
        print(f"P{i+3} Anchors shape: {anchor.shape}")

    loss_amm = l_amm(
        soft_mask_outs, targets["boxes"], im_dimx=1024, im_dimy=540
    )  # used the soft masks in this version, might be incorrect
    
    # print(f"Loss Norm: {loss_norm.item()}")
    print(f"Loss AMM: {loss_amm.item()}")

Lamm init called

🔍 Inspecting `targets` structure:
--- Sample 0 ---
Image ID:         tensor([10218])
Original Size:    tensor([ 540, 1024])
Boxes shape:      torch.Size([62, 4])
Labels shape:     torch.Size([62])
Boxes:            tensor([[577., 143., 603., 184.],
        [615.,  15., 642.,  52.],
        [590.,  53., 613.,  85.],
        [573.,  97., 597., 127.],
        [644.,  18., 668.,  51.],
        [625.,  57., 647.,  86.],
        [604.,  90., 631., 126.],
        [505., 214., 543., 234.],
        [438., 206., 473., 225.],
        [  7.,  52.,  42.,  80.],
        [  0.,  95.,  41., 130.],
        [ 58.,  80.,  94., 103.],
        [ 31., 143.,  78., 173.],
        [ 75., 134., 112., 161.],
        [ 89.,  73., 127.,  99.],
        [ 88., 117., 127., 142.],
        [157., 101., 193., 127.],
        [155., 171., 187., 195.],
        [ 96., 170., 195., 227.],
        [216., 226., 247., 250.],
        [212., 150., 249., 175.],
        [654., 296., 696., 325.],
        [206., 171.

In [None]:
print("Lamm base classes:", Lamm.__bases__)
print("l_amm type:", type(l_amm))
print("isinstance l_amm of nn.Module:", isinstance(l_amm, torch.nn.Module))
print("dir(l_amm):", dir(l_amm))