# Segmentation (Otsu luma, SageMaker-style: read-only input, write-only output)

This notebook is designed to run on local CPU while matching the SageMaker mental model:

- Treat `INPUT_ROOT` as **read-only**.
- Write all new artifacts under `OUTPUT_ROOT`.
- Use shard parameters so the same code can scale to SageMaker Processing.

## Input criteria (must already be true)

`INPUT_ROOT/` must contain:

- `plates_structured/` with exactly 435 plate directories named `plate-###`
- For each plate directory:
  - `manifest.json`
  - `source/?` with the immutable source image referenced by `manifest.json.source_image`
- `schemas/plate.manifest.schema.json`
- `schemas/run.manifest.schema.json`

## Midpoint artifacts (written once)

- `OUTPUT_ROOT/schemas/segmentation.otsu.schema.json`
- `OUTPUT_ROOT/reports/<run_id>/report.json`

## Outputs (append-only, per plate)

For each plate in the shard:

- `OUTPUT_ROOT/plates_structured/<plate_id>/runs/<run_id>/metrics.json`
- `OUTPUT_ROOT/plates_structured/<plate_id>/runs/<run_id>/segmentation.json`
- `OUTPUT_ROOT/plates_structured/<plate_id>/runs/<run_id>/segmentation_mask.png`

## Constraints enforced

- No mutation of `INPUT_ROOT`.
- Append-only run outputs in `OUTPUT_ROOT`.
- `metrics.json` validates against `schemas/run.manifest.schema.json`.
- `segmentation.json` validates against `schemas/segmentation.otsu.schema.json` (written by this notebook).


In [None]:
import hashlib
import json
import os
from datetime import datetime, timezone
from pathlib import Path

import numpy as np
from jsonschema import Draft202012Validator
from PIL import Image
from tqdm import tqdm

Image.MAX_IMAGE_PIXELS = None

DEFAULT_MAX_DIM = 1024

def utc_iso() -> str:
    return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")

def generate_run_id(models: list[str], note: str) -> str:
    stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
    payload = "|".join(models) + "|" + note
    short = hashlib.sha1(payload.encode("utf-8")).hexdigest()[:8]
    return f"run-{stamp}-{short}"

def load_validator(path: Path) -> Draft202012Validator:
    schema = json.loads(path.read_text(encoding="utf-8"))
    Draft202012Validator.check_schema(schema)
    return Draft202012Validator(schema)

def resize_fit(img: Image.Image, max_dim: int) -> Image.Image:
    w, h = img.size
    m = max(w, h)
    if m <= max_dim:
        return img
    s = max_dim / m
    return img.resize((max(1, int(round(w * s))), max(1, int(round(h * s)))), resample=Image.BILINEAR)

def otsu_threshold(hist: list[int]) -> int:
    total = float(sum(hist))
    if total <= 0:
        return 127

    sum_total = 0.0
    for i, c in enumerate(hist):
        sum_total += i * float(c)

    sum_b = 0.0
    w_b = 0.0
    var_max = -1.0
    threshold = 127

    for t in range(256):
        c = float(hist[t])
        w_b += c
        if w_b <= 0:
            continue
        w_f = total - w_b
        if w_f <= 0:
            break
        sum_b += float(t) * c
        m_b = sum_b / w_b
        m_f = (sum_total - sum_b) / w_f
        var_between = w_b * w_f * (m_b - m_f) ** 2
        if var_between > var_max:
            var_max = var_between
            threshold = t

    return int(threshold)


In [None]:
def find_input_root(start: Path) -> Path:
    override = os.environ.get("BWS_INPUT_ROOT")
    if override:
        p = Path(override).expanduser().resolve()
        if p.exists():
            return p
        raise RuntimeError(f"BWS_INPUT_ROOT does not exist: {p}")

    def is_scaffolded(p: Path) -> bool:
        if not p.is_dir():
            return False
        if not (p / "plates_structured").exists():
            return False
        if not (p / "schemas" / "plate.manifest.schema.json").exists():
            return False
        if not (p / "schemas" / "run.manifest.schema.json").exists():
            return False
        if not (p / "data.json").exists():
            return False
        return True

    candidates = [start] + list(start.parents)
    for base in candidates:
        if is_scaffolded(base):
            return base
        try:
            for child in sorted(base.iterdir()):
                if is_scaffolded(child):
                    return child
        except Exception:
            pass

    raise RuntimeError("Could not find scaffolded INPUT_ROOT; set BWS_INPUT_ROOT")

