In [32]:
import json
from typing import Dict, Any, Tuple, List, Optional
import numpy as np


def ece_function(
    jsonl_path: str,
    n_bins: int = 4,
    conf_key: str = "conf",
    pred_key: str = "pred",
    label_key: str = "label",
    skip_invalid: bool = True,
    return_details: bool = False,
) -> float | Tuple[float, Dict[str, Any]]:
    """
    Compute Expected Calibration Error (ECE) from a JSONL file where each line is a JSON object like:
        {"label": 1, "pred": 5, "conf": 0.143}

    This computes the standard top-1 ECE:
      - confidence = `conf` (assumed to be max softmax probability for the predicted class)
      - correctness = (pred == label)
      - uniform bins over [0, 1]

    Args:
        jsonl_path: Path to the .jsonl file.
        n_bins: Number of uniform confidence bins (paper you referenced uses 4).
        conf_key: JSON key for confidence value.
        pred_key: JSON key for predicted class.
        label_key: JSON key for true label.
        skip_invalid: If True, skip malformed lines / missing keys / out-of-range conf.
                      If False, raise a ValueError on the first invalid line.
        return_details: If True, also return per-bin stats for debugging/plotting.

    Returns:
        If return_details is False:
            ece (float)
        If return_details is True:
            (ece, details_dict)
    """
    preds: List[int] = []
    labels: List[int] = []
    confs: List[float] = []

    def _handle_invalid(msg: str, line_no: int, line: str) -> None:
        if skip_invalid:
            return
        raise ValueError(f"[Line {line_no}] {msg}. Line: {line.strip()[:200]}")

    # ---- Load JSONL ----
    with open(jsonl_path, "r", encoding="utf-8") as f:
        for line_no, line in enumerate(f, start=1):
            line = line.strip()
            if not line:
                continue

            try:
                obj = json.loads(line)
            except json.JSONDecodeError:
                _handle_invalid("Invalid JSON", line_no, line)
                continue

            if not all(k in obj for k in (conf_key, pred_key, label_key)):
                _handle_invalid(
                    f"Missing keys (need: {conf_key}, {pred_key}, {label_key})",
                    line_no,
                    line,
                )
                continue

            try:
                conf = float(obj[conf_key])
                pred = int(obj[pred_key])
                label = int(obj[label_key])
            except (TypeError, ValueError):
                _handle_invalid("Could not parse conf/pred/label into float/int", line_no, line)
                continue

            # Confidence should be in [0, 1]
            if not (0.0 <= conf <= 1.0):
                _handle_invalid("Confidence out of [0, 1] range", line_no, line)
                continue

            confs.append(conf)
            preds.append(pred)
            labels.append(label)

    if len(confs) == 0:
        raise ValueError("No valid rows found. Check file path and JSON keys/format.")

    confs_np = np.asarray(confs, dtype=np.float64)
    preds_np = np.asarray(preds, dtype=np.int64)
    labels_np = np.asarray(labels, dtype=np.int64)
    correct_np = (preds_np == labels_np)

    # ---- Compute ECE ----
    # Uniform bins over [0, 1]. We'll use (lo, hi] except the first bin includes 0.
    bin_edges = np.linspace(0.0, 1.0, n_bins + 1)
    ece = 0.0

    bin_counts = np.zeros(n_bins, dtype=np.int64)
    bin_acc = np.zeros(n_bins, dtype=np.float64)
    bin_conf = np.zeros(n_bins, dtype=np.float64)

    for i in range(n_bins):
        lo, hi = bin_edges[i], bin_edges[i + 1]

        if i == 0:
            in_bin = (confs_np >= lo) & (confs_np <= hi)
        else:
            in_bin = (confs_np > lo) & (confs_np <= hi)

        count = int(in_bin.sum())
        bin_counts[i] = count

        if count == 0:
            continue

        acc = float(correct_np[in_bin].mean())
        conf = float(confs_np[in_bin].mean())

        bin_acc[i] = acc
        bin_conf[i] = conf

        weight = count / len(confs_np)
        ece += weight * abs(acc - conf)

    if not return_details:
        return float(ece)

    details = {
        "n": int(len(confs_np)),
        "n_bins": int(n_bins),
        "bin_edges": bin_edges.tolist(),
        "bin_counts": bin_counts.tolist(),
        "bin_accuracy": bin_acc.tolist(),
        "bin_confidence": bin_conf.tolist(),
    }
    return float(ece), details


