# Hierarchical Aux-Heads Dataset Builder (from L3 outputs)

This notebook prepares **chunked ROI files** to train **auxiliary L2/L1 heads** (groups/super) *off the base L3 YOLO model*.

**Key features**
- Runs L3 detector on the train split to harvest:
  - **Predicted positives** (TPs by IoU≥0.5),
  - **Predicted false positives** (FPs by IoU≤0.3),
  - **Ambiguous** predictions (0.3<IoU<0.5) → **ignored**,
- Adds **simulated detections**:
  - **Jittered positives** (IoU≥0.5) to **build tolerance** to minor misalignment,
  - **Jittered negatives** (IoU≤0.3),
  - **Background negatives** (random boxes with IoU<0.05 to any GT),
- Saves **chunked JSONL.GZ** files with ROI metadata + labels for L2/L1,
- Produces **class weights** (effective-number) for L2,
- Writes a **manifest.json** for the training script.

**Why IoU-conditioned labeling?**
- We *accept* slight jitter (IoU≥0.5) as positives so aux heads don’t kill good-but-slightly-off detections,
- We *reject* clear mistakes (IoU≤0.3) as negatives to build resilience,
- We **ignore** in-between to avoid noisy supervision.


In [1]:
# Core imports
import os, json, gzip, math, random
from pathlib import Path
from copy import deepcopy
from typing import List, Dict, Tuple, Optional
from collections import defaultdict, Counter

import numpy as np
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
from ultralytics import YOLO
import yaml

import torch.nn as nn

## Configuration

This cell centralizes all parameters for the data preparation pipeline. We have introduced several new parameters to support the **Unified-Compact-ROI** feature extraction strategy.

### Key Engineering Decisions:
- **Output Format**: We are now generating `.pt` (PyTorch tensor) files instead of `jsonl.gz`. This is vastly more efficient for storing and loading large tensors.
- **Feature Extraction (`ROI_` & `COMPRESSED_` params)**: We will perform `roi_align` on a single, scale-appropriate feature map for each box. The `PYRAMID_THRESHOLDS` determine which map to use. The resulting `7x7` feature patch is immediately compressed to a small `256-dim` vector before being saved, solving the I/O bottleneck.
- **Batching (`PREP_BATCH_SIZE`)**: The process is now batch-oriented to maximize GPU utilization. A batch size of `16` is chosen as a safe default for an 8GB GPU, as it needs to hold the full-size feature maps in VRAM.
- **Chunking**: `CHUNK_SIZE` now refers to the number of ROIs per `.pt` chunk file. This ensures each file is a manageable size for streaming.

In [None]:
# Reproducibility
SEED = 2
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# ---- Paths ----
ROOT = Path.cwd()
DATA_YAML = ROOT / "data_vehicle_hierarchy.yaml"
L3_CHECKPOINT = ROOT / "runs" / "yolo_l3_base" / "weights" / "best.pt"

# Output dirs - NOW SAVING .pt FILES
OUT_DIR = ROOT / "aux_heads_chunks_pt"
OUT_DIR.mkdir(parents=True, exist_ok=True)
PREVIEW_DIR = OUT_DIR / "preview"
PREVIEW_DIR.mkdir(exist_ok=True)

# Split-specific subfolders (train/val)
TRAIN_OUT_DIR = OUT_DIR / "train"
VAL_OUT_DIR = OUT_DIR / "val"
TRAIN_OUT_DIR.mkdir(exist_ok=True)
VAL_OUT_DIR.mkdir(exist_ok=True)

MANIFEST_PATH = OUT_DIR / "manifest.json"
STATS_TRAIN_PATH = OUT_DIR / "stats_train.json"
STATS_VAL_PATH = OUT_DIR / "stats_val.json"
CLASS_WEIGHTS_PATH = OUT_DIR / "class_weights.json"

# Hierarchy
L3_NAMES: List[str] = [
    "bus","work_van","single_unit_truck","pickup_truck",
    "articulated_truck","car","motorcycle","bicycle"
 ]
L2_NAMES: List[str] = ["heavy_vehicle","car_group","two_wheeled_vehicle"]
CLASS_TO_L2: Dict[str, str] = {
    "bus": "heavy_vehicle", "single_unit_truck": "heavy_vehicle", "articulated_truck": "heavy_vehicle",
    "car": "car_group", "pickup_truck": "car_group", "work_van": "car_group",
    "motorcycle": "two_wheeled_vehicle", "bicycle": "two_wheeled_vehicle",
}
L1_NAMES: List[str] = ["vehicle"]

# Index maps
L3_TO_IDX = {n:i for i,n in enumerate(L3_NAMES)}
L2_TO_IDX = {n:i for i,n in enumerate(L2_NAMES)}

# Matching Thresholds
# These IoU (Intersection over Union) thresholds define how we label ROI candidates.
# A "candidate" can be a box predicted by the L3 model or a synthetically jittered box.
# The gap between IOU_NEG and IOU_POS creates a "dead zone" for ambiguous overlaps, which are ignored.
IOU_POS = 0.55      # Minimum IoU with a ground truth box to be considered a POSITIVE sample.
                    # This teaches the aux heads tolerance for slight misalignments.
IOU_NEG = 0.25      # Maximum IoU with any ground truth box to be considered a NEGATIVE sample.
                    # This provides confident negative examples.
IOU_BG  = 0.05      # Stricter maximum IoU for randomly sampled background boxes to ensure
                    # they are "pure" background and do not contain parts of any object.

# L3 Inference & Data Preparation Settings
PREP_BATCH_SIZE = 16       # Number of images to process on the GPU in a single pass during this script.
                           # A balance between speed (higher is faster) and VRAM usage.
PRED_CONF_THRES = 0.01     # L3 model confidence threshold. Set very low to capture ALL potential detections,
                           # including weak ones, which the aux heads will learn to either rescue or reject.
PRED_IOU_NMS    = 0.70     # NMS IoU threshold for the L3 model. Set high to be permissive, allowing
                           # multiple overlapping candidates for the aux heads to analyze.
PRED_MAX_DET    = 300      # A safety cap on the maximum number of detections per image from the L3 model.
PREP_WORKERS    = 8        # Number of CPU worker threads for loading data. Half of total threads is a good start.

# Feature Extraction Configuration
# Parameters for our "Unified-Compact-ROI" feature extraction pipeline.
ROI_ALIGN_SIZE = (7, 7)    # The fixed spatial size (H, W) to which all ROIs are pooled by roi_align.
                           # 7x7 is a standard size that preserves spatial patterns efficiently.
COMPRESSED_DIM = 256       # The final dimension of the feature vector saved to disk. Our FeatureCompressor
                           # CNN will process the 7x7 patch down to this compact size.
PYRAMID_THRESHOLDS = (32.0, 96.0) # Pixel thresholds on a box's longest side to select the feature map:
                                  #   - box_side <= 32: Use P3 (high-res map for small objects)
                                  #   - 32 < box_side <= 96: Use P4 (medium-res map)
                                  #   - box_side > 96: Use P5 (low-res map for large objects)

