# ConvNeXt-based v-CLR (VOC → Non-VOC) – Fast Training, AMP, Partial Fine-tuning

This notebook implements a ConvNeXt-based v-CLR training pipeline with:

- Precomputed `train_items` (no COCO object inside the Dataset) so DataLoader workers work on Windows.
- Random selection of depth vs. stylized view per step (always with the natural view).
- Mixed precision (AMP) for faster GPU training and lower memory.
- ConvNeXt-tiny backbone initialized from ImageNet and partially fine-tuned (last few stages unfrozen).


## v-CLR Paper Alignment Notes

This implementation attempts to recreate the v-CLR (View-Consistent Learning) approach from the paper 
"v-CLR: View-Consistent Learning for Open-World Instance Segmentation" (arXiv:2504.01383) using a CNN 
backbone (ConvNeXt) instead of the original transformer-based MaskDINO architecture.

### Key Differences from Original Paper:

1. **Architecture**: Uses ConvNeXt-tiny backbone instead of MaskDINO/Swin Transformer
2. **Detection Only**: Current implementation focuses on bounding box detection only 
   (masks are not yet implemented - set to `None`)
3. **Dense Prediction**: Uses a dense prediction head instead of DETR-style object queries

### Implementation Corrections Applied:

1. **VOC Class Names**: Fixed to use the correct 20 PASCAL VOC categories in COCO naming:
   - Removed incorrect classes: `truck`, `elephant`, `bear`, `zebra`, `giraffe`
   - Added missing VOC classes: `bottle`, `chair`, `couch`, `potted plant`, `dining table`

2. **GT Target Application**: Fixed training loop to apply ground-truth supervision only to the 
   natural view (as specified in v-CLR). Depth and stylized views are now trained via 
   view-consistency losses (L_obj, L_sim) only, not with GT annotations.

### Known Limitations:

1. **No Mask Prediction**: The original v-CLR paper includes instance segmentation masks. 
   This implementation only produces bounding boxes.

2. **Dense vs Query-based**: The paper uses object queries (like DETR/MaskDINO), while this 
   implementation uses dense per-pixel predictions reshaped as queries.

3. **Temperature Scaling**: The cosine similarity loss could benefit from temperature scaling 
   (common in contrastive learning).


In [1]:
import os
import json
import math
import random
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import transforms as T
from torchvision.ops import box_iou, generalized_box_iou
from torchvision.models import convnext_tiny

from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from PIL import Image

# Speed/predictability settings
torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision("high")

SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

if not torch.cuda.is_available():
    raise RuntimeError("CUDA is not available – this notebook expects a GPU.")
device = torch.device("cuda")
print("Using device:", device, torch.cuda.get_device_name(0))


Using device: cuda NVIDIA GeForce RTX 4080


In [2]:
# === Configuration ===

DATA_ROOT = Path(r"C:\workspace\vclr\datasets")

COCO_ROOT = DATA_ROOT
COCO_TRAIN_JSON = COCO_ROOT / "annotations" / "instances_train2017.json"
COCO_VAL_JSON   = COCO_ROOT / "annotations" / "instances_val2017.json"

COCO_TRAIN_IMG_DIR = COCO_ROOT / "train2017"
COCO_VAL_IMG_DIR   = COCO_ROOT / "val2017"

# Extra views
DEPTH_TRAIN_DIR = DATA_ROOT / "train2017_depth_cmap"    # depth as .png
STYLE_TRAIN_DIR = DATA_ROOT / "style_coco_train2017"    # stylized view (if available)
EDGE_TRAIN_DIR  = None                                  # optional

# Non-VOC val + CutLER proposals
NONVOC_VAL_JSON       = DATA_ROOT / "uvo_nonvoc_val_rle.json"
CUTLER_PROPOSALS_JSON = DATA_ROOT / "vCLR_coco_train2017_top5.json"

NONVOC_IMG_DIR = DATA_ROOT / "uvo_videos_dense_frames"

# Training / eval hyperparameters
NUM_EPOCHS         = 8
BASE_LR            = 1e-4
TRAIN_BATCH_SIZE   = 15
VAL_BATCH_SIZE     = 15
TRAIN_NUM_WORKERS  = 0   # safe after we remove COCO from Dataset
VAL_NUM_WORKERS    = 0   # simpler eval
IMG_SIZE           = 800

# v-CLR loss weights
LAMBDA_GT    = 1.0
LAMBDA_OBJ   = 1.0
LAMBDA_SIM   = 1.0
LAMBDA_MATCH = 1.0

# MaskDINO/DINO-style detection and segmentation loss weights (from official configs)
# CLASS_WEIGHT, BOX_WEIGHT, GIOU_WEIGHT, MASK_WEIGHT, DICE_WEIGHT, NO_OBJECT_WEIGHT
LOSS_CLASS_WEIGHT      = 4.0   # classification term
LOSS_BOX_WEIGHT        = 5.0   # box L1 term
LOSS_GIOU_WEIGHT       = 2.0   # (G)IoU term
LOSS_MASK_BCE_WEIGHT   = 5.0   # mask BCE/focal term
LOSS_MASK_DICE_WEIGHT  = 5.0   # mask dice term
LOSS_NO_OBJECT_WEIGHT  = 0.1   # weight for "no object" queries

# Number of object queries per image (DETR / MaskDINO style)
NUM_QUERIES = 300

# VOC 20 classes in COCO naming
# Correct VOC 20 classes mapped to COCO naming conventions
# These are the 20 PASCAL VOC categories as they appear in COCO
VOC_CLASS_NAMES = [
    "person", "bicycle", "car", "motorcycle", "airplane", "bus",
    "train", "boat",
    "bird", "cat", "dog", "horse", "sheep", "cow",
    "bottle", "chair", "couch", "potted plant", "dining table",
    "tv"
]


In [3]:
# === COCO setup & VOC / Non-VOC split ===

coco_train = COCO(str(COCO_TRAIN_JSON))

name_to_id = {cat["name"]: cat["id"] for cat in coco_train.cats.values()}

voc_cat_ids = []
for name in VOC_CLASS_NAMES:
    if name not in name_to_id:
        raise ValueError(f"VOC class '{name}' not found in COCO categories.")
    voc_cat_ids.append(name_to_id[name])

voc_cat_ids = sorted(set(voc_cat_ids))
all_cat_ids = sorted(coco_train.cats.keys())
nonvoc_cat_ids = [cid for cid in all_cat_ids if cid not in voc_cat_ids]

print("VOC category ids:", voc_cat_ids)
print("Non-VOC category ids ({} total):".format(len(nonvoc_cat_ids)),
      nonvoc_cat_ids[:10], "...")


loading annotations into memory...
Done (t=6.57s)
creating index...
index created!
VOC category ids: [1, 2, 3, 4, 5, 6, 7, 8, 9, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 72]
Non-VOC category ids (60 total): [10, 11, 13, 14, 15, 27, 28, 31, 32, 33] ...


In [4]:
# === Load CutLER proposals (COCO xywh -> xyxy for IoU/L1 with detector) ===

with open(CUTLER_PROPOSALS_JSON, "r") as f:
    cutler_data = json.load(f)

proposal_dict: Dict[int, List[Dict[str, torch.Tensor]]] = {}

def _xywh_to_xyxy(boxes_xywh: torch.Tensor) -> torch.Tensor:
    """Convert [x, y, w, h] -> [x1, y1, x2, y2] in the same coordinate system."""
    if boxes_xywh.ndim == 1:
        boxes_xywh = boxes_xywh.unsqueeze(0)
    boxes_xyxy = boxes_xywh.clone()
    boxes_xyxy[:, 2] = boxes_xywh[:, 0] + boxes_xywh[:, 2]  # x2 = x + w
    boxes_xyxy[:, 3] = boxes_xywh[:, 1] + boxes_xywh[:, 3]  # y2 = y + h
    return boxes_xyxy

if isinstance(cutler_data, dict) and "annotations" in cutler_data and "images" in cutler_data:
    # Standard COCO-style proposals: each annotation has bbox=[x,y,w,h] in original pixels
    boxes_by_image = defaultdict(list)
    scores_by_image = defaultdict(list)

    for ann in cutler_data["annotations"]:
        img_id = int(ann["image_id"])
        boxes_by_image[img_id].append(ann["bbox"])               # COCO xywh
        scores_by_image[img_id].append(ann.get("score", 1.0))

    for img_id, boxes_list in boxes_by_image.items():
        boxes_xywh = torch.tensor(boxes_list, dtype=torch.float32)
        boxes_xyxy = _xywh_to_xyxy(boxes_xywh)
        scores = torch.tensor(scores_by_image[img_id], dtype=torch.float32)
        proposal_dict[img_id] = [{"boxes": boxes_xyxy, "scores": scores}]

    print(f"Loaded CutLER-style proposals (COCO JSON) for {len(proposal_dict)} training images.")

elif isinstance(cutler_data, dict):
    # Fallback: dictionary keyed directly by image_id (as string) -> boxes / {boxes,scores}
    skipped_meta_keys: List[str] = []
    for k, v in cutler_data.items():
        try:
            img_id = int(k)
        except ValueError:
            skipped_meta_keys.append(k)
            continue

        if isinstance(v, dict) and "boxes" in v:
            boxes_xywh = torch.tensor(v["boxes"], dtype=torch.float32)
            scores = torch.tensor(
                v.get("scores", [1.0] * len(boxes_xywh)),
                dtype=torch.float32,
            )
        else:
            boxes_xywh = torch.tensor(v, dtype=torch.float32)
            scores = torch.ones(len(boxes_xywh), dtype=torch.float32)

        boxes_xyxy = _xywh_to_xyxy(boxes_xywh)
        proposal_dict[img_id] = [{"boxes": boxes_xyxy, "scores": scores}]

    print("Loaded dict-style proposals for", len(proposal_dict), "training images")
    if skipped_meta_keys:
        print("Skipped non-image keys:", skipped_meta_keys)
else:
    raise ValueError("Unrecognized CUTLER_PROPOSALS_JSON format for proposals.")

