In [None]:
# ============================================================
# Foci Simulator (CSV -> auto cell mask -> 20 variations/image)
# - Input CSV: must contain columns: image_path, label (0/1)
# - Auto-extracts cell mask from the image (no mask files needed)
# - Makes cell darker + foci pop out naturally
# - Background outside cell kept black
# - Generates 20 variations per image
# - Saves:
#     <out_dir>/images/<stem>_vXXX.png
#     <out_dir>/masks/<stem>_vXXX_focimask.png
#     <out_dir>/masks/<stem>_vXXX_cellmask.png
#   and a CSV with paths + metadata
# ============================================================

import os
import math
import random
import numpy as np
import pandas as pd
from PIL import Image, ImageOps
from scipy.ndimage import (
    gaussian_filter,
    binary_dilation,
    binary_erosion,
    binary_opening,
    binary_closing,
    binary_fill_holes,
    label as cc_label,
)

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)

def ensure_dir(p):
    os.makedirs(p, exist_ok=True)

def stem_no_ext(p):
    return os.path.splitext(os.path.basename(p))[0]

def load_gray01_resized(path, size=224):
    img = Image.open(path).convert("RGB")
    img = img.resize((size, size), resample=Image.BILINEAR)
    gray = np.asarray(ImageOps.grayscale(img)).astype(np.float32) / 255.0
    return gray  # (H,W) float [0,1]

def to_rgb_uint8(gray01):
    x = np.clip(gray01, 0.0, 1.0)
    u8 = (x * 255.0).round().astype(np.uint8)
    return np.stack([u8, u8, u8], axis=-1)

def save_u8_png(arr_u8, path):
    Image.fromarray(arr_u8).save(path)

def save_mask_png(mask01, path):
    u8 = (mask01.astype(np.uint8) * 255)
    Image.fromarray(u8).save(path)

# -----------------------------
# Otsu threshold
# -----------------------------
def otsu_threshold(x01, nbins=256):
    x = np.clip(x01, 0.0, 1.0).ravel()
    hist, bin_edges = np.histogram(x, bins=nbins, range=(0.0, 1.0))
    hist = hist.astype(np.float64)

    w = hist / (hist.sum() + 1e-12)
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) * 0.5

    w0 = np.cumsum(w)
    w1 = 1.0 - w0
    mu0 = np.cumsum(w * bin_centers) / (w0 + 1e-12)
    muT = (w * bin_centers).sum()
    mu1 = (muT - w0 * mu0) / (w1 + 1e-12)

    sigma_b2 = w0 * w1 * (mu0 - mu1) ** 2
    k = int(np.argmax(sigma_b2))
    return float(bin_centers[k])

# -----------------------------
# Auto cell mask extraction
# -----------------------------
def extract_cell_mask(gray01, rng,
                      pre_smooth_sigma_range=(1.0, 2.2),
                      min_area_frac=0.03):
    """
    Assumes: cell is brighter than background (black-ish background).
    Steps:
      - smooth -> otsu -> keep largest component -> fill holes -> close/open
    """
    H, W = gray01.shape
    sig = float(rng.uniform(*pre_smooth_sigma_range))
    sm = gaussian_filter(gray01, sigma=sig)

    thr = otsu_threshold(sm)
    m = (sm > thr)

    # cleanup
    m = binary_closing(m, iterations=2)
    m = binary_opening(m, iterations=1)
    m = binary_fill_holes(m)

    # keep largest connected component
    lab, n = cc_label(m)
    if n > 1:
        areas = [(lab == i).sum() for i in range(1, n + 1)]
        k = int(np.argmax(areas)) + 1
        m = (lab == k)
        m = binary_fill_holes(m)

    # ensure not tiny
    if m.sum() < int(min_area_frac * H * W):
        # fallback: more permissive threshold
        thr2 = max(0.02, thr * 0.75)
        m = (sm > thr2)
        m = binary_closing(m, iterations=2)
        m = binary_fill_holes(m)
        lab, n = cc_label(m)
        if n > 1:
            areas = [(lab == i).sum() for i in range(1, n + 1)]
            k = int(np.argmax(areas)) + 1
            m = (lab == k)
            m = binary_fill_holes(m)

    return m.astype(np.float32)  # (H,W) {0,1}

