In [1]:
import os
import sys
import json
import math
import shutil
import random
import subprocess
from pathlib import Path

import numpy as np
from PIL import Image

In [2]:
main_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.getcwd()))))
sys.path.append(main_dir)

print("Main directory added to sys.path:", main_dir)

Main directory added to sys.path: /data/inr/llm


In [3]:
DATASET_DIR = os.path.join(
    main_dir, "Datasets","LOVEDA", "Train", "Train"
)

In [4]:
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")

REAL_ROOT = Path(DATASET_DIR)

# Generator checkpoints (your provided paths)
LAYOUT_CKPT = "/data/inr/llm/DIFF_CD/Diffusor/outputsV3/layout_d3pm_masked_sparse_80k_domain_cond"
LAYOUT_CHECKPOINT = 79000
LAYOUT_DIFFUSION_TYPE = "d3pm"
DOMAIN_COND_SCALE = 1.0

CONTROLNET_CKPT = "/data/inr/llm/DIFF_CD/Diffusor/outputsV3/controlnet_ratio_lora_ckpt18000_layout80000/checkpoint-112000"
BASE_MODEL = "/home/nvidia/.cache/huggingface/hub/models--runwayml--stable-diffusion-v1-5/snapshots/451f4fe16113bff5a5d2269ed5ad43b0592e9a14"

# Output synthetic dataset root (WILL CREATE LoveDA-like structure)
# Saved under: <Diffusor root>/syntheticDataset/...
SYNTH_ROOT = Path(main_dir) / "DIFF_CD" / "Diffusor" /"SyntheticDataset"
# How many synthetic samples total to generate
SYNTH_TOTAL = 2000

# Prompt templates (domain-specific)
PROMPTS = {
    "rural": "A high-resolution satellite image of a rural area",
    "urban": "A high-resolution satellite image of an urban area",
}

# Inference settings (your provided)
IMAGE_SIZE = 1024
NUM_STEPS_LAYOUT = 200
NUM_STEPS_IMAGE = 50
GUIDANCE_SCALE = 1.0
GUIDANCE_RESCALE = 0.0
CONTROL_SCALE = 1.0
LORA_SCALE = 1.0
DTYPE = "fp16"
DEVICE = "cuda:0"
SAMPLER = "ddim"

# Random seed base
SEED0 = 123


In [5]:
# We do NOT "fully balance" per-image ratios.
# We instead:
#   - enforce domain quota so Rural+Synth ≈ Urban+Synth
#   - within each domain, enforce *pixel deficit* vs a flattened target distribution
#   - prompt ratios ONLY within safe bounds around real-domain mean μ_d,k

# Target flattening within each domain:
#   t_nonbg ∝ (pi_nonbg)^BETA, with BETA<1 => flatter => minorities get deficits => generated more
BETA = 0.55

# Extra minority prioritization factor:
#   score_k ∝ deficit_k * (pi_k + eps)^(-GAMMA)
GAMMA = 0.7

# Safety caps: max allowed boost in the prompted ratio relative to μ_d,k (real-domain mean)
MAX_DELTA = {
    "background": 0.15,
    "building": 1.20,
    "road": 1.10,
    "water": 0.80,
    "barren": 1.00,
    "forest": 0.35,
    "agriculture": 0.35,
}

# Acceptance criteria (prevents "ratio too far from average" failures)
FOCUS_ABS_TOL = 0.09      # |p_k - r_k| <= this
GLOBAL_MAX_ABS = 0.32     # max_j |p_j - μ_j| <= this
KL_MAX = 0.90             # KL(p || μ) <= this

# Retry limits
MAX_ATTEMPTS_PER_SAMPLE = 12

# Class order used by your pipeline (matches DEFAULT_LOVEDA_CLASS_NAMES)
CLASS_NAMES = ["background", "building", "road", "water", "barren", "forest", "agriculture"]
K = len(CLASS_NAMES)


In [6]:

