In [None]:
# Cell 1 — Bootstrap (config + stable phase logging + EMNIST data-dir discovery)
from __future__ import annotations

from dataclasses import dataclass, asdict, is_dataclass
from pathlib import Path
from typing import Any, Mapping, Optional, Tuple
import os, sys, json, random, hashlib

# ============================================================
# Stable phase logging
# ============================================================

def _jsonable(x: Any) -> Any:
    if isinstance(x, Path):
        return str(x)
    if is_dataclass(x):
        return {k: _jsonable(v) for k, v in asdict(x).items()}
    if isinstance(x, dict):
        return {k: _jsonable(v) for k, v in x.items()}
    if isinstance(x, list):
        return [_jsonable(y) for y in x]
    if isinstance(x, tuple):
        return [_jsonable(y) for y in x]
    if isinstance(x, set):
        return sorted(_jsonable(y) for y in x)
    return x

PHASES: dict[str, dict[str, Any]] = {}

def phase_print(phase: str, payload: Mapping[str, Any]) -> None:
    if "phase" in payload:
        raise KeyError("phase_print payload must not contain key 'phase'")
    PHASES[phase] = _jsonable(dict(payload))
    obj = {"phase": phase, **payload}
    obj = {k: _jsonable(v) for k, v in obj.items()}
    print(json.dumps(obj, indent=2, sort_keys=True, ensure_ascii=False, allow_nan=False))

# ============================================================
# Run config (single source of truth for Cell 3)
# ============================================================

@dataclass(frozen=True)
class RunConfig:
    # provenance / dataset
    seed: int = 0
    run_mode: str = "quick"     # {"quick","full"}
    split: str = "digits"       # EMNIST split name

    # semantics lock (critical)
    canonicalize: bool = True   # upright canonicalization
    bin_threshold: int = 127    # binarize pixel > threshold

    # caps (0 means uncapped)
    max_train: int = 20_000
    max_test: int = 5_000

    # Cell 3 (compiled discrete margin repair) knobs
    finite_n: int = 256
    cap_unary: int = 128
    margin: int = 1
    overshoot_delta: int = 2
    max_epochs: int = 50
    max_updates: int = 20_000
    probe_per_class: int = 10

def load_config() -> RunConfig:
    seed = int(os.environ.get("SEED", "0"))
    run_mode = os.environ.get("RUN_MODE", "quick").strip().lower()
    if run_mode not in {"quick", "full"}:
        raise ValueError(f"RUN_MODE must be 'quick' or 'full', got: {run_mode!r}")

    split = os.environ.get("EMNIST_SPLIT", "digits").strip().lower()

    # semantics lock
    canonicalize = True
    bin_threshold = int(os.environ.get("BIN_THRESHOLD", "127"))

    # caps by run_mode (quick defaults)
    if run_mode == "quick":
        max_train = int(os.environ.get("MAX_TRAIN", "20000"))
        max_test  = int(os.environ.get("MAX_TEST",  "5000"))
    else:
        max_train = int(os.environ.get("MAX_TRAIN", "0"))
        max_test  = int(os.environ.get("MAX_TEST",  "0"))

    # Cell 3 knobs (override via env if desired)
    finite_n        = int(os.environ.get("FINITE_N", "256"))
    cap_unary       = int(os.environ.get("CAP_UNARY", "128"))
    margin          = int(os.environ.get("MARGIN", "1"))
    overshoot_delta = int(os.environ.get("OVERSHOOT_DELTA", "2"))
    max_epochs      = int(os.environ.get("MAX_EPOCHS", "50"))
    max_updates     = int(os.environ.get("MAX_UPDATES", "20000"))
    probe_per_class = int(os.environ.get("PROBE_PER_CLASS", "10"))

    return RunConfig(
        seed=seed,
        run_mode=run_mode,
        split=split,
        canonicalize=canonicalize,
        bin_threshold=bin_threshold,
        max_train=max_train,
        max_test=max_test,
        finite_n=finite_n,
        cap_unary=cap_unary,
        margin=margin,
        overshoot_delta=overshoot_delta,
        max_epochs=max_epochs,
        max_updates=max_updates,
        probe_per_class=probe_per_class,
    )

# ============================================================
# EMNIST data directory discovery
# ============================================================

