# ü¶ñ DaZZLeD: Recursive Hasher Training Notebook

**Goal:** Train a Tiny Recursive Model (TRM) using DINOv3 distillation for adversarially-robust perceptual hashing.

## How to Use This Notebook

‚ö†Ô∏è **FIRST:** Runtime ‚Üí Change runtime type ‚Üí **GPU (T4 free or A100)**

1. **Run Cell 1** - Mount Google Drive *(required after every runtime change!)*
2. **Run Cell 2** - Download training datasets (first run ~30 min, cached runs ~2 min)
3. **Run Cell 3** - Clone repo & install dependencies
4. **Run Cell 4** - Build manifest & train the model
5. **Run remaining cells** - Test and export the trained model

> üí° **Tip:** If you change runtime mid-session, re-run Cell 1 (Mount Drive) first. The data zip on Drive persists - it just needs to be extracted to local storage again.

## 0. Mount Google Drive

In [1]:
# Mount Google Drive (required for data storage)
from google.colab import drive
drive.mount('/content/drive')

# Create project directories
from pathlib import Path

DRIVE_ROOT = Path("/content/drive/MyDrive/dazzled")
DATA_ROOT = DRIVE_ROOT / "data"
OUTPUT_ROOT = DRIVE_ROOT / "outputs"

# Create all needed directories
for d in [DATA_ROOT / "ffhq", DATA_ROOT / "openimages", DATA_ROOT / "text",
          OUTPUT_ROOT / "checkpoints", OUTPUT_ROOT / "models",
          DRIVE_ROOT / "manifests"]:
    d.mkdir(parents=True, exist_ok=True)

print(f"‚úì Project root: {DRIVE_ROOT}")
print(f"‚úì Data root: {DATA_ROOT}")
print(f"‚úì Output root: {OUTPUT_ROOT}")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
‚úì Project root: /content/drive/MyDrive/dazzled
‚úì Data root: /content/drive/MyDrive/dazzled/data
‚úì Output root: /content/drive/MyDrive/dazzled/outputs


## 1.1 Build Manifest & Train TRM Model

‚ö†Ô∏è **Switch to GPU first:** Runtime ‚Üí Change runtime type ‚Üí GPU (T4 or A100)

This cell:
1. Builds a manifest of all training images
2. Trains the **TRM Hasher** using proper TRM deep supervision (from [arXiv:2510.04871](https://arxiv.org/abs/2510.04871))
3. Distills DINOv3 teacher into tiny 2-layer recursive network
4. Saves checkpoints to Google Drive

**Key TRM features:**
- Deep supervision: N_sup=16 steps per sample, loss at each step
- Two features: y (hash) and z (latent reasoning)
- Latent recursion: z = net(x, y, z) √ó 6, then y = net(y, z)
- EMA for stability (decay=0.999)

In [None]:
# =============================================================================
# üì¶ DOWNLOAD & PREPARE DATASETS
# =============================================================================
# This cell downloads and prepares the training data.
# Total: ~45k images from 3 sources:
#   - FFHQ: Kaggle (face images for identity)
#   - OpenImages: FiftyOne (diverse real-world objects)
#   - MobileViews: HuggingFace parquet
#
# WHY THIS MIX?
#   - FFHQ 40k: Faces require precise hashing (identity preservation)
#   - OpenImages 2.5k: Broad category coverage (animals, vehicles, food)
#   - MobileViews 2k: Edge cases for text/UI (600k available, 2k is enough)
# =============================================================================

import subprocess
import shutil
from pathlib import Path

# Config
DATA_ROOT = Path("/content/data")
DRIVE_ROOT = Path("/content/drive/MyDrive/dazzled")
DRIVE_ARCHIVE = DRIVE_ROOT / "data-cache/training-images.zip"

# =============================================================================
# ‚ö†Ô∏è CHECK: Is Google Drive mounted?
# =============================================================================
if not Path("/content/drive/MyDrive").exists():
    raise RuntimeError(
        "‚ùå Google Drive is NOT mounted!\n\n"
        "This happens after runtime changes (CPU‚ÜíGPU, session timeout, etc.)\n\n"
        "FIX: Run the 'Mount Google Drive' cell above first (Cell 1),\n"
        "     then re-run this cell. Your cached data zip is safe on Drive."
    )

EXPECTED_COUNTS = {
    "ffhq": 40000,          # Full dataset
    "openimages": 2500,     # Subset
    "mobileviews": 1500,    # Target: 2k
}

def validate_dataset(data_root: Path, expected: dict) -> tuple[bool, dict]:
    """Validate all datasets have enough images."""
    results = {}
    for name, exp_count in expected.items():
        path = data_root / name
        actual = len(list(path.glob("*.jpg"))) if path.exists() else 0
        results[name] = {"count": actual, "expected": exp_count, "valid": actual >= exp_count * 0.95}
    return all(r["valid"] for r in results.values()), results

# Check what we already have
DATA_ROOT.mkdir(parents=True, exist_ok=True)

# Initialize download flags
need_download = {"ffhq": False, "openimages": False, "mobileviews": False}
skip_downloads = False

# Option 1: Restore from Drive cache (fastest)
if DRIVE_ARCHIVE.exists():
    print("üîÑ Found cached dataset on Drive!")
    print(f"   üìÅ {DRIVE_ARCHIVE}")
    shutil.unpack_archive(DRIVE_ARCHIVE, DATA_ROOT)
    all_valid, validation = validate_dataset(DATA_ROOT, EXPECTED_COUNTS)
    if all_valid:
        print("\n‚úÖ All datasets restored from cache successfully!")
        for name, info in validation.items():
            print(f"   ‚úì {name}: {info['count']:,} images")
        print("\nüéâ Skipping downloads - data is ready!")
        skip_downloads = True
    else:
        print("\n‚ö†Ô∏è Cache incomplete, will download missing data:")
        for name, info in validation.items():
            if not info["valid"]:
                need_download[name] = True
                print(f"   ‚úó {name}: {info['count']:,}/{info['expected']:,} (need more)")
            else:
                print(f"   ‚úì {name}: {info['count']:,} images (OK)")
else:
    print("üì• No cache found on Drive, downloading all datasets...")
    need_download = {"ffhq": True, "openimages": True, "mobileviews": True}

# Option 2: Download fresh data (only if cache was missing or incomplete)
if not skip_downloads and any(need_download.values()):
    print("\n" + "="*65)
    print("DOWNLOADING DATASETS")
    print("="*65)

    import torchvision.transforms as transforms
    from PIL import Image
    import io

    # -------------------------------------------------------------------------
    # 1. FFHQ via Kaggle
    # -------------------------------------------------------------------------
    if need_download["ffhq"]:
        ffhq_dir = DATA_ROOT / "ffhq"
        ffhq_dir.mkdir(parents=True, exist_ok=True)

        print("\nüì• [1/3] FFHQ via Kaggle")
        print("   Target: 40k high-quality face images")

        # Setup Kaggle credentials from Colab secrets
        try:
            from google.colab import userdata
            import os
            os.environ["KAGGLE_USERNAME"] = userdata.get("KAGGLE_USERNAME")
            os.environ["KAGGLE_KEY"] = userdata.get("KAGGLE_KEY")
            print("   ‚úì Kaggle credentials loaded from Colab secrets")
        except Exception as e:
            print(f"   ‚ö†Ô∏è Could not load Kaggle secrets: {e}")
            print("   Add KAGGLE_USERNAME and KAGGLE_KEY to Colab secrets")

        # Download FFHQ from Kaggle
        subprocess.run([
            "kaggle", "datasets", "download", "-d", "arnaud58/flickrfaceshq-dataset-ffhq",
            "-p", str(ffhq_dir), "--unzip"
        ], check=True)

        # Flatten directory structure (Kaggle downloads into nested folders)
        for nested_dir in ffhq_dir.rglob("*"):
            if nested_dir.is_file() and nested_dir.suffix.lower() in {".jpg", ".png"}:
                target = ffhq_dir / nested_dir.name
                if not target.exists():
                    shutil.move(str(nested_dir), str(target))

        # Clean empty directories
        for d in ffhq_dir.iterdir():
            if d.is_dir():
                shutil.rmtree(d)

        count = len(list(ffhq_dir.glob("*.jpg"))) + len(list(ffhq_dir.glob("*.png")))
        print(f"   ‚úì FFHQ: {count:,} images")

    # -------------------------------------------------------------------------
    # 2. OpenImages via FiftyOne
    # -------------------------------------------------------------------------
    if need_download["openimages"]:
        oi_dir = DATA_ROOT / "openimages"
        oi_dir.mkdir(parents=True, exist_ok=True)

        print("\nüì• [2/3] OpenImages via FiftyOne")
        print("   Target: 2.5k diverse real-world images")

        try:
            import fiftyone as fo
            import fiftyone.zoo as foz

            # Download a subset of OpenImages validation split
            dataset = foz.load_zoo_dataset(
                "open-images-v7",
                split="validation",
                max_samples=2500,
                shuffle=True,
                seed=42,
            )

            # Copy images to our directory
            for sample in dataset:
                src = Path(sample.filepath)
                dst = oi_dir / src.name
                if src.exists() and not dst.exists():
                    shutil.copy2(src, dst)

            count = len(list(oi_dir.glob("*")))
            print(f"   ‚úì OpenImages: {count:,} images")

            # Cleanup FiftyOne dataset
            fo.delete_dataset(dataset.name)

        except ImportError:
            print("   ‚ö†Ô∏è FiftyOne not installed. Installing...")
            subprocess.run(["pip", "install", "fiftyone"], check=True)
            print("   Please re-run this cell after installation")

    # -------------------------------------------------------------------------
    # 3. MobileViews via HuggingFace
    # -------------------------------------------------------------------------
    if need_download["mobileviews"]:
        mv_dir = DATA_ROOT / "mobileviews"
        mv_dir.mkdir(parents=True, exist_ok=True)

        print("\nüì• [3/3] MobileViews via HuggingFace")
        print("   Target: 2k mobile UI screenshots")

        try:
            from datasets import load_dataset

            # Load subset of MobileViews dataset
            ds = load_dataset(
                "mllmTeam/MobileViews",
                split="train",
                streaming=True,
            )

            # Take first 2000 samples
            transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.Lambda(lambda x: x.convert("RGB")),
            ])

            count = 0
            target = 2000
            for i, sample in enumerate(ds):
                if count >= target:
                    break
                try:
                    img = sample.get("image")
                    if img is not None:
                        img = transform(img)
                        img.save(mv_dir / f"mv_{count:05d}.jpg", "JPEG", quality=85)
                        count += 1
                except Exception:
                    continue

            print(f"   ‚úì MobileViews: {count:,} images")

        except ImportError:
            print("   ‚ö†Ô∏è datasets not installed. Installing...")
            subprocess.run(["pip", "install", "datasets"], check=True)
            print("   Please re-run this cell after installation")

    # =========================================================================
    # üíæ SAVE TO DRIVE CACHE (for future sessions)
    # =========================================================================
    print("\n" + "="*65)
    print("SAVING CACHE TO DRIVE")
    print("="*65)

    DRIVE_ARCHIVE.parent.mkdir(parents=True, exist_ok=True)

    # Create archive
    print(f"üì¶ Creating archive: {DRIVE_ARCHIVE}")
    shutil.make_archive(
        str(DRIVE_ARCHIVE.with_suffix("")),  # Remove .zip for make_archive
        "zip",
        DATA_ROOT,
    )

    archive_size = DRIVE_ARCHIVE.stat().st_size / (1024 ** 3)
    print(f"‚úì Cache saved ({archive_size:.2f} GB)")
    print("  Next session will restore from cache instantly!")

