# Müntz-Szász Networks (MSN) - Demo Notebook

This notebook demonstrates the core capabilities of **Müntz-Szász Networks**, a neural network architecture with learnable fractional power bases.

**Paper**: "Müntz-Szász Networks: Neural Architectures with Learnable Power-Law Bases"  
**Author**: Gnankan Landry Regis N'guessan

## Contents
1. Installation & Setup
2. Basic MSN Usage
3. Supervised Learning: Approximating √x
4. Interpretability: Examining Learned Exponents
5. PINN Example: Singular ODE
6. Comparison with MLP

## 1. Installation & Setup

In [None]:
# Install MSN (if not already installed)
# !pip install -e ..

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

# Import MSN components
import sys
sys.path.insert(0, '..')

from msn import MSN, MSNTrainer
from msn.baselines import MLP, build_param_matched_mlp
from msn.utils import count_params, dump_exponents

# Setup
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Plotting style
plt.rcParams['figure.figsize'] = (10, 4)
plt.rcParams['font.size'] = 12

## 2. Basic MSN Usage

MSN replaces fixed activations with learnable power functions:

$$\phi(x) = \sum_k a_k |x|^{\mu_k} + \sum_k b_k \text{sgn}(x)|x|^{\lambda_k}$$

where $\mu_k, \lambda_k$ are **learned exponents**.

In [None]:
# Create a simple MSN
model = MSN(
    dims=[1, 8, 8, 1],  # Architecture: 1 -> 8 -> 8 -> 1
    Ke=6,               # 6 even exponents per edge
    Ko=6,               # 6 odd exponents per edge
    p_max_even=4.0,     # Maximum exponent value
    exponent_mode="bounded"  # Stable parameterization
)

print(f"MSN Parameters: {count_params(model):,}")
print(f"\nArchitecture: {model.dims}")
print(f"Number of layers: {len(model.layers)}")

In [None]:
# Examine initial (random) exponents
exp_dict = dump_exponents(model, layer_idx=0)
print("Initial even exponents (μ):", [f"{x:.3f}" for x in exp_dict['mu']])
print("Initial odd exponents (λ):", [f"{x:.3f}" for x in exp_dict['lam']])

## 3. Supervised Learning: Approximating √x

The function $f(x) = \sqrt{x}$ has a singular derivative at $x=0$. This is challenging for standard MLPs but natural for MSN (which can learn $\mu \approx 0.5$).

In [None]:
# Generate training data
def f_sqrt(x):
    return torch.sqrt(x + 1e-12)

torch.manual_seed(42)
n_train = 2048
x_train = torch.rand(n_train, 1).to(device)
y_train = f_sqrt(x_train)

# Test data (uniform grid)
x_test = torch.linspace(0, 1, 500).view(-1, 1).to(device)
y_test = f_sqrt(x_test)

print(f"Training samples: {n_train}")

In [None]:
# Create and train MSN
msn_model = MSN(
    dims=[1, 8, 8, 1],
    Ke=6, Ko=6,
    p_max_even=2.0,  # Restrict to small exponents for sqrt
    exponent_mode="bounded"
).to(device)

optimizer = torch.optim.Adam(msn_model.parameters(), lr=2e-3)
criterion = nn.MSELoss()

# Training loop
losses = []
for step in range(2000):
    optimizer.zero_grad()
    pred = msn_model(x_train)
    loss = criterion(pred, y_train)
    
    # Add Müntz regularizer
    loss = loss + 0.01 * msn_model.muntz_regularizer(C=2.0)
    loss = loss + 1e-4 * msn_model.l1_coeff_regularizer()
    
    loss.backward()
    optimizer.step()
    
    losses.append(loss.item())
    if step % 500 == 0:
        print(f"Step {step}: loss = {loss.item():.6f}")

print(f"\nFinal loss: {losses[-1]:.6f}")

In [None]:
# Evaluate and visualize
msn_model.eval()
with torch.no_grad():
    y_pred = msn_model(x_test)
    rmse = torch.sqrt(torch.mean((y_pred - y_test)**2)).item()

print(f"MSN Test RMSE: {rmse:.5f}")