# Example:
# ece = ece_from_jsonl("preds.jsonl", n_bins=4)
# print("ECE:", ece)
#
# ece, details = ece_from_jsonl("preds.jsonl", n_bins=4, return_details=True)
# print("ECE:", ece)
# print(details)

In [33]:
import os
from pathlib import Path

def download_preds_artifacts(run, out_dir="artifacts"):
    """
    Download all artifacts from a W&B run whose name starts with 'preds'.

    - Does NOT overwrite existing artifacts
    - Skips already-downloaded ones
    """
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    for artifact in run.logged_artifacts():
        name = artifact.name.replace(':', '-')

        if not name.startswith("preds"):
            continue

        target_dir = out_dir / name

        # if target_dir.exists():
        #     # do not overwrite
        #     print(f"Skipping existing artifact: {name}")
        #     continue

        print(f"Downloading artifact: {name}")
        artifact.download(root=str(target_dir))


In [34]:
import json
import re
from pathlib import Path

def compute_ece_per_set_once(parent_dir, ece_function, n_bins=4):
    """
    Walk `parent_dir` (containing downloaded W&B artifact folders).
    For each artifact folder, find .jsonl files and compute ECE IF:
      1) JSONL has a 'conf' field
      2) ECE not already computed for that set number

    Returns: dict like {set_num: {"ece": float, "jsonl_path": str, "artifact_dir": str}}
    """
    parent_dir = Path(parent_dir)
    results = {}
    computed_sets = set()

    # Example folder name contains "...-set3-..." -> capture 3
    set_re = re.compile(r"(?:^|-)set(\d+)(?:-|$)")

    for artifact_dir in sorted([p for p in parent_dir.iterdir() if p.is_dir()]):
        m = set_re.search(artifact_dir.name)
        if not m:
            continue

        set_num = int(m.group(1))
        if set_num in computed_sets:
            # requirement #2: already computed for this set number
            continue

        # find all .jsonl under this artifact folder
        jsonl_paths = list(artifact_dir.rglob("*.jsonl"))
        if not jsonl_paths:
            continue

        # pick the first .jsonl (or iterate until you find one with conf)
        picked = None
        for jsonl_path in sorted(jsonl_paths):
            # requirement #1: 'conf' must exist (check first non-empty JSON line)
            has_conf = False
            with jsonl_path.open("r", encoding="utf-8") as f:
                for line in f:
                    line = line.strip()
                    if not line:
                        continue
                    obj = json.loads(line)
                    has_conf = ("conf" in obj)
                    break

            if has_conf:
                picked = jsonl_path
                break

        if picked is None:
            # no usable jsonl with conf in this artifact dir
            continue

        ece = ece_function(str(picked), n_bins=n_bins)
        results[set_num] = {
            "ece": float(ece),
            "jsonl_path": str(picked),
            "artifact_dir": str(artifact_dir),
        }
        computed_sets.add(set_num)

    return results


In [None]:
import wandb
import pandas as pd

ENTITY = "jacoba-california-state-university-east-bay"
PROJECT = "humaid_ssl"
ARTIFACT_PARENT = "wandb-preds"

api = wandb.Api()

project_obj = api.project(name=PROJECT, entity=ENTITY)
sweeps = project_obj.sweeps()

best_runs = []