# =========================================================================
# üìä FINAL VALIDATION
# =========================================================================
print("\n" + "="*65)
print("FINAL DATASET VALIDATION")
print("="*65)

all_valid, validation = validate_dataset(DATA_ROOT, EXPECTED_COUNTS)
total_images = sum(info["count"] for info in validation.values())

for name, info in validation.items():
    status = "‚úì" if info["valid"] else "‚úó"
    print(f"{status} {name}: {info['count']:,} / {info['expected']:,} images")

print(f"\nüìä Total: {total_images:,} images")

if all_valid:
    print("‚úÖ All datasets ready for training!")
else:
    print("‚ö†Ô∏è Some datasets are incomplete. Re-run this cell to download more.")


üöÄ DOWNLOADING DATASETS

‚úì [1/3] FFHQ already exists (52,001 images). Skipping download.

üì• [2/3] OpenImages via FiftyOne
   Method: Official Google download (handles AWS S3 shards)
   Target: 15k diverse real-world images
   Downloading from AWS S3...
Downloading split 'validation' to '/root/fiftyone/open-images-v7/validation' if necessary


INFO:fiftyone.zoo.datasets:Downloading split 'validation' to '/root/fiftyone/open-images-v7/validation' if necessary


Downloading 15000 images


INFO:fiftyone.utils.openimages:Downloading 15000 images


 100% |‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 15000/15000 [4.9m elapsed, 0s remaining, 50.6 files/s]      


INFO:eta.core.utils: 100% |‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 15000/15000 [4.9m elapsed, 0s remaining, 50.6 files/s]      


Dataset info written to '/root/fiftyone/open-images-v7/info.json'


INFO:fiftyone.zoo.datasets:Dataset info written to '/root/fiftyone/open-images-v7/info.json'


Loading 'open-images-v7' split 'validation'


INFO:fiftyone.zoo.datasets:Loading 'open-images-v7' split 'validation'


  94% |‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà-| 14173/15000 [11.4m elapsed, 40.2s remaining, 20.6 samples/s]   

## 1. Setup & Installation

In [None]:
# Clone the repo (only needed in Colab)
import os
if not os.path.exists('DaZZLeD'):
    !git clone https://github.com/D13ya/DaZZLeD.git
    %cd DaZZLeD/ml-core
else:
    %cd DaZZLeD/ml-core

!pip install -q -r requirements.txt

In [None]:
import sys
import torch
import torch.nn.functional as F
from pathlib import Path

# Add project root to path
sys.path.insert(0, str(Path.cwd()))

from models.trm_hasher import TRMHasher

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 1.1 Build Manifest & Train Model

‚ö†Ô∏è **Switch to GPU first:** Runtime ‚Üí Change runtime type ‚Üí GPU (T4 or A100)

This cell:
1. Builds a manifest of all training images
2. Trains the RecursiveHasher using DINOv3 distillation
3. Saves checkpoints to Google Drive

In [None]:
# =============================================================================
# BUILD MANIFEST FROM LOCAL DATA
# =============================================================================

from pathlib import Path

# Point to local fast storage
DATA_ROOT = Path("/content/data")
DRIVE_ROOT = Path("/content/drive/MyDrive/dazzled")

# Find all training images
exts = {".jpg", ".jpeg", ".png", ".bmp", ".webp"}
paths = [str(p) for p in DATA_ROOT.rglob("*") if p.suffix.lower() in exts]

print(f"Found {len(paths):,} training images in {DATA_ROOT}")

if len(paths) < 100:
    print("‚ö†Ô∏è  Not enough images! Run the download cell first.")
else:
    # Write manifest to Drive so it persists, but content points to /content/data
    manifest_path = DRIVE_ROOT / "manifests/train.txt"
    manifest_path.parent.mkdir(parents=True, exist_ok=True)

    with open(manifest_path, "w") as f:
        f.write("\n".join(paths))

    print(f"‚úì Manifest written: {manifest_path}")
    print(f"  (Points to {len(paths)} local files for high-speed training)")

In [None]:
# =============================================================================
# TRAIN THE MODEL WITH TRM ARCHITECTURE (Deep Supervision)
# =============================================================================
# Using the proper TRM algorithm from https://arxiv.org/abs/2510.04871
# Key differences from standard distillation:
#   1. Deep supervision: N_sup=16 steps, loss at EACH step
#   2. (y, z) features: y=hash, z=latent reasoning state
#   3. Carry (y, z) across supervision steps (detached)
#   4. EMA for stability (decay=0.999)

# üîë AUTHENTICATION (Required for DINOv3)
from huggingface_hub import login
from google.colab import userdata

try:
    print("üîê Logging in to Hugging Face...")
    hf_token = userdata.get('HF_TOKEN')
    if hf_token:
        login(hf_token)
        print("‚úì Logged in!")
    else:
        print("‚ö†Ô∏è HF_TOKEN not found in Secrets! You may encounter 401 Unauthorized errors.")
except Exception as e:
    print(f"‚ö†Ô∏è Login failed: {e}")

# üîß TRM TRAINING OPTIONS - Uncomment ONE of the following:

# -----------------------------------------------------------------------------
# OPTION A: Quick Test (5-10 min on T4) - Just to verify everything works
# -----------------------------------------------------------------------------
# !python training/train_trm.py \
#     --data-list /content/drive/MyDrive/dazzled/manifests/train.txt \
#     --teacher facebook/dinov3-vitl16-pretrain-lvd1689m \
#     --epochs 1 \
#     --batch-size 64 \
#     --n-sup 8 \
#     --max-steps 100 \
#     --amp \
#     --log-interval 10 \
#     --checkpoint-dir /content/drive/MyDrive/dazzled/outputs/checkpoints

# -----------------------------------------------------------------------------
# OPTION B: A100 80GB - TRM + SimCLR (Option A)
# -----------------------------------------------------------------------------
# Loss function (mathematically clean):
#   L = (1/N_sup) * Œ£_{k=1}^{N_sup} L_align(y_k, teacher) 
#       + Œª * L_SimCLR(y_N^1, y_N^2) / log(2B)
#
# Key properties:
# - Distillation at EVERY step k (core TRM deep supervision)
# - SimCLR ONLY at final step N (on most refined hash)
# - Contrast loss normalized by log(2B) for stable Œª tuning
# - Œª warmup: first epoch is pure TRM, then SimCLR kicks in
#
# Memory: batch 192 (two views) + DINOv3 + TRM ‚âà 60GB VRAM
#
# TRM architecture (from arXiv:2510.04871):
#   --embed-dim 256, --hash-dim 128, --n-layers 2, --n-latent 6, --t 3, --n-sup 16
#
# SimCLR (true two-view NT-Xent, applied at final step only):
#   --contrast-weight 0.3: Œª (normalized, 0.1-0.5 range)
#   --nce-temperature 0.07: œÑ (lower = sharper)
#   --contrast-warmup-epochs 1: first epoch is pure distillation
#
# If OOM: reduce --batch-size to 128
# -----------------------------------------------------------------------------
!python training/train_trm.py \
    --data-list /content/drive/MyDrive/dazzled/manifests/train.txt \
    --teacher facebook/dinov3-vitl16-pretrain-lvd1689m \
    --epochs 5 \
    --batch-size 192 \
    --embed-dim 256 \
    --hash-dim 128 \
    --n-layers 2 \
    --n-latent 6 \
    --t 3 \
    --n-sup 16 \
    --lr 1e-4 \
    --warmup-steps 2000 \
    --contrast-weight 0.3 \
    --nce-temperature 0.07 \
    --contrast-warmup-epochs 1 \
    --use-ema \
    --amp \
    --allow-tf32 \
    --channels-last \
    --cudnn-benchmark \
    --cache-ram \
    --workers 2 \
    --prefetch-factor 2 \
    --pin-memory \
    --log-interval 25 \
    --checkpoint-dir /content/drive/MyDrive/dazzled/outputs/checkpoints \
    --checkpoint-every 500

# -----------------------------------------------------------------------------
# OPTION C: Resume from Checkpoint (if session disconnected)
# -----------------------------------------------------------------------------
# !python training/train_trm.py \
#     --data-list /content/drive/MyDrive/dazzled/manifests/train.txt \
#     --teacher facebook/dinov3-vitl16-pretrain-lvd1689m \
#     --resume /content/drive/MyDrive/dazzled/outputs/checkpoints/student_epoch_1.safetensors \
#     --epochs 5 \
#     --batch-size 192 \
#     --embed-dim 256 \
#     --hash-dim 128 \
#     --n-layers 2 \
#     --n-latent 6 \
#     --t 3 \
#     --n-sup 16 \
#     --contrast-weight 0.3 \
#     --nce-temperature 0.07 \
#     --contrast-warmup-epochs 0 \
#     --use-ema \
#     --amp \
#     --allow-tf32 \
#     --channels-last \
#     --cudnn-benchmark \
#     --cache-ram \
#     --workers 2 \
#     --prefetch-factor 2 \
#     --pin-memory \
#     --checkpoint-dir /content/drive/MyDrive/dazzled/outputs/checkpoints \
#     --checkpoint-every 500

In [None]:
# =============================================================================
# LIST CHECKPOINTS & LOAD TRAINED TRM WEIGHTS
# =============================================================================
# This cell is SELF-CONTAINED - run after runtime restart to reload model

from pathlib import Path
import torch
import torch.nn.functional as F
import sys
import numpy as np

# Add project path
sys.path.insert(0, "/content/DaZZLeD/ml-core")

# TRM Architecture parameters (MUST match training)
EMBED_DIM = 256
HASH_DIM = 96
N_LAYERS = 2
N_LATENT = 6
T = 3  # Note: TRMHasher uses lowercase 't' parameter
N_SUP = 16  # Supervision steps for inference
IMAGE_SIZE = 224

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# Checkpoint directory on Drive
CKPT_DIR = Path("/content/drive/MyDrive/dazzled/outputs/checkpoints")
ONNX_PATH = Path("/content/drive/MyDrive/dazzled/outputs/models/trm_hasher.onnx")

