# DaZZLeD Hash Center Training Notebook (ResNet + Counterfactual VAE)

**Goal:** Train the ResNet Hash Center model from `resnet.tex` with counterfactual VAE, CF‑SimCLR, DHD, PGD, and TTC checks.

**Runtime:** Set Colab to GPU before running training cells.

**Note:** If you do not have VAE weights yet, you must train them first (Step 3). If you want a quick run without VAE, set `--counterfactual-mode aug` in Step 4.


## 0. Mount Google Drive


In [2]:
# Mount Google Drive (required for data storage)
from google.colab import drive
drive.mount('/content/drive')

# Create project directories
from pathlib import Path

DRIVE_ROOT = Path("/content/drive/MyDrive/dazzled")
DATA_ROOT = DRIVE_ROOT / "data"
OUTPUT_ROOT = DRIVE_ROOT / "outputs"

# Create all needed directories
for d in [
    DATA_ROOT / "ffhq",
    DATA_ROOT / "openimages",
    DATA_ROOT / "text",
    OUTPUT_ROOT / "checkpoints",
    OUTPUT_ROOT / "models",
    DRIVE_ROOT / "manifests",
]:
    d.mkdir(parents=True, exist_ok=True)

print(f"OK: Project root: {DRIVE_ROOT}")
print(f"OK: Data root: {DATA_ROOT}")
print(f"OK: Output root: {OUTPUT_ROOT}")


Mounted at /content/drive
OK: Project root: /content/drive/MyDrive/dazzled
OK: Data root: /content/drive/MyDrive/dazzled/data
OK: Output root: /content/drive/MyDrive/dazzled/outputs


In [3]:
# Credentials check (Colab secrets)
try:
    from google.colab import userdata
    required = ["KAGGLE_USERNAME", "KAGGLE_KEY", "HF_TOKEN"]
    for key in required:
        val = userdata.get(key)
        print(f"{key}: {'OK' if val else 'MISSING'}")
except Exception as e:
    print("Credentials check skipped:", e)


KAGGLE_USERNAME: OK
KAGGLE_KEY: OK
HF_TOKEN: OK


## 1. Setup & Installation


In [63]:
import os
if not os.path.exists('DaZZLeD'):
    !git clone https://github.com/D13ya/DaZZLeD.git
    %cd DaZZLeD/ml-core
    !grep -n "self-supervised" training/train_hashnet.py
else:
    %cd DaZZLeD/ml-core

!pip install -q -r requirements.txt
!git pull


shell-init: error retrieving current directory: getcwd: cannot access parent directories: No such file or directory
fatal: could not create work tree dir 'DaZZLeD': No such file or directory
[Errno 2] No such file or directory: 'DaZZLeD/ml-core'
/content/DaZZLeD/ml-core/DaZZLeD/ml-core/DaZZLeD/ml-core/DaZZLeD/ml-core/DaZZLeD/ml-core
shell-init: error retrieving current directory: getcwd: cannot access parent directories: No such file or directory
grep: training/train_hashnet.py: No such file or directory
shell-init: error retrieving current directory: getcwd: cannot access parent directories: No such file or directory
shell-init: error retrieving current directory: getcwd: cannot access parent directories: No such file or directory
The folder you are executing pip from can no longer be found.
shell-init: error retrieving current directory: getcwd: cannot access parent directories: No such file or directory
fatal: Unable to read current working directory: No such file or directory


## 2.1 Restore Dataset (Drive Zip)

If you keep a dataset zip on Drive, extract it once to local disk (`/content/data`) for faster I/O.


In [5]:
# DOWNLOAD & PREPARE DATASETS
# FFHQ (Kaggle), OpenImages (FiftyOne), MobileViews (HF)
# Restores from Drive cache if available; otherwise downloads and builds cache.

import importlib
import os
import subprocess
import shutil
import sys
from pathlib import Path
from tqdm.auto import tqdm

DATA_ROOT = Path("/content/data")
DRIVE_ROOT = Path("/content/drive/MyDrive/dazzled")
DRIVE_ARCHIVE = DRIVE_ROOT / "data-cache/training-images.zip"

# CHECK: Drive mounted
if not Path("/content/drive/MyDrive").exists():
    raise RuntimeError(
        "Google Drive is NOT mounted. Run the mount cell first, then re-run this cell."
    )

EXPECTED_COUNTS = {
    "ffhq": 40000,
    "openimages": 2500,
    "mobileviews": 2000,
}

def validate_dataset(data_root: Path, expected: dict):
    results = {}
    exts = {".jpg", ".jpeg", ".png", ".bmp"}
    for name, exp_count in expected.items():
        path = data_root / name
        if path.exists():
            actual = len([p for p in path.rglob("*") if p.is_file() and p.suffix.lower() in exts])
        else:
            actual = 0
        results[name] = {"count": actual, "expected": exp_count, "valid": actual >= exp_count * 0.95}
    return all(r["valid"] for r in results.values()), results