# ---------------------------
# 2) Utility: find repo root / sample_pair script
# ---------------------------
def find_repo_root() -> Path:
    """
    Find a directory containing src/scripts/sample_pair.py
    starting from CWD and walking up.
    """
    cwd = Path(os.getcwd()).resolve()
    candidates = [cwd] + list(cwd.parents)
    for base in candidates:
        p = base / "src" / "scripts" / "sample_pair.py"
        if p.exists():
            return base
        # common alternate repo folder nesting
        p2 = base / "SyntheticGen" / "src" / "scripts" / "sample_pair.py"
        if p2.exists():
            return base / "SyntheticGen"
        p3 = base / "Diffusor" / "src" / "scripts" / "sample_pair.py"
        if p3.exists():
            return base / "Diffusor"
    raise FileNotFoundError("Could not locate src/scripts/sample_pair.py from current path parents.")

REPO_ROOT = find_repo_root()
SAMPLE_PAIR = REPO_ROOT / "src" / "scripts" / "sample_pair.py"
print("Repo root:", REPO_ROOT)
print("sample_pair:", SAMPLE_PAIR)
print("Synthetic dataset root:", SYNTH_ROOT)


Repo root: /data/inr/llm/DIFF_CD/Diffusor/SyntheticGen
sample_pair: /data/inr/llm/DIFF_CD/Diffusor/SyntheticGen/src/scripts/sample_pair.py
Synthetic dataset root: /data/inr/llm/DIFF_CD/Diffusor/SyntheticDataset


In [7]:

# ---------------------------
# 3) Utility: mask loading and histogram (robust to raw vs indexed)
# ---------------------------
IGNORE_INDEX = 255

def load_mask_as_indexed(mask_path: Path) -> np.ndarray:
    """
    Returns mask in indexed format:
      classes: 0..6, ignore: 255
    Supports:
      - LoveDA raw: 0=ignore, 1..7=classes
      - already indexed: 0..6 with optional 255 ignore
    """
    arr = np.array(Image.open(mask_path))
    if arr.ndim == 3:
        arr = arr[:, :, 0]

    mx = int(arr.max())
    mn = int(arr.min())

    # Heuristic: raw LoveDA typically has values in {0..7} with 0 as ignore
    if mx <= 7 and mn >= 0 and (0 in np.unique(arr)):
        out = arr.astype(np.int64)
        ignore = out == 0
        out[ignore] = IGNORE_INDEX
        out[~ignore] = out[~ignore] - 1
        return out.astype(np.int64)

    # treat as indexed already
    return arr.astype(np.int64)

def hist_counts(mask_idx: np.ndarray) -> np.ndarray:
    flat = mask_idx.reshape(-1)

    # Drop ignore + any out-of-range labels (prevents bincount length > K)
    flat = flat[(flat != IGNORE_INDEX) & (flat >= 0) & (flat < K)]

    return np.bincount(flat, minlength=K).astype(np.float64)

def safe_kl(p: np.ndarray, q: np.ndarray, eps: float = 1e-12) -> float:
    p = np.clip(p, eps, 1.0); p = p / p.sum()
    q = np.clip(q, eps, 1.0); q = q / q.sum()
    return float(np.sum(p * np.log(p / q)))

# ---------------------------
# 4) Scan real dataset: domain counts + pixel totals
# ---------------------------
def domain_dirs(root: Path, domain_title: str) -> tuple[Path, Path]:
    droot = root / domain_title
    masks = droot / "masks_png"
    imgs = droot / "images_png"
    if not masks.exists():
        for alt in ["masks", "labels", "SegmentationClass"]:
            if (droot / alt).exists():
                masks = droot / alt
                break
    if not imgs.exists():
        for alt in ["images", "imgs", "JPEGImages"]:
            if (droot / alt).exists():
                imgs = droot / alt
                break
    return imgs, masks

def scan_domain(root: Path, domain_title: str) -> dict:
    imgs_dir, masks_dir = domain_dirs(root, domain_title)
    mask_files = sorted([p for p in masks_dir.glob("*.png")])
    if not mask_files:
        raise FileNotFoundError(f"No masks found for {domain_title} under {masks_dir}")
    totals = np.zeros(K, dtype=np.float64)
    valid_px_total = 0.0

    for mp in mask_files:
        m = load_mask_as_indexed(mp)
        c = hist_counts(m)
        totals += c
        valid_px_total += float(c.sum())

    return {
        "n": len(mask_files),
        "totals": totals,
        "valid_px_total": valid_px_total,
        "avg_valid_px": valid_px_total / max(1, len(mask_files)),
        "masks_dir": str(masks_dir),
        "imgs_dir": str(imgs_dir),
    }