def find_data_dir(start: Optional[Path] = None) -> Tuple[Path, str]:
    env = os.environ.get("EMNIST_DATA_DIR", "").strip()
    if env:
        p = Path(env).expanduser().resolve()
        if not p.exists():
            raise FileNotFoundError(f"EMNIST_DATA_DIR is set but does not exist: {p}")
        if not p.is_dir():
            raise NotADirectoryError(f"EMNIST_DATA_DIR exists but is not a directory: {p}")
        return p, "env:EMNIST_DATA_DIR"

    start = (start or Path.cwd()).resolve()
    for parent in [start] + list(start.parents):
        cand = parent / "gzip"
        if cand.exists() and cand.is_dir():
            return cand.resolve(), "found:gzip_ancestor"

    raise FileNotFoundError(
        "Cannot locate EMNIST data directory. Expected either:\n"
        "  - EMNIST_DATA_DIR=<path-to-dir-containing-emnist-gz-files>\n"
        "  - or a ./gzip/ directory in the current working directory or any ancestor."
    )

# ============================================================
# Instantiate run context
# ============================================================

CFG = load_config()
random.seed(CFG.seed)

DATA_DIR_START = Path.cwd().resolve()
DATA_DIR, DATA_DIR_METHOD = find_data_dir(start=DATA_DIR_START)
PROJECT_ROOT = DATA_DIR.parent.resolve()

phase_print("run_header", {
    "seed": CFG.seed,
    "run_mode": CFG.run_mode,
    "split": CFG.split,
    "canonicalize": "upright" if CFG.canonicalize else "raw",
    "bin_threshold": CFG.bin_threshold,
    "caps": {"max_train": CFG.max_train, "max_test": CFG.max_test},
    "cell3_cfg": {
        "finite_n": CFG.finite_n,
        "cap_unary": CFG.cap_unary,
        "margin": CFG.margin,
        "overshoot_delta": CFG.overshoot_delta,
        "max_epochs": CFG.max_epochs,
        "max_updates": CFG.max_updates,
        "probe_per_class": CFG.probe_per_class,
    },
    "python": sys.version.split()[0],
    "python_hash_seed": os.environ.get("PYTHONHASHSEED", None),
    "project_root": PROJECT_ROOT,
    "data_dir": DATA_DIR,
    "data_dir_method": DATA_DIR_METHOD,
    "data_dir_start": DATA_DIR_START,
})

In [None]:
# Cell 2 — EMNIST kernel (streamed IDX-gz reader + mapping + canonicalization + dataset streams)
from __future__ import annotations

import gzip, io, struct
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterator, List, Tuple, Callable

# ============================================================
# Semantics-critical constants
# ============================================================

GRID: int = 28
N_PIXELS: int = GRID * GRID

# ============================================================
# IDX (.ubyte) reading
# ============================================================

@dataclass(frozen=True)
class IdxImagesHeader:
    count: int
    rows: int
    cols: int

@dataclass(frozen=True)
class IdxLabelsHeader:
    count: int

def _read_exact(f: io.BufferedReader, n: int) -> bytes:
    b = f.read(n)
    if b is None or len(b) != n:
        raise EOFError(f"expected {n} bytes, got {0 if b is None else len(b)}")
    return b

def read_idx_images_header(f: io.BufferedReader) -> IdxImagesHeader:
    magic, count, rows, cols = struct.unpack(">IIII", _read_exact(f, 16))
    if magic != 2051:
        raise ValueError(f"bad IDX images magic {magic} (expected 2051)")
    return IdxImagesHeader(count=int(count), rows=int(rows), cols=int(cols))

def read_idx_labels_header(f: io.BufferedReader) -> IdxLabelsHeader:
    magic, count = struct.unpack(">II", _read_exact(f, 8))
    if magic != 2049:
        raise ValueError(f"bad IDX labels magic {magic} (expected 2049)")
    return IdxLabelsHeader(count=int(count))

class EMNISTIdxGz:
    """
    Stream (image_bytes, label_int) pairs from IDX ubyte gz files without loading all into memory.
    - image_bytes: 28*28 bytes, raw grayscale 0..255
    - label_int: original dataset label integer (before mapping)
    """
    def __init__(self, images_gz: Path, labels_gz: Path) -> None:
        self.images_gz = Path(images_gz)
        self.labels_gz = Path(labels_gz)

    def __iter__(self) -> Iterator[Tuple[bytes, int]]:
        with gzip.open(self.images_gz, "rb") as fi_gz, gzip.open(self.labels_gz, "rb") as fl_gz:
            fi = io.BufferedReader(fi_gz)
            fl = io.BufferedReader(fl_gz)

            hi = read_idx_images_header(fi)
            hl = read_idx_labels_header(fl)

            if hi.count != hl.count:
                raise ValueError(f"count mismatch: images={hi.count}, labels={hl.count}")
            if hi.rows != GRID or hi.cols != GRID:
                raise ValueError(f"expected {GRID}x{GRID} but got {hi.rows}x{hi.cols}")

            img_n = hi.rows * hi.cols
            for _ in range(hi.count):
                img = _read_exact(fi, img_n)
                lab = _read_exact(fl, 1)[0]
                yield img, int(lab)