DATA_ROOT.mkdir(parents=True, exist_ok=True)

need_download = {"ffhq": False, "openimages": False, "mobileviews": False}
skip_downloads = False

# Option 1: Restore from Drive cache
if DRIVE_ARCHIVE.exists():
    print("Found cached dataset on Drive:", DRIVE_ARCHIVE)
    shutil.unpack_archive(DRIVE_ARCHIVE, DATA_ROOT)
    all_valid, validation = validate_dataset(DATA_ROOT, EXPECTED_COUNTS)
    if all_valid:
        print("All datasets restored from cache successfully.")
        for name, info in validation.items():
            print(f"  OK {name}: {info['count']:,} images")
        print("Skipping downloads - data is ready.")
        skip_downloads = True
    else:
        print("Cache incomplete; will download missing data:")
        for name, info in validation.items():
            if not info["valid"]:
                need_download[name] = True
                print(f"  MISSING {name}: {info['count']:,}/{info['expected']:,}")
            else:
                print(f"  OK {name}: {info['count']:,} images")
else:
    print("No cache found on Drive; downloading datasets.")
    need_download = {"ffhq": True, "openimages": True, "mobileviews": True}

# Option 2: Download fresh data
if not skip_downloads and any(need_download.values()):
    print("")
    print("=" * 65)
    print("DOWNLOADING DATASETS")
    print("=" * 65)

    import torchvision.transforms as transforms
    from PIL import Image
    import io

    # 1) FFHQ via Kaggle
    if need_download["ffhq"]:
        ffhq_dir = DATA_ROOT / "ffhq"
        ffhq_dir.mkdir(parents=True, exist_ok=True)

        print("")
        print("[1/3] FFHQ via Kaggle")
        print("Target: 40k face images")

        try:
            from google.colab import userdata
            import os
            os.environ["KAGGLE_USERNAME"] = userdata.get("KAGGLE_USERNAME")
            os.environ["KAGGLE_KEY"] = userdata.get("KAGGLE_KEY")
            print("Kaggle credentials loaded from Colab secrets")
        except Exception as e:
            print("Could not load Kaggle secrets:", e)
            print("Add KAGGLE_USERNAME and KAGGLE_KEY to Colab secrets")

        subprocess.run([
            "kaggle", "datasets", "download", "-d", "arnaud58/flickrfaceshq-dataset-ffhq",
            "-p", str(ffhq_dir), "--unzip"
        ], check=True)

        # Flatten directory structure
        for nested in ffhq_dir.rglob("*"):
            if nested.is_file() and nested.suffix.lower() in {".jpg", ".png"}:
                target = ffhq_dir / nested.name
                if not target.exists():
                    shutil.move(str(nested), str(target))

        for d in ffhq_dir.iterdir():
            if d.is_dir():
                shutil.rmtree(d)

        count = len(list(ffhq_dir.glob("*.jpg"))) + len(list(ffhq_dir.glob("*.png")))
        print(f"FFHQ: {count:,} images")

    # 2) OpenImages via FiftyOne
    if need_download["openimages"]:
        oi_dir = DATA_ROOT / "openimages"
        oi_dir.mkdir(parents=True, exist_ok=True)

        print("")
        print("[2/3] OpenImages via FiftyOne")
        print("Target: 2.5k diverse images")

        def ensure_fiftyone():
            # 1. Forcefully clear any existing broken fiftyone modules from cache
            to_remove = [m for m in sys.modules if m.startswith("fiftyone")]
            for m in to_remove:
                del sys.modules[m]

            # 2. Try importing fresh
            try:
                import fiftyone as fo
                import fiftyone.zoo as foz
                # Verify critical attribute exists
                if not hasattr(fo, "config"):
                    raise ImportError("fiftyone.config is missing")
                return fo, foz
            except (ImportError, AttributeError):
                return None, None

        fo, foz = ensure_fiftyone()

        if fo is None:
            print("FiftyOne missing or broken. Installing...")
            subprocess.check_call([sys.executable, "-m", "pip", "install", "fiftyone"])
            # Re-run the clean import after install
            fo, foz = ensure_fiftyone()

        if fo is None:
             raise RuntimeError("Failed to install/import fiftyone with valid config.")

        dataset = foz.load_zoo_dataset(
            "open-images-v7",
            split="validation",
            max_samples=2500,
            shuffle=True,
            seed=42,
        )

        for sample in dataset:
            src = Path(sample.filepath)
            dst = oi_dir / src.name
            if src.exists() and not dst.exists():
                shutil.copy2(src, dst)

        count = len(list(oi_dir.glob("*")))
        print(f"OpenImages: {count:,} images")

        fo.delete_dataset(dataset.name)

    # Optional: Hugging Face token (for gated datasets)
    try:
        from google.colab import userdata
        from huggingface_hub import login
        hf_token = userdata.get("HF_TOKEN")
        if hf_token:
            login(hf_token)
            print("HF token loaded")
        else:
            print("HF_TOKEN not found in Colab secrets")
    except Exception as e:
        print("HF login skipped:", e)

    # 3) MobileViews via HuggingFace
    if need_download["mobileviews"]:
        mv_dir = DATA_ROOT / "mobileviews"
        mv_dir.mkdir(parents=True, exist_ok=True)

        print("")
        print("[3/3] MobileViews via HuggingFace")
        print("Target: 2k mobile UI screenshots")

        try:
            from datasets import load_dataset

            ds = load_dataset(
                "mllmTeam/MobileViews",
                split="train",
                streaming=True,
            )

            transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.Lambda(lambda x: x.convert("RGB")),
            ])

            count = 0
            target = 2000
            for sample in ds:
                if count >= target:
                    break
                try:
                    img = sample.get("image")
                    if img is not None:
                        img = transform(img)
                        img.save(mv_dir / f"mv_{count:05d}.jpg", "JPEG", quality=85)
                        count += 1
                except Exception:
                    continue

            print(f"MobileViews: {count:,} images")

        except ImportError:
            print("datasets not installed. Installing...")
            subprocess.check_call([sys.executable, "-m", "pip", "install", "datasets"])
            print("Re-run this cell after installation")

    # Save cache to Drive
    print("")
    print("=" * 65)
    print("SAVING CACHE TO DRIVE")
    print("=" * 65)

    DRIVE_ARCHIVE.parent.mkdir(parents=True, exist_ok=True)

    print("Creating archive:", DRIVE_ARCHIVE)
    shutil.make_archive(
        str(DRIVE_ARCHIVE.with_suffix("")),
        "zip",
        DATA_ROOT,
    )

    archive_size = DRIVE_ARCHIVE.stat().st_size / (1024 ** 3)
    print(f"Cache saved ({archive_size:.2f} GB)")