real_r = scan_domain(REAL_ROOT, "Rural")
real_u = scan_domain(REAL_ROOT, "Urban")

print("Real counts:", {"rural": real_r["n"], "urban": real_u["n"]})
print("Real avg valid pixels:", {"rural": round(real_r["avg_valid_px"], 2), "urban": round(real_u["avg_valid_px"], 2)})

# Real-domain mean proportions μ_d (used for safe prompting + acceptance)
mu_real = {
    "rural": real_r["totals"] / max(1.0, real_r["totals"].sum()),
    "urban": real_u["totals"] / max(1.0, real_u["totals"].sum()),
}


Real counts: {'rural': 1366, 'urban': 1156}
Real avg valid pixels: {'rural': 696951.32, 'urban': 982431.24}


In [8]:

# ---------------------------
# 5) Domain quota: balance Rural vs Urban after coupling with original
# ---------------------------
nR, nU = real_r["n"], real_u["n"]

target_per_domain = math.floor((nR + nU + SYNTH_TOTAL) / 2.0)

need_r = max(0, target_per_domain - nR)
need_u = max(0, target_per_domain - nU)
allocated = need_r + need_u
rem = SYNTH_TOTAL - allocated
if rem > 0:
    need_r += rem // 2
    need_u += rem - rem // 2

while (need_r + need_u) > SYNTH_TOTAL:
    if need_r >= need_u and need_r > 0:
        need_r -= 1
    elif need_u > 0:
        need_u -= 1
    else:
        break

domain_quota = {"rural": need_r, "urban": need_u}
print("Synthetic domain quota:", domain_quota, "sum=", sum(domain_quota.values()))


Synthetic domain quota: {'rural': 895, 'urban': 1105} sum= 2000


In [9]:

# ---------------------------
# 6) Build LoveDA-like synthetic structure
# IMPORTANT: we will SAVE masks in *raw LoveDA encoding* (0 ignore, 1..7 classes),
# so LoveDADataset(remap_loveda_labels) will work without hacks.
# ---------------------------
def ensure_synth_structure(root: Path):
    for dom in ["Rural", "Urban"]:
        (root / "Train" / "Train" / dom / "images_png").mkdir(parents=True, exist_ok=True)
        (root / "Train" / "Train" / dom / "masks_png").mkdir(parents=True, exist_ok=True)
    (root / "meta").mkdir(parents=True, exist_ok=True)

ensure_synth_structure(SYNTH_ROOT)

def save_mask_raw_from_indexed(mask_idx: np.ndarray, out_path: Path):
    """Convert indexed (0..6, 255 ignore) -> LoveDA raw (0 ignore, 1..7 classes)."""
    raw = np.zeros_like(mask_idx, dtype=np.uint8)
    ignore = mask_idx == IGNORE_INDEX
    raw[ignore] = 0
    raw[~ignore] = (mask_idx[~ignore] + 1).astype(np.uint8)
    Image.fromarray(raw, mode="L").save(out_path)


In [10]:

# ---------------------------
# 7) Core math: domain-aware flattened target + pixel deficits -> class choice -> safe ratio prompt
# ---------------------------
EPS = 1e-12

def make_flat_target(pi: np.ndarray, bg_idx: int = 0) -> np.ndarray:
    """Keep background proportion ~as-is, flatten only non-background via power transform."""
    pi = np.clip(pi, EPS, 1.0); pi = pi / pi.sum()
    bg = float(pi[bg_idx])
    non = pi.copy()
    non[bg_idx] = 0.0
    if non.sum() <= 0:
        return pi
    non_t = np.power(non / non.sum(), BETA)
    non_t = non_t / non_t.sum()
    t = np.zeros_like(pi)
    t[bg_idx] = bg
    t += (1.0 - bg) * non_t
    t = t / t.sum()
    return t

def pick_focus_class(deficit: np.ndarray, pi: np.ndarray) -> int:
    """Deficit-based selection with extra rarity emphasis."""
    rarity = np.power(np.clip(pi, EPS, 1.0), -GAMMA)
    score = deficit * rarity
    score[0] = 0.0  # never focus background
    s = score.sum()
    if s <= 0:
        w = np.power(np.clip(pi, EPS, 1.0), -1.0)
        w[0] = 0.0
        w = w / w.sum()
        return int(np.random.choice(np.arange(K), p=w))
    p = score / s
    return int(np.random.choice(np.arange(K), p=p))