def split_files(split: str) -> Tuple[str, str, str, str, str]:
    s = split.strip().lower()
    mapping = f"emnist-{s}-mapping.txt"
    train_images = f"emnist-{s}-train-images-idx3-ubyte.gz"
    train_labels = f"emnist-{s}-train-labels-idx1-ubyte.gz"
    test_images = f"emnist-{s}-test-images-idx3-ubyte.gz"
    test_labels = f"emnist-{s}-test-labels-idx1-ubyte.gz"
    return mapping, train_images, train_labels, test_images, test_labels

# ============================================================
# Mapping file parsing
# ============================================================

@dataclass(frozen=True)
class ClassMapping:
    label_to_class: Dict[int, int]
    class_to_label: List[int]
    class_to_codepoint: List[int]

def parse_emnist_mapping_txt(path: Path) -> ClassMapping:
    pairs: List[Tuple[int, int]] = []
    seen_labels: set[int] = set()

    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            s = line.strip()
            if not s or s.startswith("#"):
                continue
            parts = s.split()
            if len(parts) < 2:
                continue
            lab = int(parts[0])
            cp = int(parts[1])
            if lab in seen_labels:
                raise ValueError(f"duplicate label {lab} in mapping file {path}")
            seen_labels.add(lab)
            pairs.append((lab, cp))

    if not pairs:
        raise ValueError(f"no mapping pairs parsed from {path}")

    pairs.sort(key=lambda t: t[0])
    labels = [p[0] for p in pairs]
    cps = [p[1] for p in pairs]
    label_to_class: Dict[int, int] = {lab: i for i, lab in enumerate(labels)}
    return ClassMapping(label_to_class=label_to_class, class_to_label=labels, class_to_codepoint=cps)

# ============================================================
# Semantics: canonicalization (upright = rotate CW 90 then mirror horizontally)
# ============================================================

def flip_horizontal(img: bytes, *, grid: int = GRID) -> bytes:
    if len(img) != grid * grid:
        raise ValueError(f"expected image length {grid*grid}, got {len(img)}")
    out = bytearray(grid * grid)
    for r in range(grid):
        row = img[r * grid:(r + 1) * grid]
        out[r * grid:(r + 1) * grid] = row[::-1]
    return bytes(out)

def rotate90_cw(img: bytes, *, grid: int = GRID) -> bytes:
    if len(img) != grid * grid:
        raise ValueError(f"expected image length {grid*grid}, got {len(img)}")
    out = bytearray(grid * grid)
    for r in range(grid):
        for c in range(grid):
            out[c * grid + (grid - 1 - r)] = img[r * grid + c]
    return bytes(out)

def canonicalize_upright(img: bytes, *, grid: int = GRID) -> bytes:
    return flip_horizontal(rotate90_cw(img, grid=grid), grid=grid)

# ============================================================
# Re-iterable dataset streams for Cell 3
# ============================================================

class _IterableFromFactory:
    def __init__(self, factory: Callable[[], Iterator[Tuple[bytes, int]]]) -> None:
        self._factory = factory
    def __iter__(self) -> Iterator[Tuple[bytes, int]]:
        return self._factory()

def iter_raw_images_labels(*, train: bool) -> Iterator[Tuple[bytes, int]]:
    mapping_file, tr_i, tr_l, te_i, te_l = split_files(CFG.split)
    d = Path(DATA_DIR)
    mapping = parse_emnist_mapping_txt(d / mapping_file)

    if train:
        ds = EMNISTIdxGz(d / tr_i, d / tr_l)
        cap = int(CFG.max_train)
    else:
        ds = EMNISTIdxGz(d / te_i, d / te_l)
        cap = int(CFG.max_test)

    seen = 0
    for img, lab in ds:
        if lab not in mapping.label_to_class:
            continue
        if CFG.canonicalize:
            img = canonicalize_upright(img, grid=GRID)
        y = int(mapping.label_to_class[lab])
        yield img, y
        seen += 1
        if cap and seen >= cap:
            break