# -----------------------------
# Sampling + placement
# -----------------------------
def sample_n_foci(rng, label, neg_range=(0, 3), pos_range=(6, 25), mode="triangular"):
    positive = (int(label) == 1)
    lo, hi = (pos_range if positive else neg_range)
    lo, hi = int(lo), int(hi)
    if hi < lo:
        hi = lo
    if mode == "triangular" and hi > lo:
        mid = (lo + hi) / 2.0
        val = rng.triangular(lo, mid, hi)
        return int(np.clip(int(round(val)), lo, hi))
    return int(rng.integers(lo, hi + 1))

def pick_centers_in_mask(rng, mask01, n, min_dist_px=6, max_tries=25000):
    ys, xs = np.where(mask01 > 0.5)
    if len(ys) == 0:
        return []
    pts = []
    tries = 0
    while len(pts) < n and tries < max_tries:
        i = int(rng.integers(0, len(ys)))
        cy, cx = int(ys[i]), int(xs[i])
        ok = True
        for (py, px) in pts:
            if (cy - py) ** 2 + (cx - px) ** 2 < (min_dist_px ** 2):
                ok = False
                break
        if ok:
            pts.append((cy, cx))
        tries += 1

    # fallback fill
    while len(pts) < n:
        i = int(rng.integers(0, len(ys)))
        pts.append((int(ys[i]), int(xs[i])))

    return pts[:n]

# -----------------------------
# Natural looking foci generator
# -----------------------------
def elliptical_gaussian(H, W, cy, cx, sx, sy, theta_rad, amp=1.0):
    yy, xx = np.mgrid[0:H, 0:W]
    y = yy - cy
    x = xx - cx
    ct, st = math.cos(theta_rad), math.sin(theta_rad)
    xr =  ct * x + st * y
    yr = -st * x + ct * y
    g = np.exp(-(xr * xr / (2 * sx * sx) + yr * yr / (2 * sy * sy))).astype(np.float32)
    return amp * g

def seal_mask(mask_bool):
    m = binary_fill_holes(mask_bool)
    m = binary_closing(m, iterations=1)
    m = binary_fill_holes(m)
    return m

def make_one_focus_natural(H, W, cy, cx, rng,
                           size_px_range=(1.8, 8.5),
                           amp_range=(0.7, 1.6),
                           # more blob-like, fewer donut/ring artifacts
                           type_probs=(0.70, 0.00, 0.30)):
    """
    Types:
      0: multi-blob lumpy
      2: softly jagged blob (warp)
    """
    focus_type = int(rng.choice([0, 2], p=np.array([type_probs[0], type_probs[2]])))
    base_amp = float(rng.uniform(*amp_range))
    theta = float(rng.uniform(0, math.pi))
    s = float(rng.uniform(*size_px_range))

    if focus_type == 0:
        k = int(rng.integers(2, 5))
        fmap = np.zeros((H, W), np.float32)
        for _ in range(k):
            dy = float(rng.normal(0, s * 0.35))
            dx = float(rng.normal(0, s * 0.35))
            sx = float(rng.uniform(s * 0.45, s * 1.15))
            sy = float(rng.uniform(s * 0.45, s * 1.25))
            amp = base_amp * float(rng.uniform(0.55, 1.0))
            fmap += elliptical_gaussian(
                H, W, cy + dy, cx + dx, sx, sy,
                theta + float(rng.normal(0, 0.35)),
                amp=amp
            )

        # slightly blur to mimic PSF
        fmap = gaussian_filter(fmap, sigma=float(rng.uniform(0.4, 1.2)))

        m = fmap / (fmap.max() + 1e-12)
        thr = float(rng.uniform(0.22, 0.38))
        fmask = (m >= thr)
        if rng.random() < 0.35:
            fmask = binary_dilation(fmask, iterations=1)
        if rng.random() < 0.25:
            fmask = binary_erosion(fmask, iterations=1)
        fmask = seal_mask(fmask)
        return fmap, fmask

    # type 2: softly jagged blob
    sx = float(rng.uniform(s * 0.55, s * 1.20))
    sy = float(rng.uniform(s * 0.55, s * 1.25))
    base = elliptical_gaussian(H, W, cy, cx, sx, sy, theta, amp=base_amp)

    noise = rng.normal(0.0, 1.0, size=(H, W)).astype(np.float32)
    noise = gaussian_filter(noise, sigma=float(rng.uniform(1.0, 2.3)))
    noise = (noise - noise.min()) / (noise.max() - noise.min() + 1e-12)
    warp = 0.85 + 0.45 * noise
    fmap = base * warp

    fmap = gaussian_filter(fmap, sigma=float(rng.uniform(0.5, 1.3)))

    m = fmap / (fmap.max() + 1e-12)
    thr = float(rng.uniform(0.24, 0.44))
    fmask = (m >= thr)
    fmask = seal_mask(fmask)
    return fmap, fmask

