# DaZZLeD Hash Center Training Notebook (ResNet + Counterfactual VAE)

**Goal:** Train the ResNet Hash Center model from `resnet.tex` with counterfactual VAE, CF‑SimCLR, DHD, PGD, and TTC checks.

**Runtime:** Set Colab to GPU before running training cells.

**Note:** If you do not have VAE weights yet, you must train them first (Step 3). If you want a quick run without VAE, set `--counterfactual-mode aug` in Step 4.


## 0. Mount Google Drive


In [2]:
# Mount Google Drive (required for data storage)
from google.colab import drive
drive.mount('/content/drive')

# Create project directories
from pathlib import Path

DRIVE_ROOT = Path("/content/drive/MyDrive/dazzled")
DATA_ROOT = DRIVE_ROOT / "data"
OUTPUT_ROOT = DRIVE_ROOT / "outputs"

# Create all needed directories
for d in [
    DATA_ROOT / "ffhq",
    DATA_ROOT / "openimages",
    DATA_ROOT / "text",
    OUTPUT_ROOT / "checkpoints",
    OUTPUT_ROOT / "models",
    DRIVE_ROOT / "manifests",
]:
    d.mkdir(parents=True, exist_ok=True)

print(f"OK: Project root: {DRIVE_ROOT}")
print(f"OK: Data root: {DATA_ROOT}")
print(f"OK: Output root: {OUTPUT_ROOT}")


Mounted at /content/drive
OK: Project root: /content/drive/MyDrive/dazzled
OK: Data root: /content/drive/MyDrive/dazzled/data
OK: Output root: /content/drive/MyDrive/dazzled/outputs


In [3]:
# Credentials check (Colab secrets)
try:
    from google.colab import userdata
    required = ["KAGGLE_USERNAME", "KAGGLE_KEY", "HF_TOKEN"]
    for key in required:
        val = userdata.get(key)
        print(f"{key}: {'OK' if val else 'MISSING'}")
except Exception as e:
    print("Credentials check skipped:", e)


KAGGLE_USERNAME: OK
KAGGLE_KEY: OK
HF_TOKEN: OK


## 1. Setup & Installation


In [4]:
import os
if not os.path.exists('DaZZLeD'):
    !git clone https://github.com/D13ya/DaZZLeD.git
    %cd DaZZLeD/ml-core
else:
    %cd DaZZLeD/ml-core

!pip install -q -r requirements.txt