# Expose globals expected by Cell 3
mapping_file, *_rest = split_files(CFG.split)
mapping = parse_emnist_mapping_txt(Path(DATA_DIR) / mapping_file)
n_classes = len(mapping.class_to_label)

train_ds = _IterableFromFactory(lambda: iter_raw_images_labels(train=True))
test_ds  = _IterableFromFactory(lambda: iter_raw_images_labels(train=False))

phase_print("datasets_ready", {
    "ok": True,
    "split": CFG.split,
    "mapping_file": mapping_file,
    "n_classes": n_classes,
    "canonicalize": "upright" if CFG.canonicalize else "raw",
    "bin_threshold": int(CFG.bin_threshold),
    "caps": {"max_train": int(CFG.max_train), "max_test": int(CFG.max_test)},
})

In [None]:
# Cell 3 — Discrete margin repair

from __future__ import annotations

import os, time, json, math, hashlib
from collections import defaultdict, Counter
from datetime import datetime, timezone
from pathlib import Path
from typing import Dict, List, Tuple, Iterable, Iterator, Any, Optional

# ============================================================
# Helpers (stable hashing / I/O)
# ============================================================

def _now_utc_iso() -> str:
    return datetime.now(timezone.utc).replace(microsecond=0).isoformat()

def _sha256_hex(b: bytes) -> str:
    return hashlib.sha256(b).hexdigest()

def _stable_json(obj: Any) -> str:
    # Deterministic JSON for hashing. Keep allow_nan=False.
    return json.dumps(obj, sort_keys=True, ensure_ascii=False, separators=(",", ":"), allow_nan=False)

def _hash_obj(obj: Any) -> str:
    return _sha256_hex(_stable_json(obj).encode("utf-8"))

def _sha256_file(path: Path, *, chunk: int = 1 << 20) -> str:
    h = hashlib.sha256()
    with open(path, "rb") as f:
        while True:
            b = f.read(chunk)
            if not b:
                break
            h.update(b)
    return h.hexdigest()

def _env_flag(name: str, default: str = "0") -> bool:
    v = os.environ.get(name, default).strip().lower()
    return v in {"1", "true", "yes", "y", "on"}

def _argmax_stable(scores: List[int]) -> int:
    # stable first-maximum (least index among maxima)
    best_i = 0
    best_v = scores[0]
    for i in range(1, len(scores)):
        v = scores[i]
        if v > best_v:
            best_v = v
            best_i = i
    return best_i

def _best_competitor_stable(scores: List[int], y: int) -> Tuple[int, int]:
    best_c = -1
    best_v: Optional[int] = None
    for c, v in enumerate(scores):
        if c == y:
            continue
        if best_v is None or v > best_v:
            best_c = c
            best_v = v
    return best_c, int(best_v if best_v is not None else 0)

def _iter_ds(ds: Iterable[Tuple[bytes, int]]) -> Iterator[Tuple[bytes, int]]:
    for x in ds:
        yield x

# ============================================================
# Semantics Lock (explicit, hashed)
# ============================================================

SEM_LOCK = {
    "grid": int(GRID),
    "n_pixels": int(N_PIXELS),
    "canonicalize": "upright" if bool(CFG.canonicalize) else "raw",
    "binarize_predicate": "v > bin_threshold",
    "bin_threshold": int(CFG.bin_threshold),
    "feature_schema": "unary: bias + active pixels",
    "bias_feature": "bias:1",
    "active_pixel_feature": "px:<idx>",
    "cap_policy": "if cap_unary>0 and active>cap_unary: take top by (value desc, idx asc)",
    "ordering_policy": "facts list begins with bias; then pixels in idx ascending (after optional cap)",
    "dedup_policy": "preserve first occurrence",
    "argmax_tie_break": "stable first-maximum (least index among maxima)",
}
LOCK_HASH = _hash_obj(SEM_LOCK)

# ============================================================
# Dataset provenance (no dataset upload; optional strong hashes)
# ============================================================

# These names are created in Cell 2.
# mapping_file exists; split_files exists. DATA_DIR exists from Cell 1.
mapping_file, tr_i, tr_l, te_i, te_l = split_files(CFG.split)
DATA_DIR_P = Path(DATA_DIR)

paths = {
    "mapping_txt": (DATA_DIR_P / mapping_file).resolve(),
    "train_images_gz": (DATA_DIR_P / tr_i).resolve(),
    "train_labels_gz": (DATA_DIR_P / tr_l).resolve(),
    "test_images_gz": (DATA_DIR_P / te_i).resolve(),
    "test_labels_gz": (DATA_DIR_P / te_l).resolve(),
}

