# DaZZLeD Hash Center Training Notebook (ResNet + Counterfactual VAE)

**Goal:** Train the ResNet Hash Center model from `resnet.tex` with counterfactual VAE, CF‑SimCLR, DHD, PGD, and TTC checks.

**Runtime:** Set Colab to GPU before running training cells.

**Note:** If you do not have VAE weights yet, you must train them first (Step 3). If you want a quick run without VAE, set `--counterfactual-mode aug` in Step 4.


## 0. Mount Google Drive


In [None]:
from google.colab import drive
from pathlib import Path

drive.mount('/content/drive')

DRIVE_ROOT = Path("/content/drive/MyDrive/dazzled")
DATA_ROOT = DRIVE_ROOT / "data"
OUTPUT_ROOT = DRIVE_ROOT / "outputs"
OUTPUT_ROOT.mkdir(parents=True, exist_ok=True)


## 1. Setup & Installation


In [None]:
import os
if not os.path.exists('DaZZLeD'):
    !git clone https://github.com/D13ya/DaZZLeD.git
    %cd DaZZLeD/ml-core
else:
    %cd DaZZLeD/ml-core

!pip install -q -r requirements.txt


## 2.1 Restore Dataset (Drive Zip)

If you keep a dataset zip on Drive, extract it once to local disk (`/content/data`) for faster I/O.


In [None]:
from pathlib import Path
import zipfile
import shutil
import urllib.request

DRIVE_ROOT = Path("/content/drive/MyDrive/dazzled")
LOCAL_DATA = Path("/content/data")
LOCAL_DATA.mkdir(parents=True, exist_ok=True)

RESET_LOCAL = True  # set False to reuse existing /content/data
DATASET_URL = ""    # optional: direct http(s) link to a zip file
BUILD_CACHE_ZIP = True  # write training-images.zip after extraction if missing

CACHE_ZIP = DRIVE_ROOT / "data-cache/training-images.zip"
ALT_ZIP = DRIVE_ROOT / "dazzled_dataset_v4.zip"

if RESET_LOCAL and LOCAL_DATA.exists():
    shutil.rmtree(LOCAL_DATA)
    LOCAL_DATA.mkdir(parents=True, exist_ok=True)

exts = {".jpg", ".jpeg", ".png", ".bmp"}

def count_images(root: Path) -> int:
    return sum(
        1 for p in root.rglob("*")
        if p.is_file() and p.suffix.lower() in exts
    )

def extract_zip(zip_path: Path, dest: Path) -> int:
    print(f"Extracting {zip_path} -> {dest}")
    with zipfile.ZipFile(zip_path, "r") as zf:
        zf.extractall(dest)
    extracted = count_images(dest)
    print(f"Extracted {extracted} images to {dest}")
    return extracted

def download_zip(url: str, dest: Path) -> None:
    dest.parent.mkdir(parents=True, exist_ok=True)
    print(f"Downloading {url} -> {dest}")
    urllib.request.urlretrieve(url, dest)

existing = count_images(LOCAL_DATA)
if existing > 0 and not RESET_LOCAL:
    print(f"Local data already present under {LOCAL_DATA} ({existing} images); skipping extract.")
else:
    extracted = 0
    source = None

    if CACHE_ZIP.exists():
        source = CACHE_ZIP
    elif ALT_ZIP.exists():
        source = ALT_ZIP
    elif DATASET_URL:
        download_zip(DATASET_URL, CACHE_ZIP)
        source = CACHE_ZIP

    if source is not None:
        extracted = extract_zip(source, LOCAL_DATA)

    if extracted == 0:
        nested = [p for p in LOCAL_DATA.rglob("*.zip") if p.is_file()]
        for nested_zip in nested:
            if source is not None and nested_zip.resolve() == source.resolve():
                continue
            extracted = extract_zip(nested_zip, LOCAL_DATA)
            if extracted > 0:
                break

    if extracted == 0:
        print("No images found after extraction. Check the zip contents or path.")

    if BUILD_CACHE_ZIP and extracted > 0 and not CACHE_ZIP.exists():
        print(f"Building cache zip at {CACHE_ZIP}")
        CACHE_ZIP.parent.mkdir(parents=True, exist_ok=True)
        with zipfile.ZipFile(CACHE_ZIP, "w", compression=zipfile.ZIP_DEFLATED) as zf:
            for p in LOCAL_DATA.rglob("*"):
                if p.is_file():
                    zf.write(p, p.relative_to(LOCAL_DATA))
        print(f"Wrote cache zip: {CACHE_ZIP}")


In [None]:
from pathlib import Path
from collections import Counter
from PIL import Image