INPUT_ROOT = find_input_root(Path.cwd())
OUTPUT_ROOT = Path(os.environ.get("BWS_OUTPUT_ROOT", str(INPUT_ROOT / "_RUN_OUTPUT"))).expanduser().resolve()
SHARD_INDEX = int(os.environ.get("BWS_SHARD_INDEX", "0"))
SHARD_COUNT = int(os.environ.get("BWS_SHARD_COUNT", "1"))
RUN_ID = os.environ.get("BWS_RUN_ID", None)
SKIP_IF_PRESENT = os.environ.get("BWS_SKIP_IF_PRESENT", "1") == "1"
MAX_DIM = int(os.environ.get("BWS_MAX_DIM", str(DEFAULT_MAX_DIM)))

print("INPUT_ROOT :", INPUT_ROOT)
print("OUTPUT_ROOT:", OUTPUT_ROOT)
print("SHARD      :", SHARD_INDEX, "/", SHARD_COUNT)
print("RUN_ID     :", RUN_ID)
print("SKIP       :", SKIP_IF_PRESENT)
print("MAX_DIM    :", MAX_DIM)


In [None]:
PLATES_ROOT = INPUT_ROOT / "plates_structured"
SCHEMAS_ROOT = INPUT_ROOT / "schemas"

plate_schema_path = SCHEMAS_ROOT / "plate.manifest.schema.json"
run_schema_path = SCHEMAS_ROOT / "run.manifest.schema.json"

if not plate_schema_path.exists():
    raise RuntimeError(f"Missing: {plate_schema_path}")
if not run_schema_path.exists():
    raise RuntimeError(f"Missing: {run_schema_path}")

plate_validator = load_validator(plate_schema_path)
run_validator = load_validator(run_schema_path)

plates = sorted([p for p in PLATES_ROOT.iterdir() if p.is_dir() and p.name.startswith("plate-")])
if len(plates) != 435:
    raise RuntimeError(f"Unexpected plate count: {len(plates)}")

selected = [p for i, p in enumerate(plates) if i % SHARD_COUNT == SHARD_INDEX]
print("plates_total   :", len(plates))
print("plates_selected:", len(selected))

schema_dir = OUTPUT_ROOT / "schemas"
schema_dir.mkdir(parents=True, exist_ok=True)
seg_schema_path = schema_dir / "segmentation.otsu.schema.json"

SEGMENTATION_SCHEMA = {
    "$schema": "https://json-schema.org/draft/2020-12/schema",
    "$id": "https://burning-world-series/schemas/segmentation.otsu.schema.json",
    "title": "Segmentation (Otsu Luma, v1)",
    "type": "object",
    "additionalProperties": False,
    "required": [
        "plate_id",
        "run_id",
        "timestamp",
        "source_image",
        "method",
        "params",
        "original_geometry",
        "mask_geometry",
        "outputs",
        "foreground_ratio",
    ],
    "properties": {
        "plate_id": {"type": "string"},
        "run_id": {"type": "string"},
        "timestamp": {"type": "string"},
        "source_image": {"type": "string"},
        "method": {"type": "string"},
        "params": {
            "type": "object",
            "additionalProperties": True,
            "required": ["max_dim", "threshold", "polarity"],
            "properties": {
                "max_dim": {"type": "integer"},
                "threshold": {"type": "integer", "minimum": 0, "maximum": 255},
                "polarity": {"type": "string"},
            },
        },
        "original_geometry": {
            "type": "object",
            "additionalProperties": False,
            "required": ["width_px", "height_px"],
            "properties": {
                "width_px": {"type": "integer", "minimum": 1},
                "height_px": {"type": "integer", "minimum": 1},
            },
        },
        "mask_geometry": {
            "type": "object",
            "additionalProperties": False,
            "required": ["width_px", "height_px"],
            "properties": {
                "width_px": {"type": "integer", "minimum": 1},
                "height_px": {"type": "integer", "minimum": 1},
            },
        },
        "outputs": {
            "type": "object",
            "additionalProperties": False,
            "required": ["mask_png"],
            "properties": {"mask_png": {"type": "string"}},
        },
        "foreground_ratio": {"type": "number", "minimum": 0.0, "maximum": 1.0},
    },
}

seg_schema_path.write_text(json.dumps(SEGMENTATION_SCHEMA, indent=2, ensure_ascii=False) + "\n", encoding="utf-8")
seg_validator = load_validator(seg_schema_path)
print("wrote schema:", seg_schema_path)


In [None]:
MODELS = [f"segmentation-otsu-luma-v1(max_dim={MAX_DIM})"]
NOTE = "cheap luma otsu threshold segmentation (downsampled mask)"
RUN_ID_EFFECTIVE = RUN_ID or generate_run_id(MODELS, NOTE)