def generate_foci_map_and_mask(gray01, cell_mask01, rng, n_foci,
                               size_px_range=(1.8, 9.0),
                               amp_range=(0.7, 1.6),
                               min_dist_px=6,
                               keep_away_from_edge_px=3):
    H, W = gray01.shape

    # keep centers away from cell border
    m = (cell_mask01 > 0.5)
    if keep_away_from_edge_px > 0:
        m = binary_erosion(m, iterations=int(keep_away_from_edge_px))
        if m.sum() < 30:
            m = (cell_mask01 > 0.5)

    centers = pick_centers_in_mask(rng, m.astype(np.float32), n_foci, min_dist_px=min_dist_px)

    fmap = np.zeros((H, W), np.float32)
    fmask = np.zeros((H, W), bool)

    for (cy, cx) in centers:
        fm, mk = make_one_focus_natural(
            H, W, cy, cx, rng,
            size_px_range=size_px_range,
            amp_range=amp_range
        )
        fmap += fm
        fmask |= mk

    # restrict to cell
    fmap *= (cell_mask01 > 0.5).astype(np.float32)
    fmask &= (cell_mask01 > 0.5)

    # normalize foci map to [0,1]
    mx = float(fmap.max())
    if mx > 1e-8:
        fmap = fmap / mx

    # final seal
    fmask = seal_mask(fmask)
    return fmap, fmask

# -----------------------------
# Cell texture / noise (inside cell only)
# -----------------------------
def add_cell_body_noise_only(img01, cell_mask01, rng,
                             cell_noise_sigma_range=(0.008, 0.035),
                             lowfreq_amp_range=(0.00, 0.10),
                             speck_prob=0.0005,
                             blur_prob=0.20):
    H, W = img01.shape
    m = (cell_mask01 > 0.5)
    out = img01.copy()

    sig = float(rng.uniform(*cell_noise_sigma_range))
    n = rng.normal(0.0, sig, size=(H, W)).astype(np.float32)
    out[m] = np.clip(out[m] + n[m], 0.0, 1.0)

    lf_amp = float(rng.uniform(*lowfreq_amp_range))
    if lf_amp > 0:
        lf = rng.normal(0.0, 1.0, size=(H, W)).astype(np.float32)
        lf = gaussian_filter(lf, sigma=float(rng.uniform(10.0, 22.0)))
        lf = (lf - lf.min()) / (lf.max() - lf.min() + 1e-12)
        lf = (lf - 0.5) * 2.0
        out[m] = np.clip(out[m] + lf_amp * lf[m], 0.0, 1.0)

    if speck_prob and speck_prob > 0:
        specks = (rng.random((H, W)) < speck_prob) & m
        if specks.any():
            out[specks] = np.clip(out[specks] + rng.uniform(0.15, 0.55), 0.0, 1.0)

    if rng.random() < blur_prob:
        blur_img = gaussian_filter(out, sigma=float(rng.uniform(0.3, 1.0)))
        alpha = float(rng.uniform(0.15, 0.55))
        out[m] = np.clip((1 - alpha) * out[m] + alpha * blur_img[m], 0.0, 1.0)

    return out

# -----------------------------
# Darken cell body (cells dark, foci pop)
# -----------------------------
def darken_cell_only(gray01, cell_mask01, rng,
                     gamma_range=(1.6, 2.3),
                     scale_range=(0.55, 0.75),
                     blacklift_range=(0.01, 0.05)):
    out = gray01.copy()
    m = (cell_mask01 > 0.5)

    gamma = float(rng.uniform(*gamma_range))
    scale = float(rng.uniform(*scale_range))
    blacklift = float(rng.uniform(*blacklift_range))

    x = out[m]
    x = np.clip(x - blacklift, 0.0, 1.0)
    x = np.clip(x * scale, 0.0, 1.0)
    x = np.clip(x, 0.0, 1.0) ** gamma

    out[m] = x
    return out