# Check if Drive is mounted
if not Path("/content/drive/MyDrive").exists():
    raise RuntimeError(
        "‚ùå Google Drive is NOT mounted!\n\n"
        "FIX: Run the 'Mount Google Drive' cell first (Cell 2),\n"
        "     then re-run this cell."
    )

# List available checkpoints
checkpoints = sorted(CKPT_DIR.glob("*.safetensors"))
print(f"Found {len(checkpoints)} checkpoints in {CKPT_DIR}:")

# Separate epoch and step checkpoints
epoch_ckpts = [c for c in checkpoints if "epoch" in c.name]
step_ckpts = [c for c in checkpoints if "step" in c.name]

if epoch_ckpts:
    print("\nüìÅ Epoch checkpoints:")
    for ckpt in epoch_ckpts:
        size_mb = ckpt.stat().st_size / (1024 * 1024)
        print(f"  {ckpt.name} ({size_mb:.2f} MB)")

if step_ckpts:
    print(f"\nüìÅ Step checkpoints: {len(step_ckpts)} files")
    # Show first and last few
    for ckpt in step_ckpts[:2] + step_ckpts[-2:]:
        size_mb = ckpt.stat().st_size / (1024 * 1024)
        print(f"  {ckpt.name} ({size_mb:.2f} MB)")

# Load the latest checkpoint
if epoch_ckpts:
    latest = epoch_ckpts[-1]
elif step_ckpts:
    latest = step_ckpts[-1]
elif checkpoints:
    latest = checkpoints[-1]
else:
    latest = None
    raise FileNotFoundError(f"‚ö†Ô∏è No checkpoints found in {CKPT_DIR}!")

print(f"\nüîÑ Loading: {latest.name}")

# Import and create TRM model
from models.trm_hasher import TRMHasher
import safetensors.torch

student = TRMHasher(
    embed_dim=EMBED_DIM,
    hash_dim=HASH_DIM,
    n_layers=N_LAYERS,
    n_latent=N_LATENT,
    t=T  # lowercase 't' to match TRMHasher signature
).to(device)

safetensors.torch.load_model(student, str(latest))
student.eval()

# Also set 'model' as alias for compatibility with other cells
model = student

# Count parameters
total_params = sum(p.numel() for p in student.parameters())
print(f"\n‚úì Loaded TRM model from {latest.name}")
print(f"  Total parameters: {total_params:,}")
print(f"  Model size: ~{total_params * 4 / 1024 / 1024:.2f} MB (float32)")
print(f"\n‚úÖ Ready for validation milestones!")

## 2. Model Architecture Test

In [None]:
# =============================================================================
# LOAD TRM HASHER MODEL
# =============================================================================

import torch
import torch.nn.functional as F
import sys
sys.path.insert(0, "/content/DaZZLeD/ml-core")
from models.trm_hasher import TRMHasher

# TRM Architecture parameters (must match training)
EMBED_DIM = 256
HASH_DIM = 96
N_LAYERS = 2
N_LATENT = 6
T = 3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Only create a new model if one doesn't already exist
if 'model' not in dir() or model is None:
    print("‚ö†Ô∏è  Creating UNTRAINED model for architecture testing")
    model = TRMHasher(
        embed_dim=EMBED_DIM,
        hash_dim=HASH_DIM,
        n_layers=N_LAYERS,
        n_latent=N_LATENT,
        T=T
    ).to(device)
    model.eval()
else:
    print("‚úì Using existing model (trained weights loaded)")

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model size: ~{total_params * 4 / 1024 / 1024:.2f} MB (float32)")

In [None]:
# Test forward pass with dummy input (TRM deep recursion)
batch_size = 4
image_size = 224

dummy_img = torch.randn(batch_size, 3, image_size, image_size).to(device)

with torch.no_grad():
    # Use inference mode with N_sup supervision steps
    hash_out = model.inference(dummy_img, n_sup=16)

print(f"Input image shape: {dummy_img.shape}")
print(f"Output hash shape: {hash_out.shape}")
print(f"Hash L2 norm (should be ~1.0): {torch.norm(hash_out, dim=1).mean():.4f}")

## 3. Recursive Inference Test

The key innovation is running the model recursively 16 times, refining the hash at each step.

In [None]:
def recursive_inference_trm(model, image, n_sup=16):
    """
    Run TRM recursive inference for the specified number of supervision steps.
    
    This implements the deep_recursion from the TRM paper:
    - For each supervision step, run latent_recursion
    - Return hash at each step to observe convergence
    """
    batch_size = image.size(0)
    device = image.device
    
    # Initialize y and z
    y = model.y_init.expand(batch_size, -1).to(device)
    z = model.z_init.expand(batch_size, -1).to(device)
    
    # Encode image once
    with torch.no_grad():
        x = model.image_encoder(image)
    
    hashes = []
    with torch.no_grad():
        for step in range(n_sup):
            # Deep recursion: T-1 without grad, 1 with grad (all no-grad in inference)
            for _ in range(model.t):  # lowercase 't' to match TRMHasher attribute
                y, z = model.latent_recursion(x, y, z)
            
            # Compute hash at this step
            hash_out = F.normalize(model.output_head(y), p=2, dim=-1)
            hashes.append(hash_out.clone())
            
            # Carry forward (already detached since no_grad)
    
    return hashes

# Run TRM recursive inference
hashes = recursive_inference_trm(model, dummy_img.to(device), n_sup=16)

print(f"Generated {len(hashes)} hash vectors (one per supervision step)")
print(f"Final hash shape: {hashes[-1].shape}")

In [None]:
# Analyze TRM hash convergence across supervision steps
import matplotlib.pyplot as plt
import numpy as np

# Compute cosine similarity between consecutive steps
similarities = []
for i in range(1, len(hashes)):
    sim = F.cosine_similarity(hashes[i], hashes[i-1], dim=1).mean().item()
    similarities.append(sim)

plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(range(1, len(hashes)), similarities, 'b-o')
plt.xlabel('Supervision Step')
plt.ylabel('Cosine Similarity with Previous')
plt.title('TRM Hash Convergence')
plt.ylim([0.9, 1.0])
plt.grid(True)

# Plot hash norm evolution
norms = [torch.norm(h, dim=1).mean().item() for h in hashes]
plt.subplot(1, 2, 2)
plt.plot(range(len(hashes)), norms, 'g-o')
plt.xlabel('Supervision Step')
plt.ylabel('L2 Norm')
plt.title('Hash Norm Stability')
plt.grid(True)

plt.tight_layout()
plt.show()

# Convergence check: similarity should increase and stabilize
if len(similarities) > 5:
    early_sim = np.mean(similarities[:3])
    late_sim = np.mean(similarities[-3:])
    print(f"\nConvergence Analysis:")
    print(f"  Early steps avg similarity: {early_sim:.4f}")
    print(f"  Late steps avg similarity:  {late_sim:.4f}")
    print(f"  Improvement: {(late_sim - early_sim):.4f}")
    if late_sim > 0.995:
        print("  ‚úì Model has converged (high stability)")
    else:
        print("  ‚ö†Ô∏è Model may need more training or supervision steps")

## 4. Adversarial Robustness Test

Test if small perturbations to input cause large changes in hash (they shouldn't after recursive refinement).

In [None]:
# =============================================================================
# ROBUSTNESS TEST WITH PROPER SANITY CHECKS
# =============================================================================
from PIL import Image
from torchvision import transforms
from pathlib import Path
import json

# Load normalization from model_config.json (saved during training)
# This ensures eval uses the SAME mean/std as training (AutoImageProcessor values)
CHECKPOINT_DIR = Path("/content/drive/MyDrive/dazzled/outputs/checkpoints")
config_path = CHECKPOINT_DIR / "model_config.json"

if config_path.exists():
    with open(config_path) as f:
        model_config = json.load(f)
    NORM_MEAN = model_config.get("norm_mean", [0.485, 0.456, 0.406])
    NORM_STD = model_config.get("norm_std", [0.229, 0.224, 0.225])
    print(f"‚úì Loaded normalization from {config_path}")
    print(f"  Mean: {NORM_MEAN}")
    print(f"  Std:  {NORM_STD}")
else:
    # Fallback to ImageNet defaults (may cause slight eval mismatch)
    print(f"‚ö†Ô∏è {config_path} not found, using ImageNet defaults")
    print("   (This may cause slight mismatch with training normalization)")
    NORM_MEAN = [0.485, 0.456, 0.406]
    NORM_STD = [0.229, 0.224, 0.225]

transform = transforms.Compose([
    transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize(mean=NORM_MEAN, std=NORM_STD),
])

# Load real test images
DATA_ROOT = Path("/content/data")
test_image_paths = list(DATA_ROOT.rglob("*.jpg"))[:10]

if len(test_image_paths) < 2:
    print("‚ö†Ô∏è Not enough test images, using synthetic data")
    test_image_paths = None

# =============================================================================
# SANITY CHECK 1: Embedding Collapse Detection
# =============================================================================
print("="*65)
print("SANITY CHECK 1: Embedding Collapse Detection")
print("="*65)
print("Comparing hashes of different images (should have LOW similarity)\n")

if test_image_paths and len(test_image_paths) >= 5:
    hashes = []
    for img_path in test_image_paths[:5]:
        img = Image.open(img_path).convert("RGB")
        x = transform(img).unsqueeze(0).to(device)
        with torch.no_grad():
            h = model.inference(x, n_sup=16)
            hashes.append(F.normalize(h, dim=1).clone())  # .clone() to avoid aliasing!
    
    # Compute pairwise similarities
    print(f"{'Image Pair':<30} {'Cosine Sim':>12}")
    print("-"*45)
    collapse_detected = True

In [None]:
# =============================================================================
# COMPREHENSIVE ROBUSTNESS ANALYSIS
# =============================================================================
# NOTE: Pillow must be upgraded BEFORE any PIL imports (see first cell).
#       If you get JPEG errors, restart runtime and re-run from cell 1.

from PIL import Image, ImageFilter
import matplotlib.pyplot as plt
import numpy as np
import random
import io

print("="*65)
print("COMPREHENSIVE ROBUSTNESS ANALYSIS")
print("="*65)

# -----------------------------------------------------------------------------
# 1. HISTOGRAM OF PAIRWISE SIMILARITIES (1000+ pairs)
# -----------------------------------------------------------------------------
print("\nüìä 1. Pairwise Similarity Distribution")
print("-"*50)

# Get more images for statistical analysis
all_test_images = list(DATA_ROOT.rglob("*.jpg"))[:200]
random.shuffle(all_test_images)

# Compute hashes for many images
all_hashes = []
print(f"Computing hashes for {min(100, len(all_test_images))} images...")
for img_path in all_test_images[:100]:
    try:
        img = Image.open(img_path).convert("RGB")
        x = transform(img).unsqueeze(0).to(device)
        with torch.no_grad():
            h = model.inference(x, n_sup=16)
            all_hashes.append(F.normalize(h, dim=1).clone().cpu())
    except Exception:
        continue

print(f"Computed {len(all_hashes)} hashes")

# Compute pairwise similarities
pairwise_sims = []
for i in range(len(all_hashes)):
    for j in range(i+1, len(all_hashes)):
        sim = F.cosine_similarity(all_hashes[i], all_hashes[j]).item()
        pairwise_sims.append(sim)

pairwise_sims = np.array(pairwise_sims)
print(f"Computed {len(pairwise_sims)} pairwise similarities")
print(f"   Mean: {pairwise_sims.mean():.4f}")
print(f"   Std:  {pairwise_sims.std():.4f}")
print(f"   Min:  {pairwise_sims.min():.4f}")
print(f"   Max:  {pairwise_sims.max():.4f}")
print(f"   % > 0.95: {100 * (pairwise_sims > 0.95).mean():.1f}%")
print(f"   % > 0.99: {100 * (pairwise_sims > 0.99).mean():.1f}%")

# Plot histogram
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].hist(pairwise_sims, bins=50, edgecolor='black', alpha=0.7)
axes[0].axvline(x=0.9, color='r', linestyle='--', label='0.9 threshold')
axes[0].axvline(x=pairwise_sims.mean(), color='g', linestyle='-', label=f'Mean={pairwise_sims.mean():.3f}')
axes[0].set_xlabel('Cosine Similarity')
axes[0].set_ylabel('Count')
axes[0].set_title('Pairwise Similarity Distribution')
axes[0].legend()

