In [40]:
import sys, os, math, re
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms

sys.path.append('./')
from models import EfficientKAN, FastKAN, BSRBF_KAN, FasterKAN, MLP, FC_KAN, WCSRBFKAN, WendlandCSRBF, RadialBasisFunction, WCSRBFKANSolo
from torch.serialization import add_safe_globals
add_safe_globals([
    EfficientKAN, FastKAN, BSRBF_KAN, FasterKAN, MLP, FC_KAN, WCSRBFKAN, WCSRBFKANSolo, WendlandCSRBF, RadialBasisFunction
])

DISPLAY_DATASET = {"mnist": "MNIST", "fashion_mnist": "Fashion-MNIST"}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [50]:
@torch.no_grad()
def check_csrbf_sparsity_on_dataset(model_path,
                                    dataset_name="MNIST",     # "MNIST" or "FashionMNIST"
                                    batch_size=256,
                                    max_batches=None,         # None = all test batches
                                    threshold=1e-8,           # tiny tol; W-CSRBF is truly zero outside support
                                    ):
    """
    Returns dict with overall averages:
      - 'density'
      - 'zero_fraction'
      - 'avg_active_centers_per_sample_feature'
    """

    model = torch.load(model_path, weights_only=False, map_location=device)
    model.eval()

    # print(model.device)

    tfm = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.5,), (0.5,))])
    name = dataset_name.strip().lower()
    if name in ("mnist", "mnist-test"):
        ds = datasets.MNIST(root="./data", train=False, download=True, transform=tfm)
    elif name in ("fashionmnist", "fashion-mnist", "fashion"):
        ds = datasets.FashionMNIST(root="./data", train=False, download=True, transform=tfm)
    else:
        raise ValueError("dataset_name must be 'MNIST' or 'FashionMNIST'.")

    loader = torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=False)

    layers = getattr(model, "layers", None)
    if layers is None and hasattr(model, "module"):
        layers = getattr(model.module, "layers", None)
    if layers is None:
        raise RuntimeError("Model must expose .layers (ModuleList of W-CSRBF layers in order).")

    total_active = 0
    total_elems  = 0
    sum_active_per_sf = 0.0   # sum over (# active centers) per (sample, feature)
    num_sf = 0                # total count of (sample, feature) pairs

    for b_idx, (xb, _) in enumerate(loader):
        if max_batches is not None and b_idx >= max_batches:
            break
        xb = xb.to(device)
        if xb.dim() == 4:               # (B,1,28,28) -> (B,784)
            xb = xb.view(xb.size(0), -1)
        x_curr = xb

        # walk through the network, measuring CSRBF φ at each CSRBF layer
        for layer in layers:
            # get the input to this layer in its own space (with LN if present)
            z = x_curr
            if hasattr(layer, "layernorm"):
                z = layer.layernorm(z)

            if hasattr(layer, "csrbf"):  # this is a W-CSRBF layer
                # φ: (B, D*M) -> (B, D, M)
                D = int(layer.csrbf.in_features)
                M = int(layer.csrbf.n_centers)
                phi = layer.csrbf(z).view(z.size(0), D, M)

                mask = (phi.abs() > threshold)
                total_active += int(mask.sum().item())
                total_elems  += int(mask.numel())

                active_per_sf = mask.sum(dim=2).float()   # (B, D)
                sum_active_per_sf += float(active_per_sf.sum().item())
                num_sf += int(active_per_sf.numel())

            elif hasattr(layer, "rbf"):  # this is a FastKANLayer with Gaussian RBFs
                # 'z' is already the layer's input in its own space (you applied layer.layernorm above)
                # FastKANLayer.forward does: rbf(layernorm(x)), so use the same here:
                phi = layer.rbf(z)                     # (B, D, M)
                B, D, M = phi.shape                    # D = z.size(1), M = layer.rbf.grid.numel()

                mask = (phi.abs() > threshold)         # for Gaussians, this is a "near-zero" cutoff
                total_active += int(mask.sum().item())
                total_elems  += int(mask.numel())

                active_per_sf = mask.sum(dim=2).float()  # (B, D)
                sum_active_per_sf += float(active_per_sf.sum().item())
                num_sf += int(active_per_sf.numel())


            # propagate to the next layer's input
            x_curr = layer(x_curr)

    if total_elems == 0 or num_sf == 0:
        raise RuntimeError("No CSRBF features were found/processed. Check that your model has CSRBF layers.")

    density = total_active / float(total_elems)
    zero_fraction = 1.0 - density
    avg_active_centers = sum_active_per_sf / float(num_sf)

    return {
        "density": density,
        "zero_fraction": zero_fraction,
        "avg_active_centers_per_sample_feature": avg_active_centers,
    }


