In [1]:
# HYBRID: flat‑aware rotation/placement  ➜ optional curvature growth ➜ merge
import importlib
import numpy as np
import torch
import spike_aware_augmentation
import spike_aware_v1
import dataset
from spike_aware_augmentation import rotate_and_place_region_flataware2
importlib.reload(spike_aware_augmentation)
importlib.reload(spike_aware_v1)
importlib.reload(dataset)
from spike_aware_v1 import multi_apex_expand_mix, contour_curvature_apices
importlib.reload(spike_aware_v1)
import random
from spike_aware_augmentation import rotate_and_place_region_flataware2
import os, random, numpy as np, torch
from PIL import Image
import cv2
from dataset import MaskOnlyDataset,label_to_rgb, translate_1d_to_rgb,extract_connected_regions
def merge_grown_into_global(global_mask_tensor,
                            grown_binary_uint8,
                            label_value,
                            allow_overlap_labels=(3,)):
    """
    Merge `grown_binary_uint8` (0/1, numpy) into `global_mask_tensor` (torch 2D int),
    only painting over background (0). Overlap with `label_value` and allow_overlap_labels
    is tolerated (kept as-is). Any other non-zero label blocks painting.
    """
    g = global_mask_tensor.clone().numpy()
    src = grown_binary_uint8.astype(np.uint8)
    allowed = set((0, label_value) + tuple(allow_overlap_labels))

    # Only paint where dest==0 and src==1
    paint = (src == 1) & (g == 0)
    g[paint] = label_value
    return torch.from_numpy(g)
def merge_add_remove_into_global(global_mask_tensor,
                                 add_binary_uint8,
                                 remove_binary_uint8,
                                 label_value,
                                 allow_overlap_labels=(3,)):
    """
    Apply removals for `label_value` first, then additions.
    - remove: set to 0 where (global==label_value) & (remove==1)
    - add: only paint where global==0 and add==1 (allowed-overlap unchanged)
    """
    g = global_mask_tensor.clone().numpy()
    add = (add_binary_uint8.astype(np.uint8) == 1)
    rem = (remove_binary_uint8.astype(np.uint8) == 1)

    # 1) removals only affect the same label
    g[(g == label_value) & rem] = 0

    # 2) additions follow the same allowed-overlap policy as before
    allowed = set((0, label_value) + tuple(allow_overlap_labels))
    # where we want to paint (new pixels) and it's background
    paint_spots = add & (g == 0)
    g[paint_spots] = label_value

    return torch.from_numpy(g)
def _sample(v, is_int=False):
    if isinstance(v, (tuple, list)) and len(v) == 2:
        a, b = v
        return int(random.randint(int(a), int(b))) if is_int else float(random.uniform(float(a), float(b)))
    return v

def _class_pick(cfg, lbl, section, key, default=None, is_int=False):
    """
    Fetch cfg[lbl][section][key] if present, else return default.
    Then sample if it's a (min,max) tuple.
    """
    try:
        v = cfg.get(lbl, {}).get(section, {}).get(key, default)
    except Exception:
        v = default
    return _sample(v, is_int=is_int)

def _class_pick_mode(cfg, lbl, key, mode, default=None, is_int=False):
    """
    For keys that can be per-mode dicts, e.g., {"scale":{"expand":(1.2,1.6),"shrink":(0.9,1.2)}}.
    If cfg missing that structure, fall back to cfg[lbl]["grow"][key] (shared), else default.
    """
    block = cfg.get(lbl, {}).get("grow", {}).get(key, None)
    if isinstance(block, dict) and mode in block:
        return _sample(block[mode], is_int=is_int)
    if block is not None:
        return _sample(block, is_int=is_int)
    return _sample(default, is_int=is_int)