# -----------------------------
# Natural foci compositing (local contrast)
# -----------------------------
def add_foci_naturally(cell_bg01, foci_map01, foci_mask, rng,
                       add_gain_range=(0.20, 0.50),
                       mult_gain_range=(0.25, 0.95),
                       min_local_delta_range=(0.12, 0.26),
                       local_sigma_range=(2.0, 4.5),
                       foci_soften_sigma_range=(0.3, 1.2)):
    """
    Instead of "hard boost", do:
      - local baseline = blurred cell (local mean)
      - additive: + add_gain * foci_map
      - multiplicative: * (1 + mult_gain * foci_map)
      - enforce local contrast: foci >= local + delta (inside foci)
    """
    out = cell_bg01.copy()

    # soften foci to look PSF-like
    soften = float(rng.uniform(*foci_soften_sigma_range))
    fmap = gaussian_filter(foci_map01.astype(np.float32), sigma=soften)
    fmap = np.clip(fmap, 0.0, 1.0)

    add_gain = float(rng.uniform(*add_gain_range))
    mult_gain = float(rng.uniform(*mult_gain_range))
    min_delta = float(rng.uniform(*min_local_delta_range))
    local_sigma = float(rng.uniform(*local_sigma_range))

    local = gaussian_filter(out, sigma=local_sigma)  # local mean-ish

    # additive + multiplicative
    out = out + add_gain * fmap
    out = out * (1.0 + mult_gain * fmap)

    # enforce local contrast only inside foci
    m = (foci_mask.astype(bool))
    if m.any():
        out[m] = np.maximum(out[m], local[m] + min_delta * (0.65 + 0.35 * fmap[m]))

    return np.clip(out, 0.0, 1.0)

