# ZED-style Zero-Shot Detector (Entropy-based)

**What this notebook does**

- Implements a *ZED-style* detector for AI-generated images without using the original authors' code.
- Uses a pretrained neural **entropy model** from **CompressAI** to compute an entropy score (bits-per-pixel, bpp).
- **No fake data needed for training**: calibrate a threshold using *real images only*.
- Evaluate on your own real/fake sets; visualize histograms & ROC; export CSV scores.

**Folder layout expected**

```
data/
  real/   # put real images here (jpg/png)
  fake/   # put synthetic images here (optional, for evaluation)
```

**How to run**
1. Upload this notebook to **Google Colab** (recommended) or run locally with Python 3.10+.
2. Create `data/real` and (optionally) `data/fake` and add images.
3. Run all cells.
4. For deployment, use the *Single Image Inference* cell.

> ⚠️ This is a faithful *reproduction of the idea* behind ZED using open components. Results depend on your chosen entropy model and preprocessing.


In [1]:
#@title 0) Setup (installs) — run once per environment
import sys
IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
    !pip -q install compressai==1.2.6 torch torchvision scikit-learn pillow matplotlib
else:
    print("If packages are missing, run: pip install compressai torch torchvision scikit-learn pillow matplotlib")


If packages are missing, run: pip install compressai torch torchvision scikit-learn pillow matplotlib


In [3]:
!pip install compressai==1.2.6