for k, p in paths.items():
    if not p.exists():
        raise FileNotFoundError(f"missing required EMNIST file for {k}: {p}")

strong_hash = _env_flag("HASH_DATASET", "0")  # opt-in; may be slow
prov = {
    "split": str(CFG.split),
    "data_dir": str(DATA_DIR_P),
    "files": {
        k: {
            "name": p.name,
            "bytes": int(p.stat().st_size),
        }
        for k, p in paths.items()
    },
    "mapping_sha256": _sha256_file(paths["mapping_txt"]),
    "strong_hashes_included": bool(strong_hash),
}
if strong_hash:
    prov["files"]["train_images_gz"]["sha256"] = _sha256_file(paths["train_images_gz"])
    prov["files"]["train_labels_gz"]["sha256"] = _sha256_file(paths["train_labels_gz"])
    prov["files"]["test_images_gz"]["sha256"]  = _sha256_file(paths["test_images_gz"])
    prov["files"]["test_labels_gz"]["sha256"]  = _sha256_file(paths["test_labels_gz"])

PROV_HASH = _hash_obj(prov)

# ============================================================
# Unary facts from raw bytes (0..255) under Semantics Lock
# Represent facts as strings to avoid repr()-based instability.
# ============================================================

def _feat_bias() -> str:
    return "bias:1"

def _feat_px(idx: int) -> str:
    return f"px:{int(idx)}"

def facts_unary_bytes(img: bytes, *, bin_threshold: int, cap_unary: int) -> List[str]:
    if len(img) != N_PIXELS:
        raise ValueError(f"expected image length {N_PIXELS}, got {len(img)}")

    active: List[Tuple[int, int]] = []  # (value, idx)
    thr = int(bin_threshold)
    for i, v in enumerate(img):
        if v > thr:
            active.append((int(v), int(i)))

    # forced bias; if no actives, bias-only
    if not active:
        return [_feat_bias()]

    # deterministic cap (intensity desc then idx asc)
    if cap_unary and cap_unary > 0 and len(active) > cap_unary:
        active.sort(key=lambda t: (-t[0], t[1]))
        active = active[:cap_unary]
        # within cap, we still emit in idx order for stability/readability
        active.sort(key=lambda t: (t[1],))
    else:
        active.sort(key=lambda t: (t[1],))

    feats = [_feat_bias()] + [_feat_px(idx) for (_v, idx) in active]

    # deterministic dedup preserve first occurrence (bias is unique anyway)
    seen = set()
    out: List[str] = []
    for f in feats:
        if f not in seen:
            out.append(f)
            seen.add(f)
    return out

# ============================================================
# Build finite training table + balanced probe set
# ============================================================

finite_imgs: List[bytes] = []
finite_labels: List[int] = []
finite_srcpos: List[int] = []

for srcpos, (img, y) in enumerate(_iter_ds(train_ds)):
    finite_imgs.append(img)
    finite_labels.append(int(y))
    finite_srcpos.append(int(srcpos))
    if len(finite_imgs) >= int(CFG.finite_n):
        break

if len(finite_imgs) < int(CFG.finite_n):
    raise RuntimeError(f"train_ds yielded only {len(finite_imgs)} items; need finite_n={CFG.finite_n}.")

K = int(globals().get("n_classes", max(finite_labels) + 1))

probe_items: List[Tuple[bytes, int]] = []
need = Counter({c: int(CFG.probe_per_class) for c in range(K)})
for img, y in _iter_ds(test_ds):
    y = int(y)
    if 0 <= y < K and need[y] > 0:
        probe_items.append((img, y))
        need[y] -= 1
        if all(v == 0 for v in need.values()):
            break

# ============================================================
# Build feature table + postings
# ============================================================

FI: List[List[str]] = []
TAB_post: Dict[str, List[int]] = defaultdict(list)

feat_vocab: Dict[str, int] = {}  # stable IDs for bookkeeping only
def _feat_id(f: str) -> int:
    if f in feat_vocab:
        return feat_vocab[f]
    feat_vocab[f] = len(feat_vocab) + 1
    return feat_vocab[f]

for i in range(int(CFG.finite_n)):
    feats = facts_unary_bytes(
        finite_imgs[i],
        bin_threshold=int(CFG.bin_threshold),
        cap_unary=int(CFG.cap_unary),
    )
    FI.append(feats)
    for f in feats:
        TAB_post[f].append(i)
        _feat_id(f)

