CE Synthetic Electropherogram Generator
======================================

This notebook generates synthetic capillary electrophoresis (CE) electropherograms.

Outputs (in output folder)
- signals/: M1_plXXXX.csv and M2_plXXXX.csv
- labels_masks/: labels_mask_M{1|2}_plXXXX.csv
- labels_centers/: labels_center_M{1|2}_plXXXX.csv
- labels_map/: labels_map.csv + peak_positions_detailed.csv
- plots/ (optional): per-channel and stacked plots

Model rules
- Template statistics are computed using molw >= 50 bp (templates can be noisy below 50 bp).
- Channels 1â€“4: molw < 50 bp => baseline + normal noise only (no peaks).
- Channel 5 (ladder): GeneScan 500 LIZ peaks at
  35, 50, 75, 100, 139, 150, 160, 200, 250, 300, 340, 350, 400, 450, 490, 500 bp.
- Baseline centering (optional): subtract a constant estimated only from molw < 50 bp,
  which shifts the signal to baseline ~0 without changing noise texture.


In [2]:
from __future__ import annotations

import math
import os
import random
import warnings
from dataclasses import dataclass, field
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

warnings.filterwarnings("ignore", category=RuntimeWarning)

# =============================================================================
# Config
# =============================================================================

@dataclass
class GeneratorConfig:
    # Core axis rules
    template_stats_min_bp: float = 50.0   # ignore <50 bp when estimating template stats
    prepeak_cutoff_bp: float = 50.0      # channels 1-4: <50 bp => baseline+noise only

    # Ladder (channel 5): GeneScan 500 LIZ
    liz500_peaks: Tuple[int, ...] = (35, 50, 75, 100, 139, 150, 160, 200, 250, 300, 340, 350, 400, 450, 490, 500)
    ladder_noise_mult: float = 0.35
    ladder_drift_mult: float = 0.015

    # Baseline centering (DC offset removal only)
    center_baseline: bool = True
    center_q: int = 50  # percentile on region <50 bp (median recommended)

    # Peak counts / placement
    max_true_peaks: int = 12
    max_total_events_per_channel: int = 28
    p_sparse: float = 0.45
    p_medium: float = 0.40
    p_dense: float = 0.15
    min_sep_bp: float = 1.1
    p_overlap: float = 0.15

    # Peak shape/amplitude
    amp_floor: float = 80.0
    sigma_min: float = 0.30
    sigma_max: float = 2.20
    tau_min: float = 0.25
    tau_max: float = 2.60
    p_left_tail: float = 0.08
    amp_logn_sigma: float = 0.55

    # Stutter
    p_stutter: float = 0.28
    p_forward_stutter: float = 0.06
    p_double_step: float = 0.06
    repeat_units: Tuple[int, ...] = (2, 3, 4, 5)

    # Spurious peaks
    p_spurious: float = 0.40
    spurious_count_range: Tuple[int, int] = (1, 4)
    spurious_amp_range: Tuple[float, float] = (30.0, 520.0)

    # Noise
    hf_noise_range: Tuple[float, float] = (0.45, 1.35)
    corr_k_range: Tuple[int, int] = (6, 18)
    hetero_range: Tuple[float, float] = (0.0010, 0.0035)
    hetero_sqrt_cap: float = 300.0
    het_winsor: Tuple[float, float] = (0.5, 99.5)
    abs_noise_floor_range: Tuple[float, float] = (0.8, 2.2)

    micro_white_range: Tuple[float, float] = (0.12, 0.33)
    micro_ar_range: Tuple[float, float] = (0.08, 0.27)
    micro_ar_phi_range: Tuple[float, float] = (0.35, 0.75)

    # Baseline LF components
    rw_sigma_factor: float = 0.014
    drift_slope_factor: float = 0.03
    lf_smooth_range: Tuple[int, int] = (31, 121)
    lf_sine_count_range: Tuple[int, int] = (1, 3)
    lf_period_range: Tuple[float, float] = (350.0, 1500.0)
    lf_sine_amp_range: Tuple[float, float] = (0.2, 1.2)
    lf_amp_factor: float = 0.50

    baseline_lf_smooth: int = 121
    baseline_lf_cap_frac_qlow: float = 0.70
    baseline_lf_soft: float = 0.85

    # Crosstalk
    crosstalk_base: np.ndarray = field(default_factory=lambda: np.array([
        [0.0, 0.028, 0.010, 0.006],
        [0.018, 0.0, 0.028, 0.010],
        [0.010, 0.018, 0.0, 0.028],
        [0.006, 0.010, 0.018, 0.0],
    ], dtype=float))
    crosstalk_scale_range: Tuple[float, float] = (0.45, 0.80)
    crosstalk_jitter: float = 0.25
    p_pulldown: float = 0.12
    pulldown_range: Tuple[float, float] = (0.001, 0.006)

    # Saturation
    sat_factor: float = 1.05
    overload_softness: float = 1.0

    # Simple sample shift per multiplex
    mux_shift_max: int = 2

    # molw randomization
    molw_scale_range: Tuple[float, float] = (0.994, 1.006)
    molw_shift_range: Tuple[float, float] = (-4.0, 4.0)
    p_warp: float = 0.25
    warp_strength_range: Tuple[float, float] = (-0.006, 0.006)
    warp_kind_probs: Tuple[float, float] = (0.55, 0.45)  # quad, cubicS
    p_molw_jitter: float = 0.65
    molw_jitter_amp_range: Tuple[float, float] = (0.00, 0.25)
    molw_jitter_smooth_range: Tuple[int, int] = (41, 121)

    # Modes
    modes: Tuple[str, ...] = ("normal", "noisy", "clean", "saturated")
    mode_probs: Tuple[float, ...] = (0.60, 0.20, 0.10, 0.10)

    # Labels
    label_min_abs_height: float = 15.0
    label_snr_k: float = 4.0
    mask_sigma_mult: float = 2.0
    mask_w_min: float = 1.5
    mask_w_max: float = 6.0
    include_mask_types: Tuple[str, ...] = ("true", "stutter", "spurious", "crosstalk")

    # Legacy peak_positions_detailed.csv
    legacy_peak_width_mult: float = 6.0
    legacy_plant_peak_width_max_pb: float = 7.0
    ampclass_t1: float = 900.0
    ampclass_t2: float = 2570.0

    # Plotting
    plot_qlo: float = 0.5
    plot_qhi: float = 99.5