exts = {".jpg", ".jpeg", ".png", ".bmp"}

if "DATA_ROOT" not in globals():
    DRIVE_ROOT = Path("/content/drive/MyDrive/dazzled")
    LOCAL_DATA = Path("/content/data")
    DRIVE_DATA = DRIVE_ROOT / "data"
    def has_images(root: Path) -> bool:
        if not root.exists():
            return False
        return any(p.is_file() and p.suffix.lower() in exts for p in root.rglob("*"))
    DATA_ROOT = LOCAL_DATA if has_images(LOCAL_DATA) else DRIVE_DATA

print(f"Validating DATA_ROOT: {DATA_ROOT}")

paths = [p for p in DATA_ROOT.rglob("*") if p.is_file() and p.suffix.lower() in exts]
print(f"Total images found: {len(paths)}")

ext_counts = Counter(p.suffix.lower() for p in paths)
print("Extension counts:", dict(ext_counts))

sample = paths[:200]
bad = []
for p in sample:
    try:
        with Image.open(p) as img:
            img.convert("RGB")
    except Exception as e:
        bad.append((p, str(e)))

if bad:
    print(f"Corrupt/unreadable samples: {len(bad)} (showing up to 5)")
    for p, err in bad[:5]:
        print(f"  - {p}: {err}")
else:
    print("Sample check: all images readable and convertible to RGB.")

if len(paths) == 0:
    raise ValueError("No images found for training. Check extraction paths or zip contents.")


## 2.2 Build Manifest (Optional)

If you already have a manifest at `/content/drive/MyDrive/dazzled/manifests/train.txt`, you can skip this.


In [None]:
from pathlib import Path
import re

DRIVE_ROOT = Path("/content/drive/MyDrive/dazzled")
LOCAL_DATA = Path("/content/data")
DRIVE_DATA = DRIVE_ROOT / "data"
MANIFEST = DRIVE_ROOT / "manifests/train.txt"
MANIFEST.parent.mkdir(parents=True, exist_ok=True)

exts = {".jpg", ".jpeg", ".png", ".bmp"}
LABEL_REGEX = re.compile(r"^((?:ffhq|openimages|openimg|mobileviews?)_\d+|\d+)")

def has_images(root: Path) -> bool:
    if not root.exists():
        return False
    return any(p.is_file() and p.suffix.lower() in exts for p in root.rglob("*"))

DATA_ROOT = LOCAL_DATA if has_images(LOCAL_DATA) else DRIVE_DATA
print(f"Using DATA_ROOT: {DATA_ROOT}")

lines = []
for p in DATA_ROOT.rglob("*"):
    if not p.is_file() or p.suffix.lower() not in exts:
        continue
    match = LABEL_REGEX.search(p.name)
    label = match.group(1) if match else p.stem
    lines.append(f"{p} {label}")

MANIFEST.write_text("\n".join(lines))
print(f"Wrote {len(lines)} lines to {MANIFEST} (per-image labels)")


## 2.5. Sanity Checks (Labels + Domains)

Run this once after the manifest is created to verify labels/domains before any training.


In [None]:
from pathlib import Path
from collections import Counter
import re

MANIFEST = Path("/content/drive/MyDrive/dazzled/manifests/train.txt")
LABEL_REGEX = re.compile(r"^((?:ffhq|openimages|openimg|mobileview)_\d+|\d+)")
DOMAIN_REGEX = re.compile(r"(?:^|/)(ffhq|openimages|openimg|mobileviews?)(?:/|_)")

if not MANIFEST.exists():
    raise FileNotFoundError(f"Manifest not found: {MANIFEST}")

base = MANIFEST.resolve().parent
lines = [line.strip() for line in MANIFEST.read_text().splitlines() if line.strip() and not line.strip().startswith('#')]
total = len(lines)

labels = []
domains = []
missing = []

for line in lines:
    parts = line.split()
    path = Path(parts[0])
    if not path.is_absolute():
        path = (base / path).resolve()

    label = parts[1] if len(parts) > 1 else None
    if label is None:
        match = LABEL_REGEX.search(path.name)
        if match:
            label = match.group(1)

    domain = None
    match = DOMAIN_REGEX.search(str(path))
    if match:
        domain = match.group(1)

    labels.append(label)
    domains.append(domain)
    if not path.exists():
        missing.append(str(path))

label_known = [str(l) for l in labels if l is not None]
domain_known = [str(d) for d in domains if d is not None]

label_unique = len(set(label_known))
label_unlabeled = total - len(label_known)
label_pct = (label_unlabeled / total * 100.0) if total else 0.0

