In [None]:
# ============================================================
# Synthetic Foci Generator v8 (224x224) â€” matches sample cell surfaces
# - Foci channel only (no nucleus)
# - One whole cell visible per image
# - Outside cell background is balck
# - Cell interior surface: cloudy + mottled (multi-scale correlated texture)
# - Noise: varied per image, microscopy-like (shot + read + correlated patterns)
# - Foci: bright vs background, mostly circular / elliptical, some round-ish irregular
# - Saves: image_path, mask_path (foci), cell_mask_path (cell)
# - Outputs train/val/test split CSVs compatible with training pipeline
# ============================================================

import os, random, math
import numpy as np
import pandas as pd
from PIL import Image
from scipy.ndimage import (
    gaussian_filter,
    binary_dilation,
    binary_erosion
)


IMAGE_SIZE = 224  # match training pipeline

OUT_ROOT = "/content/drive/MyDrive/FYP/foci_simulation_from_scratch/sim_v8_fociChannel_surfaceMatched_224"
IMG_DIR  = os.path.join(OUT_ROOT, "images")
FOCI_DIR = os.path.join(OUT_ROOT, "foci_masks")
CELL_DIR = os.path.join(OUT_ROOT, "cell_masks")

CSV_ALL   = os.path.join(OUT_ROOT, "synthetic_all.csv")
CSV_TRAIN = os.path.join(OUT_ROOT, "train_split.csv")
CSV_VAL   = os.path.join(OUT_ROOT, "val_split.csv")
CSV_TEST  = os.path.join(OUT_ROOT, "test_split.csv")

os.makedirs(IMG_DIR, exist_ok=True)
os.makedirs(FOCI_DIR, exist_ok=True)
os.makedirs(CELL_DIR, exist_ok=True)

N_IMAGES   = 5000
TRAIN_FRAC = 0.80
VAL_FRAC   = 0.10
GLOBAL_SEED = 42

# -----------------------
# Utilities
# -----------------------
def clamp01(x): return np.clip(x, 0.0, 1.0)

def norm01(x, eps=1e-6):
    x = x.astype(np.float32)
    return (x - x.min()) / (x.max() - x.min() + eps)

def to_u8(x01): return (clamp01(x01) * 255.0).astype(np.uint8)

def save_rgb_png(path, x01):
    if x01.ndim == 2:
        x01 = np.stack([x01]*3, axis=-1)
    Image.fromarray(to_u8(x01)).save(path)

def save_mask_png(path, m01):
    Image.fromarray((m01.astype(np.uint8) * 255)).save(path)

def split_df(df, train_frac=0.8, val_frac=0.1, seed=42):
    df = df.sample(frac=1.0, random_state=seed).reset_index(drop=True)
    n = len(df)
    n_train = int(round(train_frac * n))
    n_val   = int(round(val_frac * n))
    return (
        df.iloc[:n_train].reset_index(drop=True),
        df.iloc[n_train:n_train+n_val].reset_index(drop=True),
        df.iloc[n_train+n_val:].reset_index(drop=True),
    )

def ellipse_mask(H, W, cx, cy, rx, ry, angle_rad):
    yy, xx = np.mgrid[0:H, 0:W].astype(np.float32)
    x = xx - cx
    y = yy - cy
    ca, sa = math.cos(angle_rad), math.sin(angle_rad)
    xr =  ca*x + sa*y
    yr = -sa*x + ca*y
    return ((xr/(rx+1e-6))**2 + (yr/(ry+1e-6))**2) <= 1.0

def sample_points_in_mask(mask01, n, rng):
    ys, xs = np.where(mask01 > 0.5)
    if len(xs) == 0:
        return []
    idx = rng.integers(0, len(xs), size=(n,))
    return list(zip(xs[idx].astype(np.float32), ys[idx].astype(np.float32)))

