In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys

sys.path.append("..")

In [3]:
!module list


Currently Loaded Modules:
  1) CCconfig            7) libfabric/1.18.0      13) StdEnv/2023       ([1;31mS[0m)
  2) gentoo/2023   ([1;31mS[0m)   8) pmix/4.2.4            14) mii/1.1.2
  3) [2mgcccore/.12.3[0m (H)   9) ucc/1.2.0             15) python/3.11.5     ([1;34mt[0m)
  4) gcc/12.3      ([1;34mt[0m)  10) openmpi/4.1.5    ([1;31mm[0m)  16) ipykernel/2025a
  5) hwloc/2.9.1        11) flexiblas/3.3.1       17) scipy-stack/2025a ([1;32mmath[0m)
  6) ucx/1.14.1         12) blis/0.9.0            18) opencv/4.11.0     ([1;34mvis[0m)

  Where:
   [1;31mS[0m:     Module is Sticky, requires --force to unload or purge
   [1;31mm[0m:     MPI implementations / Implémentations MPI
   [1;32mmath[0m:  Mathematical libraries / Bibliothèques mathématiques
   [1;34mt[0m:     Tools for development / Outils de développement
   [1;34mvis[0m:   Visualisation software / Logiciels de visualisation
   H:                Hidden Module

 



In [4]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [4]:
from models import Res18FPNCEASC  # Adjust as needed
from utils.dataset import get_dataset
from utils.losses import Lnorm, Lamm, LDetection  # Adjust as needed

In [5]:
import torch
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import torch
import torchvision.transforms.functional as TF
import torchvision.ops as ops
import torch.nn as nn

In [6]:
def safe_shape(x):
    if isinstance(x, torch.Tensor):
        return x.shape
    elif isinstance(x, (list, tuple)):
        return [safe_shape(e) for e in x]
    return type(x)

In [7]:
config = {
        "root_dir": "/home/soroush1/scratch/eecs_project",
        "batch_size": 4,
        "num_workers": 4,
        "num_epochs": 1,
        "lr": 1e-3,
        "config_path": "../configs/resnet18_fpn_feature_extractor.py",
    }

# Unpack config
root_dir = config["root_dir"]
batch_size = config["batch_size"]
num_workers = config["num_workers"]
num_epochs = config["num_epochs"]
learning_rate = config["lr"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dataset and loader
dataloader = get_dataset(
    root_dir=root_dir,
    split="train",
    transform=None,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
)

num_classes = 10

 # Model
model = Res18FPNCEASC(config_path=config["config_path"], num_classes=num_classes)
model.to(device)
model.eval()

# Losses
l_norm = Lnorm()
l_amm = Lamm()
l_det = LDetection(num_classes=num_classes, num_bins=16, top_k=9)

batch = next(iter(dataloader))

Lamm init called


In [8]:
images = batch["image"].to(device)
targets = {
    "boxes": batch["boxes"],
    "labels": [lbl.clamp(0, num_classes - 1) for lbl in batch["labels"]],
    "image_id": batch["image_id"],
    "orig_size": batch["orig_size"],
}

print("\n🔍 Inspecting `targets` structure:")
for i in range(len(targets["boxes"])):
    print(f"--- Sample {i} ---")
    print(f"Image ID:         {targets['image_id'][i]}")
    print(f"Original Size:    {targets['orig_size'][i]}")
    print(f"Boxes shape:      {targets['boxes'][i].shape}")  # [N_i, 4]
    print(f"Labels shape:     {targets['labels'][i].shape}")  # [N_i]
    print(f"Boxes:            {targets['boxes'][i]}")
    print(f"Labels:           {targets['labels'][i]}")


🔍 Inspecting `targets` structure:
--- Sample 0 ---
Image ID:         tensor([0])
Original Size:    tensor([540, 960])
Boxes shape:      torch.Size([82, 4])
Labels shape:     torch.Size([82])
Boxes:            tensor([[708., 471., 782., 504.],
        [639., 425., 700., 471.],
        [594., 399., 658., 450.],
        [562., 390., 623., 428.],
        [540., 372., 605., 405.],
        [514., 333., 582., 368.],
        [501., 317., 565., 348.],
        [501., 299., 546., 327.],
        [489., 284., 537., 311.],
        [463., 262., 511., 291.],
        [458., 252., 507., 274.],
        [448., 242., 493., 262.],
        [442., 230., 491., 249.],
        [439., 214., 484., 235.],
        [429., 208., 471., 227.],
        [420., 199., 463., 219.],
        [398., 188., 439., 206.],
        [ 46., 391.,  60., 417.],
        [421., 433., 495., 477.],
        [369., 346., 433., 380.],
        [398., 410., 470., 456.],
        [394., 393., 464., 429.],
        [377., 364., 448., 402.],
        

In [11]:
# Forward pass
outputs = model(images, stage="train")
(
    cls_outs,
    reg_outs,
    soft_mask_outs,
    sparse_cls_feats_outs,
    sparse_reg_feats_outs,
    dense_cls_feats_outs,
    dense_reg_feats_outs,
    feats,
    anchors,
) = outputs

In [None]:
print("\n🔍 Output shapes from model:")
for i in range(len(cls_outs)):
    print(f"--- FPN Level {i} ---")
    print(f"cls_outs[{i}]:              {safe_shape(cls_outs[i])}")
    print(f"reg_outs[{i}]:              {safe_shape(reg_outs[i])}")
    print(
        f"cls_soft_mask_outs[{i}]:    {safe_shape(cls_soft_mask_outs[i])}"
    )
    print(
        f"reg_soft_mask_outs[{i}]:    {safe_shape(reg_soft_mask_outs[i])}"
    )
    print(
        f"sparse_cls_feats[{i}]:      {safe_shape(sparse_cls_feats_outs[i])}"
    )
    print(
        f"sparse_reg_feats[{i}]:      {safe_shape(sparse_reg_feats_outs[i])}"
    )
    print(
        f"dense_cls_feats[{i}]:       {safe_shape(dense_cls_feats_outs[i])}"
    )
    print(
        f"dense_reg_feats[{i}]:       {safe_shape(dense_reg_feats_outs[i])}"
    )
    print(f"feats[{i}]:                 {safe_shape(feats[i])}")

for i, anchor in enumerate(anchors):
    print(f"P{i+3} Anchors shape: {anchor.shape}")

In [None]:
losses_per_sample = []

for i in range(len(targets["boxes"])):
    boxes = [targets["boxes"][i].to(device)]  # keep list structure
    im_h, im_w = targets["orig_size"][i]
    im_dim = (int(im_w.item()), int(im_h.item()))  # convert to (W, H)

    # Extract per-sample soft mask for all levels
    soft_mask_sample = [f[i].unsqueeze(0) for f in cls_soft_mask_outs]  # list of [1, 1, H, W]

    loss_i = l_amm(soft_mask_sample, boxes, im_dim=im_dim)
    losses_per_sample.append(loss_i)

loss_amm = sum(losses_per_sample) / len(losses_per_sample)
loss_amm

In [None]:
loss_norm = l_norm(dense_cls_feats_outs, [[m]*4 for m in cls_soft_mask_outs], sparse_cls_feats_outs)
loss_norm

In [None]:
def plot_anchors_on_image(image_tensor, anchors, num_to_plot=100, title="Anchors", color="red"):
    """
    Plots anchor boxes on an image.

    Args:
        image_tensor (Tensor): shape (3, H, W)
        anchors (Tensor): shape (N, 4), format (x1, y1, x2, y2)
        num_to_plot (int): number of anchor boxes to plot
        title (str): title of the plot
        color (str): color of anchor boxes
    """
    # Convert tensor to numpy for visualization
    if isinstance(image_tensor, torch.Tensor):
        # If image is a tensor (transformed), convert back to numpy
        img = image_tensor.permute(1, 2, 0).detach().cpu().numpy()
        # Unnormalize if normalized
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img = img * std + mean
        img = np.clip(img, 0, 1)

    
    image = TF.to_pil_image(image_tensor.cpu())
    anchors_np = anchors.cpu().numpy()

    fig, ax = plt.subplots(1, figsize=(12, 8))
    ax.imshow(image)
    ax.set_title(title)

    for i in range(min(num_to_plot, len(anchors_np))):
        x1, y1, x2, y2 = anchors_np[i]
        rect = patches.Rectangle(
            (x1, y1),
            x2 - x1,
            y2 - y1,
            linewidth=1,
            edgecolor=color,
            facecolor='none'
        )
        ax.add_patch(rect)

    plt.axis("off")
    plt.show()

ds = dataloader.dataset
plot_anchors_on_image(images[3], anchors[0], num_to_plot=5000, title="Anchors at FPN Level 0")

In [None]:
def visualize_item(image_tensor, boxes, figsize=(10, 10)):
        """
        Visualize an image with its annotations

        Args:
            idx (int): Index of the item to visualize
            figsize (tuple): Figure size
        """
        import matplotlib.pyplot as plt
        import matplotlib.patches as patches
        from matplotlib.colors import to_rgba

        # Convert tensor to numpy for visualization
        if isinstance(image_tensor, torch.Tensor):
            # If image is a tensor (transformed), convert back to numpy
            img = image_tensor.permute(1, 2, 0).cpu().numpy()
            # Unnormalize if normalized
            mean = np.array([0.485, 0.456, 0.406])
            std = np.array([0.229, 0.224, 0.225])
            img = img * std + mean
            img = np.clip(img, 0, 1)
        else:
            # If image is PIL, convert to numpy
            img = np.array(sample["image"]) / 255.0

        # Create figure and axis
        fig, ax = plt.subplots(1, figsize=figsize)
        ax.imshow(img)

        # Define colors for different categories (you can customize these)
        colors = [
            "red",
            "blue",
            "green",
            "yellow",
            "purple",
            "orange",
            "cyan",
            "magenta",
            "brown",
            "pink",
        ]

        # Plot bounding boxes
        for box in boxes:
            # print(f"{box.size() = }")
            x1, y1, x2, y2 = box
            width = x2 - x1
            height = y2 - y1

            # Get color based on category
            # color = colors[(label - 1) % len(colors)]

            # Create rectangle
            rect = patches.Rectangle(
                (x1, y1), width, height, linewidth=2, edgecolor="red", facecolor="none"
            )
            ax.add_patch(rect)

        # plt.title(f"Image: {sample['img_name']} - {len(boxes)} objects")
        plt.axis("off")
        plt.tight_layout()

        # save instead of show
        plt.savefig("test.png")
        plt.close()

visualize_item(images[3], targets["boxes"][3])

In [None]:
targets["boxes"][0]

In [None]:
from mmdet.models.task_modules.prior_generators import AnchorGenerator
import torch

img_height, img_width = 765, 1360
device = 'cuda' if torch.cuda.is_available() else 'cpu'

anchor_generator = AnchorGenerator(
    strides=[4, 8, 16, 32, 64],
    ratios=[0.5, 1.0, 2.0],
    scales=[8, 16],
    base_sizes=[16, 32, 64, 128, 256],
)

# Fix: extract stride from tuple
feature_map_sizes = [
    (img_height // s[0], img_width // s[1]) for s in anchor_generator.strides
]

# Generate anchors in image space
multi_level_anchors = anchor_generator.grid_priors(
    featmap_sizes=feature_map_sizes,
    dtype=torch.float32,
    device=device
)

all_anchors = torch.cat(multi_level_anchors, dim=0)

print("Total anchors:", all_anchors.shape)
print("Sample anchors:\n", all_anchors[:5])


# Implemeting Detection loss

In [16]:
class ATSSMatcher:
    def __init__(self, top_k=9):
        self.top_k = top_k  # number of anchors to select per level

    def __call__(self, anchors_per_level, gt_boxes, device = None):
        """
        anchors_per_level: List[Tensor[N_i, 4]] in (x1, y1, x2, y2) format
        gt_boxes: Tensor[M, 4]
        Returns:
            matched_idxs: Tensor[N_total] with GT index or -1
            max_ious: Tensor[N_total]
        """
        
        num_gt = gt_boxes.size(0)
        all_anchors = torch.cat(anchors_per_level, dim=0)  # [N_total, 4]
        num_anchors = all_anchors.size(0)

        if device:
            all_anchors = all_anchors.to(device)
            gt_boxes = gt_boxes.to(device)

        matched_idxs = torch.full((num_anchors,), -1, dtype=torch.long, device=gt_boxes.device)
        max_ious = torch.zeros(num_anchors, dtype=torch.float, device=gt_boxes.device)

        # 1. Compute IoU between all anchors and GTs
        ious = ops.box_iou(all_anchors, gt_boxes)  # [N_total, M]

        # 2. Compute anchor centers
        anchor_centers = (all_anchors[:, :2] + all_anchors[:, 2:]) / 2  # [N, 2]
        gt_centers = (gt_boxes[:, :2] + gt_boxes[:, 2:]) / 2  # [M, 2]

        for gt_idx in range(num_gt):
            gt_box = gt_boxes[gt_idx]
            gt_center = gt_centers[gt_idx]  # [2]

            # Distance from GT center to anchor centers
            distances = torch.norm(anchor_centers - gt_center[None, :], dim=1)  # [N]

            # Pick top-k closest anchors
            topk_idxs = torch.topk(distances, self.top_k, largest=False).indices  # [top_k]

            topk_ious = ious[topk_idxs, gt_idx]
            iou_mean = topk_ious.mean()
            iou_std = topk_ious.std()
            dynamic_thresh = iou_mean + iou_std

            # Positive = anchors with IoU >= dynamic_thresh and inside GT
            candidate_mask = ious[:, gt_idx] >= dynamic_thresh

            inside_gt = self.anchor_inside_box(all_anchors, gt_box)
            pos_mask = candidate_mask & inside_gt  # [N]

            pos_indices = pos_mask.nonzero(as_tuple=False).squeeze(1)
            matched_idxs[pos_indices] = gt_idx
            max_ious[pos_indices] = ious[pos_indices, gt_idx]

        return matched_idxs, max_ious

    def anchor_inside_box(self, anchors, gt_box):
        """
        Return a mask of anchors whose center is inside the GT box.
        """
        cx = (anchors[:, 0] + anchors[:, 2]) / 2
        cy = (anchors[:, 1] + anchors[:, 3]) / 2

        return (
            (cx >= gt_box[0]) & (cx <= gt_box[2]) &
            (cy >= gt_box[1]) & (cy <= gt_box[3])
        )

In [44]:
matcher = ATSSMatcher(top_k=9)
matched_idxs = []
max_ious = []

for batch_id in range(len(targets["boxes"])):
    
    matched_idx, iou = matcher(anchors, targets["boxes"][batch_id], device=device)  # for image i
    matched_idxs.append(matched_idx)
    max_ious.append(iou)

matched_idxs[0].size(), max_ious[0].size()

(torch.Size([534138]), torch.Size([534138]))

In [45]:
all_anchors = torch.cat(anchors, dim=0)  # [N_total, 4]
num_total_anchors = all_anchors.size(0)
batch_size = len(targets["boxes"])
cls_targets = torch.zeros((batch_size, num_total_anchors), dtype=torch.long, device=device)

for batch_idx in range(batch_size):
    # For positive anchors, assign the GT class label
    pos_mask = matched_idxs[batch_idx] >= 0
    if pos_mask.any():
        gt_indices = matched_idxs[batch_idx][pos_mask]
        target = targets["labels"][batch_idx].to(device)
        cls_targets[batch_idx, pos_mask] = target[gt_indices]

cls_targets.size()

torch.Size([4, 534138])

In [46]:
iou_targets = torch.zeros((batch_size, num_total_anchors), dtype=torch.float, device=device)
for batch_idx in range(batch_size):
    pos_mask = matched_idxs[batch_idx] >= 0
    if pos_mask.any():
        iou_targets[batch_idx, pos_mask] = max_ious[batch_idx][pos_mask]

iou_targets.size()

torch.Size([4, 534138])

In [47]:
reg_targets = torch.zeros((batch_size, num_total_anchors, 4), dtype=torch.float, device=device)
for batch_idx in range(batch_size):
    pos_mask = matched_idxs[batch_idx] >= 0
    if pos_mask.any():
        gt_indices = matched_idxs[batch_idx][pos_mask]
        bbox = targets["boxes"][batch_idx].to(device)
        
        gt_boxes_matched = bbox[gt_indices]
        reg_targets[batch_idx, pos_mask] = gt_boxes_matched

reg_targets.size()

torch.Size([4, 534138, 4])

In [48]:
# Prepare Model Predictions

batch_size = cls_outs[0].shape[0]
all_cls_preds = []

for level_idx, cls_pred in enumerate(cls_outs):
    # [B, C, H, W] -> [B, H, W, C] -> [B, H*W*A, C]
    B, C, H, W = cls_pred.shape
    cls_pred = cls_pred.permute(0, 2, 3, 1).reshape(B, -1, C)
    print(f"{cls_pred.size() = }")
    all_cls_preds.append(cls_pred)
    
cls_preds = torch.cat(all_cls_preds, dim=1)  # [B, N, C]

cls_preds.size()

cls_pred.size() = torch.Size([4, 66800, 10])
cls_pred.size() = torch.Size([4, 16700, 10])
cls_pred.size() = torch.Size([4, 4200, 10])
cls_pred.size() = torch.Size([4, 1050, 10])
cls_pred.size() = torch.Size([4, 273, 10])


torch.Size([4, 89023, 10])