In [None]:
import sys

import torch

sys.path.append('./FC_KAN-main/FC_KAN-main')
from models import EfficientKAN, FastKAN, BSRBF_KAN, FasterKAN, MLP, FC_KAN, WCSRBFKAN, WendlandCSRBF, RadialBasisFunction
from torch.serialization import add_safe_globals
add_safe_globals([
    EfficientKAN, FastKAN, BSRBF_KAN, FasterKAN, MLP, FC_KAN, WCSRBFKAN, WendlandCSRBF, RadialBasisFunction
])

DISPLAY_DATASET = {"mnist": "MNIST", "fashion_mnist": "Fashion-MNIST"}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
@torch.no_grad()
def _sparsity_stats(phi, threshold=1e-8):
    """
    phi: (B, D, M) tensor of basis activations
    Returns basic sparsity numbers.
    """
    B, D, M = phi.shape
    mask = (phi.abs() > threshold)
    nnz = mask.sum().item()
    total = mask.numel()
    density = nnz / float(total)
    zero_frac = 1.0 - density
    # active centers per (sample, feature)
    active_per_sf = mask.sum(dim=2).float()          # (B, D)
    avg_active = active_per_sf.mean().item()
    # histogram of how many centers are on per (sample,feature)
    hist = torch.bincount(active_per_sf.view(-1).to(torch.int64), minlength=M+1)
    return {
        "B": B, "D": D, "M": M,
        "density": density,
        "zero_fraction": zero_frac,
        "avg_active_centers_per_sample_feature": avg_active,
        "histogram_#active_centers_0..M": hist.tolist(),
    }

@torch.no_grad()
def check_sparsity_simple(
    B=1024, D=3, M=6,
    x_range=(-3.0, 3.0),
    k=2,
    center_range=(-2.0, 2.0),
    s_scale=0.8,       
    threshold=1e-6, # “near-zero” cutoff for Gaussian RBF (since it’s never exactly 0)
):
    lo, hi = x_range
    x = (hi - lo) * torch.rand(B, D) + lo  # (B, D)

    w = WendlandCSRBF(
        in_features=D, n_centers=M, k=k,
        center_range=center_range,
        per_feature_centers=True,
        trainable_centers=False, trainable_sigma=False,
        init_sigma=1.0, min_sigma=1e-6, s_scale=s_scale
    )
    # reshape phi to (B, D, M) for stats
    phi_w = w(x).view(B, D, M)
    stats_w = _sparsity_stats(phi_w, threshold=threshold)


    grid_min, grid_max = center_range
    denom = (grid_max - grid_min) / M - 1
    rbf = RadialBasisFunction(grid_min=grid_min, grid_max=grid_max, num_grids=M, denominator=denom)
    phi_g = rbf(x)  # (B, D, M)
    stats_g = _sparsity_stats(phi_g, threshold=threshold)

    print("WCSRBF (threshold =", threshold, ")")
    for k, v in stats_w.items():
        if isinstance(v, (int, float)):
            print(f"{k:>40}: {v}")
    print("histogram #active centers per (sample,feature):", stats_w["histogram_#active_centers_0..M"])

    print("\nGaussian RBF (threshold =", threshold, ")")
    for k, v in stats_g.items():
        if isinstance(v, (int, float)):
            print(f"{k:>40}: {v}")
    print("histogram #active centers per (sample,feature):", stats_g["histogram_#active_centers_0..M"])

# Run it
check_sparsity_simple()


=== WCSRBF (true sparsity; threshold = 1e-06 ) ===
                                       B: 1024
                                       D: 3
                                       M: 6
                                 density: 0.20670572916666666
                           zero_fraction: 0.7932942708333334
   avg_active_centers_per_sample_feature: 1.240234375
histogram #active centers per (sample,feature): [389, 1556, 1127, 0, 0, 0, 0]

=== Gaussian RBF (near-zero sparsity; threshold = 1e-06 ) ===
                                       B: 1024
                                       D: 3
                                       M: 6
                                 density: 0.40087890625
                           zero_fraction: 0.59912109375
   avg_active_centers_per_sample_feature: 2.4052734375
histogram #active centers per (sample,feature): [0, 579, 800, 1562, 131, 0, 0]