# Final validation
print("")
print("=" * 65)
print("FINAL DATASET VALIDATION")
print("=" * 65)

all_valid, validation = validate_dataset(DATA_ROOT, EXPECTED_COUNTS)
total_images = sum(info["count"] for info in validation.values())

for name, info in validation.items():
    status = "OK" if info["valid"] else "BAD"
    print(f"{status} {name}: {info['count']:,} / {info['expected']:,} images")

print(f"Total: {total_images:,} images")

if all_valid:
    print("All datasets ready for training.")
else:
    print("Some datasets are incomplete. Re-run this cell to download more.")

Found cached dataset on Drive: /content/drive/MyDrive/dazzled/data-cache/training-images.zip
All datasets restored from cache successfully.
  OK ffhq: 52,001 images
  OK openimages: 2,500 images
  OK mobileviews: 2,000 images
Skipping downloads - data is ready.

FINAL DATASET VALIDATION
OK ffhq: 52,001 / 40,000 images
OK openimages: 2,500 / 2,500 images
OK mobileviews: 2,000 / 2,000 images
Total: 56,501 images
All datasets ready for training.


In [29]:
from pathlib import Path
from collections import Counter
from PIL import Image

exts = {".jpg", ".jpeg", ".png", ".bmp"}

if "DATA_ROOT" not in globals():
    DRIVE_ROOT = Path("/content/drive/MyDrive/dazzled")
    LOCAL_DATA = Path("/content/data")
    DRIVE_DATA = DRIVE_ROOT / "data"
    def has_images(root: Path) -> bool:
        if not root.exists():
            return False
        return any(p.is_file() and p.suffix.lower() in exts for p in root.rglob("*"))
    DATA_ROOT = LOCAL_DATA if has_images(LOCAL_DATA) else DRIVE_DATA

print(f"Validating DATA_ROOT: {DATA_ROOT}")

paths = [p for p in DATA_ROOT.rglob("*") if p.is_file() and p.suffix.lower() in exts]
print(f"Total images found: {len(paths)}")

ext_counts = Counter(p.suffix.lower() for p in paths)
print("Extension counts:", dict(ext_counts))

sample = paths[:200]
bad = []
for p in sample:
    try:
        with Image.open(p) as img:
            img.convert("RGB")
    except Exception as e:
        bad.append((p, str(e)))

if bad:
    print(f"Corrupt/unreadable samples: {len(bad)} (showing up to 5)")
    for p, err in bad[:5]:
        print(f"  - {p}: {err}")
else:
    print("Sample check: all images readable and convertible to RGB.")

if len(paths) == 0:
    raise ValueError("No images found for training. Check extraction paths or zip contents.")


Validating DATA_ROOT: /content/data
Total images found: 56500
Extension counts: {'.png': 52000, '.jpg': 4500}
Sample check: all images readable and convertible to RGB.