def remove_small_components(mask_tensor: torch.Tensor,
                            min_size: int = 10,
                            labels_to_filter=(1, 2),
                            connectivity: int = 8) -> torch.Tensor:
    """
    Remove connected components with area < min_size for the given labels.
    Keeps other labels (e.g., ships=3) unchanged.
    Works on a 2D int mask tensor with values in {0,1,2,3,...}.
    """
    m = mask_tensor.detach().cpu().numpy().astype(np.int32)
    H, W = m.shape

    for lbl in labels_to_filter:
        binmap = (m == lbl).astype(np.uint8)
        if binmap.sum() < min_size:
            # Nothing big enough to keep
            m[binmap == 1] = 0
            continue

        # OpenCV returns: n, labels_img, stats, centroids
        # stats[:, cv2.CC_STAT_AREA] gives area in pixels
        n, labimg, stats, _ = cv2.connectedComponentsWithStats(
            binmap, connectivity=connectivity, ltype=cv2.CV_32S
        )
        if n <= 1:
            # only background
            continue

        areas = stats[1:, cv2.CC_STAT_AREA]  # skip background at index 0
        small_ids = np.where(areas < min_size)[0] + 1  # shift back (+1)
        if len(small_ids) == 0:
            continue

        # zero out the small components for this label
        remove_mask = np.isin(labimg, small_ids)
        m[remove_mask] = 0

    return torch.from_numpy(m.astype(np.uint8))

def _select_apices_for_region(apices: np.ndarray,
                              region_mask: np.ndarray,
                              n_apices: int,
                              random_state: int = 42) -> np.ndarray:
    """
    Mirror of the selection policy in multi_apex_expand_mix:
      - KMeans with k = min(n_apices, len(apices))
      - pick farthest-from-centroid point within each cluster
    Returns an array of shape (k, 2) with (col, row).
    """
    if apices is None or len(apices) == 0 or n_apices <= 0:
        return np.empty((0, 2), dtype=int)

    rows, cols = np.nonzero(region_mask)
    cy, cx = rows.mean(), cols.mean()

    k = int(min(n_apices, len(apices)))
    if k <= 0:
        return np.empty((0, 2), dtype=int)

    from sklearn.cluster import KMeans
    km = KMeans(n_clusters=k, n_init='auto', random_state=random_state).fit(apices)
    labels = km.labels_

    chosen = []
    for cluster_id in range(k):
        idxs = np.where(labels == cluster_id)[0]
        if len(idxs) == 0:
            continue
        # farthest from (cx, cy) in that cluster
        d = np.hypot(apices[idxs,0] - cx, apices[idxs,1] - cy)
        pick = idxs[np.argmax(d)]
        chosen.append(apices[pick])

    if len(chosen) == 0:
        return np.empty((0, 2), dtype=int)

    return np.vstack(chosen).astype(int)