# -----------------------------
# 20 variations each
# -----------------------------
def simulate_foci_from_csv(
    csv_path,
    output_dir,
    out_csv_path,
    image_size=224,
    variations_per_image=20,
    seed=42,

    # label-based counts
    neg_range=(0, 3),
    pos_range=(6, 25),
    count_mode="triangular",

    # foci shape
    size_px_range=(1.8, 9.5),
    amp_range=(0.7, 1.6),
    min_dist_px=6,

    # clean vs noisy cell
    p_clean_cell=0.35,
    clean_smooth_sigma_range=(0.8, 2.0),

    # noise (only if not clean)
    cell_noise_sigma_range=(0.008, 0.035),
    lowfreq_amp_range=(0.00, 0.10),
    speck_prob=0.0005,
    blur_prob=0.20,

    # cell darkening (make cells dark)
    cell_dark_gamma_range=(1.6, 2.3),
    cell_dark_scale_range=(0.55, 0.75),
    cell_blacklift_range=(0.01, 0.05),

    # foci pop
    add_gain_range=(0.20, 0.50),
    mult_gain_range=(0.25, 0.95),
    min_local_delta_range=(0.12, 0.26),
):
    ensure_dir(output_dir)
    img_out = os.path.join(output_dir, "images")
    msk_out = os.path.join(output_dir, "masks")
    ensure_dir(img_out)
    ensure_dir(msk_out)

    set_seed(seed)
    rng = np.random.default_rng(seed)

    df = pd.read_csv(csv_path)
    assert "image_path" in df.columns, "CSV must contain column: image_path"
    assert "label" in df.columns, "CSV must contain column: label"

    rows = []

    for i, row in df.iterrows():
        src_path = str(row["image_path"])
        if not os.path.exists(src_path):
            print(f"[WARN] missing: {src_path}")
            continue

        label = int(row["label"])
        base_stem = stem_no_ext(src_path)

        # load grayscale
        base_gray01 = load_gray01_resized(src_path, size=image_size)

        # auto cell mask
        cell_mask01 = extract_cell_mask(base_gray01, rng=rng)

        # force outside-cell background black in base (helps consistency)
        base_gray01 = base_gray01.copy()
        base_gray01[cell_mask01 <= 0.5] = 0.0

        for v in range(int(variations_per_image)):
            # how many foci for this label
            n_foci = sample_n_foci(
                rng, label,
                neg_range=neg_range,
                pos_range=pos_range,
                mode=count_mode
            )

            # build cell background first
            is_clean = (rng.random() < float(p_clean_cell))
            cell_bg = base_gray01.copy()

            if is_clean:
                sig = float(rng.uniform(*clean_smooth_sigma_range))
                sm = gaussian_filter(cell_bg, sigma=sig)
                m = (cell_mask01 > 0.5)
                cell_bg[m] = sm[m]
            else:
                cell_bg = add_cell_body_noise_only(
                    cell_bg, cell_mask01, rng,
                    cell_noise_sigma_range=cell_noise_sigma_range,
                    lowfreq_amp_range=lowfreq_amp_range,
                    speck_prob=speck_prob,
                    blur_prob=blur_prob
                )

            # darken cell only (foci pop)
            cell_bg = darken_cell_only(
                cell_bg, cell_mask01, rng,
                gamma_range=cell_dark_gamma_range,
                scale_range=cell_dark_scale_range,
                blacklift_range=cell_blacklift_range
            )

            # keep outside cell black
            cell_bg[cell_mask01 <= 0.5] = 0.0

            # generate foci (map + mask)
            foci_map01, foci_mask = generate_foci_map_and_mask(
                cell_bg, cell_mask01, rng, n_foci,
                size_px_range=size_px_range,
                amp_range=amp_range,
                min_dist_px=min_dist_px
            )

            # composite foci naturally
            sim01 = add_foci_naturally(
                cell_bg, foci_map01, foci_mask, rng,
                add_gain_range=add_gain_range,
                mult_gain_range=mult_gain_range,
                min_local_delta_range=min_local_delta_range
            )

            # keep outside cell black (again)
            sim01[cell_mask01 <= 0.5] = 0.0

            out_stem = f"{base_stem}_v{v:03d}"
            out_img_path = os.path.join(img_out, f"{out_stem}.png")
            out_foci_mask_path = os.path.join(msk_out, f"{out_stem}_focimask.png")
            out_cell_mask_path = os.path.join(msk_out, f"{out_stem}_cellmask.png")

            save_u8_png(to_rgb_uint8(sim01), out_img_path)
            save_mask_png(foci_mask.astype(np.uint8), out_foci_mask_path)
            save_mask_png((cell_mask01 > 0.5).astype(np.uint8), out_cell_mask_path)

            rows.append({
                "image_path": out_img_path,
                "mask_path": out_foci_mask_path,
                "cell_mask_path": out_cell_mask_path,
                "label": int(label),
                "n_foci": int(n_foci),
                "is_clean_cell": int(is_clean),
                "src_image_path": src_path,
                "variant_idx": int(v),
            })

    out_df = pd.DataFrame(rows)
    out_df.to_csv(out_csv_path, index=False)
    print("Saved:", out_csv_path)
    print("Num samples:", len(out_df))
    return out_df


In [None]:
if __name__ == "__main__":
    CSV_IN   = "/content/drive/MyDrive/FYP/foci_variations_simulated_v03.csv"
    OUT_DIR  = "/content/drive/MyDrive/FYP/foci_sim_out_darkcell_pop_foci_variations_simulated_v03"
    CSV_OUT  = "/content/drive/MyDrive/FYP/foci_sim_out_darkcell_pop_foci_variations_simulated_v03/out_foci_variations_simulated_v03.csv"

    simulate_foci_from_csv(
        csv_path=CSV_IN,
        output_dir=OUT_DIR,
        out_csv_path=CSV_OUT,
        image_size=224,
        variations_per_image=20,
        seed=42,

        # make cells darker + foci pop
        cell_dark_gamma_range=(1.7, 2.5),
        cell_dark_scale_range=(0.50, 0.72),
        cell_blacklift_range=(0.01, 0.06),

        add_gain_range=(0.22, 0.55),
        mult_gain_range=(0.30, 1.05),
        min_local_delta_range=(0.14, 0.30),

        # some clean cells
        p_clean_cell=0.35,
    )