# -----------------------------------------------------------------------------
# 2. PIXEL-SPACE NOISE (before normalization)
# -----------------------------------------------------------------------------
print("\nüìä 2. Pixel-Space Noise (Pre-Normalization)")
print("-"*50)

# Raw transform without normalization
raw_transform = transforms.Compose([
    transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.ToTensor(),  # Just [0,1] range, no normalization
])

normalize = transforms.Normalize(mean=NORM_MEAN, std=NORM_STD)

# Use a real image
test_img = Image.open(all_test_images[0]).convert("RGB")
raw_tensor = raw_transform(test_img).unsqueeze(0).to(device)  # [0,1] range

# Get baseline (apply noise BEFORE normalization)
baseline_normalized = normalize(raw_tensor.squeeze(0)).unsqueeze(0)
with torch.no_grad():
    baseline_hash = model.inference(baseline_normalized, n_sup=16)
    baseline_hash = F.normalize(baseline_hash, dim=1).clone()

pixel_noise_results = []
pixel_epsilons = [0.01, 0.05, 0.1, 0.2, 0.3]  # In [0,1] pixel range

for eps in pixel_epsilons:
    # Add noise in pixel space [0,1]
    noise = torch.randn_like(raw_tensor) * eps
    noisy_pixels = torch.clamp(raw_tensor + noise, 0, 1)
    
    # Then normalize
    noisy_normalized = normalize(noisy_pixels.squeeze(0)).unsqueeze(0)
    
    with torch.no_grad():
        noisy_hash = model.inference(noisy_normalized, n_sup=16)
        noisy_hash = F.normalize(noisy_hash, dim=1).clone()
    
    sim = F.cosine_similarity(baseline_hash, noisy_hash).item()
    pixel_noise_results.append((eps, sim))
    print(f"Pixel noise œÉ={eps:.2f}: Cosine={sim:.6f}")

# -----------------------------------------------------------------------------
# 3. HARD TRANSFORMS (JPEG, Blur, Crop, Resize)
# -----------------------------------------------------------------------------
print("\nüìä 3. Hard Transform Invariance")
print("-"*50)

hard_transform_results = []

# JPEG compression using torchvision (bypasses PIL's buggy JPEG encoder)
import torchvision.io as tvio
from torchvision.transforms.functional import to_tensor, to_pil_image

def jpeg_compress(pil_img, quality):
    """Compress image via JPEG using torchvision (avoids PIL JPEG bugs)."""
    t = (to_tensor(pil_img) * 255).clamp(0, 255).to(torch.uint8)
    enc = tvio.encode_jpeg(t, quality=quality)
    dec = tvio.decode_jpeg(enc)
    return to_pil_image(dec)

for quality in [95, 75, 50, 25, 10]:
    try:
        jpeg_img = jpeg_compress(test_img, quality)
        x = transform(jpeg_img).unsqueeze(0).to(device)
        with torch.no_grad():
            h = model.inference(x, n_sup=16)
            h = F.normalize(h, dim=1).clone()
        sim = F.cosine_similarity(baseline_hash.to(device), h).item()
        hard_transform_results.append((f"JPEG q={quality}", sim))
        print(f"JPEG quality={quality:2d}: Cosine={sim:.6f}")
    except Exception as e:
        print(f"JPEG quality={quality:2d}: FAILED - {e}")
        print("   Try: Runtime ‚Üí Restart runtime, then re-run cells")

# Gaussian blur
for radius in [1, 2, 3, 5]:
    blurred = test_img.filter(ImageFilter.GaussianBlur(radius=radius))
    x = transform(blurred).unsqueeze(0).to(device)
    with torch.no_grad():
        h = model.inference(x, n_sup=16)
        h = F.normalize(h, dim=1).clone()
    sim = F.cosine_similarity(baseline_hash.to(device), h).item()
    hard_transform_results.append((f"Blur r={radius}", sim))
    print(f"Blur radius={radius}: Cosine={sim:.6f}")

# Center crop + resize back
for crop_pct in [0.9, 0.8, 0.7, 0.6]:
    w, h_img = test_img.size
    crop_w, crop_h = int(w * crop_pct), int(h_img * crop_pct)
    left = (w - crop_w) // 2
    top = (h_img - crop_h) // 2
    cropped = test_img.crop((left, top, left + crop_w, top + crop_h))
    cropped = cropped.resize((w, h_img), Image.BICUBIC)
    x = transform(cropped).unsqueeze(0).to(device)
    with torch.no_grad():
        hash_out = model.inference(x, n_sup=16)
        hash_out = F.normalize(hash_out, dim=1).clone()
    sim = F.cosine_similarity(baseline_hash.to(device), hash_out).item()
    hard_transform_results.append((f"Crop {int(crop_pct*100)}%", sim))
    print(f"Center crop {int(crop_pct*100)}%: Cosine={sim:.6f}")

# Plot hard transforms
transform_names = [r[0] for r in hard_transform_results]
transform_sims = [r[1] for r in hard_transform_results]

axes[1].barh(transform_names, transform_sims, color='steelblue', edgecolor='black')
axes[1].axvline(x=0.9, color='r', linestyle='--', label='0.9 threshold')
axes[1].set_xlabel('Cosine Similarity')
axes[1].set_title('Hard Transform Robustness')
axes[1].set_xlim(0, 1.05)
axes[1].legend()

# Plot pixel noise curve
pixel_eps = [r[0] for r in pixel_noise_results]
pixel_sims = [r[1] for r in pixel_noise_results]
axes[2].plot(pixel_eps, pixel_sims, 'go-', linewidth=2, markersize=8)
axes[2].axhline(y=0.9, color='r', linestyle='--', label='0.9 threshold')
axes[2].set_xlabel('Pixel Noise œÉ (in [0,1] range)')
axes[2].set_ylabel('Cosine Similarity')
axes[2].set_title('Pixel-Space Noise Robustness')
axes[2].set_ylim(0.5, 1.05)
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# -----------------------------------------------------------------------------
# SUMMARY
# -----------------------------------------------------------------------------
print("\n" + "="*65)
print("ROBUSTNESS SUMMARY")
print("="*65)
print(f"‚úì Pairwise similarity: mean={pairwise_sims.mean():.3f}, std={pairwise_sims.std():.3f}")
print(f"  ‚Üí {100*(pairwise_sims > 0.95).mean():.1f}% pairs > 0.95 (potential near-duplicates)")

# Safe extraction of results
jpeg_25 = [r[1] for r in hard_transform_results if 'q=25' in r[0]]
blur_3 = [r[1] for r in hard_transform_results if 'r=3' in r[0]]
crop_70 = [r[1] for r in hard_transform_results if '70%' in r[0]]

if jpeg_25: print(f"‚úì JPEG q=25: {jpeg_25[0]:.4f}")
if blur_3: print(f"‚úì Blur r=3:  {blur_3[0]:.4f}")
if crop_70: print(f"‚úì Crop 70%:  {crop_70[0]:.4f}")

# Pass/fail assessment
jpeg_results = [r[1] for r in hard_transform_results if 'JPEG' in r[0]]
blur_results = [r[1] for r in hard_transform_results if 'Blur' in r[0]]

jpeg_robust = all(s > 0.9 for s in jpeg_results) if jpeg_results else False
blur_robust = all(s > 0.85 for s in blur_results) if blur_results else False
no_collapse = pairwise_sims.mean() < 0.9

print(f"\n{'‚úì' if jpeg_robust else '‚úó'} JPEG robustness (all > 0.9)")
print(f"{'‚úì' if blur_robust else '‚úó'} Blur robustness (all > 0.85)")
print(f"{'‚úì' if no_collapse else '‚úó'} No collapse (mean pairwise < 0.9)")

## Deterministic Eval & Threshold Calibration

The previous robustness analysis used random augmentations which inflate similarity.
This cell uses **deterministic transforms** (Resize ‚Üí CenterCrop ‚Üí Normalize) for accurate measurement.

Then we calibrate the threshold using ROC/PR curves:
- **Same image (augmented)**: Should match (positives)
- **Different images**: Should NOT match (negatives)

In [None]:
# =============================================================================
# DETERMINISTIC EVALUATION WITH THRESHOLD CALIBRATION
# =============================================================================
# Uses eval transforms (no random augmentation) for unbiased measurement
# Calibrates threshold using ROC/PR curves
# IMPORTANT: Uses the SAME normalization as training for consistency!

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageFilter
from sklearn.metrics import roc_curve, precision_recall_curve, auc
from torchvision import transforms
import random
import json

print("="*65)
print("DETERMINISTIC EVALUATION & THRESHOLD CALIBRATION")
print("="*65)

# -----------------------------------------------------------------------------
# LOAD NORMALIZATION CONFIG FROM TRAINING (for consistency)
# -----------------------------------------------------------------------------
CKPT_DIR = Path("/content/drive/MyDrive/dazzled/outputs/checkpoints")
config_path = CKPT_DIR / "model_config.json"

if config_path.exists():
    with open(config_path) as f:
        model_config = json.load(f)
    NORM_MEAN = model_config.get("norm_mean", [0.485, 0.456, 0.406])
    NORM_STD = model_config.get("norm_std", [0.229, 0.224, 0.225])
    IMAGE_SIZE = model_config.get("image_size", 224)
    print(f"‚úì Loaded normalization from training config:")
    print(f"   mean={NORM_MEAN}")
    print(f"   std={NORM_STD}")