def hybrid_flataware_then_curvature(
    sample_mask,
    label_value=1,
    max_shift=50,          # int or (min,max)  -> sampled PER REGION
    attempts=30,           # int or (min,max)  -> sampled PER REGION
    do_curvature=True,
    n_apices=3,            # int or (min,max)  -> sampled PER REGION
    alpha_deg=40,          # float or (min,max)-> sampled PER REGION
    scale=1.25,            # float or (min,max)-> sampled PER REGION
    max_radius=60,         # int or (min,max)  -> sampled PER REGION
    n_regions=8,
    n_rays=70,             # int
    region_pick="random",
    class_cfg: dict | None = None,
    growth_pick="largest",
    candidate_labels=None,
    p_expand=0.6,
    debug=False,
    enforce_label_diversity: bool = False,
    debug_save_dir: str | None = None,
    debug_prefix: str = "",
    per_region_debug: bool = False,   # <<< NEW
):
    """
    For any of {max_shift, attempts, n_apices, alpha_deg, scale, max_radius}:
      - scalar => fixed
      - (min,max) => sampled independently for EACH region
    """
    import random
    import numpy as np

    def _area(region_dict):
        return int(np.sum(region_dict['binary_mask']))

    def _select_regions(regions, k, mode):
        if not regions:
            return []
        k = min(k, len(regions))
        if mode == "largest":
            return sorted(regions, key=_area, reverse=True)[:k]
        if mode == "smallest":
            return sorted(regions, key=_area)[:k]
        return random.sample(regions, k=k)

    def _sample_param(v, is_int=False):
        if isinstance(v, (tuple, list)) and len(v) == 2:
            a, b = v
            return int(random.randint(int(a), int(b))) if is_int else float(random.uniform(float(a), float(b)))
        return v

    # ---- Build the source pools, grouped by label ----
    regions_by_label = {}
    labels_to_pool = [label_value] if candidate_labels is None else list(candidate_labels)
    for lbl in labels_to_pool:
        regs = extract_connected_regions(sample_mask, lbl)
        if regs:
            regions_by_label[int(lbl)] = []
            for r in regs:
                rr = dict(r); rr['label'] = int(lbl)
                regions_by_label[int(lbl)].append(rr)

    # Flattened pool (for fallback selection)
    regions_src = [r for regs in regions_by_label.values() for r in regs]
    if not regions_src:
        return sample_mask

    # Balanced pick when asked and feasible (needs ≥2 labels & ≥2 regions)
    def _balanced_pick(regs_by_lbl: dict[int, list], k: int):
        labels = sorted(regs_by_lbl.keys())
        if len(labels) < 2:
            pool = [r for ll in labels for r in regs_by_lbl[ll]]
            return _select_regions(pool, k, region_pick)
        a, b = labels[0], labels[1]
        k_a = k // 2
        k_b = k - k_a
        A = _select_regions(regs_by_lbl[a], k_a, region_pick)
        B = _select_regions(regs_by_lbl[b], k_b, region_pick)
        # If one side is short, fill remaining from the other
        missing = k - (len(A) + len(B))
        if missing > 0:
            rest = [r for r in regs_by_lbl[a] if r not in A] + [r for r in regs_by_lbl[b] if r not in B]
            A += _select_regions(rest, missing, region_pick)
        return A + B

    if enforce_label_diversity and len(regions_by_label.keys()) >= 2 and n_regions >= 2:
        chosen_src = _balanced_pick(regions_by_label, n_regions)
    else:
        chosen_src = _select_regions(regions_src, n_regions, region_pick)

    # ---- Step 1: per-REGION placement (translate/rotate) ----
    aug_mask = sample_mask.clone()
    Hf, Wf = aug_mask.shape
    global_add    = np.zeros((Hf, Wf), dtype=bool)
    global_remove = np.zeros((Hf, Wf), dtype=bool)

    # Keep a copy of the original for final plot
    before_np = sample_mask.detach().cpu().numpy().astype(np.uint8)
    for reg in chosen_src:
        lbl = int(reg.get('label', label_value))

        # placement params (class-aware → fallback to global)
        ms  = _class_pick(class_cfg or {}, lbl, "place", "max_shift", default=max_shift, is_int=True)
        att = _class_pick(class_cfg or {}, lbl, "place", "attempts",  default=attempts,  is_int=True)

        aug_mask = rotate_and_place_region_flataware2(
            aug_mask,
            reg,
            label_value=lbl,
            max_shift=ms,
            attempts=att,
            n_rays=20,
            debug=False,
        )
    placed_np = aug_mask.detach().cpu().numpy().astype(np.uint8)
    if not do_curvature:
        return aug_mask

    # ---- Step 2: choose regions to grow (across labels if needed) ----
    labels_for_growth = tuple(regions_by_label.keys()) if candidate_labels is not None else (label_value,)
    regions_after = []
    for lbl in labels_for_growth:
        regs = extract_connected_regions(aug_mask, lbl)
        for r in regs:
            rr = dict(r); rr['label'] = int(lbl)
            regions_after.append(rr)

    if not regions_after:
        return aug_mask

    LARGE_OIL_FRACTION = 0.40  # 40% of the tile
    Hf, Wf = aug_mask.shape

    # collect oil regions from 'regions_after'
    oil_regions = [r for r in regions_after if int(r.get('label', label_value)) == 1]
    large_oil_region = None
    if oil_regions:
        # find the largest connected oil region by area
        areas = [int(np.sum(r['binary_mask'])) for r in oil_regions]
        i_max = int(np.argmax(areas))
        area_frac_max = areas[i_max] / float(Hf * Wf + 1e-9)
        if area_frac_max >= LARGE_OIL_FRACTION:
            large_oil_region = oil_regions[i_max]

    # ---- Choose regions to grow
    chosen_to_grow = _select_regions(regions_after, len(chosen_src), growth_pick)

    # ensure the large_oil_region (if any) is included in chosen_to_grow
    if large_oil_region is not None:
        # push it in front, then fill remainder (unique)
        seen = set([id(r) for r in chosen_to_grow])
        if id(large_oil_region) not in seen:
            chosen_to_grow = [large_oil_region] + chosen_to_grow
            # trim to original length, preserving order & uniqueness
            uniq, got = [], set()
            for r in chosen_to_grow:
                ir = id(r)
                if ir in got:
                    continue
                uniq.append(r)
                got.add(ir)
                if len(uniq) >= len(chosen_src):
                    break
            chosen_to_grow = uniq

    apex_records = []  # will store dicts with {label, region_idx, chosen_apices}

    # ---- Step 3: per-REGION curvature (grow/shrink) with class-aware params ----
    for r_idx, r in enumerate(chosen_to_grow):
        region_mask = r['binary_mask'].astype(np.uint8)
        lbl = int(r.get('label', label_value))

        # class-aware growth params (fallback to global ranges)
        pexp   = _class_pick(class_cfg or {}, lbl, "grow", "p_expand",   default=p_expand)
        naps   = _class_pick(class_cfg or {}, lbl, "grow", "n_apices",   default=n_apices,  is_int=True)
        a_deg  = _class_pick(class_cfg or {}, lbl, "grow", "alpha_deg",  default=alpha_deg)
        rays   = _class_pick(class_cfg or {}, lbl, "grow", "n_rays",     default=n_rays,    is_int=True)

        # per-mode ranges (if provided)
        scale_expand  = _class_pick_mode(class_cfg or {}, lbl, "scale",      mode="expand", default=scale)
        scale_shrink  = _class_pick_mode(class_cfg or {}, lbl, "scale",      mode="shrink", default=scale)
        maxrad_expand = _class_pick_mode(class_cfg or {}, lbl, "max_radius", mode="expand", default=max_radius, is_int=True)
        maxrad_shrink = _class_pick_mode(class_cfg or {}, lbl, "max_radius", mode="shrink", default=max_radius, is_int=True)
            # ✨ Aggressive policy for very large blobs
        Hf, Wf = aug_mask.shape
        area_frac = float(region_mask.sum()) / float(Hf * Wf + 1e-9)
        if lbl == 1:  # oil only
            area_frac = float(region_mask.sum()) / float(Hf * Wf + 1e-9)
            is_large_target = (large_oil_region is not None) and (id(r) == id(large_oil_region))
            if is_large_target or (area_frac >= LARGE_OIL_FRACTION):
                # Strong bias to SHRINK with irregularity
                pexp  = min(pexp, 0)                            # mostly shrink operations
                naps  = max(naps, 20)                              # more apices → more jaggedness
                a_deg = max(a_deg, 190.0)                          # wider lobes
                rays  = max(rays, 200)                             # denser sampling

                # ensure ranges are tuples for the shrink/expand picks
                if not isinstance(scale_expand, (tuple, list)):
                    scale_expand = (1.0, 1.15)                     # almost no global expansion
                if not isinstance(scale_shrink, (tuple, list)):
                    scale_shrink = random.uniform(5, 8.0)                     # strong shrink
                if not isinstance(maxrad_expand, (tuple, list)):
                    maxrad_expand = (10, 40)
                if not isinstance(maxrad_shrink, (tuple, list)):
                    maxrad_shrink = random.uniform(150, 600)                      # long reach to carve edges



        apices, _, _ = contour_curvature_apices(
            region_mask,
            smooth_window=9, polyorder=3,
            peak_prom_quantile=0.90, min_separation_px=3,
            radial_boost=0.5, deriv_step=2,
            debug=debug
        )
        if len(apices) == 0:
            continue
        chosen_pts = _select_apices_for_region(apices, region_mask, naps)
        apex_records.append({
            "label": lbl,
            "region_idx": r_idx,
            "chosen": chosen_pts,   # (m,2) (col,row)
        })

        add_mask, remove_mask = multi_apex_expand_mix(
            region_mask = region_mask,
            apices      = apices,
            n_apices    = naps,
            alpha_deg   = a_deg,
            scale_expand       = scale_expand,
            scale_shrink       = scale_shrink,
            max_radius_expand  = maxrad_expand,
            max_radius_shrink  = maxrad_shrink,
            p_expand    = pexp,
            n_rays      = rays,
            debug       = debug and per_region_debug,
            return_components=True,
            debug_save_dir =  (debug_save_dir if per_region_debug else None),
            debug_prefix   = debug_prefix
        )
        global_add    |= (add_mask.astype(bool))
        global_remove |= (remove_mask.astype(bool))
        aug_mask = merge_add_remove_into_global(
            global_mask_tensor = aug_mask,
            add_binary_uint8   = add_mask,
            remove_binary_uint8= remove_mask,
            label_value        = lbl,
            allow_overlap_labels=(3,)
        )
        aug_mask = remove_small_components(
            aug_mask, min_size=500, connectivity=4
        )
    after_np = aug_mask.detach().cpu().numpy().astype(np.uint8)
    if debug:
        try:
            import matplotlib.pyplot as plt
            from matplotlib.lines import Line2D
            from matplotlib.patches import Patch
            from dataset import translate_1d_to_rgb

            def _normalize_rgb(arr):
                """Return float RGB in [0,1] regardless of input dtype/range."""
                arr = np.asarray(arr)
                if arr.dtype == np.uint8 or arr.max() > 1.0:
                    return arr.astype(np.float32) / 255.0
                return arr.astype(np.float32)

            def color_mask(img01, rgb01, atol=0.10):
                """
                Boolean mask where pixels in img01 (float [0,1]) are close to rgb01.
                """
                return np.all(np.isclose(img01, np.array(rgb01, dtype=np.float32), atol=atol), axis=-1)

            # ----------------------- base RGBs -----------------------
            orig_rgb   = translate_1d_to_rgb(before_np)   # pre-placement
            placed_rgb = translate_1d_to_rgb(placed_np)   # post-placement
            final_rgb  = translate_1d_to_rgb(after_np)    # final

            # normalize for color masking (keeps originals for display)
            placed_rgb01 = _normalize_rgb(placed_rgb)

            # ------------------- unions (curvature) -------------------
            add_global    = global_add.astype(bool)
            remove_global = global_remove.astype(bool)
            add_only      = add_global & ~remove_global
            remove_only   = remove_global

            # ---------------------- figure ----------------------
            fig, axes = plt.subplots(1, 4, figsize=(24, 6))

            # 1) Original
            axes[0].imshow(orig_rgb)
            axes[0].set_title("Original (Pre-Placement)")
            axes[0].axis('off')

            # 2) Post-Placement + Apices
            axes[1].imshow(placed_rgb)
            axes[1].set_title("Post-Placement + Selected Apices")
            axes[1].axis('off')

            # plot apices + labels
            for rec in apex_records:
                pts  = rec.get("chosen", None)
                ridx = rec.get("region_idx", -1)
                if pts is None or len(pts) == 0:
                    continue
                xs, ys = pts[:, 0], pts[:, 1]
                axes[1].scatter(xs, ys, s=80, c='red', edgecolors='k', linewidths=1, alpha=1.0)
                for j, (x, y) in enumerate(zip(xs, ys)):
                    axes[1].text(
                        x + 2, y + 2, f"{ridx}-{j}", fontsize=7, color='black',
                        bbox=dict(facecolor='white', alpha=0.9, edgecolor='none', pad=0.5)
                    )

            # add small legend inside subplot 2
            legend_apex = [
                Line2D([0], [0], marker='o', color='w', label='Selected Apices',
                    markerfacecolor='red', markeredgecolor='k', markersize=8)
            ]
            axes[1].legend(handles=legend_apex, loc='lower left',
                        fontsize=8, frameon=True, facecolor='white', framealpha=0.8)

            # 3) Curvature edits overlay
            axes[2].imshow(placed_rgb, interpolation='nearest')

            H, W = add_only.shape
            highlight = (add_only | remove_only)            # regions we want to keep visible

            # --- 1) Dim everything else (opaque mask with holes on highlights) ---
            dim_rgba = np.zeros((H, W, 4), dtype=np.float32)
            dim_rgba[..., :3] = 0.0                         # black
            dim_rgba[..., 3]  = 0.75                        # 75% dim over the whole image
            dim_rgba[..., 3][highlight] = 0.0               # no dim on add/shrink
            axes[2].imshow(dim_rgba, interpolation='nearest')

            # --- 2) Paint add/remove overlays on top (solid & clear) ---
            add_rgba = np.zeros((H, W, 4), dtype=np.float32)
            add_rgba[..., 1] = 1.0                          # green
            add_rgba[..., 3] = add_only.astype(np.float32) * 0.95
            axes[2].imshow(add_rgba, interpolation='nearest')

            rmv_rgba = np.zeros((H, W, 4), dtype=np.float32)
            rmv_rgba[..., 0] = 1.0                          # magenta = R+B
            rmv_rgba[..., 2] = 1.0
            rmv_rgba[..., 3] = remove_only.astype(np.float32) * 0.95
            axes[2].imshow(rmv_rgba, interpolation='nearest')

            axes[2].set_title("Curvature Edits (+Add / −Shrink)")
            axes[2].axis('off')

            # Legend inside the panel (bottom-left)
            from matplotlib.patches import Patch
            legend_elems_sub3 = [
                Patch(facecolor='limegreen', edgecolor='none', label='+ Expansion'),
                Patch(facecolor='magenta',   edgecolor='none', label='− Shrink'),
            ]
            axes[2].legend(handles=legend_elems_sub3, loc='lower left',
                        fontsize=8, frameon=True, facecolor='white', framealpha=0.85)

            # 4) Final
            axes[3].imshow(final_rgb)
            axes[3].set_title("Final (Post-Growth/Shrink)")
            axes[3].axis('off')

            # layout + save
            plt.tight_layout()
            if debug_save_dir is not None:
                os.makedirs(debug_save_dir, exist_ok=True)
                out_path = os.path.join(debug_save_dir, f"{debug_prefix}_GLOBAL_summary_rgb.png")
                plt.savefig(out_path, dpi=150, bbox_inches='tight')
                print(f"[debug] saved per-image RGB summary → {out_path}")

            plt.show()
            plt.close(fig)

        except Exception as e:
            print(f"[debug] plotting failed: {e}")

    return aug_mask