Cloning into 'DaZZLeD'...
remote: Enumerating objects: 455, done.[K
remote: Counting objects: 100% (455/455), done.[K
remote: Compressing objects: 100% (275/275), done.[K
remote: Total 455 (delta 201), reused 370 (delta 119), pack-reused 0 (from 0)[K
Receiving objects: 100% (455/455), 272.17 KiB | 9.72 MiB/s, done.
Resolving deltas: 100% (201/201), done.
/content/DaZZLeD/ml-core
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.5/17.5 MB[0m [31m91.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.4/17.4 MB[0m [31m89.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.8/86.8 kB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[?25h

## 2.1 Restore Dataset (Drive Zip)

If you keep a dataset zip on Drive, extract it once to local disk (`/content/data`) for faster I/O.


In [5]:
# DOWNLOAD & PREPARE DATASETS
# FFHQ (Kaggle), OpenImages (FiftyOne), MobileViews (HF)
# Restores from Drive cache if available; otherwise downloads and builds cache.

import importlib
import os
import subprocess
import shutil
import sys
from pathlib import Path
from tqdm.auto import tqdm

DATA_ROOT = Path("/content/data")
DRIVE_ROOT = Path("/content/drive/MyDrive/dazzled")
DRIVE_ARCHIVE = DRIVE_ROOT / "data-cache/training-images.zip"

# CHECK: Drive mounted
if not Path("/content/drive/MyDrive").exists():
    raise RuntimeError(
        "Google Drive is NOT mounted. Run the mount cell first, then re-run this cell."
    )

EXPECTED_COUNTS = {
    "ffhq": 40000,
    "openimages": 2500,
    "mobileviews": 2000,
}

def validate_dataset(data_root: Path, expected: dict):
    results = {}
    exts = {".jpg", ".jpeg", ".png", ".bmp"}
    for name, exp_count in expected.items():
        path = data_root / name
        if path.exists():
            actual = len([p for p in path.rglob("*") if p.is_file() and p.suffix.lower() in exts])
        else:
            actual = 0
        results[name] = {"count": actual, "expected": exp_count, "valid": actual >= exp_count * 0.95}
    return all(r["valid"] for r in results.values()), results

DATA_ROOT.mkdir(parents=True, exist_ok=True)

need_download = {"ffhq": False, "openimages": False, "mobileviews": False}
skip_downloads = False

# Option 1: Restore from Drive cache
if DRIVE_ARCHIVE.exists():
    print("Found cached dataset on Drive:", DRIVE_ARCHIVE)
    shutil.unpack_archive(DRIVE_ARCHIVE, DATA_ROOT)
    all_valid, validation = validate_dataset(DATA_ROOT, EXPECTED_COUNTS)
    if all_valid:
        print("All datasets restored from cache successfully.")
        for name, info in validation.items():
            print(f"  OK {name}: {info['count']:,} images")
        print("Skipping downloads - data is ready.")
        skip_downloads = True
    else:
        print("Cache incomplete; will download missing data:")
        for name, info in validation.items():
            if not info["valid"]:
                need_download[name] = True
                print(f"  MISSING {name}: {info['count']:,}/{info['expected']:,}")
            else:
                print(f"  OK {name}: {info['count']:,} images")
else:
    print("No cache found on Drive; downloading datasets.")
    need_download = {"ffhq": True, "openimages": True, "mobileviews": True}

# Option 2: Download fresh data
if not skip_downloads and any(need_download.values()):
    print("")
    print("=" * 65)
    print("DOWNLOADING DATASETS")
    print("=" * 65)

    import torchvision.transforms as transforms
    from PIL import Image
    import io

    # 1) FFHQ via Kaggle
    if need_download["ffhq"]:
        ffhq_dir = DATA_ROOT / "ffhq"
        ffhq_dir.mkdir(parents=True, exist_ok=True)

        print("")
        print("[1/3] FFHQ via Kaggle")
        print("Target: 40k face images")

        try:
            from google.colab import userdata
            import os
            os.environ["KAGGLE_USERNAME"] = userdata.get("KAGGLE_USERNAME")
            os.environ["KAGGLE_KEY"] = userdata.get("KAGGLE_KEY")
            print("Kaggle credentials loaded from Colab secrets")
        except Exception as e:
            print("Could not load Kaggle secrets:", e)
            print("Add KAGGLE_USERNAME and KAGGLE_KEY to Colab secrets")

        subprocess.run([
            "kaggle", "datasets", "download", "-d", "arnaud58/flickrfaceshq-dataset-ffhq",
            "-p", str(ffhq_dir), "--unzip"
        ], check=True)

        # Flatten directory structure
        for nested in ffhq_dir.rglob("*"):
            if nested.is_file() and nested.suffix.lower() in {".jpg", ".png"}:
                target = ffhq_dir / nested.name
                if not target.exists():
                    shutil.move(str(nested), str(target))

        for d in ffhq_dir.iterdir():
            if d.is_dir():
                shutil.rmtree(d)

        count = len(list(ffhq_dir.glob("*.jpg"))) + len(list(ffhq_dir.glob("*.png")))
        print(f"FFHQ: {count:,} images")

    # 2) OpenImages via FiftyOne
    if need_download["openimages"]:
        oi_dir = DATA_ROOT / "openimages"
        oi_dir.mkdir(parents=True, exist_ok=True)

        print("")
        print("[2/3] OpenImages via FiftyOne")
        print("Target: 2.5k diverse images")

        def ensure_fiftyone():
            # 1. Forcefully clear any existing broken fiftyone modules from cache
            to_remove = [m for m in sys.modules if m.startswith("fiftyone")]
            for m in to_remove:
                del sys.modules[m]

            # 2. Try importing fresh
            try:
                import fiftyone as fo
                import fiftyone.zoo as foz
                # Verify critical attribute exists
                if not hasattr(fo, "config"):
                    raise ImportError("fiftyone.config is missing")
                return fo, foz
            except (ImportError, AttributeError):
                return None, None

        fo, foz = ensure_fiftyone()

        if fo is None:
            print("FiftyOne missing or broken. Installing...")
            subprocess.check_call([sys.executable, "-m", "pip", "install", "fiftyone"])
            # Re-run the clean import after install
            fo, foz = ensure_fiftyone()

        if fo is None:
             raise RuntimeError("Failed to install/import fiftyone with valid config.")

        dataset = foz.load_zoo_dataset(
            "open-images-v7",
            split="validation",
            max_samples=2500,
            shuffle=True,
            seed=42,
        )

        for sample in dataset:
            src = Path(sample.filepath)
            dst = oi_dir / src.name
            if src.exists() and not dst.exists():
                shutil.copy2(src, dst)

        count = len(list(oi_dir.glob("*")))
        print(f"OpenImages: {count:,} images")

        fo.delete_dataset(dataset.name)

    # Optional: Hugging Face token (for gated datasets)
    try:
        from google.colab import userdata
        from huggingface_hub import login
        hf_token = userdata.get("HF_TOKEN")
        if hf_token:
            login(hf_token)
            print("HF token loaded")
        else:
            print("HF_TOKEN not found in Colab secrets")
    except Exception as e:
        print("HF login skipped:", e)

    # 3) MobileViews via HuggingFace
    if need_download["mobileviews"]:
        mv_dir = DATA_ROOT / "mobileviews"
        mv_dir.mkdir(parents=True, exist_ok=True)

        print("")
        print("[3/3] MobileViews via HuggingFace")
        print("Target: 2k mobile UI screenshots")

        try:
            from datasets import load_dataset

            ds = load_dataset(
                "mllmTeam/MobileViews",
                split="train",
                streaming=True,
            )

            transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.Lambda(lambda x: x.convert("RGB")),
            ])

            count = 0
            target = 2000
            for sample in ds:
                if count >= target:
                    break
                try:
                    img = sample.get("image")
                    if img is not None:
                        img = transform(img)
                        img.save(mv_dir / f"mv_{count:05d}.jpg", "JPEG", quality=85)
                        count += 1
                except Exception:
                    continue

            print(f"MobileViews: {count:,} images")

        except ImportError:
            print("datasets not installed. Installing...")
            subprocess.check_call([sys.executable, "-m", "pip", "install", "datasets"])
            print("Re-run this cell after installation")

    # Save cache to Drive
    print("")
    print("=" * 65)
    print("SAVING CACHE TO DRIVE")
    print("=" * 65)

    DRIVE_ARCHIVE.parent.mkdir(parents=True, exist_ok=True)

    print("Creating archive:", DRIVE_ARCHIVE)
    shutil.make_archive(
        str(DRIVE_ARCHIVE.with_suffix("")),
        "zip",
        DATA_ROOT,
    )

    archive_size = DRIVE_ARCHIVE.stat().st_size / (1024 ** 3)
    print(f"Cache saved ({archive_size:.2f} GB)")

# Final validation
print("")
print("=" * 65)
print("FINAL DATASET VALIDATION")
print("=" * 65)

all_valid, validation = validate_dataset(DATA_ROOT, EXPECTED_COUNTS)
total_images = sum(info["count"] for info in validation.values())

for name, info in validation.items():
    status = "OK" if info["valid"] else "BAD"
    print(f"{status} {name}: {info['count']:,} / {info['expected']:,} images")

print(f"Total: {total_images:,} images")

if all_valid:
    print("All datasets ready for training.")
else:
    print("Some datasets are incomplete. Re-run this cell to download more.")

Found cached dataset on Drive: /content/drive/MyDrive/dazzled/data-cache/training-images.zip
All datasets restored from cache successfully.
  OK ffhq: 52,001 images
  OK openimages: 2,500 images
  OK mobileviews: 2,000 images
Skipping downloads - data is ready.

FINAL DATASET VALIDATION
OK ffhq: 52,001 / 40,000 images
OK openimages: 2,500 / 2,500 images
OK mobileviews: 2,000 / 2,000 images
Total: 56,501 images
All datasets ready for training.


In [6]:
from pathlib import Path
from collections import Counter
from PIL import Image

exts = {".jpg", ".jpeg", ".png", ".bmp"}

if "DATA_ROOT" not in globals():
    DRIVE_ROOT = Path("/content/drive/MyDrive/dazzled")
    LOCAL_DATA = Path("/content/data")
    DRIVE_DATA = DRIVE_ROOT / "data"
    def has_images(root: Path) -> bool:
        if not root.exists():
            return False
        return any(p.is_file() and p.suffix.lower() in exts for p in root.rglob("*"))
    DATA_ROOT = LOCAL_DATA if has_images(LOCAL_DATA) else DRIVE_DATA

print(f"Validating DATA_ROOT: {DATA_ROOT}")

paths = [p for p in DATA_ROOT.rglob("*") if p.is_file() and p.suffix.lower() in exts]
print(f"Total images found: {len(paths)}")

ext_counts = Counter(p.suffix.lower() for p in paths)
print("Extension counts:", dict(ext_counts))

sample = paths[:200]
bad = []
for p in sample:
    try:
        with Image.open(p) as img:
            img.convert("RGB")
    except Exception as e:
        bad.append((p, str(e)))

if bad:
    print(f"Corrupt/unreadable samples: {len(bad)} (showing up to 5)")
    for p, err in bad[:5]:
        print(f"  - {p}: {err}")
else:
    print("Sample check: all images readable and convertible to RGB.")

if len(paths) == 0:
    raise ValueError("No images found for training. Check extraction paths or zip contents.")


Validating DATA_ROOT: /content/data
Total images found: 56501
Extension counts: {'.png': 52001, '.jpg': 4500}
Sample check: all images readable and convertible to RGB.


## 2.2 Build Manifest (Optional)

If you already have a manifest at `/content/drive/MyDrive/dazzled/manifests/train.txt`, you can skip this.


In [7]:
from pathlib import Path
import re

DRIVE_ROOT = Path("/content/drive/MyDrive/dazzled")
LOCAL_DATA = Path("/content/data")
DRIVE_DATA = DRIVE_ROOT / "data"
MANIFEST = DRIVE_ROOT / "manifests/train.txt"
MANIFEST.parent.mkdir(parents=True, exist_ok=True)

exts = {".jpg", ".jpeg", ".png", ".bmp"}
LABEL_REGEX = re.compile(r"^((?:ffhq|openimages|openimg|mobileviews?)_\d+|\d+)")

def has_images(root: Path) -> bool:
    if not root.exists():
        return False
    return any(p.is_file() and p.suffix.lower() in exts for p in root.rglob("*"))

DATA_ROOT = LOCAL_DATA if has_images(LOCAL_DATA) else DRIVE_DATA
print(f"Using DATA_ROOT: {DATA_ROOT}")

lines = []
for p in DATA_ROOT.rglob("*"):
    if not p.is_file() or p.suffix.lower() not in exts:
        continue
    match = LABEL_REGEX.search(p.name)
    label = match.group(1) if match else p.stem
    lines.append(f"{p} {label}")

MANIFEST.write_text("\n".join(lines))
print(f"Wrote {len(lines)} lines to {MANIFEST} (per-image labels)")


Using DATA_ROOT: /content/data
Wrote 56501 lines to /content/drive/MyDrive/dazzled/manifests/train.txt (per-image labels)


## 2.5. Sanity Checks (Labels + Domains)

Run this once after the manifest is created to verify labels/domains before any training.


In [8]:
from pathlib import Path
from collections import Counter
import re

MANIFEST = Path("/content/drive/MyDrive/dazzled/manifests/train.txt")
LABEL_REGEX = re.compile(r"^((?:ffhq|openimages|openimg|mobileview)_\d+|\d+)")
DOMAIN_REGEX = re.compile(r"(?:^|/)(ffhq|openimages|openimg|mobileviews?)(?:/|_)")

if not MANIFEST.exists():
    raise FileNotFoundError(f"Manifest not found: {MANIFEST}")

base = MANIFEST.resolve().parent
lines = [line.strip() for line in MANIFEST.read_text().splitlines() if line.strip() and not line.strip().startswith('#')]
total = len(lines)

labels = []
domains = []
missing = []

for line in lines:
    parts = line.split()
    path = Path(parts[0])
    if not path.is_absolute():
        path = (base / path).resolve()

    label = parts[1] if len(parts) > 1 else None
    if label is None:
        match = LABEL_REGEX.search(path.name)
        if match:
            label = match.group(1)

    domain = None
    match = DOMAIN_REGEX.search(str(path))
    if match:
        domain = match.group(1)

    labels.append(label)
    domains.append(domain)
    if not path.exists():
        missing.append(str(path))

label_known = [str(l) for l in labels if l is not None]
domain_known = [str(d) for d in domains if d is not None]

label_unique = len(set(label_known))
label_unlabeled = total - len(label_known)
label_pct = (label_unlabeled / total * 100.0) if total else 0.0

domain_unique = len(set(domain_known))
domain_unlabeled = total - len(domain_known)
domain_pct = (domain_unlabeled / total * 100.0) if total else 0.0

print(f"Label stats: {label_unique} unique, {label_unlabeled}/{total} unlabeled ({label_pct:.1f}%).")
if label_known:
    top = Counter(label_known).most_common(5)
    top_str = ", ".join(f"{k}:{v}" for k, v in top)
    print(f"Top labels: {top_str}")

print(f"Domain stats: {domain_unique} unique, {domain_unlabeled}/{total} unlabeled ({domain_pct:.1f}%).")
if domain_known:
    top = Counter(domain_known).most_common(5)
    top_str = ", ".join(f"{k}:{v}" for k, v in top)
    print(f"Top domains: {top_str}")

if missing:
    print(f"Missing files: {len(missing)} (showing up to 5)")
    for p in missing[:5]:
        print(f"  - {p}")

if label_unique < 2:
    raise ValueError("CRITICAL: fewer than 2 unique labels found. Check regex/manifest.")
if domain_unique < 2:
    print("WARNING: fewer than 2 unique domains found; VAE training will fail.")


Label stats: 55569 unique, 0/56501 unlabeled (0.0%).
Top labels: 5:67, 4:62, 2:61, 3:61, 7:57
Domain stats: 3 unique, 0/56501 unlabeled (0.0%).
Top domains: ffhq:52001, openimages:2500, mobileviews:2000


## 3. Train Counterfactual VAE (Save Weights)

This produces the `--counterfactual-weights` file used by HashNet.


In [9]:
import os
os.chdir("/content/DaZZLeD/ml-core")

!PYTHONPATH=/content/DaZZLeD/ml-core python training/train_counterfactual_vae.py \
  --data-list /content/drive/MyDrive/dazzled/manifests/train.txt \
  --epochs 10 \
  --batch-size 96 \
  --checkpoint-dir /content/drive/MyDrive/dazzled/outputs/cf_vae \
  --domain-mode regex \
  --domain-regex "(?:^|/)(ffhq|openimages|openimg|mobileviews?)(?:/|_)" \
  --workers 4 \
  --prefetch-factor 2 \
  --pin-memory \
  --amp


Domain stats: 3 unique, 0/56501 unlabeled (0.0%).
Top domains: ffhq:52001, openimages:2500, mobileviews:2000
Dataset: 56501 images, 588 batches

Counterfactual VAE Training
Domains: 3
Batch: 96, Epochs: 10
LR: 0.0001, WD: 0.01

E1 B50 loss=575423.1875 recon=575009.6250 kld=413.5504
E1 B100 loss=555152.5625 recon=554856.3750 kld=296.1842
E1 B150 loss=470267.3438 recon=469912.6250 kld=354.6986
E1 B200 loss=378435.1875 recon=378015.5417 kld=419.6417
E1 B250 loss=408987.1875 recon=408496.3750 kld=490.7999
Traceback (most recent call last):
  File "/content/DaZZLeD/ml-core/training/train_counterfactual_vae.py", line 383, in <module>
    main()
  File "/content/DaZZLeD/ml-core/training/train_counterfactual_vae.py", line 379, in main
    train(args)
  File "/content/DaZZLeD/ml-core/training/train_counterfactual_vae.py", line 291, in train
    for batch_idx, batch in enumerate(loader, start=1):
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packag

## 4. Train HashNet (ResNet + Hash Centers + CF/DHD/PGD)

This uses the VAE weights from Step 3 and writes checkpoints to Drive.


In [None]:
import os
os.chdir("/content/DaZZLeD/ml-core")

!PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
PYTHONPATH=/content/DaZZLeD/ml-core \
python training/train_hashnet.py \
  --data-list /content/drive/MyDrive/dazzled/manifests/train.txt \
  --epochs 10 \
  --batch-size 64 \
  --center-mode random \
  --extra-negatives 1024 \
  --center-neg-k 0 \
  --counterfactual-mode vae \
  --counterfactual-weights /content/drive/MyDrive/dazzled/outputs/cf_vae/cf_vae_final.safetensors \
  --checkpoint-dir /content/drive/MyDrive/dazzled/outputs/hashnet \
  --domain-mode regex \
  --domain-regex '(?:^|/)(ffhq|openimages|openimg|mobileviews?)(?:/|_)' \
  --workers 2 \
  --prefetch-factor 1 \
  --pin-memory \
  --amp \
  --channels-last \
  --allow-tf32 \
  --cudnn-benchmark \
  --lr 3e-4 \
  --warmup-steps 500 \
  --center-weight 10 \
  --distinct-weight 0.5 \
  --quant-weight 0.1 \
  --cf-weight 0.1 \
  --dhd-weight 0.1 \
  --adv-weight 0


## 5. List Checkpoints


In [None]:
from pathlib import Path

CKPT_DIR = Path("/content/drive/MyDrive/dazzled/outputs/hashnet")
ckpts = sorted(CKPT_DIR.glob("*.safetensors"))
print(f"Found {len(ckpts)} checkpoints")
for ckpt in ckpts:
    print(ckpt.name)


## 6. TTC Inference (Production-Style)

Run the standalone TTC inference script on a sample image.


In [None]:
import os
os.chdir("/content/DaZZLeD/ml-core")

from pathlib import Path

CKPT_DIR = Path("/content/drive/MyDrive/dazzled/outputs/hashnet")
IMAGE_PATH = "/content/drive/MyDrive/dazzled/data/ffhq/224/00000.jpg"  # TODO: set a real path

ckpts = sorted(CKPT_DIR.glob("*.safetensors"))
if not ckpts:
    raise FileNotFoundError(f"No checkpoints in {CKPT_DIR}")

checkpoint = str(ckpts[-1])
print(f"Using checkpoint: {checkpoint}")

!python inference.py   --image "{IMAGE_PATH}"   --checkpoint "{checkpoint}"   --backbone resnet50   --hash-dim 128   --proj-dim 512   --ttc-views 8   --stability-threshold 0.9   --hamming-threshold 10