else:
    # Fallback to ImageNet defaults
    NORM_MEAN = [0.485, 0.456, 0.406]
    NORM_STD = [0.229, 0.224, 0.225]
    IMAGE_SIZE = 224
    print("‚ö†Ô∏è  No model_config.json found, using ImageNet defaults")
    print("   (This may cause eval/train distribution mismatch)")

# -----------------------------------------------------------------------------
# DETERMINISTIC EVAL TRANSFORM (no randomness!)
# -----------------------------------------------------------------------------
eval_transform = transforms.Compose([
    transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=NORM_MEAN, std=NORM_STD),
])

print(f"‚úì Using deterministic eval transform (Resize‚ÜíCenterCrop({IMAGE_SIZE})‚ÜíNormalize)")

# -----------------------------------------------------------------------------
# GATHER TEST IMAGES
# -----------------------------------------------------------------------------
all_test_images = list(DATA_ROOT.rglob("*.jpg"))
random.seed(42)  # Reproducible
random.shuffle(all_test_images)
test_images = all_test_images[:200]  # 200 images for calibration
print(f"‚úì Selected {len(test_images)} test images")

# -----------------------------------------------------------------------------
# COMPUTE HASHES WITH DETERMINISTIC TRANSFORM
# -----------------------------------------------------------------------------
print("\nüìä Computing hashes with deterministic transform...")
hashes = []
for img_path in test_images:
    try:
        img = Image.open(img_path).convert("RGB")
        x = eval_transform(img).unsqueeze(0).to(device)
        with torch.no_grad():
            h = model.inference(x, n_sup=16)
            hashes.append(F.normalize(h, dim=1).cpu())
    except Exception:
        hashes.append(None)

valid_indices = [i for i, h in enumerate(hashes) if h is not None]
valid_hashes = [hashes[i] for i in valid_indices]
valid_images = [test_images[i] for i in valid_indices]
print(f"‚úì Computed {len(valid_hashes)} valid hashes")

# Stack for batch operations
hash_tensor = torch.cat(valid_hashes, dim=0)  # [N, hash_dim]

# -----------------------------------------------------------------------------
# COMPUTE SIMILARITY DISTRIBUTIONS
# -----------------------------------------------------------------------------
print("\nüìä Computing similarity distributions...")

# 1. DIFFERENT IMAGES (negatives)
neg_sims = []
n = len(valid_hashes)
for i in range(n):
    for j in range(i+1, n):
        sim = F.cosine_similarity(valid_hashes[i], valid_hashes[j]).item()
        neg_sims.append(sim)
neg_sims = np.array(neg_sims)
print(f"  Different-image pairs: {len(neg_sims)}")
print(f"    Mean: {neg_sims.mean():.4f}, Std: {neg_sims.std():.4f}")
print(f"    [Min, Max]: [{neg_sims.min():.4f}, {neg_sims.max():.4f}]")

# 2. SAME IMAGE WITH AUGMENTATIONS (positives)
# Apply realistic augmentations to create "same image" pairs
augmentations = [
    ("JPEG q=50", lambda img: jpeg_compress_eval(img, 50)),
    ("JPEG q=25", lambda img: jpeg_compress_eval(img, 25)),
    ("Blur r=2", lambda img: img.filter(ImageFilter.GaussianBlur(radius=2))),
    ("Blur r=4", lambda img: img.filter(ImageFilter.GaussianBlur(radius=4))),
    ("Crop 80%", lambda img: center_crop_resize(img, 0.8)),
    ("Crop 70%", lambda img: center_crop_resize(img, 0.7)),
]

def jpeg_compress_eval(pil_img, quality):
    import io
    buffer = io.BytesIO()
    pil_img.save(buffer, format='JPEG', quality=quality)
    buffer.seek(0)
    return Image.open(buffer).convert("RGB")

def center_crop_resize(pil_img, crop_pct):
    w, h = pil_img.size
    crop_w, crop_h = int(w * crop_pct), int(h * crop_pct)
    left = (w - crop_w) // 2
    top = (h - crop_h) // 2
    cropped = pil_img.crop((left, top, left + crop_w, top + crop_h))
    return cropped.resize((w, h), Image.BICUBIC)

pos_sims = []
sample_images = valid_images[:50]  # Use 50 images for augmentation pairs

for img_path in sample_images:
    try:
        original = Image.open(img_path).convert("RGB")
        x_orig = eval_transform(original).unsqueeze(0).to(device)
        with torch.no_grad():
            h_orig = model.inference(x_orig, n_sup=16)
            h_orig = F.normalize(h_orig, dim=1)
        
        for aug_name, aug_fn in augmentations:
            augmented = aug_fn(original)
            x_aug = eval_transform(augmented).unsqueeze(0).to(device)
            with torch.no_grad():
                h_aug = model.inference(x_aug, n_sup=16)
                h_aug = F.normalize(h_aug, dim=1)
            
            sim = F.cosine_similarity(h_orig, h_aug).item()
            pos_sims.append(sim)
    except Exception:
        continue

pos_sims = np.array(pos_sims)
print(f"  Same-image (augmented) pairs: {len(pos_sims)}")
print(f"    Mean: {pos_sims.mean():.4f}, Std: {pos_sims.std():.4f}")
print(f"    [Min, Max]: [{pos_sims.min():.4f}, {pos_sims.max():.4f}]")

# -----------------------------------------------------------------------------
# ROC & PR CURVE ANALYSIS
# -----------------------------------------------------------------------------
print("\nüìä Computing ROC & PR curves...")

# Labels: 1 = same image (should match), 0 = different image (should not match)
y_true = np.concatenate([np.ones(len(pos_sims)), np.zeros(len(neg_sims))])
y_scores = np.concatenate([pos_sims, neg_sims])

# ROC curve
fpr, tpr, roc_thresholds = roc_curve(y_true, y_scores)
roc_auc = auc(fpr, tpr)

# PR curve
precision, recall, pr_thresholds = precision_recall_curve(y_true, y_scores)
pr_auc = auc(recall, precision)

# Find optimal thresholds
# 1. Youden's J statistic (maximize TPR - FPR)
j_scores = tpr - fpr
best_j_idx = np.argmax(j_scores)
threshold_youden = roc_thresholds[best_j_idx]

# 2. F1-optimal threshold
f1_scores = 2 * precision[:-1] * recall[:-1] / (precision[:-1] + recall[:-1] + 1e-8)
best_f1_idx = np.argmax(f1_scores)
threshold_f1 = pr_thresholds[best_f1_idx]

# 3. High-precision threshold (precision > 0.95)
high_prec_mask = precision[:-1] > 0.95
if high_prec_mask.any():
    high_prec_idx = np.where(high_prec_mask)[0][0]
    threshold_high_prec = pr_thresholds[high_prec_idx]
else:
    threshold_high_prec = 0.99

print(f"\nüìè Optimal Thresholds:")
print(f"  Youden (max TPR-FPR):     {threshold_youden:.4f}")
print(f"  F1-optimal:               {threshold_f1:.4f}")
print(f"  High-precision (>95%):    {threshold_high_prec:.4f}")

# -----------------------------------------------------------------------------
# VISUALIZATION
# -----------------------------------------------------------------------------
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# 1. Similarity histograms
axes[0].hist(neg_sims, bins=50, alpha=0.7, label=f'Different (n={len(neg_sims)})', color='red')
axes[0].hist(pos_sims, bins=50, alpha=0.7, label=f'Same (n={len(pos_sims)})', color='green')
axes[0].axvline(x=threshold_f1, color='blue', linestyle='--', label=f'F1 thresh={threshold_f1:.3f}')
axes[0].axvline(x=threshold_youden, color='orange', linestyle=':', label=f'Youden={threshold_youden:.3f}')
axes[0].set_xlabel('Cosine Similarity')
axes[0].set_ylabel('Count')
axes[0].set_title('Similarity Distributions')
axes[0].legend(fontsize=8)

# 2. ROC curve
axes[1].plot(fpr, tpr, 'b-', linewidth=2, label=f'ROC (AUC={roc_auc:.3f})')
axes[1].plot([0, 1], [0, 1], 'k--', alpha=0.5)
axes[1].scatter([fpr[best_j_idx]], [tpr[best_j_idx]], color='orange', s=100, zorder=5, label=f'Youden @ {threshold_youden:.3f}')
axes[1].set_xlabel('False Positive Rate')
axes[1].set_ylabel('True Positive Rate')
axes[1].set_title('ROC Curve')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# 3. PR curve
axes[2].plot(recall, precision, 'g-', linewidth=2, label=f'PR (AUC={pr_auc:.3f})')
axes[2].scatter([recall[best_f1_idx]], [precision[best_f1_idx]], color='blue', s=100, zorder=5, label=f'F1-opt @ {threshold_f1:.3f}')
axes[2].set_xlabel('Recall')
axes[2].set_ylabel('Precision')
axes[2].set_title('Precision-Recall Curve')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# -----------------------------------------------------------------------------
# SUMMARY & DIAGNOSIS
# -----------------------------------------------------------------------------
print("\n" + "="*65)
print("CALIBRATION SUMMARY")
print("="*65)

separation = pos_sims.mean() - neg_sims.mean()
overlap = (neg_sims > pos_sims.min()).sum() / len(neg_sims)

print(f"  Same-image mean:      {pos_sims.mean():.4f}")
print(f"  Different-image mean: {neg_sims.mean():.4f}")
print(f"  Separation:           {separation:.4f}")
print(f"  Overlap (neg > min_pos): {100*overlap:.1f}%")
print(f"\n  ROC AUC: {roc_auc:.4f}")
print(f"  PR AUC:  {pr_auc:.4f}")

# Diagnosis
if separation < 0.1:
    print("\n‚ö†Ô∏è  COLLAPSE DETECTED: Separation < 0.1")
    print("   ‚Üí Model is not discriminating between different images")
    print("   ‚Üí RETRAIN with --contrast-weight 0.5 (InfoNCE loss)")
elif separation < 0.3:
    print("\n‚ö†Ô∏è  WEAK DISCRIMINATION: Separation < 0.3")
    print("   ‚Üí Consider increasing hash_dim (128 or 192)")
    print("   ‚Üí Consider adding --contrast-weight 0.3")
else:
    print("\n‚úì Good discrimination (separation > 0.3)")

if roc_auc < 0.9:
    print(f"\n‚ö†Ô∏è  ROC AUC = {roc_auc:.3f} (should be > 0.95 for production)")
else:
    print(f"\n‚úì ROC AUC = {roc_auc:.3f} (good)")

print(f"\nüìù RECOMMENDED THRESHOLD: {threshold_f1:.4f}")
print(f"   (Use this in your Go runtime for matching)")