In [2]:
import os, random, numpy as np, torch
from PIL import Image
import cv2
# your modules

# ------------------------ helpers ------------------------

def _resolve_scalar(v, is_int=False):
    """
    If v is a 2-tuple/list -> sample uniform in [v[0], v[1]] (int or float).
    Else return v.
    """
    if isinstance(v, (tuple, list)) and len(v) == 2:
        a, b = v
        if is_int:
            return int(random.randint(int(a), int(b)))
        else:
            return float(random.uniform(float(a), float(b)))
    return v

def _choose_target_label(mask_np, candidates=(1, 2)):
    """
    Return one label to augment based on presence:
      - if exactly one of {1,2} present -> return that one
      - if both present -> random choice
      - if none present -> return None
    """
    present = [lbl for lbl in candidates if (mask_np == lbl).any()]
    if not present:
        return None
    if len(present) == 1:
        return present[0]
    return random.choice(present)
def remove_small_components(mask_tensor: torch.Tensor,
                            min_size: int = 10,
                            labels_to_filter=(1, 2),
                            connectivity: int = 8) -> torch.Tensor:
    """
    Remove connected components with area < min_size for the given labels.
    Keeps other labels (e.g., ships=3) unchanged.
    Works on a 2D int mask tensor with values in {0,1,2,3,...}.
    """
    m = mask_tensor.detach().cpu().numpy().astype(np.int32)
    H, W = m.shape

    for lbl in labels_to_filter:
        binmap = (m == lbl).astype(np.uint8)
        if binmap.sum() < min_size:
            # Nothing big enough to keep
            m[binmap == 1] = 0
            continue

        # OpenCV returns: n, labels_img, stats, centroids
        # stats[:, cv2.CC_STAT_AREA] gives area in pixels
        n, labimg, stats, _ = cv2.connectedComponentsWithStats(
            binmap, connectivity=connectivity, ltype=cv2.CV_32S
        )
        if n <= 1:
            # only background
            continue

        areas = stats[1:, cv2.CC_STAT_AREA]  # skip background at index 0
        small_ids = np.where(areas < min_size)[0] + 1  # shift back (+1)
        if len(small_ids) == 0:
            continue

        # zero out the small components for this label
        remove_mask = np.isin(labimg, small_ids)
        m[remove_mask] = 0

    return torch.from_numpy(m.astype(np.uint8))
