# ü¶ñ DaZZLeD: Recursive Hasher Training Notebook

**Goal:** Train a Tiny Recursive Model (TRM) using DINOv3 distillation for adversarially-robust perceptual hashing.

## How to Use This Notebook

1. **Run Cell 1** - Mount Google Drive
2. **Run Cell 2** - Download training datasets (one-time, ~30 min)
3. **Run Cell 3** - Clone repo & install dependencies
4. **Run Cell 4** - Build manifest & train the model
5. **Run remaining cells** - Test and export the trained model

‚ö†Ô∏è **Before training:** Runtime ‚Üí Change runtime type ‚Üí **GPU (T4 free or A100)**

## 0. Mount Google Drive

In [1]:
# Mount Google Drive (required for data storage)
from google.colab import drive
drive.mount('/content/drive')

# Create project directories
from pathlib import Path

DRIVE_ROOT = Path("/content/drive/MyDrive/dazzled")
DATA_ROOT = DRIVE_ROOT / "data"
OUTPUT_ROOT = DRIVE_ROOT / "outputs"

# Create all needed directories
for d in [DATA_ROOT / "ffhq", DATA_ROOT / "openimages", DATA_ROOT / "text",
          OUTPUT_ROOT / "checkpoints", OUTPUT_ROOT / "models",
          DRIVE_ROOT / "manifests"]:
    d.mkdir(parents=True, exist_ok=True)

print(f"‚úì Project root: {DRIVE_ROOT}")
print(f"‚úì Data root: {DATA_ROOT}")
print(f"‚úì Output root: {OUTPUT_ROOT}")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
‚úì Project root: /content/drive/MyDrive/dazzled
‚úì Data root: /content/drive/MyDrive/dazzled/data
‚úì Output root: /content/drive/MyDrive/dazzled/outputs


## 0.1 Download Training Datasets (Optimized for Speed)

**Strategy:** Use local disk (`/content/data`) for training speed, then cache to Drive as a zip.

- **First run:** Downloads data ‚Üí processes ‚Üí creates zip backup on Drive (~30 min)
- **Future runs:** Extracts from Drive zip ‚Üí ready in ~2 min

‚ö†Ô∏è Accessing individual files from Drive during training is **100x slower** than local disk!

In [None]:
# =============================================================================
# üì¶ DOWNLOAD & PREPARE DATASETS
# =============================================================================
# This cell downloads and prepares the training data.
# Total: ~45k images from 3 sources:
#   - FFHQ: Kaggle (face images for identity)
#   - OpenImages: FiftyOne (diverse real-world objects)
#   - MobileViews: HuggingFace parquet
#
# WHY THIS MIX?
#   - FFHQ 40k: Faces require precise hashing (identity preservation)
#   - OpenImages 2.5k: Broad category coverage (animals, vehicles, food)
#   - MobileViews 2k: Edge cases for text/UI (600k available, 2k is enough)
# =============================================================================

import subprocess
import shutil
from pathlib import Path

# Config
DATA_ROOT = Path("/content/data")
DRIVE_ROOT = Path("/content/drive/MyDrive/dazzled")
DRIVE_ARCHIVE = DRIVE_ROOT / "data-cache/training-images.zip"

EXPECTED_COUNTS = {
    "ffhq": 40000,          # Full dataset
    "openimages": 2500,     # Subset
    "mobileviews": 1500,    # Target: 2k
}

def validate_dataset(data_root: Path, expected: dict) -> tuple[bool, dict]:
    """Validate all datasets have enough images."""
    results = {}
    for name, exp_count in expected.items():
        path = data_root / name
        actual = len(list(path.glob("*.jpg"))) if path.exists() else 0
        results[name] = {"count": actual, "expected": exp_count, "valid": actual >= exp_count * 0.95}
    return all(r["valid"] for r in results.values()), results

# Check what we already have
DATA_ROOT.mkdir(parents=True, exist_ok=True)

# Option 1: Restore from Drive cache (fastest)
if DRIVE_ARCHIVE.exists():
    print("üîÑ Restoring from Drive cache...")
    shutil.unpack_archive(DRIVE_ARCHIVE, DATA_ROOT)
    all_valid, validation = validate_dataset(DATA_ROOT, EXPECTED_COUNTS)
    if all_valid:
        print("‚úì All datasets restored from cache!")
        for name, info in validation.items():
            print(f"  {name}: {info['count']:,} images")
    else:
        print("‚ö†Ô∏è Cache incomplete, will download missing data")
        need_download = {name: not info["valid"] for name, info in validation.items()}
else:
    print("üì• No cache found, downloading all datasets...")
    need_download = {"ffhq": True, "openimages": True, "mobileviews": True}