## 2.2 Build Manifest (Optional)

If you already have a manifest at `/content/drive/MyDrive/dazzled/manifests/train.txt`, you can skip this.


In [30]:
from pathlib import Path
import re

DRIVE_ROOT = Path("/content/drive/MyDrive/dazzled")
LOCAL_DATA = Path("/content/data")
DRIVE_DATA = DRIVE_ROOT / "data"
MANIFEST = DRIVE_ROOT / "manifests/train.txt"
MANIFEST.parent.mkdir(parents=True, exist_ok=True)

exts = {".jpg", ".jpeg", ".png", ".bmp"}
LABEL_REGEX = re.compile(r"^((?:ffhq|openimages|openimg|mobileviews?)_\d+|\d+)")

def has_images(root: Path) -> bool:
    if not root.exists():
        return False
    return any(p.is_file() and p.suffix.lower() in exts for p in root.rglob("*"))

DATA_ROOT = LOCAL_DATA if has_images(LOCAL_DATA) else DRIVE_DATA
print(f"Using DATA_ROOT: {DATA_ROOT}")

lines = []
for p in DATA_ROOT.rglob("*"):
    if not p.is_file() or p.suffix.lower() not in exts:
        continue
    match = LABEL_REGEX.search(p.name)
    label = match.group(1) if match else p.stem
    lines.append(f"{p} {label}")

MANIFEST.write_text("\n".join(lines))
print(f"Wrote {len(lines)} lines to {MANIFEST} (per-image labels)")


Using DATA_ROOT: /content/data
Wrote 56500 lines to /content/drive/MyDrive/dazzled/manifests/train.txt (per-image labels)


## 2.5. Sanity Checks (Labels + Domains)

Run this once after the manifest is created to verify labels/domains before any training.


In [27]:
from pathlib import Path
from collections import Counter
import re

MANIFEST = Path("/content/drive/MyDrive/dazzled/manifests/train.txt")
LABEL_REGEX = re.compile(r"^((?:ffhq|openimages|openimg|mobileview)_\d+|\d+)")
DOMAIN_REGEX = re.compile(r"(?:^|/)(ffhq|openimages|openimg|mobileviews?)(?:/|_)")

if not MANIFEST.exists():
    raise FileNotFoundError(f"Manifest not found: {MANIFEST}")

base = MANIFEST.resolve().parent
lines = [line.strip() for line in MANIFEST.read_text().splitlines() if line.strip() and not line.strip().startswith('#')]
total = len(lines)

labels = []
domains = []
missing = []

for line in lines:
    parts = line.split()
    path = Path(parts[0])
    if not path.is_absolute():
        path = (base / path).resolve()

    label = parts[1] if len(parts) > 1 else None
    if label is None:
        match = LABEL_REGEX.search(path.name)
        if match:
            label = match.group(1)

    domain = None
    match = DOMAIN_REGEX.search(str(path))
    if match:
        domain = match.group(1)

    labels.append(label)
    domains.append(domain)
    if not path.exists():
        missing.append(str(path))

label_known = [str(l) for l in labels if l is not None]
domain_known = [str(d) for d in domains if d is not None]

label_unique = len(set(label_known))
label_unlabeled = total - len(label_known)
label_pct = (label_unlabeled / total * 100.0) if total else 0.0

domain_unique = len(set(domain_known))
domain_unlabeled = total - len(domain_known)
domain_pct = (domain_unlabeled / total * 100.0) if total else 0.0

print(f"Label stats: {label_unique} unique, {label_unlabeled}/{total} unlabeled ({label_pct:.1f}%).")
if label_known:
    top = Counter(label_known).most_common(5)
    top_str = ", ".join(f"{k}:{v}" for k, v in top)
    print(f"Top labels: {top_str}")

print(f"Domain stats: {domain_unique} unique, {domain_unlabeled}/{total} unlabeled ({domain_pct:.1f}%).")
if domain_known:
    top = Counter(domain_known).most_common(5)
    top_str = ", ".join(f"{k}:{v}" for k, v in top)
    print(f"Top domains: {top_str}")

if missing:
    print(f"Missing files: {len(missing)} (showing up to 5)")
    for p in missing[:5]:
        print(f"  - {p}")

if label_unique < 2:
    raise ValueError("CRITICAL: fewer than 2 unique labels found. Check regex/manifest.")
if domain_unique < 2:
    print("WARNING: fewer than 2 unique domains found; VAE training will fail.")


Label stats: 55568 unique, 0/56500 unlabeled (0.0%).
Top labels: 5:67, 4:62, 2:61, 3:61, 7:57
Domain stats: 3 unique, 0/56500 unlabeled (0.0%).
Top domains: ffhq:52000, openimages:2500, mobileviews:2000


## 3. Train Counterfactual VAE (Save Weights)

This produces the `--counterfactual-weights` file used by HashNet.