per_class = Counter(finite_labels)

# Fingerprint of FI itself (small for finite_n=256; this is the real feature-table identity)
FI_hash = _hash_obj({
    "finite_n": int(CFG.finite_n),
    "facts": FI,
})

table_fp = _hash_obj({
    "feat_vocab_size": int(len(feat_vocab)),
    "postings_size": int(len(TAB_post)),
    "postings_hist_prefix": sorted([len(v) for v in TAB_post.values()])[:500],
    "fi_hash": FI_hash,
})

# Dataset identity for *this* finite witness table (reconstructible from EMNIST + srcpos + lock)
finite_table_id = _hash_obj({
    "prov_hash": PROV_HASH,
    "lock_hash": LOCK_HASH,
    "finite_srcpos": finite_srcpos,
    "finite_labels": finite_labels,  # redundant but practical
    "fi_hash": FI_hash,              # binds extraction result
})

# ============================================================
# Initialize integer weights and cached scores
# ============================================================

W: Dict[Tuple[int, str], int] = {}           # (class, feature) -> weight
S: List[List[int]] = [[0] * K for _ in range(int(CFG.finite_n))]  # cached scores

def _count_violations_from_cache(Sm: List[List[int]], Y: List[int], margin: int) -> Tuple[int, int]:
    viol_pairs = 0
    max_slack = 0
    for i in range(len(Y)):
        y = Y[i]
        sy = Sm[i][y]
        for c, sc in enumerate(Sm[i]):
            if c == y:
                continue
            slack = (sc + margin) - sy
            if slack > 0:
                viol_pairs += 1
                if slack > max_slack:
                    max_slack = slack
    return viol_pairs, max_slack

def _predict_from_weights(img: bytes) -> Tuple[int, List[int]]:
    feats = facts_unary_bytes(img, bin_threshold=int(CFG.bin_threshold), cap_unary=int(CFG.cap_unary))
    scores = [0] * K
    for f in feats:
        for c in range(K):
            scores[c] += W.get((c, f), 0)
    pred = _argmax_stable(scores)
    return pred, scores

def _accuracy(items: List[Tuple[bytes, int]]) -> float:
    correct = 0
    total = 0
    for img, y in items:
        pred, _ = _predict_from_weights(img)
        correct += int(pred == int(y))
        total += 1
    return correct / total if total else float("nan")

# ============================================================
# Training loop (compiled discrete margin repair)
# ============================================================

epoch_rows: List[dict] = []
best = {"epoch": None, "probe_acc": float("-inf"), "weights_hash": None}
best_W: Optional[Dict[Tuple[int, str], int]] = None

updates_total = 0

for epoch in range(1, int(CFG.max_epochs) + 1):
    ep_start = time.time()
    viol_before, _ = _count_violations_from_cache(S, finite_labels, int(CFG.margin))

    ep_updates = 0
    ep_max_slack_seen = 0

    for i in range(int(CFG.finite_n)):
        y = finite_labels[i]
        scores_i = S[i]

        c_star, sc = _best_competitor_stable(scores_i, y)
        sy = scores_i[y]
        slack = (sc + int(CFG.margin)) - sy
        if slack <= 0:
            continue

        feats = FI[i]
        denom = max(1, len(feats))
        steps = int(math.ceil(slack / denom))
        if steps <= 0:
            steps = 1
        steps += int(CFG.overshoot_delta)

        if slack > ep_max_slack_seen:
            ep_max_slack_seen = int(slack)

        # paired weight + cache updates, propagated through postings (compiled)
        for f in feats:
            W[(y, f)] = int(W.get((y, f), 0) + steps)
            W[(c_star, f)] = int(W.get((c_star, f), 0) - steps)
            for j in TAB_post[f]:
                S[j][y] += steps
                S[j][c_star] -= steps

        ep_updates += 1
        updates_total += 1
        if updates_total >= int(CFG.max_updates):
            break

    viol_after, _ = _count_violations_from_cache(S, finite_labels, int(CFG.margin))
    probe_acc = _accuracy(probe_items)

    if probe_acc > float(best["probe_acc"]):
        best = {"epoch": int(epoch), "probe_acc": float(probe_acc), "weights_hash": None}
        best_W = dict(W)

    epoch_rows.append({
        "epoch": int(epoch),
        "updates": int(ep_updates),
        "viol_before": int(viol_before),
        "viol_after": int(viol_after),
        "max_slack_seen": int(ep_max_slack_seen),
        "probe_acc": float(probe_acc),
        "seconds": float(time.time() - ep_start),  # diagnostics only (non-certificate)
        "w_keys": int(sum(1 for v in W.values() if v != 0)),
    })

    if updates_total >= int(CFG.max_updates):
        break
    if viol_after == 0:
        break