domain_unique = len(set(domain_known))
domain_unlabeled = total - len(domain_known)
domain_pct = (domain_unlabeled / total * 100.0) if total else 0.0

print(f"Label stats: {label_unique} unique, {label_unlabeled}/{total} unlabeled ({label_pct:.1f}%).")
if label_known:
    top = Counter(label_known).most_common(5)
    top_str = ", ".join(f"{k}:{v}" for k, v in top)
    print(f"Top labels: {top_str}")

print(f"Domain stats: {domain_unique} unique, {domain_unlabeled}/{total} unlabeled ({domain_pct:.1f}%).")
if domain_known:
    top = Counter(domain_known).most_common(5)
    top_str = ", ".join(f"{k}:{v}" for k, v in top)
    print(f"Top domains: {top_str}")

if missing:
    print(f"Missing files: {len(missing)} (showing up to 5)")
    for p in missing[:5]:
        print(f"  - {p}")

if label_unique < 2:
    raise ValueError("CRITICAL: fewer than 2 unique labels found. Check regex/manifest.")
if domain_unique < 2:
    print("WARNING: fewer than 2 unique domains found; VAE training will fail.")


## 3. Train Counterfactual VAE (Save Weights)

This produces the `--counterfactual-weights` file used by HashNet.


In [None]:
import os
os.chdir("/content/DaZZLeD/ml-core")

!PYTHONPATH=/content/DaZZLeD/ml-core python training/train_counterfactual_vae.py \
  --data-list /content/drive/MyDrive/dazzled/manifests/train.txt \
  --epochs 10 \
  --batch-size 96 \
  --checkpoint-dir /content/drive/MyDrive/dazzled/outputs/cf_vae \
  --domain-mode regex \
  --domain-regex "(?:^|/)(ffhq|openimages|openimg|mobileviews?)(?:/|_)" \
  --workers 4 \
  --prefetch-factor 2 \
  --pin-memory \
  --amp


## 4. Train HashNet (ResNet + Hash Centers + CF/DHD/PGD)

This uses the VAE weights from Step 3 and writes checkpoints to Drive.


In [None]:
import os
os.chdir("/content/DaZZLeD/ml-core")

!PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
PYTHONPATH=/content/DaZZLeD/ml-core \
python training/train_hashnet.py \
  --data-list /content/drive/MyDrive/dazzled/manifests/train.txt \
  --epochs 10 \
  --batch-size 64 \
  --center-mode random \
  --extra-negatives 1024 \
  --center-neg-k 0 \
  --counterfactual-mode vae \
  --counterfactual-weights /content/drive/MyDrive/dazzled/outputs/cf_vae/cf_vae_final.safetensors \
  --checkpoint-dir /content/drive/MyDrive/dazzled/outputs/hashnet \
  --domain-mode regex \
  --domain-regex '(?:^|/)(ffhq|openimages|openimg|mobileviews?)(?:/|_)' \
  --workers 2 \
  --prefetch-factor 1 \
  --pin-memory \
  --amp \
  --channels-last \
  --allow-tf32 \
  --cudnn-benchmark \
  --lr 3e-4 \
  --warmup-steps 500 \
  --center-weight 10 \
  --distinct-weight 0.5 \
  --quant-weight 0.1 \
  --cf-weight 0.1 \
  --dhd-weight 0.1 \
  --adv-weight 0


## 5. List Checkpoints


In [None]:
from pathlib import Path

CKPT_DIR = Path("/content/drive/MyDrive/dazzled/outputs/hashnet")
ckpts = sorted(CKPT_DIR.glob("*.safetensors"))
print(f"Found {len(ckpts)} checkpoints")
for ckpt in ckpts:
    print(ckpt.name)


## 6. TTC Inference (Production-Style)

Run the standalone TTC inference script on a sample image.


In [None]:
import os
os.chdir("/content/DaZZLeD/ml-core")

from pathlib import Path

CKPT_DIR = Path("/content/drive/MyDrive/dazzled/outputs/hashnet")
IMAGE_PATH = "/content/drive/MyDrive/dazzled/data/ffhq/224/00000.jpg"  # TODO: set a real path

ckpts = sorted(CKPT_DIR.glob("*.safetensors"))
if not ckpts:
    raise FileNotFoundError(f"No checkpoints in {CKPT_DIR}")

checkpoint = str(ckpts[-1])
print(f"Using checkpoint: {checkpoint}")

!python inference.py   --image "{IMAGE_PATH}"   --checkpoint "{checkpoint}"   --backbone resnet50   --hash-dim 128   --proj-dim 512   --ttc-views 8   --stability-threshold 0.9   --hamming-threshold 10