# ------------------------ main batch driver ------------------------

def batch_hybrid_augment(
    label_dir: str,
    out_dir: str,
    num_images: int = 10,
    pick_mode: str = "random",
    indices: list[int] | None = None,
    # --- placement params (support intervals) ---
    max_shift=50,
    attempts=30,
    # --- curvature params (support intervals) ---
    do_curvature: bool = True,
    n_apices=6,
    alpha_deg=40,
    scale=1.3,
    max_radius=60,
    p_expand: float = 0.6,              # ✅ fixed
    # --- multi-blob controls ---
    class_cfg: dict | None = None,      # ✅ keep only once
    n_regions=1,
    n_rays: int = 60,
    region_pick: str = "random",
    growth_pick: str = "largest",
    per_region_params: bool = True,
    # --- cross-label selection ---
    any_label_by_size: bool = False,
    candidate_labels = (1, 2),
    # --- IO & reproducibility ---
    save_rgb: bool = True,
    min_region_pixels: int = 10,
    random_seed: int | None = None,
    force_label: int | None = None,    # ← 1 (oil) or 2 (look), or None for old behavior
    translate_only: bool = False,            # NEW
    move_both_if_present: bool = False,      # NEW
    debug: bool = False,
):
    """
    Process multiple masks from `label_dir` using hybrid augmentation.
    Now supports placing/growing multiple blobs per image via `n_regions`.
    Saves outputs in:
        <out_dir>/label_1d/aug_<tag>_label_<name>.png
        <out_dir>/color_rgb/aug_<tag>_rgb_<name>.png
    """
    # Ensure subfolders exist
    out_dir_1d = os.path.join(out_dir, "label_1d")
    out_dir_rgb = os.path.join(out_dir, "color_rgb")
    os.makedirs(out_dir_1d, exist_ok=True)
    if save_rgb:
        os.makedirs(out_dir_rgb, exist_ok=True)

    if random_seed is not None:
        random.seed(random_seed)
        np.random.seed(random_seed)
        torch.manual_seed(random_seed)

    ds = MaskOnlyDataset(label_dir)
    total = len(ds)
    if translate_only:
        do_curvature = False  # force place-only
    # choose indices
    if indices is not None:
        idxs = [i for i in indices if 0 <= i < total][:num_images]
    else:
        if pick_mode == "first":
            idxs = list(range(min(num_images, total)))
        else:
            idxs = random.sample(range(total), k=min(num_images, total))

    saved, skipped = 0, 0
    for i in idxs:
        mask = ds[i]
        mask_np = mask.cpu().numpy()

        # Decide labels to operate on
        if force_label in (1, 2):
            if (mask_np == force_label).any():
                label_to_expand = force_label
                cand_labels = None
                enforce_label_diversity = False   # irrelevant, single label
            else:
                skipped += 1
                continue
        else:
            if any_label_by_size:
                # operate on both where present; optionally enforce balanced pick
                label_to_expand = None
                cand_labels = candidate_labels
                enforce_label_diversity = bool(move_both_if_present)
            else:
                label_to_expand = _choose_target_label(mask_np, candidates=(1, 2))
                if label_to_expand is None:
                    skipped += 1
                    continue
                cand_labels = None
                enforce_label_diversity = False

        # n_regions is still per-image sampled here
        _n_regions = _resolve_scalar(n_regions, is_int=True)

        if per_region_params:
            # pass-through; hybrid will sample per region if tuples
            _max_shift = max_shift
            _attempts  = attempts
            _n_apices  = n_apices
            _alpha     = alpha_deg
            _scale     = scale
            _max_rad   = max_radius
        else:
            # old behavior: sample once per image
            _max_shift = _resolve_scalar(max_shift, is_int=True)
            _attempts  = _resolve_scalar(attempts,  is_int=True)
            _n_apices  = _resolve_scalar(n_apices,  is_int=True)
            _alpha     = _resolve_scalar(alpha_deg, is_int=False)
            _scale     = _resolve_scalar(scale,     is_int=False)
            _max_rad   = _resolve_scalar(max_radius,is_int=True)

        aug_mask = hybrid_flataware_then_curvature(
            sample_mask=mask,
            label_value=label_to_expand if label_to_expand is not None else 1,
            max_shift=_max_shift,
            attempts=_attempts,
            do_curvature=do_curvature,
            n_apices=_n_apices,
            alpha_deg=_alpha,
            n_rays=n_rays,
            scale=_scale,
            max_radius=_max_rad,
            n_regions=_n_regions,
            region_pick=region_pick,
            growth_pick=growth_pick,
            candidate_labels=cand_labels,
            class_cfg=class_cfg,             # ✅ pass correctly
            p_expand=p_expand,
            debug=debug,
            enforce_label_diversity=enforce_label_diversity,
            debug_save_dir=out_dir_rgb,
            debug_prefix=f"img_{i}_",
            per_region_debug=False,
        )
        labels_clean = candidate_labels if any_label_by_size else ((label_to_expand,) if label_to_expand is not None else (1,2))
        aug_mask = remove_small_components(
            aug_mask, min_size=min_region_pixels, labels_to_filter=labels_clean, connectivity=4
        )

        # save outputs
        final_np = aug_mask.cpu().numpy().astype(np.uint8)
        base_no_ext = os.path.splitext(ds.label_files[i])[0]

        tag = ("oil" if (force_label == 1 or label_to_expand == 1)
               else "lookalike")

        out_1d_path = os.path.join(out_dir_1d, f"aug_{tag}_label_{base_no_ext}.png")
        Image.fromarray(final_np, mode="L").save(out_1d_path)

        if save_rgb:
            rgb_np = translate_1d_to_rgb(final_np)
            out_rgb_path = os.path.join(out_dir_rgb, f"aug_{tag}_rgb_{base_no_ext}.png")
            Image.fromarray(rgb_np).save(out_rgb_path)

        saved += 1

    return {
        "processed": len(idxs),
        "saved": saved,
        "skipped": skipped,
        "out_dir_1d": out_dir_1d,
        "out_dir_rgb": out_dir_rgb if save_rgb else None
    }

