# ü¶ñ DaZZLeD: Recursive Hasher Training Notebook

**Goal:** Train a Tiny Recursive Model (TRM) using DINOv3 distillation for adversarially-robust perceptual hashing.

## How to Use This Notebook

1. **Run Cell 1** - Mount Google Drive
2. **Run Cell 2** - Download training datasets (one-time, ~30 min)
3. **Run Cell 3** - Clone repo & install dependencies
4. **Run Cell 4** - Build manifest & train the model
5. **Run remaining cells** - Test and export the trained model

‚ö†Ô∏è **Before training:** Runtime ‚Üí Change runtime type ‚Üí **GPU (T4 free or A100)**

## 0. Mount Google Drive

In [None]:
# Mount Google Drive (required for data storage)
from google.colab import drive
drive.mount('/content/drive')

# Create project directories
from pathlib import Path

DRIVE_ROOT = Path("/content/drive/MyDrive/dazzled")
DATA_ROOT = DRIVE_ROOT / "data"
OUTPUT_ROOT = DRIVE_ROOT / "outputs"

# Create all needed directories
for d in [DATA_ROOT / "ffhq", DATA_ROOT / "openimages", DATA_ROOT / "text",
          OUTPUT_ROOT / "checkpoints", OUTPUT_ROOT / "models",
          DRIVE_ROOT / "manifests"]:
    d.mkdir(parents=True, exist_ok=True)

print(f"‚úì Project root: {DRIVE_ROOT}")
print(f"‚úì Data root: {DATA_ROOT}")
print(f"‚úì Output root: {OUTPUT_ROOT}")

## 0.1 Download Training Datasets (Optimized for Speed)

**Strategy:** Use local disk (`/content/data`) for training speed, then cache to Drive as a zip.

- **First run:** Downloads data ‚Üí processes ‚Üí creates zip backup on Drive (~30 min)
- **Future runs:** Extracts from Drive zip ‚Üí ready in ~2 min

‚ö†Ô∏è Accessing individual files from Drive during training is **100x slower** than local disk!

In [None]:
# =============================================================================
# DOWNLOAD & PREPARE DATASET (Optimized for Colab I/O)
#
# STANDARD APPROACHES:
#   - FFHQ: Kaggle bulk download ‚Üí ALL 70k images (more data = better model)
#   - OpenImages: FiftyOne (official Google method)
#   - MobileViews: HuggingFace parquet
#
# WHY THESE COUNTS:
#   - FFHQ 70k: Full dataset, 1024x1024 faces ‚Üí resize to 224
#   - OpenImages 15k: Diverse real-world (validation split has 41k, we sample 15k)
#   - MobileViews 2k: Edge cases for text/UI (600k available, 2k is enough)
#   - Total: ~87k images for distillation training
# =============================================================================

import os
import shutil
import io
from pathlib import Path
from PIL import Image
from torchvision import transforms

# =============================================================================
# ‚öôÔ∏è CONFIGURATION
# =============================================================================
FORCE_REDOWNLOAD = False  # Set True to bypass cache

# Expected counts - USING FULL FFHQ (70k)
EXPECTED_COUNTS = {
    "ffhq": 60000,        # Target: 70k, allow ~85% success (some corrupt)
    "openimages": 12000,  # Target: 15k
    "mobileviews": 1500,  # Target: 2k
}

# Paths
DATA_ROOT = Path("/content/data")
DATA_ROOT.mkdir(parents=True, exist_ok=True)
DRIVE_ARCHIVE = Path("/content/drive/MyDrive/dazzled/dazzled_dataset_v5.zip")  # v5 = full FFHQ

# =============================================================================
# VALIDATION HELPER
# =============================================================================
def validate_dataset(data_root, expected_counts):
    results = {}
    all_valid = True
    for name, min_count in expected_counts.items():
        path = data_root / name
        count = len(list(path.glob("*.jpg"))) + len(list(path.glob("*.png"))) if path.exists() else 0
        valid = count >= min_count
        results[name] = {"count": count, "expected": min_count, "valid": valid}
        if not valid:
            all_valid = False
    return all_valid, results