# Option 2: Download fresh data
if any(need_download.values()):
    print("\n" + "="*65)
    print("DOWNLOADING DATASETS")
    print("="*65)

    import torchvision.transforms as transforms
    from PIL import Image
    import io

    # -------------------------------------------------------------------------
    # 1. FFHQ via Kaggle
    # -------------------------------------------------------------------------
    if need_download["ffhq"]:
        ffhq_dir = DATA_ROOT / "ffhq"
        ffhq_dir.mkdir(parents=True, exist_ok=True)

        print("\nüì• [1/3] FFHQ via Kaggle")
        print("   Target: 40k high-quality face images")

        # Setup Kaggle credentials from Colab secrets
        from google.colab import userdata
        import os

        try:
            os.environ['KAGGLE_USERNAME'] = userdata.get('KAGGLE_USERNAME')
            os.environ['KAGGLE_KEY'] = userdata.get('KAGGLE_KEY')
        except Exception as e:
            print("‚ö†Ô∏è Kaggle credentials not found in Colab Secrets!")
            print("   To set up:")
            print("   1. Go to kaggle.com ‚Üí Your Profile ‚Üí Account ‚Üí API ‚Üí Create New Token")
            print("   2. Add to Colab Secrets (üîë icon on left sidebar):")
            print("      - KAGGLE_USERNAME = your_username")
            print("      - KAGGLE_KEY = your_api_key")
            raise RuntimeError("Kaggle credentials required") from e

        print("   Downloading from Kaggle...")
        !kaggle datasets download -d denislukovnikov/ffhq256-images-only -p /content/ffhq_temp --unzip -q

        # Move and verify
        for src in Path("/content/ffhq_temp").rglob("*.png"):
            src.rename(ffhq_dir / src.name)
        shutil.rmtree("/content/ffhq_temp", ignore_errors=True)

        # Convert to jpg for consistency
        print("   Converting to JPEG...")
        from concurrent.futures import ThreadPoolExecutor

        def convert_to_jpg(png_path):
            jpg_path = png_path.with_suffix('.jpg')
            try:
                Image.open(png_path).convert('RGB').save(jpg_path, 'JPEG', quality=95)
                png_path.unlink()
            except:
                pass

        pngs = list(ffhq_dir.glob("*.png"))
        with ThreadPoolExecutor(max_workers=8) as executor:
            list(executor.map(convert_to_jpg, pngs))

        print(f"   ‚úì FFHQ: {len(list(ffhq_dir.glob('*.jpg'))):,} images")

    # -------------------------------------------------------------------------
    # 2. OpenImages via FiftyOne
    # -------------------------------------------------------------------------
    if need_download["openimages"]:
        openimages_dir = DATA_ROOT / "openimages"
        openimages_dir.mkdir(parents=True, exist_ok=True)

        print("\nüì• [2/3] OpenImages via FiftyOne")
        print("   Target: 2.5k diverse real-world images")

        !pip install -q fiftyone

        import fiftyone as fo
        import fiftyone.zoo as foz

        dataset = foz.load_zoo_dataset(
            "open-images-v7",
            split="train",
            max_samples=2500,
            shuffle=True,
            seed=42,
            dataset_name="openimages_train"
        )

        resize_tfm = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224)
        ])

        for idx, sample in enumerate(dataset):
            try:
                img = Image.open(sample.filepath).convert("RGB")
                resize_tfm(img).save(openimages_dir / f"openimg_{idx:05d}.jpg", quality=90)
            except:
                pass

        fo.delete_dataset(dataset.name)
        print(f"   ‚úì OpenImages: {len(list(openimages_dir.glob('*.jpg'))):,} images")

    # -------------------------------------------------------------------------
    # 3. MobileViews (parquet - MEMORY-EFFICIENT STREAMING)
    # -------------------------------------------------------------------------
    if need_download["mobileviews"]:
        mobileviews_dir = DATA_ROOT / "mobileviews"
        mobileviews_dir.mkdir(parents=True, exist_ok=True)

        print("\nüì• [3/3] MobileViews via HuggingFace Parquet")
        print("   Target: 2k mobile UI screenshots (edge cases for text/UI)")

        !pip install -q huggingface_hub pyarrow

        from huggingface_hub import hf_hub_download, login
        from google.colab import userdata
        import pyarrow.parquet as pq

        try:
            hf_token = userdata.get('HF_TOKEN')
            if hf_token:
                login(hf_token)
        except:
            pass

        print("   Downloading parquet...")
        parquet_path = hf_hub_download(
            repo_id="mllmTeam/MobileViews",
            filename="MobileViews_Screenshots_ViewHierarchies/Parquets/MobileViews_0-150000.parquet",
            repo_type="dataset",
            local_dir="/content/mobileviews_cache"
        )

        print("   Extracting screenshots (streaming to avoid OOM)...")
        
        # Use ParquetFile for memory-efficient streaming instead of read_table()
        # read_table() loads ALL 150k images into RAM at once (~50GB+)
        # iter_batches() streams row-by-row, using only ~100MB RAM
        parquet_file = pq.ParquetFile(parquet_path)
        total_rows = parquet_file.metadata.num_rows
        
        mv_resize = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224)
        ])

        from tqdm import tqdm
        num_samples = 2000
        step = max(1, total_rows // num_samples)
        
        # Stream through the file, extracting every Nth row
        saved_count = 0
        current_row = 0
        
        for batch in tqdm(parquet_file.iter_batches(batch_size=1000, columns=["image_content"]),
                          desc="   Streaming", total=(total_rows // 1000) + 1):
            for i in range(len(batch)):
                if current_row % step == 0 and saved_count < num_samples:
                    try:
                        img_bytes = batch.column("image_content")[i].as_py()
                        img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
                        mv_resize(img).save(mobileviews_dir / f"mobileview_{saved_count:05d}.jpg", quality=90)
                        saved_count += 1
                    except:
                        pass
                current_row += 1
                
            # Early exit once we have enough samples
            if saved_count >= num_samples:
                break

        shutil.rmtree("/content/mobileviews_cache", ignore_errors=True)
        print(f"   ‚úì MobileViews: {len(list(mobileviews_dir.glob('*.jpg'))):,} images")

    # -------------------------------------------------------------------------
    # SAVE BACKUP
    # -------------------------------------------------------------------------
    all_valid, validation = validate_dataset(DATA_ROOT, EXPECTED_COUNTS)

    print("\n" + "="*65)
    print("üìã VALIDATION")
    print("="*65)
    for name, info in validation.items():
        status = "‚úì" if info["valid"] else "‚úó"
        print(f"{status} {name:<15} {info['count']:>6,} / {info['expected']:,}")

    if all_valid:
        print(f"\nüíæ Creating backup: {DRIVE_ARCHIVE}")
        DRIVE_ARCHIVE.parent.mkdir(parents=True, exist_ok=True)
        if DRIVE_ARCHIVE.exists():
            DRIVE_ARCHIVE.unlink()
        shutil.make_archive(str(DRIVE_ARCHIVE.with_suffix('')), 'zip', DATA_ROOT)
        archive_size = DRIVE_ARCHIVE.stat().st_size / (1024**3)
        print(f"‚úì Backup complete! ({archive_size:.2f} GB)")
    else:
        print("\n‚ö†Ô∏è Some datasets incomplete - not saving cache.")

# =============================================================================
# SUMMARY
# =============================================================================
print("\n" + "="*65)
print("üìä DATASET SUMMARY")
print("="*65)
print(f"{'Dataset':<20} {'Role':<25} {'Count':>10}")
print("-"*65)
for subdir, role in [("ffhq", "People (faces)"),
                      ("openimages", "Life (real-world)"),
                      ("mobileviews", "Edge Case (mobile UI)")]:
    path = DATA_ROOT / subdir
    count = len(list(path.glob("*.jpg"))) if path.exists() else 0
    print(f"{subdir:<20} {role:<25} {count:>10,}")
total = sum(1 for _ in DATA_ROOT.rglob("*.jpg"))
print("-"*65)
print(f"{'TOTAL':<20} {'':<25} {total:>10,}")
print("="*65)
print(f"\nüìÅ Data: {DATA_ROOT}")


üöÄ DOWNLOADING DATASETS

‚úì [1/3] FFHQ already exists (52,001 images). Skipping download.

üì• [2/3] OpenImages via FiftyOne
   Method: Official Google download (handles AWS S3 shards)
   Target: 15k diverse real-world images
   Downloading from AWS S3...
Downloading split 'validation' to '/root/fiftyone/open-images-v7/validation' if necessary


INFO:fiftyone.zoo.datasets:Downloading split 'validation' to '/root/fiftyone/open-images-v7/validation' if necessary


Downloading 15000 images


INFO:fiftyone.utils.openimages:Downloading 15000 images


 100% |‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 15000/15000 [4.9m elapsed, 0s remaining, 50.6 files/s]      


INFO:eta.core.utils: 100% |‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 15000/15000 [4.9m elapsed, 0s remaining, 50.6 files/s]      


Dataset info written to '/root/fiftyone/open-images-v7/info.json'


INFO:fiftyone.zoo.datasets:Dataset info written to '/root/fiftyone/open-images-v7/info.json'


Loading 'open-images-v7' split 'validation'


INFO:fiftyone.zoo.datasets:Loading 'open-images-v7' split 'validation'


  94% |‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà-| 14173/15000 [11.4m elapsed, 40.2s remaining, 20.6 samples/s]   

## 1. Setup & Installation

In [None]:
# Clone the repo (only needed in Colab)
import os
if not os.path.exists('DaZZLeD'):
    !git clone https://github.com/D13ya/DaZZLeD.git
    %cd DaZZLeD/ml-core
else:
    %cd DaZZLeD/ml-core

!pip install -q -r requirements.txt

In [None]:
import sys
import torch
import torch.nn.functional as F
from pathlib import Path

# Add project root to path
sys.path.insert(0, str(Path.cwd()))

from models.recursive_student import RecursiveHasher

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 1.1 Build Manifest & Train Model

‚ö†Ô∏è **Switch to GPU first:** Runtime ‚Üí Change runtime type ‚Üí GPU (T4 or A100)

This cell:
1. Builds a manifest of all training images
2. Trains the RecursiveHasher using DINOv3 distillation
3. Saves checkpoints to Google Drive

In [None]:
# =============================================================================
# BUILD MANIFEST FROM LOCAL DATA
# =============================================================================

from pathlib import Path

# Point to local fast storage
DATA_ROOT = Path("/content/data")
DRIVE_ROOT = Path("/content/drive/MyDrive/dazzled")

# Find all training images
exts = {".jpg", ".jpeg", ".png", ".bmp", ".webp"}
paths = [str(p) for p in DATA_ROOT.rglob("*") if p.suffix.lower() in exts]

print(f"Found {len(paths):,} training images in {DATA_ROOT}")

if len(paths) < 100:
    print("‚ö†Ô∏è  Not enough images! Run the download cell first.")
else:
    # Write manifest to Drive so it persists, but content points to /content/data
    manifest_path = DRIVE_ROOT / "manifests/train.txt"
    manifest_path.parent.mkdir(parents=True, exist_ok=True)

    with open(manifest_path, "w") as f:
        f.write("\n".join(paths))

    print(f"‚úì Manifest written: {manifest_path}")
    print(f"  (Points to {len(paths)} local files for high-speed training)")

In [None]:
# =============================================================================
# TRAIN THE MODEL (Choose one option)
# =============================================================================

# üîß TRAINING OPTIONS - Uncomment ONE of the following:

# -----------------------------------------------------------------------------
# OPTION A: Quick Test (5-10 min on T4) - Just to verify everything works
# -----------------------------------------------------------------------------
# !python training/train.py \
#     --data-list /content/drive/MyDrive/dazzled/manifests/train.txt \
#     --teacher facebook/dinov3-vitl16-pretrain-lvd1689m \
#     --epochs 1 \
#     --batch-size 16 \
#     --recursion-steps 8 \
#     --max-steps 100 \
#     --amp \
#     --log-interval 10 \
#     --checkpoint-dir /content/drive/MyDrive/dazzled/outputs/checkpoints

# -----------------------------------------------------------------------------
# OPTION B: Full Training (High RAM Optimized)
# -----------------------------------------------------------------------------
# Optimized for 50GB+ RAM:
# --cache-ram: Preloads all images into RAM (fastest IO, uses ~10GB RAM)
# --workers 8: Maximizes CPU usage for augmentation
# -----------------------------------------------------------------------------
!python training/train.py \
    --data-list /content/drive/MyDrive/dazzled/manifests/train.txt \
    --teacher facebook/dinov3-vitl16-pretrain-lvd1689m \
    --epochs 5 \
    --batch-size 64 \
    --recursion-steps 16 \
    --grad-accum 2 \
    --lr 1e-4 \
    --amp \
    --allow-tf32 \
    --channels-last \
    --cudnn-benchmark \
    --workers 8 \
    --cache-ram \
    --log-interval 50 \
    --checkpoint-dir /content/drive/MyDrive/dazzled/outputs/checkpoints \
    --checkpoint-every 500

# -----------------------------------------------------------------------------
# OPTION C: Resume from Checkpoint (if session disconnected)
# -----------------------------------------------------------------------------
# !python training/train.py \
#     --data-list /content/drive/MyDrive/dazzled/manifests/train.txt \
#     --teacher facebook/dinov3-vitl16-pretrain-lvd1689m \
#     --resume /content/drive/MyDrive/dazzled/outputs/checkpoints/student_epoch_3.safetensors \
#     --epochs 5 \
#     --batch-size 64 \
#     --recursion-steps 16 \
#     --amp \
#     --checkpoint-dir /content/drive/MyDrive/dazzled/outputs/checkpoints

In [None]:
# =============================================================================
# LIST CHECKPOINTS & LOAD TRAINED WEIGHTS
# =============================================================================

from pathlib import Path
import safetensors.torch

CKPT_DIR = Path("/content/drive/MyDrive/dazzled/outputs/checkpoints")

# List available checkpoints
checkpoints = sorted(CKPT_DIR.glob("*.safetensors"))
print(f"Found {len(checkpoints)} checkpoints:")
for ckpt in checkpoints:
    size_mb = ckpt.stat().st_size / (1024 * 1024)
    print(f"  {ckpt.name} ({size_mb:.2f} MB)")

# Load the latest checkpoint into model
if checkpoints:
    latest_ckpt = checkpoints[-1]
    print(f"\nüì• Loading: {latest_ckpt.name}")

    # Model should already be defined from earlier cell
    # If not, uncomment:
    # from models.recursive_student import RecursiveHasher
    # model = RecursiveHasher(state_dim=128, hash_dim=96)

    safetensors.torch.load_model(model, str(latest_ckpt))
    model.eval()
    print("‚úì Trained weights loaded!")
else:
    print("‚ö†Ô∏è  No checkpoints found. Run training first.")

## 2. Model Architecture Test

In [None]:
# Initialize model with default params
STATE_DIM = 128
HASH_DIM = 96
RECURSION_STEPS = 16

model = RecursiveHasher(state_dim=STATE_DIM, hash_dim=HASH_DIM)
model.eval()

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model size: ~{total_params * 4 / 1024 / 1024:.2f} MB (float32)")

In [None]:
# Test forward pass with dummy input
batch_size = 4
image_size = 224

dummy_img = torch.randn(batch_size, 3, image_size, image_size)
dummy_state = torch.zeros(batch_size, STATE_DIM)

with torch.no_grad():
    next_state, hash_out = model(dummy_img, dummy_state)

print(f"Input image shape: {dummy_img.shape}")
print(f"Input state shape: {dummy_state.shape}")
print(f"Output state shape: {next_state.shape}")
print(f"Output hash shape: {hash_out.shape}")
print(f"Hash L2 norm (should be ~1.0): {torch.norm(hash_out, dim=1).mean():.4f}")

## 3. Recursive Inference Test

The key innovation is running the model recursively 16 times, refining the hash at each step.

In [None]:
def recursive_inference(model, image, steps=16):
    """Run recursive inference for the specified number of steps."""
    batch_size = image.size(0)
    state = torch.zeros(batch_size, STATE_DIM, device=image.device)

    hashes = []
    with torch.no_grad():
        for step in range(steps):
            state, hash_out = model(image, state)
            hashes.append(hash_out.clone())

    return hashes

# Run recursive inference
hashes = recursive_inference(model, dummy_img, steps=RECURSION_STEPS)

print(f"Generated {len(hashes)} hash vectors")
print(f"Final hash shape: {hashes[-1].shape}")

In [None]:
# Analyze hash stability across recursive steps
import matplotlib.pyplot as plt

# Compute cosine similarity between consecutive steps
similarities = []
for i in range(1, len(hashes)):
    sim = F.cosine_similarity(hashes[i], hashes[i-1], dim=1).mean().item()
    similarities.append(sim)

plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(range(1, len(hashes)), similarities, 'b-o')
plt.xlabel('Recursion Step')
plt.ylabel('Cosine Similarity with Previous')
plt.title('Hash Convergence Over Recursion')
plt.ylim(0, 1.1)
plt.grid(True)

# Compute similarity to final hash
final_hash = hashes[-1]
similarities_to_final = []
for h in hashes:
    sim = F.cosine_similarity(h, final_hash, dim=1).mean().item()
    similarities_to_final.append(sim)

plt.subplot(1, 2, 2)
plt.plot(range(len(hashes)), similarities_to_final, 'r-o')
plt.xlabel('Recursion Step')
plt.ylabel('Cosine Similarity to Final Hash')
plt.title('Convergence to Final Hash')
plt.ylim(0, 1.1)
plt.grid(True)

plt.tight_layout()
plt.show()

## 4. Adversarial Robustness Test

Test if small perturbations to input cause large changes in hash (they shouldn't after recursive refinement).

In [None]:
def test_perturbation_robustness(model, image, epsilon_range, steps=16):
    """Test hash stability under input perturbations."""
    results = []

    # Get baseline hash
    baseline_hashes = recursive_inference(model, image, steps)
    baseline_final = baseline_hashes[-1]

    for epsilon in epsilon_range:
        # Add random noise
        noise = torch.randn_like(image) * epsilon
        perturbed = image + noise

        # Get perturbed hash
        perturbed_hashes = recursive_inference(model, perturbed, steps)
        perturbed_final = perturbed_hashes[-1]

        # Compute similarity
        sim = F.cosine_similarity(baseline_final, perturbed_final, dim=1).mean().item()
        results.append((epsilon, sim))
        print(f"Epsilon={epsilon:.4f}: Cosine Similarity={sim:.4f}")

    return results

# Test with various noise levels
epsilons = [0.001, 0.01, 0.05, 0.1, 0.2, 0.5]
test_img = torch.randn(1, 3, 224, 224)
results = test_perturbation_robustness(model, test_img, epsilons)

In [None]:
# Plot robustness results
epsilons, sims = zip(*results)

plt.figure(figsize=(8, 5))
plt.plot(epsilons, sims, 'g-o', linewidth=2, markersize=8)
plt.axhline(y=0.9, color='r', linestyle='--', label='0.9 threshold')
plt.xlabel('Noise Level (epsilon)', fontsize=12)
plt.ylabel('Cosine Similarity to Original', fontsize=12)
plt.title('Hash Robustness to Input Perturbations', fontsize=14)
plt.xscale('log')
plt.ylim(0, 1.1)
plt.grid(True, alpha=0.3)
plt.legend()
plt.show()

## 5. ONNX Export Test

In [None]:
import onnx
import onnxruntime as ort
import numpy as np

# Export to ONNX
onnx_path = "test_model.onnx"

model.eval()
dummy_img = torch.randn(1, 3, 224, 224)
dummy_state = torch.zeros(1, STATE_DIM)

torch.onnx.export(
    model,
    (dummy_img, dummy_state),
    onnx_path,
    input_names=["image", "prev_state"],
    output_names=["next_state", "hash"],
    opset_version=14,
    dynamic_axes={
        "image": {0: "batch"},
        "prev_state": {0: "batch"},
        "next_state": {0: "batch"},
        "hash": {0: "batch"},
    },
)

print(f"Exported ONNX model to {onnx_path}")
print(f"File size: {os.path.getsize(onnx_path) / 1024:.2f} KB")

In [None]:
# Validate ONNX model
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
print("ONNX model validation passed!")

# Test ONNX runtime inference
session = ort.InferenceSession(onnx_path)

# Run inference
test_img = np.random.default_rng(42).standard_normal((1, 3, 224, 224)).astype(np.float32)
test_state = np.zeros((1, STATE_DIM), dtype=np.float32)

outputs = session.run(None, {"image": test_img, "prev_state": test_state})
next_state_onnx, hash_onnx = outputs

print(f"ONNX output state shape: {next_state_onnx.shape}")
print(f"ONNX output hash shape: {hash_onnx.shape}")
print(f"ONNX hash L2 norm: {np.linalg.norm(hash_onnx, axis=1).mean():.4f}")

In [None]:
# Compare PyTorch vs ONNX outputs
with torch.no_grad():
    pt_state, pt_hash = model(torch.from_numpy(test_img), torch.from_numpy(test_state))

pt_hash_np = pt_hash.numpy()
pt_state_np = pt_state.numpy()

hash_diff = np.abs(pt_hash_np - hash_onnx).max()
state_diff = np.abs(pt_state_np - next_state_onnx).max()

print(f"Max hash difference (PyTorch vs ONNX): {hash_diff:.8f}")
print(f"Max state difference (PyTorch vs ONNX): {state_diff:.8f}")

if hash_diff < 1e-5 and state_diff < 1e-5:
    print("‚úÖ ONNX export matches PyTorch output!")
else:
    print("‚ö†Ô∏è Warning: Numerical differences detected")

## 6. Latency Benchmark

In [None]:
import time

def benchmark_inference(model, device, num_runs=100, warmup=10):
    """Benchmark recursive inference latency."""
    model = model.to(device)
    model.eval()

    test_img = torch.randn(1, 3, 224, 224).to(device)

    # Warmup
    for _ in range(warmup):
        _ = recursive_inference(model, test_img, steps=16)

    if device.type == 'cuda':
        torch.cuda.synchronize()

    # Benchmark
    times = []
    for _ in range(num_runs):
        if device.type == 'cuda':
            torch.cuda.synchronize()
        start = time.perf_counter()

        _ = recursive_inference(model, test_img, steps=16)

        if device.type == 'cuda':
            torch.cuda.synchronize()
        times.append(time.perf_counter() - start)

    return times

# Benchmark on CPU
cpu_times = benchmark_inference(model, torch.device('cpu'), num_runs=50)
print(f"CPU Latency: {np.mean(cpu_times)*1000:.2f} ¬± {np.std(cpu_times)*1000:.2f} ms")

# Benchmark on GPU if available
if torch.cuda.is_available():
    gpu_times = benchmark_inference(model, torch.device('cuda'), num_runs=100)
    print(f"GPU Latency: {np.mean(gpu_times)*1000:.2f} ¬± {np.std(gpu_times)*1000:.2f} ms")

## 8. Summary

### ‚úÖ This notebook provides a complete workflow:

| Step | Cell | Description |
|------|------|-------------|
| 0 | Mount Drive | Connect Google Drive for data/checkpoints |
| 0.1 | Download Data | Download FFHQ, COCO, text images directly to Drive |
| 1 | Setup | Clone repo & install dependencies |
| 1.1 | Train | Build manifest & train with DINOv3 distillation |
| 2-3 | Architecture | Verify model structure & forward pass |
| 4 | Recursion | Test recursive inference convergence |
| 5 | Robustness | Perturbation stability testing |
| 6 | ONNX | Export & validate ONNX model |
| 7 | Save | Copy ONNX to Drive for Go runtime |

### üì¶ Output artifacts (on Google Drive):

```
/content/drive/MyDrive/dazzled/outputs/
‚îú‚îÄ‚îÄ checkpoints/
‚îÇ   ‚îî‚îÄ‚îÄ student_epoch_5.safetensors   # Trained weights
‚îî‚îÄ‚îÄ models/
    ‚îî‚îÄ‚îÄ recursive_hasher.onnx          # ONNX for Go runtime
```

### üîó References:
- **DINOv3:** [arXiv:2508.10104](https://arxiv.org/abs/2508.10104)
- **TRM:** [arXiv:2510.04871](https://arxiv.org/abs/2510.04871)
- **Split Accumulation:** [ePrint 2020/1618](https://eprint.iacr.org/2020/1618)

## 7. Export ONNX to Google Drive

Save the trained model to Drive for use in the Go application.

In [None]:
# =============================================================================
# EXPORT TRAINED ONNX MODEL TO GOOGLE DRIVE
# =============================================================================

import shutil
from pathlib import Path

# Source: the ONNX file exported earlier in this notebook
source_onnx = Path("test_model.onnx")

# Destination on Drive
dest_dir = Path("/content/drive/MyDrive/dazzled/outputs/models")
dest_dir.mkdir(parents=True, exist_ok=True)
dest_onnx = dest_dir / "recursive_hasher.onnx"

if source_onnx.exists():
    shutil.copy(source_onnx, dest_onnx)
    size_mb = dest_onnx.stat().st_size / (1024 * 1024)
    print(f"‚úì ONNX model saved to Drive: {dest_onnx}")
    print(f"  Size: {size_mb:.2f} MB")

    # Also save the latest checkpoint
    ckpt_dir = Path("/content/drive/MyDrive/dazzled/outputs/checkpoints")
    checkpoints = sorted(ckpt_dir.glob("*.safetensors"))
    if checkpoints:
        print("\nüì¶ Available artifacts on Drive:")
        print(f"  Model: {dest_onnx}")
        print(f"  Checkpoint: {checkpoints[-1]}")
else:
    print("‚ö†Ô∏è  ONNX file not found. Run the ONNX export cell first (Cell 23-25).")

# üéØ MODEL VALIDATION MILESTONES

Before exporting to ONNX and implementing in Go, the model must pass **all four milestones**.
These tests ensure the recursive student is stable, accurate, and portable.

| Milestone | Test | Success Criteria |
|-----------|------|------------------|
| 1 | Recursive Drift | Emb‚ÇÅ ‚âà Emb‚ÇÖ (cosine sim > 0.99) |
| 2 | Validation Loss | Plateaued for 3-5 epochs |
| 3 | Preprocessing Parity | Go/Python use identical transforms |
| 4 | ONNX Parity | PyTorch vs ONNX diff < 1e-5 |

## Milestone 1: Recursive Drift Test (Stability)

A recursive model feeds its own output back into itself. If the model is unstable,
errors compound and the embedding "drifts" into garbage after 2-3 passes.

**Test:** Run an image through the student 5 times recursively.
**Pass Criteria:** Distance between Emb‚ÇÅ and Emb‚ÇÖ should be near-zero (cosine similarity > 0.99).

In [None]:
# =============================================================================
# MILESTONE 1: RECURSIVE DRIFT TEST
# =============================================================================
# Test that the recursive student produces stable embeddings across multiple passes.
# If the model is well-trained, feeding the output back should not cause drift.

import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from pathlib import Path
import numpy as np

# Load trained student model
CHECKPOINT_DIR = Path("/content/drive/MyDrive/dazzled/outputs/checkpoints")
ONNX_PATH = Path("/content/drive/MyDrive/dazzled/outputs/student.onnx")

# Find latest checkpoint
checkpoints = sorted(CHECKPOINT_DIR.glob("*.safetensors"))
if not checkpoints:
    raise FileNotFoundError(f"No checkpoints found in {CHECKPOINT_DIR}")

latest_ckpt = checkpoints[-1]
print(f"Loading checkpoint: {latest_ckpt.name}")

# Load model
import safetensors.torch
import sys
sys.path.append("/content")
from models.recursive_student import RecursiveHasher

STATE_DIM = 128
HASH_DIM = 96
IMAGE_SIZE = 224
RECURSION_STEPS = 16

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
student = RecursiveHasher(state_dim=STATE_DIM, hash_dim=HASH_DIM).to(device)
safetensors.torch.load_model(student, str(latest_ckpt))
student.eval()

# Standard ImageNet normalization (MUST match Go implementation)
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

# Load a test image
DATA_ROOT = Path("/content/data")
test_images = list(DATA_ROOT.rglob("*.jpg"))[:5]
if not test_images:
    raise FileNotFoundError("No test images found")

print(f"\n{'='*65}")
print("MILESTONE 1: RECURSIVE DRIFT TEST")
print("="*65)

milestone1_passed = True
for img_path in test_images:
    img = Image.open(img_path).convert("RGB")
    x = transform(img).unsqueeze(0).to(device)
    
    embeddings = []
    with torch.no_grad():
        state = torch.zeros(1, STATE_DIM, device=device)
        for pass_num in range(5):
            # Run through student
            for _ in range(RECURSION_STEPS):
                state, student_hash = student(x, state)
            embeddings.append(F.normalize(student_hash, dim=1).cpu())
    
    # Compare first and last embeddings
    emb1 = embeddings[0]
    emb5 = embeddings[4]
    cosine_sim = F.cosine_similarity(emb1, emb5).item()
    l2_dist = torch.norm(emb1 - emb5).item()
    
    status = "‚úì PASS" if cosine_sim > 0.99 else "‚úó FAIL"
    if cosine_sim <= 0.99:
        milestone1_passed = False
    
    print(f"{status} {img_path.name[:30]:<30} | Cosine(Emb1,Emb5)={cosine_sim:.6f} | L2={l2_dist:.6f}")

print(f"\n{'MILESTONE 1: PASSED ‚úì' if milestone1_passed else 'MILESTONE 1: FAILED ‚úó'}")
print("="*65)

## Milestone 2: Validation Loss Plateau + Visual Inspection

Check that validation loss has plateaued and visually verify embeddings on "hard" images
(blurry faces, text documents, edge cases) are close to teacher embeddings.

**Test:** Compare student vs teacher embeddings on held-out validation set.
**Pass Criteria:** Student-Teacher cosine similarity > 0.95 on average.

In [None]:
# =============================================================================
# MILESTONE 2: VALIDATION LOSS & TEACHER-STUDENT ALIGNMENT
# =============================================================================
# Verify the student embeddings closely match the teacher (DINOv3) embeddings.

from transformers import AutoModel

# Load teacher model
print("Loading DINOv3 teacher model...")
teacher = AutoModel.from_pretrained(
    "facebook/dinov3-vitl16-pretrain-lvd1689m",
    trust_remote_code=True
).to(device)
teacher.eval()

# Select validation images (mix of easy and hard cases)
val_images = []
for subdir in ["ffhq", "openimages", "mobileviews"]:
    subdir_path = DATA_ROOT / subdir
    if subdir_path.exists():
        imgs = list(subdir_path.glob("*.jpg"))[:5]
        val_images.extend(imgs)

if len(val_images) < 5:
    val_images = test_images

print(f"\n{'='*65}")
print("MILESTONE 2: TEACHER-STUDENT ALIGNMENT")
print("="*65)
print(f"Evaluating on {len(val_images)} validation images...")

similarities = []
milestone2_details = []

for img_path in val_images:
    img = Image.open(img_path).convert("RGB")
    x = transform(img).unsqueeze(0).to(device)
    
    with torch.no_grad():
        # Teacher embedding
        teacher_out = teacher(x)
        teacher_emb = F.normalize(teacher_out.last_hidden_state[:, 0], dim=1)
        
        # Student embedding
        state = torch.zeros(1, STATE_DIM, device=device)
        for _ in range(RECURSION_STEPS):
            state, student_hash = student(x, state)
        student_emb = F.normalize(student_hash, dim=1)
        
        # Compare
        cosine_sim = F.cosine_similarity(teacher_emb, student_emb).item()
        similarities.append(cosine_sim)
        
        status = "‚úì" if cosine_sim > 0.90 else "‚ö†" if cosine_sim > 0.80 else "‚úó"
        milestone2_details.append((img_path.parent.name, img_path.name[:25], cosine_sim, status))

# Print results
print(f"\n{'Category':<15} {'Image':<28} {'Similarity':>12} {'Status'}")
print("-"*65)
for cat, name, sim, status in milestone2_details:
    print(f"{cat:<15} {name:<28} {sim:>12.4f} {status:>6}")

avg_sim = np.mean(similarities)
min_sim = np.min(similarities)
milestone2_passed = avg_sim > 0.95

print("-"*65)
print(f"{'AVERAGE':<44} {avg_sim:>12.4f}")
print(f"{'MINIMUM':<44} {min_sim:>12.4f}")
print(f"\n{'MILESTONE 2: PASSED ‚úì' if milestone2_passed else 'MILESTONE 2: FAILED ‚úó'} (threshold: avg > 0.95)")
print("="*65)

# Cleanup teacher to free GPU memory
del teacher
torch.cuda.empty_cache()

## Milestone 3: Preprocessing Parity Check

**Critical:** The Go implementation MUST use identical preprocessing to Python.
This cell documents and verifies the exact preprocessing pipeline.

| Parameter | Value | Go Implementation |
|-----------|-------|-------------------|
| Image Size | 224√ó224 | `imaging.Resize(224, 224, imaging.Lanczos)` |
| Interpolation | Bicubic | `imaging.Lanczos` (closest match) |
| Normalization | ImageNet | mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225] |
| Channel Order | RGB | Standard (not BGR) |
| Data Type | float32 | `float32` |

In [None]:
# =============================================================================
# MILESTONE 3: PREPROCESSING PARITY CHECK
# =============================================================================
# Document and verify the exact preprocessing pipeline for Go parity.

print("="*65)
print("MILESTONE 3: PREPROCESSING PARITY SPECIFICATION")
print("="*65)

# Lock these values - changing them requires retraining!
PREPROCESSING_SPEC = {
    "image_size": IMAGE_SIZE,
    "interpolation": "BICUBIC",  # Go: imaging.Lanczos (closest match)
    "normalization": {
        "mean": IMAGENET_MEAN,
        "std": IMAGENET_STD,
    },
    "channel_order": "RGB",  # Not BGR!
    "data_type": "float32",
    "recursion_steps": RECURSION_STEPS,
    "state_dim": STATE_DIM,
    "hash_dim": HASH_DIM,
}

print("\nüìã LOCKED PREPROCESSING SPECIFICATION:")
print("-"*65)
print(f"  Image Size:      {PREPROCESSING_SPEC['image_size']}√ó{PREPROCESSING_SPEC['image_size']}")
print(f"  Interpolation:   {PREPROCESSING_SPEC['interpolation']}")
print(f"  Mean:            {PREPROCESSING_SPEC['normalization']['mean']}")
print(f"  Std:             {PREPROCESSING_SPEC['normalization']['std']}")
print(f"  Channel Order:   {PREPROCESSING_SPEC['channel_order']}")
print(f"  Data Type:       {PREPROCESSING_SPEC['data_type']}")
print(f"  Recursion Steps: {PREPROCESSING_SPEC['recursion_steps']}")
print(f"  State Dim:       {PREPROCESSING_SPEC['state_dim']}")
print(f"  Hash Dim:        {PREPROCESSING_SPEC['hash_dim']}")

# Demonstrate preprocessing step-by-step
print("\nüìù STEP-BY-STEP PREPROCESSING (for Go implementation):")
print("-"*65)

test_img = Image.open(test_images[0]).convert("RGB")
print(f"1. Load image as RGB: {test_img.size} ‚Üí {test_img.mode}")

# Resize
resized = test_img.resize((IMAGE_SIZE, IMAGE_SIZE), Image.BICUBIC)
print(f"2. Resize to {IMAGE_SIZE}√ó{IMAGE_SIZE} using BICUBIC interpolation")

# To tensor (0-1 range)
import numpy as np
arr = np.array(resized).astype(np.float32) / 255.0
print(f"3. Convert to float32 and scale to [0, 1]: shape={arr.shape}, range=[{arr.min():.3f}, {arr.max():.3f}]")

# Normalize
for c, (m, s) in enumerate(zip(IMAGENET_MEAN, IMAGENET_STD)):
    arr[:, :, c] = (arr[:, :, c] - m) / s
print(f"4. Normalize with ImageNet mean/std: range=[{arr.min():.3f}, {arr.max():.3f}]")

# Transpose to NCHW
arr = arr.transpose(2, 0, 1)  # HWC -> CHW
arr = arr[np.newaxis, ...]   # Add batch dimension
print(f"5. Transpose HWC‚ÜíCHW and add batch: shape={arr.shape}")

print("\n‚úÖ Go implementation must produce identical tensor!")
print("\n‚ö†Ô∏è  CRITICAL: Use imaging.Lanczos in Go (closest to BICUBIC)")
print("="*65)

# Save spec to JSON for Go reference
import json
spec_path = Path("/content/drive/MyDrive/dazzled/outputs/preprocessing_spec.json")
spec_path.parent.mkdir(parents=True, exist_ok=True)
with open(spec_path, "w") as f:
    json.dump(PREPROCESSING_SPEC, f, indent=2)
print(f"\nüíæ Saved specification to: {spec_path}")

milestone3_passed = True  # Manual check - specification documented
print(f"\n{'MILESTONE 3: PASSED ‚úì' if milestone3_passed else 'MILESTONE 3: FAILED ‚úó'} (specification documented)")

## Milestone 4: ONNX Parity Check

The final and most critical test: verify that the exported ONNX model produces
**identical outputs** to the PyTorch model.

**Test:** Run the same image through PyTorch and ONNX, compare outputs.
**Pass Criteria:** Maximum absolute difference < 1e-5.

In [None]:
# =============================================================================
# MILESTONE 4: ONNX PARITY CHECK
# =============================================================================
# Verify PyTorch and ONNX outputs are identical (to ~1e-5 precision).

import onnx
import onnxruntime as ort

print("="*65)
print("MILESTONE 4: ONNX PARITY CHECK")
print("="*65)

# Check if ONNX model exists
if not ONNX_PATH.exists():
    print(f"‚ö†Ô∏è  ONNX model not found at {ONNX_PATH}")
    print("   Run the ONNX export cell first!")
    milestone4_passed = False
else:
    # Load ONNX model
    print(f"Loading ONNX model: {ONNX_PATH.name}")
    onnx_model = onnx.load(str(ONNX_PATH))
    onnx.checker.check_model(onnx_model)
    print("‚úì ONNX model validation passed")
    
    # Create ONNX runtime session
    session = ort.InferenceSession(str(ONNX_PATH))
    
    # Test multiple images
    parity_results = []
    
    for img_path in test_images[:5]:
        img = Image.open(img_path).convert("RGB")
        x = transform(img).unsqueeze(0)
        
        # PyTorch inference
        with torch.no_grad():
            state_pt = torch.zeros(1, STATE_DIM, device=device)
            for _ in range(RECURSION_STEPS):
                state_pt, hash_pt = student(x.to(device), state_pt)
            pytorch_output = hash_pt.cpu().numpy()
        
        # ONNX inference
        x_np = x.numpy()
        state_np = np.zeros((1, STATE_DIM), dtype=np.float32)
        
        for _ in range(RECURSION_STEPS):
            onnx_outputs = session.run(None, {
                "image": x_np,
                "prev_state": state_np
            })
            state_np = onnx_outputs[0]
        onnx_output = onnx_outputs[1]
        
        # Compare
        max_diff = np.abs(pytorch_output - onnx_output).max()
        mean_diff = np.abs(pytorch_output - onnx_output).mean()
        
        status = "‚úì PASS" if max_diff < 1e-4 else "‚úó FAIL"
        parity_results.append((img_path.name[:30], max_diff, mean_diff, max_diff < 1e-4))
    
    # Print results
    print(f"\n{'Image':<32} {'Max Diff':>12} {'Mean Diff':>12} {'Status':>8}")
    print("-"*65)
    for name, max_d, mean_d, passed in parity_results:
        status = "‚úì PASS" if passed else "‚úó FAIL"
        print(f"{name:<32} {max_d:>12.2e} {mean_d:>12.2e} {status:>8}")
    
    milestone4_passed = all(r[3] for r in parity_results)
    print("-"*65)
    print(f"\n{'MILESTONE 4: PASSED ‚úì' if milestone4_passed else 'MILESTONE 4: FAILED ‚úó'} (threshold: max_diff < 1e-4)")

print("="*65)

---
## üöÄ GO / NO-GO Decision

Final checkpoint aggregating all milestone results.

In [None]:
# =============================================================================
# GO / NO-GO DECISION
# =============================================================================
# Final checkpoint: Are we ready to lock in the ONNX model?

print("\n" + "="*65)
print("           MODEL VALIDATION SUMMARY - GO/NO-GO DECISION")
print("="*65 + "\n")

milestones = [
    ("Milestone 1", "Recursive Drift Test", "drift_passed" in dir() and drift_passed),
    ("Milestone 2", "Teacher Alignment", "milestone2_passed" in dir() and milestone2_passed),
    ("Milestone 3", "Preprocessing Spec", "milestone3_passed" in dir() and milestone3_passed),
    ("Milestone 4", "ONNX Parity Check", "milestone4_passed" in dir() and milestone4_passed),
]

print(f"{'Milestone':<14} {'Test':<25} {'Status':>10}")
print("-"*50)

all_passed = True
for name, desc, passed in milestones:
    status = "‚úì PASSED" if passed else "‚úó FAILED"
    if not passed:
        all_passed = False
    print(f"{name:<14} {desc:<25} {status:>10}")

print("-"*50)
print()

if all_passed:
    print("‚ïî" + "‚ïê"*63 + "‚ïó")
    print("‚ïë" + " "*20 + "üöÄ GO FOR ONNX EXPORT üöÄ" + " "*18 + "‚ïë")
    print("‚ï†" + "‚ïê"*63 + "‚ï£")
    print("‚ïë  All milestones passed! You are ready to:                    ‚ïë")
    print("‚ïë                                                               ‚ïë")
    print("‚ïë  1. Export final ONNX model:                                  ‚ïë")
    print("‚ïë     python ml-core/training/export_onnx.py                    ‚ïë")
    print("‚ïë                                                               ‚ïë")
    print("‚ïë  2. Copy to Go project:                                       ‚ïë")
    print("‚ïë     cp outputs/recursive_hasher.onnx bin/                     ‚ïë")
    print("‚ïë                                                               ‚ïë")
    print("‚ïë  3. Update internal/bridge/onnx_runtime.go                    ‚ïë")
    print("‚ïö" + "‚ïê"*63 + "‚ïù")
else:
    print("‚ïî" + "‚ïê"*63 + "‚ïó")
    print("‚ïë" + " "*22 + "‚õî NO-GO - FIX ISSUES ‚õî" + " "*17 + "‚ïë")
    print("‚ï†" + "‚ïê"*63 + "‚ï£")
    print("‚ïë  Some milestones failed. Address the issues above before      ‚ïë")
    print("‚ïë  proceeding with ONNX export.                                 ‚ïë")
    print("‚ïö" + "‚ïê"*63 + "‚ïù")

print("\n" + "="*65)