[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/CalculatedContent/xgboost2ww/blob/main/notebooks/XGBWWDataMultiSourceXGBoost2WW.ipynb)

# Multi-source xgbwwdata + xgboost2ww Experiments

This Colab notebook trains strong XGBoost classifiers across multiple dataset sources from `xgbwwdata`, converts each trained model to an `xgboost2ww` layer, and evaluates the resulting spectral diagnostics with WeightWatcher.

Key goals:
- Pull **N=20 datasets from each source** (`openml`, `pmlb`, `keel`, `libsvm`, `amlb`).
- Train a high-accuracy XGBoost model using train-only CV (as in `GoodModelsXGBoost2WW.ipynb`).
- Convert each model to a chosen xgboost2ww matrix (`W1`, `W2`, `W7`, or `W8`).
- Run WeightWatcher with `ERG=True` and `randomize=True`.
- Plot train/test accuracies versus `alpha`, `ERG_gap`, and `num_traps` in dedicated cells.


In [None]:
#@title Experiment configuration
MATRIX = "W8"  # @param ["W1", "W2", "W7", "W8"]

DATA_SOURCES = ["openml", "pmlb", "keel", "libsvm", "amlb"]
DATASETS_PER_SOURCE = 20
TARGET_DATASETS = DATASETS_PER_SOURCE * len(DATA_SOURCES)

CATALOG_CSV = "/content/drive/MyDrive/xgbwwdata/catalog_checkpoint/dataset_catalog.csv"
RNG = 0

TEST_SIZE = 0.20
NFOLDS = 5
T_TRAJ = 160

MAX_ROWS = 60000
MAX_FEATURES_GUARD = 50_000
MAX_DENSE_ELEMENTS = int(2e8)

GOOD_TRIALS = 5
CV_MAX_ROUNDS = 3000
CV_EARLY_STOP = 150
MIN_GOOD_TEST_ACC = 0.75


In [None]:
#@title Mount Google Drive and create output directory
from google.colab import drive
import os
from datetime import datetime

drive.mount("/content/drive", force_remount=False)
GDRIVE_DIR = "/content/drive/MyDrive/xgboost2ww_runs"
os.makedirs(GDRIVE_DIR, exist_ok=True)
print("Saving results under:", GDRIVE_DIR)


In [None]:
#@title Install dependencies and xgbwwdata
!apt-get -qq update && apt-get -qq install -y git

%pip install -q -U pip setuptools wheel
%pip install -q "pandas==2.2.2" xgboost weightwatcher scikit-learn scipy pyarrow xgboost2ww

!rm -rf /content/repo_xgbwwdata
!git clone https://github.com/CalculatedContent/xgbwwdata.git /content/repo_xgbwwdata
%run /content/repo_xgbwwdata/scripts/colab_install.py --repo /content/repo_xgbwwdata

import xgboost2ww
import xgbwwdata
print("xgboost2ww:", getattr(xgboost2ww, "__file__", None))
print("xgbwwdata:", getattr(xgbwwdata, "__file__", None))


In [None]:
#@title Imports and shared helpers
import warnings, time, gc
from pathlib import Path
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
import xgboost as xgb
import matplotlib.pyplot as plt
from scipy import sparse as sp

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, log_loss

import torch
import weightwatcher as ww

from xgbwwdata import load_dataset
from xgboost2ww import convert

rng = np.random.default_rng(RNG)


In [None]:
#@title Optional: GPU detection for XGBoost

def xgb_gpu_available() -> bool:
    try:
        Xtmp = np.random.randn(256, 8).astype(np.float32)
        ytmp = (Xtmp[:, 0] > 0).astype(np.int32)
        dtmp = xgb.DMatrix(Xtmp, label=ytmp)
        params = dict(
            objective="binary:logistic",
            eval_metric="logloss",
            tree_method="gpu_hist",
            predictor="gpu_predictor",
            max_depth=2,
            learning_rate=0.2,
            seed=RNG,
        )
        _ = xgb.train(params=params, dtrain=dtmp, num_boost_round=5, verbose_eval=False)
        return True
    except Exception:
        return False

USE_GPU = xgb_gpu_available()
print("XGBoost GPU available:", USE_GPU)


In [None]:
#@title Load dataset candidates from xgbwwdata catalog checkpoint CSV
if not Path(CATALOG_CSV).exists():
    raise FileNotFoundError(f"Catalog CSV not found: {CATALOG_CSV}")