def propose_ratio(domain: str, focus_k: int, t: np.ndarray, pi: np.ndarray) -> float:
    """Safe ratio prompt anchored to μ_real(domain,k) and capped by MAX_DELTA."""
    mu0 = float(mu_real[domain][focus_k])
    if mu0 <= 0:
        mu0 = float(pi[focus_k])
    desired_lift = float(t[focus_k] / max(pi[focus_k], EPS))
    name = CLASS_NAMES[focus_k]
    cap = float(MAX_DELTA.get(name, 0.7))
    delta = min(cap, max(0.0, 0.85 * (desired_lift - 1.0)))
    r = mu0 * (1.0 + delta)
    r = float(r + np.random.normal(0.0, 0.015))
    r = float(np.clip(r, 0.01, 0.60))
    r = float(np.clip(r, mu0 * (1.0 - 0.25), mu0 * (1.0 + cap)))
    return r

def accept_layout(domain: str, p: np.ndarray, focus_k: int, r_req: float) -> tuple[bool, dict]:
    """Reject "bad generations" characterized by ratio too far from μ or requested ratio."""
    mu0 = mu_real[domain]
    max_abs = float(np.max(np.abs(p - mu0)))
    kl = safe_kl(p, mu0)
    focus_err = float(abs(p[focus_k] - r_req))
    ok = (focus_err <= FOCUS_ABS_TOL) and (max_abs <= GLOBAL_MAX_ABS) and (kl <= KL_MAX)
    return ok, {"max_abs": max_abs, "kl": kl, "focus_err": focus_err}


In [11]:

# ---------------------------
# 8) Generator call (sample_pair) and dataset writer
# ---------------------------
def run_sample_pair(tmp_out: Path, domain: str, ratios_str: str, seed: int):
    cmd = [
        "python3", str(SAMPLE_PAIR),
        "--layout_ckpt", str(LAYOUT_CKPT),
        "--layout_checkpoint", str(LAYOUT_CHECKPOINT),
        "--layout_diffusion_type", str(LAYOUT_DIFFUSION_TYPE),
        "--domain", str(domain),
        "--domain_cond_scale", str(DOMAIN_COND_SCALE),
        "--controlnet_ckpt", str(CONTROLNET_CKPT),
        "--base_model", str(BASE_MODEL),
        "--save_dir", str(tmp_out),
        "--ratios", ratios_str,
        "--prompt", PROMPTS[domain],
        "--image_size", str(IMAGE_SIZE),
        "--num_inference_steps_layout", str(NUM_STEPS_LAYOUT),
        "--num_inference_steps_image", str(NUM_STEPS_IMAGE),
        "--guidance_scale", str(GUIDANCE_SCALE),
        "--guidance_rescale", str(GUIDANCE_RESCALE),
        "--control_scale", str(CONTROL_SCALE),
        "--lora_scale", str(LORA_SCALE),
        "--sampler", str(SAMPLER),
        "--seed", str(seed),
        "--dtype", str(DTYPE),
        "--device", str(DEVICE),
    ]
    subprocess.run(cmd, cwd=str(REPO_ROOT), check=True)

def write_synth_sample(domain: str, stem: str, tmp_out: Path):
    """Move tmp_out outputs into SYNTH_ROOT LoveDA structure. Save mask as raw LoveDA encoding."""
    dom_title = "Rural" if domain == "rural" else "Urban"
    img_src = tmp_out / "image.png"
    msk_src = tmp_out / "layout.png"
    meta_src = tmp_out / "metadata.json"

    img_dst = SYNTH_ROOT / "Train" / "Train" / dom_title / "images_png" / f"{stem}.png"
    msk_dst = SYNTH_ROOT / "Train" / "Train" / dom_title / "masks_png" / f"{stem}.png"
    meta_dst = SYNTH_ROOT / "meta" / f"{stem}.json"

    shutil.move(str(img_src), str(img_dst))

    m_idx = np.array(Image.open(msk_src))
    if m_idx.ndim == 3:
        m_idx = m_idx[:, :, 0]
    m_idx = m_idx.astype(np.int64)
    save_mask_raw_from_indexed(m_idx, msk_dst)

    if meta_src.exists():
        shutil.move(str(meta_src), str(meta_dst))
    else:
        meta_dst.write_text(json.dumps({"domain": domain}, indent=2))

    return img_dst, msk_dst, meta_dst, m_idx