CFG = GeneratorConfig()

# =============================================================================
# Utilities
# =============================================================================

def ensure_dir(path: str) -> None:
    os.makedirs(path, exist_ok=True)

def clamp(x: float, lo: float, hi: float) -> float:
    return float(min(hi, max(lo, x)))

def mad(x: np.ndarray) -> float:
    x = np.asarray(x, float)
    med = np.median(x)
    return 1.4826 * np.median(np.abs(x - med)) + 1e-12

def smooth_noise(n: int, k: int) -> np.ndarray:
    k = max(3, int(k))
    z = np.random.normal(0.0, 1.0, n)
    return np.convolve(z, np.ones(k) / k, mode="same")

def shift_series(y: np.ndarray, shift: int) -> np.ndarray:
    y = np.asarray(y, float)
    if shift == 0:
        return y.copy()
    out = np.roll(y, shift)
    if shift > 0:
        out[:shift] = out[shift]
    else:
        out[shift:] = out[shift - 1]
    return out

def resample_to_molw(molw_src: np.ndarray, y_src: np.ndarray, molw_dst: np.ndarray) -> np.ndarray:
    molw_src = np.asarray(molw_src, float)
    y_src = np.asarray(y_src, float)
    molw_dst = np.asarray(molw_dst, float)
    order = np.argsort(molw_src)
    xs = molw_src[order]
    ys = y_src[order]
    return np.interp(molw_dst, xs, ys, left=ys[0], right=ys[-1])

def nearest_index(x: np.ndarray, v: float) -> int:
    return int(np.clip(np.searchsorted(x, v), 0, len(x) - 1))

def ar1(n: int, phi: float, sigma: float) -> np.ndarray:
    e = np.random.normal(0.0, sigma, n)
    y = np.zeros(n)
    for i in range(1, n):
        y[i] = phi * y[i - 1] + e[i]
    return y

def winsor(x: np.ndarray, lo: float, hi: float) -> np.ndarray:
    a, b = np.percentile(x, [lo, hi])
    return np.clip(x, a, b)

def span_p99_p1(y: np.ndarray) -> float:
    y = np.asarray(y, float)
    return float(np.percentile(y, 99) - np.percentile(y, 1))

# =============================================================================
# EMG peak
# =============================================================================

try:
    from scipy.special import erfcx  # type: ignore
except Exception:
    from math import erfc as _erfc_scalar
    def erfcx(z):
        z = np.asarray(z, dtype=float)
        return np.exp(z * z) * np.vectorize(_erfc_scalar)(z)

def emg_peak(x: np.ndarray, amp: float, mu: float, sigma: float, tau: float, left_tail: bool = False) -> np.ndarray:
    x = np.asarray(x, float)
    sigma = max(float(sigma), 1e-6)
    tau = max(float(tau), 1e-6)
    if left_tail:
        x = (2.0 * mu) - x

    lam = 1.0 / tau
    a = (mu + lam * sigma * sigma - x) / (math.sqrt(2.0) * sigma)
    expo = (lam / 2.0) * (2.0 * mu + lam * sigma * sigma - 2.0 * x)

    with np.errstate(over="ignore", under="ignore", invalid="ignore", divide="ignore"):
        y = (lam / 2.0) * np.exp(expo - a * a) * erfcx(a)

    y = np.nan_to_num(y, nan=0.0, posinf=0.0, neginf=0.0)
    m = float(y.max()) if y.size else 0.0
    if m <= 0.0:
        return np.zeros_like(x)
    return (amp / m) * y

def soft_saturate(y: np.ndarray, sat: float, softness: float) -> np.ndarray:
    sat = max(float(sat), 1.0)
    with np.errstate(over="ignore", invalid="ignore"):
        z = sat * np.tanh(y / (sat * max(1e-6, softness)))
    return np.nan_to_num(z, nan=0.0, posinf=0.0, neginf=0.0)

# =============================================================================
# Template reading
# =============================================================================

@dataclass
class Template:
    df: pd.DataFrame
    molw: np.ndarray
    min_bp: float
    max_bp: float
    ladder_noise: float
    stats: Dict[int, Dict[str, float]]  # ch -> dict


def read_template(csv_path: str, cfg: GeneratorConfig) -> Template:
    df = pd.read_csv(csv_path, sep=";")
    cols = ["index", "molw", "channel_1", "channel_2", "channel_3", "channel_4", "channel_5"]
    missing = [c for c in cols if c not in df.columns]
    if missing:
        raise ValueError(f"Template missing columns: {missing}")

    for c in cols:
        df[c] = pd.to_numeric(df[c], errors="coerce").fillna(0.0)

    molw = df["molw"].to_numpy(float)
    good = molw >= cfg.template_stats_min_bp
    if not np.any(good):
        good = np.ones_like(molw, dtype=bool)

    margin = max(40.0, 0.05 * (molw.max() - molw.min()))
    min_bp = float(max(molw.min() + margin, cfg.template_stats_min_bp))
    max_bp = float(molw.max() - margin)
    if max_bp <= min_bp:
        min_bp = float(max(np.min(molw[good]), cfg.template_stats_min_bp))
        max_bp = float(np.max(molw[good]))

    stats: Dict[int, Dict[str, float]] = {}
    for ch in range(1, 5):
        x = df[f"channel_{ch}"].to_numpy(float)[good]

        q10 = np.quantile(x, 0.10)
        low = x[x <= q10]
        base_med = float(np.median(low)) if low.size else float(np.median(x))
        base_noise = float(mad(low)) if low.size else float(mad(x))

        q_low = float(np.quantile(x, 0.80))
        q_high = float(np.quantile(x, 0.9995))

        amp_cap = max(cfg.amp_floor, q_high) * 1.2
        sat = max(cfg.amp_floor, q_high) * cfg.sat_factor

        stats[ch] = dict(
            base_med=base_med,
            base_noise=max(base_noise, 1e-6),
            q_low=q_low,
            q_high=q_high,
            amp_cap=amp_cap,
            sat=sat,
        )

    ladder = df["channel_5"].to_numpy(float)[good]
    q10 = np.quantile(ladder, 0.10)
    low = ladder[ladder <= q10]
    ladder_noise = float(mad(low)) if low.size else float(mad(ladder))

    return Template(df=df, molw=molw, min_bp=min_bp, max_bp=max_bp, ladder_noise=ladder_noise, stats=stats)