# Jitter & Background Sampling
# Parameters to control the generation of synthetic ROIs for data augmentation and hard-negative mining.
JITTER_POS_PER_GT = 1      # Number of "jittered positive" boxes to generate per ground truth object.
JITTER_NEG_PER_GT = 2      # Number of "jittered negative" boxes to generate per ground truth object.
JITTER_SHIFT_RANGE = (0.1, 0.6) # A box's center will be randomly shifted by 10% to 60% of its size.
JITTER_SCALE_RANGE = (0.1, 0.6) # A box's size will be randomly scaled up or down by 10% to 60%.
BG_NEG_PER_IMG    = 3      # Number of random "background negative" boxes to sample per image.
BG_MIN_SIZE_FRAC  = 0.05   # Minimum size of a background box, as a fraction of image width/height,
                           # to prevent sampling tiny, meaningless patches.

# Chunking (for .pt files)
CHUNK_SIZE = 20_000   # ~20k ROI feature vectors per file
MAX_IMAGES = None     # Set to small int to debug on subset

print("Config loaded.")
print("Output will be written to:", OUT_DIR)

Config loaded.
Output will be written to: c:\Users\Mika\Desktop\New_Training_Run_Post_Bulgaria\aux_heads_chunks_pt


In [3]:
assert DATA_YAML.exists(), f"Missing data yaml: {DATA_YAML}"
with open(DATA_YAML, "r") as f:
    data_cfg = yaml.safe_load(f)

# Expected YOLO structure in data yaml
train_dir = Path(data_cfg["train"])  # folder path or list file
val_dir = Path(data_cfg.get("val", "")) if "val" in data_cfg else None

# Infer labels dir by standard YOLO layout
# images/train/*.jpg -> labels/train/*.txt
def infer_labels_dir(img_dir: Path) -> Path:
    if "images" in img_dir.parts:
        idx = img_dir.parts.index("images")
        labels = Path(*img_dir.parts[:idx], "labels", *img_dir.parts[idx+1:])
        return labels
    # fallback
    return img_dir.parent / "labels" / img_dir.name

labels_train = infer_labels_dir(train_dir)
labels_val = infer_labels_dir(val_dir) if val_dir and str(val_dir) != "" else None

assert train_dir.exists(), f"Train images dir not found: {train_dir}"
assert labels_train.exists(), f"Train labels dir not found: {labels_train}"
if val_dir and str(val_dir) != "":
    assert val_dir.exists(), f"Val images dir not found: {val_dir}"
    assert labels_val.exists(), f"Val labels dir not found: {labels_val}"

# Collect image files
IMG_EXTS = (".jpg",".jpeg",".png",".bmp",".tif",".tiff",".webp")
train_images = [p for p in sorted(train_dir.glob("**/*")) if p.suffix.lower() in IMG_EXTS]
val_images = [p for p in sorted(val_dir.glob("**/*")) if (val_dir and p.suffix.lower() in IMG_EXTS)] if (val_dir and str(val_dir) != "") else []
if MAX_IMAGES is not None:
    train_images = train_images[:MAX_IMAGES]
    val_images = val_images[:MAX_IMAGES]

print(f"Found {len(train_images)} train images.")
print(f"Found {len(val_images)} val images.")

Found 74746 train images.
Found 13190 val images.


## Negative Sampling Strategies

To ensure the **auxiliary L2/L1 heads** become **resilient to false positives from L3**, the dataset builder implements several complementary **negative sampling strategies**. Each is carefully designed to mimic the kinds of mistakes the base detector might make at inference:

1. **False Positive Predictions (`pred_fp`)**  
   - Collected directly from L3’s detections that **do not overlap any ground truth** sufficiently (IoU ≤ 0.3).  
   - These represent *realistic mistakes* the L3 head produces, teaching the aux heads to recognize and down-weight such spurious detections.

2. **Jittered Negatives (`jitter_neg`)**  
   - Generated by perturbing ground-truth boxes (shift/scale up to 20%).  
   - If the perturbed box has **low overlap with its GT (IoU ≤ 0.3)**, it is labeled as a *negative*.  
   - Simulates the situation where L3 places a box in the right neighborhood but **off-target enough to be wrong**.

3. **Ignored Ambiguities (0.3 < IoU < 0.5)**  
   - Predictions or jitters with intermediate IoU are **not included** in training.  
   - This avoids noisy labels in the “gray zone” and keeps supervision **clean and decisive**.

4. **Background Negatives (`bg_neg`)**  
   - Random boxes are sampled in the image, sized roughly like typical objects, but constrained to have **IoU ≤ 0.05 with any GT**.  
   - These are **pure background crops** (road, sky, walls, etc.), helping the aux heads avoid firing on areas with **no objects at all**.



### Why This Matters
- **Coverage**: The heads see *real mistakes* (FPs), *plausible misalignments* (jitter-neg), and *true background* (bg-neg).  
- **Robustness**: Aux heads learn to tolerate small misalignment (positives at IoU ≥ 0.5), but to veto boxes that are clearly wrong.  
- **Balance**: Ignoring the ambiguous range prevents confusing signals and stabilizes learning.

This combination ensures that aux heads act as **robust validators**:  
- They pass through true objects (even with slight misalignment),  
- But suppress random clutter and false alarms.  


## Positive Sampling Strategies

In addition to negatives, the dataset builder ensures that **auxiliary L2/L1 heads** receive strong and diverse **positive examples**, so they learn to *accept true objects* even when L3’s boxes are slightly imperfect:

1. **True Positives (`gt_pos`)**  
   - Directly taken from the **ground-truth (GT) annotations**.  
   - Serve as the **gold standard positives**, ensuring the aux heads learn to recognize each object correctly within the hierarchy.

2. **Jittered Positives (`jitter_pos`)**  
   - Perturbed versions of GT boxes (shifted/scaled by up to 20%).  
   - If the perturbed box maintains **sufficient overlap (IoU ≥ 0.5)** with the GT, it is treated as a **positive sample**.  
   - This teaches aux heads to be **tolerant to mild misalignments**, preparing them to accept real detections from L3 that are not pixel-perfect but still correct.

3. **IoU-Aware Filtering**  
   - Boxes in the ambiguous overlap range (0.3 < IoU < 0.5) are **ignored** rather than forced into positive/negative.  
   - Keeps training labels **clean**, avoiding confusing supervision that could make aux heads overly strict.

### Why This Matters
- **Tolerance**: By including jittered positives, aux heads won’t wrongly reject slightly off detections that are still valid.  
- **Generalization**: They learn not just from exact GT, but also from realistic detector outputs.  
- **Consistency**: The IoU-aware scheme ensures stable training by distinguishing true objects from noise without ambiguity.

This guarantees that aux heads act as **forgiving validators**:  
- They correctly accept slightly misaligned but true boxes,  
- While still rejecting boxes that drift too far into the false positive zone.


## Helper Functions (IOU, Box Ops, Pyramid Chooser)