for sweep in sweeps:
    try:
        best = sweep.best_run(order="dev_macro-F1")
        if not best:
            continue
    except Exception as e:
        print(f"Could not get best run for sweep {sweep.id} — skipping. Error: {e}")
        continue

    download_preds_artifacts(best, ARTIFACT_PARENT)
    ece_dict = compute_ece_per_set_once(ARTIFACT_PARENT, ece_function)

    best_runs.append({
        "sweep_id": sweep.id,
        "sweep_name": getattr(sweep, "name", None),
        "best_run_id": best.id,
        "best_run_name": best.name,
        "summary": best.summary,
        "config": best.config,
        "ece1": ece_dict[1]["ece"], 
        "ece2": ece_dict[2]["ece"], 
        "ece3": ece_dict[3]["ece"], 
    })

df = pd.DataFrame(best_runs)


In [37]:
import pandas as pd
import re

df2 = pd.DataFrame({
    "sweep_name": df["sweep_name"],
    "dev_macro-F1": df["summary"].apply(lambda s: s.get("dev_macro-F1")),
    "test_macro-F1": df["summary"].apply(lambda s: s.get("test_macro-F1"))
})

# Extract the lbcl count from the sweep name (e.g., kerala_floods_2018_50lbcl → 50)
df2["lbcl_count"] = df2["sweep_name"].str.extract(r"_(\d+)lbcl").astype(float)

# Extract disaster name (everything before _<number>lbcl)
df2["disaster_name"] = df2["sweep_name"].str.extract(r"^(.*)_\d+lbcl")

# Now sort:
df2 = df2.sort_values(["lbcl_count", "disaster_name"], ascending=[True, True])


In [38]:
# Identify unique lbcl values
lbcls = sorted(df2["lbcl_count"].unique())

tables = {}
for lbcl in lbcls:
    sub = df2[df2["lbcl_count"] == lbcl].reset_index(drop=True)
    tables[lbcl] = sub["test_macro-F1"].tolist()

for key, items in tables.items():
    print(f"{key}:")
    for item in items:
        print(item)
    print()

5.0:
0.49733936820431807
0.5191025072017287
0.4586526049724043
0.4966738793228979
0.5905430216894625
0.4936308624500024
0.5419256023827606
0.5920581987871771
0.619976551536117
0.43018815932671367

10.0:
0.5592429270583196
0.5986065538148083
0.50437982285394
0.5218830168807028
0.6310426842947895
0.5895643006829325
0.5968836940064904
0.6163896344323987
0.6801174396152766
0.49110881680786383

25.0:
0.6512848089093568
0.5939190763698257
0.6395447126493671
0.5772570764098851
0.6635741437243766
0.6478561567792466
0.6131685982859593
0.6613639528348317
0.7128565840439821
0.564855271916366

50.0:
0.6526330159256859
0.6135023641401601
0.6092617177414079
0.5936654183611628
0.6939616083843202
0.6768169689963898
0.6475989843849889
0.6681890838371644
0.7671146722989439
0.583763307865972



In [47]:
import re

def parse_event_lbcl(name: str):
    """
    Extract event name and lbcl integer from sweep/run name.
    Example:
      'kerala_floods_2018_50lbcl' ->
          ('kerala_floods_2018', 50)
    """
    m = re.search(r"(.+)_([0-9]+)lbcl$", name)
    if not m:
        return None, None
    return m.group(1), int(m.group(2))


# Build ECE table with parsed fields
rows = []

for _, row in df.iterrows():
    event, lbcl = parse_event_lbcl(row["sweep_name"])

    rows.append({
        "event": event,
        "lbcl": lbcl,
        "sweep_name": row["sweep_name"],
        "ece_set1": row["ece1"],
        "ece_set2": row["ece2"],
        "ece_set3": row["ece3"],
    })

ece_table = (
    pd.DataFrame(rows)
    .sort_values(by=["lbcl", "event"])
    .reset_index(drop=True)
)

ece_table.to_csv("ece_by_sweep.csv", index=False)