# =============================================================================
# Axis randomization
# =============================================================================

def enforce_monotonic(x: np.ndarray, min_step: float = 1e-6) -> np.ndarray:
    x = np.asarray(x, float).copy()
    for i in range(1, len(x)):
        if x[i] <= x[i - 1]:
            x[i] = x[i - 1] + min_step
    return x

def warp_molw(molw_base: np.ndarray, cfg: GeneratorConfig) -> Tuple[np.ndarray, Dict[str, float | str]]:
    x = np.asarray(molw_base, float)
    a = random.uniform(*cfg.molw_scale_range)
    b = random.uniform(*cfg.molw_shift_range)
    y = a * x + b

    warp_kind = "none"
    warp_k = 0.0
    if random.random() < cfg.p_warp:
        warp_kind = random.choices(["quad", "cubicS"], weights=cfg.warp_kind_probs, k=1)[0]
        t = (x - x.mean()) / (x.max() - x.min() + 1e-9)
        warp_k = random.uniform(*cfg.warp_strength_range)
        if warp_kind == "quad":
            y = y + (warp_k * (x.max() - x.min())) * (t * t - np.mean(t * t))
        else:
            y = y + (warp_k * (x.max() - x.min())) * (t * t * t - t)

    jitter_amp = 0.0
    if random.random() < cfg.p_molw_jitter:
        jitter_amp = random.uniform(*cfg.molw_jitter_amp_range)
        if jitter_amp > 0:
            k_s = random.randint(*cfg.molw_jitter_smooth_range)
            jit = smooth_noise(len(y), k_s)
            jit = jit / (np.std(jit) + 1e-9)
            y = y + jitter_amp * jit

    y = enforce_monotonic(y)
    meta = dict(molw_scale=a, molw_shift=b, warp_kind=warp_kind, warp_k=warp_k, molw_jitter_amp=jitter_amp)
    return y, meta

# =============================================================================
# Baseline + noise
# =============================================================================

def baseline_components(molw: np.ndarray, base_med: float, base_noise: float, plant_offset: float,
                        baseline_gain: float, mode: str, cfg: GeneratorConfig) -> np.ndarray:
    n = len(molw)
    x0 = float(molw.mean())
    xr = float(molw.max() - molw.min() + 1e-9)

    drift_mult, rw_mult = 1.0, 1.0
    if mode == "noisy":
        drift_mult, rw_mult = 1.35, 1.25
    elif mode == "clean":
        drift_mult, rw_mult = 0.75, 0.80
    elif mode == "saturated":
        drift_mult, rw_mult = 1.10, 1.10

    offset = (base_med + plant_offset) + np.random.normal(0.0, 1.2 * base_noise)
    slope = np.random.normal(0.0, drift_mult * cfg.drift_slope_factor * base_noise / xr)
    drift = slope * (molw - x0)

    rw = np.cumsum(np.random.normal(0.0, rw_mult * cfg.rw_sigma_factor * base_noise * baseline_gain, n))
    rw -= rw.mean()

    lf = smooth_noise(n, random.randint(*cfg.lf_smooth_range)) * (cfg.lf_amp_factor * base_noise * baseline_gain)

    sines = np.zeros(n)
    for _ in range(random.randint(*cfg.lf_sine_count_range)):
        period = np.random.uniform(*cfg.lf_period_range)
        phase = np.random.uniform(0, 2 * math.pi)
        amp = np.random.uniform(*cfg.lf_sine_amp_range) * base_noise * baseline_gain
        sines += amp * np.sin(2 * math.pi * (molw - molw.min()) / period + phase)

    base = offset + drift + rw + lf + sines
    return np.nan_to_num(base, nan=0.0, posinf=0.0, neginf=0.0)

