#  Data preprocessing
first iomt traffic dataset link: https://zenodo.org/records/8116338
second dataset ice dataset link: http://perception.inf.um.es/ICE-datasets/
third wustl dataset link: https://www.cse.wustl.edu/~jain/ehms/index.html

**Split Iomt traffic dataset into 2 datasets to more client simulation without reordering anything**

In [None]:
# Mount Drive
from google.colab import drive
drive.mount('/content/drive')

# === CONFIG ===
INPUT_CSV  = "/content/drive/MyDrive/fed_MID/data/iomt_trafficdata/IP-Based-Flows-Dataset.csv"
OUTPUT_DIR = "/content/drive/MyDrive/fed_MID/iomt_clients"
CLIENT_A   = "client_iomt_A.csv"
CLIENT_B   = "client_iomt_B.csv"
# ==============

import os, pathlib, hashlib
import numpy as np
import pandas as pd

pathlib.Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
print("Reading:", INPUT_CSV)
df = pd.read_csv(INPUT_CSV, low_memory=False)
df["_orig_idx"] = np.arange(len(df))  # preserve original order for flow-pure splits

# --- normalise some columns (keep original names too) ---
cols = {c.lower(): c for c in df.columns}
traffic_col = cols.get("traffic") or cols.get("scenario") or cols.get("label") or None
service_col = cols.get("service")
proto_col   = cols.get("proto")

def lc_series(s):
    return s.astype(str).str.strip().str.lower()

t = lc_series(df[traffic_col]) if traffic_col else pd.Series([""], index=df.index)
s = lc_series(df[service_col]) if service_col else pd.Series([""], index=df.index)
p = lc_series(df[proto_col])   if proto_col   else pd.Series([""], index=df.index)

# --- define families by scenario keywords (tuned to this dataset) ---
A_kw = (
    r"apache\s*killer|rudy|slow\s*loris|slow\s*read|malaria|mqtt|http"
)
B_kw = (
    r"\barp\b|arp\s*spoof|cam(\s|_|-)?table|net\s*scan|netscan|dns"
)

assign = pd.Series("", index=df.index, dtype="object")
assign = np.where(t.str.contains(A_kw, regex=True), "A", assign)
assign = np.where((assign=="") & t.str.contains(B_kw, regex=True), "B", assign)

# --- fallback by service family (if scenario text was missing) ---
assign = np.where((assign=="") & s.isin(["http","mqtt"]), "A", assign)
assign = np.where((assign=="") & s.isin(["arp","dns"]), "B", assign)

# --- final fallback: stable hash on a 5-tuple (keeps clients disjoint, deterministic) ---
def stable_bit(a):
    h = hashlib.md5(("||".join(map(lambda x: "" if pd.isna(x) else str(x), a))).encode("utf-8")).hexdigest()
    return int(h, 16) & 1

# vectorised hash: apply on a dataframe of fields
hash_fields = pd.DataFrame({
    "id.orig_h": df.get("id.orig_h", ""),
    "id.orig_p": df.get("id.orig_p", ""),
    "id.resp_h": df.get("id.resp_h", ""),
    "id.resp_p": df.get("id.resp_p", ""),
    "proto":     df.get(proto_col, s)  # fallback to service if proto missing
})
bits = hash_fields.astype(str).agg("||".join, axis=1).apply(lambda x: int(hashlib.md5(x.encode()).hexdigest(), 16) & 1)
assign = np.where(assign=="", np.where(bits==0, "A", "B"), assign)

df["__client__"] = assign

# --- binary attack indicator (robust to various label schemes) ---
def infer_binary_attack(frame):
    # prefer is_attack if available
    for cand in ["is_attack","Is_Attack","label","Label","class","Class"]:
        if cand in frame.columns:
            col = frame[cand]
            if col.dtype.kind in "biufc":
                return (pd.to_numeric(col, errors="coerce").fillna(0) != 0).astype("int8")
            s = col.astype(str).str.lower().str.strip()
            return (~s.str.contains("benign|normal|clean")).astype("int8")
    # fallback to traffic scenario
    if traffic_col:
        s = lc_series(frame[traffic_col])
        return (~s.str.contains("benign|normal|clean")).astype("int8")
    # last resort: everything benign
    return pd.Series(np.zeros(len(frame), dtype="int8"), index=frame.index)

df["Label_binary"] = infer_binary_attack(df)

# --- write clients in original file order (flow-pure) ---
dfA = df.loc[df["__client__"]=="A"].sort_values("_orig_idx").drop(columns=["__client__"])
dfB = df.loc[df["__client__"]=="B"].sort_values("_orig_idx").drop(columns=["__client__"])

outA = os.path.join(OUTPUT_DIR, CLIENT_A)
outB = os.path.join(OUTPUT_DIR, CLIENT_B)
dfA.to_csv(outA, index=False)
dfB.to_csv(outB, index=False)

def summarize(name, frame):
    pos = int(frame["Label_binary"].sum()) if "Label_binary" in frame.columns else 0
    n   = len(frame)
    pr  = (100.0*pos/n) if n else 0.0
    print(f"{name}: rows={n:,} | attacks={pos:,} ({pr:.2f}%)")

print("\n=== Split summary ===")
summarize("Client A (IoMT_A)", dfA)
summarize("Client B (IoMT_B)", dfB)
print("Saved:", outA, "and", outB)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Reading: /content/drive/MyDrive/fed_MID/data/iomt_trafficdata/IP-Based-Flows-Dataset.csv


 converting ice ransomware dataset from binetflow to csv

In [None]:


ICE_DIR   = "/content/drive/MyDrive/fed_MID/data/ICEdataset"
OUT_CSV   = "/content/drive/MyDrive/fed_MID/iomt_clients/client_ICE.csv"


def read_binetflow_with_rowid(path, file_id):
    # read as-is; order preserved
    df = pd.read_csv(path, sep=",", engine="python", comment="#",
                     skip_blank_lines=True, on_bad_lines="skip")
    df.columns = [c.strip() for c in df.columns]
    df["scenario"] = pathlib.Path(path).stem.lower()
    df["__file_id__"] = file_id
    df["__row_in_file__"] = range(len(df))  # preserves original row order
    return df

# define a deterministic file order (edit if you prefer a different sequence)
files = [
    os.path.join(ICE_DIR, "clean.binetflow"),
    os.path.join(ICE_DIR, "wannacry.binetflow"),
    os.path.join(ICE_DIR, "petya.binetflow"),
    os.path.join(ICE_DIR, "badrabbit.binetflow"),
    os.path.join(ICE_DIR, "powerghost.binetflow"),
]
parts = [read_binetflow_with_rowid(p, i) for i, p in enumerate(files)]
df = pd.concat(parts, ignore_index=True)

# standardize is_attack without touching order
label_col = next((c for c in ["Label","label","is_attack","IsAttack"] if c in df.columns), None)
if label_col:
    s = df[label_col].astype(str).str.strip().str.lower()
    df["is_attack"] = (~s.str.contains("^normal$", na=False)).astype(int)
else:
    df["is_attack"] = (~df["scenario"].str.contains("clean", na=False)).astype(int)

df.to_csv(OUT_CSV, index=False)
print("Saved:", OUT_CSV)




The pipeline standardises raw IoMT datasets for federated learning. It:

Loads client CSVs and infers binary labels.

Cleans features (removes IDs/biometrics, fixes missing/infinite values).

Splits into train/val/test with balance checks.

Normalises, scales, and windows the data for time-series learning.

Saves processed arrays, metadata, and validation reports for consistency.

In [None]:
# Cell 1 — Config
from google.colab import drive
drive.mount('/content/drive')

# Input CSVs
ICE_CSV   = "/content/drive/MyDrive/fed_MID/iomt_clients/client_ICE.csv"
IOMT_B_CSV  = "/content/drive/MyDrive/fed_MID/iomt_clients/client_iomt_B.csv"
WUSTL_CSV = "/content/drive/MyDrive/fed_MID/iomt_clients/client_WUSTL.csv"
IOMT_A_CSV  = "/content/drive/MyDrive/fed_MID/iomt_clients/client_iomt_A.csv"
# Output directory for current dataset (change per run)
OUT_ICE     = "/content/drive/MyDrive/fed_MID/DATA_CLIENT_NPY/ICE"
OUT_IOMT_A     = "/content/drive/MyDrive/fed_MID/DATA_CLIENT_NPY/IOMT_A"
OUT_IOMT_B     = "/content/drive/MyDrive/fed_MID/DATA_CLIENT_NPY/IOMT_B"
OUT_WUSTL     = "/content/drive/MyDrive/fed_MID/DATA_CLIENT_NPY/WUSTL"
DATASET_NAME = "IOMT_B"




import os, pathlib, numpy as np, random
np.random.seed(42); random.seed(42)

# Parameters
W = 50
HORIZONS = (1,5,10)
STRIDE = 1                 # keep 1 to guarantee windows
CLIP_Q_LOW, CLIP_Q_HIGH = 0.005, 0.995
SAVE_AS_MEMMAP = True      # writes .npy (not .npz)

# --- splitting config ---
SPLIT_MODE = "AUTO"         # "CONTIG", "CHUNKED", or "AUTO"
CHUNK_ROWS = 20_000        # size of contiguous chunks for chunked split
MIN_CHUNK_ROWS = 5_000
MAX_CHUNK_RETRIES = 3
POS_EXTREME = 0.98           # if val/test pre-window pos-rate >= 0.98 (or <= 0.02) we retry
VAL_FRAC   = 0.05
TEST_FRAC  = 0.15
# Guards
assert all(h > 0 for h in HORIZONS), "All horizons must be > 0"
assert STRIDE >= 1, "STRIDE must be >= 1"
pathlib.Path(OUT_ICE).mkdir(parents=True, exist_ok=True)

# Toggle EDA
RUN_EDA = True


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


The cell below defines helper functions for robust CSV reading, automatic inference of binary attack labels from diverse dataset columns, and addition of base metadata (row indices, file name, unique IDs). The label inference logic harmonises heterogeneous sources by mapping different label formats into a consistent Label_binary (0/1).

In [None]:
#cell2
import pandas as pd
import hashlib
import numpy as np
import os

def read_csv_robust(path):
    return pd.read_csv(path, engine="python", on_bad_lines="skip")

def infer_binary_label(df):
    """Return series Label_binary (0/1) and keep original in Label_raw if present."""
    cols = {c.lower(): c for c in df.columns}
    for name in ["is_attack","label_binary","label","class","attack","attack category","attack_category"]:
        if name in cols:
            raw = df[cols[name]]
            if raw.dtype.kind in "biufc":
                return (raw.fillna(0).astype(float)!=0).astype("int8"), cols[name]
            s = raw.astype(str).str.strip().str.lower()
            pos = ~(s.str.fullmatch("benign|normal|clean"))
            return pos.astype("int8"), cols[name]
    if "scenario" in df.columns:
        pos = ~df["scenario"].astype(str).str.lower().str.contains("clean|normal")
        return pos.astype("int8"), "scenario"
    raise ValueError("Could not infer binary label column.")

def add_base_meta(df, src_path):
    df = df.copy()
    df["_orig_idx"] = np.arange(len(df))       # used for positional split
    df["Label_raw_file"] = os.path.basename(src_path)
    df["row_uid"] = df["Label_raw_file"].astype(str) + "#" + df["_orig_idx"].astype(str)
    df["Split_hint"] = "all"
    return df


Data integrity checks: The cell below audits duplicates and missing values without dropping any rows. In sequential traffic data, row order is critical because models learn from sliding windows over consecutive flows. Removing duplicates could disrupt temporal continuity and harm sequence learning. Instead, duplicates are hashed and audited to detect possible leakage across splits, while preserving the full flow order for training.

In [None]:
#cell3
def numeric_coerce_cols(df, names):
    for n in names:
        if n in df.columns:
            df[n] = pd.to_numeric(df[n], errors="coerce")

def content_hash_row(vals):
    # format numbers consistently; join; sha1
    s = []
    for v in vals:
        if isinstance(v, (float, np.floating, int, np.integer)):
            try: s.append(f"{float(v):.10g}")
            except: s.append(str(v))
        else:
            s.append(str(v))
    return hashlib.sha1("|".join(s).encode("utf-8")).hexdigest()

def audit_duplicates_no_drop(df, feature_cols, label_col="Label_binary", audit_path=None):
    """
    Compute a stable content_hash over feature_cols, group stats, and write an audit CSV.
    Does NOT drop any rows .
    """
    hashes = df[feature_cols].apply(lambda r: content_hash_row(r.values), axis=1)
    df = df.assign(content_hash=hashes)

    stats = df.groupby("content_hash")[label_col].agg(
        total_rows="count",
        nlabels=lambda s: s.nunique(),
        pos_count=lambda s: (s.astype(int) == 1).sum()
    ).reset_index()
    stats["neg_count"] = stats["total_rows"] - stats["pos_count"]

    if audit_path is not None:
        try:
            stats.to_csv(audit_path, index=False)
        except Exception as e:
            print(f"[warn] audit write failed: {e}")

    return df, 0, stats


def missing_inf_report(df):
    rep = {}
    for c in df.columns:
        s = pd.to_numeric(df[c], errors="coerce") if df[c].dtype == object else df[c]
        if np.issubdtype(s.dtype, np.number):
            rep[c] = {
                "missing": int(s.isna().sum()),
                "pos_inf": int(np.isposinf(s).sum()),
                "neg_inf": int(np.isneginf(s).sum())
            }
    return rep


**Feature curation**
This cell below is to prevents leakage and reduces noise. Raw ports are bucketed into categories (well-known, registered, ephemeral). Identifiers (IPs, MACs, UIDs), timestamps, and label/meta fields are dropped to avoid shortcuts. Small categorical fields are one-hot encoded, and biometric columns are removed for privacy. Finally, a negative-value audit flags unexpected negatives in non-signed features.

In [None]:
#cell4
import re


def port_bucket_series(s):
    p = pd.to_numeric(s, errors="coerce")
    cat = pd.Series("unknown", index=s.index)
    cat = np.where((p>=0)&(p<=1023), "well_known", cat)
    cat = np.where((p>=1024)&(p<=49151), "registered", cat)
    cat = np.where((p>=49152)&(p<=65535), "ephemeral", cat)
    return pd.Series(cat, index=s.index, dtype="category")

def apply_port_flags(df):
    df = df.copy()
    port_candidates = [
     c for c in df.columns
     if re.search(r"(?:^|[_\.])(sport|dport|id\.orig_p|id\.resp_p)$", c, re.I)
    ]
    raw_to_drop = []
    for c in port_candidates:
        b = port_bucket_series(df[c])
        for val in ["well_known","registered","ephemeral"]:
            df[f"{c}_bucket_{val}"] = (b==val).astype("int8")
        raw_to_drop.append(c)
    df.drop(columns=raw_to_drop, inplace=True, errors="ignore")
    return df