In [51]:
"""

This code picks json files from folder which contain the model training and epoch metrics
Then extracts the best epoch out of how much ever training run you have done
displays the best and average metrics as a table 
and return the details of the best epochs with the best epoch model file name and path

"""


def canon_dataset(name: str) -> str:
    s = name.strip().lower().replace("-", "_")
    if s in {"fashionmnist", "fashion__mnist", "fashion_mnist"}:
        return "fashion_mnist"
    return s

def find_run_idx(path: Path) -> int | None:
    """Return the nearest parent named output<digits>."""
    for p in [path] + list(path.parents):
        m = re.fullmatch(r"output(\d+)", p.name, flags=re.IGNORECASE)
        if m:
            return int(m.group(1))
    return None

def scan_files(root: Path) -> list[dict]:
    """
    Find files like: output*/<dataset>/<model>/<model>__<dataset>__full.json
    Return: dict(run, dataset, model, path, json_name)
    """
    files = list(root.glob("output*/**/*__*__full.json"))
    rows = []
    for fp in files:
        m = re.fullmatch(r"(.+?)__([^_].+?)__full\.json", fp.name, flags=re.IGNORECASE)
        if not m:
            continue
        model = m.group(1).strip().lower()
        dataset = canon_dataset(m.group(2))
        run = find_run_idx(fp)
        if run is None:
            continue
        rows.append({
            "run": run,
            "dataset": dataset,
            "model": model,
            "path": fp.resolve(),
            "json_name": fp.name,
        })
    return rows

def load_json_any(path: Path) -> tuple[pd.DataFrame, float]:
    """Load JSONL (epochs + optional final {'training time': ...}) or JSON array."""
    try:
        df_all = pd.read_json(path, lines=True)
    except ValueError:
        with open(path, "r") as f:
            obj = json.load(f)
        df_all = pd.DataFrame(obj)

    ttime = pd.to_numeric(df_all.get("training time"), errors="coerce").dropna()
    ttime = float(ttime.iloc[0]) if not ttime.empty else np.nan

    if "epoch" not in df_all.columns:
        return pd.DataFrame(), ttime
    df = (df_all.dropna(subset=["val_accuracy"])
                 .sort_values("val_accuracy", ascending=False)
                 .reset_index(drop=True))
    return df, ttime

def percent(x: pd.Series) -> pd.Series:
    return (x * 100).round(2)

def build_display_table(df: pd.DataFrame) -> pd.DataFrame:
    out = df.copy()
    if "dataset" in out:
        out["Dataset"] = out["dataset"].map(DISPLAY_DATASET).fillna(out["dataset"])
    if "model" in out:
        out["#Params"] = out["model"].map(PARAMS).astype("Int64")
    if "train_acc" in out: out["Train. Acc."] = percent(out["train_acc"])
    if "val_acc" in out: out["Val. Acc."]= percent(out["val_acc"])
    if "f1" in out: out["F1"] = percent(out["f1"])
    if "precision" in out: out["precision"] = percent(out["precision"])
    if "recall" in out: out["recall"] = percent(out["recall"])
    if "time_sec"  in out: out["Time (sec)"]  = out["time_sec"].round(1)
    cols = ["Dataset","model","Train. Acc.","Val. Acc.","F1","precision","recall","Time (sec)","#Params",
            "run","final_epoch","json_name","path"]
    return out[[c for c in cols if c in out.columns]]