def limit_slow_baseline(base: np.ndarray, q_low: float, cfg: GeneratorConfig) -> np.ndarray:
    n = len(base)
    k = min(cfg.baseline_lf_smooth, max(9, n // 10))
    slow = np.convolve(base, np.ones(k) / k, mode="same")
    fast = base - slow
    amp = np.percentile(slow, 95) - np.percentile(slow, 5)
    cap = cfg.baseline_lf_cap_frac_qlow * max(1.0, float(q_low))
    if amp <= cap or amp <= 1e-9:
        return base
    scale = (cap / amp) ** cfg.baseline_lf_soft
    slow2 = (slow - slow.mean()) * scale + slow.mean()
    return np.nan_to_num(slow2 + fast, nan=0.0, posinf=0.0, neginf=0.0)

def micro_texture(n: int, base_noise: float, abs_floor: float, mode: str, cfg: GeneratorConfig) -> np.ndarray:
    w_lo, w_hi = cfg.micro_white_range
    a_lo, a_hi = cfg.micro_ar_range
    if mode == "noisy":
        w_hi *= 1.25
        a_hi *= 1.25
    elif mode == "clean":
        w_hi *= 0.75
        a_hi *= 0.75

    w_sigma = max(np.random.uniform(w_lo, w_hi) * base_noise, 0.25 * abs_floor)
    a_sigma = max(np.random.uniform(a_lo, a_hi) * base_noise, 0.25 * abs_floor)

    white = np.random.normal(0.0, w_sigma, n)
    phi = np.random.uniform(*cfg.micro_ar_phi_range)
    ar = ar1(n, phi, a_sigma)
    return np.nan_to_num(white + ar, nan=0.0, posinf=0.0, neginf=0.0)

def add_noise(base_noise: float, signal: np.ndarray, noise_gain: float, mode: str, cfg: GeneratorConfig) -> np.ndarray:
    n = len(signal)

    abs_floor = random.uniform(*cfg.abs_noise_floor_range)
    if mode == "noisy":
        abs_floor *= 1.2
    elif mode == "clean":
        abs_floor *= 0.8

    hf_lo, hf_hi = cfg.hf_noise_range
    if mode == "noisy":
        hf_hi *= 1.35
    elif mode == "clean":
        hf_hi *= 0.80

    hf_sigma = np.random.uniform(hf_lo, hf_hi) * base_noise * noise_gain
    hf_sigma = max(hf_sigma, abs_floor)

    white = np.random.normal(0.0, hf_sigma, n)
    corr = smooth_noise(n, k=random.randint(*cfg.corr_k_range)) * (0.55 * hf_sigma)

    hetero = np.random.uniform(*cfg.hetero_range) * noise_gain
    s = np.sqrt(np.maximum(np.abs(signal), 0.0))
    s = np.clip(s, 0.0, cfg.hetero_sqrt_cap)
    het = winsor(np.random.normal(0.0, hetero * s, n), *cfg.het_winsor)

    y = white + corr + het + micro_texture(n, base_noise, abs_floor, mode, cfg)
    return np.nan_to_num(y, nan=0.0, posinf=0.0, neginf=0.0)

# =============================================================================
# Peak sampling helpers
# =============================================================================

def choose_mode(cfg: GeneratorConfig) -> str:
    return random.choices(cfg.modes, weights=cfg.mode_probs, k=1)[0]

def choose_peak_count(cfg: GeneratorConfig) -> int:
    r = random.random()
    if r < cfg.p_sparse:
        rr = random.random()
        return 0 if rr < 0.10 else (1 if rr < 0.55 else 2)
    if r < cfg.p_sparse + cfg.p_medium:
        return int(np.clip(np.random.poisson(5), 3, 8))
    return int(np.clip(np.random.poisson(13), 9, 20))

def sample_peak_centers(count: int, min_bp: float, max_bp: float, cfg: GeneratorConfig) -> List[float]:
    centers: List[float] = []
    tries = 0
    while len(centers) < count and tries < 7000:
        tries += 1
        mu = random.uniform(min_bp, max_bp)
        allow = (random.random() < cfg.p_overlap)
        if allow or all(abs(mu - c) >= cfg.min_sep_bp for c in centers):
            centers.append(mu)
    while len(centers) < count:
        centers.append(random.uniform(min_bp, max_bp))
    centers.sort()
    return centers

def dropout_prob(q_low: float, x: float, cfg: GeneratorConfig) -> float:
    mid = max(cfg.amp_floor, q_low) * 0.9
    scale = max(1.0, mid * 0.30)
    z = (mid - x) / scale
    if z >= 0:
        ez = math.exp(-z)
        return 1 / (1 + ez)
    ez = math.exp(z)
    return ez / (1 + ez)

def stutter_ratio(parent_amp: float, frac: float, cfg: GeneratorConfig) -> float:
    mean = 0.05 + 0.10 * frac + (-0.025) * math.log(max(parent_amp, 1.0))
    mean = float(np.clip(mean, 0.02, 0.30))
    r = np.random.normal(mean, 0.03)
    return float(np.clip(r, 0.01, 0.35))

# =============================================================================
# Crosstalk
# =============================================================================

def sample_crosstalk_matrix(cfg: GeneratorConfig) -> np.ndarray:
    M = np.array(cfg.crosstalk_base, float)
    M *= random.uniform(*cfg.crosstalk_scale_range)
    M *= (1.0 + np.random.normal(0.0, cfg.crosstalk_jitter, M.shape))
    M = np.clip(M, 0.0, None)

    if random.random() < cfg.p_pulldown:
        for _ in range(random.randint(1, 3)):
            i, j = random.randint(0, 3), random.randint(0, 3)
            if i != j:
                M[i, j] = -random.uniform(*cfg.pulldown_range)

    np.fill_diagonal(M, 0.0)
    return M

def apply_crosstalk(peak_comp: Dict[str, np.ndarray], sat_levels: Dict[str, float],
                    M: np.ndarray, cfg: GeneratorConfig) -> Tuple[Dict[str, np.ndarray], List[Tuple[str, str, float]]]:
    chs = [f"channel_{i}" for i in range(1, 5)]
    out = {ch: peak_comp[ch].copy() for ch in chs}
    used: List[Tuple[str, str, float]] = []

    sat_peaks = {ch: soft_saturate(peak_comp[ch], sat_levels[ch], cfg.overload_softness) for ch in chs}

    for i, src in enumerate(chs):
        for j, tgt in enumerate(chs):
            if i == j:
                continue
            coef = float(M[i, j])
            if abs(coef) < 1e-8:
                continue
            out[tgt] += coef * sat_peaks[src]
            used.append((src, tgt, coef))

    for ch in chs:
        out[ch] = np.nan_to_num(out[ch], nan=0.0, posinf=0.0, neginf=0.0)

    return out, used

# =============================================================================
# Labels
# =============================================================================

def passes_label_filter(event: Dict, cfg: GeneratorConfig) -> bool:
    amp = float(event.get("peak_height", 0.0))
    bn = float(event.get("base_noise", 0.0))
    thr = max(cfg.label_min_abs_height, cfg.label_snr_k * bn)
    return amp >= thr

def build_dense_mask(molw: np.ndarray, events_out_axis: List[Dict], filename: str, cfg: GeneratorConfig) -> pd.DataFrame:
    df = pd.DataFrame({"molw": molw})
    for ch in range(1, 5):
        df[f"label_channel_{ch}"] = 0

    for e in events_out_axis:
        if e["filename"] != filename:
            continue
        if e.get("peak_type") not in cfg.include_mask_types:
            continue
        if not passes_label_filter(e, cfg):
            continue

        ch = int(e["channel"].split("_")[1])
        mu = float(e["peak_center_molw"])
        sigma = float(e.get("sigma", 0.0))
        w = cfg.mask_sigma_mult * sigma if sigma > 0 else cfg.mask_w_min
        w = clamp(w, cfg.mask_w_min, cfg.mask_w_max)

        df.loc[np.abs(molw - mu) <= w, f"label_channel_{ch}"] = 1

    return df

def build_center_labels(molw: np.ndarray, events_out_axis: List[Dict], filename: str, cfg: GeneratorConfig) -> pd.DataFrame:
    df = pd.DataFrame({"molw": molw})
    for ch in range(1, 5):
        df[f"center_channel_{ch}"] = 0

    for e in events_out_axis:
        if e["filename"] != filename:
            continue
        if e.get("peak_type") not in cfg.include_mask_types:
            continue
        if not passes_label_filter(e, cfg):
            continue

        ch = int(e["channel"].split("_")[1])
        mu = float(e["peak_center_molw"])
        idx = nearest_index(molw, mu)
        df.loc[idx, f"center_channel_{ch}"] = 1

    return df

# =============================================================================
# Baseline centering (DC only)
# =============================================================================

def center_baseline_pre50(y: np.ndarray, molw: np.ndarray, cutoff: float, q: int) -> np.ndarray:
    y = np.asarray(y, float).copy()
    m = np.asarray(molw, float) < float(cutoff)
    ref = float(np.percentile(y[m], q)) if np.any(m) else float(np.percentile(y, q))
    return y - ref

# =============================================================================
# Ladder
# =============================================================================

def generate_liz500_ladder(molw_base: np.ndarray, molw_out: np.ndarray, ladder_noise: float, mux_shift: int,
                           cfg: GeneratorConfig) -> np.ndarray:
    x = np.asarray(molw_base, float)
    y = np.zeros_like(x)

    base_sigma = np.random.uniform(0.35, 0.75)
    base_tau = np.random.uniform(0.25, 0.65)

    for p in cfg.liz500_peaks:
        if p < float(x.min()) or p > float(x.max()):
            continue
        t = (p - 250.0) / 250.0
        profile = 1.0 - 0.15 * (t * t)
        amp = 12000.0 * profile * np.random.uniform(0.85, 1.15)

        sigma = base_sigma * np.random.uniform(0.85, 1.15)
        tau = base_tau * np.random.uniform(0.85, 1.15)
        y += emg_peak(x, float(amp), float(p), float(sigma), float(tau), left_tail=False)

    x0 = float(x.mean())
    xr = float(x.max() - x.min() + 1e-9)
    drift = np.random.normal(0.0, cfg.ladder_drift_mult * ladder_noise / xr) * (x - x0)
    base = np.random.normal(0.0, cfg.ladder_noise_mult * ladder_noise, size=y.shape)

    y = y + drift + base
    y = shift_series(np.nan_to_num(y), mux_shift)
    return resample_to_molw(x, y, np.asarray(molw_out, float))

# =============================================================================
# peak_positions_detailed (legacy)
# =============================================================================

def amp_class(amplitude: float, cfg: GeneratorConfig) -> str:
    if amplitude < cfg.ampclass_t1:
        return "low"
    if amplitude < cfg.ampclass_t2:
        return "mid"
    return "high"

def build_peak_positions_detailed(rows: List[Dict]) -> pd.DataFrame:
    cols = [
        "plant_id", "plant_id_str", "multiplex", "channel", "marker_id", "peak_kind", "peak_index",
        "mu_pb", "sigma_pb", "peak_width_pb", "amplitude", "amp_class", "plant_peak_width_max_pb"
    ]
    if not rows:
        return pd.DataFrame(columns=cols)

    df = pd.DataFrame(rows)
    group_cols = ["plant_id_str", "multiplex", "channel", "peak_kind"]
    df = df.sort_values(group_cols + ["mu_pb"], kind="mergesort").reset_index(drop=True)
    df["peak_index"] = df.groupby(group_cols).cumcount() + 1
    return df[cols]

# =============================================================================
# Plotting (optional)
# =============================================================================

def robust_ylim(ax, y, qlo=0.5, qhi=99.5):
    y = np.asarray(y, float)
    y = y[np.isfinite(y)]
    if y.size == 0:
        return
    lo = np.percentile(y, qlo)
    hi = np.percentile(y, qhi)
    pad = 0.05 * (hi - lo) if hi > lo else 1.0
    ax.set_ylim(lo - pad, hi + pad)

def plot_channel(df: pd.DataFrame, out_path: str, col: str, qlo: float, qhi: float):
    x = df["molw"].to_numpy(float)
    y = df[col].to_numpy(float)
    fig, ax = plt.subplots(figsize=(10, 3))
    ax.plot(x, y, lw=0.7)
    ax.grid(True, alpha=0.25)
    ax.set_xlabel("molw (bp)")
    ax.set_ylabel(col)
    ax.set_xlim(float(x.min()), float(x.max()))
    robust_ylim(ax, y, qlo, qhi)
    fig.tight_layout()
    fig.savefig(out_path, dpi=140)
    plt.close(fig)

def plot_stacked(df: pd.DataFrame, out_path: str, qlo: float, qhi: float):
    x = df["molw"].to_numpy(float)
    fig, axes = plt.subplots(5, 1, figsize=(10, 8), sharex=True)
    fig.subplots_adjust(hspace=0.35)
    for i in range(1, 6):
        col = f"channel_{i}"
        y = df[col].to_numpy(float)
        ax = axes[i - 1]
        ax.plot(x, y, lw=0.7)
        ax.set_ylabel(col, fontsize=8)
        ax.grid(True, alpha=0.25)
        robust_ylim(ax, y, qlo, qhi)
    axes[-1].set_xlabel("molw (bp)")
    axes[-1].set_xlim(float(x.min()), float(x.max()))
    fig.tight_layout()
    fig.savefig(out_path, dpi=160)
    plt.close(fig)

def progress(done: int, total: int, width: int = 28, label: str = "Generating plots"):
    frac = done / total if total else 1.0
    filled = int(width * frac)
    bar = "#" * filled + "-" * (width - filled)
    print(f"\r{label}: [{bar}] {done}/{total} ({100*frac:5.1f}%)", end="", flush=True)
    if done >= total:
        print("")

# =============================================================================
# Main generator
# =============================================================================

def generate_dataset(
    template_m1_path: str,
    template_m2_path: str,
    n_plants: int,
    out_dir: str,
    cfg: GeneratorConfig,
    make_plots: bool = False,
    seed: int | None = None,
) -> Dict[str, str | None]:
    if seed is not None:
        random.seed(seed)
        np.random.seed(seed)

    ensure_dir(out_dir)
    signals_dir = os.path.join(out_dir, "signals")
    masks_dir = os.path.join(out_dir, "labels_masks")
    centers_dir = os.path.join(out_dir, "labels_centers")
    map_dir = os.path.join(out_dir, "labels_map")
    for d in (signals_dir, masks_dir, centers_dir, map_dir):
        ensure_dir(d)

    plots_root = os.path.join(out_dir, "plots")
    stacked_dir = os.path.join(plots_root, "stacked")
    by_channel_dir = os.path.join(plots_root, "by_channel")
    if make_plots:
        ensure_dir(stacked_dir)
        ensure_dir(by_channel_dir)
        total_plots = n_plants * 2 * 6
        done_plots = 0

    tpl1 = read_template(template_m1_path, cfg)
    tpl2 = read_template(template_m2_path, cfg)

    labels_map_rows: List[Dict] = []
    legacy_rows: List[Dict] = []

    for plant_i in range(1, n_plants + 1):
        plant_id = int(plant_i)
        plant_id_str = int(plant_i)      # legacy example uses non-padded ints
        plant_tag = f"pl{plant_i:04d}"

        plant_gain = clamp(np.exp(np.random.normal(0.0, 0.18)), 0.6, 1.8)
        plant_noise_gain = clamp(np.exp(np.random.normal(0.0, 0.14)), 0.7, 1.8)
        plant_baseline_gain = clamp(np.exp(np.random.normal(0.0, 0.12)), 0.7, 1.6)
        plant_offset = float(np.random.normal(0.0, 12.0))

        corr = 0.45
        C = (1 - corr) * np.eye(4) + corr * np.ones((4, 4))
        L = np.linalg.cholesky(C)
        z = L @ np.random.normal(0.0, 0.16, 4)
        ch_gain = np.exp(z)
        ch_gain = ch_gain / np.mean(ch_gain)

        CT = sample_crosstalk_matrix(cfg)

        for mux_idx, (mux_name, tpl) in enumerate([("M1", tpl1), ("M2", tpl2)], start=1):
            mode = choose_mode(cfg)
            mux_gain = clamp(np.exp(np.random.normal(0.0, 0.10)), 0.75, 1.35)
            mux_noise_gain = clamp(np.exp(np.random.normal(0.0, 0.08)), 0.75, 1.35)
            mux_shift = random.randint(-cfg.mux_shift_max, cfg.mux_shift_max)

            filename = f"{mux_name}_{plant_tag}.csv"

            molw_base = tpl.molw
            molw_out, molw_meta = warp_molw(molw_base, cfg)
            molw_out = np.asarray(molw_out, float)
            pre_mask = molw_out < cfg.prepeak_cutoff_bp

            base_by_ch: Dict[str, np.ndarray] = {}
            peaks_by_ch: Dict[str, np.ndarray] = {}
            sat_levels: Dict[str, float] = {}
            events: List[Dict] = []

            # --- build baseline+peaks on base axis for channels 1..4 ---
            for ch in range(1, 5):
                ch_name = f"channel_{ch}"
                gain = plant_gain * float(ch_gain[ch - 1]) * mux_gain
                st = tpl.stats[ch]

                base = baseline_components(molw_base, st["base_med"], st["base_noise"],
                                          plant_offset, plant_baseline_gain, mode, cfg)
                base = limit_slow_baseline(base, st["q_low"], cfg)

                peaks = np.zeros_like(molw_base)
                evs: List[Dict] = []

                n_true = choose_peak_count(cfg)
                centers = sample_peak_centers(n_true, tpl.min_bp, tpl.max_bp, cfg)
                for mu in centers:
                    if sum(1 for e in evs if e.get("peak_type") == "true") >= cfg.max_true_peaks:
                        break
                    if len(evs) >= cfg.max_total_events_per_channel:
                        break

                    frac = (mu - tpl.min_bp) / max(1e-9, (tpl.max_bp - tpl.min_bp))
                    q_low = st["q_low"]
                    q_high = max(max(cfg.amp_floor, q_low) * 1.2, st["q_high"])
                    amp_cap = max(q_high, st["amp_cap"])

                    amp = math.exp(np.random.uniform(math.log(max(cfg.amp_floor, q_low) + 1e-9),
                                                     math.log(q_high + 1e-9)))
                    amp *= math.exp(np.random.normal(0.0, cfg.amp_logn_sigma))
                    amp *= gain
                    amp = float(np.clip(amp, cfg.amp_floor, amp_cap))

                    decayed = amp * math.exp(-np.random.uniform(0.0, 1.1) * frac)
                    if random.random() < dropout_prob(q_low, decayed, cfg):
                        continue

                    sigma = float(np.clip(np.random.uniform(cfg.sigma_min, cfg.sigma_max), cfg.sigma_min, cfg.sigma_max))
                    tau = float(np.random.uniform(cfg.tau_min, cfg.tau_max))
                    left = (random.random() < cfg.p_left_tail)

                    peaks += emg_peak(molw_base, amp, mu, sigma, tau, left_tail=left)
                    evs.append(dict(
                        filename=filename, channel=ch_name,
                        peak_center_molw=float(mu), peak_height=float(amp),
                        peak_type="true", sigma=float(sigma), tau=float(tau),
                        base_noise=float(st["base_noise"]),
                    ))

                    # stutter
                    if random.random() < cfg.p_stutter and (len(evs) + 2) <= cfg.max_total_events_per_channel:
                        ru = random.choice(cfg.repeat_units)
                        if random.random() < cfg.p_double_step:
                            ru *= 2
                        shifts = [-ru]
                        if random.random() < cfg.p_forward_stutter:
                            shifts.append(ru)

                        for sh in shifts:
                            if len(evs) >= cfg.max_total_events_per_channel:
                                break
                            mu_s = mu + float(sh)
                            if not (tpl.min_bp <= mu_s <= tpl.max_bp):
                                continue
                            ratio = stutter_ratio(amp, frac, cfg)
                            amp_s = amp * ratio
                            sigma_s = max(cfg.sigma_min, sigma * np.random.uniform(0.85, 1.10))
                            tau_s = max(cfg.tau_min, tau * np.random.uniform(0.90, 1.15))
                            peaks += emg_peak(molw_base, amp_s, mu_s, sigma_s, tau_s, left_tail=left)
                            evs.append(dict(
                                filename=filename, channel=ch_name,
                                peak_center_molw=float(mu_s), peak_height=float(amp_s),
                                peak_type="stutter", sigma=float(sigma_s), tau=float(tau_s),
                                base_noise=float(st["base_noise"]),
                            ))

                # spurious peaks
                if random.random() < cfg.p_spurious:
                    for mu in sample_peak_centers(random.randint(*cfg.spurious_count_range), tpl.min_bp, tpl.max_bp, cfg):
                        if len(evs) >= cfg.max_total_events_per_channel:
                            break
                        amp_s = random.uniform(*cfg.spurious_amp_range) * gain
                        sigma_s = float(np.random.uniform(0.25, 2.0))
                        tau_s = float(np.random.uniform(cfg.tau_min, cfg.tau_max))
                        peaks += emg_peak(molw_base, amp_s, mu, sigma_s, tau_s, left_tail=(random.random() < 0.04))
                        evs.append(dict(
                            filename=filename, channel=ch_name,
                            peak_center_molw=float(mu), peak_height=float(amp_s),
                            peak_type="spurious", sigma=float(sigma_s), tau=float(tau_s),
                            base_noise=float(st["base_noise"]),
                        ))

                base_by_ch[ch_name] = base
                peaks_by_ch[ch_name] = peaks
                sat_levels[ch_name] = st["sat"] * gain
                events.extend(evs)

            # --- crosstalk on peaks only ---
            peaks_ct, used = apply_crosstalk(peaks_by_ch, sat_levels, CT, cfg)

            # record crosstalk events (derived from donor peaks)
            def can_add_event(chname: str, extra: int = 1) -> bool:
                c = sum(1 for e in events if e["filename"] == filename and e["channel"] == chname)
                return (c + extra) <= cfg.max_total_events_per_channel

            for src, tgt, coef in used:
                for e in list(events):
                    if e["channel"] != src:
                        continue
                    amp_ct = float(e["peak_height"]) * abs(float(coef))
                    if amp_ct < 8.0:
                        continue
                    if not can_add_event(tgt, 1):
                        continue
                    events.append(dict(
                        filename=filename, channel=tgt,
                        peak_center_molw=float(e["peak_center_molw"]),
                        peak_height=float(amp_ct),
                        peak_type="crosstalk",
                        sigma=float(e.get("sigma", 0.0)),
                        tau=float(e.get("tau", 0.0)),
                        base_noise=float(e.get("base_noise", 0.0)),
                    ))

            # --- output dataframe ---
            out = tpl.df.copy()
            out["molw"] = molw_out

            # Ladder
            out["channel_5"] = generate_liz500_ladder(molw_base, molw_out, tpl.ladder_noise, mux_shift, cfg).astype(float)

            def mu_out(mu: float) -> float:
                return float(np.interp(mu, molw_base, molw_out))

            # --- build final channels 1..4 with splice+centering ---
            for ch in range(1, 5):
                ch_name = f"channel_{ch}"
                st = tpl.stats[ch]

                # Full
                y_full = base_by_ch[ch_name] + peaks_ct[ch_name]
                y_full += add_noise(st["base_noise"], y_full, plant_noise_gain * mux_noise_gain, mode, cfg)
                y_full = soft_saturate(y_full, sat_levels[ch_name], cfg.overload_softness)
                y_full = shift_series(np.nan_to_num(y_full), mux_shift)
                y_out_full = resample_to_molw(molw_base, y_full, molw_out)

                # Baseline-only (for <50)
                y_pre = base_by_ch[ch_name].copy()
                y_pre += add_noise(st["base_noise"], y_pre, plant_noise_gain * mux_noise_gain, mode, cfg)
                y_pre = soft_saturate(y_pre, sat_levels[ch_name], cfg.overload_softness)
                y_pre = shift_series(np.nan_to_num(y_pre), mux_shift)
                y_out_pre = resample_to_molw(molw_base, y_pre, molw_out)

                # splice: <50 bp baseline-only
                y = np.asarray(y_out_full, float).copy()
                y[pre_mask] = np.asarray(y_out_pre, float)[pre_mask]

                # baseline centering: constant from <50 region
                if cfg.center_baseline:
                    y = center_baseline_pre50(y, molw_out, cfg.prepeak_cutoff_bp, cfg.center_q)

                out[ch_name] = np.nan_to_num(y, nan=0.0, posinf=0.0, neginf=0.0).astype(float)

            # Save signals
            out.to_csv(os.path.join(signals_dir, filename), sep=";", index=False)

            # Events on out-axis for masks/centers
            events_out = []
            for e in events:
                eo = dict(e)
                eo["peak_center_molw"] = mu_out(float(e["peak_center_molw"]))
                events_out.append(eo)

            mask_df = build_dense_mask(molw_out, events_out, filename, cfg)
            mask_df.insert(0, "index", tpl.df["index"].to_numpy(int))
            mask_df.to_csv(os.path.join(masks_dir, f"labels_mask_{mux_name}_{plant_tag}.csv"), sep=";", index=False)

            center_df = build_center_labels(molw_out, events_out, filename, cfg)
            center_df.insert(0, "index", tpl.df["index"].to_numpy(int))
            center_df.to_csv(os.path.join(centers_dir, f"labels_center_{mux_name}_{plant_tag}.csv"), sep=";", index=False)

            # labels_map rows
            for e in events:
                mu_base = float(e["peak_center_molw"])
                muo = mu_out(mu_base)
                if e["channel"].startswith("channel_") and muo < cfg.prepeak_cutoff_bp:
                    continue
                if not passes_label_filter(e, cfg):
                    continue
                labels_map_rows.append(dict(
                    filename=e["filename"],
                    channel=e["channel"],
                    peak_center_molw=float(mu_base),
                    peak_center_molw_out=float(muo),
                    peak_height=float(e["peak_height"]),
                    peak_type=str(e.get("peak_type", "")),
                    mode=mode,
                    mux_shift=mux_shift,
                    molw_scale=float(molw_meta["molw_scale"]),
                    molw_shift=float(molw_meta["molw_shift"]),
                    warp_kind=str(molw_meta["warp_kind"]),
                    warp_k=float(molw_meta["warp_k"]),
                    molw_jitter_amp=float(molw_meta["molw_jitter_amp"]),
                    bp_min_stats=float(cfg.template_stats_min_bp),
                    prepeak_cutoff=float(cfg.prepeak_cutoff_bp),
                ))

            # peak_positions_detailed (legacy): only main/spurious
            for e in events:
                ch_name = e["channel"]
                if ch_name not in ("channel_1", "channel_2", "channel_3", "channel_4"):
                    continue
                ptype = str(e.get("peak_type", ""))
                if ptype not in ("true", "spurious"):
                    continue

                ch_num = int(ch_name.split("_")[1])
                muo = mu_out(float(e["peak_center_molw"]))
                if muo < cfg.prepeak_cutoff_bp:
                    continue

                peak_kind = "main" if ptype == "true" else "spurious"
                sigma_pb = float(e.get("sigma", 0.0) or 0.0)
                peak_width_pb = cfg.legacy_peak_width_mult * sigma_pb
                amplitude = float(e.get("peak_height", 0.0) or 0.0)

                legacy_rows.append(dict(
                    plant_id=plant_id,
                    plant_id_str=plant_id_str,
                    multiplex=int(mux_idx),
                    channel=int(ch_num),
                    marker_id=int(ch_num),
                    peak_kind=str(peak_kind),
                    mu_pb=float(muo),
                    sigma_pb=float(sigma_pb),
                    peak_width_pb=float(peak_width_pb),
                    amplitude=float(amplitude),
                    amp_class=(amp_class(amplitude, cfg) if peak_kind == "main" else np.nan),
                    plant_peak_width_max_pb=float(cfg.legacy_plant_peak_width_max_pb),
                ))

            # Optional plots
            if make_plots:
                chan_dir = os.path.join(by_channel_dir, plant_tag, mux_name)
                ensure_dir(chan_dir)
                for c in range(1, 6):
                    plot_channel(out, os.path.join(chan_dir, f"{mux_name}_{plant_tag}_channel_{c}.png"),
                                 f"channel_{c}", cfg.plot_qlo, cfg.plot_qhi)
                    done_plots += 1
                    progress(done_plots, total_plots)

                plot_stacked(out, os.path.join(stacked_dir, f"{mux_name}_{plant_tag}_STACKED.png"),
                             cfg.plot_qlo, cfg.plot_qhi)
                done_plots += 1
                progress(done_plots, total_plots)

    # Write global files
    pd.DataFrame(labels_map_rows).to_csv(os.path.join(map_dir, "labels_map.csv"), sep=";", index=False)
    build_peak_positions_detailed(legacy_rows).to_csv(os.path.join(map_dir, "peak_positions_detailed.csv"), index=False)

    return dict(
        out_dir=out_dir,
        signals_dir=signals_dir,
        labels_masks_dir=masks_dir,
        labels_centers_dir=centers_dir,
        labels_map_dir=map_dir,
        plots_dir=(plots_root if make_plots else None),
    )

# =============================================================================
# Run section
# =============================================================================

TEMPLATE_M1 = "M1_pl1.csv"
TEMPLATE_M2 = "M2_pl1.csv"
DEFAULT_OUT = "dataset_synthetic"

try:
    n = int(input("How many individuals (plants) to generate? [20]: ") or "20")
except Exception:
    n = 20

make_plots = (input("Generate plots? (y/n) [n]: ").strip().lower() == "y")

seed_in = input("Seed (empty = random): ").strip()
seed = int(seed_in) if seed_in else None

out_dir = input(f"Output folder (empty = {DEFAULT_OUT}): ").strip() or DEFAULT_OUT

paths = generate_dataset(
    template_m1_path=TEMPLATE_M1,
    template_m2_path=TEMPLATE_M2,
    n_plants=n,
    out_dir=out_dir,
    cfg=CFG,
    make_plots=make_plots,
    seed=seed,
)

print("\nDone.")
for k, v in paths.items():
    print(f"- {k}: {v}")


How many individuals (plants) to generate? [20]:  20
Generate plots? (y/n) [n]:  y
Seed (empty = random):  
Output folder (empty = dataset_synthetic):  


Generating plots: [############################] 240/240 (100.0%)

Done.
- out_dir: dataset_synthetic
- signals_dir: dataset_synthetic/signals
- labels_masks_dir: dataset_synthetic/labels_masks
- labels_centers_dir: dataset_synthetic/labels_centers
- labels_map_dir: dataset_synthetic/labels_map
- plots_dir: dataset_synthetic/plots