This cell contains core utility functions. We have added the `choose_pyramid_level` helper, which is a key component of our scale-aware feature extraction strategy.

### Engineering Decision:
- **Scale-Aware Pooling**: Instead of wastefully pooling from all three feature maps, this function uses simple box size heuristics to select the *single most appropriate* feature map (P3 for small objects, P4 for medium, P5 for large). This is a standard FPN technique that significantly improves efficiency while ensuring features are extracted at the correct resolution.

In [None]:
from torchvision.ops import roi_align

def yolo_txt_to_boxes(label_path: Path, img_w: int, img_h: int):
    # (Existing function, no changes needed)
    if not label_path.exists(): return []
    lines = label_path.read_text().strip().splitlines()
    out = []
    for ln in lines:
        parts = ln.strip().split()
        if len(parts) < 5: continue
        c = int(float(parts[0]))
        xc, yc, w, h = [float(p) for p in parts[1:5]]
        x1 = max(0.0, (xc - w/2.0) * img_w)
        y1 = max(0.0, (yc - h/2.0) * img_h)
        x2 = min(float(img_w), (xc + w/2.0) * img_w)
        y2 = min(float(img_h), (yc + h/2.0) * img_h)
        if x2 > x1 and y2 > y1:
            out.append((c, [x1,y1,x2,y2]))
    return out

def boxes_iou_matrix(a: np.ndarray, b: np.ndarray) -> np.ndarray:
    # (Existing function, no changes needed)
    if len(a)==0 or len(b)==0: return np.zeros((len(a), len(b)), dtype=np.float32)
    x1 = np.maximum(a[:,None,0], b[None,:,0])
    y1 = np.maximum(a[:,None,1], b[None,:,1])
    x2 = np.minimum(a[:,None,2], b[None,:,2])
    y2 = np.minimum(a[:,None,3], b[None,:,3])
    inter = np.clip(x2 - x1, 0, None) * np.clip(y2 - y1, 0, None)
    area_a = (a[:,2]-a[:,0]) * (a[:,3]-a[:,1])
    area_b = (b[:,2]-b[:,0]) * (b[:,3]-b[:,1])
    union = area_a[:,None] + area_b[None,:] - inter
    iou = np.where(union > 0, inter/union, 0.0)
    return iou.astype(np.float32)

def clip_box_xyxy(xyxy, w, h):
    # (Existing function, no changes needed)
    x1,y1,x2,y2 = xyxy
    x1 = float(np.clip(x1, 0, w-1)); y1 = float(np.clip(y1, 0, h-1))
    x2 = float(np.clip(x2, 0, w-1)); y2 = float(np.clip(y2, 0, h-1))
    if x2 <= x1 or y2 <= y1: return None
    return [x1,y1,x2,y2]

def to_one_hot(idx: Optional[int], n: int):
    # (Existing function, no changes needed)
    v = [0]*n
    if idx is not None and 0 <= idx < n: v[idx] = 1
    return v

# NEW HELPER FUNCTION
def choose_pyramid_level(boxes_xyxy: torch.Tensor, thresholds: Tuple[float, float] = (32.0, 96.0)):
    """
    Choose which pyramid level to use based on box size.
    Args:
        boxes_xyxy: Float tensor [N, 4] (x1, y1, x2, y2) in pixels.
        thresholds: tuple(t1, t2) thresholds (in pixels).
                    if max(w,h) <= t1: use P3 (high-res)
                    elif <= t2: use P4
                    else: P5
    Returns:
        levels: Long tensor [N] with values in {3, 4, 5}
    """
    if boxes_xyxy.numel() == 0:
        return torch.empty(0, dtype=torch.long, device=boxes_xyxy.device)
    
    w = boxes_xyxy[:, 2] - boxes_xyxy[:, 0]
    h = boxes_xyxy[:, 3] - boxes_xyxy[:, 1]
    max_side = torch.maximum(w, h)

    t1, t2 = thresholds
    levels = torch.full_like(max_side, 4, dtype=torch.long) # Default to P4
    levels[max_side <= t1] = 3
    levels[max_side > t2] = 5
    return levels

## Jitter & Background Samplers (IoU-aware)

In [5]:
def jitter_box(xyxy, w, h, shift_frac=0.2, scale_frac=0.2):
    """
    Randomly jitter a box by shifting center and scaling size.
    shift_frac: max absolute fraction for center shift relative to box w/h.
    scale_frac: max absolute fraction for scaling (±).
    """
    x1,y1,x2,y2 = xyxy
    bw = x2 - x1
    bh = y2 - y1
    cx = x1 + bw/2.0
    cy = y1 + bh/2.0

    # sample actual magnitudes (already provided by caller)
    # shift
    dx = (random.uniform(-shift_frac, shift_frac)) * bw
    dy = (random.uniform(-shift_frac, shift_frac)) * bh
    # scale
    sx = 1.0 + (random.uniform(-scale_frac, scale_frac))
    sy = 1.0 + (random.uniform(-scale_frac, scale_frac))

    new_bw = max(2.0, bw * sx)
    new_bh = max(2.0, bh * sy)
    ncx = cx + dx
    ncy = cy + dy

    nx1 = ncx - new_bw/2.0
    ny1 = ncy - new_bh/2.0
    nx2 = ncx + new_bw/2.0
    ny2 = ncy + new_bh/2.0

    clipped = clip_box_xyxy([nx1,ny1,nx2,ny2], w, h)
    return clipped

def sample_bg_boxes(w, h, n=3, min_size_frac=0.05, size_ref=None):
    """
    Sample random background boxes. If size_ref is provided (list of gt sizes),
    sample sizes near typical GT sizes; else use a broad range.
    """
    out = []
    for _ in range(n*3):  # try more to meet IoU constraints later
        if size_ref and len(size_ref)>0:
            sw, sh = random.choice(size_ref)
            # vary by ±30%
            sw = max(4.0, sw * random.uniform(0.7, 1.3))
            sh = max(4.0, sh * random.uniform(0.7, 1.3))
        else:
            min_w = max(4.0, w * min_size_frac)
            min_h = max(4.0, h * min_size_frac)
            # pick random size
            sw = random.uniform(min_w, w * 0.5)
            sh = random.uniform(min_h, h * 0.5)
        cx = random.uniform(sw/2.0, w - sw/2.0)
        cy = random.uniform(sh/2.0, h - sh/2.0)
        x1 = cx - sw/2.0
        y1 = cy - sh/2.0
        x2 = cx + sw/2.0
        y2 = cy + sh/2.0
        b = clip_box_xyxy([x1,y1,x2,y2], w, h)
        if b is not None:
            out.append(b)
        if len(out) >= n:
            break
    return out

##  Feature Compressor Definition

This cell defines the `FeatureCompressor` module. This lightweight CNN is a critical part of our pipeline.

### Engineering Decision:
- **Pre-computation**: Instead of saving large `7x7` feature patches to disk and processing them during training, we use this module *once* during data preparation. It takes the `roi_align` output and distills it into a compact, fixed-size vector (`256-dim`).
- **Efficiency**: This strategy solves the I/O and storage bottleneck. We save only the small, final vectors, allowing the training script to be extremely fast and memory-efficient, enabling large batch sizes. The compressor model itself is moved to the GPU for fast processing during this data prep stage.