# -----------------------
# Texture fields to match natural examples
# -----------------------
def fractal_field(rng, H, W, sigmas=(3, 7, 15, 32), weights=(0.22, 0.30, 0.28, 0.20)):
    out = np.zeros((H, W), dtype=np.float32)
    for s, w in zip(sigmas, weights):
        n = rng.normal(0, 1, (H, W)).astype(np.float32)
        n = gaussian_filter(n, sigma=float(s))
        out += float(w) * n
    return norm01(out)

def illumination_blob(rng, H, W):
    yy, xx = np.mgrid[0:H, 0:W].astype(np.float32)

    cx = float(rng.uniform(0.35*W, 0.65*W))
    cy = float(rng.uniform(0.35*H, 0.65*H))
    sx = float(rng.uniform(0.25*W, 0.55*W))
    sy = float(rng.uniform(0.25*H, 0.55*H))
    blob = np.exp(-0.5*(((xx-cx)/sx)**2 + ((yy-cy)/sy)**2))

    gx = float(rng.uniform(-0.25, 0.25))
    gy = float(rng.uniform(-0.25, 0.25))
    grad = (gx * (xx / max(1, W-1)) + gy * (yy / max(1, H-1)))

    field = blob + grad
    field = gaussian_filter(field, sigma=float(rng.uniform(10, 24)))
    return norm01(field)

def smooth_mottles(rng, H, W, n_spots_range=(10, 26)):
    yy, xx = np.mgrid[0:H, 0:W].astype(np.float32)
    out = np.zeros((H, W), dtype=np.float32)
    n = int(rng.integers(n_spots_range[0], n_spots_range[1] + 1))
    for _ in range(n):
        x0 = float(rng.uniform(0, W))
        y0 = float(rng.uniform(0, H))
        s  = float(rng.uniform(10, 28))
        amp = float(rng.uniform(0.4, 1.0))
        out += amp * np.exp(-0.5*(((xx-x0)/s)**2 + ((yy-y0)/s)**2))
    out = gaussian_filter(out, sigma=float(rng.uniform(6, 14)))
    return norm01(out)

# -----------------------
# Cell masks (ellipse OR deformed blob)
# -----------------------
def deformed_cell_mask(H, W, rng, margin=18):
    cx = float(rng.uniform(margin + 70, W - margin - 70))
    cy = float(rng.uniform(margin + 70, H - margin - 70))

    yy, xx = np.mgrid[0:H, 0:W].astype(np.float32)
    n_lobes = int(rng.integers(2, 6))
    field = np.zeros((H, W), dtype=np.float32)

    for _ in range(n_lobes):
        dx = float(rng.normal(0, rng.uniform(8, 30)))
        dy = float(rng.normal(0, rng.uniform(8, 30)))
        x0 = cx + dx
        y0 = cy + dy
        sx = float(rng.uniform(24, 52))
        sy = float(rng.uniform(22, 54))
        ang = float(rng.uniform(0, 2*np.pi))
        ca, sa = math.cos(ang), math.sin(ang)
        x = xx - x0
        y = yy - y0
        xr =  ca*x + sa*y
        yr = -sa*x + ca*y
        g = np.exp(-0.5*((xr/(sx+1e-6))**2 + (yr/(sy+1e-6))**2))
        field += float(rng.uniform(0.7, 1.3)) * g

    field = gaussian_filter(field, sigma=float(rng.uniform(1.8, 3.6)))
    field = norm01(field)

    thr = float(rng.uniform(0.34, 0.58))
    cell = (field > thr).astype(np.float32)

    cell = gaussian_filter(cell, sigma=float(rng.uniform(0.9, 1.8)))
    cell = (cell > 0.5).astype(np.float32)

    if cell.sum() < 9000:
        return None
    return cell