def _hash_weights(Wd: Dict[Tuple[int, str], int]) -> str:
    items: List[Tuple[int, str, int]] = []
    for (c, f), w in Wd.items():
        if w != 0:
            items.append((int(c), str(f), int(w)))
    items.sort()
    return _sha256_hex(_stable_json(items).encode("utf-8"))

best["weights_hash"] = _hash_weights(best_W if best_W is not None else W)
weights_hash = _hash_weights(W)

# ============================================================
# Exact verification (definitional recomputation on finite table)
# ============================================================

def _finite_margin_check() -> dict:
    violations = 0
    prefix = []
    margin = int(CFG.margin)

    for i, feats in enumerate(FI):
        y = finite_labels[i]
        scores = [0] * K
        for f in feats:
            for c in range(K):
                scores[c] += W.get((c, f), 0)
        sy = scores[y]
        for c, sc in enumerate(scores):
            if c == y:
                continue
            slack = (sc + margin) - sy
            if slack > 0:
                violations += 1
                if len(prefix) < 20:
                    prefix.append({
                        "i": int(i), "y": int(y), "c": int(c),
                        "sy": int(sy), "sc": int(sc), "slack": int(slack),
                    })

    return {
        "ok": (violations == 0),
        "violations_total": int(violations),
        "violations_prefix": prefix,
    }

def _cache_consistency_check() -> dict:
    mismatches = 0
    prefix = []
    for i, feats in enumerate(FI):
        for c in range(K):
            s_def = 0
            for f in feats:
                s_def += W.get((c, f), 0)
            if int(S[i][c]) != int(s_def):
                mismatches += 1
                if len(prefix) < 20:
                    prefix.append({
                        "i": int(i), "c": int(c),
                        "cache": int(S[i][c]), "def": int(s_def),
                    })
    return {
        "ok": (mismatches == 0),
        "mismatches_total": int(mismatches),
        "mismatches_prefix": prefix,
    }

FULL = _finite_margin_check()
CACHE = _cache_consistency_check()
OK = bool(FULL["ok"] and CACHE["ok"])

# ============================================================
# Deterministic certificate payload vs run record
# ============================================================

cfg_hash = _hash_obj({
    "seed": int(CFG.seed),
    "run_mode": str(CFG.run_mode),
    "split": str(CFG.split),
    "canonicalize": bool(CFG.canonicalize),
    "bin_threshold": int(CFG.bin_threshold),
    "finite_n": int(CFG.finite_n),
    "cap_unary": int(CFG.cap_unary),
    "margin": int(CFG.margin),
    "overshoot_delta": int(CFG.overshoot_delta),
    "max_epochs": int(CFG.max_epochs),
    "max_updates": int(CFG.max_updates),
    "probe_per_class": int(CFG.probe_per_class),
})

CERT = {
    "schema": "typedrepair: compiled discrete margin repair (integer weights; postings + cache)",
    "ok": bool(OK),
    "lock": SEM_LOCK,
    "lock_hash": LOCK_HASH,
    "cfg_hash": cfg_hash,
    "dataset_provenance": prov,          # deterministic; strong hashes optional but still deterministic
    "prov_hash": PROV_HASH,
    "finite_table": {
        "finite_table_id": finite_table_id,
        "finite_n": int(CFG.finite_n),
        "src_positions_in_train_stream": list(map(int, finite_srcpos)),
        "labels": list(map(int, finite_labels)),
        "class_hist": dict(sorted({int(k): int(v) for k, v in per_class.items()}.items())),
        "fi_hash": FI_hash,
        "table_fingerprint": table_fp,
    },
    "features": {
        "schema": "bias:1 + px:<idx> for pixels with value > bin_threshold; optional cap by intensity",
        "cap_unary": int(CFG.cap_unary),
        "feat_vocab_size": int(len(feat_vocab)),
        "postings_size": int(len(TAB_post)),
    },
    "weights": {
        "weights_hash": weights_hash,
        "nonzero_keys": int(sum(1 for v in W.values() if v != 0)),
    },
    "verification": {
        "finite_full_margin_check": FULL,
        "cache_consistency_check": CACHE,
        "note": "Exact audit on finite witness table; cache check enforces definitional score equality on the table.",
    },
}