# Plot results
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Solution
axes[0].plot(x_test.cpu(), y_test.cpu(), 'b-', label=r'True: $\sqrt{x}$', linewidth=2)
axes[0].plot(x_test.cpu(), y_pred.cpu(), 'r--', label='MSN prediction', linewidth=2)
axes[0].set_xlabel('x')
axes[0].set_ylabel('y')
axes[0].set_title(r'MSN Approximation of $f(x) = \sqrt{x}$')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Error
error = torch.abs(y_pred - y_test).cpu()
axes[1].semilogy(x_test.cpu(), error + 1e-10, 'r-', linewidth=2)
axes[1].set_xlabel('x')
axes[1].set_ylabel('|error|')
axes[1].set_title('Pointwise Error (log scale)')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 4. Interpretability: Examining Learned Exponents

A key advantage of MSN is that learned exponents reveal the solution structure. For $\sqrt{x}$, we expect $\mu \approx 0.5$.

In [None]:
# Get learned exponents
exp_dict = dump_exponents(msn_model, layer_idx=0)

print("Learned even exponents (μ):")
for i, mu in enumerate(exp_dict['mu']):
    marker = "  ← CLOSE TO 0.5!" if abs(mu - 0.5) < 0.1 else ""
    print(f"  μ_{i+1} = {mu:.4f}{marker}")

print(f"\nTarget exponent for √x: α = 0.5")
print(f"Closest learned exponent: μ = {min(exp_dict['mu'], key=lambda x: abs(x-0.5)):.4f}")

In [None]:
# Visualize exponent distribution
fig, ax = plt.subplots(figsize=(8, 4))

mu_vals = exp_dict['mu']
lam_vals = exp_dict['lam']

ax.scatter(mu_vals, [1]*len(mu_vals), s=200, c='green', marker='o', label='Even (μ)', alpha=0.7)
ax.scatter(lam_vals, [0.5]*len(lam_vals), s=200, c='orange', marker='s', label='Odd (λ)', alpha=0.7)

# Mark target
ax.axvline(x=0.5, color='red', linestyle='--', linewidth=2, label=r'Target $\alpha=0.5$')

ax.set_xlabel('Exponent value')
ax.set_yticks([0.5, 1])
ax.set_yticklabels(['Odd (λ)', 'Even (μ)'])
ax.set_title(r'Learned Exponents for $f(x) = \sqrt{x}$')
ax.legend(loc='upper right')
ax.grid(True, alpha=0.3, axis='x')
ax.set_xlim(0, 2.5)

plt.tight_layout()
plt.show()

## 5. PINN Example: Singular ODE

Physics-Informed Neural Networks (PINNs) solve differential equations by embedding physics into the loss. Consider:

$$u'(x) = \frac{1}{2\sqrt{x}}, \quad u(0) = 0$$

The solution is $u(x) = \sqrt{x}$, which has a **singular derivative at $x=0$**.

In [None]:
def pinn_loss(model, n_col=2048):
    """PINN loss for u'(x) = 1/(2√x), u(0) = 0"""
    # Collocation points (biased toward x=0)
    x = torch.rand(n_col, 1, device=device) ** 2
    x.requires_grad_(True)
    
    u = model(x)
    u_x = torch.autograd.grad(u.sum(), x, create_graph=True)[0]
    
    # PDE residual: u' - 1/(2√x) = 0
    rhs = 0.5 / torch.sqrt(x + 1e-12)
    pde_loss = torch.mean((u_x - rhs) ** 2)
    
    # Boundary condition: u(0) = 0
    x0 = torch.zeros(256, 1, device=device)
    bc_loss = torch.mean(model(x0) ** 2)
    
    return pde_loss, bc_loss

In [None]:
# Create MSN for PINN
pinn_msn = MSN(
    dims=[1, 8, 8, 1],
    Ke=6, Ko=6,
    p_max_even=3.0,
    exponent_mode="bounded"
).to(device)

# Use MSNTrainer for stable training
trainer = MSNTrainer(
    pinn_msn,
    lr=1e-3,
    lr_exp_mult=0.02,
    warmup_steps=500,
    exp_grad_clip=0.05,
    use_muntz=True,
    use_l1=True,
    device=device
)