# ---------------------------
# 9) Main loop: domain quota + pixel-deficit controller
# ---------------------------
cur = {
    "rural": real_r["totals"].copy(),
    "urban": real_u["totals"].copy(),
}
avg_valid_px = {
    "rural": float(real_r["avg_valid_px"]),
    "urban": float(real_u["avg_valid_px"]),
}

accepted = {"rural": 0, "urban": 0}
manifest = []

print(SYNTH_ROOT)

log_path = SYNTH_ROOT / "meta" / "generation_log.jsonl"
print("Logging to:", log_path)

def choose_domain(domain_quota: dict) -> str:
    rem_r = domain_quota["rural"] - accepted["rural"]
    rem_u = domain_quota["urban"] - accepted["urban"]
    if rem_r <= 0 and rem_u <= 0:
        return "rural"
    if rem_r <= 0:
        return "urban"
    if rem_u <= 0:
        return "rural"
    pr = rem_r / (rem_r + rem_u)
    return "rural" if random.random() < pr else "urban"

total_to_make = sum(domain_quota.values())
print(f"Will generate {total_to_make} synthetic samples under {SYNTH_ROOT}")

/data/inr/llm/DIFF_CD/Diffusor/SyntheticDataset
Logging to: /data/inr/llm/DIFF_CD/Diffusor/SyntheticDataset/meta/generation_log.jsonl
Will generate 2000 synthetic samples under /data/inr/llm/DIFF_CD/Diffusor/SyntheticDataset


In [12]:
with open(log_path, "a", encoding="utf-8") as flog:
    global_i = 0
    while (accepted["rural"] + accepted["urban"]) < total_to_make:
        domain = choose_domain(domain_quota)

        pi = cur[domain] / max(EPS, cur[domain].sum())
        t = make_flat_target(pi)

        rem = (domain_quota[domain] - accepted[domain])
        plan_total = float(cur[domain].sum() + rem * avg_valid_px[domain])

        deficit = np.maximum(0.0, plan_total * t - cur[domain])

        focus_k = pick_focus_class(deficit, pi)
        focus_name = CLASS_NAMES[focus_k]

        r_req = propose_ratio(domain, focus_k, t, pi)
        ratios_str = f"{focus_name}:{r_req:.4f}"

        ok = False
        last_metrics = None

        for attempt in range(MAX_ATTEMPTS_PER_SAMPLE):
            seed = SEED0 + (accepted["rural"] + accepted["urban"]) * 1000 + attempt
            tmp_out = SYNTH_ROOT / "tmp" / f"{domain}_{accepted[domain]:06d}_try{attempt:02d}"
            if tmp_out.exists():
                shutil.rmtree(tmp_out, ignore_errors=True)
            tmp_out.mkdir(parents=True, exist_ok=True)

            run_sample_pair(tmp_out, domain, ratios_str, seed)

            layout_path = tmp_out / "layout.png"
            m_idx = np.array(Image.open(layout_path))
            if m_idx.ndim == 3:
                m_idx = m_idx[:, :, 0]
            m_idx = m_idx.astype(np.int64)

            counts = hist_counts(m_idx)
            p = counts / max(EPS, counts.sum())

            ok, metrics = accept_layout(domain, p, focus_k, r_req)
            last_metrics = metrics

            record = {
                "domain": domain,
                "focus_class": focus_name,
                "ratios": ratios_str,
                "seed": seed,
                "attempt": attempt,
                "ok": bool(ok),
                "metrics": metrics,
                "p": [float(x) for x in p.tolist()],
            }
            flog.write(json.dumps(record) + "\n")
            flog.flush()

            if not ok:
                shutil.rmtree(tmp_out, ignore_errors=True)
                continue

            stem = f"{domain}_{accepted[domain]:06d}"
            img_dst, msk_dst, meta_dst, m_idx2 = write_synth_sample(domain, stem, tmp_out)

            cur[domain] += hist_counts(m_idx2)

            accepted[domain] += 1
            global_i += 1

            manifest.append({
                "domain": domain,
                "image": str(img_dst),
                "mask": str(msk_dst),
                "meta": str(meta_dst),
                "ratios": ratios_str,
                "focus_class": focus_name,
                "metrics": last_metrics,
            })

            shutil.rmtree(tmp_out, ignore_errors=True)

            if global_i % 25 == 0:
                print(f"[{global_i}/{total_to_make}] accepted | quotas={domain_quota} | accepted={accepted} | last={domain}/{ratios_str} metrics={last_metrics}")
            break

        if not ok:
            damp = 0.85
            ratios_str = f"{focus_name}:{(r_req*damp):.4f}"
            print(f"WARNING: too many rejects for {domain}/{focus_name}. Backing off ratio to {ratios_str} and continuing.")