# Small sanity check: print one entry (id and shapes)
_example_key = next(iter(proposal_dict.keys()))
_example_entry = proposal_dict[_example_key][0]
print(
    f"Example proposals for image_id={_example_key}: "
    f"boxes={_example_entry['boxes'].shape}, scores={_example_entry['scores'].shape}"
)


Loaded CutLER-style proposals (COCO JSON) for 118287 training images.
Example proposals for image_id=558840: boxes=torch.Size([16, 4]), scores=torch.Size([16])


In [5]:
# === Precompute train_items for a Windows-safe Dataset ===

train_items: List[Dict[str, Any]] = []

for img_id in coco_train.getImgIds():
    img_info = coco_train.loadImgs([img_id])[0]
    ann_ids = coco_train.getAnnIds(imgIds=[img_id], iscrowd=None)
    anns = coco_train.loadAnns(ann_ids)

    boxes = []
    iscrowd = []
    segs = []

    for a in anns:
        if a["category_id"] not in voc_cat_ids:
            continue
        if "bbox" not in a:
            continue
        boxes.append(a["bbox"])              # [x, y, w, h]
        iscrowd.append(a.get("iscrowd", 0))
        segs.append(a.get("segmentation", None))

    item = {
        "image_id": img_id,
        "file_name": img_info["file_name"],
        "boxes": boxes,
        "iscrowd": iscrowd,
        "masks": segs,
        "orig_size": [img_info["height"], img_info["width"]],  # [H, W]
    }
    train_items.append(item)

print("len(train_items) =", len(train_items))


len(train_items) = 118287


In [6]:
%%writefile vclr_dataset2.py
from pathlib import Path
from typing import Any, Dict, List, Optional

import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image
import pycocotools.mask as mask_utils


class VCLRTrainSubset(Dataset):
    """
    Windows-safe v-CLR training subset.

    Each entry in `items` is a Python dict:
      - image_id: int
      - file_name: str
      - boxes: List[List[float]] in xywh (original pixels)
      - iscrowd: List[int]
      - masks: List[segmentation] (COCO polygons or RLE)
      - orig_size: [H, W]

    `proposals` is:
      image_id -> list of { "boxes": Tensor[K,4] (xywh), "scores": Tensor[K] }.
    """

    def __init__(
        self,
        items: List[Dict[str, Any]],
        img_dir: Path,
        depth_dir: Optional[str] = None,
        stylized_dir: Optional[str] = None,
        edge_dir: Optional[str] = None,
        proposals: Optional[Dict[int, List[Dict[str, torch.Tensor]]]] = None,
        transform: Optional[Any] = None,
    ):
        self.items = items
        self.img_dir = Path(img_dir)
        self.depth_dir = Path(depth_dir) if depth_dir is not None else None
        self.stylized_dir = Path(stylized_dir) if stylized_dir is not None else None
        self.edge_dir = Path(edge_dir) if edge_dir is not None else None
        self.proposals = proposals or {}
        self.transform = transform

    def __len__(self) -> int:
        return len(self.items)

    def _load_optional_view(self, root: Optional[Path], file_name: str) -> Optional[Image.Image]:
        if root is None:
            return None
        path = root / file_name
        if not path.is_file():
            alt = None
            if path.suffix.lower() == ".png":
                alt = path.with_suffix(".jpg")
            elif path.suffix.lower() == ".jpg":
                alt = path.with_suffix(".png")
            if alt is None or not alt.is_file():
                return None
            path = alt
        return Image.open(path).convert("RGB")

    def _decode_masks(self, segs, h: int, w: int) -> torch.Tensor:
        if not segs:
            # Legit: there are no masks for this image
            return torch.zeros((0, h, w), dtype=torch.uint8)

        decoded = []
        failed = 0

        for s in segs:
            if s is None:
                failed += 1
                continue

            rle = None
            if isinstance(s, list):
                rles = mask_utils.frPyObjects(s, h, w)
                rle = mask_utils.merge(rles) if isinstance(rles, list) else rles

            elif isinstance(s, dict) and "counts" in s:
                if isinstance(s["counts"], list):
                    rle = mask_utils.frPyObjects(s, h, w)
                    if isinstance(rle, list):
                        rle = mask_utils.merge(rle)
                else:
                    rle = s

            if rle is None:
                failed += 1
                continue

            m = mask_utils.decode(rle).astype("uint8")
            decoded.append(m)

        if not decoded:
            # segs was non-empty but we failed to decode all of them
            # --> better to raise or log, not silently pretend it's background
            raise RuntimeError(
                f"Failed to decode any masks for image; got {len(segs)} seg entries, all invalid."
            )

        return torch.from_numpy(np.stack(decoded, axis=0))




    def __getitem__(self, idx: int) -> Dict[str, Any]:
        item = self.items[idx]
        img_id = int(item["image_id"])
        file_name = item["file_name"]

        # Natural image
        path = self.img_dir / file_name
        if not path.is_file():
            alt = None
            if path.suffix.lower() == ".png":
                alt = path.with_suffix(".jpg")
            elif path.suffix.lower() == ".jpg":
                alt = path.with_suffix(".png")
            if alt is None or not alt.is_file():
                raise FileNotFoundError(f"Image not found: {path} (alt={alt})")
            path = alt
        image_nat = Image.open(path).convert("RGB")

        h0, w0 = item["orig_size"][0], item["orig_size"][1]

        # Extra views
        image_depth = self._load_optional_view(self.depth_dir, file_name)
        image_style = self._load_optional_view(self.stylized_dir, file_name)
        image_edge  = self._load_optional_view(self.edge_dir, file_name)

        boxes = torch.as_tensor(item.get("boxes", []), dtype=torch.float32)
        if boxes.numel() == 0:
            boxes = boxes.reshape(0, 4)

        iscrowd = torch.as_tensor(item.get("iscrowd", []), dtype=torch.int64)
        if iscrowd.numel() == 0:
            iscrowd = iscrowd.reshape(0)

        segs = item.get("masks", [])
        masks = self._decode_masks(segs, h0, w0)  # (N, H0, W0)
        if masks.shape[0] != boxes.shape[0]:
            masks = torch.zeros((boxes.shape[0], h0, w0), dtype=torch.uint8)

        # Class-agnostic: all foreground
        labels = torch.ones((boxes.shape[0],), dtype=torch.int64)

        target: Dict[str, Any] = {
            "boxes": boxes,
            "labels": labels,
            "iscrowd": iscrowd,
            "masks": masks,
            "image_id": torch.tensor([img_id]),
            "orig_size": torch.tensor([h0, w0], dtype=torch.float32),
        }

        props = self.proposals.get(img_id, [])

        sample: Dict[str, Any] = {
            "image_nat": image_nat,
            "image_depth": image_depth,
            "image_style": image_style,
            "image_edge": image_edge,
            "target": target,
            "proposals": props,
        }

        if self.transform is not None:
            sample = self.transform(sample)
        return sample


Overwriting vclr_dataset2.py


In [7]:
import importlib
import vclr_dataset2
importlib.reload(vclr_dataset2)
from vclr_dataset2 import VCLRTrainSubset

class VCLRTransform:
    """
    - Resize images to (IMG_SIZE, IMG_SIZE)
    - Convert to tensor and normalize
    - Rescale boxes + proposals accordingly
    """
    def __init__(self, img_size=IMG_SIZE):
        self.img_size = img_size
        self.resize_img = T.Resize((img_size, img_size))
        self.to_tensor = T.ToTensor()
        self.normalize = T.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
        )

    def _process_view(self, img):
        if img is None:
            return None
        img = self.resize_img(img)
        img = self.to_tensor(img)
        img = self.normalize(img)
        return img

    def __call__(self, sample: Dict[str, Any]) -> Dict[str, Any]:
        img_nat = sample["image_nat"]
        h0, w0 = img_nat.size[1], img_nat.size[0]

        sample["image_nat"]   = self._process_view(sample["image_nat"])
        sample["image_depth"] = self._process_view(sample["image_depth"])
        sample["image_style"] = self._process_view(sample["image_style"])
        sample["image_edge"]  = self._process_view(sample["image_edge"])

        sx = self.img_size / float(w0)
        sy = self.img_size / float(h0)

        target = sample["target"]
        boxes = target["boxes"].clone()
        if boxes.numel() > 0:
            boxes[:, 0] = boxes[:, 0] * sx
            boxes[:, 1] = boxes[:, 1] * sy
            boxes[:, 2] = boxes[:, 2] * sx
            boxes[:, 3] = boxes[:, 3] * sy
        target["boxes"] = boxes
        target["orig_size"] = target["orig_size"]
        sample["target"] = target

        for entry in sample["proposals"]:
            b = entry["boxes"].clone()
            if b.numel() == 0:
                continue
            b[:, 0] = b[:, 0] * sx
            b[:, 1] = b[:, 1] * sy
            b[:, 2] = b[:, 2] * sx
            b[:, 3] = b[:, 3] * sy
            entry["boxes"] = b

        return sample