# =============================================================================
# CHECK CACHE
# =============================================================================
need_download = {"ffhq": True, "openimages": True, "mobileviews": True}

if FORCE_REDOWNLOAD:
    print("üîÑ FORCE_REDOWNLOAD=True - Starting fresh...")
    shutil.rmtree(DATA_ROOT, ignore_errors=True)
    DATA_ROOT.mkdir(parents=True, exist_ok=True)
elif DRIVE_ARCHIVE.exists():
    print(f"üì¶ Found cache: {DRIVE_ARCHIVE}")
    shutil.unpack_archive(DRIVE_ARCHIVE, DATA_ROOT)
    all_valid, validation = validate_dataset(DATA_ROOT, EXPECTED_COUNTS)
    print("\n   Validation:")
    for name, info in validation.items():
        status = "‚úì" if info["valid"] else "‚úó"
        print(f"   {status} {name}: {info['count']:,} / {info['expected']:,}")
        need_download[name] = not info["valid"]
    if all_valid:
        print("\n‚úì All datasets ready!")

# =============================================================================
# DOWNLOAD MISSING DATASETS
# =============================================================================
if any(need_download.values()):
    print("\n" + "="*65)
    print("üöÄ DOWNLOADING DATASETS")
    print("="*65)

    # -------------------------------------------------------------------------
    # 1. FFHQ via Kaggle - FULL 70k DATASET
    #    Why Kaggle: Single 89GB zip, fast CDN, no rate limits
    #    Why 70k: More data = better distillation. We resize to 224x224.
    # -------------------------------------------------------------------------
    if need_download["ffhq"]:
        ffhq_dir = DATA_ROOT / "ffhq"
        ffhq_dir.mkdir(parents=True, exist_ok=True)
        
        print("\nüì• [1/3] FFHQ via Kaggle - FULL 70k DATASET")
        print("   Source: arnaud58/flickrfaceshq-dataset-ffhq (89GB ‚Üí 70k √ó 1024√ó1024)")
        print("   Output: 70k √ó 224√ó224 (~3GB after resize)")
        print("   Time estimate: ~15 min download + ~10 min resize")
        
        # Setup Kaggle credentials
        from google.colab import userdata
        kaggle_ok = False
        try:
            kaggle_username = userdata.get('KAGGLE_USERNAME')
            kaggle_key = userdata.get('KAGGLE_KEY')
            
            os.makedirs('/root/.kaggle', exist_ok=True)
            with open('/root/.kaggle/kaggle.json', 'w') as f:
                f.write(f'{{"username":"{kaggle_username}","key":"{kaggle_key}"}}')
            os.chmod('/root/.kaggle/kaggle.json', 0o600)
            print("   ‚úì Kaggle credentials configured")
            kaggle_ok = True
        except Exception as e:
            print(f"\n   ‚ùå Kaggle credentials not found!")
            print(f"   To set up:")
            print(f"   1. Go to kaggle.com ‚Üí Your Profile ‚Üí Account ‚Üí API ‚Üí Create New Token")
            print(f"   2. Add to Colab Secrets (üîë icon on left sidebar):")
            print(f"      - KAGGLE_USERNAME = your_username")
            print(f"      - KAGGLE_KEY = your_api_key")
        
        if kaggle_ok:
            print("\n   Downloading from Kaggle CDN...")
            !pip install -q kaggle
            !kaggle datasets download -d arnaud58/flickrfaceshq-dataset-ffhq -p /content/ffhq_download --unzip
            
            # Process ALL images (not sampling)
            print("\n   Resizing all images to 224√ó224...")
            from tqdm import tqdm
            resize = transforms.Resize((224, 224))
            
            ffhq_source = Path("/content/ffhq_download")
            all_images = sorted(list(ffhq_source.rglob("*.png")) + list(ffhq_source.rglob("*.jpg")))
            print(f"   Found {len(all_images):,} images")
            
            for i, img_path in enumerate(tqdm(all_images, desc="   Resizing")):
                try:
                    img = Image.open(img_path).convert("RGB")
                    resize(img).save(ffhq_dir / f"ffhq_{i:05d}.jpg", quality=90)
                except:
                    pass
            
            # Cleanup (important - 89GB!)
            print("   Cleaning up original 1024√ó1024 images...")
            shutil.rmtree("/content/ffhq_download", ignore_errors=True)
            print(f"   ‚úì FFHQ: {len(list(ffhq_dir.glob('*.jpg'))):,} images")

    # -------------------------------------------------------------------------
    # 2. OpenImages via FiftyOne (official method, bug workaround)
    # -------------------------------------------------------------------------
    if need_download["openimages"]:
        openimages_dir = DATA_ROOT / "openimages"
        openimages_dir.mkdir(parents=True, exist_ok=True)
        
        print("\nüì• [2/3] OpenImages via FiftyOne")
        print("   Method: Official Google download (handles AWS S3 shards)")
        print("   Target: 15k diverse real-world images")
        
        !pip install -q fiftyone
        
        import fiftyone as fo
        import fiftyone.zoo as foz
        
        print("   Downloading from AWS S3...")
        # Workaround: do NOT set dataset_dir (FiftyOne bug)
        dataset = foz.load_zoo_dataset(
            "open-images-v7",
            split="validation",
            max_samples=15000,
            shuffle=True,
            seed=42
        )
        
        print("   Copying images to openimages_dir...")
        from tqdm import tqdm
        oi_resize = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224)
        ])
        
        for i, sample in enumerate(tqdm(dataset, desc="   Resizing")):
            try:
                img = Image.open(sample.filepath).convert("RGB")
                oi_resize(img).save(openimages_dir / f"openimages_{i:05d}.jpg", quality=90)
            except:
                pass
        
        fo.delete_dataset(dataset.name)
        print(f"   ‚úì OpenImages: {len(list(openimages_dir.glob('*.jpg'))):,} images")

    # -------------------------------------------------------------------------
    # 3. MobileViews (parquet - already optimal)
    # -------------------------------------------------------------------------
    if need_download["mobileviews"]:
        mobileviews_dir = DATA_ROOT / "mobileviews"
        mobileviews_dir.mkdir(parents=True, exist_ok=True)
        
        print("\nüì• [3/3] MobileViews via HuggingFace Parquet")
        print("   Target: 2k mobile UI screenshots (edge cases for text/UI)")
        
        !pip install -q huggingface_hub pyarrow
        
        from huggingface_hub import hf_hub_download, login
        from google.colab import userdata
        import pyarrow.parquet as pq
        
        try:
            hf_token = userdata.get('HF_TOKEN')
            if hf_token:
                login(hf_token)
        except:
            pass
        
        print("   Downloading parquet...")
        parquet_path = hf_hub_download(
            repo_id="mllmTeam/MobileViews",
            filename="MobileViews_Screenshots_ViewHierarchies/Parquets/MobileViews_0-150000.parquet",
            repo_type="dataset",
            local_dir="/content/mobileviews_cache"
        )
        
        print("   Extracting screenshots...")
        table = pq.read_table(parquet_path, columns=["image_content"])
        
        mv_resize = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224)
        ])
        
        from tqdm import tqdm
        num_samples = 2000
        step = max(1, len(table) // num_samples)
        
        for idx, i in enumerate(tqdm(range(0, len(table), step), desc="   Processing", total=num_samples)):
            if idx >= num_samples:
                break
            try:
                img_bytes = table.column("image_content")[i].as_py()
                img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
                mv_resize(img).save(mobileviews_dir / f"mobileview_{idx:05d}.jpg", quality=90)
            except:
                pass
        
        shutil.rmtree("/content/mobileviews_cache", ignore_errors=True)
        print(f"   ‚úì MobileViews: {len(list(mobileviews_dir.glob('*.jpg'))):,} images")

    # -------------------------------------------------------------------------
    # SAVE BACKUP
    # -------------------------------------------------------------------------
    all_valid, validation = validate_dataset(DATA_ROOT, EXPECTED_COUNTS)
    
    print("\n" + "="*65)
    print("üìã VALIDATION")
    print("="*65)
    for name, info in validation.items():
        status = "‚úì" if info["valid"] else "‚úó"
        print(f"{status} {name:<15} {info['count']:>6,} / {info['expected']:,}")
    
    if all_valid:
        print(f"\nüíæ Creating backup: {DRIVE_ARCHIVE}")
        DRIVE_ARCHIVE.parent.mkdir(parents=True, exist_ok=True)
        if DRIVE_ARCHIVE.exists():
            DRIVE_ARCHIVE.unlink()
        shutil.make_archive(str(DRIVE_ARCHIVE.with_suffix('')), 'zip', DATA_ROOT)
        archive_size = DRIVE_ARCHIVE.stat().st_size / (1024**3)
        print(f"‚úì Backup complete! ({archive_size:.2f} GB)")
    else:
        print("\n‚ö†Ô∏è Some datasets incomplete - not saving cache.")

# =============================================================================
# SUMMARY
# =============================================================================
print("\n" + "="*65)
print("üìä DATASET SUMMARY")
print("="*65)
print(f"{'Dataset':<20} {'Role':<25} {'Count':>10}")
print("-"*65)
for subdir, role in [("ffhq", "People (faces)"),
                      ("openimages", "Life (real-world)"),
                      ("mobileviews", "Edge Case (mobile UI)")]:
    path = DATA_ROOT / subdir
    count = len(list(path.glob("*.jpg"))) if path.exists() else 0
    print(f"{subdir:<20} {role:<25} {count:>10,}")
total = sum(1 for _ in DATA_ROOT.rglob("*.jpg"))
print("-"*65)
print(f"{'TOTAL':<20} {'':<25} {total:>10,}")
print("="*65)
print(f"\nüìÅ Data: {DATA_ROOT}")
print(f"üíæ Cache: {DRIVE_ARCHIVE}")

## 1. Setup & Installation

In [None]:
# Clone the repo (only needed in Colab)
import os
if not os.path.exists('DaZZLeD'):
    !git clone https://github.com/D13ya/DaZZLeD.git
    %cd DaZZLeD/ml-core
else:
    %cd DaZZLeD/ml-core

!pip install -q -r requirements.txt

In [None]:
import sys
import torch
import torch.nn.functional as F
from pathlib import Path

# Add project root to path
sys.path.insert(0, str(Path.cwd()))

from models.recursive_student import RecursiveHasher

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 1.1 Build Manifest & Train Model

‚ö†Ô∏è **Switch to GPU first:** Runtime ‚Üí Change runtime type ‚Üí GPU (T4 or A100)

This cell:
1. Builds a manifest of all training images
2. Trains the RecursiveHasher using DINOv3 distillation
3. Saves checkpoints to Google Drive

In [None]:
# =============================================================================
# BUILD MANIFEST FROM LOCAL DATA
# =============================================================================

from pathlib import Path

# Point to local fast storage
DATA_ROOT = Path("/content/data")
DRIVE_ROOT = Path("/content/drive/MyDrive/dazzled")

# Find all training images
exts = {".jpg", ".jpeg", ".png", ".bmp", ".webp"}
paths = [str(p) for p in DATA_ROOT.rglob("*") if p.suffix.lower() in exts]

print(f"Found {len(paths):,} training images in {DATA_ROOT}")

if len(paths) < 100:
    print("‚ö†Ô∏è  Not enough images! Run the download cell first.")
else:
    # Write manifest to Drive so it persists, but content points to /content/data
    manifest_path = DRIVE_ROOT / "manifests/train.txt"
    manifest_path.parent.mkdir(parents=True, exist_ok=True)

    with open(manifest_path, "w") as f:
        f.write("\n".join(paths))

    print(f"‚úì Manifest written: {manifest_path}")
    print(f"  (Points to {len(paths)} local files for high-speed training)")

In [None]:
# =============================================================================
# TRAIN THE MODEL (Choose one option)
# =============================================================================

# üîß TRAINING OPTIONS - Uncomment ONE of the following:

# -----------------------------------------------------------------------------
# OPTION A: Quick Test (5-10 min on T4) - Just to verify everything works
# -----------------------------------------------------------------------------
# !python training/train.py \
#     --data-list /content/drive/MyDrive/dazzled/manifests/train.txt \
#     --teacher facebook/dinov3-vitl16-pretrain-lvd1689m \
#     --epochs 1 \
#     --batch-size 16 \
#     --recursion-steps 8 \
#     --max-steps 100 \
#     --amp \
#     --log-interval 10 \
#     --checkpoint-dir /content/drive/MyDrive/dazzled/outputs/checkpoints

# -----------------------------------------------------------------------------
# OPTION B: Full Training (High RAM Optimized)
# -----------------------------------------------------------------------------
# Optimized for 50GB+ RAM:
# --cache-ram: Preloads all images into RAM (fastest IO, uses ~10GB RAM)
# --workers 8: Maximizes CPU usage for augmentation
# -----------------------------------------------------------------------------
!python training/train.py \
    --data-list /content/drive/MyDrive/dazzled/manifests/train.txt \
    --teacher facebook/dinov3-vitl16-pretrain-lvd1689m \
    --epochs 5 \
    --batch-size 64 \
    --recursion-steps 16 \
    --grad-accum 2 \
    --lr 1e-4 \
    --amp \
    --allow-tf32 \
    --channels-last \
    --cudnn-benchmark \
    --workers 8 \
    --cache-ram \
    --log-interval 50 \
    --checkpoint-dir /content/drive/MyDrive/dazzled/outputs/checkpoints \
    --checkpoint-every 500

# -----------------------------------------------------------------------------
# OPTION C: Resume from Checkpoint (if session disconnected)
# -----------------------------------------------------------------------------
# !python training/train.py \
#     --data-list /content/drive/MyDrive/dazzled/manifests/train.txt \
#     --teacher facebook/dinov3-vitl16-pretrain-lvd1689m \
#     --resume /content/drive/MyDrive/dazzled/outputs/checkpoints/student_epoch_3.safetensors \
#     --epochs 5 \
#     --batch-size 64 \
#     --recursion-steps 16 \
#     --amp \
#     --checkpoint-dir /content/drive/MyDrive/dazzled/outputs/checkpoints

In [None]:
# =============================================================================
# LIST CHECKPOINTS & LOAD TRAINED WEIGHTS
# =============================================================================

from pathlib import Path
import safetensors.torch

CKPT_DIR = Path("/content/drive/MyDrive/dazzled/outputs/checkpoints")

# List available checkpoints
checkpoints = sorted(CKPT_DIR.glob("*.safetensors"))
print(f"Found {len(checkpoints)} checkpoints:")
for ckpt in checkpoints:
    size_mb = ckpt.stat().st_size / (1024 * 1024)
    print(f"  {ckpt.name} ({size_mb:.2f} MB)")

# Load the latest checkpoint into model
if checkpoints:
    latest_ckpt = checkpoints[-1]
    print(f"\nüì• Loading: {latest_ckpt.name}")
    
    # Model should already be defined from earlier cell
    # If not, uncomment:
    # from models.recursive_student import RecursiveHasher
    # model = RecursiveHasher(state_dim=128, hash_dim=96)
    
    safetensors.torch.load_model(model, str(latest_ckpt))
    model.eval()
    print("‚úì Trained weights loaded!")
else:
    print("‚ö†Ô∏è  No checkpoints found. Run training first.")

## 2. Model Architecture Test

In [None]:
# Initialize model with default params
STATE_DIM = 128
HASH_DIM = 96
RECURSION_STEPS = 16

model = RecursiveHasher(state_dim=STATE_DIM, hash_dim=HASH_DIM)
model.eval()

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model size: ~{total_params * 4 / 1024 / 1024:.2f} MB (float32)")

In [None]:
# Test forward pass with dummy input
batch_size = 4
image_size = 224

dummy_img = torch.randn(batch_size, 3, image_size, image_size)
dummy_state = torch.zeros(batch_size, STATE_DIM)

with torch.no_grad():
    next_state, hash_out = model(dummy_img, dummy_state)

print(f"Input image shape: {dummy_img.shape}")
print(f"Input state shape: {dummy_state.shape}")
print(f"Output state shape: {next_state.shape}")
print(f"Output hash shape: {hash_out.shape}")
print(f"Hash L2 norm (should be ~1.0): {torch.norm(hash_out, dim=1).mean():.4f}")

## 3. Recursive Inference Test

The key innovation is running the model recursively 16 times, refining the hash at each step.

In [None]:
def recursive_inference(model, image, steps=16):
    """Run recursive inference for the specified number of steps."""
    batch_size = image.size(0)
    state = torch.zeros(batch_size, STATE_DIM, device=image.device)
    
    hashes = []
    with torch.no_grad():
        for step in range(steps):
            state, hash_out = model(image, state)
            hashes.append(hash_out.clone())
    
    return hashes

# Run recursive inference
hashes = recursive_inference(model, dummy_img, steps=RECURSION_STEPS)

print(f"Generated {len(hashes)} hash vectors")
print(f"Final hash shape: {hashes[-1].shape}")

In [None]:
# Analyze hash stability across recursive steps
import matplotlib.pyplot as plt

# Compute cosine similarity between consecutive steps
similarities = []
for i in range(1, len(hashes)):
    sim = F.cosine_similarity(hashes[i], hashes[i-1], dim=1).mean().item()
    similarities.append(sim)

plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(range(1, len(hashes)), similarities, 'b-o')
plt.xlabel('Recursion Step')
plt.ylabel('Cosine Similarity with Previous')
plt.title('Hash Convergence Over Recursion')
plt.ylim(0, 1.1)
plt.grid(True)

# Compute similarity to final hash
final_hash = hashes[-1]
similarities_to_final = []
for h in hashes:
    sim = F.cosine_similarity(h, final_hash, dim=1).mean().item()
    similarities_to_final.append(sim)

plt.subplot(1, 2, 2)
plt.plot(range(len(hashes)), similarities_to_final, 'r-o')
plt.xlabel('Recursion Step')
plt.ylabel('Cosine Similarity to Final Hash')
plt.title('Convergence to Final Hash')
plt.ylim(0, 1.1)
plt.grid(True)

plt.tight_layout()
plt.show()

## 4. Adversarial Robustness Test

Test if small perturbations to input cause large changes in hash (they shouldn't after recursive refinement).

In [None]:
def test_perturbation_robustness(model, image, epsilon_range, steps=16):
    """Test hash stability under input perturbations."""
    results = []
    
    # Get baseline hash
    baseline_hashes = recursive_inference(model, image, steps)
    baseline_final = baseline_hashes[-1]
    
    for epsilon in epsilon_range:
        # Add random noise
        noise = torch.randn_like(image) * epsilon
        perturbed = image + noise
        
        # Get perturbed hash
        perturbed_hashes = recursive_inference(model, perturbed, steps)
        perturbed_final = perturbed_hashes[-1]
        
        # Compute similarity
        sim = F.cosine_similarity(baseline_final, perturbed_final, dim=1).mean().item()
        results.append((epsilon, sim))
        print(f"Epsilon={epsilon:.4f}: Cosine Similarity={sim:.4f}")
    
    return results

# Test with various noise levels
epsilons = [0.001, 0.01, 0.05, 0.1, 0.2, 0.5]
test_img = torch.randn(1, 3, 224, 224)
results = test_perturbation_robustness(model, test_img, epsilons)

In [None]:
# Plot robustness results
epsilons, sims = zip(*results)

plt.figure(figsize=(8, 5))
plt.plot(epsilons, sims, 'g-o', linewidth=2, markersize=8)
plt.axhline(y=0.9, color='r', linestyle='--', label='0.9 threshold')
plt.xlabel('Noise Level (epsilon)', fontsize=12)
plt.ylabel('Cosine Similarity to Original', fontsize=12)
plt.title('Hash Robustness to Input Perturbations', fontsize=14)
plt.xscale('log')
plt.ylim(0, 1.1)
plt.grid(True, alpha=0.3)
plt.legend()
plt.show()

## 5. ONNX Export Test

In [None]:
import onnx
import onnxruntime as ort
import numpy as np

# Export to ONNX
onnx_path = "test_model.onnx"

model.eval()
dummy_img = torch.randn(1, 3, 224, 224)
dummy_state = torch.zeros(1, STATE_DIM)

torch.onnx.export(
    model,
    (dummy_img, dummy_state),
    onnx_path,
    input_names=["image", "prev_state"],
    output_names=["next_state", "hash"],
    opset_version=14,
    dynamic_axes={
        "image": {0: "batch"},
        "prev_state": {0: "batch"},
        "next_state": {0: "batch"},
        "hash": {0: "batch"},
    },
)

print(f"Exported ONNX model to {onnx_path}")
print(f"File size: {os.path.getsize(onnx_path) / 1024:.2f} KB")

In [None]:
# Validate ONNX model
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
print("ONNX model validation passed!")

# Test ONNX runtime inference
session = ort.InferenceSession(onnx_path)

# Run inference
test_img = np.random.randn(1, 3, 224, 224).astype(np.float32)
test_state = np.zeros((1, STATE_DIM), dtype=np.float32)

outputs = session.run(None, {"image": test_img, "prev_state": test_state})
next_state_onnx, hash_onnx = outputs

print(f"ONNX output state shape: {next_state_onnx.shape}")
print(f"ONNX output hash shape: {hash_onnx.shape}")
print(f"ONNX hash L2 norm: {np.linalg.norm(hash_onnx, axis=1).mean():.4f}")

In [None]:
# Compare PyTorch vs ONNX outputs
with torch.no_grad():
    pt_state, pt_hash = model(torch.from_numpy(test_img), torch.from_numpy(test_state))

pt_hash_np = pt_hash.numpy()
pt_state_np = pt_state.numpy()

hash_diff = np.abs(pt_hash_np - hash_onnx).max()
state_diff = np.abs(pt_state_np - next_state_onnx).max()

print(f"Max hash difference (PyTorch vs ONNX): {hash_diff:.8f}")
print(f"Max state difference (PyTorch vs ONNX): {state_diff:.8f}")

if hash_diff < 1e-5 and state_diff < 1e-5:
    print("‚úÖ ONNX export matches PyTorch output!")
else:
    print("‚ö†Ô∏è Warning: Numerical differences detected")

## 6. Latency Benchmark

In [None]:
import time

def benchmark_inference(model, device, num_runs=100, warmup=10):
    """Benchmark recursive inference latency."""
    model = model.to(device)
    model.eval()
    
    test_img = torch.randn(1, 3, 224, 224).to(device)
    
    # Warmup
    for _ in range(warmup):
        _ = recursive_inference(model, test_img, steps=16)
    
    if device.type == 'cuda':
        torch.cuda.synchronize()
    
    # Benchmark
    times = []
    for _ in range(num_runs):
        if device.type == 'cuda':
            torch.cuda.synchronize()
        start = time.perf_counter()
        
        _ = recursive_inference(model, test_img, steps=16)
        
        if device.type == 'cuda':
            torch.cuda.synchronize()
        times.append(time.perf_counter() - start)
    
    return times

# Benchmark on CPU
cpu_times = benchmark_inference(model, torch.device('cpu'), num_runs=50)
print(f"CPU Latency: {np.mean(cpu_times)*1000:.2f} ¬± {np.std(cpu_times)*1000:.2f} ms")

# Benchmark on GPU if available
if torch.cuda.is_available():
    gpu_times = benchmark_inference(model, torch.device('cuda'), num_runs=100)
    print(f"GPU Latency: {np.mean(gpu_times)*1000:.2f} ¬± {np.std(gpu_times)*1000:.2f} ms")

## 8. Summary

### ‚úÖ This notebook provides a complete workflow:

| Step | Cell | Description |
|------|------|-------------|
| 0 | Mount Drive | Connect Google Drive for data/checkpoints |
| 0.1 | Download Data | Download FFHQ, COCO, text images directly to Drive |
| 1 | Setup | Clone repo & install dependencies |
| 1.1 | Train | Build manifest & train with DINOv3 distillation |
| 2-3 | Architecture | Verify model structure & forward pass |
| 4 | Recursion | Test recursive inference convergence |
| 5 | Robustness | Perturbation stability testing |
| 6 | ONNX | Export & validate ONNX model |
| 7 | Save | Copy ONNX to Drive for Go runtime |

### üì¶ Output artifacts (on Google Drive):

```
/content/drive/MyDrive/dazzled/outputs/
‚îú‚îÄ‚îÄ checkpoints/
‚îÇ   ‚îî‚îÄ‚îÄ student_epoch_5.safetensors   # Trained weights
‚îî‚îÄ‚îÄ models/
    ‚îî‚îÄ‚îÄ recursive_hasher.onnx          # ONNX for Go runtime
```

### üîó References:
- **DINOv3:** [arXiv:2508.10104](https://arxiv.org/abs/2508.10104)
- **TRM:** [arXiv:2510.04871](https://arxiv.org/abs/2510.04871)
- **Split Accumulation:** [ePrint 2020/1618](https://eprint.iacr.org/2020/1618)

## 7. Export ONNX to Google Drive

Save the trained model to Drive for use in the Go application.

In [None]:
# =============================================================================
# EXPORT TRAINED ONNX MODEL TO GOOGLE DRIVE
# =============================================================================

import shutil
from pathlib import Path

# Source: the ONNX file exported earlier in this notebook
source_onnx = Path("test_model.onnx")

# Destination on Drive
dest_dir = Path("/content/drive/MyDrive/dazzled/outputs/models")
dest_dir.mkdir(parents=True, exist_ok=True)
dest_onnx = dest_dir / "recursive_hasher.onnx"

if source_onnx.exists():
    shutil.copy(source_onnx, dest_onnx)
    size_mb = dest_onnx.stat().st_size / (1024 * 1024)
    print(f"‚úì ONNX model saved to Drive: {dest_onnx}")
    print(f"  Size: {size_mb:.2f} MB")
    
    # Also save the latest checkpoint
    ckpt_dir = Path("/content/drive/MyDrive/dazzled/outputs/checkpoints")
    checkpoints = sorted(ckpt_dir.glob("*.safetensors"))
    if checkpoints:
        print(f"\nüì¶ Available artifacts on Drive:")
        print(f"  Model: {dest_onnx}")
        print(f"  Checkpoint: {checkpoints[-1]}")
else:
    print("‚ö†Ô∏è  ONNX file not found. Run the ONNX export cell first (Cell 23-25).")