df_catalog = pd.read_csv(CATALOG_CSV)
print("Catalog shape:", df_catalog.shape)

required_cols = {"dataset_uid", "source"}
missing = required_cols - set(df_catalog.columns)
if missing:
    raise ValueError(f"Catalog is missing required columns: {missing}")

# Normalize source keys to avoid case/whitespace mismatches between config and catalog
normalized_sources = [str(s).strip().lower() for s in DATA_SOURCES]
df_catalog["source_norm"] = df_catalog["source"].astype(str).str.strip().str.lower()

# Keep only the configured sources (if present in the catalog)
df_registry = df_catalog[df_catalog["source_norm"].isin(normalized_sources)].copy()
if len(df_registry) == 0:
    raise RuntimeError(
        "No datasets found in catalog for selected DATA_SOURCES after normalization. "
        f"Configured={normalized_sources}; available={sorted(df_catalog['source_norm'].dropna().unique())[:20]}"
    )

# Keep binary-classification candidates when task metadata is available
if "task_type" in df_registry.columns:
    df_registry = df_registry[
        df_registry["task_type"].astype(str).str.contains("classification", case=False, na=False)
    ].copy()

print("Candidates loaded from catalog:", len(df_registry))
preview_cols = [c for c in ["source", "dataset_uid", "name", "task_type"] if c in df_registry.columns]
display(df_registry[preview_cols].head(10))
print(df_registry["source_norm"].value_counts(dropna=False))


In [None]:

#@title Training and WeightWatcher utilities
import hashlib

def stable_int_seed(value, mod=(2**31 - 1)):
    digest = hashlib.sha256(str(value).encode("utf-8")).digest()
    return int.from_bytes(digest[:8], "little") % int(mod)


def encode_binary_labels(y):
    y = np.asarray(y)
    classes, encoded = np.unique(y, return_inverse=True)
    if len(classes) != 2:
        raise ValueError(f"Expected binary labels, got classes={classes}")
    return encoded.astype(np.int32), classes


def pick_good_params_via_cv(Xtr, ytr, nfold=5, *, dataset_seed: int):
    dtrain = xgb.DMatrix(Xtr, label=ytr)
    local_rng = np.random.default_rng(RNG + int(dataset_seed))

    best = None
    best_score = np.inf

    for _ in range(GOOD_TRIALS):
        params = dict(
            objective="binary:logistic",
            eval_metric="logloss",
            tree_method="hist",
            seed=RNG,
            learning_rate=float(10 ** local_rng.uniform(-2.0, -0.6)),
            max_depth=int(local_rng.integers(2, 7)),
            min_child_weight=float(10 ** local_rng.uniform(0.0, 2.0)),
            subsample=float(local_rng.uniform(0.6, 0.9)),
            colsample_bytree=float(local_rng.uniform(0.6, 0.9)),
            reg_lambda=float(10 ** local_rng.uniform(0.0, 2.0)),
            gamma=float(local_rng.uniform(0.0, 0.5)),
        )
        if USE_GPU:
            params["tree_method"] = "gpu_hist"
            params["predictor"] = "gpu_predictor"

        cv = xgb.cv(
            params=params,
            dtrain=dtrain,
            num_boost_round=CV_MAX_ROUNDS,
            nfold=nfold,
            stratified=True,
            early_stopping_rounds=CV_EARLY_STOP,
            seed=RNG,
            verbose_eval=False,
        )

        score = float(cv["test-logloss-mean"].iloc[-1])
        rounds = int(len(cv))
        if score < best_score:
            best_score = score
            best = (params, rounds, score)

    return best


def train_eval_fulltrain(Xtr, ytr, Xte, yte, params, rounds):
    dtr = xgb.DMatrix(Xtr, label=ytr)
    dte = xgb.DMatrix(Xte, label=yte)

    bst = xgb.train(params=params, dtrain=dtr, num_boost_round=rounds, verbose_eval=False)

    m_tr = bst.predict(dtr, output_margin=True).astype(np.float32)
    p_tr = 1.0 / (1.0 + np.exp(-m_tr))
    train_acc = float(accuracy_score(ytr, (p_tr >= 0.5).astype(int)))

    m_te = bst.predict(dte, output_margin=True).astype(np.float32)
    p_te = 1.0 / (1.0 + np.exp(-m_te))
    test_acc = float(accuracy_score(yte, (p_te >= 0.5).astype(int)))
    test_loss = float(log_loss(yte, np.vstack([1 - p_te, p_te]).T, labels=[0, 1]))

    return train_acc, test_acc, test_loss, bst