In [31]:
import os
os.chdir("/content/DaZZLeD/ml-core")

!PYTHONPATH=/content/DaZZLeD/ml-core python training/train_counterfactual_vae.py \
  --data-list /content/drive/MyDrive/dazzled/manifests/train.txt \
  --epochs 10 \
  --batch-size 96 \
  --checkpoint-dir /content/drive/MyDrive/dazzled/outputs/cf_vae \
  --domain-mode regex \
  --domain-regex "(?:^|/)(ffhq|openimages|openimg|mobileviews?)(?:/|_)" \
  --workers 4 \
  --prefetch-factor 2 \
  --pin-memory \
  --amp


Domain stats: 3 unique, 0/56500 unlabeled (0.0%).
Top domains: ffhq:52000, openimages:2500, mobileviews:2000
Dataset: 56500 images, 588 batches

Counterfactual VAE Training
Domains: 3
Batch: 96, Epochs: 10
LR: 0.0001, WD: 0.01

E1 B50 loss=583680.9375 recon=583270.5833 kld=410.3486
E1 B100 loss=583605.8125 recon=583301.3750 kld=304.4023
E1 B150 loss=430412.8125 recon=430077.9167 kld=334.8824
E1 B200 loss=393849.4688 recon=393437.1667 kld=412.2728
E1 B250 loss=389832.8125 recon=389344.3333 kld=488.4381
E1 B300 loss=403627.1875 recon=403055.4583 kld=571.6875
E1 B350 loss=336065.8125 recon=335434.0208 kld=631.7756
E1 B400 loss=322120.9375 recon=321430.3958 kld=690.5272
E1 B450 loss=291824.8125 recon=291075.5208 kld=749.2822
E1 B500 loss=244133.4062 recon=243313.8750 kld=819.5300
E1 B550 loss=278633.2812 recon=277812.3958 kld=820.8669

>>> Epoch 1/10 done. Avg loss=394905.9670

Saved: /content/drive/MyDrive/dazzled/outputs/cf_vae/cf_vae_e1.safetensors
E2 B50 loss=255907.0938 recon=255017.4

In [32]:
from PIL import Image, ImageFile
import os
from pathlib import Path
from tqdm.auto import tqdm

# Define data root (matches your setup)
DATA_ROOT = Path("/content/data")

print(f"Scanning {DATA_ROOT} for corrupt images... This may take a few minutes.")
bad_files = []
exts = {".jpg", ".jpeg", ".png", ".bmp"}

# Gather all image paths
files = [p for p in DATA_ROOT.rglob("*") if p.is_file() and p.suffix.lower() in exts]

for p in tqdm(files):
    try:
        with Image.open(p) as img:
            img.load()  # Force a full load of pixel data to catch truncation
            img.convert("RGB") # Ensure it's valid image data
    except Exception as e:
        print(f"\nFound corrupt image: {p} - {e}")
        bad_files.append(p)

if bad_files:
    print(f"\nDeleting {len(bad_files)} corrupt images...")
    for p in bad_files:
        try:
            os.remove(p)
        except OSError as e:
            print(f"Error deleting {p}: {e}")
    print("Cleanup complete. \n\nIMPORTANT: Now go back and re-run Step 2.2 (Build Manifest) to remove these files from your training list.")
else:
    print("\nNo corrupt images detected by PIL.")

Scanning /content/data for corrupt images... This may take a few minutes.


  0%|          | 0/56500 [00:00<?, ?it/s]

KeyboardInterrupt: 

## 4. Train HashNet (ResNet + Hash Centers + CF/DHD/PGD)

This uses the VAE weights from Step 3 and writes checkpoints to Drive.


In [80]:
import os
os.chdir("/content/DaZZLeD/ml-core")

!PYTORCH_ALLOC_CONF=expandable_segments:True \
PYTHONPATH=/content/DaZZLeD/ml-core \
python training/train_hashnet.py \
  --data-list /content/drive/MyDrive/dazzled/manifests/train.txt \
  --backbone resnet50 \
  --epochs 10 \
  --batch-size 128 \
  --grad-checkpoint \
  --label-mode parent \
  --center-mode hadamard \
  --center-weight 10.0 \
  --distinct-weight 0.5 \
  --quant-weight 0.1 \
  --cf-weight 1.0 \
  --dhd-weight 1.0 \
  --hash-contrastive-weight 1.0 \
  --counterfactual-mode aug \
  --checkpoint-dir /content/drive/MyDrive/dazzled/outputs/hashnet \
  --workers 8 \
  --prefetch-factor 4 \
  --pin-memory \
  --amp \
  --channels-last \
  --allow-tf32 \
  --cudnn-benchmark \
  --lr 1e-4 \
  --warmup-steps 500 \
  --adv-weight 0

  self.setter(val)