def generate_single_cell_base(H=224, W=224, rng=None):
    """
    FOCI CHANNEL ONLY:
    - no nucleus
    - cloudy + mottled surface like your examples
    - background black outside cell
    """
    rng = np.random.default_rng() if rng is None else rng
    margin = int(rng.integers(14, 26))

    # ----- cell shape -----
    if rng.random() < 0.50:
        max_r = min(H, W)/2 - margin
        rx = float(rng.uniform(48, max_r))
        ry = float(rng.uniform(34, max_r))
        ang = float(rng.uniform(0, 2*np.pi))
        R = max(rx, ry)
        cx = float(rng.uniform(margin + R, W - margin - R))
        cy = float(rng.uniform(margin + R, H - margin - R))
        cell = ellipse_mask(H, W, cx, cy, rx, ry, ang).astype(np.float32)
    else:
        cell = deformed_cell_mask(H, W, rng, margin=margin)
        if cell is None:
            max_r = min(H, W)/2 - margin
            rx = float(rng.uniform(50, max_r))
            ry = float(rng.uniform(40, max_r))
            ang = float(rng.uniform(0, 2*np.pi))
            R = max(rx, ry)
            cx = float(rng.uniform(margin + R, W - margin - R))
            cy = float(rng.uniform(margin + R, H - margin - R))
            cell = ellipse_mask(H, W, cx, cy, rx, ry, ang).astype(np.float32)

    # ----- surface mode (bright/dark like your samples) -----
    mode = "bright" if rng.random() < 0.65 else "dark"

    cloud = fractal_field(rng, H, W, sigmas=(3, 7, 15, 32), weights=(0.22, 0.30, 0.28, 0.20))
    mott  = smooth_mottles(rng, H, W, n_spots_range=(10, 26))
    illum = illumination_blob(rng, H, W)

    surface = 0.55*cloud + 0.25*mott + 0.20*illum
    surface = norm01(gaussian_filter(surface, sigma=float(rng.uniform(1.5, 3.0))))

    if mode == "bright":
        base_level = float(rng.uniform(0.18, 0.32))
        base_var   = float(rng.uniform(0.10, 0.22))
    else:
        base_level = float(rng.uniform(0.04, 0.10))
        base_var   = float(rng.uniform(0.05, 0.14))

    img = base_level + base_var * (surface - 0.5) * 2.0

    # faint rim to make edges feel real
    edge = cell - gaussian_filter(cell, sigma=2.0)
    edge = np.clip(edge, 0.0, 1.0)
    img += float(rng.uniform(0.00, 0.04)) * edge

    img *= cell
    img = clamp01(img)
    return img, cell

# -----------------------
# Foci rendering
# -----------------------
def render_gaussian_elliptic(H, W, x0, y0, sx, sy, ang, amp):
    yy, xx = np.mgrid[0:H, 0:W].astype(np.float32)
    x = xx - x0
    y = yy - y0
    ca, sa = math.cos(ang), math.sin(ang)
    xr =  ca*x + sa*y
    yr = -sa*x + ca*y
    g = np.exp(-0.5*((xr/(sx+1e-6))**2 + (yr/(sy+1e-6))**2))
    return amp * g

def render_roundish_irregular(H, W, x0, y0, rng, base_sigma=2.2, amp=1.0):
    g = render_gaussian_elliptic(
        H, W, x0, y0,
        sx=float(rng.uniform(base_sigma*0.8, base_sigma*1.3)),
        sy=float(rng.uniform(base_sigma*0.8, base_sigma*1.3)),
        ang=float(rng.uniform(0, 2*np.pi)),
        amp=1.0
    )
    n = rng.normal(0, 1, (H, W)).astype(np.float32)
    n = gaussian_filter(n, sigma=float(rng.uniform(2.0, 5.0)))
    n = norm01(n)
    perturb = 1.0 + float(rng.uniform(-0.35, 0.35)) * (n - 0.5) * 2.0
    out = amp * g * perturb
    out = gaussian_filter(out, sigma=float(rng.uniform(0.5, 1.1)))
    return out