CERT_HASH = _hash_obj(CERT)

# Non-certificate run record (may include time, probe curve, etc.)
RUN_RECORD = {
    "experiment": "logical_backpropagation_compiled_discrete_margin_repair_curve",
    "created_utc": _now_utc_iso(),
    "emnist": {
        "split": str(CFG.split),
        "n_classes": int(K),
        "mapping_file": str(mapping_file),
        "canonicalize": "upright" if bool(CFG.canonicalize) else "raw",
        "bin_threshold": int(CFG.bin_threshold),
    },
    "cfg": {
        "finite_n": int(CFG.finite_n),
        "cap_unary": int(CFG.cap_unary),
        "margin": int(CFG.margin),
        "overshoot_delta": int(CFG.overshoot_delta),
        "max_epochs": int(CFG.max_epochs),
        "max_updates": int(CFG.max_updates),
        "probe_per_class": int(CFG.probe_per_class),
        "seed": int(CFG.seed),
    },
    "epoch_curve": epoch_rows,  # includes seconds (non-deterministic) by design
    "best_probe": dict(best),   # diagnostic only (not certified)
}

PAYLOAD = {
    "certificate": dict(CERT, certificate_hash=CERT_HASH),
    "run_record": RUN_RECORD,
}

# For phase logging / later cells
PHASES["logical_backpropagation_compiled_curve"] = PAYLOAD

# ============================================================
# Compact summary
# ============================================================

final_v = epoch_rows[-1]["viol_after"] if epoch_rows else None
max_v = int(CFG.finite_n) * (int(K) - 1)

print("\nSummary (certificate-relevant):")
print(f"  EMNIST/{CFG.split} | finite_n={CFG.finite_n} | classes={K} | margin={CFG.margin} | cap_unary={CFG.cap_unary}")
print(f"  Violations: {final_v}/{max_v} (finite-table) | Cache: {'OK' if CACHE.get('ok') else 'FAIL'} | Finite-check: {'OK' if FULL.get('ok') else 'FAIL'}")
print(f"  Weights hash: {weights_hash}")
print(f"  Finite table id: {finite_table_id}")
print(f"  Provenance hash: {PROV_HASH}  (HASH_DATASET={'1' if strong_hash else '0'})")
print(f"  Lock hash: {LOCK_HASH}")
print(f"  CERTIFICATE HASH (citable): {CERT_HASH}")

# Optional diagnostic print
best_epoch = best.get("epoch")
if best_epoch is not None:
    print("\nDiagnostics (non-certified):")
    print(f"  Best probe: epoch {best_epoch} | acc={best.get('probe_acc', float('nan')):.3f} | weights_hash(best)={best.get('weights_hash')}")

In [None]:
from pathlib import Path
import json
import hashlib
import time

def stable_json_bytes(obj) -> bytes:
    return json.dumps(obj, sort_keys=True, separators=(",", ":"), ensure_ascii=False).encode("utf-8")

def emit_artifact(record: dict, out_dir: Path) -> Path:
    out_dir = out_dir.resolve()
    out_dir.mkdir(parents=True, exist_ok=True)

    payload = stable_json_bytes(record)
    h = hashlib.sha256(payload).hexdigest()

    # Deterministic filename, easy to locate
    out_path = out_dir / f"artifact_{h}.json"
    out_path.write_bytes(payload)

    # Optional: write a stable “latest” pointer for convenience
    (out_dir / "LATEST.txt").write_text(str(out_path.name) + "\n", encoding="utf-8")

    # Hard assertions: fail loudly instead of “printing a hash and moving on”
    assert out_path.exists(), f"Artifact write failed: {out_path}"
    assert out_path.stat().st_size == len(payload), "Short write detected."

    print("Artifact SHA-256:", h)
    print("Artifact path:", str(out_path))
    return out_path

# Example: build the record you are already hashing/printing
certificate_record = {
    "created_utc": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
    "lock": {
        "bin_threshold": 127,
        "cap_unary": 128,
        "finite_n": 256,
        # include your dataset/schema fingerprints here
    },
    "result": {
        "violations": 0,
        "cache_ok": True,
        "probe_acc": 0.710,
        "epoch_best": 6,
    },
    # include weights + any other structures you certify (prefer a compressed encoding if large)
    # "weights": ...
}

emit_artifact(certificate_record, out_dir=Path("artifacts"))