def vclr_collate(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
    out = {
        "image_nat": [],
        "image_depth": [],
        "image_style": [],
        "image_edge": [],
        "targets": [],
        "proposals": [],
    }
    for b in batch:
        out["image_nat"].append(b["image_nat"])
        out["image_depth"].append(b["image_depth"])
        out["image_style"].append(b["image_style"])
        out["image_edge"].append(b["image_edge"])
        out["targets"].append(b["target"])
        out["proposals"].append(b["proposals"])

    out["image_nat"] = torch.stack(out["image_nat"], dim=0)
    return out


train_transform = VCLRTransform(img_size=IMG_SIZE)

# ---- Train/val split on train_items ----
indices = list(range(len(train_items)))
rng = random.Random(SEED)
rng.shuffle(indices)

val_frac = 0.1  # 10% for validation loss
split = int((1.0 - val_frac) * len(indices))
train_indices = indices[:split]
val_indices   = indices[split:]

train_items_train = [train_items[i] for i in train_indices]
val_items_loss    = [train_items[i] for i in val_indices]

print("Train items:", len(train_items_train), "Val items (loss):", len(val_items_loss))

train_dataset = VCLRTrainSubset(
    items=train_items_train,
    img_dir=COCO_TRAIN_IMG_DIR,
    depth_dir=str(DEPTH_TRAIN_DIR),
    stylized_dir=str(STYLE_TRAIN_DIR),
    edge_dir=str(EDGE_TRAIN_DIR) if EDGE_TRAIN_DIR is not None else None,
    proposals=proposal_dict,
    transform=train_transform,
)

val_dataset_loss = VCLRTrainSubset(
    items=val_items_loss,
    img_dir=COCO_TRAIN_IMG_DIR,
    depth_dir=str(DEPTH_TRAIN_DIR),
    stylized_dir=str(STYLE_TRAIN_DIR),
    edge_dir=str(EDGE_TRAIN_DIR) if EDGE_TRAIN_DIR is not None else None,
    proposals=proposal_dict,
    transform=train_transform,
)

train_loader = DataLoader(
    train_dataset,
    batch_size=TRAIN_BATCH_SIZE,
    shuffle=True,
    num_workers=TRAIN_NUM_WORKERS,
    pin_memory=False,
    persistent_workers=False,
    collate_fn=vclr_collate,
)

val_loader_loss = DataLoader(
    val_dataset_loss,
    batch_size=VAL_BATCH_SIZE,
    shuffle=False,
    num_workers=VAL_NUM_WORKERS,
    pin_memory=False,
    persistent_workers=False,
    collate_fn=vclr_collate,
)

print("Train batches:", len(train_loader), "Val (loss) batches:", len(val_loader_loss))


Train items: 106458 Val items (loss): 11829
Train batches: 7098 Val (loss) batches: 789


In [8]:
# === Box utilities ===

def box_xywh_to_cxcywh(boxes: torch.Tensor) -> torch.Tensor:
    x, y, w, h = boxes.unbind(-1)
    cx = x + 0.5 * w
    cy = y + 0.5 * h
    return torch.stack([cx, cy, w, h], dim=-1)

def box_cxcywh_to_xyxy(boxes: torch.Tensor) -> torch.Tensor:
    cx, cy, w, h = boxes.unbind(-1)
    x1 = cx - 0.5 * w
    y1 = cy - 0.5 * h
    x2 = cx + 0.5 * w
    y2 = cy + 0.5 * h
    return torch.stack([x1, y1, x2, y2], dim=-1)


In [9]:
# === ConvNeXt-based dense detector with partial fine-tuning + greedy GPU matching ===

class ConvNeXtDetector(nn.Module):
    def __init__(self, d_model: int = 256, unfreeze_stages: int = 2):
        super().__init__()
        self.d_model = d_model

        try:
            from torchvision.models import ConvNeXt_Tiny_Weights
            weights = ConvNeXt_Tiny_Weights.IMAGENET1K_V1
            backbone = convnext_tiny(weights=weights)
        except Exception:
            backbone = convnext_tiny(pretrained=True)

        for p in backbone.features.parameters():
            p.requires_grad = False

        if unfreeze_stages > 0:
            for layer in backbone.features[-unfreeze_stages:]:
                for p in layer.parameters():
                    p.requires_grad = True

        self.backbone_features = backbone.features  # [B,C,Hf,Wf]

        self.conv_proj: Optional[nn.Conv2d] = None

        self.cls_head = nn.Sequential(
            nn.Conv2d(d_model, d_model, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(d_model, 1, kernel_size=1),
        )
        self.box_head = nn.Sequential(
            nn.Conv2d(d_model, d_model, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(d_model, 4, kernel_size=1),
        )

    def forward(
        self,
        images: torch.Tensor,
        targets: Optional[List[Dict[str, Any]]] = None,
        proposals: Optional[List[Any]] = None,
        view_name: str = "nat",
    ):
        B, _, H, W = images.shape
        images = images.to(memory_format=torch.channels_last)

        feat = self.backbone_features(images)  # [B,C,Hf,Wf]
        if self.conv_proj is None:
            self.conv_proj = nn.Conv2d(feat.shape[1], self.d_model, kernel_size=1).to(feat.device)
        feat = self.conv_proj(feat)           # [B,d_model,Hf,Wf]

        cls_map = self.cls_head(feat)         # [B,1,Hf,Wf]
        box_map = self.box_head(feat).sigmoid()  # [B,4,Hf,Wf]

        B, _, Hf, Wf = cls_map.shape
        N = Hf * Wf

        pred_logits = cls_map.view(B, 1, N).permute(0, 2, 1)   # [B,N,1]
        pred_boxes  = box_map.view(B, 4, N).permute(0, 2, 1)   # [B,N,4]
        query_feats = feat.view(B, self.d_model, N).permute(0, 2, 1)

        outputs_raw = {
            "pred_logits": pred_logits,
            "pred_boxes": pred_boxes,
            "query_feats": query_feats,
        }

        outputs_list: List[Dict[str, torch.Tensor]] = []
        for b in range(B):
            scores = pred_logits[b].sigmoid().squeeze(-1)      # [N]
            boxes_norm_cxcywh = pred_boxes[b]                  # [N,4]
            boxes_norm_xyxy = box_cxcywh_to_xyxy(boxes_norm_cxcywh)
            scale = torch.tensor([W, H, W, H], device=boxes_norm_xyxy.device)
            boxes_xyxy = boxes_norm_xyxy * scale
            outputs_list.append({
                "boxes": boxes_xyxy,
                "scores": scores,
                "masks": None,
                "query_feats": query_feats[b],
            })

        if self.training and targets is not None:
            loss_dict = self.compute_losses(outputs_raw, targets, (H, W))
        else:
            loss_dict = {}

        return outputs_list, loss_dict

    def compute_losses(
        self,
        outputs_raw: Dict[str, torch.Tensor],
        targets: List[Dict[str, Any]],
        img_size_hw: Tuple[int, int],
    ) -> Dict[str, torch.Tensor]:
        pred_logits = outputs_raw["pred_logits"]  # [B,N,1]
        pred_boxes  = outputs_raw["pred_boxes"]   # [B,N,4]

        B, N, _ = pred_boxes.shape
        H, W = img_size_hw

        loss_cls_list = []
        loss_bbox_list = []
        loss_giou_list = []

        for b in range(B):
            logits = pred_logits[b].squeeze(-1)  # [N]
            pb = pred_boxes[b]                  # [N,4]

            tgt = targets[b]
            gt_boxes_xywh = tgt["boxes"].to(pb.device)  # [Ng,4] in resized pixels

            if gt_boxes_xywh.numel() == 0:
                target_classes = torch.zeros(N, dtype=torch.float32, device=logits.device)
                loss_cls_list.append(
                    F.binary_cross_entropy_with_logits(logits, target_classes)
                )
                continue

            sizes = torch.tensor([W, H, W, H], device=pb.device)
            gt_boxes_norm_xywh = gt_boxes_xywh / sizes
            gt_boxes_norm_cxcywh = box_xywh_to_cxcywh(gt_boxes_norm_xywh)

            pred_xyxy = box_cxcywh_to_xyxy(pb)
            gt_xyxy   = box_cxcywh_to_xyxy(gt_boxes_norm_cxcywh)

            ious = box_iou(pred_xyxy, gt_xyxy)  # [N_pred, N_gt]

            best_pred_for_gt = ious.argmax(dim=0)         # [Ng]
            matched_pred_idx = best_pred_for_gt.unique()  # [Nm]
            matched_gt_idx = []
            for j in matched_pred_idx:
                gt_idxs = (best_pred_for_gt == j).nonzero(as_tuple=False).squeeze(-1)
                _, k = ious[j, gt_idxs].max(dim=0)
                matched_gt_idx.append(gt_idxs[k])
            matched_gt_idx = torch.stack(matched_gt_idx, dim=0)

            target_classes = torch.zeros(N, dtype=torch.float32, device=logits.device)
            target_classes[matched_pred_idx] = 1.0

            loss_cls_list.append(
                F.binary_cross_entropy_with_logits(logits, target_classes)
            )

            pb_matched = pb[matched_pred_idx]
            gt_matched = gt_boxes_norm_cxcywh[matched_gt_idx]

            loss_bbox_list.append(F.l1_loss(pb_matched, gt_matched, reduction="mean"))

            giou = generalized_box_iou(
                box_cxcywh_to_xyxy(pb_matched),
                box_cxcywh_to_xyxy(gt_matched),
            )
            loss_giou_list.append(1.0 - giou.diag().mean())

        loss_cls  = torch.stack(loss_cls_list).mean()
        loss_bbox = torch.stack(loss_bbox_list).mean() if loss_bbox_list else torch.tensor(0.0, device=loss_cls.device)
        loss_giou = torch.stack(loss_giou_list).mean() if loss_giou_list else torch.tensor(0.0, device=loss_cls.device)

        return {
            "loss_cls": loss_cls,
            "loss_bbox": loss_bbox,
            "loss_giou": loss_giou,
        }


model = ConvNeXtDetector(d_model=256, unfreeze_stages=4).to(device)
model = model.to(memory_format=torch.channels_last)

optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=BASE_LR, weight_decay=1e-4
)

def adjust_lr(optimizer, epoch: int):
    factor = 0.1 if epoch >= 7 else 1.0
    for pg in optimizer.param_groups:
        pg["lr"] = BASE_LR * factor

print("Total trainable params:",
      sum(p.numel() for p in model.parameters() if p.requires_grad))


Total trainable params: 27764997


In [10]:
# === Teacher-Student wrapper around ConvNeXtDetector (v-CLR two-branch) ===

import copy
import torch.nn as nn

class TeacherStudentVCLR(nn.Module):
    """
    EMA teacher–student wrapper around ConvNeXtDetector.
    - teacher: natural-image branch (no gradients, EMA-updated)
    - student: transformed-image branch (gets all supervision + gradients)
    """
    def __init__(self, base_detector: nn.Module, ema_momentum: float = 0.999):
        super().__init__()
        self.student = base_detector           # reuse your existing ConvNeXtDetector
        self.teacher = copy.deepcopy(base_detector)
        for p in self.teacher.parameters():
            p.requires_grad = False
        self.ema_m = ema_momentum

    @torch.no_grad()
    def update_teacher(self):
        m = self.ema_m
        for t_param, s_param in zip(self.teacher.parameters(), self.student.parameters()):
            t_param.data.mul_(m).add_(s_param.data, alpha=1.0 - m)

    def train(self, mode: bool = True):
        """
        Only the student is put in train() / eval().
        Teacher is always kept in eval mode.
        """
        self.student.train(mode)
        self.teacher.eval()
        self.training = mode
        return self

    def eval(self):  # type: ignore[override]
        return self.train(False)


# Wrap your existing single-branch ConvNeXtDetector as student
# (assumes `model` already exists from the previous cell)
base_detector = model
model = TeacherStudentVCLR(base_detector, ema_momentum=0.999).to(device)

# Use channels_last for both branches to speed up ConvNeXt
model.student = model.student.to(memory_format=torch.channels_last)
model.teacher = model.teacher.to(memory_format=torch.channels_last)

# --- Optimizer + LR schedule (backbone vs head LRs) ---

# BASE_LR should already be defined above (your "main" LR)
backbone_lr = BASE_LR * 0.1   # smaller LR for pretrained ConvNeXt backbone
head_lr     = BASE_LR         # original LR for detection/mask/v-CLR heads

# Identify backbone parameters on the *student* branch
backbone_param_ids = {id(p) for p in model.student.backbone_features.parameters()}

backbone_params = []
head_params = []

for p in model.parameters():
    if not p.requires_grad:
        continue
    if id(p) in backbone_param_ids:
        backbone_params.append(p)
    else:
        head_params.append(p)

print(f"[DEBUG] trainable backbone params: {len(backbone_params)}")
print(f"[DEBUG] trainable head params:     {len(head_params)}")

optimizer = torch.optim.AdamW(
    [
        {"params": backbone_params, "lr": backbone_lr},
        {"params": head_params,     "lr": head_lr},
    ],
    weight_decay=1e-4,
)

# Remember each group's initial LR so we keep the ratio when decaying
for pg in optimizer.param_groups:
    pg["initial_lr"] = pg["lr"]

def adjust_lr(optimizer, epoch: int):
    """
    Simple step schedule: drop all LRs by 10× after epoch 7,
    while preserving the backbone/head LR ratio.
    """
    factor = 0.1 if epoch >= 7 else 1.0
    for pg in optimizer.param_groups:
        pg["lr"] = pg["initial_lr"] * factor

print(
    "Total trainable params:",
    sum(p.numel() for p in model.parameters() if p.requires_grad),
)


[DEBUG] trainable backbone params: 116
[DEBUG] trainable head params:     8
Total trainable params: 27764997


In [11]:
# === v-CLR auxiliary losses (L_obj, L_sim) – NaN-safe L_gt ===

def cosine_sim_loss(q1: torch.Tensor, q2: torch.Tensor, temperature: float = 0.07) -> torch.Tensor:
    """
    Compute cosine similarity loss between query features with temperature scaling.
    
    Temperature scaling is standard in contrastive learning (e.g., SimCLR, CLIP).
    Lower temperature makes the similarity more peaked (sharper), encouraging 
    better feature alignment. Default temperature of 0.07 is commonly used.
    
    The loss is computed as: 1 - cosine_similarity
    Temperature is applied to scale the similarity before computing the loss,
    making the optimization more sensitive to small differences.
    """
    q1 = F.normalize(q1, dim=-1)
    q2 = F.normalize(q2, dim=-1)
    # Compute raw cosine similarity (range: [-1, 1])
    similarity = (q1 * q2).sum(dim=-1)
    # Clamp to valid cosine similarity range
    similarity = similarity.clamp(-1, 1)
    # Apply temperature: lower temp = sharper gradients for near-perfect matches
    # The loss is (1 - similarity) / temperature, but we normalize by temperature
    # to keep the loss scale reasonable
    loss = (1.0 - similarity) / temperature
    return loss.mean()


def match_proposals_to_predictions(
    proposal_boxes: torch.Tensor,
    pred_boxes: torch.Tensor,
    iou_thresh: float = 0.3,
    max_pairs: int = 32,
):
    """
    Greedy IoU-based matching: for each proposal, pick the best prediction
    above iou_thresh, then truncate to at most max_pairs matches.
    Returns:
      idx_p:    indices into proposal_boxes  [K]
      idx_pred: indices into pred_boxes      [K]
    """
    if proposal_boxes.numel() == 0 or pred_boxes.numel() == 0:
        device = pred_boxes.device
        return (
            torch.empty(0, dtype=torch.long, device=device),
            torch.empty(0, dtype=torch.long, device=device),
        )

    ious = box_iou(proposal_boxes, pred_boxes)  # [Np, Npred]
    best_iou, best_idx = ious.max(dim=1)        # for each proposal: best pred

    keep = best_iou > iou_thresh
    idx_p = torch.nonzero(keep, as_tuple=False).squeeze(-1)
    idx_pred = best_idx[keep]

    if idx_p.numel() > max_pairs:
        idx_p = idx_p[:max_pairs]
        idx_pred = idx_pred[:max_pairs]

    return idx_p, idx_pred


def _intersect1d_with_indices(a: torch.Tensor, b: torch.Tensor):
    """
    Slow but simple 1D "set intersection" with index tracking.

    Returns:
      common:      values common to a and b          [K]
      idx_a_local: indices in a for those values     [K]
      idx_b_local: indices in b for those values     [K]
    """
    assert a.dim() == 1 and b.dim() == 1
    device = a.device
    common_vals = []
    idx_a = []
    idx_b = []

    for i in range(a.numel()):
        v = a[i]
        matches = (b == v).nonzero(as_tuple=False)
        if matches.numel() > 0:
            common_vals.append(v)
            idx_a.append(i)
            idx_b.append(matches[0].item())

    if len(common_vals) == 0:
        return (
            torch.empty(0, dtype=a.dtype, device=device),
            torch.empty(0, dtype=torch.long, device=device),
            torch.empty(0, dtype=torch.long, device=device),
        )

    common = torch.stack(common_vals).to(device)
    idx_a_local = torch.tensor(idx_a, dtype=torch.long, device=device)
    idx_b_local = torch.tensor(idx_b, dtype=torch.long, device=device)
    return common, idx_a_local, idx_b_local


def compute_vclr_losses(
    outputs_nat,   loss_nat,
    outputs_depth, loss_depth,
    outputs_style, loss_style,
    proposals_batch,
) -> Dict[str, torch.Tensor]:
    """
    Aggregate:
      - L_gt   : detection loss (nat + optional depth/style)
      - L_obj  : proposal–prediction L1 alignment across views
      - L_sim  : query feature cosine similarity between nat and other views
    """

    # ---- device for scalar tensors ----
    # outputs_nat is a list of per-image dicts; each has "boxes".
    device = outputs_nat[0]["boxes"].device

    def sum_loss_dict(ld: Dict[str, torch.Tensor]) -> torch.Tensor:
        """
        Sum detection losses in a dict, ignoring non-finite terms.
        Always returns a 0-D tensor on `device`.
        """
        if not ld:
            return torch.tensor(0.0, device=device)

        vals = []
        for v in ld.values():
            if not torch.is_tensor(v):
                continue
            v = v.to(device=device, dtype=torch.float32)
            if not torch.isfinite(v):
                # Rare; if this ever triggers, it's exactly the source of NaNs.
                # Uncomment if you want explicit logging:
                # print(f"[WARN] non-finite detection loss term {v.item()} in sum_loss_dict, skipping.")
                continue
            vals.append(v)

        if not vals:
            return torch.tensor(0.0, device=device)
        # NOTE: no reweighting here; this keeps the same behaviour as before.
        return torch.stack(vals).sum()

    # ---- Ground-truth supervised term L_gt (nat + optional depth/style) ----
    L_gt_nat   = sum_loss_dict(loss_nat)
    L_gt_depth = sum_loss_dict(loss_depth)
    L_gt_style = sum_loss_dict(loss_style)

    # All of these are now guaranteed 0-D tensors on `device`, no Python floats.
    L_gt = L_gt_nat + L_gt_depth + L_gt_style

    # ---- Proposal-objectness term L_obj and query-similarity term L_sim ----
    L_obj_list: List[torch.Tensor] = []
    L_sim_list: List[torch.Tensor] = []

    B = len(outputs_nat)

    for b in range(B):
        props = proposals_batch[b]
        if not props:
            continue

        # proposals: list of dicts (we use the first entry)
        prop_boxes = props[0]["boxes"].to(device)

        # Per-view outputs for this image
        nat_out   = outputs_nat[b]
        depth_out = outputs_depth[b] if outputs_depth else None
        style_out = outputs_style[b] if outputs_style else None

        # 1) L_obj: L1 alignment between proposals and predicted boxes per view
        #    Use normalized coordinates (divide by image size) so this lives on ~[0, 1]
        for ov in [nat_out, depth_out, style_out]:
            if ov is None:
                continue
            pred_boxes = ov["boxes"]  # xyxy in pixels
            if pred_boxes.numel() == 0:
                continue

            idx_p, idx_pred = match_proposals_to_predictions(prop_boxes, pred_boxes)
            if idx_p.numel() == 0:
                continue

            sel_props = prop_boxes[idx_p]   # [K,4], xyxy in pixels
            sel_preds = pred_boxes[idx_pred]

            # Normalize coordinates by image size
            scale_xyxy = torch.tensor(
                [IMG_SIZE, IMG_SIZE, IMG_SIZE, IMG_SIZE],
                device=sel_preds.device,
                dtype=sel_preds.dtype,
            )
            sel_props_norm = sel_props / scale_xyxy
            sel_preds_norm = sel_preds / scale_xyxy

            L_obj_list.append(F.l1_loss(sel_preds_norm, sel_props_norm))


        # 2) L_sim: cosine similarity between query features for nat vs each other view
        idx_p_nat, idx_pred_nat = match_proposals_to_predictions(
            prop_boxes, nat_out["boxes"]
        )
        if idx_p_nat.numel() == 0:
            continue

        def get_query_feats(out, idx_pred):
            if out is None:
                return None
            q = out.get("query_feats", None)
            if q is None:
                return None
            if idx_pred.numel() == 0 or idx_pred.max().item() >= q.shape[0]:
                return None
            return q[idx_pred]

        for ov in [depth_out, style_out]:
            if ov is None:
                continue

            idx_p_view, idx_pred_view = match_proposals_to_predictions(
                prop_boxes, ov["boxes"]
            )
            if idx_p_view.numel() == 0:
                continue

            common_prop, idx_nat_local, idx_view_local = _intersect1d_with_indices(
                idx_p_nat, idx_p_view
            )
            if common_prop.numel() == 0:
                continue

            q_nat  = get_query_feats(nat_out,  idx_pred_nat[idx_nat_local])
            q_view = get_query_feats(ov,       idx_pred_view[idx_view_local])

            if q_nat is None or q_view is None or q_nat.shape[0] == 0:
                continue

            L_sim_list.append(cosine_sim_loss(q_nat, q_view))

    # Safe defaults if we collected no matches
    L_obj = (
        torch.stack(L_obj_list).mean()
        if L_obj_list
        else torch.tensor(0.0, device=device)
    )
    L_sim = (
        torch.stack(L_sim_list).mean()
        if L_sim_list
        else torch.tensor(0.0, device=device)
    )

    # v-CLR total loss (same structure as before)
    L_total = (
        LAMBDA_GT * L_gt +
        LAMBDA_MATCH * (LAMBDA_OBJ * L_obj + LAMBDA_SIM * L_sim)
    )

    return {
        "L_gt":    L_gt,
        "L_obj":   L_obj,
        "L_sim":   L_sim,
        "L_total": L_total,
    }


In [12]:
# History for plotting
train_history = {
    "epoch": [],
    "L_total_train": [],
    "L_gt_train": [],
    "L_obj_train": [],
    "L_sim_train": [],
    "L_total_val": [],
    "L_gt_val": [],
    "L_obj_val": [],
    "L_sim_val": [],
}


In [13]:
# === Validation loss loop (v-CLR with teacher–student, robust view handling) ===

def compute_epoch_loss_on_loader(model: nn.Module, loader: DataLoader) -> Dict[str, float]:
    model.eval()
    is_ts = hasattr(model, "student") and hasattr(model, "teacher")
    student = model.student if is_ts else model
    teacher = model.teacher if is_ts else model

    total_L_total = 0.0
    total_L_gt    = 0.0
    total_L_obj   = 0.0
    total_L_sim   = 0.0
    batch_count   = 0

    with torch.no_grad():
        for batch in loader:
            images_nat   = batch["image_nat"].to(device, non_blocking=True)
            images_depth = _normalize_view_batch(batch["image_depth"], device)
            images_style = _normalize_view_batch(batch["image_style"], device)
            targets      = batch["targets"]
            proposals    = batch["proposals"]

            has_depth = images_depth is not None
            has_style = images_style is not None

            if has_depth and has_style:
                if random.random() < 0.5:
                    use_depth, use_style = True, False
                else:
                    use_depth, use_style = False, True
            else:
                use_depth, use_style = has_depth, has_style

            with torch.cuda.amp.autocast(dtype=torch.float16):
                if is_ts:
                    # EMA teacher on natural images (no gradient)
                    outputs_nat_teacher, _ = teacher(
                        images_nat, targets=None, proposals=proposals, view_name="nat"
                    )
                    # Student gets GT detection loss on nat
                    outputs_nat_student, loss_nat = student(
                        images_nat, targets=targets, proposals=proposals, view_name="nat"
                    )
                    outputs_nat = outputs_nat_teacher
                else:
                    outputs_nat, loss_nat = student(
                        images_nat, targets=targets, proposals=proposals, view_name="nat"
                    )

                outputs_depth, loss_depth = outputs_nat, {}
                outputs_style, loss_style = outputs_nat, {}

                # NOTE: Depth/style views do NOT receive GT targets per v-CLR paper.
                if use_depth and images_depth is not None:
                    outputs_depth, loss_depth = student(
                        images_depth, targets=None, proposals=proposals, view_name="depth"
                    )
                if use_style and images_style is not None:
                    outputs_style, loss_style = student(
                        images_style, targets=None, proposals=proposals, view_name="style"
                    )

                vclr_loss = compute_vclr_losses(
                    outputs_nat,  loss_nat,
                    outputs_depth, loss_depth,
                    outputs_style, loss_style,
                    proposals,
                )
                loss = vclr_loss["L_total"]

            total_L_total += float(loss)
            total_L_gt    += float(vclr_loss["L_gt"])
            total_L_obj   += float(vclr_loss["L_obj"])
            total_L_sim   += float(vclr_loss["L_sim"])
            batch_count   += 1

    if batch_count == 0:
        return {"L_total": 0.0, "L_gt": 0.0, "L_obj": 0.0, "L_sim": 0.0}

    return {
        "L_total": total_L_total / batch_count,
        "L_gt":    total_L_gt    / batch_count,
        "L_obj":   total_L_obj   / batch_count,
        "L_sim":   total_L_sim   / batch_count,
    }


In [14]:
import torch
from typing import Optional, Union, List

def _normalize_view_batch(
    x: Union[None, torch.Tensor, List[Optional[torch.Tensor]]],
    device: torch.device,
) -> Optional[torch.Tensor]:
    """
    Normalize depth/style batch from collate_fn:
      - None -> None
      - Tensor -> Tensor on device
      - list[Tensor] -> stacked Tensor on device
      - list[None] -> None
    Raises if list contains a mix of None and Tensor, which indicates a real data issue.
    """
    if x is None:
        return None

    if isinstance(x, torch.Tensor):
        return x.to(device, non_blocking=True)

    if isinstance(x, list):
        if len(x) == 0:
            return None
        if all(v is None for v in x):
            return None
        if all(isinstance(v, torch.Tensor) for v in x):
            return torch.stack(x, dim=0).to(device, non_blocking=True)
        # Mixed None and Tensor -> this is a bug upstream
        raise RuntimeError("image_depth/image_style list has mixed None and Tensors; please fix Dataset/Collate.")

    raise TypeError(f"Unexpected type for view batch: {type(x)}")


In [16]:
# === DEBUG: inspect one batch (GT boxes, CutLER proposals, teacher/student losses) ===

import math

# Handle teacher–student or single model uniformly
is_ts = hasattr(model, "student") and hasattr(model, "teacher")
student = model.student if is_ts else model
teacher = model.teacher if is_ts else None

print("=== DEBUG: single-batch inspection (nat view only) ===")
print(f"len(train_dataset) = {len(train_dataset)}")
print(f"len(train_loader)  = {len(train_loader)}")

# Get a single batch from the train loader
batch = next(iter(train_loader))

images_nat = batch["image_nat"]  # should be a tensor [B,3,H,W]
print("[DEBUG] type(images_nat):", type(images_nat))
if isinstance(images_nat, torch.Tensor):
    print("[DEBUG] images_nat.shape:", images_nat.shape)
else:
    raise TypeError(f"Expected images_nat to be Tensor, got {type(images_nat)}")

images_nat = images_nat.to(device, non_blocking=True)

targets   = batch["targets"]    # list of dict
proposals = batch["proposals"]  # list of list-of-dicts (CutLER proposals)

print(f"[DEBUG] batch size (targets): {len(targets)}")
print(f"[DEBUG] batch size (proposals): {len(proposals)}")

# ---- 1) Inspect GT boxes for degenerate / weird values ----
for i, tgt in enumerate(targets):
    boxes = tgt["boxes"]          # [Ng,4] in resized pixels (xywh in your pipeline)
    img_id = int(tgt["image_id"].item())
    print(f"\n[DEBUG] sample {i}: image_id={img_id}, num_gt_boxes={boxes.shape[0]}")

    if boxes.numel() == 0:
        print("  -> no GT boxes")
        continue

    wh = boxes[:, 2:4]
    print("  GT w/h min:", wh.min(dim=0).values.cpu().tolist(),
          "max:",        wh.max(dim=0).values.cpu().tolist())

    print("  any NaN in GT boxes? ", torch.isnan(boxes).any().item())
    print("  any Inf in GT boxes?  ",
          (torch.isinf(boxes) & (boxes > 0)).any().item())

# ---- 2) Inspect CutLER proposals for this batch ----
for i, props in enumerate(proposals):
    if not props:
        print(f"\n[DEBUG] sample {i}: NO proposals")
        continue

    entry = props[0]
    boxes_p = entry["boxes"]   # should now be xyxy after our earlier fix
    scores_p = entry["scores"]

    print(f"\n[DEBUG] sample {i}: proposals -> boxes {boxes_p.shape}, scores {scores_p.shape}")
    if boxes_p.numel() > 0:
        print("  proposals min:", boxes_p.min().item(), "max:", boxes_p.max().item())
    print("  any NaN in proposals? ", torch.isnan(boxes_p).any().item())
    print("  any Inf in proposals?  ",
          (torch.isinf(boxes_p) & (boxes_p > 0)).any().item())

# ---- 3) Run student detection once and inspect loss_nat (source of L_gt_nat) ----
student.eval()
with torch.no_grad():
    outputs_nat, loss_nat = student(
        images_nat, targets=targets, proposals=proposals, view_name="nat"
    )

print("\n[DEBUG] loss_nat raw:", loss_nat)
for k, v in loss_nat.items():
    if torch.is_tensor(v):
        print(f"[DEBUG] loss_nat[{k}]: {v.item():.6f}, finite={bool(torch.isfinite(v))}")
    else:
        print(f"[DEBUG] loss_nat[{k}]: non-tensor value -> {v}")

# Optional: also check teacher branch if you’re actually using it for GT loss
if is_ts and teacher is not None:
    teacher.eval()
    with torch.no_grad():
        outputs_nat_teacher, loss_nat_teacher = teacher(
            images_nat, targets=targets, proposals=proposals, view_name="nat"
        )
    print("\n[DEBUG] teacher loss_nat raw:", loss_nat_teacher)
    for k, v in loss_nat_teacher.items():
        if torch.is_tensor(v):
            print(f"[DEBUG] teacher loss_nat[{k}]: {v.item():.6f}, finite={bool(torch.isfinite(v))}")
        else:
            print(f"[DEBUG] teacher loss_nat[{k}]: non-tensor value -> {v}")


=== DEBUG: single-batch inspection (nat view only) ===
len(train_dataset) = 106458
len(train_loader)  = 10646
[DEBUG] type(images_nat): <class 'torch.Tensor'>
[DEBUG] images_nat.shape: torch.Size([10, 3, 800, 800])
[DEBUG] batch size (targets): 10
[DEBUG] batch size (proposals): 10

[DEBUG] sample 0: image_id=578350, num_gt_boxes=8
  GT w/h min: [29.145540237426758, 80.4749984741211] max: [553.4647827148438, 481.75]
  any NaN in GT boxes?  False
  any Inf in GT boxes?   False

[DEBUG] sample 1: image_id=45387, num_gt_boxes=0
  -> no GT boxes

[DEBUG] sample 2: image_id=174213, num_gt_boxes=3
  GT w/h min: [23.125, 102.18267059326172] max: [122.61249542236328, 250.49180603027344]
  any NaN in GT boxes?  False
  any Inf in GT boxes?   False

[DEBUG] sample 3: image_id=518415, num_gt_boxes=0
  -> no GT boxes

[DEBUG] sample 4: image_id=67443, num_gt_boxes=1
  GT w/h min: [169.21249389648438, 410.2857360839844] max: [169.21249389648438, 410.2857360839844]
  any NaN in GT boxes?  False
  an

In [17]:
# === DEBUG v2: inspect detection losses in TRAIN mode (single batch) ===

import math

# Make sure we use the student network (if teacher–student is wrapped)
is_ts = hasattr(model, "student") and hasattr(model, "teacher")
student = model.student if is_ts else model

print("=== DEBUG v2: single-batch detection loss inspection (TRAIN mode) ===")
print(f"len(train_dataset) = {len(train_dataset)}")
print(f"len(train_loader)  = {len(train_loader)}")

# Switch student to train mode so `compute_losses` actually runs
student.train()

# Get one batch
batch = next(iter(train_loader))

images_nat   = batch["image_nat"].to(device, non_blocking=True)
images_depth = _normalize_view_batch(batch["image_depth"], device)
images_style = _normalize_view_batch(batch["image_style"], device)
targets      = batch["targets"]
proposals    = batch["proposals"]

print("[DEBUG] images_nat.shape:", images_nat.shape)
print("[DEBUG] batch size:", len(targets))

# ---- 1) One forward pass on nat view, with GT + proposals ----
# No need for autocast/GradScaler here; we just want the raw losses.
outputs_nat, loss_nat = student(
    images_nat, targets=targets, proposals=proposals, view_name="nat"
)

print("\n[DEBUG] loss_nat keys:", list(loss_nat.keys()))
for k, v in loss_nat.items():
    if torch.is_tensor(v):
        print(f"  loss_nat[{k}] = {v.item():.6f}, finite={bool(torch.isfinite(v))}")
    else:
        print(f"  loss_nat[{k}] = {v} (non-tensor)")

# ---- 2) If depth/style are present, you can optionally inspect them too ----
if images_depth is not None:
    print("\n[DEBUG] running depth view forward for detection loss...")
    outputs_depth, loss_depth = student(
        images_depth, targets=targets, proposals=proposals, view_name="depth"
    )
    for k, v in loss_depth.items():
        if torch.is_tensor(v):
            print(f"  loss_depth[{k}] = {v.item():.6f}, finite={bool(torch.isfinite(v))}")
        else:
            print(f"  loss_depth[{k}] = {v} (non-tensor)")

if images_style is not None:
    print("\n[DEBUG] running style view forward for detection loss...")
    outputs_style, loss_style = student(
        images_style, targets=targets, proposals=proposals, view_name="style"
    )
    for k, v in loss_style.items():
        if torch.is_tensor(v):
            print(f"  loss_style[{k}] = {v.item():.6f}, finite={bool(torch.isfinite(v))}")
        else:
            print(f"  loss_style[{k}] = {v} (non-tensor)")


=== DEBUG v2: single-batch detection loss inspection (TRAIN mode) ===
len(train_dataset) = 106458
len(train_loader)  = 10646
[DEBUG] images_nat.shape: torch.Size([10, 3, 800, 800])
[DEBUG] batch size: 10

[DEBUG] loss_nat keys: ['loss_cls', 'loss_bbox', 'loss_giou']
  loss_nat[loss_cls] = 0.723318, finite=True
  loss_nat[loss_bbox] = 0.187690, finite=True
  loss_nat[loss_giou] = 0.827295, finite=True

[DEBUG] running depth view forward for detection loss...
  loss_depth[loss_cls] = 0.726135, finite=True
  loss_depth[loss_bbox] = 0.196456, finite=True
  loss_depth[loss_giou] = 0.842603, finite=True

[DEBUG] running style view forward for detection loss...
  loss_style[loss_cls] = 0.726785, finite=True
  loss_style[loss_bbox] = 0.187293, finite=True
  loss_style[loss_giou] = 0.829546, finite=True


In [None]:
# === Training loop (teacher–student v-CLR, AMP, loss prints, robust view handling) ===

scaler = torch.amp.GradScaler('cuda')

is_ts = hasattr(model, "student") and hasattr(model, "teacher")
student = model.student if is_ts else model
teacher = model.teacher if is_ts else model

best_val_L_total = float("inf")
start_epoch = 1

print("[DEBUG] Starting training.")
print(f"[DEBUG] len(train_dataset) = {len(train_dataset)}")
print(f"[DEBUG] len(train_loader)  = {len(train_loader)}")
print(f"[DEBUG] TRAIN_BATCH_SIZE = {TRAIN_BATCH_SIZE}, TRAIN_NUM_WORKERS = {TRAIN_NUM_WORKERS}")

for epoch in range(start_epoch, NUM_EPOCHS + 1):
    model.train()
    adjust_lr(optimizer, epoch)

    sum_L_total = 0.0
    sum_L_gt    = 0.0
    sum_L_obj   = 0.0
    sum_L_sim   = 0.0
    batch_count = 0

    print(f"\n[Epoch {epoch}] -----------------------------")

    for step, batch in enumerate(train_loader):
        images_nat   = batch["image_nat"].to(device, non_blocking=True)
        images_depth = _normalize_view_batch(batch["image_depth"], device)
        images_style = _normalize_view_batch(batch["image_style"], device)
        targets      = batch["targets"]
        proposals    = batch["proposals"]

        has_depth = images_depth is not None
        has_style = images_style is not None

        # v-CLR view sampling: randomly pick one extra view when both exist
        if has_depth and has_style:
            if random.random() < 0.5:
                use_depth, use_style = True, False
            else:
                use_depth, use_style = False, True
        else:
            use_depth, use_style = has_depth, has_style

        optimizer.zero_grad(set_to_none=True)

        with torch.cuda.amp.autocast(dtype=torch.float16):
            if is_ts:
                # EMA teacher on natural images (no gradient)
                with torch.no_grad():
                    outputs_nat_teacher, _ = teacher(
                        images_nat, targets=None, proposals=proposals, view_name="nat"
                    )
                # Student gets GT detection loss on nat
                outputs_nat_student, loss_nat = student(
                    images_nat, targets=targets, proposals=proposals, view_name="nat"
                )
                outputs_nat = outputs_nat_teacher
            else:
                outputs_nat, loss_nat = student(
                    images_nat, targets=targets, proposals=proposals, view_name="nat"
                )

            outputs_depth, loss_depth = outputs_nat, {}
            outputs_style, loss_style = outputs_nat, {}

            # NOTE: Depth/style views do NOT receive GT targets per v-CLR paper.
            # These views are trained via view-consistency losses (L_obj, L_sim) only.
            # GT supervision is applied only to the natural view.
            if use_depth and images_depth is not None:
                outputs_depth, loss_depth = student(
                    images_depth, targets=None, proposals=proposals, view_name="depth"
                )
            if use_style and images_style is not None:
                outputs_style, loss_style = student(
                    images_style, targets=None, proposals=proposals, view_name="style"
                )

            vclr_loss = compute_vclr_losses(
                outputs_nat,  loss_nat,
                outputs_depth, loss_depth,
                outputs_style, loss_style,
                proposals,
            )
            loss = vclr_loss["L_total"]

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        if is_ts:
            model.update_teacher()

        L_total = float(vclr_loss["L_total"])
        L_gt    = float(vclr_loss["L_gt"])
        L_obj   = float(vclr_loss["L_obj"])
        L_sim   = float(vclr_loss["L_sim"])

        sum_L_total += L_total
        sum_L_gt    += L_gt
        sum_L_obj   += L_obj
        sum_L_sim   += L_sim
        batch_count += 1

        if (step + 1) % 50 == 0:
            print(
                f"[Epoch {epoch} | step {step+1}/{len(train_loader)}] "
                f"L_total={L_total:.4f}  L_gt={L_gt:.4f}  "
                f"L_obj={L_obj:.4f}  L_sim={L_sim:.4f}"
            )

    avg_L_total = sum_L_total / max(1, batch_count)
    avg_L_gt    = sum_L_gt    / max(1, batch_count)
    avg_L_obj   = sum_L_obj   / max(1, batch_count)
    avg_L_sim   = sum_L_sim   / max(1, batch_count)

    train_history["L_total_train"].append(avg_L_total)
    train_history["L_gt_train"].append(avg_L_gt)
    train_history["L_obj_train"].append(avg_L_obj)
    train_history["L_sim_train"].append(avg_L_sim)

    # === Validation epoch (loss-only) ===
    val_losses = compute_epoch_loss_on_loader(model, val_loader_loss)
    train_history["L_total_val"].append(val_losses["L_total"])
    train_history["L_gt_val"].append(val_losses["L_gt"])
    train_history["L_obj_val"].append(val_losses["L_obj"])
    train_history["L_sim_val"].append(val_losses["L_sim"])

    print(
        f"[Epoch {epoch}] TRAIN: "
        f"L_total={avg_L_total:.4f}  L_gt={avg_L_gt:.4f}  "
        f"L_obj={avg_L_obj:.4f}  L_sim={avg_L_sim:.4f}"
    )
    print(
        f"[Epoch {epoch}] VAL  : "
        f"L_total={val_losses['L_total']:.4f}  "
        f"L_gt={val_losses['L_gt']:.4f}  "
        f"L_obj={val_losses['L_obj']:.4f}  "
        f"L_sim={val_losses['L_sim']:.4f}"
    )

    if val_losses["L_total"] < best_val_L_total:
        best_val_L_total = val_losses["L_total"]
        print(f"[Epoch {epoch}] New best val L_total={best_val_L_total:.4f}")
        torch.save(
            {
                "epoch": epoch,
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "train_history": train_history,
            },
            "vclr_convnext_teacher_student_best.pth",
        )
        print("  -> Saved checkpoint: vclr_convnext_teacher_student_best.pth")


[DEBUG] Starting training.
[DEBUG] len(train_dataset) = 106458
[DEBUG] len(train_loader)  = 7098
[DEBUG] TRAIN_BATCH_SIZE = 15, TRAIN_NUM_WORKERS = 0

[Epoch 1] -----------------------------


  with torch.cuda.amp.autocast(dtype=torch.float16):


[Epoch 1 | step 50/7098] L_total=2.3412  L_gt=1.7174  L_obj=0.1247  L_sim=0.4991
[Epoch 1 | step 100/7098] L_total=1.8388  L_gt=1.2556  L_obj=0.1039  L_sim=0.4793
[Epoch 1 | step 150/7098] L_total=2.0708  L_gt=1.4860  L_obj=0.1166  L_sim=0.4682
[Epoch 1 | step 200/7098] L_total=1.9727  L_gt=1.3801  L_obj=0.1204  L_sim=0.4722
[Epoch 1 | step 250/7098] L_total=1.8291  L_gt=1.2487  L_obj=0.1156  L_sim=0.4648
[Epoch 1 | step 300/7098] L_total=1.7003  L_gt=1.1425  L_obj=0.1066  L_sim=0.4512
[Epoch 1 | step 350/7098] L_total=1.7608  L_gt=1.2183  L_obj=0.1173  L_sim=0.4252
[Epoch 1 | step 400/7098] L_total=1.7618  L_gt=1.2268  L_obj=0.1121  L_sim=0.4229
[Epoch 1 | step 450/7098] L_total=1.9242  L_gt=1.3817  L_obj=0.1134  L_sim=0.4291
[Epoch 1 | step 500/7098] L_total=1.4072  L_gt=0.9233  L_obj=0.1055  L_sim=0.3784
[Epoch 1 | step 550/7098] L_total=1.5720  L_gt=1.0906  L_obj=0.1041  L_sim=0.3774
[Epoch 1 | step 600/7098] L_total=1.6215  L_gt=1.1309  L_obj=0.1063  L_sim=0.3843
[Epoch 1 | step 6

  with torch.cuda.amp.autocast(dtype=torch.float16):


[Epoch 1] TRAIN: L_total=1.4314  L_gt=1.1637  L_obj=0.0572  L_sim=0.2105
[Epoch 1] VAL  : L_total=0.1572  L_gt=0.0000  L_obj=0.0416  L_sim=0.1156
[Epoch 1] New best val L_total=0.1572
  -> Saved checkpoint: vclr_convnext_teacher_student_best.pth

[Epoch 2] -----------------------------
[Epoch 2 | step 50/7098] L_total=1.5085  L_gt=1.2859  L_obj=0.1097  L_sim=0.1129
[Epoch 2 | step 100/7098] L_total=1.4478  L_gt=1.1831  L_obj=0.1458  L_sim=0.1189
[Epoch 2 | step 150/7098] L_total=1.4446  L_gt=1.2108  L_obj=0.1233  L_sim=0.1104
[Epoch 2 | step 200/7098] L_total=1.3065  L_gt=1.0652  L_obj=0.1380  L_sim=0.1033
[Epoch 2 | step 250/7098] L_total=1.1765  L_gt=0.9513  L_obj=0.1139  L_sim=0.1113
[Epoch 2 | step 300/7098] L_total=1.2283  L_gt=0.9636  L_obj=0.1547  L_sim=0.1100
[Epoch 2 | step 350/7098] L_total=1.3042  L_gt=1.0576  L_obj=0.1331  L_sim=0.1135
[Epoch 2 | step 400/7098] L_total=1.1759  L_gt=0.9506  L_obj=0.1232  L_sim=0.1021
[Epoch 2 | step 450/7098] L_total=1.3706  L_gt=1.1516  L_o

In [16]:
# === Save a checkpoint after training
from pathlib import Path

CHECKPOINT_DIR = Path("checkpoints")
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

ckpt_path = CHECKPOINT_DIR / f"convnext_vclr_epoch{NUM_EPOCHS}_rev2.pth"

torch.save(
    {
        "epoch": NUM_EPOCHS,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "config": {
            "IMG_SIZE": IMG_SIZE,
            "TRAIN_BATCH_SIZE": TRAIN_BATCH_SIZE,
            "BASE_LR": BASE_LR,
            "NUM_EPOCHS": NUM_EPOCHS,
        },
    },
    ckpt_path,
)

print("Saved checkpoint to:", ckpt_path)


Saved checkpoint to: checkpoints\convnext_vclr_epoch4_rev2.pth


In [None]:
# === Validation dataset for Non-VOC AR (UVO, boxes + proxy masks) ===

from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from pycocotools.coco import COCO
from PIL import Image
import torchvision.transforms as T
import torch

# COCO-format Non-VOC annotations (UVO subset)
# NONVOC_VAL_JSON and NONVOC_IMG_DIR should already be defined as:
# NONVOC_VAL_JSON = DATA_ROOT / "uvo_nonvoc_val_rle.json"
# NONVOC_IMG_DIR  = DATA_ROOT / "uvo_videos_dense_frames"
coco_nonvoc = COCO(str(NONVOC_VAL_JSON))
val_img_ids = sorted(coco_nonvoc.getImgIds())
print("Non-VOC val images:", len(val_img_ids))


class VCLRValDataset(Dataset):
    def __init__(self, coco: COCO, img_root: Path, img_size: int = IMG_SIZE):
        """
        img_root: directory that 'file_name' in JSON is relative to.
        Example: file_name = '--33Lscn6sk/180.png'
        and actual disk path:
            NONVOC_IMG_DIR/--33Lscn6sk/180.png
        """
        self.coco = coco
        self.img_root = Path(img_root)
        self.img_ids = sorted(coco.getImgIds())
        self.img_size = img_size

        self.resize_img = T.Resize((img_size, img_size))
        self.to_tensor = T.ToTensor()
        self.normalize = T.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
        )

    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        img_info = self.coco.loadImgs([img_id])[0]

        # e.g. '--33Lscn6sk/180.png'
        rel_path = Path(img_info["file_name"])
        path = self.img_root / rel_path

        # Fallback: .png <-> .jpg if needed
        if not path.is_file():
            alt = None
            if path.suffix.lower() == ".png":
                alt = path.with_suffix(".jpg")
            elif path.suffix.lower() == ".jpg":
                alt = path.with_suffix(".png")
            if alt is not None and alt.is_file():
                path = alt
            else:
                raise FileNotFoundError(
                    f"Image not found for img_id={img_id}. Tried {path}"
                    + (f" and {alt}" if alt is not None else "")
                )

        img = Image.open(path).convert("RGB")
        w0, h0 = img.size
        orig_size = torch.tensor([h0, w0], dtype=torch.float32)

        img = self.resize_img(img)
        img = self.to_tensor(img)
        img = self.normalize(img)

        # Returned:
        #   img        : resized + normalized tensor (C, H, W)
        #   img_id     : COCO/UVO image id (int)
        #   orig_size  : original (H, W) for rescaling boxes/masks back
        return img, img_id, orig_size


def val_collate(batch):
    images = torch.stack([b[0] for b in batch], dim=0)
    img_ids = [b[1] for b in batch]
    orig_sizes = torch.stack([b[2] for b in batch], dim=0)
    return images, img_ids, orig_sizes


val_dataset = VCLRValDataset(coco_nonvoc, NONVOC_IMG_DIR, img_size=IMG_SIZE)
val_loader = DataLoader(
    val_dataset,
    batch_size=VAL_BATCH_SIZE,
    shuffle=False,
    num_workers=VAL_NUM_WORKERS,
    pin_memory=True,
    collate_fn=val_collate,
)

print("Val (Non-VOC UVO) batches:", len(val_loader))

In [None]:
# === COCO AR evaluation + Table-1-style row (bbox + segm) ===

import gc
from pycocotools.cocoeval import COCOeval
from pycocotools import mask as maskUtils
import torch

def evaluate_vclr(model: torch.nn.Module):
    """
    Evaluate AR/AP for bounding boxes and proxy masks on Non-VOC UVO.
    If a teacher–student wrapper is used, the EMA teacher is used for inference.
    """
    # Use EMA teacher if available; otherwise use the model directly
    if hasattr(model, "teacher"):
        net = model.teacher
    else:
        net = model

    net.eval()
    dets_bbox = []
    dets_segm = []

    with torch.no_grad():
        for images, img_ids, orig_sizes in val_loader:
            images = images.to(device, non_blocking=True)

            # Forward pass on natural images
            outputs, _ = net(images, targets=None, proposals=None, view_name="nat")

            for i, img_id in enumerate(img_ids):
                out = outputs[i]
                boxes = out["boxes"].detach().cpu().clone()
                scores = out["scores"].detach().cpu()

                if boxes.numel() == 0:
                    continue

                # Scale boxes from network (IMG_SIZE) back to original H,W
                h0, w0 = orig_sizes[i].tolist()   # orig_size was [H, W]
                sx = w0 / IMG_SIZE
                sy = h0 / IMG_SIZE

                boxes[:, 0] *= sx
                boxes[:, 2] *= sx
                boxes[:, 1] *= sy
                boxes[:, 3] *= sy

                ws = boxes[:, 2] - boxes[:, 0]
                hs = boxes[:, 3] - boxes[:, 1]

                H = int(round(h0))
                W = int(round(w0))
                image_id_int = int(img_id)

                for k in range(boxes.size(0)):
                    x0 = float(boxes[k, 0])
                    y0 = float(boxes[k, 1])
                    w  = float(ws[k])
                    h  = float(hs[k])
                    score = float(scores[k])

                    if w <= 0 or h <= 0:
                        continue

                    # Bounding box detection
                    dets_bbox.append(
                        {
                            "image_id": image_id_int,
                            "category_id": 1,
                            "bbox": [x0, y0, w, h],
                            "score": score,
                        }
                    )

                    # Rectangular proxy mask encoded as RLE for COCO "segm" eval
                    x1 = x0 + w
                    y1 = y0 + h
                    poly = [x0, y0, x1, y0, x1, y1, x0, y1]  # xyxy polygon

                    # frPyObjects -> RLE; COCO expects RLE for detection segm
                    rle = maskUtils.frPyObjects([poly], H, W)[0]
                    if isinstance(rle["counts"], bytes):
                        rle["counts"] = rle["counts"].decode("ascii")

                    dets_segm.append(
                        {
                            "image_id": image_id_int,
                            "category_id": 1,
                            "segmentation": rle,
                            "score": score,
                        }
                    )

    if len(dets_bbox) == 0:
        print("[WARN] No detections produced on Non-VOC val set.")
        return None

    # Bounding box AR / AP
    coco_dt_box = coco_nonvoc.loadRes(dets_bbox)
    coco_eval_box = COCOeval(coco_nonvoc, coco_dt_box, iouType="bbox")
    coco_eval_box.evaluate()
    coco_eval_box.accumulate()
    coco_eval_box.summarize()
    stats_box = coco_eval_box.stats

    # Segmentation AR / AP using rectangular proxy masks (RLE)
    coco_dt_segm = coco_nonvoc.loadRes(dets_segm)
    coco_eval_segm = COCOeval(coco_nonvoc, coco_dt_segm, iouType="segm")
    coco_eval_segm.evaluate()
    coco_eval_segm.accumulate()
    coco_eval_segm.summarize()
    stats_segm = coco_eval_segm.stats

    return {
        "ARb_1":   stats_box[6],
        "ARb_10":  stats_box[7],
        "ARb_100": stats_box[8],
        "APb":     stats_box[0],
        "ARm_1":   stats_segm[6],
        "ARm_10":  stats_segm[7],
        "ARm_100": stats_segm[8],
        "APm":     stats_segm[0],
    }


# --- Run evaluation and print Table-1-style row ---

torch.cuda.empty_cache()
gc.collect()

ar = evaluate_vclr(model)

if ar is not None:
    ARb10  = ar["ARb_10"]  * 100.0
    ARb100 = ar["ARb_100"] * 100.0
    ARm10  = ar["ARm_10"]  * 100.0
    ARm100 = ar["ARm_100"] * 100.0

    print("\nTable-1-style row (Non-VOC UVO):")
    print(
        f"ConvNeXt v-CLR (CNN, EMA) | "
        f"AR^b_10 = {ARb10:.1f}  "
        f"AR^b_100 = {ARb100:.1f}  "
        f"AR^m_10 = {ARm10:.1f}  "
        f"AR^m_100 = {ARm100:.1f}"
    )


In [21]:
print(coco_nonvoc.cats)
print("Category IDs:", coco_nonvoc.getCatIds())

{1: {'supercategory': 'person', 'id': 1, 'name': 'person'}}
Category IDs: [1]


In [20]:
# === Debug: compare training vs validation image dimensions ===

import torch

# ----- 1) One batch from TRAIN loader (v-CLR training set) -----
train_batch = next(iter(train_loader))

images_nat   = train_batch["image_nat"]
images_depth = train_batch["image_depth"]
images_style = train_batch["image_style"]

print("=== TRAIN BATCH ===")
print("image_nat batch shape :",
      images_nat.shape if isinstance(images_nat, torch.Tensor) else type(images_nat))

if isinstance(images_nat, torch.Tensor):
    print("  sample[0] shape   :", images_nat[0].shape)

# Depth view
if images_depth is None:
    print("image_depth         : None")
elif isinstance(images_depth, torch.Tensor):
    print("image_depth batch   :", images_depth.shape)
elif isinstance(images_depth, list) and len(images_depth) > 0 and isinstance(images_depth[0], torch.Tensor):
    print("image_depth[0] shape:", images_depth[0].shape)
else:
    print("image_depth type    :", type(images_depth))

# Style view
if images_style is None:
    print("image_style         : None")
elif isinstance(images_style, torch.Tensor):
    print("image_style batch   :", images_style.shape)
elif isinstance(images_style, list) and len(images_style) > 0 and isinstance(images_style[0], torch.Tensor):
    print("image_style[0] shape:", images_style[0].shape)
else:
    print("image_style type    :", type(images_style))

# ----- 2) One batch from AR VAL loader (UVO Non-VOC) -----
val_batch = next(iter(val_loader))
val_imgs, val_ids, val_orig_sizes = val_batch

print("\n=== VAL-AR BATCH (UVO Non-VOC) ===")
print("images batch shape  :", val_imgs.shape)         # [B, 3, H_val, W_val]
print("sample[0] shape     :", val_imgs[0].shape)      # [3, H_val, W_val]
print("orig_size[0] (H,W)  :", val_orig_sizes[0].tolist())
print("orig_size[1] (H,W)  :", val_orig_sizes[1].tolist() if len(val_orig_sizes) > 1 else "(only one sample)")


=== TRAIN BATCH ===
image_nat batch shape : torch.Size([10, 3, 800, 800])
  sample[0] shape   : torch.Size([3, 800, 800])
image_depth[0] shape: torch.Size([3, 800, 800])
image_style[0] shape: torch.Size([3, 800, 800])

=== VAL-AR BATCH (UVO Non-VOC) ===
images batch shape  : torch.Size([10, 3, 800, 800])
sample[0] shape     : torch.Size([3, 800, 800])
orig_size[0] (H,W)  : [480.0, 854.0]
orig_size[1] (H,W)  : [480.0, 854.0]


In [None]:
# === Plot training / validation losses for v-CLR ===

import matplotlib.pyplot as plt
import numpy as np

# Quick sanity check on what's in train_history
print("L_total_train len:", len(train_history.get("L_total_train", [])))
print("L_total_val   len:", len(train_history.get("L_total_val", [])))
print("L_gt_train    len:", len(train_history.get("L_gt_train", [])))
print("L_gt_val      len:", len(train_history.get("L_gt_val", [])))
print("L_obj_train   len:", len(train_history.get("L_obj_train", [])))
print("L_obj_val     len:", len(train_history.get("L_obj_val", [])))
print("L_sim_train   len:", len(train_history.get("L_sim_train", [])))
print("L_sim_val     len:", len(train_history.get("L_sim_val", [])))

# Use the minimum length across train/val so x and y always match
n_total = min(len(train_history.get("L_total_train", [])),
              len(train_history.get("L_total_val", [])))
n_gt    = min(len(train_history.get("L_gt_train", [])),
              len(train_history.get("L_gt_val", [])))
n_obj   = min(len(train_history.get("L_obj_train", [])),
              len(train_history.get("L_obj_val", [])))
n_sim   = min(len(train_history.get("L_sim_train", [])),
              len(train_history.get("L_sim_val", [])))

if n_total == 0:
    print("No logged epochs in train_history. Run training first, then re-run this cell.")
else:
    epochs_total = np.arange(1, n_total + 1)

    # 1) Total loss: train vs val
    plt.figure(figsize=(6, 4))
    plt.plot(epochs_total,
             train_history["L_total_train"][:n_total],
             label="train L_total")
    plt.plot(epochs_total,
             train_history["L_total_val"][:n_total],
             label="val L_total")
    plt.xlabel("Epoch")
    plt.ylabel("L_total")
    plt.title("v-CLR total loss")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    # 2) L_gt
    if n_gt > 0:
        epochs_gt = np.arange(1, n_gt + 1)
        plt.figure(figsize=(6, 4))
        plt.plot(epochs_gt,
                 train_history["L_gt_train"][:n_gt],
                 label="train L_gt")
        plt.plot(epochs_gt,
                 train_history["L_gt_val"][:n_gt],
                 label="val L_gt")
        plt.xlabel("Epoch")
        plt.ylabel("L_gt")
        plt.title("v-CLR GT loss")
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.show()

    # 3) L_obj
    if n_obj > 0:
        epochs_obj = np.arange(1, n_obj + 1)
        plt.figure(figsize=(6, 4))
        plt.plot(epochs_obj,
                 train_history["L_obj_train"][:n_obj],
                 label="train L_obj")
        plt.plot(epochs_obj,
                 train_history["L_obj_val"][:n_obj],
                 label="val L_obj")
        plt.xlabel("Epoch")
        plt.ylabel("L_obj")
        plt.title("v-CLR object loss")
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.show()

    # 4) L_sim
    if n_sim > 0:
        epochs_sim = np.arange(1, n_sim + 1)
        plt.figure(figsize=(6, 4))
        plt.plot(epochs_sim,
                 train_history["L_sim_train"][:n_sim],
                 label="train L_sim")
        plt.plot(epochs_sim,
                 train_history["L_sim_val"][:n_sim],
                 label="val L_sim")
        plt.xlabel("Epoch")
        plt.ylabel("L_sim")
        plt.title("v-CLR similarity loss")
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.show()