def render_one_focus(H, W, x0, y0, rng):
    t = rng.choice(["point", "round", "ellipse", "cluster", "irregular"],
                   p=[0.22, 0.38, 0.18, 0.14, 0.08])

    if t == "point":
        s = float(rng.uniform(0.6, 1.0))
        amp = float(rng.uniform(0.90, 1.60))
        return render_gaussian_elliptic(H, W, x0, y0, s, s, 0.0, amp)

    if t == "round":
        s = float(rng.uniform(1.2, 3.8))
        amp = float(rng.uniform(0.85, 1.60))
        return render_gaussian_elliptic(H, W, x0, y0, s, s, 0.0, amp)

    if t == "ellipse":
        sx = float(rng.uniform(1.0, 4.2))
        sy = float(rng.uniform(0.8, 3.4))
        ang = float(rng.uniform(0, 2*np.pi))
        amp = float(rng.uniform(0.80, 1.55))
        return render_gaussian_elliptic(H, W, x0, y0, sx, sy, ang, amp)

    if t == "cluster":
        k = int(rng.integers(2, 7))
        out = np.zeros((H, W), dtype=np.float32)
        for _ in range(k):
            dx = float(rng.normal(0, rng.uniform(0.8, 3.6)))
            dy = float(rng.normal(0, rng.uniform(0.8, 3.6)))
            s  = float(rng.uniform(0.7, 2.8))
            amp = float(rng.uniform(0.45, 1.10))
            out += render_gaussian_elliptic(H, W, x0+dx, y0+dy, s, s, 0.0, amp)
        out = gaussian_filter(out, sigma=float(rng.uniform(0.35, 0.90)))
        return out

    amp = float(rng.uniform(0.90, 1.65))
    base_sigma = float(rng.uniform(1.8, 4.2))
    return render_roundish_irregular(H, W, x0, y0, rng, base_sigma=base_sigma, amp=amp)

# -----------------------
# Noise model (varied, microscopy-like, not dot soup)
# -----------------------
def apply_varied_noise(img01, cell01, rng):
    img01 = gaussian_filter(img01, sigma=float(rng.uniform(0.45, 1.35)))

    photons = float(rng.uniform(55, 220))
    noisy = rng.poisson(clamp01(img01) * photons) / photons

    read_sigma = float(rng.uniform(0.003, 0.020))
    noisy = noisy + rng.normal(0, read_sigma, size=noisy.shape).astype(np.float32)

    if rng.random() < 0.75:
        s = rng.normal(0, 1, img01.shape).astype(np.float32)
        s = gaussian_filter(s, sigma=float(rng.uniform(6, 18)))
        s = norm01(s) - 0.5
        noisy = noisy + float(rng.uniform(-0.05, 0.05)) * s

    if rng.random() < 0.15:
        stripe = rng.normal(0, 1, (img01.shape[0], 1)).astype(np.float32)
        stripe = gaussian_filter(stripe, sigma=float(rng.uniform(2.0, 6.0)))
        stripe = norm01(stripe) - 0.5
        noisy = noisy + float(rng.uniform(-0.02, 0.02)) * stripe

    noisy = clamp01(noisy)

    if rng.random() < 0.55:
        noisy = gaussian_filter(noisy, sigma=float(rng.uniform(0.20, 0.65)))

    return clamp01(noisy) * cell01

