In [None]:
# ============================================================
# FOCI SIMULATOR (v4 - "natural foci")
#  - Noise ONLY inside cell (optional)
#  - Mix of clean/noisy cells
#  - Foci added LAST in a microscope-like way (PSF blur + local contrast)
#  - Foci masks: no holes (fill_holes + closing)
# ============================================================

import os
import glob
import math
import random
import numpy as np
import pandas as pd
from PIL import Image, ImageOps
from scipy.ndimage import (
    gaussian_filter,
    binary_dilation,
    binary_erosion,
    binary_opening,
    binary_closing,
    binary_fill_holes,
)

# -----------------------------
# Repro
# -----------------------------
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)

# -----------------------------
# I/O helpers
# -----------------------------
def ensure_dir(p):
    os.makedirs(p, exist_ok=True)

def stem_no_ext(p):
    return os.path.splitext(os.path.basename(p))[0]

def load_rgb_resized(path, size=224):
    img = Image.open(path).convert("RGB")
    img = img.resize((size, size), resample=Image.BILINEAR)
    return img

def load_mask01_resized(path, size=224):
    m = Image.open(path).convert("L")
    m = m.resize((size, size), resample=Image.NEAREST)
    m = np.asarray(m).astype(np.float32) / 255.0
    m = (m > 0.5).astype(np.float32)
    return m  # (H,W) 0/1

def to_rgb_uint8(gray01):
    x = np.clip(gray01, 0.0, 1.0)
    u8 = (x * 255.0).round().astype(np.uint8)
    return np.stack([u8, u8, u8], axis=-1)

def save_u8_png(arr_u8, path):
    Image.fromarray(arr_u8).save(path)

# -----------------------------
# Sampling + placement
# -----------------------------
def sample_n_foci(rng, positive, neg_range=(0, 4), pos_range=(5, 20), mode="triangular"):
    lo, hi = (pos_range if positive else neg_range)
    lo, hi = int(lo), int(hi)
    if hi < lo:
        hi = lo
    if mode == "triangular" and hi > lo:
        mid = (lo + hi) / 2.0
        val = rng.triangular(lo, mid, hi)
        return int(np.clip(int(round(val)), lo, hi))
    return int(rng.integers(lo, hi + 1))