Collecting compressai==1.2.6
  Downloading compressai-1.2.6.tar.gz (163 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m163.9/163.9 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25h  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[?25hCollecting einops (from compressai==1.2.6)
  Downloading einops-0.8.1-py3-none-any.whl.metadata (13 kB)
Collecting torch-geometric>=2.3.0 (from compressai==1.2.6)
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m
Collecting pytorch-msssim (from compressai==1.2.6)
  Downloading pytorch_msssim-1.0.0-py3-none-any.whl.metadata (8.0 kB)
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[

In [4]:
#@title 1) Imports & utilities
import os, glob, math, json, random
from pathlib import Path
from typing import List, Tuple

import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, roc_curve, accuracy_score, f1_score

from compressai.zoo import bmshj2018_hyperprior, cheng2020_attn

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using device:', DEVICE)

IMG_EXTS = {'.jpg', '.jpeg', '.png', '.bmp', '.webp'}

def list_images(folder: str) -> List[str]:
    p = Path(folder)
    if not p.exists():
        return []
    return sorted([str(x) for x in p.rglob('*') if x.suffix.lower() in IMG_EXTS])

def load_image(path: str, max_side: int = 512) -> torch.Tensor:
    """Load image -> torch.Tensor [1,3,H,W] in [0,1]. Resizes so the longer side <= max_side."""
    img = Image.open(path).convert('RGB')
    w, h = img.size
    scale = min(1.0, max_side / max(w, h))
    if scale < 1.0:
        new_w, new_h = int(w*scale), int(h*scale)
        img = img.resize((max(1,new_w), max(1,new_h)), Image.LANCZOS)
    arr = np.asarray(img).astype(np.float32) / 255.0
    x = torch.from_numpy(arr).permute(2,0,1).unsqueeze(0)
    return x

def choose_entropy_model(model_name: str = 'bmshj2018_hyperprior', quality: int = 8):
    """Load a pretrained CompressAI model. quality in [1..8] (higher = better rate-distortion, more compute)."""
    if model_name == 'bmshj2018_hyperprior':
        m = bmshj2018_hyperprior(quality=quality, pretrained=True)
    elif model_name == 'cheng2020_attn':
        m = cheng2020_attn(pretrained=True)
    else:
        raise ValueError('Unknown model_name')
    m.eval().to(DEVICE)
    try:
        m.update()  # updates entropy parameters if needed
    except Exception as e:
        print('Warning: model.update() failed:', e)
    return m

def bpp_from_likelihoods(x: torch.Tensor, out: dict) -> torch.Tensor:
    """Estimate bits-per-pixel using model likelihoods (faster than full arithmetic coding)."""
    N, C, H, W = x.shape
    total_bits = 0.0
    for k, lik in out['likelihoods'].items():
        # Ensure numerical stability
        lik = torch.clamp(lik, min=1e-9)
        total_bits += torch.sum(-torch.log2(lik))
    bpp = total_bits / (N * H * W)
    return bpp

def zed_score(image_path: str, model, multiscale=(1.0, 0.75, 0.5)) -> float:
    """Compute a ZED-style surprisal score: average bpp across a few downscale factors."""
    scores = []
    base = Image.open(image_path).convert('RGB')
    for s in multiscale:
        w, h = base.size
        img = base if s == 1.0 else base.resize((max(1,int(w*s)), max(1,int(h*s))), Image.LANCZOS)
        x = torch.from_numpy(np.asarray(img).astype(np.float32)/255.0).permute(2,0,1).unsqueeze(0).to(DEVICE)
        with torch.inference_mode():
            out = model(x)
            bpp = bpp_from_likelihoods(x, out)
        scores.append(float(bpp.detach().cpu()))
    return float(np.mean(scores))

def batch_scores(paths: List[str], model, desc: str = '') -> List[float]:
    scores = []
    for i, p in enumerate(paths):
        sc = zed_score(p, model)
        scores.append(sc)
        if (i+1) % 10 == 0:
            print(f"{desc} {i+1}/{len(paths)}: current={sc:.4f}")
    return scores


Using device: cpu


  @amp.autocast(enabled=False)


In [5]:
#@title 2) Load model
model_name = 'bmshj2018_hyperprior'  #@param ['bmshj2018_hyperprior', 'cheng2020_attn']
quality = 8  #@param {type:'slider', min:1, max:8, step:1}
model = choose_entropy_model(model_name, quality)
print('Model loaded:', model_name, 'quality', quality)


Downloading: "https://compressai.s3.amazonaws.com/models/v1/bmshj2018-hyperprior-8-a583f0cf.pth.tar" to /Users/kuan_/.cache/torch/hub/checkpoints/bmshj2018-hyperprior-8-a583f0cf.pth.tar
100%|██████████| 46.0M/46.0M [03:02<00:00, 265kB/s]

Model loaded: bmshj2018_hyperprior quality 8





In [9]:
# --- Setup: download Tiny ImageNet and export N real images for ZED ---
import os, zipfile, urllib.request, random, shutil
from PIL import Image

ROOT = "data/tiny-imagenet"
OUT  = "data/zed"
N_TRAIN, N_VAL = 2000, 500   # change counts as you like
SIZE = 256                   # resize target (ZED usually uses 224/256)

os.makedirs(ROOT, exist_ok=True)
os.makedirs(f"{OUT}/train/real", exist_ok=True)
os.makedirs(f"{OUT}/val/real", exist_ok=True)

zip_path = f"{ROOT}/tiny-imagenet-200.zip"
if not os.path.exists(zip_path):
    urllib.request.urlretrieve(
        "http://cs231n.stanford.edu/tiny-imagenet-200.zip", zip_path
    )

with zipfile.ZipFile(zip_path, 'r') as zf:
    zf.extractall(ROOT)

# Collect all train/val image paths
def collect(img_dir):
    paths = []
    for cls in os.listdir(img_dir):
        p = os.path.join(img_dir, cls, "images")
        if os.path.isdir(p):
            for f in os.listdir(p):
                if f.lower().endswith((".jpg", ".jpeg", ".png")):
                    paths.append(os.path.join(p, f))
    return paths

train_imgs = collect(os.path.join(ROOT, "tiny-imagenet-200", "train"))
val_dir    = os.path.join(ROOT, "tiny-imagenet-200", "val", "images")
val_imgs   = [os.path.join(val_dir, f) for f in os.listdir(val_dir)
              if f.lower().endswith((".jpg", ".jpeg", ".png"))]

random.seed(42)
random.shuffle(train_imgs)
random.shuffle(val_imgs)
train_imgs = train_imgs[:N_TRAIN]
val_imgs   = val_imgs[:N_VAL]

def export(paths, outdir):
    for i, src in enumerate(paths):
        try:
            im = Image.open(src).convert("RGB")
            im = im.resize((SIZE, SIZE), Image.BICUBIC)
            im.save(os.path.join(outdir, f"real_{i:06d}.jpg"), quality=95)
        except Exception as e:
            print("skip", src, e)

export(train_imgs, f"{OUT}/train/real")
export(val_imgs,   f"{OUT}/val/real")

print("Done. Sample:",
      len(os.listdir(f'{OUT}/train/real')), "train,",
      len(os.listdir(f'{OUT}/val/real')),   "val")


KeyboardInterrupt: 

In [10]:
import os, urllib.request, zipfile, random, shutil

os.makedirs("data/real", exist_ok=True)
os.makedirs("data/fake", exist_ok=True)

# --- Download a tiny set of real images ---
url = "http://cs231n.stanford.edu/tiny-imagenet-200.zip"
zip_path = "tiny-imagenet.zip"
if not os.path.exists(zip_path):
    urllib.request.urlretrieve(url, zip_path)

with zipfile.ZipFile(zip_path, 'r') as zf:
    zf.extractall(".")

# Copy 50 real images into data/real
src_dir = "tiny-imagenet-200/train/n01443537/images"
for i, f in enumerate(os.listdir(src_dir)[:50]):
    shutil.copy(os.path.join(src_dir, f), f"data/real/real_{i:03d}.jpg")

# --- Download some fake faces (StyleGAN on Kaggle mirror) ---
fake_url = "https://github.com/jeffheaton/WhichFaceIsReal-dataset/raw/master/fake.zip"
fake_zip = "fake.zip"
urllib.request.urlretrieve(fake_url, fake_zip)
with zipfile.ZipFile(fake_zip, 'r') as zf:
    zf.extractall("data/fake")

print("Real images:", len(os.listdir("data/real")))
print("Fake images:", len(os.listdir("data/fake")))


HTTPError: HTTP Error 404: Not Found

In [8]:
#@title 3) Point to your data folders
REAL_DIR = 'data/real'  #@param {type:'string'}
FAKE_DIR = 'data/fake'  #@param {type:'string'}

real_paths = list_images(REAL_DIR)
fake_paths = list_images(FAKE_DIR)
print(f"Found {len(real_paths)} real images, {len(fake_paths)} fake images")
assert len(real_paths) > 0, 'Please add some images into data/real first.'


Found 0 real images, 0 fake images


AssertionError: Please add some images into data/real first.

In [None]:
#@title 4) Calibrate threshold on REAL images only
target_fpr = 0.05  #@param {type:'number'}
np.random.seed(0)
calib_subset = real_paths  # you can subsample if you have many
real_scores = batch_scores(calib_subset, model, desc='Real')
thr = float(np.quantile(real_scores, 1.0 - target_fpr))
print(f"Calibrated threshold @FPR~{target_fpr:.2f}: {thr:.4f} bpp")

# Save for later use
os.makedirs('artifacts', exist_ok=True)
json.dump({'threshold_bpp': thr, 'model_name': model_name, 'quality': quality}, open('artifacts/zed_threshold.json','w'))
np.savetxt('artifacts/real_scores.csv', np.array(real_scores), delimiter=',')
print('Saved artifacts to artifacts/ directory')


In [None]:
#@title 5) Evaluate on REAL and FAKE (if available)
def predict_label(score, thr):
    return 1 if score > thr else 0  # 1=fake, 0=real

all_y, all_s = [], []
print('Scoring real set...')
real_scores_eval = batch_scores(real_paths, model, desc='RealEval')
all_y += [0]*len(real_scores_eval)
all_s += real_scores_eval

fake_scores_eval = []
if len(fake_paths) > 0:
    print('Scoring fake set...')
    fake_scores_eval = batch_scores(fake_paths, model, desc='FakeEval')
    all_y += [1]*len(fake_scores_eval)
    all_s += fake_scores_eval

metrics = {}
if len(set(all_y)) == 2:
    auroc = roc_auc_score(all_y, all_s)
    fpr, tpr, _ = roc_curve(all_y, all_s)
    preds = [predict_label(s, thr) for s in all_s]
    acc = accuracy_score(all_y, preds)
    f1 = f1_score(all_y, preds)
    metrics = {'AUROC': auroc, 'ACC@thr': acc, 'F1@thr': f1}
    print('Metrics:', metrics)
else:
    print('Only real images found; evaluated calibration only.')

# Save detailed CSV
import csv
with open('artifacts/scores_detailed.csv','w', newline='') as f:
    w = csv.writer(f)
    w.writerow(['path','label(0=real,1=fake)','score_bpp','pred(0=real,1=fake)'])
    for p, s in zip(real_paths, real_scores_eval):
        w.writerow([p, 0, s, predict_label(s, thr)])
    for p, s in zip(fake_paths, fake_scores_eval):
        w.writerow([p, 1, s, predict_label(s, thr)])
print('Saved artifacts/scores_detailed.csv')

# Plots
plt.figure(figsize=(6,4))
if len(real_scores_eval):
    plt.hist(real_scores_eval, bins=40, alpha=0.6, label='real')
if len(fake_scores_eval):
    plt.hist(fake_scores_eval, bins=40, alpha=0.6, label='fake')
plt.axvline(thr, linestyle='--', label=f'Threshold={thr:.3f}')
plt.xlabel('Entropy score (bpp)')
plt.ylabel('Count')
plt.title('Score distribution')
plt.legend()
plt.show()

if len(metrics):
    plt.figure(figsize=(6,4))
    plt.plot(fpr, tpr)
    plt.xlabel('FPR')
    plt.ylabel('TPR')
    plt.title(f'ROC (AUROC={metrics["AUROC"]:.3f})')
    plt.show()


In [None]:
#@title 6) Single Image Inference (deploy-style)
import json
cfg = json.load(open('artifacts/zed_threshold.json')) if os.path.exists('artifacts/zed_threshold.json') else None
if cfg:
    print('Loaded threshold config:', cfg)
else:
    print('No saved threshold found; using current settings.')
    cfg = {'threshold_bpp': thr, 'model_name': model_name, 'quality': quality}

TEST_IMAGE = ''  #@param {type:'string'}
if TEST_IMAGE:
    score = zed_score(TEST_IMAGE, model)
    decision = 'FAKE' if score > cfg['threshold_bpp'] else 'REAL'
    print(f"Image: {TEST_IMAGE}\nScore (bpp): {score:.4f}\nDecision: {decision}")
else:
    print('Set TEST_IMAGE to a file path to run inference.')


In [None]:
#@title 7) (Optional) Ensemble over multiple models
def zed_score_ensemble(image_path: str, models: list) -> float:
    return float(np.mean([zed_score(image_path, m) for m in models]))

## Example usage:
# models = [choose_entropy_model('bmshj2018_hyperprior', q) for q in (6,8)]
# sc = zed_score_ensemble('data/real/example.jpg', models)
# print('Ensemble score:', sc)


## Notes & Tips
- **Calibration**: The 95th percentile threshold on real scores targets ~5% FPR. Adjust `target_fpr` to your needs.
- **Speed**: Using `bpp_from_likelihoods` avoids full arithmetic coding and is faster, while remaining faithful to the idea.
- **Robustness**: You can improve robustness by averaging scores over multiple scales and mild JPEG compressions.
- **Security**: Like all detectors, this can be attacked. Consider ensembling, input randomization, and frequency-domain checks for production.