Label stats: 3 unique, 0/56500 unlabeled (0.0%).
Top labels: ffhq:52000, openimages:2500, mobileviews:2000
Domain stats: 0 unique, 56500/56500 unlabeled (100.0%).
Dataset: 56500 images, 441 batches
Gradient checkpointing enabled (~50% memory savings)
Scheduler: 4,410 total steps, 500 warmup

Hash Center Training - ResNet
Mode: two-view
Backbone: resnet50 (pretrained=True)
Hash: 128d, Proj: 512d
Batch: 128, Epochs: 10
Losses: center(C)=10.0 distinct(D)=0.5 quant(Q)=0.1
Center mode: hadamard
Extra negatives: 0
Counterfactual mode: aug
Extras: CF=1.0 DHD=1.0 ADV=0.0 PGD=7 TTC=False

Hash centers: 3 (classes=3)
E1 B50 U50 tot=10.0072 ctr=0.6928 dst=0.6937 q=0.9999 cf=1.2072 dhd=0.4502 hcl=1.6688 adv=0.0000 lr=1.01e-05
E1 B100 U100 tot=8.3417 ctr=0.6926 dst=0.6933 q=0.9999 cf=0.3965 dhd=0.3286 hcl=0.9372 adv=0.0000 lr=2.01e-05
E1 B150 U150 tot=7.3632 ctr=0.6920 dst=0.6932 q=0.9998 cf=0.0937 dhd=0.1972 hcl=0.3986 adv=0.0000 lr=3.01e-05
E1 B200 U200 tot=7.0304 ctr=0.6910 ds

In [81]:
# STEP 4.5: HARDENING (Adversarial Fine-Tuning)
import os
os.chdir("/content/DaZZLeD/ml-core")

!PYTORCH_ALLOC_CONF=expandable_segments:True \
PYTHONPATH=/content/DaZZLeD/ml-core \
python training/train_hashnet.py \
  --data-list /content/drive/MyDrive/dazzled/manifests/train.txt \
  --backbone resnet50 \
  --resume /content/drive/MyDrive/dazzled/outputs/hashnet/student_e10.safetensors \
  --epochs 2 \
  --batch-size 128 \
  --grad-checkpoint \
  --label-mode parent \
  --center-mode hadamard \
  --center-weight 1.0 \
  --distinct-weight 0.5 \
  --quant-weight 0.1 \
  --cf-weight 1.0 \
  --dhd-weight 1.0 \
  --hash-contrastive-weight 1.0 \
  --counterfactual-mode aug \
  --adv-weight 0.1 \
  --pgd-steps 3 \
  --checkpoint-dir /content/drive/MyDrive/dazzled/outputs/hashnet_hardened \
  --workers 4 \
  --prefetch-factor 2 \
  --pin-memory \
  --amp \
  --channels-last \
  --lr 1e-5 \
  --warmup-steps 0 \
  --adv-weight 0.1

Label stats: 3 unique, 0/56500 unlabeled (0.0%).
Top labels: ffhq:52000, openimages:2500, mobileviews:2000
Domain stats: 0 unique, 56500/56500 unlabeled (100.0%).
Dataset: 56500 images, 441 batches
Gradient checkpointing enabled (~50% memory savings)
Resuming from checkpoint: /content/drive/MyDrive/dazzled/outputs/hashnet/student_e10.safetensors
Checkpoint loaded successfully
Scheduler: 882 total steps, 0 warmup

Hash Center Training - ResNet
Mode: two-view
Backbone: resnet50 (pretrained=True)
Hash: 128d, Proj: 512d
Batch: 128, Epochs: 2
Losses: center(C)=1.0 distinct(D)=0.5 quant(Q)=0.1
Center mode: hadamard
Extra negatives: 0
Counterfactual mode: aug
Extras: CF=1.0 DHD=1.0 ADV=0.1 PGD=3 TTC=False