report = {
    "run_id": RUN_ID_EFFECTIVE,
    "timestamp": utc_iso(),
    "dataset_root": str(INPUT_ROOT),
    "input_root": str(INPUT_ROOT),
    "output_root": str(OUTPUT_ROOT),
    "shard_index": SHARD_INDEX,
    "shard_count": SHARD_COUNT,
    "plates_total": len(plates),
    "plates_selected": len(selected),
    "plates_processed": 0,
    "plates_skipped": 0,
    "decode_failures": 0,
    "schema_failures": 0,
    "errors_sample": [],
}

def compute_segmentation(plate_dir: Path) -> tuple[dict, dict, Image.Image]:
    manifest = json.loads((plate_dir / "manifest.json").read_text(encoding="utf-8"))
    errs = list(plate_validator.iter_errors(manifest))
    if errs:
        raise RuntimeError(f"plate manifest schema error: {list(errs[0].path)} -> {errs[0].message}")

    src_path = plate_dir / manifest["source_image"]
    if not src_path.exists():
        raise RuntimeError("missing source image")

    run_manifest = {
        "run_id": RUN_ID_EFFECTIVE,
        "plate_id": plate_dir.name,
        "timestamp": utc_iso(),
        "models": MODELS,
        "outputs": ["segmentation.json", "segmentation_mask.png"],
        "notes": NOTE,
    }

    rerrs = list(run_validator.iter_errors(run_manifest))
    if rerrs:
        raise RuntimeError(f"run manifest schema error: {list(rerrs[0].path)} -> {rerrs[0].message}")

    with Image.open(src_path) as img:
        w0, h0 = img.size
        L = img.convert("L")
        Ls = resize_fit(L, MAX_DIM)
        ws, hs = Ls.size

        hist = Ls.histogram()[:256]
        thr = otsu_threshold(hist)
        arr = np.asarray(Ls, dtype=np.uint8)
        mask_arr = (arr < thr).astype(np.uint8) * 255
        mask_img = Image.fromarray(mask_arr, mode="L")

    foreground_ratio = float(mask_arr.mean() / 255.0)

    seg = {
        "plate_id": manifest["plate_id"],
        "run_id": RUN_ID_EFFECTIVE,
        "timestamp": utc_iso(),
        "source_image": manifest["source_image"],
        "method": "otsu-luma-v1",
        "params": {"max_dim": int(MAX_DIM), "threshold": int(thr), "polarity": "dark-is-foreground"},
        "original_geometry": {"width_px": int(w0), "height_px": int(h0)},
        "mask_geometry": {"width_px": int(ws), "height_px": int(hs)},
        "outputs": {"mask_png": "segmentation_mask.png"},
        "foreground_ratio": foreground_ratio,
    }

    berrs = list(seg_validator.iter_errors(seg))
    if berrs:
        raise RuntimeError(f"segmentation schema error: {list(berrs[0].path)} -> {berrs[0].message}")

    return run_manifest, seg, mask_img

for plate_dir in tqdm(selected, desc="plates"):
    out_plate_dir = OUTPUT_ROOT / "plates_structured" / plate_dir.name
    out_run_dir = out_plate_dir / "runs" / RUN_ID_EFFECTIVE
    out_run_dir.mkdir(parents=True, exist_ok=True)

    out_metrics = out_run_dir / "metrics.json"
    out_seg = out_run_dir / "segmentation.json"
    out_mask = out_run_dir / "segmentation_mask.png"

    if SKIP_IF_PRESENT and out_metrics.exists() and out_seg.exists() and out_mask.exists():
        report["plates_skipped"] += 1
        continue

    try:
        run_manifest, seg, mask_img = compute_segmentation(plate_dir)
        out_metrics.write_text(json.dumps(run_manifest, indent=2), encoding="utf-8")
        out_seg.write_text(json.dumps(seg, indent=2, ensure_ascii=False) + "\n", encoding="utf-8")
        mask_img.save(out_mask, format="PNG", optimize=True)
        report["plates_processed"] += 1
    except Exception as e:
        report["decode_failures"] += 1
        if len(report["errors_sample"]) < 10:
            report["errors_sample"].append(f"{plate_dir.name}: {type(e).__name__}: {str(e)}")

report_dir = OUTPUT_ROOT / "reports" / RUN_ID_EFFECTIVE
report_dir.mkdir(parents=True, exist_ok=True)
(report_dir / "report.json").write_text(json.dumps(report, indent=2, ensure_ascii=False) + "\n", encoding="utf-8")
print(json.dumps(report, indent=2))