# safer ID-like pattern: DO NOT include 'ip' (keeps resp_ip_bytes etc.)
ID_LIKE_PAT = re.compile(r"(mac|addr|address|uuid|uid|guid|session|ssid|hash|md5|sha1|patient|nhs|ssn|room|bed|ward)", re.I)
TIME_LIKE_PAT = re.compile(r"(time|timestamp|ts|starttime|date)", re.I)

NEVER_FEATURES = set(["row_uid","content_hash","_orig_idx","Label_raw_file","Split_hint","Label_raw"])

EXPLICIT_BLOCK = set([
    # network identifiers/leakage
    "id.orig_h","id.resp_h","id.orig_p","id.resp_p",
    "SrcAddr","DstAddr","SrcMac","DstMac","Packet_num",
    # protocol/state-ish categoricals you don't want as raw IDs
    "proto","Proto","service","conn_state","history","tunnel_parents",
    "State","Flgs","Dir",
    # labels/meta
    "traffic","is_attack","Label","label","class","scenario","Label_binary",
    "Attack Category","Attack_Category","StartTime"
])

DROP_BIOMETRICS = True
BIOMETRIC_COLS = {
    "temp","spo2","pulse_rate","sys","dia","heart_rate","resp_rate","st"
}
SMALL_ONE_HOTS = ["proto","Proto","service","conn_state","state","State",
                  "dir","Dir","Flgs","local_orig","local_resp"]

def curate_features(df):
    df = df.copy()
    present_small = [c for c in SMALL_ONE_HOTS if c in df.columns]
    if present_small:
        df = pd.get_dummies(df, columns=present_small, dummy_na=False)

    block = set()
    for c in df.columns:
        lc = c.lower()
        if c in NEVER_FEATURES or c in EXPLICIT_BLOCK:
            block.add(c); continue
        if TIME_LIKE_PAT.search(c):
            block.add(c); continue
        if ID_LIKE_PAT.search(c):
            block.add(c); continue
        # NEW: drop WUSTL-like biometrics if requested
        if DROP_BIOMETRICS and lc in BIOMETRIC_COLS:
            block.add(c); continue

    candidates = [c for c in df.columns if c not in block]
    cand_num = [c for c in candidates if np.issubdtype(df[c].dtype, np.number)]
    assert len(cand_num) > 0, "No numeric feature columns remained after curation."
    return df, cand_num




# columns whose names match this may legitimately contain negatives
_NEG_ALLOW_RX = re.compile(r"(diff|delta|change|zscore|(^|[_\.])z($|[_\.])|resid|residual|signed|skew)", re.I)

def negative_value_audit(df, feature_cols):
    """
    Report columns that should typically be non-negative but have negatives.
    Skips columns whose names suggest signed quantities (diff/z/resid/etc.).
    Returns a dict: {col: {"negatives": n, "total": N, "frac": n/N, "min": min_negative}}
    """
    rep = {}
    for c in feature_cols:
        if _NEG_ALLOW_RX.search(c):
            continue
        s = pd.to_numeric(df[c], errors="coerce")
        s = s[~s.isna()]
        if s.empty:
            continue
        neg_mask = s < 0
        if neg_mask.any():
            vneg = s[neg_mask]
            rep[c] = {
                "negatives": int(neg_mask.sum()),
                "total": int(s.size),
                "frac": float(neg_mask.mean()),
                "min": float(vneg.min()),
            }
    return rep



**Train/Val/Test splitting (temporal + balanced fallback)**
We first split by row order (flows in file sequence) to preserve temporal structure and avoid look-ahead. We also leak-check using a content_hash so exact duplicate flows don’t appear across splits.
However, two datasets out of 4 had a long contiguous regions of a single class, which can make val/test nearly all positives or all negatives. When that happens, we switch to a chunked, class-balanced split: we cut the two datasets into contiguous chunks, compute each chunk’s positive rate, then distribute high/low chunks round-robin into train/val/test (with retries and smaller chunks) until val/test class balance is reasonable and each split is large enough to form at least one window.

In [None]:
#cell5
def temporal_split_indices(df, test_frac=0.15, val_frac=0.05):
    n = len(df)
    i_test = int(n*(1-test_frac))
    i_val  = int(n*(1-test_frac-val_frac))
    return slice(0,i_val), slice(i_val,i_test), slice(i_test,n)

def make_splits(df):
    # Always split by row index (flow order)
    tr, va, te = temporal_split_indices(df)
    idx = {"train": tr, "val": va, "test": te}

    # leak check
    leak = {}
    if "content_hash" in df.columns:
        for a, b in [("train","val"),("train","test"),("val","test")]:
            s = set(df.iloc[idx[a]]["content_hash"]).intersection(
                  set(df.iloc[idx[b]]["content_hash"]))
            leak[f"{a}_x_{b}"] = len(s)
    return idx, leak

#   chunk-stratified split (balanced + retries)

def _ranges_to_index(ranges):
    out = []
    for a, b in ranges:
        out.extend(range(a, b))
    return np.array(out, dtype=int)

def _split_sizes(n, val_frac, test_frac):
    n_val  = int(round(n * val_frac))
    n_test = int(round(n * test_frac))
    n_train = n - n_val - n_test
    return n_train, n_val, n_test

def _pos_rate_from_index(y_all, ix):
    arr = y_all[ix] if isinstance(ix, slice) else y_all[np.asarray(ix)]
    return float(arr.mean()) if arr.size else 0.0, int(arr.size)