def ww_metrics_from_layer(layer):
    watcher = ww.WeightWatcher(model=layer)
    details_df = watcher.analyze(randomize=True, ERG=True, plot=False)
    alpha = float(details_df["alpha"].iloc[0]) if "alpha" in details_df.columns else np.nan
    traps = float(details_df["num_traps"].iloc[0]) if "num_traps" in details_df.columns else np.nan
    ERG_gap = float(details_df["ERG_gap"].iloc[0]) if "ERG_gap" in details_df.columns else np.nan
    return alpha, traps, ERG_gap


In [None]:
#@title Run experiment: keep up to N datasets per source
import collections

rows = []
rows_below_threshold = []
accepted_by_source = {str(s).strip().lower(): 0 for s in DATA_SOURCES}
skip = collections.Counter()

# Adaptive fallback so downstream plotting still works when thresholds are too strict.
fallback_min_keep = max(10, len(DATA_SOURCES))

print("Registry size:", len(df_registry))
print("Registry sources:\n", df_registry["source_norm"].value_counts(dropna=False))
print("DATA_SOURCES:", DATA_SOURCES)
print("DATASETS_PER_SOURCE:", DATASETS_PER_SOURCE)
print("TARGET_DATASETS:", TARGET_DATASETS)

registry_source_keys = set(df_registry["source_norm"].dropna().astype(str).unique()) if "source_norm" in df_registry.columns else set()
config_source_keys = set(accepted_by_source.keys())
if len(df_registry) == 0:
    print("df_registry is empty after filtering")
elif registry_source_keys and registry_source_keys.isdisjoint(config_source_keys):
    print("source keys mismatch")