## 5. ONNX Export Test

In [None]:
# =============================================================================
# ONNX EXPORT FOR TRM HASHER
# =============================================================================
# The TRM model uses recursive inference, which ONNX doesn't handle well.
# We export a single-step model that the Go runtime calls repeatedly.

import os

# Install onnxscript (required for torch.onnx.export in PyTorch 2.x)
!pip install -q onnxscript

import onnx
import onnxruntime as ort
import numpy as np

# Import ONNX wrapper
from models.trm_hasher import TRMHasherONNX

# Create ONNX-exportable wrapper from trained model
onnx_model = TRMHasherONNX(student)  # `student` is the trained TRMHasher
onnx_model.eval()
onnx_model = onnx_model.to("cpu")

# Export to ONNX
onnx_path = "trm_hasher.onnx"

# Dummy inputs for ONNX export
batch_size = 1
dummy_img = torch.randn(batch_size, 3, 224, 224)
dummy_y = torch.randn(batch_size, EMBED_DIM)
dummy_z = torch.randn(batch_size, EMBED_DIM)
dummy_x_cached = torch.zeros(batch_size, EMBED_DIM)  # Zeros = need to encode

torch.onnx.export(
    onnx_model,
    (dummy_img, dummy_y, dummy_z, dummy_x_cached),
    onnx_path,
    input_names=["image", "y", "z", "x_cached"],
    output_names=["x", "y_new", "z_new", "hash"],
    opset_version=18,
    dynamic_axes={
        "image": {0: "batch"},
        "y": {0: "batch"},
        "z": {0: "batch"},
        "x_cached": {0: "batch"},
        "x": {0: "batch"},
        "y_new": {0: "batch"},
        "z_new": {0: "batch"},
        "hash": {0: "batch"},
    },
)

print(f"‚úì Exported TRM ONNX model to {onnx_path}")
print(f"  File size: {os.path.getsize(onnx_path) / 1024:.2f} KB")
print(f"\n  Inputs: image[B,3,224,224], y[B,{EMBED_DIM}], z[B,{EMBED_DIM}], x_cached[B,{EMBED_DIM}]")
print(f"  Outputs: x[B,{EMBED_DIM}], y_new[B,{EMBED_DIM}], z_new[B,{EMBED_DIM}], hash[B,{HASH_DIM}]")
print(f"\n  Usage: Call N_sup=16 times, passing (y_new, z_new, x) back as (y, z, x_cached)")

In [None]:
# Validate ONNX model and test recursive inference
onnx_model_loaded = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model_loaded)
print("‚úì ONNX model validation passed!")

# Create ONNX Runtime session
session = ort.InferenceSession(onnx_path)

# Test recursive ONNX inference (simulating Go runtime)
N_SUP_TEST = 16
rng = np.random.default_rng(42)
test_img = rng.standard_normal((1, 3, 224, 224)).astype(np.float32)

# Initialize y and z (from model's learned init)
y = onnx_model.y_init.detach().cpu().numpy()
z = onnx_model.z_init.detach().cpu().numpy()
x_cached = np.zeros((1, EMBED_DIM), dtype=np.float32)

print(f"\nRunning {N_SUP_TEST} supervision steps in ONNX Runtime...")
for step in range(N_SUP_TEST):
    outputs = session.run(None, {
        "image": test_img,
        "y": y,
        "z": z,
        "x_cached": x_cached,
    })
    x_new, y_new, z_new, hash_onnx = outputs
    
    # For next iteration, use cached x (don't re-encode image)
    x_cached = x_new
    y = y_new
    z = z_new
    
    if step in [0, 7, 15]:
        print(f"  Step {step+1}: hash L2 norm = {np.linalg.norm(hash_onnx, axis=1).mean():.4f}")

print(f"\n‚úì ONNX inference complete!")
print(f"  Final hash shape: {hash_onnx.shape}")
print(f"  Final hash L2 norm: {np.linalg.norm(hash_onnx, axis=1).mean():.4f}")

In [None]:
# Compare PyTorch vs ONNX outputs
print("Comparing PyTorch TRM vs ONNX outputs...")

# Run PyTorch inference with same input
test_img_torch = torch.from_numpy(test_img)
with torch.no_grad():
    pt_hash = student.inference(test_img_torch.to("cpu"), n_sup=N_SUP_TEST)
    pt_hash_np = pt_hash.cpu().numpy()

# Compare with ONNX output
hash_diff = np.abs(pt_hash_np - hash_onnx).max()
cosine_sim = np.dot(pt_hash_np.flatten(), hash_onnx.flatten()) / (
    np.linalg.norm(pt_hash_np) * np.linalg.norm(hash_onnx)
)

print(f"\nMax absolute difference: {hash_diff:.8f}")
print(f"Cosine similarity: {cosine_sim:.6f}")

if cosine_sim > 0.999:
    print("‚úÖ ONNX export matches PyTorch output!")
else:
    print("‚ö†Ô∏è Warning: Some differences detected (may be due to floating point)")
    print("   This is expected for complex recursive models. Check cosine similarity.")

## 6. Latency Benchmark

In [None]:
import time

def benchmark_trm_inference(model, device, n_sup=16, num_runs=100, warmup=10):
    """Benchmark TRM inference latency."""
    model = model.to(device)
    model.eval()

    test_img = torch.randn(1, 3, 224, 224).to(device)

    # Warmup
    for _ in range(warmup):
        with torch.no_grad():
            _ = model.inference(test_img, n_sup=n_sup)

    if device.type == 'cuda':
        torch.cuda.synchronize()

    # Benchmark
    times = []
    for _ in range(num_runs):
        if device.type == 'cuda':
            torch.cuda.synchronize()
        start = time.perf_counter()

        with torch.no_grad():
            _ = model.inference(test_img, n_sup=n_sup)

        if device.type == 'cuda':
            torch.cuda.synchronize()
        times.append(time.perf_counter() - start)

    return times

# Benchmark on CPU
print(f"Benchmarking TRM inference (N_sup={N_SUP})...")
cpu_times = benchmark_trm_inference(student, torch.device('cpu'), n_sup=N_SUP, num_runs=50)
print(f"CPU Latency: {np.mean(cpu_times)*1000:.2f} ¬± {np.std(cpu_times)*1000:.2f} ms")

# Benchmark on GPU if available
if torch.cuda.is_available():
    gpu_times = benchmark_trm_inference(student.to('cuda'), torch.device('cuda'), n_sup=N_SUP, num_runs=100)
    print(f"GPU Latency: {np.mean(gpu_times)*1000:.2f} ¬± {np.std(gpu_times)*1000:.2f} ms")

## 8. Summary

### ‚úÖ This notebook provides a complete TRM training workflow:

| Step | Cell | Description |
|------|------|-------------|
| 0 | Mount Drive | Connect Google Drive for data/checkpoints |
| 0.1 | Download Data | Download FFHQ, OpenImages, MobileViews (~45k images) |
| 1 | Setup | Clone repo & install dependencies |
| 1.1 | Train TRM | Build manifest & train with **TRM deep supervision** |
| 2-3 | Architecture | Verify TRM structure & forward pass |
| 4 | Recursion | Test N_sup supervision step convergence |
| 5 | Robustness | Perturbation stability testing |
| 6 | ONNX | Export & validate ONNX model |
| 7 | Save | Copy ONNX to Drive for Go runtime |

### üì¶ Output artifacts (on Google Drive):

```
/content/drive/MyDrive/dazzled/outputs/
‚îú‚îÄ‚îÄ checkpoints/
‚îÇ   ‚îî‚îÄ‚îÄ student_epoch_5.safetensors   # Trained TRM weights
‚îî‚îÄ‚îÄ models/
    ‚îî‚îÄ‚îÄ trm_hasher.onnx               # ONNX for Go runtime
```

### üîë TRM Architecture (from paper):
- **2-layer network** with RMSNorm + SwiGLU
- **Two features**: y (hash/answer), z (latent reasoning)
- **Deep supervision**: N_sup=16 steps, loss at each step
- **Latent recursion**: z = net(x,y,z) √ó 6, then y = net(y,z)
- **Deep recursion**: T=3 cycles (2 no-grad + 1 grad)