In [6]:
class FeatureCompressor(nn.Module):
    """
    Compresses a [N, C, H, W] feature map from ROI Align into a compact [N, D] vector.
    This is used once during data preparation.
    """
    def __init__(self, in_channels: int, out_dim: int, hidden_dim_factor: int = 2):
        super().__init__()
        hidden_dim = out_dim * hidden_dim_factor
        self.trunk = nn.Sequential(
            nn.Conv2d(in_channels, hidden_dim, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.SiLU(inplace=True),
            nn.Conv2d(hidden_dim, out_dim, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_dim),
            nn.SiLU(inplace=True),
        )
        self.pool = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [N, C, H, W] from roi_align
        features = self.trunk(x)
        pooled_features = self.pool(features)
        return torch.flatten(pooled_features, 1) # [N, out_dim]

# We will define the in_channels later once we hook the model
# For now, we create a placeholder for the variable.
feature_compressor = None
print("FeatureCompressor class defined. Will be instantiated after model loading.")

FeatureCompressor class defined. Will be instantiated after model loading.


## Model Loading, Feature Hooking & Compressor Instantiation

This cell now performs three critical setup tasks:
1.  **Loads the pre-trained L3 YOLO Model** as before.
2.  **Identifies and Hooks Neck Layers**: We programmatically find the specific layers in the YOLOv11 neck that produce the P3, P4, and P5 feature maps. We then attach **forward hooks** to these layers. A hook is a PyTorch function that runs during the model's forward pass, allowing us to capture intermediate outputs (the feature maps) without modifying the original model's code. This is a clean and robust way to get the data we need.
3.  **Instantiates Feature Compressors**: Once the hooks are in place and we know the exact channel dimensions of P3, P4, and P5, we can properly instantiate our `FeatureCompressor` models. We create a dictionary of three separate compressors, one for each pyramid level, and move them to the GPU, ready for the data generation loop.

In [None]:
# This dictionary will store the feature maps captured by our hooks
feature_maps = {}
# This dictionary will store the instantiated FeatureCompressor modules
feature_compressors = {}
# This dictionary will store the spatial stride for each feature map
strides = {}

# 1. Load the L3 Model
assert L3_CHECKPOINT.exists(), f"Missing L3 checkpoint: {L3_CHECKPOINT}"
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = YOLO(str(L3_CHECKPOINT))
model.to(DEVICE)
print("Loaded L3 model:", L3_CHECKPOINT.name)

# 2. Identify and Hook Neck Layers
# The hook function is a closure that captures the output of a specific layer
def get_features_hook(name):
    def hook(model, input, output):
        # For some YOLO versions, the neck returns a tuple. We take the tensor output.
        feature_maps[name] = output[0] if isinstance(output, tuple) else output
    return hook

# =================================================================================
# TODO: UPDATE THESE INDICES based on the architcture of used model!
# =================================================================================
p5_module_idx = 16
p4_module_idx = 19
p3_module_idx = 22
# =================================================================================

try:
    hook_p5 = model.model.model[p5_module_idx].register_forward_hook(get_features_hook('p5'))
    hook_p4 = model.model.model[p4_module_idx].register_forward_hook(get_features_hook('p4'))
    hook_p3 = model.model.model[p3_module_idx].register_forward_hook(get_features_hook('p3'))
    print(f"Successfully attached forward hooks to layers {p3_module_idx}(P3), {p4_module_idx}(P4), {p5_module_idx}(P5).")

except IndexError as e:
    print(f"\n[ERROR] IndexError: {e}")
    print("The indices you provided are incorrect for this model's architecture.")
    print("Please re-run the 'Probe' cell, carefully identify the correct indices, and update them.")
    raise e
except Exception as e:
    print(f"Could not attach hooks. Error: {e}")
    raise

# 3. Quick Probe to Get Channel Dims and Instantiate Compressors
if len(train_images) > 0:
    print("\nRunning a single-image probe to capture feature map shapes...")
    # Using model.predict() is the easiest way to trigger a full forward pass
    with torch.no_grad():
        _ = model.predict(source=[str(train_images[0])], device=DEVICE, verbose=False)

    if not feature_maps:
        raise RuntimeError("Hooks did not capture any feature maps. Check module indices.")

    # Get channel dimensions from the captured feature maps
    p3_channels = feature_maps['p3'].shape[1]
    p4_channels = feature_maps['p4'].shape[1]
    p5_channels = feature_maps['p5'].shape[1]
    
    strides = {
        'p3': model.model.stride[0], # Stride 8
        'p4': model.model.stride[1], # Stride 16
        'p5': model.model.stride[2], # Stride 32
    }

    print(f"Captured feature map shapes (B, C, H, W) and strides:")
    print(f"  P3: {feature_maps['p3'].shape}, Stride: {strides['p3']}")
    print(f"  P4: {feature_maps['p4'].shape}, Stride: {strides['p4']}")
    print(f"  P5: {feature_maps['p5'].shape}, Stride: {strides['p5']}")

    # Now, instantiate the FeatureCompressor for each pyramid level
    feature_compressors = nn.ModuleDict({
        'p3': FeatureCompressor(in_channels=p3_channels, out_dim=COMPRESSED_DIM),
        'p4': FeatureCompressor(in_channels=p4_channels, out_dim=COMPRESSED_DIM),
        'p5': FeatureCompressor(in_channels=p5_channels, out_dim=COMPRESSED_DIM)
    }).to(DEVICE)
    feature_compressors.eval() # Set to eval mode as we only use it for inference here
    
    total_params = sum(p.numel() for p in feature_compressors.parameters())
    print(f"\nFeatureCompressor models instantiated and moved to {DEVICE}.")
    print(f"Total parameters in all compressors: {total_params:,}")

else:
    print("No training images found, skipping probe and compressor instantiation.")

# 4. Final Sanity Check
print("\nPhase 2 setup complete. Ready for data generation.")

Loaded L3 model: best.pt
Successfully attached forward hooks to layers 22(P3), 19(P4), 16(P5).

Running a single-image probe to capture feature map shapes...
Captured feature map shapes (B, C, H, W) and strides:
  P3: torch.Size([1, 256, 14, 20]), Stride: 8.0
  P4: torch.Size([1, 128, 28, 40]), Stride: 16.0
  P5: torch.Size([1, 64, 56, 80]), Stride: 32.0

FeatureCompressor models instantiated and moved to cuda:0.
Total parameters in all compressors: 5,607,936

Phase 2 setup complete. Ready for data generation.


## Ground Truth & Labeling Helpers

This cell contains the helper functions responsible for loading ground truth annotations from YOLO `.txt` files and mapping the L3 leaf classes to their L2 parent classes in the hierarchy. This logic is essential for creating the correct training targets for our auxiliary heads.

In [8]:
def l3_to_l2_idx(l3_idx: int) -> Optional[int]:
    """Maps an L3 class index to its corresponding L2 class index."""
    if l3_idx is None or l3_idx < 0 or l3_idx >= len(L3_NAMES):
        return None
    l3_name = L3_NAMES[l3_idx]
    l2_name = CLASS_TO_L2.get(l3_name, None)
    return L2_TO_IDX[l2_name] if l2_name in L2_TO_IDX else None

def gt_to_l2_idx_by_name(l3_name: str) -> Optional[int]:
    """Maps an L3 class name to its corresponding L2 class index."""
    l2_name = CLASS_TO_L2.get(l3_name, None)
    return L2_TO_IDX[l2_name] if l2_name in L2_TO_IDX else None

def prepare_gt(image_path: Path, labels_dir: Path):
    """
    Loads an image to get its dimensions and parses the corresponding
    YOLO label file to get ground truth boxes and their hierarchical labels.
    """
    # Get image size
    try:
        with Image.open(image_path) as im:
            w, h = im.size
    except Exception as e:
        print(f"Warning: Could not open image {image_path}. Skipping. Error: {e}")
        return [], (0,0), []

    # Find the corresponding label path
    # Standard: images/train/xxx.jpg -> labels/train/xxx.txt
    lbl_path = labels_dir / image_path.with_suffix(".txt").name
    # Handle nested directories if they exist
    if not lbl_path.exists():
        try:
            # Assumes structure like .../images/split_name/subdir/img.jpg
            # Maps to .../labels/split_name/subdir/img.txt
            split_name = image_path.parent.parent.name # e.g., 'train' or 'val'
            relative_path = image_path.relative_to(image_path.parent.parent.parent / "images" / split_name)
            lbl_path = labels_dir.parent / split_name / relative_path.with_suffix(".txt")
        except Exception:
            # Fallback if path logic is complex
            pass

    # Parse GT boxes from the label file
    gt_lst = yolo_txt_to_boxes(lbl_path, w, h)

    # Final list of GTs with (l3_idx, l2_idx, box_xyxy)
    gt = []
    # List of box sizes (w,h) for sampling background boxes
    size_ref = []
    for cls_id, b in gt_lst:
        if cls_id < 0 or cls_id >= len(L3_NAMES):
            continue
        l3_name = L3_NAMES[cls_id]
        l2_idx = gt_to_l2_idx_by_name(l3_name)
        if l2_idx is not None:
            gt.append( (cls_id, l2_idx, b) )
            bw = max(1.0, b[2]-b[0])
            bh = max(1.0, b[3]-b[1])
            size_ref.append((bw, bh))
            
    return gt, (w,h), size_ref

## Core Builder: Batched Feature Extraction (Final Version)

This is the final, definitive version of the core data preparation logic. It uses a robust manual batching loop and includes all necessary metadata for downstream tasks.

### Key Engineering Decisions:
- **Manual Batching**: We iterate through image paths in batches, calling `model.predict()` on one batch at a time. This is efficient and avoids the "Too many open files" error.
- **Pre-filtering for `roi_align`**: We identify which ROIs belong to a specific pyramid level **before** running `roi_align`, which is both correct and computationally efficient.
- **Complete Metadata**: Crucially, we now save the source `img_path` along with the type, IoU, and label for **every single ROI**. This is essential for visualization, debugging, and advanced analysis later.
- **Detailed Progress Monitoring**: The `tqdm` progress bar tracks the running total of each ROI type, giving a clear view of the dataset composition as it's being built.

In [None]:
from collections import Counter

# This is the final, corrected version of the core builder.
# It includes detailed reporting and saves the 'img_path' in the metadata.

def build_chunks_for_split(image_paths: List[Path], labels_dir: Path, out_dir: Path) -> List[str]:
    out_dir.mkdir(exist_ok=True)
    
    buffer = {
        'features': [], 'l1_targets': [], 'l2_targets': [],
        'boxes': [], 'metadata': []
    }
    chunk_idx = 0
    chunk_paths_local: List[str] = []
    total_rois_processed = 0
    stats_counter = Counter()

    pbar = tqdm(total=len(image_paths), desc=f"Processing split: {out_dir.name}")

    # Manually iterate through image paths in batches
    for i in range(0, len(image_paths), PREP_BATCH_SIZE):
        batch_paths = image_paths[i : i + PREP_BATCH_SIZE]
        if not batch_paths: continue

        with torch.no_grad():
            results_list = model.predict(source=[str(p) for p in batch_paths], device=DEVICE, verbose=False)

        batch_rois_by_image = [[] for _ in batch_paths]
        batch_metadata_by_image = [[] for _ in batch_paths]

        for batch_idx, res in enumerate(results_list):
            img_path = Path(res.path)
            gt, (w, h), size_ref = prepare_gt(img_path, labels_dir)
            gt_boxes = np.array([g[2] for g in gt]) if gt else np.zeros((0, 4))
            gt_l3s, gt_l2s = [g[0] for g in gt], [g[1] for g in gt]
            rois_this_image = []
            
            # 1. GT boxes
            for gt_l3, gt_l2, gxy in gt:
                rois_this_image.append((gxy, "gt_pos", 1.0, gt_l3, gt_l2)); stats_counter["gt_pos"] += 1
            # 2. Predicted boxes
            boxes = res.boxes
            pred_xyxy = boxes.xyxy.cpu().numpy() if boxes is not None else np.zeros((0, 4))
            if pred_xyxy.size > 0 and gt_boxes.size > 0:
                iou_mat = boxes_iou_matrix(pred_xyxy, gt_boxes); best_iou, best_gt = iou_mat.max(axis=1), iou_mat.argmax(axis=1)
            else:
                best_iou, best_gt = np.zeros(len(pred_xyxy)), -np.ones(len(pred_xyxy), dtype=int)
            matched_gts = set()
            for k in range(len(pred_xyxy)):
                iou, j = float(best_iou[k]), int(best_gt[k])
                if j != -1 and iou >= IOU_POS and j not in matched_gts:
                    matched_gts.add(j); rois_this_image.append((pred_xyxy[k].tolist(), "pred_tp", iou, gt_l3s[j], gt_l2s[j])); stats_counter["pred_tp"] += 1
                elif iou <= IOU_NEG:
                    rois_this_image.append((pred_xyxy[k].tolist(), "pred_fp", iou, None, None)); stats_counter["pred_fp"] += 1
            # 3. Jittered & Background boxes
            for gt_l3, gt_l2, gxy in gt:
                for _ in range(JITTER_POS_PER_GT):
                    if jb := jitter_box(gxy, w, h, random.uniform(*JITTER_SHIFT_RANGE), random.uniform(*JITTER_SCALE_RANGE)):
                        iou = float(boxes_iou_matrix(np.array([jb]), np.array([gxy]))[0, 0])
                        if iou >= IOU_POS: rois_this_image.append((jb, "jitter_pos", iou, gt_l3, gt_l2)); stats_counter["jitter_pos"] += 1
                        elif iou <= IOU_NEG: rois_this_image.append((jb, "jitter_neg", iou, None, None)); stats_counter["jitter_neg"] += 1
                for _ in range(JITTER_NEG_PER_GT):
                     if jb := jitter_box(gxy, w, h, random.uniform(*JITTER_SHIFT_RANGE), random.uniform(*JITTER_SCALE_RANGE)):
                        iou = float(boxes_iou_matrix(np.array([jb]), np.array([gxy]))[0, 0])
                        if iou <= IOU_NEG: rois_this_image.append((jb, "jitter_neg", iou, None, None)); stats_counter["jitter_neg"] += 1
            bg_boxes = sample_bg_boxes(w, h, n=BG_NEG_PER_IMG, size_ref=size_ref)
            if bg_boxes:
                iou_max = boxes_iou_matrix(np.array(bg_boxes), gt_boxes).max(axis=1) if gt_boxes.size > 0 else np.zeros(len(bg_boxes))
                for box, iou in zip(bg_boxes, iou_max):
                    if iou <= IOU_BG: rois_this_image.append((box, "bg_neg", float(iou), None, None)); stats_counter["bg_neg"] += 1
            if rois_this_image:
                batch_rois_by_image[batch_idx] = [r[0] for r in rois_this_image]
                batch_metadata_by_image[batch_idx] = [r[1:] for r in rois_this_image]

        all_rois_flat = torch.cat([torch.tensor(rois, dtype=torch.float32, device=DEVICE) for rois in batch_rois_by_image], dim=0)
        if all_rois_flat.numel() == 0:
            pbar.update(len(batch_paths)); continue
            
        levels = choose_pyramid_level(all_rois_flat, thresholds=PYRAMID_THRESHOLDS)
        final_features = torch.zeros(all_rois_flat.shape[0], COMPRESSED_DIM, device=DEVICE)
        with torch.no_grad():
            for level_idx in [3, 4, 5]:
                level_name = f'p{level_idx}'; mask = (levels == level_idx)
                if not mask.any(): continue
                original_indices = mask.nonzero().squeeze(1)
                roi_counts = [len(r) for r in batch_rois_by_image]
                image_indices_for_rois = torch.repeat_interleave(torch.arange(len(roi_counts), device=DEVICE), torch.tensor(roi_counts, device=DEVICE))
                boxes_on_level = all_rois_flat[mask]
                image_indices_on_level = image_indices_for_rois[mask]
                rois_for_align_this_level = [boxes_on_level[image_indices_on_level == i] for i in range(len(batch_paths))]
                pooled_feats = roi_align(feature_maps[level_name], rois_for_align_this_level, output_size=ROI_ALIGN_SIZE, spatial_scale=1.0 / strides[level_name])
                if pooled_feats.numel() > 0:
                    compressed_feats = feature_compressors[level_name](pooled_feats)
                    final_features[original_indices] = compressed_feats

        buffer['features'].append(final_features.cpu())
        buffer['boxes'].append(all_rois_flat.cpu())
        
        # Populate Buffer with full metadata
        current_rois_count = 0
        for batch_idx, metadata_list in enumerate(batch_metadata_by_image):
            if not metadata_list: continue
            
            img_path = Path(results_list[batch_idx].path) # Get the correct img_path for this set of ROIs
            
            for sample_type, iou, gt_l3, gt_l2 in metadata_list:
                is_positive = gt_l2 is not None and iou >= IOU_POS
                l1 = 1 if is_positive else 0
                l2 = to_one_hot(gt_l2 if is_positive else None, len(L2_NAMES))
                buffer['l1_targets'].append(l1)
                buffer['l2_targets'].append(l2)
                # Add img_path to each metadata entry
                buffer['metadata'].append({
                    'img_path': str(img_path),
                    'type': sample_type,
                    'iou': iou,
                    'gt_cls_l2': gt_l2 if is_positive else None
                })
        
        total_rois_processed += all_rois_flat.shape[0]; pbar.update(len(batch_paths))
        postfix_stats = {k: f"{v/1e3:.1f}k" for k, v in stats_counter.items()}
        postfix_stats['Total_ROIs'] = f"{total_rois_processed/1e6:.2f}M"
        pbar.set_postfix(postfix_stats)

        if len(buffer['metadata']) >= CHUNK_SIZE:
            save_path = out_dir / f"rois_chunk_{chunk_idx:04d}.pt"
            chunk_data = {'features': torch.cat(buffer['features'], dim=0).half(), 'l1_targets': torch.tensor(buffer['l1_targets'], dtype=torch.bool), 'l2_targets': torch.tensor(buffer['l2_targets'], dtype=torch.bool), 'boxes': torch.cat(buffer['boxes'], dim=0), 'metadata': buffer['metadata']}
            torch.save(chunk_data, save_path); chunk_paths_local.append(str(save_path))
            buffer = {k: [] for k in buffer}; chunk_idx += 1

    if buffer['features']:
        save_path = out_dir / f"rois_chunk_{chunk_idx:04d}.pt"
        chunk_data = {'features': torch.cat(buffer['features'], dim=0).half(), 'l1_targets': torch.tensor(buffer['l1_targets'], dtype=torch.bool), 'l2_targets': torch.tensor(buffer['l2_targets'], dtype=torch.bool), 'boxes': torch.cat(buffer['boxes'], dim=0), 'metadata': buffer['metadata']}
        torch.save(chunk_data, save_path); chunk_paths_local.append(str(save_path))

    pbar.close()
    print(f"Finished split {out_dir.name}. Total ROIs: {total_rois_processed}. Chunks: {len(chunk_paths_local)}")
    return chunk_paths_local

# Main Execution Logic
assert len(train_images) > 0, "No train images found."
train_chunk_paths = build_chunks_for_split(train_images, labels_train, TRAIN_OUT_DIR)
val_chunk_paths = []
if len(val_images) > 0 and labels_val is not None:
    val_chunk_paths = build_chunks_for_split(val_images, labels_val, VAL_OUT_DIR)
print(f"\nFinished all processing. Train chunks: {len(train_chunk_paths)}, Val chunks: {len(val_chunk_paths)}")

Processing split: train: 100%|██████████| 74746/74746 [17:36<00:00, 70.75it/s, gt_pos=188.5k, pred_tp=181.3k, pred_fp=13.8k, jitter_neg=41.6k, jitter_pos=77.0k, bg_neg=175.8k, Total_ROIs=0.68M]


Finished split train. Total ROIs: 678015. Chunks: 34


Processing split: val: 100%|██████████| 13190/13190 [02:42<00:00, 80.94it/s, gt_pos=48.2k, pred_tp=44.1k, jitter_pos=19.6k, jitter_neg=10.6k, bg_neg=31.5k, pred_fp=3.3k, Total_ROIs=0.16M]

Finished split val. Total ROIs: 157400. Chunks: 8

Finished all processing. Train chunks: 34, Val chunks: 8





## Quick Visual Sanity Check (from `.pt` Chunks)

This cell provides a visual confirmation that our data generation is correct. It now loads a random `.pt` chunk, groups the ROIs by their source image, and draws the bounding boxes with colors corresponding to their sample type (`gt_pos`, `pred_fp`, etc.). This allows us to quickly verify that our sampling strategies are producing a diverse and correctly labeled set of examples on real images.

In [10]:
def draw_boxes_pt(image_path: Path, boxes: torch.Tensor, metadata: List[Dict], out_path: Path, max_draw=100):
    """Draws boxes from a .pt chunk onto an image."""
    if not image_path.exists():
        print(f"Warning: Image not found at {image_path}, cannot draw preview.")
        return
        
    with Image.open(image_path).convert("RGB") as im:
        draw = ImageDraw.Draw(im)
        color_map = {
            "gt_pos": (128,0,255),  # Purple
            "pred_tp": (0,255,0),    # Green
            "pred_fp": (255,0,0),    # Red
            "jitter_pos": (0,200,255),# Cyan
            "jitter_neg": (255,128,0),# Orange
            "bg_neg": (200,200,200),  # Gray
        }
        for i in range(min(len(boxes), max_draw)):
            box = boxes[i].tolist()
            meta = metadata[i]
            sample_type = meta.get('type', 'unknown')
            color = color_map.get(sample_type, (255, 255, 0)) # Yellow for unknown
            draw.rectangle(box, outline=color, width=2)
    im.save(out_path)
    print(f"Preview saved: {out_path}")

# --- Load a chunk and visualize a few images from it ---
PREVIEW_IMAGES = 30
chunk_to_load = None
if train_chunk_paths:
    chunk_to_load = random.choice(train_chunk_paths)
elif val_chunk_paths:
    chunk_to_load = random.choice(val_chunk_paths)

if chunk_to_load:
    print(f"Loading chunk for preview: {chunk_to_load}")
    data = torch.load(chunk_to_load)
    
    # Group boxes and metadata by image path
    rois_by_image = defaultdict(lambda: {'boxes': [], 'metadata': []})
    all_boxes = data['boxes']
    all_metadata = data['metadata']

    for i, meta in enumerate(all_metadata):
        img_path = meta.get('img_path')
        if img_path:
            rois_by_image[img_path]['boxes'].append(all_boxes[i])
            rois_by_image[img_path]['metadata'].append(meta)
    
    # Select a few images from this chunk to visualize
    image_paths_in_chunk = list(rois_by_image.keys())
    images_to_preview = random.sample(image_paths_in_chunk, k=min(PREVIEW_IMAGES, len(image_paths_in_chunk)))
    
    print(f"Found {len(image_paths_in_chunk)} unique images in this chunk. Visualizing {len(images_to_preview)} of them.")
    for img_path_str in images_to_preview:
        img_path = Path(img_path_str)
        img_data = rois_by_image[img_path_str]
        
        # Stack the list of tensors into a single tensor for drawing
        boxes_tensor = torch.stack(img_data['boxes'])
        
        out_path = PREVIEW_DIR / (img_path.stem + "_preview.jpg")
        draw_boxes_pt(img_path, boxes_tensor, img_data['metadata'], out_path)

else:
    print("No chunks found to generate a preview.")

Loading chunk for preview: c:\Users\Mika\Desktop\New_Training_Run_Post_Bulgaria\aux_heads_chunks_pt\train\rois_chunk_0008.pt
Found 2160 unique images in this chunk. Visualizing 30 of them.
Preview saved: c:\Users\Mika\Desktop\New_Training_Run_Post_Bulgaria\aux_heads_chunks_pt\preview\00026064_preview.jpg
Preview saved: c:\Users\Mika\Desktop\New_Training_Run_Post_Bulgaria\aux_heads_chunks_pt\preview\00026612_preview.jpg
Preview saved: c:\Users\Mika\Desktop\New_Training_Run_Post_Bulgaria\aux_heads_chunks_pt\preview\00029486_preview.jpg
Preview saved: c:\Users\Mika\Desktop\New_Training_Run_Post_Bulgaria\aux_heads_chunks_pt\preview\00027563_preview.jpg
Preview saved: c:\Users\Mika\Desktop\New_Training_Run_Post_Bulgaria\aux_heads_chunks_pt\preview\00025745_preview.jpg
Preview saved: c:\Users\Mika\Desktop\New_Training_Run_Post_Bulgaria\aux_heads_chunks_pt\preview\00026067_preview.jpg
Preview saved: c:\Users\Mika\Desktop\New_Training_Run_Post_Bulgaria\aux_heads_chunks_pt\preview\00029687_prev

## Dataset Stats & Class Weights (from `.pt` Chunks)

This cell calculates summary statistics and class weights for the L2 head. It reads data from our `.pt` chunk format and uses an **adaptive beta** for class-balanced weighting.

### Engineering Decision:
- **Efficient Iteration**: The `iter_chunk_metadata` function loads each `.pt` file and yields only the `metadata`, which is memory-efficient.
- **Adaptive Class-Balanced Weights**: We calculate an adaptive `beta` based on the frequency of the most common class. This is a robust method from the paper "Class-Balanced Loss Based on Effective Number of Samples" (Cui et al., 2019).
- **Stable Normalization**: To prevent excessively large weights that could destabilize training, we normalize the final weights so that their **mean is 1.0**. This balances the importance of rare classes without causing volatile gradients, making the training process less sensitive to the learning rate.

In [None]:
from collections import defaultdict

def iter_chunk_metadata(paths: List[str]):
    """Generator to iterate through metadata from all .pt chunk files."""
    for cp in paths:
        try:
            data = torch.load(cp, map_location='cpu')
            for meta_item in data['metadata']:
                yield meta_item
        except Exception as e:
            print(f"Warning: Could not load or process chunk {cp}. Error: {e}")
            continue

# Calculate stats for the training split
counts_by_type_train = defaultdict(int)
counts_by_l2_train = defaultdict(int)
total_train = 0

for meta in iter_chunk_metadata(train_chunk_paths):
    counts_by_type_train[meta["type"]] += 1
    total_train += 1
    if meta.get("gt_cls_l2") is not None:
        counts_by_l2_train[meta["gt_cls_l2"]] += 1

print(f"[Train] Total ROIs: {total_train:,}")
print(f"[Train] Counts by type: {dict(counts_by_type_train)}")
l2_counts_named = {L2_NAMES[k]: v for k, v in sorted(counts_by_l2_train.items())}
print(f"[Train] L2 positive counts: {l2_counts_named}")

# Calculate stats for the validation split
total_val = 0
if val_chunk_paths:
    counts_by_type_val = defaultdict(int)
    counts_by_l2_val = defaultdict(int)
    for meta in iter_chunk_metadata(val_chunk_paths):
        counts_by_type_val[meta["type"]] += 1
        total_val += 1
        if meta.get("gt_cls_l2") is not None:
            counts_by_l2_val[meta["gt_cls_l2"]] += 1

    print(f"\n[Val] Total ROIs: {total_val:,}")
    print(f"[Val] Counts by type: {dict(counts_by_type_val)}")
    l2_counts_named_val = {L2_NAMES[k]: v for k, v in sorted(counts_by_l2_val.items())}
    print(f"[Val] L2 positive counts: {l2_counts_named_val}")
else:
    print("\n[Val] No validation chunks found.")

# Calculate Class-Balanced Weights for L2 with Adaptive Beta
def effective_number_weights(counts: Dict[int, int], beta: float):
    """Calculates class-balanced weights and normalizes them to have a mean of 1."""
    weights = {}
    num_classes = len(L2_NAMES)
    if not counts: return {k: 1.0 for k in range(num_classes)}
    
    # 1. Calculate raw weights based on effective number of samples
    for k in range(num_classes):
        n_k = counts.get(k, 1)
        weights[k] = (1.0 - beta) / (1.0 - (beta ** n_k))
    
    # 2. Normalize weights so their mean is 1.0 for training stability
    sum_of_weights = sum(weights.values())
    mean_weight = sum_of_weights / num_classes
    
    final_weights = {k: v / mean_weight for k, v in weights.items()}
    return final_weights

# Derive adaptive beta from train positives
if counts_by_l2_train:
    n_max = max(1, max(counts_by_l2_train.values()))
    beta_used = 1.0 - 1.0 / n_max
    beta_used = float(np.clip(beta_used, 0.9, 0.9999999)) # Clamp for stability
else:
    beta_used = 0.999 # Fallback if no positive samples

print(f"\nAdaptive beta calculated based on max class frequency: {beta_used:.7f}")

l2_weights = effective_number_weights(counts_by_l2_train, beta=beta_used)
print("L2 class weights (normalized to mean=1):", {L2_NAMES[k]: round(v, 4) for k, v in l2_weights.items()})

# Save Stats and Weights
with open(STATS_TRAIN_PATH, "w") as f:
    json.dump({'total_roi': total_train, 'counts_by_type': dict(counts_by_type_train), 'l2_positive_counts': dict(counts_by_l2_train)}, f, indent=2)
if val_chunk_paths:
    with open(STATS_VAL_PATH, "w") as f:
        json.dump({'total_roi': total_val, 'counts_by_type': dict(counts_by_type_val), 'l2_positive_counts': dict(counts_by_l2_val)}, f, indent=2)

with open(CLASS_WEIGHTS_PATH, "w") as f:
    json.dump({
        'l2_weights': l2_weights,
        'l2_beta_used': beta_used,
        'l2_names': L2_NAMES,
        'l1_names': L1_NAMES,
        'l3_names': L3_NAMES
    }, f, indent=2)

print("\nStats saved to:", STATS_TRAIN_PATH)
if val_chunk_paths: print("Stats saved to:", STATS_VAL_PATH)
print("Class weights saved to:", CLASS_WEIGHTS_PATH)

[Train] Total ROIs: 678,015
[Train] Counts by type: {'gt_pos': 188542, 'pred_tp': 181304, 'pred_fp': 13800, 'jitter_neg': 41554, 'jitter_pos': 77040, 'bg_neg': 175775}
[Train] L2 positive counts: {'heavy_vehicle': 32097, 'car_group': 408075, 'two_wheeled_vehicle': 6714}

[Val] Total ROIs: 157,400
[Val] Counts by type: {'gt_pos': 48226, 'pred_tp': 44110, 'jitter_pos': 19612, 'jitter_neg': 10647, 'bg_neg': 31496, 'pred_fp': 3309}
[Val] L2 positive counts: {'heavy_vehicle': 13147, 'car_group': 97594, 'two_wheeled_vehicle': 1207}

Adaptive beta calculated based on max class frequency: 0.9999975
L2 class weights (normalized to mean=1): {'heavy_vehicle': 0.5213, 'car_group': 0.0624, 'two_wheeled_vehicle': 2.4163}

Stats saved to: c:\Users\Mika\Desktop\New_Training_Run_Post_Bulgaria\aux_heads_chunks_pt\stats_train.json
Stats saved to: c:\Users\Mika\Desktop\New_Training_Run_Post_Bulgaria\aux_heads_chunks_pt\stats_val.json
Class weights saved to: c:\Users\Mika\Desktop\New_Training_Run_Post_Bulg

## Write Manifest File

This final step creates the `manifest.json` file. This file is the bridge between our data preparation and the training script. It contains all the necessary information for the trainer to understand the dataset, including paths to the `.pt` chunks and the configuration parameters used to generate the features.

### Engineering Decision:
- **Self-Contained Configuration**: By including the `feature_extraction` parameters (`COMPRESSED_DIM`, etc.) in the manifest, we make the training process more robust. The training script can read these values and dynamically build a compatible model, preventing errors from mismatched feature dimensions.
- **Portability**: All paths are saved as relative paths from the project root. This ensures that the entire project folder can be moved to a different machine without breaking the file links.

In [12]:
manifest = {
    "chunks": {
        "train": [str(Path(p).relative_to(ROOT)) for p in train_chunk_paths],
        "val": [str(Path(p).relative_to(ROOT)) for p in val_chunk_paths]
    },
    "hierarchy": {
        "L1_NAMES": L1_NAMES,
        "L2_NAMES": L2_NAMES,
        "L3_NAMES": L3_NAMES,
        "CLASS_TO_L2": CLASS_TO_L2
    },
    "feature_extraction": { # <-- NEW SECTION
        "roi_align_size": ROI_ALIGN_SIZE,
        "compressed_dim": COMPRESSED_DIM,
        "pyramid_thresholds": PYRAMID_THRESHOLDS
    },
    "thresholds": {
        "IOU_POS": IOU_POS,
        "IOU_NEG": IOU_NEG,
        "IOU_BG":  IOU_BG,
    },
    "sampling": {
        "JITTER_POS_PER_GT": JITTER_POS_PER_GT,
        "JITTER_NEG_PER_GT": JITTER_NEG_PER_GT,
        "BG_NEG_PER_IMG": BG_NEG_PER_IMG
    },
    "paths": {
        "data_yaml": str(DATA_YAML.relative_to(ROOT)),
        "l3_checkpoint": str(L3_CHECKPOINT.relative_to(ROOT)),
        "out_dir": str(OUT_DIR.relative_to(ROOT)),
        "stats_train": str(STATS_TRAIN_PATH.relative_to(ROOT)),
        "stats_val": str(STATS_VAL_PATH.relative_to(ROOT)) if val_chunk_paths else None,
        "class_weights": str(CLASS_WEIGHTS_PATH.relative_to(ROOT)),
    },
    "seed": SEED
}

with open(MANIFEST_PATH, "w") as f:
    json.dump(manifest, f, indent=2)

print("Manifest written to:", MANIFEST_PATH)
print("\n✅ Data preparation notebook is complete!")
print("You are now ready to run the training script for the auxiliary heads.")

Manifest written to: c:\Users\Mika\Desktop\New_Training_Run_Post_Bulgaria\aux_heads_chunks_pt\manifest.json

✅ Data preparation notebook is complete!
You are now ready to run the training script for the auxiliary heads.