# -----------------------
# Generate one image
# -----------------------
def generate_one_image(seed, H=224, W=224):
    rng = np.random.default_rng(int(seed))

    img, cell = generate_single_cell_base(H, W, rng=rng)

    # safe placement away from boundary
    cell_er = binary_erosion(cell.astype(bool), iterations=int(rng.integers(6, 12)))
    if cell_er.sum() < 200:
        cell_er = cell.astype(bool)

    # label & foci count
    is_pos = (rng.random() < 0.6)
    if not is_pos:
        foci_count = 0
    else:
        base = int(rng.integers(1, 6))
        extra = int(rng.choice([0,0,1,2,3,4,5], p=[0.22,0.18,0.18,0.15,0.12,0.10,0.05]))
        foci_count = base + extra

    pts = sample_points_in_mask(cell_er.astype(np.float32), foci_count, rng) if foci_count > 0 else []

    foci_field = np.zeros((H, W), dtype=np.float32)
    for (x0, y0) in pts:
        foci_field += render_one_focus(H, W, x0, y0, rng)

    foci_field *= cell

    # Foci are bright vs background
    foci_gain = float(rng.uniform(2.6, 4.6))
    img = clamp01(img + foci_gain * foci_field)

    # Apply noise
    img = apply_varied_noise(img, cell, rng)

    # Foci mask from clean field (no noise)
    thr = float(rng.uniform(0.10, 0.18))
    foci_mask = (foci_field > thr).astype(np.float32)
    if foci_mask.sum() > 0:
        iters = int(rng.integers(1, 3))
        foci_mask = binary_dilation(foci_mask.astype(bool), iterations=iters).astype(np.float32)
    foci_mask = (foci_mask * (cell > 0.5)).astype(np.float32)

    label = 1 if foci_count > 0 else 0

    img_rgb = np.stack([img, img, img], axis=-1)
    meta = {"seed": int(seed), "foci_count": int(foci_count), "label": int(label)}
    return img_rgb, foci_mask, cell.astype(np.float32), meta

# -----------------------
# Generate dataset + splits
# -----------------------
random.seed(GLOBAL_SEED)
np.random.seed(GLOBAL_SEED)

rows = []
for i in range(N_IMAGES):
    seed = GLOBAL_SEED * 100000 + i
    img_rgb, foci_m, cell_m, meta = generate_one_image(seed, H=IMAGE_SIZE, W=IMAGE_SIZE)

    img_name  = f"sim_{i:06d}.png"
    foci_name = f"sim_{i:06d}_foci.png"
    cell_name = f"sim_{i:06d}_cell.png"

    img_path  = os.path.join(IMG_DIR,  img_name)
    foci_path = os.path.join(FOCI_DIR, foci_name)
    cell_path = os.path.join(CELL_DIR, cell_name)

    save_rgb_png(img_path, img_rgb)
    save_mask_png(foci_path, foci_m)
    save_mask_png(cell_path, cell_m)

    rows.append({
        "image_path": img_path,
        "mask_path":  foci_path,
        "cell_mask_path": cell_path,
        "label": meta["label"],
        "foci_count": meta["foci_count"],
        "seed": meta["seed"],
    })

df_all = pd.DataFrame(rows)
df_all.to_csv(CSV_ALL, index=False)
print("Saved:", CSV_ALL, "rows:", len(df_all))

df_train, df_val, df_test = split_df(df_all, train_frac=TRAIN_FRAC, val_frac=VAL_FRAC, seed=GLOBAL_SEED)
df_train.to_csv(CSV_TRAIN, index=False)
df_val.to_csv(CSV_VAL, index=False)
df_test.to_csv(CSV_TEST, index=False)

print("Saved splits:")
print("  train:", CSV_TRAIN, len(df_train))
print("  val:  ", CSV_VAL,   len(df_val))
print("  test: ", CSV_TEST,  len(df_test))

print("\nLabel balance (mean label):")
print("  all :", df_all["label"].mean())
print("  train:", df_train["label"].mean())
print("  val :", df_val["label"].mean())
print("  test:", df_test["label"].mean())


Saved: /content/drive/MyDrive/FYP/foci_simulation_from_scratch/sim_v8_fociChannel_surfaceMatched_224/synthetic_all.csv rows: 5000
Saved splits:
  train: /content/drive/MyDrive/FYP/foci_simulation_from_scratch/sim_v8_fociChannel_surfaceMatched_224/train_split.csv 4000
  val:   /content/drive/MyDrive/FYP/foci_simulation_from_scratch/sim_v8_fociChannel_surfaceMatched_224/val_split.csv 500
  test:  /content/drive/MyDrive/FYP/foci_simulation_from_scratch/sim_v8_fociChannel_surfaceMatched_224/test_split.csv 500

Label balance (mean label):
  all : 0.6072
  train: 0.6095
  val : 0.586
  test: 0.61