def table(ROOT):
    records = scan_files(ROOT)
    if not records:
        raise SystemExit("No matching files found (output*/**/*__*__full.json).")

    final_rows = []

    for rec in records:
        df, ttime = load_json_any(rec["path"])
        if df.empty or "val_accuracy" not in df or "train_accuracy" not in df:
            continue

        final = df.loc[df["val_accuracy"].idxmax()]
        final_rows.append({
            "run": rec["run"],
            "dataset": rec["dataset"],
            "model": rec["model"],
            "final_epoch": int(final["epoch"]),
            "val_acc": float(final["val_accuracy"]),
            "train_acc": float(final["train_accuracy"]),
            "f1": float(final.get("f1_macro", np.nan)),
            "precision": float(final.get("pre_macro", np.nan)),
            "recall": float(final.get("re_macro", np.nan)),
            "time_sec": float(ttime),
            "json_name": rec["json_name"],
            "path": str(rec["path"]),
        })

    final_df = pd.DataFrame(final_rows)
    if final_df.empty:
        raise SystemExit("No usable results (missing epoch/accuracy columns).")

    best_final = (
        final_df
        .sort_values(
            by=["dataset","model","val_acc"],
            ascending=[True, True, False]
        )
        .groupby(["dataset","model"], as_index=False)
        .head(1)
    )

    best_table = build_display_table(best_final.sort_values(by=["val_acc"],ascending=[False]))
    print("Best metrics per dataset/model (best epoch by Val. Acc)")
    print(best_table.to_string(index=False))

    avg_per_dataset = (
        final_df
        .groupby(["dataset","model"], as_index=False)
        .agg(train_acc=("train_acc","mean"),
            val_acc=("val_acc","mean"),
            f1=("f1","mean"),
            time_sec=("time_sec","mean"))
    )
    avg_table = build_display_table(avg_per_dataset.sort_values(by=["val_acc"],ascending=[False]))
    print("\nAverage metrics across runs (mean of best epoch by Val. Acc for 3 indipendent training runs) per dataset/model")
    print(avg_table.to_string(index=False))

    return best_table

In [52]:
def get_best_model_paths_dict(result):
    best_model_path = dict()
    for i, r in result.loc[:,["model","path"]].iterrows():
        path_list = r["path"].split("/")
        json_to_pth = path_list[-1].split(".")[0]+".pth"
        best_model_path[r["model"]] = "./"+"/".join(path_list[8:-1])+f"/{json_to_pth}"
    return best_model_path

In [53]:
PARAMS = {
    "bsrbf_kan": 457344,
    "fast_kan": 457418,
    "faster_kan": 406528,
    "efficient_kan": 508160,
    "wcsrbf_kan_un": 457492,
    "wcsrbf_kan_solo": 406602,
    "mlp": 50816,
    "fc_kan": 558979
}

root= Path(f"{os.getcwd()}/output_without_ln")
btable = table(ROOT=root)

Best metrics per dataset/model (best epoch by Val. Acc)
      Dataset           model  Train. Acc.  Val. Acc.    F1  precision  recall  Time (sec)  #Params  run  final_epoch                                 json_name                                                                                                                                         path
        MNIST          fc_kan        99.45      97.66 97.62      97.62   97.62       210.2   558979    2           15                  fc_kan__mnist__full.json                                   /mnt/c/Users/aravi/OneDrive/Desktop/Thesis/output_without_ln/output2/mnist/fc_kan/fc_kan__mnist__full.json
        MNIST             mlp        97.52      97.14 97.11      97.12   97.10       119.5    50816    3           15                     mlp__mnist__full.json                                         /mnt/c/Users/aravi/OneDrive/Desktop/Thesis/output_without_ln/output3/mnist/mlp/mlp__mnist__full.json
        MNIST   wcsrbf_kan_un        97.7