# Training loop
pinn_losses = []
for step in range(3000):
    pde_loss, bc_loss = pinn_loss(pinn_msn)
    total_loss = pde_loss + 200 * bc_loss
    
    metrics = trainer.step(total_loss)
    pinn_losses.append(metrics['loss'])
    
    if step % 500 == 0:
        warmup_status = "(warmup)" if not metrics['warmup_done'] else ""
        print(f"Step {step}: loss={metrics['loss']:.6f} {warmup_status}")

In [None]:
# Evaluate PINN solution
pinn_msn.eval()
with torch.no_grad():
    pinn_pred = pinn_msn(x_test)
    pinn_rmse = torch.sqrt(torch.mean((pinn_pred - y_test)**2)).item()

print(f"PINN MSN Test RMSE: {pinn_rmse:.5f}")

# Check learned exponents
pinn_exp = dump_exponents(pinn_msn, layer_idx=0)
closest_mu = min(pinn_exp['mu'], key=lambda x: abs(x-0.5))
print(f"Closest exponent to 0.5: μ = {closest_mu:.4f}")

## 6. Comparison with MLP

Let's compare MSN with a standard MLP (parameter-matched).

In [None]:
# Create parameter-matched MLP
msn_params = count_params(msn_model)
mlp_model, mlp_params, H = build_param_matched_mlp(
    Din=1, Dout=1, target_params=msn_params, depth=3
)
mlp_model = mlp_model.to(device)

print(f"MSN params: {msn_params}")
print(f"MLP params: {mlp_params} (H={H})")

In [None]:
# Train MLP
mlp_optimizer = torch.optim.Adam(mlp_model.parameters(), lr=2e-3)

mlp_losses = []
for step in range(2000):
    mlp_optimizer.zero_grad()
    pred = mlp_model(x_train)
    loss = criterion(pred, y_train)
    loss.backward()
    mlp_optimizer.step()
    mlp_losses.append(loss.item())

# Evaluate MLP
mlp_model.eval()
with torch.no_grad():
    mlp_pred = mlp_model(x_test)
    mlp_rmse = torch.sqrt(torch.mean((mlp_pred - y_test)**2)).item()

print(f"MLP Test RMSE: {mlp_rmse:.5f}")
print(f"MSN Test RMSE: {rmse:.5f}")
print(f"\nMSN improvement: {mlp_rmse/rmse:.1f}×")

In [None]:
# Final comparison plot
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Solutions
axes[0].plot(x_test.cpu(), y_test.cpu(), 'b-', label=r'True $\sqrt{x}$', linewidth=2)
axes[0].plot(x_test.cpu(), y_pred.cpu(), 'g--', label=f'MSN (RMSE={rmse:.4f})', linewidth=2)
axes[0].plot(x_test.cpu(), mlp_pred.cpu(), 'r:', label=f'MLP (RMSE={mlp_rmse:.4f})', linewidth=2)
axes[0].set_xlabel('x')
axes[0].set_ylabel('y')
axes[0].set_title('Solution Comparison')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Error comparison
msn_error = torch.abs(y_pred - y_test).cpu()
mlp_error = torch.abs(mlp_pred - y_test).cpu()

axes[1].semilogy(x_test.cpu(), msn_error + 1e-10, 'g-', label='MSN error', linewidth=2)
axes[1].semilogy(x_test.cpu(), mlp_error + 1e-10, 'r--', label='MLP error', linewidth=2)
axes[1].set_xlabel('x')
axes[1].set_ylabel('|error|')
axes[1].set_title('Pointwise Error (log scale)')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Learning curves
axes[2].semilogy(losses, 'g-', label='MSN', alpha=0.8)
axes[2].semilogy(mlp_losses, 'r-', label='MLP', alpha=0.8)
axes[2].set_xlabel('Step')
axes[2].set_ylabel('Loss')
axes[2].set_title('Training Loss')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Summary

This demo showed:

1. **MSN learns interpretable exponents**: For $\sqrt{x}$, MSN learns $\mu \approx 0.5$
2. **MSN outperforms MLP on singular functions**: Achieving lower error with the same parameters
3. **Stable training techniques**: Warmup, two-time-scale optimization, gradient clipping
4. **PINN integration**: MSN works naturally with physics-informed learning

For more details, see the paper and full experiments in `experiments/`.