Hash centers: 3 (classes=3)
E1 B50 U50 tot=-706.7818 ctr=0.0723 dst=1301.2964 q=0.6179 cf=3.9683 dhd=0.0078 hcl=4.8206 adv=-650.6437 lr=9.92e-06
E1 B100 U100 tot=-711.5756 ctr=0.0708 dst=1311.2511 q=0.6100 cf=4.0991 dhd=0.0061 hcl=5.0183 adv=-652.0532 lr=9.69e-06
Traceback (most recent call

In [93]:
import os
os.chdir("/content/DaZZLeD/ml-core")

!PYTORCH_ALLOC_CONF=expandable_segments:True \
PYTHONPATH=/content/DaZZLeD/ml-core \
python training/train_hashnet.py \
  --data-list /content/drive/MyDrive/dazzled/manifests/train.txt \
  --backbone resnet50 \
  --epochs 10 \
  --batch-size 128 \
  --grad-checkpoint \
  --label-mode none \
  --center-weight 0 \
  --distinct-weight 0 \
  --quant-weight 0.1 \
  --cf-weight 0 \
  --dhd-weight 0.5 \
  --hash-contrastive-weight 1.0 \
  --adv-weight 0.0 \
  --counterfactual-mode aug \
  --checkpoint-dir /content/drive/MyDrive/dazzled/outputs/hashnet_perimage_test \
  --workers 4 \
  --pin-memory \
  --amp \
  --channels-last \
  --lr 5e-4 \
  --warmup-steps 500

Domain stats: 0 unique, 56500/56500 unlabeled (100.0%).
Dataset: 56500 images, 441 batches
Gradient checkpointing enabled (~50% memory savings)
Scheduler: 4,410 total steps, 500 warmup

Hash Center Training - ResNet
Mode: two-view
Backbone: resnet50 (pretrained=True)
Hash: 128d, Proj: 512d
Batch: 128, Epochs: 10
Losses: center(C)=0.0 distinct(D)=0.0 quant(Q)=0.1
Center mode: hadamard
Extra negatives: 0
Counterfactual mode: aug
Extras: CF=0.0 DHD=0.5 ADV=0.0 PGD=7 TTC=False

E1 B50 U50 tot=0.9288 ctr=0.0000 dst=0.0000 q=0.9998 cf=0.0000 dhd=0.2751 hcl=0.6913 adv=0.0000 lr=5.04e-05
E1 B100 U100 tot=0.3527 ctr=0.0000 dst=0.0000 q=0.9998 cf=0.0000 dhd=0.1249 hcl=0.1903 adv=0.0000 lr=1.00e-04
E1 B150 U150 tot=0.2274 ctr=0.0000 dst=0.0000 q=0.9997 cf=0.0000 dhd=0.0622 hcl=0.0963 adv=0.0000 lr=1.50e-04
E1 B200 U200 tot=0.1928 ctr=0.0000 dst=0.0000 q=0.9997 cf=0.0000 dhd=0.0468 hcl=0.0694 adv=0.0000 lr=2.00e-04
E1 B250 U250 tot=0.1738 ctr=0.0000 dst=0.0000 q=0.9997 cf=0.0000 dhd=0.0425 hcl=0.0

## 5. List Checkpoints


In [105]:
from pathlib import Path

CKPT_DIR = Path("/content/drive/MyDrive/dazzled/outputs/hashnet")
ckpts = sorted(CKPT_DIR.glob("*.safetensors"))
print(f"Found {len(ckpts)} checkpoints")
for ckpt in ckpts:
    print(ckpt.name)


Found 11 checkpoints
student_e1.safetensors
student_e10.safetensors
student_e2.safetensors
student_e3.safetensors
student_e4.safetensors
student_e5.safetensors
student_e6.safetensors
student_e7.safetensors
student_e8.safetensors
student_e9.safetensors
student_final.safetensors


## 6. TTC Inference (Production-Style)

Run the standalone TTC inference script on a sample image.


In [107]:
import os
os.chdir("/content/DaZZLeD/ml-core")

from pathlib import Path

CKPT_DIR = Path("/content/drive/MyDrive/dazzled/outputs/hashnet_perimage_test/")
IMAGE_PATH = "/content/data/openimages/0019308d876736fe.jpg"  # TODO: set a real path
IMAGE_PATH2 = "/content/data/openimages/0019308d876736fe.jpg"  # TODO: set a real path
IMAGE_PATH3 = "/content/data/openimages/0a37aa0734ac8016.jpg"  # TODO: set a real path
IMAGE_PATH4 = "/content/data/mobileviews/mv_00034.jpg"  # TODO: set a real path
IMAGE_PATH5 = "/content/data/mobileviews/mv_00334.jpg"  # TODO: set a real path


ckpts = sorted(CKPT_DIR.glob("*.safetensors"))
if not ckpts:
    raise FileNotFoundError(f"No checkpoints in {CKPT_DIR}")

checkpoint = str(ckpts[-1])
print(f"Using checkpoint: {checkpoint}")

!python inference.py   --image "{IMAGE_PATH}"   --checkpoint "{checkpoint}"   --backbone resnet50   --hash-dim 128   --proj-dim 512   --ttc-views 8   --stability-threshold 0.9   --hamming-threshold 10
!python inference.py   --image "{IMAGE_PATH2}"   --checkpoint "{checkpoint}"   --backbone resnet50   --hash-dim 128   --proj-dim 512   --ttc-views 8   --stability-threshold 0.9   --hamming-threshold 10
!python inference.py   --image "{IMAGE_PATH3}"   --checkpoint "{checkpoint}"   --backbone resnet50   --hash-dim 128   --proj-dim 512   --ttc-views 8   --stability-threshold 0.9   --hamming-threshold 10
!python inference.py   --image "{IMAGE_PATH4}"   --checkpoint "{checkpoint}"   --backbone resnet50   --hash-dim 128   --proj-dim 512   --ttc-views 8   --stability-threshold 0.9   --hamming-threshold 10
!python inference.py   --image "{IMAGE_PATH5}"   --checkpoint "{checkpoint}"   --backbone resnet50   --hash-dim 128   --proj-dim 512   --ttc-views 8   --stability-threshold 0.9   --hamming-threshold 10



Using checkpoint: /content/drive/MyDrive/dazzled/outputs/hashnet_perimage_test/student_final.safetensors
TTC: ACCEPT
Stability: 0.9668
Max Hamming distance: 10.00
Hash (128 bits): 10000000101110101111001000011111101110011101001101000101010110110110001100000011001101101101010101011101111000111000111101100101
TTC: ACCEPT
Stability: 0.9668
Max Hamming distance: 10.00
Hash (128 bits): 10000000101110101111001000011111101110011101001101000101010110110110001100000011001101101101010101011101111000111000111101100101
TTC: ACCEPT
Stability: 0.9863
Max Hamming distance: 4.00
Hash (128 bits): 10001110010011101010010000000110010010010111111110111101101000100000000100010110000000101011111110100011111000000000011101001000
TTC: ACCEPT
Stability: 0.9785
Max Hamming distance: 4.00
Hash (128 bits): 01010110101000010111000000001101011111011000110001101101000000001101010110111101001101100100110000011110000010101010101110011000
TTC: ACCEPT
Stability: 0.9473
Max Hamming distance: 8.00
Hash (128 bits): 1101001

In [109]:
import os
!pip install onnxscript onnxruntime -q
os.chdir("/content/DaZZLeD/ml-core")

import torch
import safetensors.torch
from training.train_hashnet import ResNetHashNet

# Load your trained model
model = ResNetHashNet("resnet50", hash_dim=128, proj_dim=512, pretrained=False)
safetensors.torch.load_model(
    model,
    "/content/drive/MyDrive/dazzled/outputs/hashnet_perimage_test/student_final.safetensors"
)
model.eval()

# Verify output format
dummy = torch.randn(1, 3, 224, 224)
out = model(dummy)
print(f"Output type: {type(out)}, shape: {out.shape if isinstance(out, torch.Tensor) else [o.shape for o in out]}")

# Wrapper class for clean single-output ONNX export
class HashNetONNX(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        out = self.model(x, return_proj=False)
        # Ensure single tensor output
        if isinstance(out, tuple):
            return out[0]
        return out

# Wrap and export
wrapper = HashNetONNX(model)
wrapper.eval()

dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
    wrapper,
    dummy_input,
    "/content/drive/MyDrive/dazzled/outputs/hashnet_perimage_test/hashnet.onnx",
    input_names=["image"],
    output_names=["hash"],
    dynamic_axes={"image": {0: "batch"}, "hash": {0: "batch"}},
    opset_version=14
)
print("Exported to hashnet.onnx")


# Verify ONNX output
import onnxruntime as ort
session = ort.InferenceSession("/content/drive/MyDrive/dazzled/outputs/hashnet_perimage_test/hashnet.onnx")
print(f"ONNX inputs: {[i.name for i in session.get_inputs()]}")
print(f"ONNX outputs: {[o.name for o in session.get_outputs()]}")

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/693.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m693.4/693.4 kB[0m [31m20.8 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/139.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m139.1/139.1 kB[0m [31m13.8 MB/s[0m eta [36m0:00:00[0m
[?25hOutput type: <class 'torch.Tensor'>, shape: torch.Size([1, 128])


  torch.onnx.export(
W0112 17:50:44.549000 2798 torch/onnx/_internal/exporter/_compat.py:114] Setting ONNX exporter to use operator set version 18 because the requested opset_version 14 is a lower version than we have implementations for. Automatic version conversion will be performed, which may not be successful at converting to the requested version. If version conversion is unsuccessful, the opset version of the exported model will be kept at 18. Please consider setting opset_version >=18 to leverage latest ONNX features


[torch.onnx] Obtain model graph for `HashNetONNX([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `HashNetONNX([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...


Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/onnxscript/version_converter/__init__.py", line 127, in call
    converted_proto = _c_api_utils.call_onnx_api(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/onnxscript/version_converter/_c_api_utils.py", line 65, in call_onnx_api
    result = func(proto)
             ^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/onnxscript/version_converter/__init__.py", line 122, in _partial_convert_version
    return onnx.version_converter.convert_version(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/onnx/version_converter.py", line 39, in convert_version
    converted_model_str = C.convert_version(model_str, target_version)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: /github/workspace/onnx/version_converter/adapters/axes_input_to_attribute.h:65: adapt: Asserti

[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
Applied 107 of general pattern rewrite rules.
Exported to hashnet.onnx
ONNX inputs: ['image']
ONNX outputs: ['hash']