In [54]:
result_mnist = btable.query("Dataset=='Fashion-MNIST'")
best_model_path_mnist = get_best_model_paths_dict(result_mnist)
print(best_model_path_mnist)

{'fc_kan': './output_without_ln/output2/fashion_mnist/fc_kan/fc_kan__fashion_mnist__full.pth', 'bsrbf_kan': './output_without_ln/output3/fashion_mnist/bsrbf_kan/bsrbf_kan__fashion_mnist__full.pth', 'fast_kan': './output_without_ln/output1/fashion_mnist/fast_kan/fast_kan__fashion_mnist__full.pth', 'mlp': './output_without_ln/output3/fashion_mnist/mlp/mlp__fashion_mnist__full.pth', 'wcsrbf_kan_un': './output_without_ln/output1/fashion_mnist/wcsrbf_kan_un/wcsrbf_kan_un__fashion_mnist__full.pth', 'faster_kan': './output_without_ln/output1/fashion_mnist/faster_kan/faster_kan__fashion_mnist__full.pth', 'wcsrbf_kan_solo': './output_without_ln/output2/fashion_mnist/wcsrbf_kan_solo/wcsrbf_kan_solo__fashion_mnist__full.pth', 'efficient_kan': './output_without_ln/output2/fashion_mnist/efficient_kan/efficient_kan__fashion_mnist__full.pth'}


In [56]:
models_list = ["fast_kan", "wcsrbf_kan_un", "wcsrbf_kan_solo" ]

for model_name in models_list:
    path = best_model_path_mnist[model_name]
    stats_kan_base = check_csrbf_sparsity_on_dataset(
        model_path=path,
        dataset_name="Fashion-MNIST",
        batch_size=256,
        max_batches=None,   
        threshold=1e-8     
    )

    print(f"{model_name}: {stats_kan_base}")


fast_kan: {'density': 0.876907797759434, 'zero_fraction': 0.12309220224056605, 'avg_active_centers_per_sample_feature': 7.015262382075472}
wcsrbf_kan_un: {'density': 0.24408677771226414, 'zero_fraction': 0.7559132222877358, 'avg_active_centers_per_sample_feature': 1.9526942216981131}
wcsrbf_kan_solo: {'density': 0.24572326061320754, 'zero_fraction': 0.7542767393867924, 'avg_active_centers_per_sample_feature': 1.9657860849056603}


In [None]:
PARAMS = {
    "wcsrbf_kan_un": 457492,
    "wcsrbf_kan_tc_un": 464276,
    "wcsrbf_kan_ts_un": 464276,
    "wcsrbf_kan_tc_ts_un": 471060,
    "wcsrbf_kan_solo": 406602,
    "wcsrbf_kan_solo_tc": 413386,
    "wcsrbf_kan_solo_ts": 413386,
    "wcsrbf_kan_solo_tc_ts": 420170,
}

root = Path(f"{os.getcwd()}/wcsrbf_op")
btable_wcsrbf = table(ROOT=root)

In [58]:
result_mnist_wcsrbf = btable_wcsrbf.query("Dataset=='Fashion-MNIST'")
best_model_path_mnist_wcsrbf = get_best_model_paths_dict(result_mnist_wcsrbf)

print(best_model_path_mnist_wcsrbf)