In [92]:
# --- config you already have (or swap to CLASS_CFG_SIMPLE if you prefer the fixed params) ---
CLASS_CFG = {
    1: {  # oil
        "place": {"max_shift": (120, 140), "attempts": (25, 55)},
        "grow": {
            "p_expand": 0.6,
            "n_apices": (6, 8),
            "alpha_deg": (90, 180),
            "n_rays": 120,
            "scale": {           # per-mode (optional)
                "expand": (1.2, 1.4),
                "shrink": (1,2.5),
            },
            "max_radius": {      # per-mode (optional)
                "expand": (30, 80),
                "shrink": (150, 180),
            },
        },
    },
    2: {  # look-alike
        "place": {"max_shift": (30, 140), "attempts": (20, 45)},
        "grow": {
            "p_expand": 0.5,
            "n_apices": (5, 6),
            "alpha_deg": (90, 180),
            "n_rays": 120,
            "scale": {
                "expand": (2.3,2.6),
                "shrink": (3.5,3.6),
            },
            "max_radius": {
                "expand": (400,500),
                "shrink": (300,700 ),
            },
        },
    },
}

# --- single-image, two-class augmentation ---
label_dir = r"D:\AndreJuarez\Code_OilSpill_hpc\Figuras_datasets_articulo\morp_parte_out\data_temp_peru_label1d"
out_dir   = r"D:\AndreJuarez\Code_OilSpill_hpc\Figuras_datasets_articulo\morp_parte_out\version2_labels"
idx       = 0 # <- your chosen image index