def pick_centers_in_mask(rng, mask01, n, min_dist_px=6, max_tries=20000):
    ys, xs = np.where(mask01 > 0.5)
    if len(ys) == 0:
        return []
    pts = []
    tries = 0
    while len(pts) < n and tries < max_tries:
        i = int(rng.integers(0, len(ys)))
        cy, cx = int(ys[i]), int(xs[i])

        ok = True
        for (py, px) in pts:
            if (cy - py) ** 2 + (cx - px) ** 2 < (min_dist_px ** 2):
                ok = False
                break

        if ok or (len(pts) > 0 and tries > max_tries // 3):
            pts.append((cy, cx))
        tries += 1

    while len(pts) < n:
        i = int(rng.integers(0, len(ys)))
        pts.append((int(ys[i]), int(xs[i])))

    return pts[:n]

# -----------------------------
# Focus generator utilities
# -----------------------------
def elliptical_gaussian(H, W, cy, cx, sx, sy, theta_rad, amp=1.0):
    yy, xx = np.mgrid[0:H, 0:W]
    y = yy - cy
    x = xx - cx
    ct, st = math.cos(theta_rad), math.sin(theta_rad)
    xr = ct * x + st * y
    yr = -st * x + ct * y
    g = np.exp(-(xr * xr / (2 * sx * sx) + yr * yr / (2 * sy * sy))).astype(np.float32)
    return amp * g

def _seal_mask_no_holes(mask_bool, rng=None):
    """
    Make sure there are NO holes and boundaries are slightly sealed.
    """
    m = binary_fill_holes(mask_bool)
    m = binary_closing(m, iterations=1)
    if rng is not None and rng.random() < 0.25:
        m = binary_dilation(m, iterations=1)
        m = binary_erosion(m, iterations=1)
        m = binary_fill_holes(m)
        m = binary_closing(m, iterations=1)
    return m

def make_one_focus_irregular(
    H, W, cy, cx, rng,
    size_px_range=(2.0, 10.0),
    amp_range=(0.6, 1.4),
    # Rings can create holes; keep them rare
    type_probs=(0.70, 0.03, 0.27)
):
    """
    Focus types:
      0: lumpy blob
      1: ring-ish (rare) -> we seal holes anyway
      2: jagged blob
    Returns:
      fmap: float32 intensity map (not normalized)
      fmask: bool mask
    """
    focus_type = int(rng.choice([0, 1, 2], p=np.array(type_probs)))
    base_amp = float(rng.uniform(*amp_range))
    theta = float(rng.uniform(0, math.pi))
    s = float(rng.uniform(*size_px_range))

    if focus_type == 0:
        k = int(rng.integers(2, 4))
        fmap = np.zeros((H, W), np.float32)
        for _ in range(k):
            dy = float(rng.normal(0, s * 0.35))
            dx = float(rng.normal(0, s * 0.35))
            sx = float(rng.uniform(s * 0.5, s * 1.2))
            sy = float(rng.uniform(s * 0.4, s * 1.3))
            amp = base_amp * float(rng.uniform(0.6, 1.0))
            fmap += elliptical_gaussian(
                H, W, cy + dy, cx + dx, sx, sy,
                theta + float(rng.normal(0, 0.4)),
                amp=amp
            )

        m = fmap / (fmap.max() + 1e-12)
        thr = float(rng.uniform(0.22, 0.36))
        fmask = (m >= thr)

        if rng.random() < 0.5:
            fmask = binary_dilation(fmask, iterations=int(rng.integers(1, 3)))
        if rng.random() < 0.4:
            fmask = binary_erosion(fmask, iterations=int(rng.integers(0, 2)))

        fmask = _seal_mask_no_holes(fmask, rng=rng)
        return fmap, fmask

    if focus_type == 1:
        # ring-ish: outer - inner. We'll still seal/fill holes.
        sx = float(rng.uniform(s * 0.7, s * 1.4))
        sy = float(rng.uniform(s * 0.7, s * 1.4))
        outer = elliptical_gaussian(H, W, cy, cx, sx, sy, theta, amp=base_amp)
        inner = elliptical_gaussian(
            H, W,
            cy + float(rng.normal(0, 0.5)),
            cx + float(rng.normal(0, 0.5)),
            sx * float(rng.uniform(0.35, 0.6)),
            sy * float(rng.uniform(0.35, 0.6)),
            theta + float(rng.normal(0, 0.2)),
            amp=base_amp * float(rng.uniform(0.8, 1.1))
        )
        fmap = np.clip(outer - inner, 0.0, None)

        m = fmap / (fmap.max() + 1e-12)
        thr = float(rng.uniform(0.18, 0.30))
        fmask = (m >= thr)
        fmask = _seal_mask_no_holes(fmask, rng=rng)
        return fmap, fmask

    # jagged blob
    sx = float(rng.uniform(s * 0.6, s * 1.25))
    sy = float(rng.uniform(s * 0.6, s * 1.25))
    base = elliptical_gaussian(H, W, cy, cx, sx, sy, theta, amp=base_amp)

    noise = rng.normal(0.0, 1.0, size=(H, W)).astype(np.float32)
    noise = gaussian_filter(noise, sigma=float(rng.uniform(1.0, 2.2)))
    noise = (noise - noise.min()) / (noise.max() - noise.min() + 1e-12)
    warp = 0.85 + 0.45 * noise
    fmap = base * warp

    m = fmap / (fmap.max() + 1e-12)
    thr = float(rng.uniform(0.24, 0.42))
    fmask = (m >= thr)

    if rng.random() < 0.35:
        fmask = binary_opening(fmask, iterations=1)
    if rng.random() < 0.55:
        fmask = binary_closing(fmask, iterations=1)

    fmask = _seal_mask_no_holes(fmask, rng=rng)
    return fmap, fmask

def generate_foci_map_and_mask(
    gray01, cell_mask01, rng, n_foci,
    size_px_range=(2.0, 10.0),
    amp_range=(0.6, 1.4),
    min_dist_px=6
):
    """
    Returns:
      foci_map01: float32 (H,W) in 0..1
      foci_mask: bool (H,W) hole-sealed
    """
    H, W = gray01.shape
    centers = pick_centers_in_mask(rng, cell_mask01, n_foci, min_dist_px=min_dist_px)

    fmap = np.zeros((H, W), np.float32)
    fmask = np.zeros((H, W), bool)

    for (cy, cx) in centers:
        fm, mk = make_one_focus_irregular(
            H, W, cy, cx, rng,
            size_px_range=size_px_range,
            amp_range=amp_range
        )
        fmap += fm
        fmask |= mk

    # restrict to cell
    fmap *= cell_mask01.astype(np.float32)
    fmask &= (cell_mask01 > 0.5)

    # normalize map to 0..1
    mx = float(fmap.max())
    if mx > 1e-8:
        fmap = fmap / mx
    fmap = np.clip(fmap, 0.0, 1.0)

    # final seal
    fmask = _seal_mask_no_holes(fmask, rng=rng)
    return fmap, fmask

# -----------------------------
# Noise ONLY inside cell
# -----------------------------
def add_cell_body_noise_only(
    img01, cell_mask01, rng,
    cell_noise_sigma_range=(0.01, 0.05),
    lowfreq_amp_range=(0.00, 0.10),
    speck_prob=0.0006,          # keep small (specks can look like fake foci)
    blur_prob=0.20
):
    """
    Adds noise/texture ONLY where cell_mask==1.
    Outside cell: unchanged.
    """
    H, W = img01.shape
    m = (cell_mask01 > 0.5)
    out = img01.copy()

    # Gaussian noise inside cell
    sig = float(rng.uniform(*cell_noise_sigma_range))
    n = rng.normal(0.0, sig, size=(H, W)).astype(np.float32)
    out[m] = np.clip(out[m] + n[m], 0.0, 1.0)

    # Low-frequency texture inside cell
    lf_amp = float(rng.uniform(*lowfreq_amp_range))
    if lf_amp > 0:
        lf = rng.normal(0.0, 1.0, size=(H, W)).astype(np.float32)
        lf = gaussian_filter(lf, sigma=float(rng.uniform(10.0, 22.0)))
        lf = (lf - lf.min()) / (lf.max() - lf.min() + 1e-12)  # 0..1
        lf = (lf - 0.5) * 2.0  # -1..1
        out[m] = np.clip(out[m] + lf_amp * lf[m], 0.0, 1.0)

    # Rare specks inside cell only (VERY rare)
    if speck_prob and speck_prob > 0:
        specks = (rng.random((H, W)) < speck_prob) & m
        if specks.any():
            out[specks] = np.clip(out[specks] + rng.uniform(0.15, 0.45), 0.0, 1.0)

    # Optional blur blend inside cell (microscope softness)
    if rng.random() < blur_prob:
        blur_img = gaussian_filter(out, sigma=float(rng.uniform(0.3, 1.0)))
        alpha = float(rng.uniform(0.15, 0.50))
        out[m] = np.clip((1 - alpha) * out[m] + alpha * blur_img[m], 0.0, 1.0)

    return out

# -----------------------------
# Normalization helper (match mean/std)
# -----------------------------
def match_mean_std(target01, ref01, mask01=None, eps=1e-6):
    if mask01 is None:
        t = target01
        r = ref01
    else:
        m = (mask01 > 0.5)
        if m.sum() < 20:
            return target01
        t = target01[m]
        r = ref01[m]

    t_mean, t_std = float(t.mean()), float(t.std())
    r_mean, r_std = float(r.mean()), float(r.std())
    t_std = max(t_std, eps)

    out = (target01 - t_mean) / t_std * r_std + r_mean
    return out

def gamma_jitter(img01, rng, gamma_range=(0.90, 1.10)):
    g = float(rng.uniform(*gamma_range))
    return np.clip(img01, 0.0, 1.0) ** g

# -----------------------------
# Natural foci add (core + PSF + local contrast)
# -----------------------------
def add_foci_natural(
    cell_bg01, foci_map01, foci_mask, rng,
    alpha_range=(0.35, 1.10),
    psf_sigma_range=(0.6, 1.4),
    foci_gamma_range=(0.60, 0.95),
    halo_strength_range=(0.10, 0.30),
    local_bg_sigma_range=(4.0, 8.0),
    min_local_delta_range=(0.02, 0.05),
    edge_soft_sigma_range=(0.9, 1.8),
):
    """
    Add foci in a microscope-like way (no hard binary boost):
      - blur foci_map (PSF)
      - gamma shape for core-like intensity
      - subtle halo
      - add as local-contrast delta relative to local nuclear background
      - blend with a soft mask (no edges)
    """
    # local nuclear background estimate (large blur)
    bg_sigma = float(rng.uniform(*local_bg_sigma_range))
    local_bg = gaussian_filter(cell_bg01, sigma=bg_sigma)

    # PSF blur the foci intensity
    psf_sigma = float(rng.uniform(*psf_sigma_range))
    fmap = gaussian_filter(foci_map01.astype(np.float32), sigma=psf_sigma)

    # normalize fmap to 0..1
    mx = float(fmap.max())
    if mx > 1e-8:
        fmap = fmap / mx
    fmap = np.clip(fmap, 0.0, 1.0)

    # core shaping
    g = float(rng.uniform(*foci_gamma_range))
    fmap = fmap ** g

    # soft mask (edge smoothing)
    soft_sigma = float(rng.uniform(*edge_soft_sigma_range))
    soft_mask = gaussian_filter(foci_mask.astype(np.float32), sigma=soft_sigma)
    soft_mask = np.clip(soft_mask, 0.0, 1.0)

    # halo: difference of gaussians on mask (subtle)
    halo_strength = float(rng.uniform(*halo_strength_range))
    halo = gaussian_filter(foci_mask.astype(np.float32), sigma=1.2) - gaussian_filter(
        foci_mask.astype(np.float32), sigma=3.2
    )
    halo = np.clip(halo, 0.0, 1.0)
    fmap = np.clip(fmap * (1.0 + halo_strength * halo), 0.0, 1.0)

    # delta brightness: stronger when background is darker (avoids saturation)
    alpha = float(rng.uniform(*alpha_range))
    delta = alpha * fmap * (0.35 + 0.65 * (1.0 - local_bg))

    # enforce minimum local contrast (still shaped by fmap)
    min_local_delta = float(rng.uniform(*min_local_delta_range))
    delta = np.maximum(delta, min_local_delta * fmap)

    # apply softly around the foci region
    delta = delta * soft_mask

    out = np.clip(cell_bg01 + delta, 0.0, 1.0)
    return out

# -----------------------------
# Main generator
# -----------------------------
def generate_sim_foci_from_masked_folder(
    input_dir,
    output_dir,
    out_csv_path,
    input_size=224,
    variants_per_image=3,
    seed=42,

    # label & count
    p_positive=0.5,
    neg_range=(0, 4),
    pos_range=(5, 20),
    count_mode="triangular",
    pos_threshold=5,

    # foci shape
    size_px_range=(2.0, 10.0),
    amp_range=(0.6, 1.4),
    min_dist_px=6,

    # clean/noisy mixture
    P_CLEAN_CELL=0.30,
    CLEAN_SMOOTH_SIGMA_RANGE=(0.8, 2.0),

    # noise params (noisy samples only)
    cell_noise_sigma_range=(0.01, 0.05),
    lowfreq_amp_range=(0.00, 0.10),
    speck_prob=0.0006,
    blur_prob=0.20,

    # gamma
    gamma_range=(0.90, 1.10),
    APPLY_GAMMA_ON_CLEAN=False,

    # match stats
    match_stats_in_cell=True,

    # natural foci params
    foci_alpha_range=(0.35, 1.10),
    foci_psf_sigma_range=(0.6, 1.4),
    foci_gamma_range=(0.60, 0.95),
    foci_halo_strength_range=(0.10, 0.30),
    foci_local_bg_sigma_range=(4.0, 8.0),
    foci_min_local_delta_range=(0.02, 0.05),
    foci_edge_soft_sigma_range=(0.9, 1.8),
):
    ensure_dir(output_dir)
    set_seed(seed)
    rng = np.random.default_rng(seed)

    # find mask files
    mask_files = []
    for ext in ["png", "jpg", "jpeg", "tif", "tiff", "bmp"]:
        mask_files += glob.glob(os.path.join(input_dir, f"*_*mask*.{ext}"))
    mask_files = sorted([p for p in mask_files if "_mask" in stem_no_ext(p)])
    if len(mask_files) == 0:
        raise FileNotFoundError(f"No '*_mask' files found in {input_dir}")

    rows = []

    for mpath in mask_files:
        mstem = stem_no_ext(mpath)
        base_stem = mstem[:-5] if mstem.endswith("_mask") else mstem.replace("_mask", "")

        # find corresponding image
        img_path = None
        for ext in ["png", "jpg", "jpeg", "tif", "tiff", "bmp"]:
            cand = os.path.join(input_dir, f"{base_stem}.{ext}")
            if os.path.exists(cand):
                img_path = cand
                break
        if img_path is None:
            print(f"[WARN] No image found for mask: {mpath}")
            continue

        img_pil = load_rgb_resized(img_path, size=input_size)
        base_gray01 = np.asarray(ImageOps.grayscale(img_pil)).astype(np.float32) / 255.0
        cell_mask01 = load_mask01_resized(mpath, size=input_size)

        if (cell_mask01 > 0.5).sum() < 50:
            print(f"[WARN] Cell mask too small/empty: {mpath}")
            continue

        for v in range(variants_per_image):
            # --- sample label/count
            positive = (rng.random() < float(p_positive))
            n_foci = sample_n_foci(rng, positive, neg_range=neg_range, pos_range=pos_range, mode=count_mode)
            label = 1 if n_foci >= int(pos_threshold) else 0

            # --- generate foci map + hole-sealed mask
            foci_map01, foci_mask = generate_foci_map_and_mask(
                base_gray01, cell_mask01, rng, n_foci,
                size_px_range=size_px_range,
                amp_range=amp_range,
                min_dist_px=min_dist_px
            )

            # ======================================================
            # 1) Make cell background (clean or noisy)
            # ======================================================
            is_clean = (rng.random() < float(P_CLEAN_CELL))
            cell_bg01 = base_gray01.copy()

            if is_clean:
                # smooth only inside cell to reduce texture and make foci pop
                sig = float(rng.uniform(*CLEAN_SMOOTH_SIGMA_RANGE))
                smooth = gaussian_filter(cell_bg01, sigma=sig)
                m = (cell_mask01 > 0.5)
                cell_bg01[m] = smooth[m]
            else:
                cell_bg01 = add_cell_body_noise_only(
                    cell_bg01, cell_mask01, rng,
                    cell_noise_sigma_range=cell_noise_sigma_range,
                    lowfreq_amp_range=lowfreq_amp_range,
                    speck_prob=speck_prob,
                    blur_prob=blur_prob
                )

            if match_stats_in_cell:
                cell_bg01 = match_mean_std(cell_bg01, base_gray01, mask01=cell_mask01)
                cell_bg01 = np.clip(cell_bg01, 0.0, 1.0)

            if (not is_clean) or APPLY_GAMMA_ON_CLEAN:
                cell_bg01 = gamma_jitter(cell_bg01, rng, gamma_range=gamma_range)

            cell_bg01 = np.clip(cell_bg01, 0.0, 1.0)

            # ======================================================
            # 2) Add foci LAST in a natural way (no hard boost)
            # ======================================================
            sim01 = add_foci_natural(
                cell_bg01, foci_map01, foci_mask, rng,
                alpha_range=foci_alpha_range,
                psf_sigma_range=foci_psf_sigma_range,
                foci_gamma_range=foci_gamma_range,
                halo_strength_range=foci_halo_strength_range,
                local_bg_sigma_range=foci_local_bg_sigma_range,
                min_local_delta_range=foci_min_local_delta_range,
                edge_soft_sigma_range=foci_edge_soft_sigma_range,
            )

            # --- save
            out_stem = f"{base_stem}_sim{v:03d}"
            out_img_path = os.path.join(output_dir, f"{out_stem}.png")
            out_foci_mask_path = os.path.join(output_dir, f"{out_stem}_focimask.png")
            out_cell_mask_copy = os.path.join(output_dir, f"{out_stem}_cellmask.png")

            save_u8_png(to_rgb_uint8(sim01), out_img_path)
            save_u8_png((foci_mask.astype(np.uint8) * 255), out_foci_mask_path)
            save_u8_png((cell_mask01.astype(np.uint8) * 255), out_cell_mask_copy)

            rows.append({
                "image_path": out_img_path,
                "mask_path": out_foci_mask_path,
                "label": int(label),
                "n_foci": int(n_foci),
                "is_clean_cell": int(is_clean),
                "src_image_path": img_path,
                "src_cellmask_path": mpath,
                "cell_mask_path": out_cell_mask_copy
            })

    df_out = pd.DataFrame(rows)
    df_out.to_csv(out_csv_path, index=False)
    print("Saved CSV:", out_csv_path)
    print("Num simulated samples:", len(df_out))
    return df_out

# ============================================================
# EXAMPLE USAGE
# ============================================================
if __name__ == "__main__":
    INPUT_DIR  = "/content/drive/MyDrive/FYP/gH2AX_1Gy_4Gy_8Gy_10Gy_Replicates2+3_patches_224_batched_nuclie_channel_batch/batch_0023"
    OUTPUT_DIR = "/content/drive/MyDrive/FYP/foci_simulation_out_v04_natural/gH2AX_1Gy_4Gy_8Gy_10Gy_Replicates2+3_patches_224/batch_0023"
    OUT_CSV    = "/content/drive/MyDrive/FYP/foci_simulation_out_v04_natural/gH2AX_1Gy_4Gy_8Gy_10Gy_Replicates2+3_patches_224/batch_0023.csv"



    # INPUT_DIR  = "/content/drive/MyDrive/FYP/gH2AX_1Gy_4Gy_8Gy_10Gy_Replicates2+3_patches_224_batched_nuclie_channel_batch/batch_0001"
    # OUTPUT_DIR = "/content/drive/MyDrive/FYP/foci_simulation_out_with_noiseoncellonly_v03_gH2AX_1Gy_4Gy_8Gy_10Gy_Replicates2+3_patches_224_batched_nuclie_channel_batch/batch_0001"
    # OUT_CSV    = "/content/drive/MyDrive/FYP/foci_simulation_out_with_noiseoncellonly_v03_gH2AX_1Gy_4Gy_8Gy_10Gy_Replicates2+3_patches_224_batched_nuclie_channel_batch/batch_0001_sim.csv"


    df = generate_sim_foci_from_masked_folder(
        input_dir=INPUT_DIR,
        output_dir=OUTPUT_DIR,
        out_csv_path=OUT_CSV,
        input_size=224,
        variants_per_image=5,
        seed=42,

        # mix of clean/noisy
        P_CLEAN_CELL=0.35,
        APPLY_GAMMA_ON_CLEAN=False,

        # foci count
        p_positive=0.5,
        neg_range=(0, 4),
        pos_range=(5, 20),
        pos_threshold=5,

        # foci shape variety
        size_px_range=(2.0, 12.0),
        amp_range=(0.6, 1.5),
        min_dist_px=5,

        # noisy cells: keep specks very low (specks can look like fake foci)
        cell_noise_sigma_range=(0.01, 0.06),
        lowfreq_amp_range=(0.00, 0.10),
        speck_prob=0.0005,
        blur_prob=0.20,

        # natural foci strength (tune this first)
        foci_alpha_range=(0.40, 1.20),
        foci_psf_sigma_range=(0.7, 1.6),
        foci_gamma_range=(0.60, 0.92),
        foci_halo_strength_range=(0.08, 0.22),
        foci_min_local_delta_range=(0.02, 0.05),
    )

    print(df.head())