t0 = time.time()
loop_iterations = 0
for i, (_, rec) in enumerate(df_registry.iterrows(), start=1):
    loop_iterations += 1
    source = str(rec.get("source_norm", rec.get("source", "unknown"))).strip().lower()
    dataset_uid = rec.get("dataset_uid", "<missing_uid>")

    if i % 25 == 0:
        print(f"progress i={i} dataset_uid={dataset_uid} source={source} accepted={accepted_by_source.get(source, 'n/a')}", flush=True)

    if source not in accepted_by_source:
        skip["source_not_in_config"] += 1
        continue
    if accepted_by_source[source] >= DATASETS_PER_SOURCE:
        skip["source_already_full"] += 1
        continue

    if all(v >= DATASETS_PER_SOURCE for v in accepted_by_source.values()):
        break

    try:
        X, y, meta = load_dataset(dataset_uid=dataset_uid, source=source, preprocess=True)
    except Exception as e:
        skip["load_fail"] += 1
        print("SKIP load:", dataset_uid, type(e).__name__, e)
        continue

    try:
        y, class_values = encode_binary_labels(y)
    except Exception:
        skip["nonbinary_labels"] += 1
        continue

    if int(X.shape[1]) > MAX_FEATURES_GUARD:
        skip["too_many_features_guard"] += 1
        continue

    if np.min(np.bincount(y)) < 2:
        skip["min_class_count_lt2"] += 1
        continue

    try:
        tr_idx, te_idx = train_test_split(
            np.arange(len(y)),
            test_size=TEST_SIZE,
            random_state=RNG,
            stratify=y,
        )
    except Exception:
        skip["train_fail"] += 1
        continue

    Xtr, Xte = X[tr_idx], X[te_idx]
    ytr, yte = y[tr_idx], y[te_idx]

    is_sparse = sp.issparse(Xtr)
    if is_sparse:
        Xtr = Xtr.tocsr().astype(np.float32)
        Xte = Xte.tocsr().astype(np.float32)
        MAX_NNZ = MAX_DENSE_ELEMENTS
        if int(Xtr.nnz) > MAX_NNZ:
            skip["sparse_guard_triggered"] += 1
            continue
    else:
        Xtr = np.asarray(Xtr, dtype=np.float32)
        Xte = np.asarray(Xte, dtype=np.float32)
        if int(Xtr.size) > MAX_DENSE_ELEMENTS:
            skip["dense_guard_triggered"] += 1
            continue

    seed_from_uid = stable_int_seed(dataset_uid)
    try:
        good_params, good_rounds, good_cv_logloss = pick_good_params_via_cv(
            Xtr, ytr, nfold=NFOLDS, dataset_seed=seed_from_uid
        )

        good_train_acc, good_test_acc, good_test_loss, bst = train_eval_fulltrain(
            Xtr, ytr, Xte, yte, good_params, good_rounds
        )
    except Exception as e:
        skip["train_fail"] += 1
        print("SKIP train:", dataset_uid, type(e).__name__, e)
        del X, y, Xtr, Xte, ytr, yte
        gc.collect()
        continue

    base_row = dict(
        dataset_uid=dataset_uid,
        source=source,
        dataset=meta.get("name", rec.get("name", dataset_uid)),
        original_classes=str(tuple(class_values.tolist())),
        n_rows_total=int(X.shape[0]),
        n_train=int(Xtr.shape[0]),
        n_test=int(Xte.shape[0]),
        n_features=int(X.shape[1]),
        rounds=int(good_rounds),
        cv_logloss=float(good_cv_logloss),
        good_train_acc=float(good_train_acc),
        good_test_acc=float(good_test_acc),
        good_test_loss=float(good_test_loss),
    )

    if good_test_acc < MIN_GOOD_TEST_ACC:
        skip["below_min_test_acc"] += 1
        rows_below_threshold.append(base_row)
        del bst, X, y, Xtr, Xte, ytr, yte
        gc.collect()
        continue

    try:
        layer_W = convert(
            model=bst,
            data=Xtr,
            labels=ytr,
            W=MATRIX,
            nfolds=NFOLDS,
            t_points=T_TRAJ,
            random_state=RNG,
            train_params=good_params,
            num_boost_round=good_rounds,
            multiclass="error",
            return_type="torch",
            verbose=False,
        )
    except Exception as e:
        skip["convert_fail"] += 1
        print("SKIP convert:", dataset_uid, type(e).__name__, e)
        del bst, X, y, Xtr, Xte, ytr, yte
        gc.collect()
        continue

    alpha_W, traps_W, ERG_gap_W = ww_metrics_from_layer(layer_W)
    rows.append(dict(base_row, alpha_W=float(alpha_W), traps_W=float(traps_W), ERG_gap_W=float(ERG_gap_W)))

    accepted_by_source[source] += 1
    kept_total = sum(accepted_by_source.values())
    elapsed = (time.time() - t0) / 60.0
    print(
        f"[{kept_total}/{TARGET_DATASETS}] src={source} ({accepted_by_source[source]}/{DATASETS_PER_SOURCE}) "
        f"{meta.get('name', dataset_uid)} | train/test={good_train_acc:.3f}/{good_test_acc:.3f} "
        f"| α={alpha_W:.2f} traps={traps_W:.1f} ERG_gap={ERG_gap_W:.2f} | elapsed={elapsed:.1f} min",
        flush=True,
    )

    del bst, layer_W, X, y, Xtr, Xte, ytr, yte
    gc.collect()

if loop_iterations == 0:
    if len(df_registry) == 0:
        print("df_registry is empty after filtering")
    else:
        print("source keys mismatch")

if len(rows) == 0 and len(rows_below_threshold) > 0:
    print(
        f"No datasets passed MIN_GOOD_TEST_ACC={MIN_GOOD_TEST_ACC:.3f}. "
        f"Falling back to top-{fallback_min_keep} by test accuracy for diagnostics. "
        f"below-threshold candidates={len(rows_below_threshold)}"
    )
    fallback_df = pd.DataFrame(rows_below_threshold).sort_values("good_test_acc", ascending=False).head(fallback_min_keep)
    fallback_df["alpha_W"] = np.nan
    fallback_df["traps_W"] = np.nan
    fallback_df["ERG_gap_W"] = np.nan
    fallback_df["fallback_only"] = True
    rows = fallback_df.to_dict(orient="records")

skip_summary = dict(sorted(skip.items(), key=lambda kv: kv[1], reverse=True))
print("Skip counters:", skip_summary)
print("Accepted per source:", accepted_by_source)

df_good = pd.DataFrame(rows)
print(f"DONE. datasets_kept={df_good['dataset_uid'].nunique() if len(df_good) else 0} rows={len(df_good)}")
display(df_good.head(20))


In [None]:
#@title Plot train/test accuracies vs alpha_W
if len(df_good) == 0:
    print("No datasets kept. Try lowering MIN_GOOD_TEST_ACC or reviewing the catalog selection.")