### üîó References:
- **DINOv3:** [arXiv:2508.10104](https://arxiv.org/abs/2508.10104)
- **TRM:** [arXiv:2510.04871](https://arxiv.org/abs/2510.04871) "Less is More: Recursive Reasoning with Tiny Networks"
- **Split Accumulation:** [ePrint 2020/1618](https://eprint.iacr.org/2020/1618)

## 7. Export ONNX to Google Drive

Save the trained model to Drive for use in the Go application.

In [None]:
# =============================================================================
# EXPORT TRAINED TRM ONNX MODEL TO GOOGLE DRIVE
# =============================================================================

import shutil
from pathlib import Path

# Source: the ONNX file exported earlier in this notebook
source_onnx = Path("trm_hasher.onnx")

# Destination on Drive
dest_dir = Path("/content/drive/MyDrive/dazzled/outputs/models")
dest_dir.mkdir(parents=True, exist_ok=True)
dest_onnx = dest_dir / "trm_hasher.onnx"

# Check if ONNX already exists on Drive
if dest_onnx.exists():
    existing_size = dest_onnx.stat().st_size / 1024
    print(f"‚ö†Ô∏è  ONNX model already exists on Drive!")
    print(f"   Path: {dest_onnx}")
    print(f"   Size: {existing_size:.2f} KB")
    print(f"\n   Skipping copy to preserve existing model.")
    print(f"   To overwrite, delete the file manually or set FORCE_OVERWRITE = True below.")
    
    FORCE_OVERWRITE = False  # <-- Set to True to overwrite
    
    if FORCE_OVERWRITE and source_onnx.exists():
        shutil.copy(source_onnx, dest_onnx)
        print(f"\n   ‚úì Overwrote existing model (FORCE_OVERWRITE=True)")

elif source_onnx.exists():
    shutil.copy(source_onnx, dest_onnx)
    size_kb = dest_onnx.stat().st_size / 1024
    
    print(f"‚úì Exported TRM ONNX model to Drive!")
    print(f"  Path: {dest_onnx}")
    print(f"  Size: {size_kb:.2f} KB")
    print(f"\nüìã Model Interface (for Go runtime):")
    print(f"  Inputs:")
    print(f"    - image: [B, 3, 224, 224] float32 (normalized ImageNet)")
    print(f"    - y: [B, {EMBED_DIM}] float32 (answer state)")
    print(f"    - z: [B, {EMBED_DIM}] float32 (latent state)")
    print(f"    - x_cached: [B, {EMBED_DIM}] float32 (cached image encoding)")
    print(f"  Outputs:")
    print(f"    - x: [B, {EMBED_DIM}] float32 (image encoding, cache this)")
    print(f"    - y_new: [B, {EMBED_DIM}] float32 (updated answer state)")
    print(f"    - z_new: [B, {EMBED_DIM}] float32 (updated latent state)")
    print(f"    - hash: [B, {HASH_DIM}] float32 (current hash)")
    print(f"\nüîÅ Usage: Call {N_SUP} times, passing outputs back as inputs")
else:
    print(f"‚ùå ONNX file not found at {source_onnx}")
    print("   Run the ONNX export cell first!")

# üéØ MODEL VALIDATION MILESTONES

Before exporting to ONNX and implementing in Go, the model must pass **all four milestones**.
These tests ensure the recursive student is stable, accurate, and portable.

| Milestone | Test | Success Criteria |
|-----------|------|------------------|
| 1 | Recursive Drift | Emb‚ÇÅ ‚âà Emb‚ÇÖ (cosine sim > 0.99) |
| 2 | Validation Loss | Plateaued for 3-5 epochs |
| 3 | Preprocessing Parity | Go/Python use identical transforms |
| 4 | ONNX Parity | PyTorch vs ONNX diff < 1e-5 |

## Milestone 1: Recursive Drift Test (Stability)

A recursive model feeds its own output back into itself. If the model is unstable,
errors compound and the embedding "drifts" into garbage after 2-3 passes.

**Test:** Run an image through the student 5 times recursively.
**Pass Criteria:** Distance between Emb‚ÇÅ and Emb‚ÇÖ should be near-zero (cosine similarity > 0.99).

In [None]:
# =============================================================================
# MILESTONE 1: TRM RECURSIVE DRIFT TEST
# =============================================================================
# Test that the TRM produces stable embeddings across multiple inference runs.
# If the model is well-trained, repeated inference should produce same hash.

import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from pathlib import Path
import numpy as np

# Load trained TRM model
CHECKPOINT_DIR = Path("/content/drive/MyDrive/dazzled/outputs/checkpoints")
ONNX_PATH = Path("/content/drive/MyDrive/dazzled/outputs/models/trm_hasher.onnx")

# Find checkpoints - prioritize epoch checkpoints over step checkpoints
all_checkpoints = sorted(CHECKPOINT_DIR.glob("*.safetensors"))
epoch_ckpts = [c for c in all_checkpoints if "epoch" in c.name]
step_ckpts = [c for c in all_checkpoints if "step" in c.name]

if epoch_ckpts:
    latest_ckpt = epoch_ckpts[-1]  # Use final epoch checkpoint
    print(f"‚úì Loading EPOCH checkpoint: {latest_ckpt.name}")
elif step_ckpts:
    latest_ckpt = step_ckpts[-1]
    print(f"‚ö†Ô∏è  No epoch checkpoints found. Loading step: {latest_ckpt.name}")
elif all_checkpoints:
    latest_ckpt = all_checkpoints[-1]
    print(f"‚ö†Ô∏è  Loading: {latest_ckpt.name}")
else:
    raise FileNotFoundError(f"No checkpoints found in {CHECKPOINT_DIR}")

# Load TRM model
import safetensors.torch
import sys
sys.path.insert(0, "/content/DaZZLeD/ml-core")
from models.trm_hasher import TRMHasher

# TRM Architecture parameters (must match training)
EMBED_DIM = 256
HASH_DIM = 96
N_LAYERS = 2
N_LATENT = 6
T = 3
N_SUP = 16  # Supervision steps for inference
IMAGE_SIZE = 224

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
student = TRMHasher(
    embed_dim=EMBED_DIM,
    hash_dim=HASH_DIM,
    n_layers=N_LAYERS,
    n_latent=N_LATENT,
    T=T
).to(device)
safetensors.torch.load_model(student, str(latest_ckpt))
student.eval()

# Standard ImageNet normalization (MUST match Go implementation)
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

# Load a test image
DATA_ROOT = Path("/content/data")
test_images = list(DATA_ROOT.rglob("*.jpg"))[:5]
if not test_images:
    raise FileNotFoundError("No test images found")

print(f"\n{'='*65}")
print("MILESTONE 1: TRM RECURSIVE DRIFT TEST")
print("="*65)

milestone1_passed = True
for img_path in test_images:
    img = Image.open(img_path).convert("RGB")
    x = transform(img).unsqueeze(0).to(device)
    
    embeddings = []
    with torch.no_grad():
        for pass_num in range(5):
            # Run TRM inference (N_SUP supervision steps)
            hash_out = student.inference(x, n_sup=N_SUP)
            embeddings.append(F.normalize(hash_out, dim=1).cpu())
    
    # Compare first and last embeddings
    emb1 = embeddings[0]
    emb5 = embeddings[4]
    cosine_sim = F.cosine_similarity(emb1, emb5).item()
    l2_dist = torch.norm(emb1 - emb5).item()
    
    status = "‚úì PASS" if cosine_sim > 0.99 else "‚úó FAIL"
    if cosine_sim <= 0.99:
        milestone1_passed = False
    
    print(f"{status} {img_path.name[:30]:<30} | Cosine(Run1,Run5)={cosine_sim:.6f} | L2={l2_dist:.6f}")

print(f"\n{'MILESTONE 1: PASSED ‚úì' if milestone1_passed else 'MILESTONE 1: FAILED ‚úó'}")
print("="*65)

## Milestone 2: Validation Loss Plateau + Visual Inspection

Check that validation loss has plateaued and visually verify embeddings on "hard" images
(blurry faces, text documents, edge cases) are close to teacher embeddings.

**Test:** Compare student vs teacher embeddings on held-out validation set.
**Pass Criteria:** Student-Teacher cosine similarity > 0.95 on average.

In [None]:
# =============================================================================
# MILESTONE 2: TRM TEACHER-STUDENT ALIGNMENT (Relative Similarity Preservation)
# =============================================================================
# Since teacher (1024-dim) and student (96-dim) have different dimensions,
# we check if student PRESERVES RELATIVE SIMILARITIES learned from teacher.
# 
# Test: For pairs of images, does student agree with teacher on which pairs
# are more/less similar? This is the true test of distillation quality.

from transformers import AutoModel
from itertools import combinations

# Load teacher model
print("Loading DINOv3 teacher model...")
teacher = AutoModel.from_pretrained(
    "facebook/dinov3-vitl16-pretrain-lvd1689m",
    trust_remote_code=True
).to(device)
teacher.eval()

# Select validation images (mix of categories for diversity)
val_images = []
for subdir in ["ffhq", "openimages", "mobileviews"]:
    subdir_path = DATA_ROOT / subdir
    if subdir_path.exists():
        imgs = list(subdir_path.glob("*.jpg"))[:5]
        val_images.extend(imgs)

if len(val_images) < 6:
    val_images = test_images

print(f"\n{'='*65}")
print("MILESTONE 2: TRM TEACHER-STUDENT ALIGNMENT")
print("="*65)
print(f"Evaluating relative similarity preservation on {len(val_images)} images...")
print("(Checking if student agrees with teacher on which images are similar)\n")

# Collect all embeddings
teacher_embeddings = []
student_embeddings = []
image_names = []

for img_path in val_images:
    img = Image.open(img_path).convert("RGB")
    x = transform(img).unsqueeze(0).to(device)
    
    with torch.no_grad():
        # Teacher embedding (1024-dim)
        teacher_out = teacher(x)
        teacher_emb = F.normalize(teacher_out.last_hidden_state[:, 0], dim=1)
        teacher_embeddings.append(teacher_emb.cpu())
        
        # TRM Student embedding (96-dim via deep supervision inference)
        student_hash = student.inference(x, n_sup=N_SUP)
        student_emb = F.normalize(student_hash, dim=1)
        student_embeddings.append(student_emb.cpu())
        
        image_names.append(f"{img_path.parent.name}/{img_path.name[:15]}")

# Stack embeddings
teacher_embs = torch.cat(teacher_embeddings, dim=0)  # [N, 1024]
student_embs = torch.cat(student_embeddings, dim=0)  # [N, 96]

# Check for embedding collapse (all embeddings too similar)
student_avg_sim = 0
n_pairs = 0
for i in range(len(student_embs)):
    for j in range(i+1, len(student_embs)):
        student_avg_sim += F.cosine_similarity(student_embs[i:i+1], student_embs[j:j+1]).item()
        n_pairs += 1
student_avg_sim /= max(n_pairs, 1)

print(f"üìä Embedding Diversity Check:")
print(f"   Student avg pairwise similarity: {student_avg_sim:.4f}")
if student_avg_sim > 0.95:
    print(f"   ‚ö†Ô∏è WARNING: Possible embedding collapse! All outputs too similar.")
else:
    print(f"   ‚úì Good diversity in student embeddings")

# Compute pairwise similarities for both models
n_images = len(val_images)
teacher_sims = []
student_sims = []
pair_names = []

for i, j in combinations(range(n_images), 2):
    t_sim = F.cosine_similarity(teacher_embs[i:i+1], teacher_embs[j:j+1]).item()
    s_sim = F.cosine_similarity(student_embs[i:i+1], student_embs[j:j+1]).item()
    teacher_sims.append(t_sim)
    student_sims.append(s_sim)
    pair_names.append(f"{image_names[i][:12]} ‚Üî {image_names[j][:12]}")

teacher_sims = np.array(teacher_sims)
student_sims = np.array(student_sims)

# Compute rank correlation (Spearman) - does student preserve similarity ordering?
from scipy.stats import spearmanr
correlation, p_value = spearmanr(teacher_sims, student_sims)

# Also check if student correctly identifies "similar" vs "different" pairs
# (using median as threshold)
teacher_median = np.median(teacher_sims)
student_median = np.median(student_sims)

teacher_similar = teacher_sims > teacher_median
student_similar = student_sims > student_median
agreement_rate = np.mean(teacher_similar == student_similar)

# Show sample pairs
print(f"\n{'Pair':<35} {'Teacher':>10} {'Student':>10} {'Agree?':>8}")
print("-"*65)
for idx in range(min(10, len(pair_names))):
    t_sim = teacher_sims[idx]
    s_sim = student_sims[idx]
    t_high = "high" if t_sim > teacher_median else "low"
    s_high = "high" if s_sim > student_median else "low"
    agree = "‚úì" if t_high == s_high else "‚úó"
    print(f"{pair_names[idx]:<35} {t_sim:>10.4f} {s_sim:>10.4f} {agree:>8}")

if len(pair_names) > 10:
    print(f"... ({len(pair_names) - 10} more pairs)")

print("-"*65)
print(f"\nüìä ALIGNMENT METRICS:")
print(f"   Spearman Correlation: {correlation:.4f} (p={p_value:.2e})")
print(f"   Agreement Rate:       {agreement_rate*100:.1f}%")
print(f"   (Student agrees with teacher on similar/different classification)")

# Pass criteria: 
# - Spearman correlation > 0.5 (moderate positive correlation)
# - Agreement rate > 70%
# - No embedding collapse (avg pairwise sim < 0.95)
milestone2_passed = correlation > 0.5 and agreement_rate > 0.70 and student_avg_sim < 0.95

print(f"\n{'MILESTONE 2: PASSED ‚úì' if milestone2_passed else 'MILESTONE 2: FAILED ‚úó'}")
print(f"   (requires: correlation > 0.5 AND agreement > 70% AND no collapse)")
print("="*65)

# Cleanup teacher to free GPU memory
del teacher
torch.cuda.empty_cache()

## Milestone 3: Preprocessing Parity Check

**Critical:** The Go implementation MUST use identical preprocessing to Python.
This cell documents and verifies the exact preprocessing pipeline.

| Parameter | Value | Go Implementation |
|-----------|-------|-------------------|
| Image Size | 224√ó224 | `imaging.Resize(224, 224, imaging.Lanczos)` |
| Interpolation | Bicubic | `imaging.Lanczos` (closest match) |
| Normalization | ImageNet | mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225] |
| Channel Order | RGB | Standard (not BGR) |
| Data Type | float32 | `float32` |

In [None]:
# =============================================================================
# MILESTONE 3: PREPROCESSING PARITY CHECK
# =============================================================================
# Document and verify the exact preprocessing pipeline for Go parity.

print("="*65)
print("MILESTONE 3: PREPROCESSING PARITY SPECIFICATION")
print("="*65)

# Lock these values - changing them requires retraining!
PREPROCESSING_SPEC = {
    "image_size": IMAGE_SIZE,
    "interpolation": "BICUBIC",  # Go: imaging.Lanczos (closest match)
    "normalization": {
        "mean": IMAGENET_MEAN,
        "std": IMAGENET_STD,
    },
    "channel_order": "RGB",  # Not BGR!
    "data_type": "float32",
    "recursion_steps": RECURSION_STEPS,
    "state_dim": STATE_DIM,
    "hash_dim": HASH_DIM,
}

print("\nüìã LOCKED PREPROCESSING SPECIFICATION:")
print("-"*65)
print(f"  Image Size:      {PREPROCESSING_SPEC['image_size']}√ó{PREPROCESSING_SPEC['image_size']}")
print(f"  Interpolation:   {PREPROCESSING_SPEC['interpolation']}")
print(f"  Mean:            {PREPROCESSING_SPEC['normalization']['mean']}")
print(f"  Std:             {PREPROCESSING_SPEC['normalization']['std']}")
print(f"  Channel Order:   {PREPROCESSING_SPEC['channel_order']}")
print(f"  Data Type:       {PREPROCESSING_SPEC['data_type']}")
print(f"  Recursion Steps: {PREPROCESSING_SPEC['recursion_steps']}")
print(f"  State Dim:       {PREPROCESSING_SPEC['state_dim']}")
print(f"  Hash Dim:        {PREPROCESSING_SPEC['hash_dim']}")

# Demonstrate preprocessing step-by-step
print("\nüìù STEP-BY-STEP PREPROCESSING (for Go implementation):")
print("-"*65)

test_img = Image.open(test_images[0]).convert("RGB")
print(f"1. Load image as RGB: {test_img.size} ‚Üí {test_img.mode}")

# Resize
resized = test_img.resize((IMAGE_SIZE, IMAGE_SIZE), Image.BICUBIC)
print(f"2. Resize to {IMAGE_SIZE}√ó{IMAGE_SIZE} using BICUBIC interpolation")

# To tensor (0-1 range)
import numpy as np
arr = np.array(resized).astype(np.float32) / 255.0
print(f"3. Convert to float32 and scale to [0, 1]: shape={arr.shape}, range=[{arr.min():.3f}, {arr.max():.3f}]")

# Normalize
for c, (m, s) in enumerate(zip(IMAGENET_MEAN, IMAGENET_STD)):
    arr[:, :, c] = (arr[:, :, c] - m) / s
print(f"4. Normalize with ImageNet mean/std: range=[{arr.min():.3f}, {arr.max():.3f}]")

# Transpose to NCHW
arr = arr.transpose(2, 0, 1)  # HWC -> CHW
arr = arr[np.newaxis, ...]   # Add batch dimension
print(f"5. Transpose HWC‚ÜíCHW and add batch: shape={arr.shape}")

print("\n‚úÖ Go implementation must produce identical tensor!")
print("\n‚ö†Ô∏è  CRITICAL: Use imaging.Lanczos in Go (closest to BICUBIC)")
print("="*65)

# Save spec to JSON for Go reference
import json
spec_path = Path("/content/drive/MyDrive/dazzled/outputs/preprocessing_spec.json")
spec_path.parent.mkdir(parents=True, exist_ok=True)
with open(spec_path, "w") as f:
    json.dump(PREPROCESSING_SPEC, f, indent=2)
print(f"\nüíæ Saved specification to: {spec_path}")

milestone3_passed = True  # Manual check - specification documented
print(f"\n{'MILESTONE 3: PASSED ‚úì' if milestone3_passed else 'MILESTONE 3: FAILED ‚úó'} (specification documented)")

## Milestone 4: ONNX Parity Check

The final and most critical test: verify that the exported ONNX model produces
**identical outputs** to the PyTorch model.

**Test:** Run the same image through PyTorch and ONNX, compare outputs.
**Pass Criteria:** Maximum absolute difference < 1e-5.

In [None]:
# =============================================================================
# MILESTONE 4: ONNX PARITY CHECK
# =============================================================================
# Verify PyTorch and ONNX outputs are identical (to ~1e-5 precision).

import onnx
import onnxruntime as ort

print("="*65)
print("MILESTONE 4: ONNX PARITY CHECK")
print("="*65)

# Check if ONNX model exists
if not ONNX_PATH.exists():
    print(f"‚ö†Ô∏è  ONNX model not found at {ONNX_PATH}")
    print("   Run the ONNX export cell first!")
    milestone4_passed = False
else:
    # Load ONNX model
    print(f"Loading ONNX model: {ONNX_PATH.name}")
    onnx_model = onnx.load(str(ONNX_PATH))
    onnx.checker.check_model(onnx_model)
    print("‚úì ONNX model validation passed")
    
    # Create ONNX runtime session
    session = ort.InferenceSession(str(ONNX_PATH))
    
    # Test multiple images
    parity_results = []
    
    for img_path in test_images[:5]:
        img = Image.open(img_path).convert("RGB")
        x = transform(img).unsqueeze(0)
        
        # PyTorch inference
        with torch.no_grad():
            state_pt = torch.zeros(1, STATE_DIM, device=device)
            for _ in range(RECURSION_STEPS):
                state_pt, hash_pt = student(x.to(device), state_pt)
            pytorch_output = hash_pt.cpu().numpy()
        
        # ONNX inference
        x_np = x.numpy()
        state_np = np.zeros((1, STATE_DIM), dtype=np.float32)
        
        for _ in range(RECURSION_STEPS):
            onnx_outputs = session.run(None, {
                "image": x_np,
                "prev_state": state_np
            })
            state_np = onnx_outputs[0]
        onnx_output = onnx_outputs[1]
        
        # Compare
        max_diff = np.abs(pytorch_output - onnx_output).max()
        mean_diff = np.abs(pytorch_output - onnx_output).mean()
        
        status = "‚úì PASS" if max_diff < 1e-4 else "‚úó FAIL"
        parity_results.append((img_path.name[:30], max_diff, mean_diff, max_diff < 1e-4))
    
    # Print results
    print(f"\n{'Image':<32} {'Max Diff':>12} {'Mean Diff':>12} {'Status':>8}")
    print("-"*65)
    for name, max_d, mean_d, passed in parity_results:
        status = "‚úì PASS" if passed else "‚úó FAIL"
        print(f"{name:<32} {max_d:>12.2e} {mean_d:>12.2e} {status:>8}")
    
    milestone4_passed = all(r[3] for r in parity_results)
    print("-"*65)
    print(f"\n{'MILESTONE 4: PASSED ‚úì' if milestone4_passed else 'MILESTONE 4: FAILED ‚úó'} (threshold: max_diff < 1e-4)")

print("="*65)

---
## üöÄ GO / NO-GO Decision

Final checkpoint aggregating all milestone results.

In [None]:
# =============================================================================
# GO / NO-GO DECISION
# =============================================================================
# Final checkpoint: Are we ready to lock in the ONNX model?

print("\n" + "="*65)
print("           MODEL VALIDATION SUMMARY - GO/NO-GO DECISION")
print("="*65 + "\n")

milestones = [
    ("Milestone 1", "Recursive Drift Test", "milestone1_passed" in dir() and milestone1_passed),
    ("Milestone 2", "Teacher Alignment", "milestone2_passed" in dir() and milestone2_passed),
    ("Milestone 3", "Preprocessing Spec", "milestone3_passed" in dir() and milestone3_passed),
    ("Milestone 4", "ONNX Parity Check", "milestone4_passed" in dir() and milestone4_passed),
]

print(f"{'Milestone':<14} {'Test':<25} {'Status':>10}")
print("-"*50)

all_passed = True
for name, desc, passed in milestones:
    status = "‚úì PASSED" if passed else "‚úó FAILED"
    if not passed:
        all_passed = False
    print(f"{name:<14} {desc:<25} {status:>10}")

print("-"*50)
print()

if all_passed:
    print("‚ïî" + "‚ïê"*63 + "‚ïó")
    print("‚ïë" + " "*20 + "üöÄ GO FOR ONNX EXPORT üöÄ" + " "*18 + "‚ïë")
    print("‚ï†" + "‚ïê"*63 + "‚ï£")
    print("‚ïë  All milestones passed! You are ready to:                    ‚ïë")
    print("‚ïë                                                               ‚ïë")
    print("‚ïë  1. Download ONNX from Drive:                                 ‚ïë")
    print("‚ïë     /content/drive/MyDrive/dazzled/outputs/models/            ‚ïë")
    print("‚ïë                                                               ‚ïë")
    print("‚ïë  2. Copy to Go project:                                       ‚ïë")
    print("‚ïë     recursive_hasher.onnx ‚Üí bin/                              ‚ïë")
    print("‚ïë                                                               ‚ïë")
    print("‚ïë  3. Update internal/bridge/onnx_runtime.go                    ‚ïë")
    print("‚ïö" + "‚ïê"*63 + "‚ïù")
else:
    print("‚ïî" + "‚ïê"*63 + "‚ïó")
    print("‚ïë" + " "*22 + "‚õî NO-GO - FIX ISSUES ‚õî" + " "*17 + "‚ïë")
    print("‚ï†" + "‚ïê"*63 + "‚ï£")
    print("‚ïë  Some milestones failed. Address the issues above before      ‚ïë")
    print("‚ïë  proceeding with ONNX export.                                 ‚ïë")
    print("‚ïö" + "‚ïê"*63 + "‚ïù")

print("\n" + "="*65)