result = batch_hybrid_augment(
    label_dir=label_dir,
    out_dir=out_dir,
    num_images=1,                 # process exactly one image
    indices=[idx],                # which one
    any_label_by_size=True,       # enable cross-label pass
    candidate_labels=(1, 2),      # operate on oil & look-alike
    move_both_if_present=True,    # balanced pick across labels when possible
    n_regions=5,                  # e.g., pick one region per class (adjust as you like)
    class_cfg=CLASS_CFG,          # class-aware params (placement + growth)
    per_region_params=True,       # sample tuple params per region
    do_curvature=True,            # include the curvature stage                   # matches your cfg                # base; class_cfg will override per label
    min_region_pixels=2000,         # post-filter tiny bits; tune as needed
    save_rgb=True,                # also write a color preview
    random_seed=6,               # reproducibility (optional)
    debug = True,                  # visualize each step
)
print(result)

{'processed': 1, 'saved': 1, 'skipped': 0, 'out_dir_1d': 'D:\\AndreJuarez\\Code_OilSpill_hpc\\Figuras_datasets_articulo\\morp_parte_out\\version2_labels\\label_1d', 'out_dir_rgb': 'D:\\AndreJuarez\\Code_OilSpill_hpc\\Figuras_datasets_articulo\\morp_parte_out\\version2_labels\\color_rgb'}