else:
    plt.figure(figsize=(6,4))
    plt.scatter(df_good["alpha_W"], df_good["good_train_acc"], label="train_acc", alpha=0.8)
    plt.scatter(df_good["alpha_W"], df_good["good_test_acc"], label="test_acc", alpha=0.8)
    plt.xlabel("alpha_W")
    plt.ylabel("accuracy")
    plt.title(f"Train/Test Accuracy vs alpha ({MATRIX})")
    plt.legend()
    plt.tight_layout()
    plt.show()


In [None]:
#@title Plot train/test accuracies vs ERG_gap_W
if len(df_good) == 0:
    print("No datasets kept. Try lowering MIN_GOOD_TEST_ACC or reviewing the catalog selection.")
else:
    plt.figure(figsize=(6,4))
    plt.scatter(df_good["ERG_gap_W"], df_good["good_train_acc"], label="train_acc", alpha=0.8)
    plt.scatter(df_good["ERG_gap_W"], df_good["good_test_acc"], label="test_acc", alpha=0.8)
    plt.xlabel("ERG_gap_W")
    plt.ylabel("accuracy")
    plt.title(f"Train/Test Accuracy vs ERG_gap ({MATRIX})")
    plt.legend()
    plt.tight_layout()
    plt.show()


In [None]:
#@title Plot train/test accuracies vs traps_W (num_traps)
if len(df_good) == 0:
    print("No datasets kept. Try lowering MIN_GOOD_TEST_ACC or reviewing the catalog selection.")
else:
    plt.figure(figsize=(6,4))
    plt.scatter(df_good["traps_W"], df_good["good_train_acc"], label="train_acc", alpha=0.8)
    plt.scatter(df_good["traps_W"], df_good["good_test_acc"], label="test_acc", alpha=0.8)
    plt.xlabel("traps_W")
    plt.ylabel("accuracy")
    plt.title(f"Train/Test Accuracy vs num_traps ({MATRIX})")
    plt.legend()
    plt.tight_layout()
    plt.show()


In [None]:
#@title Additional structural diagnostics
if len(df_good) == 0:
    print("No results to plot.")
else:
    plt.figure(); plt.hist(df_good["alpha_W"].dropna().values, bins=30)
    plt.xlabel("alpha_W"); plt.ylabel("count"); plt.title(f"Histogram of alpha({MATRIX})")
    plt.tight_layout(); plt.show()

    plt.figure(); plt.hist(df_good["traps_W"].dropna().values, bins=30)
    plt.xlabel("traps_W"); plt.ylabel("count"); plt.title(f"Histogram of traps({MATRIX})")
    plt.tight_layout(); plt.show()

    plt.figure(); plt.hist(df_good["ERG_gap_W"].dropna().values, bins=30)
    plt.xlabel("ERG_gap_W"); plt.ylabel("count"); plt.title(f"Histogram of ERG_gap({MATRIX})")
    plt.tight_layout(); plt.show()


In [None]:
#@title Save results to Google Drive
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
RESULTS_FEATHER = os.path.join(GDRIVE_DIR, f"{MATRIX}_multisource_results_{ts}.feather")

df_good.to_feather(RESULTS_FEATHER)
print(f"Saved {len(df_good)} rows to: {RESULTS_FEATHER}")


In [None]:
#@title Reload latest results from Google Drive and re-plot summary
import glob

files = sorted(glob.glob(os.path.join(GDRIVE_DIR, f"{MATRIX}_multisource_results_*.feather")))
if not files:
    raise FileNotFoundError(f"No {MATRIX}_multisource_results_*.feather files found in {GDRIVE_DIR}")

RESULTS_FEATHER = files[-1]
print("Loading:", RESULTS_FEATHER)

df = pd.read_feather(RESULTS_FEATHER)
print("Rows:", len(df), "| Cols:", len(df.columns))
display(df.head(10))

if "good_test_acc" in df.columns:
    df = df.sort_values("good_test_acc", ascending=False)

plt.figure(); plt.scatter(df["good_test_acc"].values, df["alpha_W"].values)
plt.xlabel("good_test_acc"); plt.ylabel("alpha_W")
plt.title(f"alpha({MATRIX}) vs test accuracy")
plt.tight_layout(); plt.show()

summary_cols = [c for c in [
    "dataset", "dataset_uid", "source", "good_train_acc", "good_test_acc",
    "alpha_W", "traps_W", "ERG_gap_W", "rounds"
] if c in df.columns]
print("Top 15 by test accuracy:")
display(df[summary_cols].head(15))