manifest_path = SYNTH_ROOT / "meta" / "manifest.json"
manifest_path.write_text(json.dumps({
    "real_root": str(REAL_ROOT),
    "synth_root": str(SYNTH_ROOT),
    "synth_total": total_to_make,
    "domain_quota": domain_quota,
    "accepted": accepted,
    "items": manifest,
}, indent=2), encoding="utf-8")

print("Done.")
print("Synthetic dataset:", SYNTH_ROOT)
print("Manifest:", manifest_path)
print("Accepted:", accepted)

INFO - Resolved layout checkpoint to: /data/inr/llm/DIFF_CD/Diffusor/outputsV3/layout_d3pm_masked_sparse_80k_domain_cond/checkpoint-79000
INFO - Inferred --layout_size=256 from --image_size=1024 (override with --layout_size).
`torch_dtype` is deprecated! Use `dtype` instead!
Traceback (most recent call last):
  File "/data/inr/llm/DIFF_CD/Diffusor/SyntheticGen/src/scripts/sample_pair.py", line 1186, in <module>
    main()
  File "/data/inr/llm/DIFF_CD/Diffusor/SyntheticGen/src/scripts/sample_pair.py", line 1137, in main
    images = _vae_decode(vae, latents / vae.config.scaling_factor)
  File "/data/inr/llm/DIFF_CD/Diffusor/SyntheticGen/src/scripts/sample_pair.py", line 390, in _vae_decode
    decoded = vae.decode(latents)
  File "/data/inr/llm/DIFF_CD/Diffusor/diffusers/src/diffusers/utils/accelerate_utils.py", line 46, in wrapper
    return method(self, *args, **kwargs)
  File "/data/inr/llm/DIFF_CD/Diffusor/diffusers/src/diffusers/models/autoencoders/autoencoder_kl.py", line 237, in

CalledProcessError: Command '['python3', '/data/inr/llm/DIFF_CD/Diffusor/SyntheticGen/src/scripts/sample_pair.py', '--layout_ckpt', '/data/inr/llm/DIFF_CD/Diffusor/outputsV3/layout_d3pm_masked_sparse_80k_domain_cond', '--layout_checkpoint', '79000', '--layout_diffusion_type', 'd3pm', '--domain', 'urban', '--domain_cond_scale', '1.0', '--controlnet_ckpt', '/data/inr/llm/DIFF_CD/Diffusor/outputsV3/controlnet_ratio_lora_ckpt18000_layout80000/checkpoint-112000', '--base_model', '/home/nvidia/.cache/huggingface/hub/models--runwayml--stable-diffusion-v1-5/snapshots/451f4fe16113bff5a5d2269ed5ad43b0592e9a14', '--save_dir', '/data/inr/llm/DIFF_CD/Diffusor/SyntheticDataset/tmp/urban_000000_try00', '--ratios', 'barren:0.0748', '--prompt', 'A high-resolution satellite image of an urban area', '--image_size', '1024', '--num_inference_steps_layout', '200', '--num_inference_steps_image', '50', '--guidance_scale', '1.0', '--guidance_rescale', '0.0', '--control_scale', '1.0', '--lora_scale', '1.0', '--sampler', 'ddim', '--seed', '123', '--dtype', 'fp16', '--device', 'cuda:0']' returned non-zero exit status 1.