{'wcsrbf_kan_tc_ts_un': './wcsrbf_op/output1/fashion_mnist/wcsrbf_kan_tc_ts_un/wcsrbf_kan_tc_ts_un__fashion_mnist__full.pth', 'wcsrbf_kan_ts_un': './wcsrbf_op/output3/fashion_mnist/wcsrbf_kan_ts_un/wcsrbf_kan_ts_un__fashion_mnist__full.pth', 'wcsrbf_kan_solo_tc_ts': './wcsrbf_op/output2/fashion_mnist/wcsrbf_kan_solo_tc_ts/wcsrbf_kan_solo_tc_ts__fashion_mnist__full.pth', 'wcsrbf_kan_solo_ts': './wcsrbf_op/output2/fashion_mnist/wcsrbf_kan_solo_ts/wcsrbf_kan_solo_ts__fashion_mnist__full.pth', 'wcsrbf_kan_tc_un': './wcsrbf_op/output1/fashion_mnist/wcsrbf_kan_tc_un/wcsrbf_kan_tc_un__fashion_mnist__full.pth', 'wcsrbf_kan_solo_tc': './wcsrbf_op/output2/fashion_mnist/wcsrbf_kan_solo_tc/wcsrbf_kan_solo_tc__fashion_mnist__full.pth', 'wcsrbf_kan_un': './wcsrbf_op/output1/fashion_mnist/wcsrbf_kan_un/wcsrbf_kan_un__fashion_mnist__full.pth', 'wcsrbf_kan_solo': './wcsrbf_op/output2/fashion_mnist/wcsrbf_kan_solo/wcsrbf_kan_solo__fashion_mnist__full.pth'}


In [64]:
models_list = ["wcsrbf_kan_tc_un", "wcsrbf_kan_solo_tc" ]

for model_name in models_list:
    path = best_model_path_mnist_wcsrbf[model_name]
    stats_kan_base = check_csrbf_sparsity_on_dataset(
        model_path=path,
        dataset_name="Fashion-MNIST",
        batch_size=256,
        max_batches=None,   
        threshold=1e-8     
    )

    print(f"{model_name}: {stats_kan_base}")


wcsrbf_kan_tc_un: {'density': 0.2503430571933962, 'zero_fraction': 0.7496569428066038, 'avg_active_centers_per_sample_feature': 2.00274445754717}
wcsrbf_kan_solo_tc: {'density': 0.24938012971698112, 'zero_fraction': 0.7506198702830189, 'avg_active_centers_per_sample_feature': 1.995041037735849}


In [65]:
models_list = ["wcsrbf_kan_ts_un", "wcsrbf_kan_solo_ts" ]

for model_name in models_list:
    path = best_model_path_mnist_wcsrbf[model_name]
    stats_kan_base = check_csrbf_sparsity_on_dataset(
        model_path=path,
        dataset_name="Fashion-MNIST",
        batch_size=256,
        max_batches=None,   
        threshold=1e-8     
    )

    print(f"{model_name}: {stats_kan_base}")


wcsrbf_kan_ts_un: {'density': 0.9831933372641509, 'zero_fraction': 0.016806662735849076, 'avg_active_centers_per_sample_feature': 7.865546698113207}
wcsrbf_kan_solo_ts: {'density': 0.9459000884433962, 'zero_fraction': 0.054099911556603764, 'avg_active_centers_per_sample_feature': 7.56720070754717}


In [66]:
models_list = ["wcsrbf_kan_tc_ts_un", "wcsrbf_kan_solo_tc_ts" ]

for model_name in models_list:
    path = best_model_path_mnist_wcsrbf[model_name]
    stats_kan_base = check_csrbf_sparsity_on_dataset(
        model_path=path,
        dataset_name="Fashion-MNIST",
        batch_size=256,
        max_batches=None,   
        threshold=1e-8     
    )

    print(f"{model_name}: {stats_kan_base}")


wcsrbf_kan_tc_ts_un: {'density': 0.9602665536556604, 'zero_fraction': 0.03973344634433962, 'avg_active_centers_per_sample_feature': 7.682132429245283}
wcsrbf_kan_solo_tc_ts: {'density': 0.9084645194575471, 'zero_fraction': 0.09153548054245286, 'avg_active_centers_per_sample_feature': 7.267716155660377}