def make_splits_chunked(df, val_frac=0.05, test_frac=0.15, chunk_rows=20_000, seed=42,
                        pos_extreme=0.98, min_chunk_rows=5_000, max_retries=3):
    """
    Balanced chunked split:
      - make contiguous chunks [a,b)
      - compute pos-rate per chunk
      - distribute chunks in round-robin across (train, val, test) from high & low ends
      - retry with smaller chunk_rows if any split is near-all-positive/negative
    """
    rng = np.random.default_rng(seed)
    y_all = df["Label_binary"].astype("int8").to_numpy()
    n = len(df)
    n_train, n_val, n_test = _split_sizes(n, val_frac, test_frac)
    targets = {"train": n_train, "val": n_val, "test": n_test}
    order_cycle = ["train", "val", "test"]

    for attempt in range(max_retries + 1):
        # build chunks
        starts = list(range(0, n, chunk_rows))
        chunks = [(s, min(s + chunk_rows, n)) for s in starts]

        # compute pos-rate per chunk
        per_chunk = []
        for (a, b) in chunks:
            seg = y_all[a:b]
            pr = float(seg.mean()) if seg.size else 0.0
            per_chunk.append((pr, a, b))

        # sort by pos-rate
        per_chunk.sort(key=lambda t: t[0])   # low .. high
        lo, hi = 0, len(per_chunk) - 1

        # assign in round-robin high/low fashion
        sel = {"train": [], "val": [], "test": []}
        used = {"train": 0, "val": 0, "test": 0}
        pick_hi = True
        cyc_idx = 0

        def _room(split):
            return max(0, targets[split] - used[split])

        while lo <= hi:
            pr, a, b = per_chunk[hi] if pick_hi else per_chunk[lo]
            pick_hi = not pick_hi
            span = b - a

            # try up to 3 different next-splits to place without overrunning target too much
            tried = 0
            placed = False
            while tried < 3 and not placed:
                split = order_cycle[cyc_idx % 3]
                cyc_idx += 1
                tried += 1
                room = _room(split)
                # allow small overshoot (<= half a chunk) to finish targets
                if room >= span or (room == 0 and used[split] < targets[split] + chunk_rows // 2):
                    sel[split].append((a, b))
                    used[split] += span
                    placed = True

            if not placed:
                # force into the split with max remaining room
                split = max(order_cycle, key=lambda s: _room(s))
                sel[split].append((a, b))
                used[split] += span

            if pick_hi:
                hi -= 1
            else:
                lo += 1

        # materialize indices
        idx = {k: _ranges_to_index(v) for k, v in sel.items()}

        # quick pre-window diagnostics
        pr_tr, n_tr = _pos_rate_from_index(y_all, idx["train"])
        pr_va, n_va = _pos_rate_from_index(y_all, idx["val"])
        pr_te, n_te = _pos_rate_from_index(y_all, idx["test"])
        print(f"[split] pre-window pos-rate  train={pr_tr:.3f} (n={n_tr})  "
              f"val={pr_va:.3f} (n={n_va})  test={pr_te:.3f} (n={n_te})")

        # sanity: each split must have at least one window
        min_win = W + max(HORIZONS) + 1
        ok_sizes = (n_tr >= min_win) and (n_va >= min_win) and (n_te >= min_win)
        lo, hi = (1 - pos_extreme), pos_extreme
        ok_balance = (lo <= pr_va <= hi) and (lo <= pr_te <= hi)


        if ok_sizes and ok_balance:
            # leak check via content_hash (optional)
            leak = {}
            if "content_hash" in df.columns:
                for a, b in [("train","val"), ("train","test"), ("val","test")]:
                    s = set(df.iloc[idx[a]]["content_hash"]).intersection(
                        set(df.iloc[idx[b]]["content_hash"])
                    )
                    leak[f"{a}_x_{b}"] = len(s)
            return idx, leak

        # retry with smaller chunk_rows
        if attempt < max_retries and chunk_rows > min_chunk_rows:
            chunk_rows = max(min_chunk_rows, chunk_rows // 2)
            print(f"[split] retrying with smaller CHUNK_ROWS={chunk_rows} "
                  f"(attempt {attempt+1}/{max_retries})")
        else:
            print("[split] Could not obtain balanced chunked splits within retries; "
                  "falling back to contiguous.")
            return make_splits(df)  # fallback



**Imputation, variance filtering, clipping, and scaling**
We handle feature preprocessing in three stages:

Imputation: Missing values are filled with training-set medians (applied consistently to val/test) to avoid information leakage.

Variance filtering: Features with zero variance in the training split are dropped, since they carry no predictive signal.

Clipping & scaling: Each feature is clipped to the [0.5%, 99.5%] quantile range (robust against outliers), then transformed with a RobustScaler to normalize distributions while reducing sensitivity to extreme values.

In [None]:
#cell6
from sklearn.preprocessing import RobustScaler

def train_only_impute(df_tr, df_va, df_te, feature_cols):
    med = df_tr[feature_cols].median()
    df_tr[feature_cols] = df_tr[feature_cols].fillna(med)
    df_va[feature_cols] = df_va[feature_cols].fillna(med)
    df_te[feature_cols] = df_te[feature_cols].fillna(med)
    return med, (df_tr, df_va, df_te)

def drop_zero_var(df_tr, df_va, df_te, feature_cols):
    keep = [c for c in feature_cols if df_tr[c].nunique(dropna=True) > 1]
    return keep, (df_tr[keep], df_va[keep], df_te[keep])

def clip_and_scale(df_tr, df_va, df_te, feature_cols, qlow=CLIP_Q_LOW, qhigh=CLIP_Q_HIGH):
    ql = df_tr[feature_cols].quantile(qlow)
    qh = df_tr[feature_cols].quantile(qhigh)
    # nudge degenerate bounds
    for c in feature_cols:
        if qh[c] <= ql[c]:
            qh[c] = ql[c] + 1e-9
    def _clip(df):
        vals = df[feature_cols].to_numpy(dtype=np.float64, copy=True)
        low  = ql[feature_cols].astype(float).to_numpy()
        high = qh[feature_cols].astype(float).to_numpy()
        # Broadcast clip per-column
        np.clip(vals, low, high, out=vals)
        out = df.copy()
        out[feature_cols] = vals.astype(np.float32, copy=False)
        return out
    df_tr_c = _clip(df_tr); df_va_c = _clip(df_va); df_te_c = _clip(df_te)
    scaler = RobustScaler()
    Xtr = scaler.fit_transform(df_tr_c[feature_cols].values.astype("float32"))
    Xva = scaler.transform(df_va_c[feature_cols].values.astype("float32"))
    Xte = scaler.transform(df_te_c[feature_cols].values.astype("float32"))
    assert np.isfinite(Xtr).all() and np.isfinite(Xva).all() and np.isfinite(Xte).all()
    return scaler, (Xtr, Xva, Xte), {"clip_low": ql.to_dict(), "clip_high": qh.to_dict()}


**Sliding-window labeling with earliness (K):**
We convert row-wise features into fixed-length sequences for sequence models. For each time index t, we take the past W rows as one window X[t-W:t]. For each prediction horizon h ∈ {1,5,10}, the label Y[h] is 1 if any positive event occurs in the next h steps (look-ahead), else 0. We also record K[h]: the step (1..h) of the first positive within that horizon, or −1 if none—this supports earliness/lead-time analysis. The loop advances by stride (default 1). Outputs: a tensor X of shape (num_windows, W, F), dicts Y[h] (int8) and K[h] (int16) aligned per window.

In [None]:
#cell7

def build_windows_by_rows_with_k(X_mat, y_vec, W=50, horizons=(1,5,10), stride=1):
    N, F = X_mat.shape
    Hmax = max(horizons)
    out_X, outY = [], {h: [] for h in horizons}
    outK = {h: [] for h in horizons}  # 1..h to first attack; -1 if none
    t = W
    while t + Hmax <= N:
        out_X.append(X_mat[t-W:t])
        future = y_vec[t:t+Hmax]
        for h in horizons:
            seg = future[:h]
            y = int((seg > 0).any())
            k = (np.argmax(seg > 0) + 1) if y == 1 else -1
            outY[h].append(y)
            outK[h].append(k)
        t += stride
    X = np.stack(out_X).astype("float32") if out_X else np.empty((0, W, F), np.float32)
    Y = {h: np.array(outY[h], dtype="int8") for h in horizons}
    K = {h: np.array(outK[h], dtype="int16") for h in horizons}
    return X, Y, K


**integrates all the earlier preprocessing steps into a single pipeline.**
the cell below is a helper functions (label inference, feature curation, duplicate audit, splitting, imputation, scaling, and windowing) in sequence, adds guardrails for data balance and minimum window size, and finally saves processed datasets, audits, and metadata.

In [None]:
#cell8
import pickle, json

def process_client(csv_path, out_dir, dataset_name):
    print(f"\n=== {dataset_name} ===\nReading {csv_path}")
    os.makedirs(out_dir, exist_ok=True)
    df = read_csv_robust(csv_path)
    df = add_base_meta(df, csv_path)

    # labels
    y_bin, y_raw_col = infer_binary_label(df)
    df["Label_binary"] = y_bin
    if y_raw_col not in ("Label_binary",):
        df["Label_raw"] = df[y_raw_col]
    # ensure binary labels are strictly 0/1
    if not set(np.unique(df["Label_binary"].dropna().values)).issubset({0,1}):
     raise ValueError("Label_binary must be strictly 0/1.")
    # DO NOT reorder by time: keep original row order for flow windows
    time_col = None

    # coerce common numerics (best-effort; ok if absent)
    # coerce common numerics (best-effort; ok if absent)
    numeric_coerce_cols(df, [
      # IoMT traffic / CICFlowMeter / Zeek
      "duration","flow_duration","Dur","orig_bytes","resp_bytes","TotBytes","tot_bytes",
      "TotPkts","fwd_pkts_tot","bwd_pkts_tot",
      "fwd_pkts_per_sec","bwd_pkts_per_sec","flow_pkts_per_sec",
      "fwd_header_size_tot","bwd_header_size_tot","down_up_ratio",
      "payload_bytes_per_second",
      # Argus/CLE/WUSTL variants
      "SrcBytes","DstBytes","SrcPkts","Load","SrcLoad","Rate",
      "SIntPkt","DIntPkt","SrcGap","DstGap","Loss","pLoss","pSrcLoss","pDstLoss",
      "sMaxPktSz","dMaxPktSz","sMinPktSz","dMinPktSz","Trans"
   ])

    # pre-clean audit
    pre = missing_inf_report(df)

    # ports → buckets; curate feature set; fix inf
    df = apply_port_flags(df)
    df, feature_cols = curate_features(df)
    for c in feature_cols:
        df[c] = df[c].replace([np.inf, -np.inf], np.nan)

    # audit-only (no row removal): flows can repeat; we keep all rows
    aud_path = os.path.join(out_dir, "duplicates_report.csv")
    df_dedup, removed, dup_audit = audit_duplicates_no_drop(
        df, feature_cols, label_col="Label_binary", audit_path=aud_path
    )

    n_conflict = int((dup_audit['nlabels'] > 1).sum()) if dup_audit is not None else 0
    print(f"[audit] duplicate groups (no removal): {len(dup_audit)} | "
          f"groups with mixed labels: {n_conflict} | rows removed: {removed}")
    print(f"[audit] wrote per-hash stats to {aud_path}")



    # ---- choose split mode (contiguous / chunked / auto) ----
    if SPLIT_MODE == "CONTIG":
        idx, leak = make_splits(df_dedup)
    elif SPLIT_MODE == "CHUNKED":
        idx, leak = make_splits_chunked(
            df_dedup, val_frac=VAL_FRAC, test_frac=TEST_FRAC,
            chunk_rows=CHUNK_ROWS, seed=42,
            pos_extreme=POS_EXTREME, min_chunk_rows=MIN_CHUNK_ROWS, max_retries=MAX_CHUNK_RETRIES
        )
    else:
        # AUTO: try CONTIG, switch to CHUNKED if extreme
        idx, leak = make_splits(df_dedup)
        y_all = df_dedup["Label_binary"].astype("int8").to_numpy()
        def _pr(ix):
            arr = y_all[ix] if isinstance(ix, slice) else y_all[np.asarray(ix)]
            return float(arr.mean()) if arr.size else 0.0
        if (_pr(idx["val"]) >= POS_EXTREME or _pr(idx["val"]) <= 1-POS_EXTREME or
            _pr(idx["test"]) >= POS_EXTREME or _pr(idx["test"]) <= 1-POS_EXTREME):
            print("[split] AUTO switching to CHUNKED due to extreme val/test balance.")
            idx, leak = make_splits_chunked(
                df_dedup, val_frac=VAL_FRAC, test_frac=TEST_FRAC,
                chunk_rows=CHUNK_ROWS, seed=42,
                pos_extreme=POS_EXTREME, min_chunk_rows=MIN_CHUNK_ROWS, max_retries=MAX_CHUNK_RETRIES
            )


    # ---- ensure each split has enough rows to form at least one window ----
    def _len_ix(ix):
        return (ix.stop - ix.start) if isinstance(ix, slice) else int(np.asarray(ix).size)

    min_win = W + max(HORIZONS) + 1  # minimum rows needed to form a window
    n_tr0, n_va0, n_te0 = _len_ix(idx["train"]), _len_ix(idx["val"]), _len_ix(idx["test"])

    if (n_tr0 < min_win) or (n_va0 < min_win) or (n_te0 < min_win):
        print(f"[split] Selected split produced too-small split(s) "
              f"(train={n_tr0}, val={n_va0}, test={n_te0}, need≥{min_win}). "
              f"Retrying CHUNKED with smaller CHUNK_ROWS...")

        retry_chunk = max(min_win, CHUNK_ROWS // 2)
        idx, leak = make_splits_chunked(
            df_dedup, val_frac=VAL_FRAC, test_frac=TEST_FRAC,
            chunk_rows=retry_chunk, seed=42
        )

        n_tr0, n_va0, n_te0 = _len_ix(idx["train"]), _len_ix(idx["val"]), _len_ix(idx["test"])
        if (n_tr0 < min_win) or (n_va0 < min_win) or (n_te0 < min_win):
            print(f"[split] Retry still too small (train={n_tr0}, val={n_va0}, test={n_te0}). "
                  f"Falling back to CONTIG for safety.")
            idx, leak = make_splits(df_dedup)
        y_all = df_dedup["Label_binary"].astype("int8").to_numpy()
        def _len_ix(ix): return (ix.stop - ix.start) if isinstance(ix, slice) else int(np.asarray(ix).size)
        def _pos(ix):
            arr = y_all[ix] if isinstance(ix, slice) else y_all[np.asarray(ix)]
            return float(arr.mean()) if arr.size else 0.0, arr.size

        pr_tr,n_tr = _pos(idx["train"]); pr_va,n_va = _pos(idx["val"]); pr_te,n_te = _pos(idx["test"])
        print(f"[split] pre-window pos-rate  train={pr_tr:.3f} (n={n_tr})  val={pr_va:.3f} (n={n_va})  test={pr_te:.3f} (n={n_te})")


    # materialize splits
    tr = df_dedup.iloc[idx["train"]].copy()
    va = df_dedup.iloc[idx["val"]].copy()
    te = df_dedup.iloc[idx["test"]].copy()


    # --- Guardrail 1: ensure each split can form ≥1 window ---
    min_win = W + max(HORIZONS) + 1            # e.g., 50 + 10 + 1 = 61
    for name, part in [("train", tr), ("val", va), ("test", te)]:
        if len(part) < min_win:
            raise ValueError(
                f"{dataset_name}:{name} split too small for windowing — "
                f"need ≥ {min_win} rows, have {len(part)}."
            )
    # train-only impute
    med, (tr, va, te) = train_only_impute(tr, va, te, feature_cols)

    # drop zero-variance (decided on train); keep full frames
    keep_cols, _ = drop_zero_var(tr, va, te, feature_cols)
    feature_cols = keep_cols

    # clip+scale
    scaler, (Xtr_raw, Xva_raw, Xte_raw), clip_bounds = clip_and_scale(tr, va, te, feature_cols)

    # negative value audit
    neg_audit = negative_value_audit(tr, feature_cols)

    # save audits
    post = missing_inf_report(df_dedup)
    with open(os.path.join(out_dir, "missing_inf_report_preclean.json"), "w") as f: json.dump(pre, f, indent=2)
    with open(os.path.join(out_dir, "missing_inf_report_postclean.json"), "w") as f: json.dump(post, f, indent=2)
    with open(os.path.join(out_dir, "negative_values_report.json"), "w") as f: json.dump(neg_audit, f, indent=2)
    with open(os.path.join(out_dir, "cross_split_leak_report.json"), "w") as f: json.dump(leak, f, indent=2)
    with open(os.path.join(out_dir, "feature_columns.txt"), "w") as f: f.write("\n".join(feature_cols))

    # build ROW windows (now also returning K = distance-to-event)
    def split_to_windows_with_k(df_split, X_split):
        y_vec = df_split["Label_binary"].values.astype("int8")
        return build_windows_by_rows_with_k(X_split, y_vec, W=W, horizons=HORIZONS, stride=STRIDE)

    Xtr_win, Ytr, Ktr = split_to_windows_with_k(tr, Xtr_raw)
    Xva_win, Yva, Kva = split_to_windows_with_k(va, Xva_raw)
    Xte_win, Yte, Kte = split_to_windows_with_k(te, Xte_raw)

    # --- Persist exact split indices (relative to df_dedup) ---
    def _slice_to_list(slc):
        if isinstance(slc, slice):
            return list(range(slc.start, slc.stop))
        # handle numpy/pandas indexers too
        return [int(i) for i in (slc if hasattr(slc, "__iter__") else [int(slc)])]


    split_indices = {
      "train": _slice_to_list(idx["train"]),
      "val":   _slice_to_list(idx["val"]),
      "test":  _slice_to_list(idx["test"]),
      }
    with open(os.path.join(out_dir, "split_indices.json"), "w") as f:
       json.dump(split_indices, f, indent=2)
    # --- Guardrail 2a: X/Y alignment sanity ---
    n_tr, n_va, n_te = Xtr_win.shape[0], Xva_win.shape[0], Xte_win.shape[0]
    for h in HORIZONS:
        assert len(Ytr[h]) == n_tr, f"h={h} train y len {len(Ytr[h])} != X len {n_tr}"
        assert len(Yva[h]) == n_va, f"h={h} val   y len {len(Yva[h])} != X len {n_va}"
        assert len(Yte[h]) == n_te, f"h={h} test  y len {len(Yte[h])} != X len {n_te}"

    # --- Guardrail 2b: quick balance audit (just prints) ---
    def frac(y): return float((y > 0).mean()) if len(y) else 0.0
    print("Positive fractions by horizon (train | val | test):")
    for h in HORIZONS:
        print(f"  h={h}: {frac(Ytr[h]):.3f} | {frac(Yva[h]):.3f} | {frac(Yte[h]):.3f}")

    def _pos_frac(y):
     y = y.astype("int8")
     return float((y == 1).mean()) if len(y) else 0.0

    class_weights = {}
    for split, Y in [("train", Ytr), ("val", Yva), ("test", Yte)]:
        cw = {}
        for h in HORIZONS:
            p = _pos_frac(Y[h])
            # inverse-prevalence for positives, capped for stability
            w_pos = float(min(50.0, (1.0 - p) / max(1e-6, p))) if p > 0 else 1.0
            cw[int(h)] = {"w_neg": 1.0, "w_pos": w_pos}
        class_weights[split] = cw

    with open(os.path.join(out_dir, "class_weights.json"), "w") as f:
        json.dump(class_weights, f, indent=2)
    def save_split(name, X, Y, K):
        if SAVE_AS_MEMMAP:
            np.save(os.path.join(out_dir, f"X_{name}.npy"), X.astype("float32"))
            for h in HORIZONS:
                np.save(os.path.join(out_dir, f"y_{h}_{name}.npy"), Y[h].astype("int8"))
                np.save(os.path.join(out_dir, f"k_{h}_{name}.npy"), K[h].astype("int16"))
        else:
            np.savez_compressed(os.path.join(out_dir, f"X_{name}.npz"), X=X.astype("float32"))
            for h in HORIZONS:
                np.save(os.path.join(out_dir, f"y_{h}_{name}.npy"), Y[h].astype("int8"))
                np.save(os.path.join(out_dir, f"k_{h}_{name}.npy"), K[h].astype("int16"))

    save_split("train", Xtr_win, Ytr, Ktr)
    save_split("val",   Xva_win, Yva, Kva)
    save_split("test",  Xte_win, Yte, Kte)


    # reports
    horizon_balance = {
        "train": {int(h): int(Ytr[h].sum()) for h in HORIZONS},
        "val":   {int(h): int(Yva[h].sum()) for h in HORIZONS},
        "test":  {int(h): int(Yte[h].sum()) for h in HORIZONS}
    }
    with open(os.path.join(out_dir, "horizon_balance.json"), "w") as f: json.dump(horizon_balance, f, indent=2)

    artifacts = {
        "scaler": scaler,
        "clip_bounds": clip_bounds,
        "feature_cols": feature_cols,
        "medians": med.to_dict(),
        "W": W, "HORIZONS": list(HORIZONS), "STRIDE": STRIDE
    }
    with open(os.path.join(out_dir, "artifacts.pkl"), "wb") as f: pickle.dump(artifacts, f)

    meta = {
        "dataset": dataset_name,
        "n_features": len(feature_cols),
        "W": W, "horizons": list(HORIZONS), "stride": STRIDE,
        "shapes": {
            "train": [int(x) for x in Xtr_win.shape],
            "val":   [int(x) for x in Xva_win.shape],
            "test":  [int(x) for x in Xte_win.shape]
        },
        "leak_report": leak,
        "split_mode": SPLIT_MODE,
        "chunk_rows": CHUNK_ROWS,
        "val_frac": VAL_FRAC,
        "test_frac": TEST_FRAC
    }

    with open(os.path.join(out_dir, "meta.json"), "w") as f: json.dump(meta, f, indent=2)

    print(f"{dataset_name}: X_train {Xtr_win.shape}, X_val {Xva_win.shape}, X_test {Xte_win.shape}")
    print("Horizon positives (train):", {h:int(Ytr[h].sum()) for h in HORIZONS})
    if RUN_EDA:
      run_minimal_eda(df_dedup, idx, out_dir, feature_cols, corr_sample=5000)



**EDA (sanity visuals & quick stats)**
using the cell bellow we generates a lightweight EDA pack per client split to catch issues early. It writes a text summary (rows, kept features, split sizes), plots class balance (linear & log scale), shows top-missing features, a few train feature histograms (auto log if heavy-tailed), a small correlation heatmap (sampled), and an “attack-rate over file order” line to reveal drift. If Label_raw exists, it also saves a top-30 class snapshot and a binary-vs-raw crosstab. Outputs go to <out_dir>/EDA/.

In [None]:
# cell 8.1 — Minimal EDA figs to <out_dir>/EDA/
import matplotlib.pyplot as plt
import os, math, numpy as np, pandas as pd

# ---- toggles for optional EDA bits ----
SHOW_FEATURE_HISTS = True     # set False to skip histograms
MAX_FEAT_PLOTS     = 6        # how many features to show in hist figure
SHOW_LABEL_RAW     = True     # set False to skip label_raw plots/crosstab

def run_minimal_eda(df_full, idx, out_dir, feature_cols, corr_sample=5000, seed=7):
    eda_dir = os.path.join(out_dir, "EDA")
    os.makedirs(eda_dir, exist_ok=True)

    # 0) small text summary
    with open(os.path.join(eda_dir, "summary.txt"), "w") as f:
        f.write(f"rows={len(df_full)}\nfeatures_kept={len(feature_cols)}\n")
        for split in ("train","val","test"):
            sl = idx[split]
            n = len(df_full.iloc[sl]) if isinstance(sl, slice) else len(sl)
            y = df_full.loc[sl, "Label_binary"].astype("int64")
            f.write(f"{split}: n={n}, pos={int((y==1).sum())}, neg={int((y==0).sum())}\n")

    # 1) Class balance (linear & log) per split
    for name, slc in idx.items():
        y = df_full.loc[slc, "Label_binary"].values
        pos = int((y == 1).sum()); neg = int((y == 0).sum())
        for scale, tag in [("linear",""), ("log","_log")]:
            plt.figure()
            plt.bar(["neg","pos"], [max(0,neg), max(0,pos)])
            if scale == "log": plt.yscale("log")
            plt.title(f"class_balance_{name} (neg={neg}, pos={pos})")
            plt.tight_layout()
            plt.savefig(os.path.join(eda_dir, f"class_balance_{name}{tag}.png"), dpi=150)
            plt.close()

    # 2) Missingness (top-40 by NaN count)
    na_counts = df_full[feature_cols].isna().sum()
    na_counts = na_counts[na_counts > 0].sort_values(ascending=False).head(40)
    if len(na_counts) > 0:
        plt.figure(figsize=(10, max(3, 0.28*len(na_counts))))
        plt.barh(na_counts.index.tolist(), na_counts.values)
        plt.title("missingness_top40")
        plt.xlabel("NaN count")
        plt.tight_layout()
        plt.savefig(os.path.join(eda_dir, "missingness_top40.png"), dpi=150)
        plt.close()
    else:
        with open(os.path.join(eda_dir, "missingness_top40.txt"), "w") as f:
            f.write("No missing values in feature columns.\n")

    # 3) Distributions for a few train features (optional; auto log-hist if heavy skew)
    if SHOW_FEATURE_HISTS:
        tr = df_full.loc[idx["train"], feature_cols]
        feats = feature_cols[:MAX_FEAT_PLOTS]
        cols = 2; rows = max(1, math.ceil(len(feats)/cols))
        plt.figure(figsize=(11, 3.2*rows))
        rng = np.random.default_rng(seed)
        for i, c in enumerate(feats, 1):
            s = pd.to_numeric(tr[c], errors="coerce").dropna()
            if len(s) > 200000:   # cap for big datasets
                s = pd.Series(rng.choice(s.values, size=200000, replace=False))
            plt.subplot(rows, cols, i)
            if len(s) == 0:
                plt.title(f"{c} (no data)"); plt.hist([]); continue
            if s.min() >= 0 and (s.replace(0, np.nan).min(skipna=True) is not None):
                mn = s.replace(0, np.nan).min(skipna=True)
                if mn is not None and s.max()/max(1e-9, mn) > 1e3:
                    s = s.replace(0, np.nan).dropna()
                    plt.hist(np.log10(s.values), bins=60)
                    plt.xlabel("log10(value)")
                else:
                    plt.hist(s.values, bins=60)
            else:
                plt.hist(s.values, bins=60)
            plt.title(c, fontsize=10)
        plt.tight_layout()
        plt.savefig(os.path.join(eda_dir, "feature_hist_train.png"), dpi=150)
        plt.close()

    # 4) Correlation heatmap on a sample of train (numeric only)
    tr = df_full.loc[idx["train"], feature_cols]
    if len(tr) > 1:
        n_samp = min(corr_sample, len(tr))
        sample = tr.sample(n_samp, random_state=seed)
        C = sample.corr(numeric_only=True)
        plt.figure(figsize=(min(12, 0.25*C.shape[1]+4), min(10, 0.25*C.shape[0]+4)))
        plt.imshow(C.values, aspect='auto', interpolation='nearest')
        plt.colorbar()
        plt.title("corr_train_sample")
        plt.tight_layout()
        plt.savefig(os.path.join(eda_dir, "corr_train_sample.png"), dpi=150)
        plt.close()

    # 5) Attack-rate drift over file order (per client file)
    chunk = max(1000, min(10000, len(df_full)//100))  # ~100 points
    if chunk > 0:
        y = df_full["Label_binary"].astype("int8").values
        pts, rates = [], []
        for i in range(0, len(y), chunk):
            seg = y[i:i+chunk]
            pts.append(i + len(seg)//2)
            rates.append(float((seg==1).mean()))
        plt.figure(figsize=(10,3))
        plt.plot(pts, rates, marker='.')
        plt.title(f"attack_rate_by_position (chunks≈{chunk} rows)")
        plt.xlabel("row index (file order)")
        plt.ylabel("positive rate")
        plt.tight_layout()
        plt.savefig(os.path.join(eda_dir, "attack_rate_by_position.png"), dpi=150)
        plt.close()

    # 6) Multi-class snapshot (only if present AND enabled)
    if SHOW_LABEL_RAW and ("Label_raw" in df_full.columns):
        lab = (df_full["Label_raw"].astype(str).str.strip().str.lower()
               .replace({"": "<empty>", "nan": "<nan>"}))
        vc = lab.value_counts()
        if len(vc) > 0:
            vc_top = vc.head(30)
            plt.figure(figsize=(10, max(3, 0.28*len(vc_top))))
            plt.barh(vc_top.index[::-1], vc_top.values[::-1])
            plt.title("label_raw_distribution_top30")
            plt.xlabel("count")
            plt.tight_layout()
            plt.savefig(os.path.join(eda_dir, "label_raw_distribution_top30.png"), dpi=150)
            plt.close()
            pd.crosstab(df_full["Label_binary"].astype(int), lab, dropna=False) \
              .to_csv(os.path.join(eda_dir, "Label_binary_by_Label_raw.csv"))


The function below is a  helper verifies dataset integrity after preprocessing. It loads metadata and all saved arrays (X, y, and optionally k) for each split (train/val/test) and horizon (1, 5, 10). It asserts alignment (same number of samples across features, labels, and distance-to-event k), ensuring no mismatches before training.

In [None]:
def sanity(out_dir):
    with open(os.path.join(out_dir, "meta.json")) as f:
        meta = json.load(f)
    print("Shapes:", meta["shapes"])
    for split in ["train","val","test"]:
        X = np.load(os.path.join(out_dir, f"X_{split}.npy"), mmap_mode="r")
        n = X.shape[0]
        for h in (1,5,10):
            y = np.load(os.path.join(out_dir, f"y_{h}_{split}.npy"))
            assert len(y) == n, (split, h, "y len", len(y), "X len", n)
            k_path = os.path.join(out_dir, f"k_{h}_{split}.npy")
            if os.path.exists(k_path):
                k = np.load(k_path)
                assert len(k) == n, (split, h, "k len", len(k), "X len", n)
    print(" aligned for all splits/horizons (and K if present)")


# preprocessing results for each dataset

client 1 : ice ransomware dataset



In [None]:
process_client(ICE_CSV, OUT_ICE, DATASET_NAME)
sanity(OUT_ICE)



=== ICE ===
Reading /content/drive/MyDrive/fed_MID/iomt_clients/client_ICE.csv
[audit] duplicate groups (no removal): 507300 | groups with mixed labels: 0 | rows removed: 0
[audit] wrote per-hash stats to /content/drive/MyDrive/fed_MID/DATA_CLIENT_NPY/ICE/duplicates_report.csv
Positive fractions by horizon (train | val | test):
  h=1: 0.217 | 0.338 | 0.302
  h=5: 0.268 | 0.502 | 0.538
  h=10: 0.278 | 0.609 | 0.589
ICE: X_train (405780, 50, 15), X_val (25307, 50, 15), X_test (76036, 50, 15)
Horizon positives (train): {1: 88204, 5: 108858, 10: 112650}
Shapes: {'train': [405780, 50, 15], 'val': [25307, 50, 15], 'test': [76036, 50, 15]}
 aligned for all splits/horizons (and K if present)


VALIDATION WAS ACROSS ALL THE DATA

In [None]:
# Cell: validator
import os, json, numpy as np, glob

def validate_npys(out_dir, horizons=(1,5,10), W=50):
    problems = []
    # 1) presence
    need = ["meta.json","artifacts.pkl","feature_columns.txt","split_indices.json","class_weights.json"]
    for f in need:
        if not os.path.exists(os.path.join(out_dir, f)):
            problems.append(f"[missing] {f}")

    # 2) shapes & alignment
    with open(os.path.join(out_dir,"meta.json")) as f:
        meta = json.load(f)
    shapes = meta["shapes"]
    for split in ["train","val","test"]:
        X = np.load(os.path.join(out_dir, f"X_{split}.npy"), mmap_mode="r")
        if X.ndim != 3 or X.shape[1] != W:
            problems.append(f"[shape] X_{split} bad shape {X.shape}, expect (?,{W},F)")
        for h in horizons:
            y = np.load(os.path.join(out_dir, f"y_{h}_{split}.npy"))
            if len(y) != X.shape[0]:
                problems.append(f"[align] y_{h}_{split} len={len(y)} != X_{split} n={X.shape[0]}")

    # 3) label sanity (prevalence and monotonicity)
    def _prev(p): return float((p>0).mean()) if len(p) else 0.0
    for split in ["train","val","test"]:
        ys = [np.load(os.path.join(out_dir, f"y_{h}_{split}.npy")) for h in horizons]
        prevs = [_prev(y) for y in ys]
        if any(p==0 or p==1 for p in prevs):
            problems.append(f"[labels] {split} extreme class ratio {prevs}")
        # monotonic: y_h must be <= y_h' if h<h'
        for i in range(len(horizons)-1):
            if np.any((ys[i]==1) & (ys[i+1]==0)):
                problems.append(f"[labels] non-monotonic: h={horizons[i]} > h={horizons[i+1]} in {split}")

    # 4) feature names (just warn if time/IDs slipped through)
    bad_tokens = ("time","timestamp","starttime","uid","guid","mac","addr","sha","hash","patient","nhs")
    with open(os.path.join(out_dir,"feature_columns.txt")) as f:
        cols = [c.strip().lower() for c in f]
    bad = [c for c in cols if any(tok in c for tok in bad_tokens)]
    if bad:
        problems.append(f"[leak?] suspicious features present: {bad[:10]}{'...' if len(bad)>10 else ''}")

    # summary
    if problems:
        print(f"[VALIDATOR] Issues in {out_dir}:")
        for p in problems: print(" -", p)
    else:
        print(f"[VALIDATOR] OK: {out_dir}")

# Run for each client dir you produced
validate_npys(OUT_ICE, horizons=(1,5,10), W=50)



[VALIDATOR] OK: /content/drive/MyDrive/fed_MID/DATA_CLIENT_NPY/ICE


client 2 :  iomt-a (IoMT-TrafficData)





In [None]:
process_client(IOMT_A_CSV, OUT_IOMT_A, DATASET_NAME)
sanity(OUT_IOMT_A)



=== IOMT_A ===
Reading /content/drive/MyDrive/fed_MID/iomt_clients/client_iomt_A.csv
[audit] duplicate groups (no removal): 852920 | groups with mixed labels: 5 | rows removed: 0
[audit] wrote per-hash stats to /content/drive/MyDrive/fed_MID/DATA_CLIENT_NPY/IOMT_A/duplicates_report.csv
[split] AUTO switching to CHUNKED due to extreme val/test balance.
[split] pre-window pos-rate  train=0.794 (n=1343076)  val=0.500 (n=80000)  test=0.231 (n=260000)
Positive fractions by horizon (train | val | test):
  h=1: 0.794 | 0.500 | 0.231
  h=5: 0.794 | 0.500 | 0.231
  h=10: 0.794 | 0.500 | 0.231
IOMT_A: X_train (1343017, 50, 55), X_val (79941, 50, 55), X_test (259941, 50, 55)
Horizon positives (train): {1: 1065978, 5: 1066034, 10: 1066104}
Shapes: {'train': [1343017, 50, 55], 'val': [79941, 50, 55], 'test': [259941, 50, 55]}
 aligned for all splits/horizons (and K if present)


client 3 :  iomt-b (IoMT-TrafficData)

In [None]:
process_client(IOMT_B_CSV, OUT_IOMT_B, DATASET_NAME)
sanity(OUT_IOMT_B)



=== IOMT_B ===
Reading /content/drive/MyDrive/fed_MID/iomt_clients/client_iomt_B.csv
[audit] duplicate groups (no removal): 436558 | groups with mixed labels: 2 | rows removed: 0
[audit] wrote per-hash stats to /content/drive/MyDrive/fed_MID/DATA_CLIENT_NPY/IOMT_B/duplicates_report.csv
[split] AUTO switching to CHUNKED due to extreme val/test balance.
[split] pre-window pos-rate  train=0.879 (n=1240112)  val=0.600 (n=100000)  test=0.909 (n=220000)
Positive fractions by horizon (train | val | test):
  h=1: 0.879 | 0.600 | 0.909
  h=5: 0.879 | 0.600 | 0.909
  h=10: 0.879 | 0.600 | 0.909
IOMT_B: X_train (1240053, 50, 58), X_val (99941, 50, 58), X_test (219941, 50, 58)
Horizon positives (train): {1: 1090177, 5: 1090189, 10: 1090204}
Shapes: {'train': [1240053, 50, 58], 'val': [99941, 50, 58], 'test': [219941, 50, 58]}
 aligned for all splits/horizons (and K if present)


client 4:  WUSTL(wustl-ehms-2020)

In [None]:
process_client(WUSTL_CSV, OUT_WUSTL, DATASET_NAME)
sanity(OUT_WUSTL)



=== WUSTL ===
Reading /content/drive/MyDrive/fed_MID/iomt_clients/client_WUSTL.csv
[audit] duplicate groups (no removal): 16318 | groups with mixed labels: 0 | rows removed: 0
[audit] wrote per-hash stats to /content/drive/MyDrive/fed_MID/DATA_CLIENT_NPY/WUSTL/duplicates_report.csv
Positive fractions by horizon (train | val | test):
  h=1: 0.126 | 0.125 | 0.127
  h=5: 0.147 | 0.141 | 0.147
  h=10: 0.173 | 0.161 | 0.172
WUSTL: X_train (12995, 50, 18), X_val (757, 50, 18), X_test (2389, 50, 18)
Horizon positives (train): {1: 1641, 5: 1913, 10: 2253}
Shapes: {'train': [12995, 50, 18], 'val': [757, 50, 18], 'test': [2389, 50, 18]}
 aligned for all splits/horizons (and K if present)


# FEDERATED LEARNING AND LOCAL MODELS



Installs & imports the NEEDED LIBARARIIES (Flower, PyTorch, sklearn, matplotlib) and mounts Google Drive.

Sets project paths to the preprocessed client datasets (ICE, IOMT_A/B, WUSTL) plus config/results folders.

Fixes randomness  

Defines globals: window size W=50, prediction HORIZONS=[1,5,10], device, and training hyperparams (EPOCHS_LOCAL, BATCH, LR, GRAD_CLIP).

Early stopping policy: monitor validation PRAUC (or AUROC), with patience=2 and min_delta=1e-3, optionally verbose.

Federated simulation config: 4 clients, 7 rounds × 5 local epochs per round, plus optional robust aggregation knobs (FED_TRIMMED_BETA) and FedProx strength (FED_FEDPROX_MU).

Client name/index maps for convenient referencing during simulation.

In [None]:
!pip -q install "flwr[simulation]" torch torchvision pyyaml scikit-learn


In [None]:
#cell1  Core
import os, json, yaml, math, gc, pathlib, random
import numpy as np
from typing import Dict, List, Tuple

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Metrics/plots
from sklearn.metrics import (
    f1_score, roc_auc_score, average_precision_score,
    precision_recall_curve
)
import matplotlib.pyplot as plt

# Flower
import flwr
from flwr.app import Context
from flwr.client import NumPyClient, Client
from flwr.clientapp import ClientApp

from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.strategy import FedAvg, FedProx
from flwr.simulation import run_simulation
from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays

# Colab Drive
from google.colab import drive
drive.mount('/content/drive')

# ------------------ Paths ------------------
DRIVE_ROOT = "/content/drive/MyDrive/fed_MID/DATA_CLIENT_NPY"
CLIENTS = {
    "ICE":     os.path.join(DRIVE_ROOT, "ICE"),
    "IOMT_A":  os.path.join(DRIVE_ROOT, "IOMT_A"),
    "IOMT_B":  os.path.join(DRIVE_ROOT, "IOMT_B"),
    "WUSTL":   os.path.join(DRIVE_ROOT, "WUSTL"),
}

CFG_DIR = "/content/drive/MyDrive/fed_MID/configs"
CLIENT_CFG_DIR = os.path.join(CFG_DIR, "clients")
RESULTS_DIR = "/content/drive/MyDrive/fed_MID/results"
for d in [CFG_DIR, CLIENT_CFG_DIR, RESULTS_DIR]:
    pathlib.Path(d).mkdir(parents=True, exist_ok=True)

# ------------------ Globals ------------------
HORIZONS = [1, 5, 10]
W = 50
SEED = 42
random.seed(SEED); np.random.seed(SEED)
torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Training budgets
EPOCHS_LOCAL = 20                 # each deep model (local-only & centralized)
BATCH = 128
LR = 1e-3
GRAD_CLIP = 1.0
# ---- Early Stopping (validation-based) ----
ES_USE = True          # master toggle
ES_PATIENCE = 2        # epochs with no improvement before stopping
ES_MIN_DELTA = 1e-3    # min improvement to reset patience
ES_MONITOR = "prauc"   # "prauc"  or "auroc"
VERBOSE_ES = True      # print per-epoch train/val and overfit warnings

# Federated (simulation)
FED_NUM_CLIENTS = 4
FED_NUM_ROUNDS = 7
FED_EPOCHS_PER_ROUND = 5
FED_TRIMMED_BETA = 0.10
FED_FEDPROX_MU = 0.01  # proximal strength (client-side)

# Make client→index maps
CLIENT_NAMES = list(CLIENTS.keys())
PID_TO_NAME = {i: n for i, n in enumerate(CLIENT_NAMES)}
NAME_TO_PID = {n: i for i, n in PID_TO_NAME.items()}


  return datetime.utcnow().replace(tzinfo=utc)


Mounted at /content/drive


Feature alignment (centralized preprocessing)
Builds the union of feature names across all client sites and rewrites each site’s X_* arrays to that common order (missing features zero-filled). Copies labels/metadata when present, updates meta.json with shapes and n_features, and writes via chunked memmaps for scalability—producing aligned datasets under ALIGNED/.

In [None]:
 FEATURE ALIGNMENT (ONLY centralized)
import os, json, numpy as np, pathlib, shutil

DRIVE_ROOT = "/content/drive/MyDrive/fed_MID/DATA_CLIENT_NPY"
SRC_DIRS = {
    "ICE":   f"{DRIVE_ROOT}/ICE",
    "IOMT_A":f"{DRIVE_ROOT}/IOMT_A",
    "IOMT_B":f"{DRIVE_ROOT}/IOMT_B",
    "WUSTL": f"{DRIVE_ROOT}/WUSTL",
}
ALIGNED_ROOT = f"{DRIVE_ROOT}/ALIGNED"     # destination for aligned copies
pathlib.Path(ALIGNED_ROOT).mkdir(parents=True, exist_ok=True)

# Build union of feature names across all clients
feat_lists = {}
for name, d in SRC_DIRS.items():
    with open(os.path.join(d, "feature_columns.txt")) as f:
        feat_lists[name] = [c.strip() for c in f if c.strip()]
feat_union = sorted(set().union(*feat_lists.values()))
print("Union feature count:", len(feat_union))

def repack_to(dest_dir, src_dir, union_cols, W=50, splits=("train","val","test"), chunk=4000):
    os.makedirs(dest_dir, exist_ok=True)

    # Copy metadata/audits if present (optional; skipped if missing)
    meta_like = [
        "class_weights.json","split_indices.json","artifacts.pkl",
        "missing_inf_report_preclean.json","missing_inf_report_postclean.json",
        "negative_values_report.json","horizon_balance.json",
        "meta.json","cross_split_leak_report.json","leak_report.json"
    ]
    for fname in meta_like:
        src = os.path.join(src_dir, fname)
        if os.path.exists(src):
            shutil.copy2(src, dest_dir)

    # Copy labels (and k_*) as-is
    for split in splits:
        for h in (1,5,10):
            for base in [f"y_{h}_{split}.npy", f"k_{h}_{split}.npy"]:
                s = os.path.join(src_dir, base)
                if os.path.exists(s):
                    shutil.copy2(s, dest_dir)

    # Map X_* to union
    with open(os.path.join(src_dir, "feature_columns.txt")) as f:
        old_cols = [c.strip() for c in f if c.strip()]
    old_index = {c:i for i,c in enumerate(old_cols)}
    idx_map = [old_index.get(c, -1) for c in union_cols]

    for split in splits:
        src_X = os.path.join(src_dir, f"X_{split}.npy")
        X = np.load(src_X, mmap_mode="r")  # (N, W, F_old)
        N, W0, F_old = X.shape
        assert W0 == W, f"W mismatch in {src_X}: {W0} vs {W}"
        F_new = len(union_cols)
        tmp = os.path.join(dest_dir, f"X_{split}.npy.tmp")
        X_new = np.lib.format.open_memmap(tmp, mode="w+", dtype=np.float32, shape=(N, W, F_new))
        for a in range(0, N, chunk):
            b = min(N, a+chunk)
            block = X[a:b]                  # (B,W,F_old)
            out = np.zeros((b-a, W, F_new), np.float32)
            for j_new, j_old in enumerate(idx_map):
                if j_old >= 0:
                    out[:, :, j_new] = block[:, :, j_old]
            X_new[a:b] = out
        del X_new
        os.replace(tmp, os.path.join(dest_dir, f"X_{split}.npy"))
        print(f"[{os.path.basename(dest_dir)}] {split}: {F_old}→{F_new}, N={N}")

    # Update feature list and meta (shapes)
    with open(os.path.join(dest_dir, "feature_columns.txt"), "w") as f:
        f.write("\n".join(union_cols))
    meta_path = os.path.join(dest_dir, "meta.json")
    meta = {}
    if os.path.exists(meta_path):
        with open(meta_path) as f:
            try: meta = json.load(f)
            except: meta = {}
    meta.setdefault("shapes", {})
    for split in splits:
        Xs = np.load(os.path.join(dest_dir, f"X_{split}.npy"), mmap_mode="r")
        meta["shapes"][split] = [int(Xs.shape[0]), int(Xs.shape[1]), int(Xs.shape[2])]
    meta["n_features"] = len(union_cols)
    with open(meta_path, "w") as f:
        json.dump(meta, f, indent=2)

# Run repack for each client into ALIGNED/*
for name, src in SRC_DIRS.items():
    repack_to(os.path.join(ALIGNED_ROOT, name), src, feat_union, W=50)

print("Aligned copies ready under:", ALIGNED_ROOT)


# Data I/O and DataLoaders
 Loads per-split arrays (X_*, y_h_*, optional k_h_*), reads per-horizon class weights, wraps data in a windowed Dataset, and returns train/val/test DataLoaders plus the feature dimension required to instantiate models.

In [None]:

def load_split_arrays(data_dir: str, split: str, horizons: List[int]):
    X = np.load(os.path.join(data_dir, f"X_{split}.npy"), mmap_mode="r")
    ys = {h: np.load(os.path.join(data_dir, f"y_{h}_{split}.npy")) for h in horizons}
    return X, ys

def load_split_arrays_with_k(data_dir: str, split: str, horizons: List[int]):
    X = np.load(os.path.join(data_dir, f"X_{split}.npy"), mmap_mode="r")
    ys = {h: np.load(os.path.join(data_dir, f"y_{h}_{split}.npy")) for h in horizons}
    Ks = {}
    for h in horizons:
        p = os.path.join(data_dir, f"k_{h}_{split}.npy")
        Ks[h] = np.load(p) if os.path.exists(p) else None
    return X, ys, Ks

def load_class_weights(data_dir: str, horizons: List[int]):
    with open(os.path.join(data_dir, "class_weights.json")) as f:
        cw = json.load(f)
    # prefer "train" weights; fallback to 1.0
    out = {}
    for h in horizons:
        key = str(h)
        w = cw.get("train", {}).get(key, {}).get("w_pos", 1.0)
        out[h] = float(w)
    return out

class WindowDataset(Dataset):
    def __init__(self, X, Y_dict, horizons: List[int]):
        self.X = np.asarray(X, dtype=np.float32)
        self.horizons = list(horizons)
        self.Y = np.stack([np.asarray(Y_dict[h], dtype=np.int64) for h in horizons], axis=1)

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        #
        x = np.array(self.X[idx], copy=True)
        y = np.array(self.Y[idx], copy=True)
        return torch.from_numpy(x), torch.from_numpy(y)


def make_loaders(data_dir, horizons, batch_size, shuffle_train=True):
    Xtr, Ytr = load_split_arrays(data_dir, "train", horizons)
    Xva, Yva = load_split_arrays(data_dir, "val",   horizons)
    Xte, Yte = load_split_arrays(data_dir, "test",  horizons)
    ds_tr = WindowDataset(Xtr, Ytr, horizons)
    ds_va = WindowDataset(Xva, Yva, horizons)
    ds_te = WindowDataset(Xte, Yte, horizons)


    dl_tr = DataLoader(ds_tr, batch_size=batch_size, shuffle=shuffle_train, num_workers=0, pin_memory=True)
    dl_va = DataLoader(ds_va, batch_size=batch_size, shuffle=False,          num_workers=0, pin_memory=True)
    dl_te = DataLoader(ds_te, batch_size=batch_size, shuffle=False,          num_workers=0, pin_memory=True)
    return dl_tr, dl_va, dl_te, Xtr.shape[2]


# Metrics, calibration, and plots
Chooses decision thresholds on the validation set via macro-F1 grid search, computes evaluation metrics (macro-F1, AUROC, PR-AUC), estimates 95% bootstrap CIs, computes Expected Calibration Error (ECE), draws reliability diagrams, and provides precision@target-recall utilities and a safe savefig helper.

In [None]:

def pick_threshold_from_val(y_true, y_score):
    thr_grid = np.linspace(0.05, 0.95, 19)
    best_thr, best_f1 = 0.5, -1.0
    y_true = np.asarray(y_true, int); y_score = np.asarray(y_score, float)
    for t in thr_grid:
        y_pred = (y_score >= t).astype(int)
        f1 = f1_score(y_true, y_pred, average="macro", zero_division=0)
        if f1 > best_f1:
            best_f1, best_thr = f1, t
    return float(best_thr)

def evaluate_split(y_true, y_score, threshold):
    y_true = np.asarray(y_true, int); y_score = np.asarray(y_score, float)
    y_pred = (y_score >= threshold).astype(int)
    out = {}
    out["F1_macro"] = float(f1_score(y_true, y_pred, average="macro", zero_division=0))
    try: out["AUROC"] = float(roc_auc_score(y_true, y_score))
    except Exception: out["AUROC"] = float("nan")
    try: out["PR_AUC"] = float(average_precision_score(y_true, y_score))
    except Exception: out["PR_AUC"] = float("nan")
    out["threshold"] = float(threshold)
    return out

def bootstrap_ci(metric_fn, y_true, y_score, iters=1000, alpha=0.05, seed=SEED):
    rng = np.random.default_rng(seed)
    y_true = np.asarray(y_true); y_score = np.asarray(y_score)
    n = len(y_true)
    vals = []
    for _ in range(iters):
        idx = rng.integers(0, n, size=n)
        vals.append(metric_fn(y_true[idx], y_score[idx]))
    lo = float(np.quantile(vals, alpha/2)); hi = float(np.quantile(vals, 1-alpha/2))
    return lo, hi

def ece_score(y_true, y_score, n_bins=15):
    y_true = np.asarray(y_true, int); y_score = np.asarray(y_score, float)
    bins = np.linspace(0.0, 1.0, n_bins+1)
    ece = 0.0; N = len(y_true)
    for i in range(n_bins):
        m = (y_score >= bins[i]) & (y_score < bins[i+1])
        if m.any():
            conf = y_score[m].mean()
            acc  = ( (y_score[m] >= 0.5).astype(int) == y_true[m] ).mean()
            w = m.mean()
            ece += w * abs(acc - conf)
    return float(ece)

def plot_reliability(ax, y_true, y_score, n_bins=15, title="Calibration"):
    y_true = np.asarray(y_true, int); y_score = np.asarray(y_score, float)
    bins = np.linspace(0.0, 1.0, n_bins+1)
    xs, ys = [], []
    for i in range(n_bins):
        m = (y_score >= bins[i]) & (y_score < bins[i+1])
        if m.any():
            xs.append(y_score[m].mean())
            ys.append((y_score[m] >= 0.5).astype(int).mean())
    ax.plot([0,1],[0,1], linestyle="--")
    ax.plot(xs, ys, marker="o")
    ax.set_xlabel("Confidence"); ax.set_ylabel("Accuracy")
    ax.set_title(title)

def precision_at_recall(y_true, y_score, target_recall=0.90):
    p, r, _ = precision_recall_curve(y_true, y_score)
    # pick highest precision where recall≥target
    m = r >= target_recall
    return float(p[m].max()) if m.any() else float("nan")

def savefig(path):
    pathlib.Path(os.path.dirname(path)).mkdir(parents=True, exist_ok=True)
    plt.tight_layout(); plt.savefig(path, dpi=140); plt.close()


# Model: HIPFedPredict with attention (FedBN & private heads)
Defines an additive-attention model: 1×1 projection → Conv-BN-ReLU×2 → BiLSTM → attention → per-horizon linear heads. Exposes FL-aware sharing so BatchNorm, the projection layer, and heads remain local/private (FedBN + private heads), while the rest of the trunk is shared via ndarray export/import.

In [None]:

class AdditiveAttention(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.W = nn.Linear(in_dim, in_dim, bias=True)
        self.v = nn.Linear(in_dim, 1, bias=False)

    def forward(self, seq):  # (B, T, C)
        u = torch.tanh(self.W(seq))
        s = self.v(u).squeeze(-1)              # (B, T)
        a = torch.softmax(s, dim=1)            # (B, T)
        c = (seq * a.unsqueeze(-1)).sum(1)     # (B, C)
        return c, a


class HIPFedPredict(nn.Module):
    """Projection (1x1 Conv) → Conv-BN-ReLU x2 → BiLSTM → Additive Attention → private heads."""
    def __init__(self, in_features, horizons: List[int]):
        super().__init__()
        self.horizons = list(horizons)

        # projection on feature axis (treat (W,F) as (F,W) for Conv1d)
        self.proj  = nn.Conv1d(in_channels=in_features, out_channels=128, kernel_size=1)
        self.conv1 = nn.Conv1d(128, 128, kernel_size=3, padding=1, dilation=1)
        self.bn1   = nn.BatchNorm1d(128)
        self.conv2 = nn.Conv1d(128, 128, kernel_size=3, padding=2, dilation=2)
        self.bn2   = nn.BatchNorm1d(128)

        self.bilstm = nn.LSTM(input_size=128, hidden_size=128, batch_first=True, bidirectional=True)
        self.attn   = AdditiveAttention(in_dim=256)
        self.heads  = nn.ModuleDict({str(h): nn.Linear(256, 1) for h in self.horizons})

    def forward(self, x):           # x: (B, W, F)
        z = x.transpose(1, 2)       # (B, F, W)
        z = F.relu(self.bn1(self.conv1(self.proj(z))))   # (B,128,W)
        z = F.relu(self.bn2(self.conv2(z)))              # (B,128,W)
        z = z.transpose(1, 2)       # (B, W, 128)
        z, _ = self.bilstm(z)       # (B, W, 256)
        c, _ = self.attn(z)         # (B, 256)
        return {h: self.heads[str(h)](c).squeeze(-1) for h in self.horizons}

    # ----- FedBN + private heads: share everything EXCEPT BatchNorm and heads -----
    def shared_keys(self):
        bn_keys = set()
        for name, module in self.named_modules():
            if isinstance(module, nn.BatchNorm1d):
                for k in module.state_dict().keys():
                    bn_keys.add(f"{name}.{k}")
                for p_name, _ in module.named_parameters(recurse=False):
                    bn_keys.add(f"{name}.{p_name}")

        keys = []
        for k in self.state_dict().keys():
            if k.startswith("heads."):   # keep heads private
                continue
            if k.startswith("proj."):    # keep projection local
                continue
            if k in bn_keys:             # FedBN: keep BN local
                continue
            keys.append(k)
        return keys

    def get_shared_ndarrays(self):
        sd = self.state_dict()
        keys = self.shared_keys()
        arrs = [sd[k].detach().cpu().numpy() for k in keys]
        return keys, arrs

    def load_shared_ndarrays(self, keys, arrs, strict=False):
        sd = self.state_dict()
        for k, a in zip(keys, arrs):
            if k in sd:
                sd[k] = torch.from_numpy(a)
        self.load_state_dict(sd, strict=strict)


# Training & inference primitives (FedProx-ready)
Implements one training epoch with per-horizon weighted BCE, optional FedProx proximal regularization against a global snapshot, and gradient clipping. Adds batched inference (predict_logits), validation summaries (avg AUROC/PR-AUC), and comprehensive per-horizon reporting (F1, AUROC/PR-AUC with 95% CIs, P@R=0.90, ECE, and “earliness” when k_h_* is available)

In [None]:

def make_pos_weight_tensor(pos_weight_scalar):
    return torch.tensor([pos_weight_scalar], dtype=torch.float32, device=device)

def train_epoch(model, loader, optimizer, pos_weights: Dict[int, float],
                grad_clip=1.0, fedprox_mu=0.0, global_state: Dict[str, torch.Tensor] = None):
    model.train()
    total = 0.0
    for xb, yb in loader:
        xb = xb.to(device)          # (B, W, F)
        yb = yb.to(device)          # (B, H)
        optimizer.zero_grad()
        logits = model(xb)          # dict h -> (B,)
        loss = 0.0
        for i, h in enumerate(model.horizons):
            logit_h = logits[h]
            target_h = yb[:, i].float()
            pos_w = make_pos_weight_tensor(pos_weights[h])
            loss += nn.BCEWithLogitsLoss(pos_weight=pos_w)(logit_h, target_h)
        # FedProx proximal term (client-side)
        if fedprox_mu > 0.0 and global_state is not None:
            prox = 0.0
            for name, p in model.named_parameters():
                if p.requires_grad and (name in global_state):
                    prox += ((p - global_state[name].to(device)) ** 2).sum()
            loss = loss + (fedprox_mu / 2.0) * prox
        loss.backward()
        if grad_clip and grad_clip > 0:
            nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        optimizer.step()
        total += float(loss.item())
    return total / max(1, len(loader))

@torch.no_grad()
def predict_logits(model, loader):
    model.eval()
    scores = {h: [] for h in model.horizons}
    truths = {h: [] for h in model.horizons}
    for xb, yb in loader:
        xb = xb.to(device)
        out = model(xb)
        for i, h in enumerate(model.horizons):
            scores[h].append(torch.sigmoid(out[h]).cpu().numpy())
            truths[h].append(yb[:, i].cpu().numpy())
    scores = {h: np.concatenate(scores[h], 0) for h in model.horizons}
    truths = {h: np.concatenate(truths[h], 0) for h in model.horizons}
    return truths, scores

def val_summary(model, dl_va):
    """Return (avg_pr_auc, avg_auroc) across horizons on validation loader."""
    yv, sv = predict_logits(model, dl_va)
    pr, roc = [], []
    for h in model.horizons:
        try:
            pr.append(average_precision_score(yv[h], sv[h]))
        except Exception:
            pass
        try:
            roc.append(roc_auc_score(yv[h], sv[h]))
        except Exception:
            pass
    pr_avg  = float(np.mean(pr)) if pr else 0.0
    roc_avg = float(np.mean(roc)) if roc else 0.0
    return pr_avg, roc_avg

def pack_metrics_per_horizon(horizons, y_true_dict, y_score_dict, val_thr_dict, k_test=None):
    report = {}
    for h in horizons:
        y_true_h  = np.asarray(y_true_dict[h], dtype=int)
        y_score_h = np.asarray(y_score_dict[h], dtype=float)

        rep = evaluate_split(y_true_h, y_score_h, val_thr_dict[h])

        # Precision@Recall=0.90
        rep["Prec@R0.90"] = precision_at_recall(y_true_h, y_score_h, 0.90)

        # 95% CIs (bootstrap) for AUROC/PR-AUC
        try:
            lo, hi = bootstrap_ci(lambda yt, ys: roc_auc_score(yt, ys), y_true_h, y_score_h)
            rep["AUROC_CI95"] = [lo, hi]
        except Exception:
            rep["AUROC_CI95"] = ["nan", "nan"]
        try:
            lo, hi = bootstrap_ci(lambda yt, ys: average_precision_score(yt, ys), y_true_h, y_score_h)
            rep["PR_AUC_CI95"] = [lo, hi]
        except Exception:
            rep["PR_AUC_CI95"] = ["nan", "nan"]

        # Expected Calibration Error
        try:
            rep["ECE"] = ece_score(y_true_h, y_score_h)
        except Exception:
            rep["ECE"] = "nan"

        # Earliness (if k_{h}_* exists)
        if k_test is not None and k_test.get(h) is not None:
            y_pred_h = (y_score_h >= val_thr_dict[h]).astype(int)
            k_arr = np.asarray(k_test[h], dtype=int)
            m = (y_true_h == 1) & (y_pred_h == 1) & (k_arr > 0)
            if m.any():
                lead = (h - k_arr[m]) / float(h)
                rep["Earliness"] = float(np.clip(lead, 0.0, 1.0).mean())
            else:
                rep["Earliness"] = "N/A"
        else:
            rep["Earliness"] = "N/A"

        report[str(h)] = rep

    # Averages across horizons
    for key in ["F1_macro", "AUROC", "PR_AUC", "Prec@R0.90"]:
        vals = [report[str(h)][key] for h in horizons if isinstance(report[str(h)][key], (int, float))]
        report[f"_avg_{key}"] = float(np.mean(vals)) if vals else "N/A"

    return report


 Local training, curves, and centralized oracle
Trains a single-site model with early stopping on avg validation AUROC, saves learning-curve and calibration plots, and returns thresholds and a metrics report. Includes helpers to run all local baselines (results/local/) and to pool data across sites for a centralized “oracle” model with pooled class weights (results/centralized/).

In [None]:

def save_json(d, path):
    pathlib.Path(os.path.dirname(path)).mkdir(parents=True, exist_ok=True)
    with open(path, "w") as f:
        json.dump(d, f, indent=2)

# ---- helpers for validation/monitoring ----
@torch.no_grad()
def _val_metrics_avg(model, dl_va):
    """Compute threshold-free metrics on validation: avg AUROC/PR-AUC across horizons."""
    yv, sv = predict_logits(model, dl_va)
    aurocs, praucs = [], []
    for h in model.horizons:
        yt = np.asarray(yv[h])
        ys = np.asarray(sv[h])
        try:
            aurocs.append(roc_auc_score(yt, ys))
        except Exception:
            pass
        try:
            praucs.append(average_precision_score(yt, ys))
        except Exception:
            pass
    au = float(np.mean(aurocs)) if aurocs else float("nan")
    pr = float(np.mean(praucs)) if praucs else float("nan")
    return au, pr, yv, sv  #
def train_local_model(
    data_dir,
    horizons,
    feature_dim,
    epochs,
    batch_size,
    lr,
    grad_clip,
    class_pos_weights,
    model_cls=HIPFedPredict,
    fedprox_mu=0.0,
    global_state=None,
    # --- early stopping knobs ---
    early_stop=True,
    patience=4,
    min_delta=1e-4,
):
    """
    Trains a single local model with early stopping on validation AUROC (higher is better).
    Also logs per-epoch train loss and val AUROC/PR-AUC, saves curves to RESULTS_DIR.
    """
    from copy import deepcopy

    model = model_cls(in_features=feature_dim, horizons=horizons).to(device)
    dl_tr, dl_va, dl_te, _ = make_loaders(data_dir, horizons, batch_size)
    opt = torch.optim.Adam(model.parameters(), lr=lr)

    # histories for monitoring/plots
    hist = {"train_loss": [], "val_auroc_avg": [], "val_prauc_avg": []}

    # early-stopping state
    best_metric = -float("inf")  # monitor = val_auroc_avg (maximize)
    best_state = None
    best_epoch = -1
    wait = 0

    # ---- training loop ----
    for ep in range(1, epochs + 1):
        tr_loss = train_epoch(
            model,
            dl_tr,
            opt,
            class_pos_weights,
            grad_clip=grad_clip,
            fedprox_mu=fedprox_mu,
            global_state=global_state,
        )
        hist["train_loss"].append(float(tr_loss))

        val_au, val_pr, _, _ = _val_metrics_avg(model, dl_va)
        hist["val_auroc_avg"].append(val_au)
        hist["val_prauc_avg"].append(val_pr)

        # early stop on val_auroc_avg
        improved = (not np.isnan(val_au)) and (val_au > best_metric + min_delta)
        if improved:
            best_metric = float(val_au)
            best_state = deepcopy(model.state_dict())
            best_epoch = ep
            wait = 0
        else:
            wait += 1

        # print small progress line (handy in Colab logs)
        print(f"[{model_cls.__name__}] ep {ep:02d}/{epochs} | train_loss={tr_loss:.4f} | "
              f"val_AUROC={val_au:.4f} | val_PRAUC={val_pr:.4f} | best_ep={best_epoch} ", flush=True)

        if early_stop and wait >= patience:
            print(f"[{model_cls.__name__}] Early stop at epoch {ep} (patience={patience}, best_ep={best_epoch})")
            break

    # restore best weights (if we ever improved)
    if best_state is not None:
        model.load_state_dict(best_state, strict=False)

    # ---- choose thresholds on val (using the *current/best* model) ----
    yv, sv = predict_logits(model, dl_va)
    val_thr = {h: pick_threshold_from_val(yv[h], sv[h]) for h in horizons}

    # ---- final test evaluation ----
    yt, st = predict_logits(model, dl_te)
    _, _, Kte = load_split_arrays_with_k(data_dir, "test", horizons)
    report = pack_metrics_per_horizon(horizons, yt, st, val_thr, k_test=Kte)
    report["_early_stop"] = {
        "enabled": bool(early_stop),
        "best_epoch": int(best_epoch if best_epoch != -1 else len(hist["train_loss"])),
        "monitor": "val_auroc_avg",
        "best_val_auroc": float(best_metric) if best_metric != -float("inf") else "nan",
        "total_epochs_run": int(len(hist["train_loss"])),
    }

    # ---- save learning curves (per-client dir inferred from data_dir) ----
    client_name = os.path.basename(data_dir.rstrip("/"))
    base_dir = os.path.join(RESULTS_DIR, "local", client_name)
    pathlib.Path(base_dir).mkdir(parents=True, exist_ok=True)

    # JSON history
    save_json(hist, os.path.join(base_dir, f"{model_cls.__name__}_curves.json"))

    # Plots
    try:
        # Train loss
        plt.figure()
        plt.plot(range(1, len(hist["train_loss"]) + 1), hist["train_loss"], marker="o")
        plt.xlabel("Epoch"); plt.ylabel("Train loss"); plt.title(f"{model_cls.__name__} — train loss")
        savefig(os.path.join(base_dir, f"{model_cls.__name__}_train_loss.png"))

        # Val AUROC
        plt.figure()
        plt.plot(range(1, len(hist["val_auroc_avg"]) + 1), hist["val_auroc_avg"], marker="o")
        plt.xlabel("Epoch"); plt.ylabel("Val AUROC (avg)"); plt.title(f"{model_cls.__name__} — val AUROC")
        savefig(os.path.join(base_dir, f"{model_cls.__name__}_val_auroc.png"))

        # Val PR-AUC
        plt.figure()
        plt.plot(range(1, len(hist["val_prauc_avg"]) + 1), hist["val_prauc_avg"], marker="o")
        plt.xlabel("Epoch"); plt.ylabel("Val PR-AUC (avg)"); plt.title(f"{model_cls.__name__} — val PR-AUC")
        savefig(os.path.join(base_dir, f"{model_cls.__name__}_val_prauc.png"))
    except Exception:
        pass

    # ---- calibration figs per horizon ----
    for h in horizons:
        plt.figure()
        plot_reliability(plt.gca(), yt[h], st[h], title=f"Calibration (h={h})")
        savefig(os.path.join(base_dir, f"{model_cls.__name__}_calib_h{h}.png"))


    return model, report, val_thr

def run_local_only_all():
    summary = {}
    for cname, cdir in CLIENTS.items():
        print(f"[LOCAL] {cname}")
        dl_tr, dl_va, dl_te, F = make_loaders(cdir, HORIZONS, BATCH)
        posw = load_class_weights(cdir, HORIZONS)

        res = {"DL": {}}

        # Deep model
        _, report, _ = train_local_model(
            cdir,
            HORIZONS,
            F,
            epochs=EPOCHS_LOCAL,
            batch_size=BATCH,
            lr=LR,
            grad_clip=GRAD_CLIP,
            class_pos_weights=posw,
            model_cls=HIPFedPredict,
            early_stop=True,
            patience=4,
            min_delta=1e-4,
        )
        res["DL"]["HIPFedPredict"] = report

        # Save per-client and accumulate
        save_json(res, os.path.join(RESULTS_DIR, "local", cname, "results.json"))
        summary[cname] = res
        gc.collect()

    save_json(summary, os.path.join(RESULTS_DIR, "local", "summary.json"))
    print("[LOCAL] done.")

def run_centralized_oracle():
    # pool train/val
    Xtr_all, Ytr_all = [], {h: [] for h in HORIZONS}
    Xva_all, Yva_all = [], {h: [] for h in HORIZONS}
    pooled_F = None
    for cdir in CLIENTS.values():
        Xtr, Ytr = load_split_arrays(cdir, "train", HORIZONS)
        Xva, Yva = load_split_arrays(cdir, "val", HORIZONS)
        if pooled_F is None:
            pooled_F = Xtr.shape[2]
        Xtr_all.append(np.asarray(Xtr)); Xva_all.append(np.asarray(Xva))
        for h in HORIZONS:
            Ytr_all[h].append(Ytr[h]); Yva_all[h].append(Yva[h])
    Xtr_all = np.concatenate(Xtr_all, 0); Xva_all = np.concatenate(Xva_all, 0)
    Ytr_all = {h: np.concatenate(Ytr_all[h], 0) for h in HORIZONS}
    Yva_all = {h: np.concatenate(Yva_all[h], 0) for h in HORIZONS}

    POOL_DIR = "/content/drive/MyDrive/fed_MID/_central_pooled"
    pathlib.Path(POOL_DIR).mkdir(parents=True, exist_ok=True)
    np.save(os.path.join(POOL_DIR, "X_train.npy"), Xtr_all.astype(np.float32))
    np.save(os.path.join(POOL_DIR, "X_val.npy"),   Xva_all.astype(np.float32))
    for h in HORIZONS:
        np.save(os.path.join(POOL_DIR, f"y_{h}_train.npy"), Ytr_all[h].astype(np.int8))
        np.save(os.path.join(POOL_DIR, f"y_{h}_val.npy"),   Yva_all[h].astype(np.int8))

    # Pooled test = concat all tests
    Xte_all, Yte_all = [], {h: [] for h in HORIZONS}
    for cdir in CLIENTS.values():
        Xte, Yte = load_split_arrays(cdir, "test", HORIZONS)
        Xte_all.append(np.asarray(Xte))
        for h in HORIZONS:
            Yte_all[h].append(Yte[h])
    Xte_all = np.concatenate(Xte_all, 0)
    for h in HORIZONS:
        Yte_all[h] = np.concatenate(Yte_all[h], 0)
        np.save(os.path.join(POOL_DIR, f"y_{h}_test.npy"), Yte_all[h].astype(np.int8))
    np.save(os.path.join(POOL_DIR, "X_test.npy"), Xte_all.astype(np.float32))

    # pos weights from pooled train
    pooled_posw = {}
    for h in HORIZONS:
        p = float(np.mean(Ytr_all[h]))
        pooled_posw[h] = float(min(50.0, (1.0 - p) / max(1e-6, p))) if p > 0 else 1.0

    res = {"DL": {}}
    _, report, _ = train_local_model(
        POOL_DIR,
        HORIZONS,
        pooled_F,
        epochs=EPOCHS_LOCAL,
        batch_size=BATCH,
        lr=LR,
        grad_clip=GRAD_CLIP,
        class_pos_weights=pooled_posw,
        model_cls=HIPFedPredict,
        early_stop=True,
        patience=4,
        min_delta=1e-4,
    )
    res["DL"]["HIPFedPredict"] = report

    save_json(res, os.path.join(RESULTS_DIR, "centralized", "results.json"))
    print("[CENTRALIZED] done.")


**Flower client (NumPyClient)**
Implements a per-site client: sets device, builds loaders and class weights, initializes HIPFedPredict+Adam, and defines FL hooks. Shares only permitted layers, trains locally (optionally with FedProx), tracks the best validation AUROC to return the best shared weights + metrics, and saves each client’s full final model (local BN + private heads).

In [None]:
#cell9
from flwr.common import parameters_to_ndarrays  # keep your existing imports

class IoMTClient(NumPyClient):
    def __init__(self, name: str):
        self.name = name
        self.data_dir = CLIENTS[name]
        self.horizons = HORIZONS
        self.batch = BATCH
        self.epochs = FED_EPOCHS_PER_ROUND
        self.lr = LR
        self.grad_clip = GRAD_CLIP
        self.fedprox_mu = FED_FEDPROX_MU

        # --- pick device inside the actor (not from a global) ---
        _has_cuda = torch.cuda.is_available() and torch.cuda.device_count() > 0
        self.device = torch.device("cuda") if _has_cuda else torch.device("cpu")

        # data
        self.dl_tr, self.dl_va, self.dl_te, self.F = make_loaders(
            self.data_dir, self.horizons, self.batch
        )
        self.posw = load_class_weights(self.data_dir, self.horizons)

        # model/opt
        self.model = HIPFedPredict(self.F, self.horizons).to(self.device)
        self.opt = torch.optim.Adam(self.model.parameters(), lr=self.lr)


    def get_parameters(self, config):
        _, arrs = self.model.get_shared_ndarrays()
        return arrs

    @staticmethod
    def _normalize_incoming(parameters):
        """Accept either list[np.ndarray] or Flower Parameters; return list[np.ndarray] or None."""
        if parameters is None:
            return None
        if isinstance(parameters, list):
            return parameters
        try:
            return parameters_to_ndarrays(parameters)
        except Exception:

            raise TypeError(f"Unsupported parameters type: {type(parameters)}")

    def fit(self, parameters, config):
        # 1) Load global shared
        arrs = self._normalize_incoming(parameters)
        if arrs is not None:
            keys = self.model.shared_keys()
            self.model.load_shared_ndarrays(keys, arrs, strict=False)

        # 2) FedProx snapshot (pre-update)
        global_state = None
        if config.get("strategy_name", "") == "FedProx" and self.fedprox_mu > 0.0:
            global_state = {k: v.detach().clone().cpu() for k, v in self.model.state_dict().items()}

        # 3) Train for self.epochs, keep best-by-val AUROC
        best_val = -float("inf")
        best_params_shared = None
        best_metrics = None

        for _ in range(self.epochs):

            train_epoch(
                self.model, self.dl_tr, self.opt, self.posw,
                grad_clip=self.grad_clip,
                fedprox_mu=(self.fedprox_mu if config.get("strategy_name","")=="FedProx" else 0.0),
                global_state=global_state
            )

            yv, sv = predict_logits(self.model, self.dl_va)
            aurocs, praucs = [], []
            for h in self.horizons:
                try:  aurocs.append(roc_auc_score(yv[h], sv[h]))
                except: pass
                try:  praucs.append(average_precision_score(yv[h], sv[h]))
                except: pass
            val_auroc_avg = float(np.mean(aurocs)) if aurocs else 0.0
            val_prauc_avg = float(np.mean(praucs)) if praucs else 0.0

            if val_auroc_avg > best_val:
                best_val = val_auroc_avg
                _, best_params_shared = self.model.get_shared_ndarrays()
                best_metrics = {
                    "val_auroc_avg": val_auroc_avg,
                    "val_prauc_avg": val_prauc_avg,
                    "client_name": self.name,
                }


        if best_params_shared is None:
            _, best_params_shared = self.model.get_shared_ndarrays()
            yv, sv = predict_logits(self.model, self.dl_va)
            aurocs, praucs = [], []
            for h in self.horizons:
                try:  aurocs.append(roc_auc_score(yv[h], sv[h]))
                except: pass
                try:  praucs.append(average_precision_score(yv[h], sv[h]))
                except: pass
            best_metrics = {
                "val_auroc_avg": float(np.mean(aurocs)) if aurocs else 0.0,
                "val_prauc_avg": float(np.mean(praucs)) if praucs else 0.0,
                "client_name": self.name,
            }

        # 4) Save full model on last round
        rnd = int(config.get("round", 1))
        nrounds = int(config.get("num_rounds", 1))
        strategy_name = str(config.get("strategy_name", "FedAvg"))
        if rnd == nrounds:
            save_dir = os.path.join(RESULTS_DIR, "federated", strategy_name, "final_clients")
            pathlib.Path(save_dir).mkdir(parents=True, exist_ok=True)
            torch.save(self.model.state_dict(), os.path.join(save_dir, f"{self.name}_final.pt"))

        num_examples = len(self.dl_tr.dataset)
        # NumPyClient: return (list[np.ndarray], int, dict)
        return best_params_shared, num_examples, best_metrics

    def evaluate(self, parameters, config):
        return 0.0, len(self.dl_te.dataset), {}


# Server strategies & federation orchestration
Provides server-side strategies: FedAvg with round-wise weighted-metric logging and latest-weights retention, plus a β-Trimmed-Mean aggregator for robustness. Builds client/server apps with correctly initialized parameter ordering, runs the simulation, and saves per-round AUROC/PR-AUC plots and a round_log.json for analysis.

In [None]:
#cell10
class FedAvgKeepLast(FedAvg):
    """FedAvg that keeps latest aggregated parameters and logs round metrics."""
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.latest_params = None
        self.round_log = []  # list of dicts with aggregated fit metrics per round

    def aggregate_fit(self, rnd, results, failures):
        out = super().aggregate_fit(rnd, results, failures)
        if out is not None:
            params, metrics_agg = out
            if params is not None:
               self.latest_params = parameters_to_ndarrays(params)

        # aggregate client-returned metrics (weighted by examples)
        if results:
            total = sum([res.num_examples for _, res in results])
            w_auroc = 0.0; w_prauc = 0.0
            for _, res in results:
                m = res.metrics or {}
                n = res.num_examples
                w_auroc += n * float(m.get("val_auroc_avg", 0.0))
                w_prauc += n * float(m.get("val_prauc_avg", 0.0))
            self.round_log.append({
                "round": int(rnd),
                "val_auroc_avg_w": float(w_auroc/total) if total>0 else 0.0,
                "val_prauc_avg_w": float(w_prauc/total) if total>0 else 0.0
            })
        return out

class TrimmedMeanStrategy(FedAvgKeepLast):
    def __init__(self, beta=0.1, **kwargs):
        super().__init__(**kwargs)
        self.beta = beta
    def aggregate_fit(self, rnd, results, failures):
        if not results:
            # record a placeholder so round indexing stays consistent
            self.round_log.append({
                "round": int(rnd),
                "val_auroc_avg_w": 0.0,
                "val_prauc_avg_w": 0.0
            })
            return None

        # collect ndarrays from each client
        all_nd = [parameters_to_ndarrays(res.parameters) for _, res in results]
        # elementwise trimmed mean
        agg = []
        for layer_vals in zip(*all_nd):
            stacked = np.stack(layer_vals, axis=0)
            k = stacked.shape[0]
            lo = int(np.floor(self.beta * k))
            hi = int(np.ceil((1.0 - self.beta) * k))
            trimmed = np.sort(stacked, axis=0)[lo:hi]
            agg.append(trimmed.mean(axis=0))
        # wrap and record
        params = ndarrays_to_parameters(agg)
        # log metrics same as parent:
        total = sum([res.num_examples for _, res in results])
        w_auroc = 0.0; w_prauc = 0.0
        for _, res in results:
            m = res.metrics or {}; n = res.num_examples
            w_auroc += n * float(m.get("val_auroc_avg", 0.0))
            w_prauc += n * float(m.get("val_prauc_avg", 0.0))
        self.round_log.append({
            "round": int(rnd),
            "val_auroc_avg_w": float(w_auroc/total) if total>0 else 0.0,
            "val_prauc_avg_w": float(w_prauc/total) if total>0 else 0.0
        })
        self.latest_params = [a.copy() for a in agg]
        return (params, {})  # (parameters, metrics)

def build_client_app():
    def client_fn(context: Context) -> Client:
        pid = int(context.node_config["partition-id"])
        name = PID_TO_NAME[pid]
        return IoMTClient(name).to_client()
    return ClientApp(client_fn=client_fn)

def build_server_app(strategy_name="FedAvg", trimmed_beta=0.1):
    # initialize global params using SAME ordering as client shares (fix a)
    any_dir = CLIENTS[CLIENT_NAMES[0]]
    _, _, _, Fdim = make_loaders(any_dir, HORIZONS, BATCH)
    init = HIPFedPredict(Fdim, HORIZONS).to(device)
    _, arrs = init.get_shared_ndarrays()
    initial_parameters = ndarrays_to_parameters(arrs)

    if strategy_name == "FedAvg":
        strategy = FedAvgKeepLast(
            fraction_fit=1.0, fraction_evaluate=0.0,
            min_fit_clients=FED_NUM_CLIENTS, min_available_clients=FED_NUM_CLIENTS,
            initial_parameters=initial_parameters,
            on_fit_config_fn=lambda r: {"round": r, "num_rounds": FED_NUM_ROUNDS,
                                        "local_epochs": FED_EPOCHS_PER_ROUND,
                                        "strategy_name": "FedAvg"},
        )
    elif strategy_name == "TrimmedMean":
        strategy = TrimmedMeanStrategy(
            beta=trimmed_beta,
            fraction_fit=1.0, fraction_evaluate=0.0,
            min_fit_clients=FED_NUM_CLIENTS, min_available_clients=FED_NUM_CLIENTS,
            initial_parameters=initial_parameters,
            on_fit_config_fn=lambda r: {"round": r, "num_rounds": FED_NUM_ROUNDS,
                                        "local_epochs": FED_EPOCHS_PER_ROUND,
                                        "strategy_name": "TrimmedMean"},
        )
    elif strategy_name == "FedProx":
        # aggregation is FedAvg; proximal is applied on clients (fix b)
        strategy = FedAvgKeepLast(
            fraction_fit=1.0, fraction_evaluate=0.0,
            min_fit_clients=FED_NUM_CLIENTS, min_available_clients=FED_NUM_CLIENTS,
            initial_parameters=initial_parameters,
            on_fit_config_fn=lambda r: {"round": r, "num_rounds": FED_NUM_ROUNDS,
                                        "local_epochs": FED_EPOCHS_PER_ROUND,
                                        "strategy_name": "FedProx"},
        )
    else:
        raise ValueError("Unknown strategy")

    def server_fn(_: Context) -> ServerAppComponents:
        return ServerAppComponents(strategy=strategy, config=ServerConfig(num_rounds=FED_NUM_ROUNDS))

    return ServerApp(server_fn=server_fn), strategy

def run_federated(strategy_name="FedAvg", trimmed_beta=FED_TRIMMED_BETA):
    client_app = build_client_app()
    server_app, strategy = build_server_app(strategy_name=strategy_name, trimmed_beta=trimmed_beta)

    # run
    run_simulation(
        server_app=server_app,
        client_app=client_app,
        num_supernodes=FED_NUM_CLIENTS,
        backend_config={"client_resources": {"num_cpus": 1, "num_gpus": 1}},
    )

    # save round curves
    rounds = [r["round"] for r in strategy.round_log]
    va = [r["val_auroc_avg_w"] for r in strategy.round_log]
    vp = [r["val_prauc_avg_w"] for r in strategy.round_log]
    plt.figure(); plt.plot(rounds, va, marker="o"); plt.xlabel("Round"); plt.ylabel("Val AUROC (avg)")
    savefig(os.path.join(RESULTS_DIR, "federated", strategy_name, "val_auroc_rounds.png"))
    plt.figure(); plt.plot(rounds, vp, marker="o"); plt.xlabel("Round"); plt.ylabel("Val PR-AUC (avg)")
    savefig(os.path.join(RESULTS_DIR, "federated", strategy_name, "val_prauc_rounds.png"))

    # save history json
    save_json(strategy.round_log, os.path.join(RESULTS_DIR, "federated", strategy_name, "round_log.json"))
    return strategy


**Final federated evaluation (per client)**
Loads each client’s saved final model, overwrites only the shared trunk with the server’s latest global weights, re-selects thresholds on that client’s validation set, evaluates on its test set, and saves calibration plots and a consolidated final_reports.json under results/federated/<strategy>/.

In [None]:
#cell11
@torch.no_grad()
def eval_federated_final(strategy, strategy_name="FedAvg"):
    # latest global shared (trunk/proj, no BN, no heads)
    global_shared = strategy.latest_params
    reports = {}

    for cname, cdir in CLIENTS.items():
        # Rebuild model with the right input dim
        _, _, _, F = make_loaders(cdir, HORIZONS, BATCH)
        model = HIPFedPredict(F, HORIZONS).to(device)

        # Load the saved client model (has private heads and local BN)
        ckpt_path = os.path.join(RESULTS_DIR, "federated", strategy_name, "final_clients", f"{cname}_final.pt")
        if not os.path.exists(ckpt_path):
            print(f"[WARN] Missing saved model for {cname}: {ckpt_path}")
            continue
        state = torch.load(ckpt_path, map_location=device)
        model.load_state_dict(state, strict=False)

        # Overwrite only the shared (global) parts
        keys = model.shared_keys()
        model.load_shared_ndarrays(keys, global_shared, strict=False)

        # Evaluate: thresholds from val, then test
        dl_tr, dl_va, dl_te, _ = make_loaders(cdir, HORIZONS, BATCH)
        yv, sv = predict_logits(model, dl_va)
        val_thr = {h: pick_threshold_from_val(yv[h], sv[h]) for h in HORIZONS}
        yt, st = predict_logits(model, dl_te)
        _, _, Kte = load_split_arrays_with_k(cdir, "test", HORIZONS)
        rep = pack_metrics_per_horizon(HORIZONS, yt, st, val_thr, k_test=Kte)

        # Calibration figs
        for h in HORIZONS:
            plt.figure()
            plot_reliability(plt.gca(), yt[h], st[h], title=f"{strategy_name}:{cname} (h={h})")
            savefig(os.path.join(RESULTS_DIR, "federated", strategy_name, "calibration", f"{cname}_h{h}.png"))

        reports[cname] = rep

    save_json(reports, os.path.join(RESULTS_DIR, "federated", strategy_name, "final_reports.json"))
    print(f"[{strategy_name}] final per-client reports saved.")
    return reports


**Final toggles — Select and run experiments**
Simple flags choose which pipelines to execute (local-only, centralized oracle, and federated variants: FedAvg, Trimmed-Mean, FedProx). Running this cell executes the selected experiments, performs training/evaluation, and prints where outputs are stored.

In [None]:


# Toggles
RUN_LOCAL_ONLY   = True
RUN_CENTRALIZED  = False
RUN_FEDAVG       = False
RUN_TRIMMED_MEAN = False
RUN_FEDPROX      = False

# 1) Local-only baselines (per client)
if RUN_LOCAL_ONLY:
    print("\n=== Running Local-only baselines ===")
    run_local_only_all()

# 2) Centralized (oracle)
if RUN_CENTRALIZED:
    print("\n=== Running Centralized (oracle on ALIGNED features) ===")
    run_centralized_oracle()

# 3) Federated (simulation)

if RUN_FEDAVG:
    print("\n=== Federated: FedAvg ===")
    strat_fa = run_federated(strategy_name="FedAvg")
    reports_fa = eval_federated_final(strat_fa, "FedAvg")

if RUN_TRIMMED_MEAN:
    print("\n=== Federated: Trimmed-Mean ===")
    strat_tm = run_federated(strategy_name="TrimmedMean", trimmed_beta=FED_TRIMMED_BETA)
    reports_tm = eval_federated_final(strat_tm, "TrimmedMean")

if RUN_FEDPROX:
    print("\n=== Federated: FedProx ===")

    strat_fp = run_federated(strategy_name="FedProx")
    reports_fp = eval_federated_final(strat_fp, "FedProx")

print("\nAll selected runs finished. Outputs live under:", RESULTS_DIR)



=== Running Local-only baselines ===
[LOCAL] ICE
[HIPFedPredict] ep 01/20 | train_loss=0.2818 | val_AUROC=0.8788 | val_PRAUC=0.8871 | best_ep=1 


  return datetime.utcnow().replace(tzinfo=utc)


[HIPFedPredict] ep 02/20 | train_loss=0.2524 | val_AUROC=0.8791 | val_PRAUC=0.8871 | best_ep=2 


  return datetime.utcnow().replace(tzinfo=utc)


[HIPFedPredict] ep 03/20 | train_loss=0.2469 | val_AUROC=0.8844 | val_PRAUC=0.8936 | best_ep=3 


  return datetime.utcnow().replace(tzinfo=utc)


[HIPFedPredict] ep 04/20 | train_loss=0.2431 | val_AUROC=0.8797 | val_PRAUC=0.8913 | best_ep=3 


  return datetime.utcnow().replace(tzinfo=utc)


[HIPFedPredict] ep 05/20 | train_loss=0.2384 | val_AUROC=0.8815 | val_PRAUC=0.8935 | best_ep=3 


  return datetime.utcnow().replace(tzinfo=utc)


[HIPFedPredict] ep 06/20 | train_loss=0.2342 | val_AUROC=0.8795 | val_PRAUC=0.8911 | best_ep=3 


  return datetime.utcnow().replace(tzinfo=utc)


[HIPFedPredict] ep 07/20 | train_loss=0.2316 | val_AUROC=0.8860 | val_PRAUC=0.8945 | best_ep=7 


  return datetime.utcnow().replace(tzinfo=utc)


[HIPFedPredict] ep 08/20 | train_loss=0.2283 | val_AUROC=0.8777 | val_PRAUC=0.8852 | best_ep=7 


  return datetime.utcnow().replace(tzinfo=utc)


[HIPFedPredict] ep 09/20 | train_loss=0.2258 | val_AUROC=0.8707 | val_PRAUC=0.8809 | best_ep=7 


  return datetime.utcnow().replace(tzinfo=utc)


[HIPFedPredict] ep 10/20 | train_loss=0.2238 | val_AUROC=0.8698 | val_PRAUC=0.8757 | best_ep=7 


  return datetime.utcnow().replace(tzinfo=utc)


[HIPFedPredict] ep 11/20 | train_loss=0.2205 | val_AUROC=0.8749 | val_PRAUC=0.8857 | best_ep=7 
[HIPFedPredict] Early stop at epoch 11 (patience=4, best_ep=7)


  return datetime.utcnow().replace(tzinfo=utc)


[LOCAL] IOMT_A


  return datetime.utcnow().replace(tzinfo=utc)


[HIPFedPredict] ep 01/20 | train_loss=0.0028 | val_AUROC=0.9986 | val_PRAUC=0.9988 | best_ep=1 


  return datetime.utcnow().replace(tzinfo=utc)


[HIPFedPredict] ep 02/20 | train_loss=0.0016 | val_AUROC=0.9984 | val_PRAUC=0.9986 | best_ep=1 


  return datetime.utcnow().replace(tzinfo=utc)


[HIPFedPredict] ep 03/20 | train_loss=0.0014 | val_AUROC=0.9999 | val_PRAUC=0.9999 | best_ep=3 


  return datetime.utcnow().replace(tzinfo=utc)


[HIPFedPredict] ep 04/20 | train_loss=0.0014 | val_AUROC=0.9999 | val_PRAUC=0.9999 | best_ep=3 


  return datetime.utcnow().replace(tzinfo=utc)


[HIPFedPredict] ep 05/20 | train_loss=0.0014 | val_AUROC=0.9998 | val_PRAUC=0.9999 | best_ep=3 


  return datetime.utcnow().replace(tzinfo=utc)


[HIPFedPredict] ep 06/20 | train_loss=0.0013 | val_AUROC=0.9999 | val_PRAUC=0.9999 | best_ep=3 


  return datetime.utcnow().replace(tzinfo=utc)


[HIPFedPredict] ep 07/20 | train_loss=0.0012 | val_AUROC=0.9999 | val_PRAUC=0.9999 | best_ep=3 
[HIPFedPredict] Early stop at epoch 7 (patience=4, best_ep=3)


  return datetime.utcnow().replace(tzinfo=utc)


[LOCAL] IOMT_B


  return datetime.utcnow().replace(tzinfo=utc)


[HIPFedPredict] ep 01/20 | train_loss=0.0010 | val_AUROC=0.9997 | val_PRAUC=0.9998 | best_ep=1 


  return datetime.utcnow().replace(tzinfo=utc)


[HIPFedPredict] ep 02/20 | train_loss=0.0004 | val_AUROC=0.9998 | val_PRAUC=0.9999 | best_ep=1 


  return datetime.utcnow().replace(tzinfo=utc)


[HIPFedPredict] ep 03/20 | train_loss=0.0003 | val_AUROC=0.9999 | val_PRAUC=0.9999 | best_ep=3 


  return datetime.utcnow().replace(tzinfo=utc)


[HIPFedPredict] ep 04/20 | train_loss=0.0003 | val_AUROC=0.9999 | val_PRAUC=0.9999 | best_ep=3 


  return datetime.utcnow().replace(tzinfo=utc)


[HIPFedPredict] ep 05/20 | train_loss=0.0004 | val_AUROC=0.9998 | val_PRAUC=0.9999 | best_ep=3 


  return datetime.utcnow().replace(tzinfo=utc)


[HIPFedPredict] ep 06/20 | train_loss=0.0003 | val_AUROC=0.9995 | val_PRAUC=0.9997 | best_ep=3 


  return datetime.utcnow().replace(tzinfo=utc)


[HIPFedPredict] ep 07/20 | train_loss=0.0003 | val_AUROC=0.9998 | val_PRAUC=0.9999 | best_ep=3 
[HIPFedPredict] Early stop at epoch 7 (patience=4, best_ep=3)


  return datetime.utcnow().replace(tzinfo=utc)


[LOCAL] WUSTL


  return datetime.utcnow().replace(tzinfo=utc)


[HIPFedPredict] ep 01/20 | train_loss=2.4745 | val_AUROC=0.9478 | val_PRAUC=0.6812 | best_ep=1 


  return datetime.utcnow().replace(tzinfo=utc)


[HIPFedPredict] ep 02/20 | train_loss=1.8824 | val_AUROC=0.9374 | val_PRAUC=0.6940 | best_ep=1 


  return datetime.utcnow().replace(tzinfo=utc)


[HIPFedPredict] ep 03/20 | train_loss=1.7392 | val_AUROC=0.9571 | val_PRAUC=0.7612 | best_ep=3 


  return datetime.utcnow().replace(tzinfo=utc)


[HIPFedPredict] ep 04/20 | train_loss=1.6582 | val_AUROC=0.9292 | val_PRAUC=0.7014 | best_ep=3 


  return datetime.utcnow().replace(tzinfo=utc)


[HIPFedPredict] ep 05/20 | train_loss=1.4911 | val_AUROC=0.9459 | val_PRAUC=0.7710 | best_ep=3 


  return datetime.utcnow().replace(tzinfo=utc)


[HIPFedPredict] ep 06/20 | train_loss=1.3226 | val_AUROC=0.9197 | val_PRAUC=0.6605 | best_ep=3 


  return datetime.utcnow().replace(tzinfo=utc)


[HIPFedPredict] ep 07/20 | train_loss=1.1672 | val_AUROC=0.9299 | val_PRAUC=0.7145 | best_ep=3 
[HIPFedPredict] Early stop at epoch 7 (patience=4, best_ep=3)


  return datetime.utcnow().replace(tzinfo=utc)
  return datetime.utcnow().replace(tzinfo=utc)


[LOCAL] done.

All selected runs finished. Outputs live under: /content/drive/MyDrive/fed_MID/results


  return datetime.utcnow().replace(tzinfo=utc)
