In [11]:
import torch
import neural_tangents as nt
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Set, Tuple
import torch.nn as nn
import torch.optim as optim
import os
from scipy.linalg import eigh
from functools import partial

class MSPFunction:
    def __init__(self, P: int, sets: List[Set[int]]):
        self.P = P
        self.sets = sets
    
    def evaluate(self, z: torch.Tensor) -> torch.Tensor:
        device = z.device
        batch_size = z.shape[0]
        result = torch.zeros(batch_size, dtype=torch.float32, device=device)
        
        for S in self.sets:
            term = torch.ones(batch_size, dtype=torch.float32, device=device)
            for idx in S:
                term = term * z[:, idx]
            result = result + term
            
        return result

class DeepNN(nn.Module):
    def __init__(self, d: int, hidden_size: int, depth: int, mode: str = 'mup_pennington'):
        super().__init__()
        
        torch.set_default_dtype(torch.float32)
        
        self.mode = mode
        self.depth = depth
        self.hidden_size = hidden_size
        self.input_dim = d
        
        layers = []
        prev_dim = d
        self.layer_lrs = []  # Store layerwise learning rates
        
        for layer_idx in range(depth):
            linear = nn.Linear(prev_dim, hidden_size)
            
            if mode == 'mup_pennington':
                # muP initialization and learning rates from the paper
                if layer_idx == 0:  # Embedding layer
                    std = 1.0 / np.sqrt(prev_dim)
                    lr_scale = 1.0  # O(1) learning rate for embedding
                else:  # Hidden layers
                    std = 1.0 / np.sqrt(prev_dim)
                    lr_scale = 1.0 / prev_dim  # O(1/n) learning rate for hidden
                nn.init.normal_(linear.weight, mean=0.0, std=std)
                nn.init.zeros_(linear.bias)
                self.layer_lrs.append(lr_scale)
            
            layers.extend([
                linear,
                nn.ReLU()
            ])
            prev_dim = hidden_size
        
        # Final layer
        final_layer = nn.Linear(prev_dim, 1)
        if mode == 'mup_pennington':
            std = 1.0 / np.sqrt(prev_dim)
            lr_scale = 1.0 / prev_dim  # O(1/n) learning rate for readout
            nn.init.normal_(final_layer.weight, std=std)
            self.layer_lrs.append(lr_scale)
            
        nn.init.zeros_(final_layer.bias)
        layers.append(final_layer)
        
        self.network = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x).squeeze()

def generate_datasets(P: int, d: int, n_train: int, n_test: int, msp: MSPFunction, seed: int = 42):
    """Generate training and test datasets"""
    torch.manual_seed(seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    X_train = (2 * torch.bernoulli(0.5 * torch.ones((n_train, d), dtype=torch.float32)) - 1).to(device)
    y_train = msp.evaluate(X_train)
    
    X_test = (2 * torch.bernoulli(0.5 * torch.ones((n_test, d), dtype=torch.float32)) - 1).to(device)
    y_test = msp.evaluate(X_test)
    
    return X_train, y_train, X_test, y_test

def convert_data_to_jax(data: torch.Tensor) -> jnp.ndarray:
    """Convert PyTorch tensor to JAX array"""
    return jnp.array(data.detach().cpu().numpy())

def compute_kernels(X_train: jnp.ndarray, X_test: jnp.ndarray, params: List[jnp.ndarray], 
                   architecture: nn.Module) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """Compute NNGP and NTK kernels for both train and test sets"""
    # Create neural_tangents network
    layers = []
    
    # First layer
    layers.extend([
        nt.stax.Dense(architecture.hidden_size),
        nt.stax.Relu()
    ])
    
    # Additional hidden layers
    for _ in range(architecture.depth - 1):
        layers.extend([
            nt.stax.Dense(architecture.hidden_size),
            nt.stax.Relu()
        ])
    
    # Output layer
    layers.append(nt.stax.Dense(1))
    
    # Create the network
    init_fn, apply_fn, kernel_fn = nt.stax.serial(*layers)
    
    # Compute train kernels
    ntk_train = kernel_fn(X_train, X_train, 'ntk')
    nngp_train = kernel_fn(X_train, X_train, 'nngp')
    
    # Compute test-train kernels
    ntk_test = kernel_fn(X_test, X_train, 'ntk')
    nngp_test = kernel_fn(X_test, X_train, 'nngp')
    
    return nngp_train, ntk_train, nngp_test, ntk_test

def compute_spectrum(kernel: jnp.ndarray, k: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Compute top k eigenvalues and eigenvectors"""
    # Convert JAX array to numpy array for eigendecomposition
    kernel_np = np.array(kernel)
    n = kernel_np.shape[0]
    
    # Compute all eigenvalues and eigenvectors
    eigenvals, eigenvecs = eigh(kernel_np)
    
    # Sort in descending order and take top k
    idx = eigenvals.argsort()[::-1][:k]
    eigenvals = eigenvals[idx]
    eigenvecs = eigenvecs[:, idx]
    
    return eigenvals, eigenvecs

def plot_spectrum(eigenvals_nngp: jnp.ndarray, eigenvals_ntk: jnp.ndarray, 
                 save_path: str, title: str):
    """Plot eigenvalue spectrum for both NNGP and NTK"""
    plt.figure(figsize=(10, 6))
    plt.semilogy(np.arange(len(eigenvals_nngp)), eigenvals_nngp, 'b-', label='NNGP')
    plt.semilogy(np.arange(len(eigenvals_ntk)), eigenvals_ntk, 'r-', label='NTK')
    plt.xlabel('Index')
    plt.ylabel('Eigenvalue (log scale)')
    plt.title(title)
    plt.legend()
    plt.grid(True)
    plt.savefig(save_path)
    plt.close()

def plot_eigenfunctions(eigenvecs: jnp.ndarray, X: jnp.ndarray, y: jnp.ndarray,
                       save_path: str, title: str, k: int):
    """Plot eigenfunctions as a grid of colored squares based on integer values"""
    # Convert JAX arrays to numpy arrays
    eigenvecs_np = np.array(eigenvecs)
    
    # Convert eigenfunction values to nearest integers
    values = np.round(eigenvecs_np).astype(int)
    n_points = len(X)
    
    # Calculate grid dimensions
    grid_size = int(np.ceil(np.sqrt(n_points)))
    
    # Create subplot grid
    n_rows = (k + 2) // 3
    fig, axes = plt.subplots(n_rows, 3, figsize=(15, 5*n_rows))
    axes = axes.ravel()
    
    # Plot each eigenfunction
    for i in range(k):
        # Create grid with zeros for padding
        grid_values = np.zeros((grid_size, grid_size), dtype=int)
        
        # Fill in the actual values
        num_rows = values[:, i].size // grid_size
        remaining = values[:, i].size % grid_size
        
        # Fill complete rows
        grid_values[:num_rows, :] = values[:num_rows*grid_size, i].reshape(num_rows, grid_size)
        
        # Fill remaining values in last row
        if remaining > 0:
            grid_values[num_rows, :remaining] = values[num_rows*grid_size:, i]
        
        # Get unique integer values for colormap
        unique_vals = np.unique(values[:, i])
        vmin, vmax = np.min(unique_vals), np.max(unique_vals)
        
        # Create custom colormap with discrete colors
        num_colors = vmax - vmin + 1
        cmap = plt.cm.get_cmap('viridis', num_colors)
        
        # Create heatmap
        im = axes[i].imshow(grid_values, 
                          cmap=cmap,
                          vmin=vmin - 0.5,
                          vmax=vmax + 0.5)
        
        # Add colorbar with integer ticks
        colorbar = plt.colorbar(im, ax=axes[i], ticks=np.arange(vmin, vmax + 1))
        colorbar.set_label('Rounded Value')
        
        axes[i].set_title(f'Eigenfunction {i+1}')
        axes[i].axis('off')  # Hide axes
    
    # Remove empty subplots
    for i in range(k, len(axes)):
        fig.delaxes(axes[i])
    
    plt.suptitle(title)
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def compute_eigenfunction_values(kernel: jnp.ndarray, eigenvecs: jnp.ndarray, train_kernel: jnp.ndarray, y_train: jnp.ndarray):
    """
    Compute eigenfunction values in function space by projecting through the kernel
    
    Parameters:
    kernel: kernel between points we want to evaluate at and training points
    eigenvecs: eigenvectors from training kernel
    train_kernel: training kernel (used for coefficient computation)
    y_train: training labels
    """
    # Compute coefficients using training data
    y_train = y_train.reshape(-1, 1)  # Ensure y is a column vector
    coeffs = jnp.dot(eigenvecs.T, y_train)
    
    # Project to get function values using the appropriate kernel
    function_values = jnp.dot(kernel, eigenvecs)
    
    # Scale by the coefficients
    function_values = function_values * coeffs.T
    
    return function_values

def analyze_model(model_path: str, 
                 X_train: torch.Tensor, y_train: torch.Tensor,
                 X_test: torch.Tensor, y_test: torch.Tensor,
                 d: int, hidden_size: int, depth: int, mode: str,
                 k: int, output_dir: str):
    """Main analysis function"""
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Convert data to JAX arrays
    X_train_jax = convert_data_to_jax(X_train)
    y_train_jax = convert_data_to_jax(y_train)
    X_test_jax = convert_data_to_jax(X_test)
    y_test_jax = convert_data_to_jax(y_test)
    
    # Print some of the actual function values
    print("\nActual function values (first 100 training points):")
    print(y_train_jax[:100])
    
    # Load model architecture
    architecture = DeepNN(d, hidden_size, depth, mode)
    
    # Compute kernels directly without loading parameters
    nngp_train, ntk_train, nngp_test, ntk_test = compute_kernels(X_train_jax, X_test_jax, None, architecture)
    
    # Compute spectra from training kernels
    nngp_vals, nngp_vecs = compute_spectrum(nngp_train, k)
    ntk_vals, ntk_vecs = compute_spectrum(ntk_train, k)
    
    # Compute eigenfunction values in function space
    nngp_train_func = compute_eigenfunction_values(nngp_train, nngp_vecs, nngp_train, y_train_jax)
    
    # For test set, project using test-train kernel but compute coefficients using training data
    nngp_test_func = compute_eigenfunction_values(nngp_test, nngp_vecs, nngp_train, y_train_jax)
    
    # Print first 100 values of the projected eigenfunctions
    print("\nFirst 100 values of NNGP eigenfunctions in function space (training set):")
    for i in range(k):
        print(f"\nEigenfunction {i+1}:")
        print(nngp_train_func[:100, i])
        
    # Print kernel eigenvalues
    print("\nNNGP kernel eigenvalues:")
    print(nngp_vals)
    
    # Plot spectra
    plot_spectrum(nngp_vals, ntk_vals,
                 os.path.join(output_dir, 'kernel_spectra.png'),
                 'Kernel Spectra')
    
    # Plot eigenfunctions using the function space values
    plot_eigenfunctions(nngp_train_func, X_train_jax, y_train_jax,
                       os.path.join(output_dir, 'nngp_eigenfunctions_train.png'),
                       'NNGP Eigenfunctions (Training Set)', k)
    plot_eigenfunctions(nngp_test_func, X_test_jax, y_test_jax,
                       os.path.join(output_dir, 'nngp_eigenfunctions_test.png'),
                       'NNGP Eigenfunctions (Test Set)', k)
    
    # For NTK, maybe we should handle it differently or investigate further
    # For now, commenting out NTK eigenfunction plots
    # plot_eigenfunctions(ntk_vecs, X_train_jax, y_train_jax,
    #                    os.path.join(output_dir, 'ntk_eigenfunctions_train.png'),
    #                    'NTK Eigenfunctions (Training Set)', k)
    # plot_eigenfunctions(ntk_test_vals, X_test_jax, y_test_jax,
    #                    os.path.join(output_dir, 'ntk_eigenfunctions_test.png'),
    #                    'NTK Eigenfunctions (Test Set)', k)

if __name__ == "__main__":
    # Model parameters
    P = 8
    d = 30
    hidden_size = 500
    depth = 1
    mode = 'mup_pennington'
    k = 10  # number of eigenvalues/eigenfunctions to analyze
    
    # Dataset parameters
    n_train = 2000
    n_test = 1000
    
    # Define MSP sets
    msp_sets = [{7}, {2,7}, {0,2,7}, {5,7,4}, {1}, {0,4}, {3,7}, {0,1,2,3,4,6,7}]
    
    # Initialize MSP function
    msp = MSPFunction(P, msp_sets)
    
    # Generate datasets
    X_train, y_train, X_test, y_test = generate_datasets(P, d, n_train, n_test, msp)
    
    # Model path and output directory
    model_path = "/mnt/users/goringn/NNs_vs_Kernels/stair_function/results/msp_NN_grid_1612_nogrokk_mup_pennington/final_model_h500_d1_n20000_lr0.05_mup_pennington_20241219_163433_rank8.pt"
    output_dir = "ntk_analysis_results"
    
    # Run analysis
    analyze_model(model_path, X_train, y_train, X_test, y_test,
                 d, hidden_size, depth, mode, k, output_dir)


Actual function values (first 100 training points):
[ 0.  0.  4.  2.  2.  2.  2. -2. -2.  2.  0.  0.  2.  4.  4. -4.  0.  2.
 -2. -2.  4.  2. -6. -2. -2. -2. -2.  2.  4.  0.  0. -2.  4. -6.  4.  6.
 -6.  2.  2. -2.  6.  2. -2.  4.  0.  4.  2. -2. -2.  0.  0.  0. -4. -4.
 -2. -2.  2. -2.  0.  2.  2.  0.  0.  0.  4. -2. -2.  4.  0. -6.  0. -2.
 -4. -4. -6.  4.  0.  2.  4.  0.  0.  2.  0.  0. -2.  2.  2. -4.  0.  4.
 -2.  2. -2. -4.  0.  2.  4.  0. -4.  2.]



First 100 values of NNGP eigenfunctions in function space (training set):

Eigenfunction 1:
[-43.45184  -43.436447 -42.937305 -43.167183 -43.184265 -43.38768
 -42.91761  -43.677357 -43.573334 -43.254055 -43.280552 -43.43361
 -43.721256 -42.852036 -42.924114 -43.394966 -43.54008  -42.668392
 -43.68885  -43.287613 -42.946808 -43.409298 -43.549225 -43.40602
 -43.27262  -43.50223  -42.58378  -43.61826  -43.6876   -43.302174
 -42.582455 -43.636097 -43.49086  -43.114014 -43.341015 -43.173027
 -42.639687 -43.076183 -42.890278 -43.48782  -43.40509  -43.136242
 -43.263718 -42.84679  -42.995335 -42.729614 -42.942257 -42.65961
 -43.29551  -43.40476  -43.033657 -43.187244 -42.4333   -42.7052
 -43.07616  -42.8367   -43.055683 -43.522293 -42.873306 -43.363457
 -42.976707 -43.651794 -43.452934 -43.801987 -43.44698  -43.4505
 -42.839626 -43.143463 -43.606834 -43.372654 -43.55606  -43.56627
 -42.94402  -43.57842  -44.081726 -43.0494   -43.23356  -42.662544
 -43.283566 -43.192894 -43.35858  -43.504612 

  cmap = plt.cm.get_cmap('viridis', num_colors)


In [12]:
import torch
import neural_tangents as nt
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Set, Tuple
import torch.nn as nn
import torch.optim as optim
import os
from scipy.linalg import eigh
from functools import partial

class MSPFunction:
    def __init__(self, P: int, sets: List[Set[int]]):
        self.P = P
        self.sets = sets
    
    def evaluate(self, z: torch.Tensor) -> torch.Tensor:
        device = z.device
        batch_size = z.shape[0]
        result = torch.zeros(batch_size, dtype=torch.float32, device=device)
        
        for S in self.sets:
            term = torch.ones(batch_size, dtype=torch.float32, device=device)
            for idx in S:
                term = term * z[:, idx]
            result = result + term
            
        return result

class DeepNN(nn.Module):
    def __init__(self, d: int, hidden_size: int, depth: int, mode: str = 'mup_pennington'):
        super().__init__()
        
        torch.set_default_dtype(torch.float32)
        
        self.mode = mode
        self.depth = depth
        self.hidden_size = hidden_size
        self.input_dim = d
        
        layers = []
        prev_dim = d
        self.layer_lrs = []  # Store layerwise learning rates
        
        for layer_idx in range(depth):
            linear = nn.Linear(prev_dim, hidden_size)
            
            if mode == 'mup_pennington':
                # muP initialization and learning rates from the paper
                if layer_idx == 0:  # Embedding layer
                    std = 1.0 / np.sqrt(prev_dim)
                    lr_scale = 1.0  # O(1) learning rate for embedding
                else:  # Hidden layers
                    std = 1.0 / np.sqrt(prev_dim)
                    lr_scale = 1.0 / prev_dim  # O(1/n) learning rate for hidden
                nn.init.normal_(linear.weight, mean=0.0, std=std)
                nn.init.zeros_(linear.bias)
                self.layer_lrs.append(lr_scale)
            
            layers.extend([
                linear,
                nn.ReLU()
            ])
            prev_dim = hidden_size
        
        # Final layer
        final_layer = nn.Linear(prev_dim, 1)
        if mode == 'mup_pennington':
            std = 1.0 / np.sqrt(prev_dim)
            lr_scale = 1.0 / prev_dim  # O(1/n) learning rate for readout
            nn.init.normal_(final_layer.weight, std=std)
            self.layer_lrs.append(lr_scale)
            
        nn.init.zeros_(final_layer.bias)
        layers.append(final_layer)
        
        self.network = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x).squeeze()

def generate_datasets(P: int, d: int, n_train: int, n_test: int, msp: MSPFunction, seed: int = 42):
    """Generate training and test datasets"""
    torch.manual_seed(seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    X_train = (2 * torch.bernoulli(0.5 * torch.ones((n_train, d), dtype=torch.float32)) - 1).to(device)
    y_train = msp.evaluate(X_train)
    
    X_test = (2 * torch.bernoulli(0.5 * torch.ones((n_test, d), dtype=torch.float32)) - 1).to(device)
    y_test = msp.evaluate(X_test)
    
    return X_train, y_train, X_test, y_test

def convert_data_to_jax(data: torch.Tensor) -> jnp.ndarray:
    """Convert PyTorch tensor to JAX array"""
    return jnp.array(data.detach().cpu().numpy())

def compute_kernels(X_train: jnp.ndarray, X_test: jnp.ndarray, params: List[jnp.ndarray], 
                   architecture: nn.Module) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """Compute NNGP and NTK kernels for both train and test sets"""
    # Create neural_tangents network
    layers = []
    
    # First layer
    layers.extend([
        nt.stax.Dense(architecture.hidden_size),
        nt.stax.Relu()
    ])
    
    # Additional hidden layers
    for _ in range(architecture.depth - 1):
        layers.extend([
            nt.stax.Dense(architecture.hidden_size),
            nt.stax.Relu()
        ])
    
    # Output layer
    layers.append(nt.stax.Dense(1))
    
    # Create the network
    init_fn, apply_fn, kernel_fn = nt.stax.serial(*layers)
    
    # Compute train kernels
    ntk_train = kernel_fn(X_train, X_train, 'ntk')
    nngp_train = kernel_fn(X_train, X_train, 'nngp')
    
    # Compute test-train kernels
    ntk_test = kernel_fn(X_test, X_train, 'ntk')
    nngp_test = kernel_fn(X_test, X_train, 'nngp')
    
    return nngp_train, ntk_train, nngp_test, ntk_test

def compute_spectrum(kernel: jnp.ndarray, k: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Compute top k eigenvalues and eigenvectors"""
    # Convert JAX array to numpy array for eigendecomposition
    kernel_np = np.array(kernel)
    n = kernel_np.shape[0]
    
    # Compute all eigenvalues and eigenvectors
    eigenvals, eigenvecs = eigh(kernel_np)
    
    # Sort in descending order and take top k
    idx = eigenvals.argsort()[::-1][:k]
    eigenvals = eigenvals[idx]
    eigenvecs = eigenvecs[:, idx]
    
    return eigenvals, eigenvecs

def plot_spectrum(eigenvals_nngp: jnp.ndarray, eigenvals_ntk: jnp.ndarray, 
                 save_path: str, title: str):
    """Plot eigenvalue spectrum for both NNGP and NTK"""
    plt.figure(figsize=(10, 6))
    plt.semilogy(np.arange(len(eigenvals_nngp)), eigenvals_nngp, 'b-', label='NNGP')
    plt.semilogy(np.arange(len(eigenvals_ntk)), eigenvals_ntk, 'r-', label='NTK')
    plt.xlabel('Index')
    plt.ylabel('Eigenvalue (log scale)')
    plt.title(title)
    plt.legend()
    plt.grid(True)
    plt.savefig(save_path)
    plt.close()

def plot_eigenfunctions(eigenvecs: jnp.ndarray, X: jnp.ndarray, y: jnp.ndarray,
                       save_path: str, title: str, k: int):
    """Plot eigenfunctions as a grid of colored squares based on integer values"""
    # Convert JAX arrays to numpy arrays
    eigenvecs_np = np.array(eigenvecs)
    
    # Convert eigenfunction values to nearest integers
    values = np.round(eigenvecs_np).astype(int)
    n_points = len(X)
    
    # Calculate grid dimensions
    grid_size = int(np.ceil(np.sqrt(n_points)))
    
    # Create subplot grid
    n_rows = (k + 2) // 3
    fig, axes = plt.subplots(n_rows, 3, figsize=(15, 5*n_rows))
    axes = axes.ravel()
    
    # Plot each eigenfunction
    for i in range(k):
        # Create grid with zeros for padding
        grid_values = np.zeros((grid_size, grid_size), dtype=int)
        
        # Fill in the actual values
        num_rows = values[:, i].size // grid_size
        remaining = values[:, i].size % grid_size
        
        # Fill complete rows
        grid_values[:num_rows, :] = values[:num_rows*grid_size, i].reshape(num_rows, grid_size)
        
        # Fill remaining values in last row
        if remaining > 0:
            grid_values[num_rows, :remaining] = values[num_rows*grid_size:, i]
        
        # Get unique integer values for colormap
        unique_vals = np.unique(values[:, i])
        vmin, vmax = np.min(unique_vals), np.max(unique_vals)
        
        # Create custom colormap with discrete colors
        num_colors = vmax - vmin + 1
        cmap = plt.cm.get_cmap('viridis', num_colors)
        
        # Create heatmap
        im = axes[i].imshow(grid_values, 
                          cmap=cmap,
                          vmin=vmin - 0.5,
                          vmax=vmax + 0.5)
        
        # Add colorbar with integer ticks
        colorbar = plt.colorbar(im, ax=axes[i], ticks=np.arange(vmin, vmax + 1))
        colorbar.set_label('Rounded Value')
        
        axes[i].set_title(f'Eigenfunction {i+1}')
        axes[i].axis('off')  # Hide axes
    
    # Remove empty subplots
    for i in range(k, len(axes)):
        fig.delaxes(axes[i])
    
    plt.suptitle(title)
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def compute_nngp_eigenfunction_values(kernel: jnp.ndarray, eigenvecs: jnp.ndarray, train_kernel: jnp.ndarray, y_train: jnp.ndarray):
    """Compute NNGP eigenfunction values by projecting through function space"""
    # Compute coefficients using training data
    y_train = y_train.reshape(-1, 1)
    coeffs = jnp.dot(eigenvecs.T, y_train)
    
    # Project to get function values using the appropriate kernel
    function_values = jnp.dot(kernel, eigenvecs)
    
    # Scale by the coefficients
    function_values = function_values * coeffs.T
    
    return function_values

def compute_ntk_eigenfunction_values(kernel: jnp.ndarray, eigenvecs: jnp.ndarray):
    """Compute NTK eigenfunction values which are already in function space"""
    # For NTK, eigenfunction values are directly in function space
    # We just need to project using the kernel for test points
    function_values = jnp.dot(kernel, eigenvecs)
    return function_values

def analyze_model(model_path: str, 
                 X_train: torch.Tensor, y_train: torch.Tensor,
                 X_test: torch.Tensor, y_test: torch.Tensor,
                 d: int, hidden_size: int, depth: int, mode: str,
                 k: int, output_dir: str):
    """Main analysis function"""
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Convert data to JAX arrays
    X_train_jax = convert_data_to_jax(X_train)
    y_train_jax = convert_data_to_jax(y_train)
    X_test_jax = convert_data_to_jax(X_test)
    y_test_jax = convert_data_to_jax(y_test)
    
    # Print some of the actual function values
    print("\nActual function values (first 100 training points):")
    print(y_train_jax[:100])
    
    # Load model architecture
    architecture = DeepNN(d, hidden_size, depth, mode)
    
    # Compute kernels directly without loading parameters
    nngp_train, ntk_train, nngp_test, ntk_test = compute_kernels(X_train_jax, X_test_jax, None, architecture)
    
    # Compute spectra from training kernels
    nngp_vals, nngp_vecs = compute_spectrum(nngp_train, k)
    ntk_vals, ntk_vecs = compute_spectrum(ntk_train, k)
    
    # Compute NNGP eigenfunction values in function space
    nngp_train_func = compute_nngp_eigenfunction_values(nngp_train, nngp_vecs, nngp_train, y_train_jax)
    nngp_test_func = compute_nngp_eigenfunction_values(nngp_test, nngp_vecs, nngp_train, y_train_jax)
    
    # Compute NTK eigenfunction values (already in function space)
    ntk_train_func = compute_ntk_eigenfunction_values(ntk_train, ntk_vecs)
    ntk_test_func = compute_ntk_eigenfunction_values(ntk_test, ntk_vecs)
    
    # Print first 100 values of the projected eigenfunctions
    print("\nFirst 100 values of NNGP eigenfunctions in function space (training set):")
    for i in range(k):
        print(f"\nNNGP Eigenfunction {i+1}:")
        print(nngp_train_func[:100, i])
    
    print("\nFirst 100 values of NTK eigenfunctions in function space (training set):")
    for i in range(k):
        print(f"\nNTK Eigenfunction {i+1}:")
        print(ntk_train_func[:100, i])
        
    # Print kernel eigenvalues
    print("\nNNGP kernel eigenvalues:")
    print(nngp_vals)
    print("\nNTK kernel eigenvalues:")
    print(ntk_vals)
    
    # Plot spectra
    plot_spectrum(nngp_vals, ntk_vals,
                 os.path.join(output_dir, 'kernel_spectra.png'),
                 'Kernel Spectra')
    
    # Plot NNGP eigenfunctions
    plot_eigenfunctions(nngp_train_func, X_train_jax, y_train_jax,
                       os.path.join(output_dir, 'nngp_eigenfunctions_train.png'),
                       'NNGP Eigenfunctions (Training Set)', k)
    plot_eigenfunctions(nngp_test_func, X_test_jax, y_test_jax,
                       os.path.join(output_dir, 'nngp_eigenfunctions_test.png'),
                       'NNGP Eigenfunctions (Test Set)', k)
                       
    # Plot NTK eigenfunctions
    plot_eigenfunctions(ntk_train_func, X_train_jax, y_train_jax,
                       os.path.join(output_dir, 'ntk_eigenfunctions_train.png'),
                       'NTK Eigenfunctions (Training Set)', k)
    plot_eigenfunctions(ntk_test_func, X_test_jax, y_test_jax,
                       os.path.join(output_dir, 'ntk_eigenfunctions_test.png'),
                       'NTK Eigenfunctions (Test Set)', k)
    
    # For NTK, maybe we should handle it differently or investigate further
    # For now, commenting out NTK eigenfunction plots
    # plot_eigenfunctions(ntk_vecs, X_train_jax, y_train_jax,
    #                    os.path.join(output_dir, 'ntk_eigenfunctions_train.png'),
    #                    'NTK Eigenfunctions (Training Set)', k)
    # plot_eigenfunctions(ntk_test_vals, X_test_jax, y_test_jax,
    #                    os.path.join(output_dir, 'ntk_eigenfunctions_test.png'),
    #                    'NTK Eigenfunctions (Test Set)', k)

if __name__ == "__main__":
    # Model parameters
    P = 8
    d = 30
    hidden_size = 500
    depth = 1
    mode = 'mup_pennington'
    k = 20  # number of eigenvalues/eigenfunctions to analyze
    
    # Dataset parameters
    n_train = 5000
    n_test = 1000
    
    # Define MSP sets
    msp_sets = [{7}, {2,7}, {0,2,7}, {5,7,4}, {1}, {0,4}, {3,7}, {0,1,2,3,4,6,7}]
    
    # Initialize MSP function
    msp = MSPFunction(P, msp_sets)
    
    # Generate datasets
    X_train, y_train, X_test, y_test = generate_datasets(P, d, n_train, n_test, msp)
    
    # Model path and output directory
    model_path = "/mnt/users/goringn/NNs_vs_Kernels/stair_function/results/msp_NN_grid_1612_nogrokk_mup_pennington/final_model_h500_d1_n20000_lr0.05_mup_pennington_20241219_163433_rank8.pt"
    output_dir = "ntk_analysis_results"
    
    # Run analysis
    analyze_model(model_path, X_train, y_train, X_test, y_test,
                 d, hidden_size, depth, mode, k, output_dir)


Actual function values (first 100 training points):
[ 0.  0.  4.  2.  2.  2.  2. -2. -2.  2.  0.  0.  2.  4.  4. -4.  0.  2.
 -2. -2.  4.  2. -6. -2. -2. -2. -2.  2.  4.  0.  0. -2.  4. -6.  4.  6.
 -6.  2.  2. -2.  6.  2. -2.  4.  0.  4.  2. -2. -2.  0.  0.  0. -4. -4.
 -2. -2.  2. -2.  0.  2.  2.  0.  0.  0.  4. -2. -2.  4.  0. -6.  0. -2.
 -4. -4. -6.  4.  0.  2.  4.  0.  0.  2.  0.  0. -2.  2.  2. -4.  0.  4.
 -2.  2. -2. -4.  0.  2.  4.  0. -4.  2.]

First 100 values of NNGP eigenfunctions in function space (training set):

NNGP Eigenfunction 1:
[-85.37371  -85.019646 -84.34767  -84.25189  -84.2034   -84.60511
 -84.746796 -85.31234  -85.07817  -84.87671  -85.07814  -85.15818
 -85.20113  -84.49798  -84.997116 -84.91655  -85.08416  -84.01941
 -85.25253  -84.75312  -85.05879  -84.7531   -85.214226 -85.672226
 -84.70594  -85.44693  -84.35448  -85.51021  -84.98478  -85.16694
 -84.60072  -85.07938  -84.63913  -84.76871  -84.69246  -84.38
 -84.424736 -85.28353  -84.62525  -85.03922  -85

  cmap = plt.cm.get_cmap('viridis', num_colors)


In [13]:
import torch
import neural_tangents as nt
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Set, Tuple
import torch.nn as nn
import torch.optim as optim
import os
from scipy.linalg import eigh
from functools import partial

class MSPFunction:
    def __init__(self, P: int, sets: List[Set[int]]):
        self.P = P
        self.sets = sets
    
    def evaluate(self, z: torch.Tensor) -> torch.Tensor:
        device = z.device
        batch_size = z.shape[0]
        result = torch.zeros(batch_size, dtype=torch.float32, device=device)
        
        for S in self.sets:
            term = torch.ones(batch_size, dtype=torch.float32, device=device)
            for idx in S:
                term = term * z[:, idx]
            result = result + term
            
        return result

class DeepNN(nn.Module):
    def __init__(self, d: int, hidden_size: int, depth: int, mode: str = 'mup_pennington'):
        super().__init__()
        
        torch.set_default_dtype(torch.float32)
        
        self.mode = mode
        self.depth = depth
        self.hidden_size = hidden_size
        self.input_dim = d
        
        layers = []
        prev_dim = d
        self.layer_lrs = []  # Store layerwise learning rates
        
        for layer_idx in range(depth):
            linear = nn.Linear(prev_dim, hidden_size)
            
            if mode == 'mup_pennington':
                # muP initialization and learning rates from the paper
                if layer_idx == 0:  # Embedding layer
                    std = 1.0 / np.sqrt(prev_dim)
                    lr_scale = 1.0  # O(1) learning rate for embedding
                else:  # Hidden layers
                    std = 1.0 / np.sqrt(prev_dim)
                    lr_scale = 1.0 / prev_dim  # O(1/n) learning rate for hidden
                nn.init.normal_(linear.weight, mean=0.0, std=std)
                nn.init.zeros_(linear.bias)
                self.layer_lrs.append(lr_scale)
            
            layers.extend([
                linear,
                nn.ReLU()
            ])
            prev_dim = hidden_size
        
        # Final layer
        final_layer = nn.Linear(prev_dim, 1)
        if mode == 'mup_pennington':
            std = 1.0 / np.sqrt(prev_dim)
            lr_scale = 1.0 / prev_dim  # O(1/n) learning rate for readout
            nn.init.normal_(final_layer.weight, std=std)
            self.layer_lrs.append(lr_scale)
            
        nn.init.zeros_(final_layer.bias)
        layers.append(final_layer)
        
        self.network = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x).squeeze()

def generate_datasets(P: int, d: int, n_train: int, n_test: int, msp: MSPFunction, seed: int = 42):
    """Generate training and test datasets"""
    torch.manual_seed(seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    X_train = (2 * torch.bernoulli(0.5 * torch.ones((n_train, d), dtype=torch.float32)) - 1).to(device)
    y_train = msp.evaluate(X_train)
    
    X_test = (2 * torch.bernoulli(0.5 * torch.ones((n_test, d), dtype=torch.float32)) - 1).to(device)
    y_test = msp.evaluate(X_test)
    
    return X_train, y_train, X_test, y_test

def convert_data_to_jax(data: torch.Tensor) -> jnp.ndarray:
    """Convert PyTorch tensor to JAX array"""
    return jnp.array(data.detach().cpu().numpy())

def compute_kernels(X_train: jnp.ndarray, X_test: jnp.ndarray, params: List[jnp.ndarray], 
                   architecture: nn.Module) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """Compute NNGP and NTK kernels for both train and test sets"""
    # Create neural_tangents network
    layers = []
    
    # First layer
    layers.extend([
        nt.stax.Dense(architecture.hidden_size),
        nt.stax.Relu()
    ])
    
    # Additional hidden layers
    for _ in range(architecture.depth - 1):
        layers.extend([
            nt.stax.Dense(architecture.hidden_size),
            nt.stax.Relu()
        ])
    
    # Output layer
    layers.append(nt.stax.Dense(1))
    
    # Create the network
    init_fn, apply_fn, kernel_fn = nt.stax.serial(*layers)
    
    # Compute train kernels
    ntk_train = kernel_fn(X_train, X_train, 'ntk')
    nngp_train = kernel_fn(X_train, X_train, 'nngp')
    
    # Compute test-train kernels
    ntk_test = kernel_fn(X_test, X_train, 'ntk')
    nngp_test = kernel_fn(X_test, X_train, 'nngp')
    
    return nngp_train, ntk_train, nngp_test, ntk_test

def compute_spectrum(kernel: jnp.ndarray, k: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Compute top k eigenvalues and eigenvectors"""
    # Convert JAX array to numpy array for eigendecomposition
    kernel_np = np.array(kernel)
    n = kernel_np.shape[0]
    
    # Compute all eigenvalues and eigenvectors
    eigenvals, eigenvecs = eigh(kernel_np)
    
    # Sort in descending order and take top k
    idx = eigenvals.argsort()[::-1][:k]
    eigenvals = eigenvals[idx]
    eigenvecs = eigenvecs[:, idx]
    
    return eigenvals, eigenvecs

def plot_spectrum(eigenvals_nngp: jnp.ndarray, eigenvals_ntk: jnp.ndarray, 
                 save_path: str, title: str):
    """Plot eigenvalue spectrum for both NNGP and NTK"""
    plt.figure(figsize=(10, 6))
    plt.semilogy(np.arange(len(eigenvals_nngp)), eigenvals_nngp, 'b-', label='NNGP')
    plt.semilogy(np.arange(len(eigenvals_ntk)), eigenvals_ntk, 'r-', label='NTK')
    plt.xlabel('Index')
    plt.ylabel('Eigenvalue (log scale)')
    plt.title(title)
    plt.legend()
    plt.grid(True)
    plt.savefig(save_path)
    plt.close()

def plot_eigenfunctions(eigenvecs: jnp.ndarray, X: jnp.ndarray, y: jnp.ndarray,
                       save_path: str, title: str, k: int):
    """Plot eigenfunctions as a grid of colored squares based on integer values"""
    # Convert JAX arrays to numpy arrays
    eigenvecs_np = np.array(eigenvecs)
    
    # Convert eigenfunction values to nearest integers
    values = np.round(eigenvecs_np).astype(int)
    n_points = len(X)
    
    # Calculate grid dimensions
    grid_size = int(np.ceil(np.sqrt(n_points)))
    
    # Create subplot grid
    n_rows = (k + 2) // 3
    fig, axes = plt.subplots(n_rows, 3, figsize=(15, 5*n_rows))
    axes = axes.ravel()
    
    # Plot each eigenfunction
    for i in range(k):
        # Create grid with zeros for padding
        grid_values = np.zeros((grid_size, grid_size), dtype=int)
        
        # Fill in the actual values
        num_rows = values[:, i].size // grid_size
        remaining = values[:, i].size % grid_size
        
        # Fill complete rows
        grid_values[:num_rows, :] = values[:num_rows*grid_size, i].reshape(num_rows, grid_size)
        
        # Fill remaining values in last row
        if remaining > 0:
            grid_values[num_rows, :remaining] = values[num_rows*grid_size:, i]
        
        # Get unique integer values for colormap
        unique_vals = np.unique(values[:, i])
        vmin, vmax = np.min(unique_vals), np.max(unique_vals)
        
        # Create custom colormap with discrete colors
        num_colors = vmax - vmin + 1
        cmap = plt.cm.get_cmap('viridis', num_colors)
        
        # Create heatmap
        im = axes[i].imshow(grid_values, 
                          cmap=cmap,
                          vmin=vmin - 0.5,
                          vmax=vmax + 0.5)
        
        # Add colorbar with integer ticks
        colorbar = plt.colorbar(im, ax=axes[i], ticks=np.arange(vmin, vmax + 1))
        colorbar.set_label('Rounded Value')
        
        axes[i].set_title(f'Eigenfunction {i+1}')
        axes[i].axis('off')  # Hide axes
    
    # Remove empty subplots
    for i in range(k, len(axes)):
        fig.delaxes(axes[i])
    
    plt.suptitle(title)
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def compute_nngp_eigenfunction_values(kernel: jnp.ndarray, eigenvecs: jnp.ndarray, train_kernel: jnp.ndarray, y_train: jnp.ndarray):
    """Compute NNGP eigenfunction values by projecting through function space"""
    # Compute coefficients using training data
    y_train = y_train.reshape(-1, 1)
    coeffs = jnp.dot(eigenvecs.T, y_train)
    
    # Project to get function values using the appropriate kernel
    function_values = jnp.dot(kernel, eigenvecs)
    
    # Scale by the coefficients
    function_values = function_values * coeffs.T
    
    return function_values

def compute_ntk_eigenfunction_values(kernel: jnp.ndarray, eigenvecs: jnp.ndarray):
    """Compute NTK eigenfunction values which are already in function space"""
    # For NTK, eigenfunction values are directly in function space
    # We just need to project using the kernel for test points
    function_values = jnp.dot(kernel, eigenvecs)
    return function_values

def analyze_model(model_path: str, 
                 X_train: torch.Tensor, y_train: torch.Tensor,
                 X_test: torch.Tensor, y_test: torch.Tensor,
                 d: int, hidden_size: int, depth: int, mode: str,
                 k: int, output_dir: str):
    """Main analysis function"""
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Convert data to JAX arrays
    X_train_jax = convert_data_to_jax(X_train)
    y_train_jax = convert_data_to_jax(y_train)
    X_test_jax = convert_data_to_jax(X_test)
    y_test_jax = convert_data_to_jax(y_test)
    
    # Print some of the actual function values
    print("\nActual function values (first 100 training points):")
    print(y_train_jax[:100])
    
    # Load model architecture
    architecture = DeepNN(d, hidden_size, depth, mode)
    
    # Compute kernels directly without loading parameters
    nngp_train, ntk_train, nngp_test, ntk_test = compute_kernels(X_train_jax, X_test_jax, None, architecture)
    
    # Compute spectra from training kernels
    nngp_vals, nngp_vecs = compute_spectrum(nngp_train, k)
    ntk_vals, ntk_vecs = compute_spectrum(ntk_train, k)
    
    # Compute NNGP eigenfunction values in function space
    nngp_train_func = compute_nngp_eigenfunction_values(nngp_train, nngp_vecs, nngp_train, y_train_jax)
    nngp_test_func = compute_nngp_eigenfunction_values(nngp_test, nngp_vecs, nngp_train, y_train_jax)
    
    # Compute NTK eigenfunction values (already in function space)
    ntk_train_func = compute_ntk_eigenfunction_values(ntk_train, ntk_vecs)
    ntk_test_func = compute_ntk_eigenfunction_values(ntk_test, ntk_vecs)
    
    # Print first 100 values of the projected eigenfunctions
    print("\nFirst 100 values of NNGP eigenfunctions in function space (training set):")
    for i in range(k):
        print(f"\nNNGP Eigenfunction {i+1}:")
        print(nngp_train_func[:100, i])
    
    print("\nFirst 100 values of NTK eigenfunctions in function space (training set):")
    for i in range(k):
        print(f"\nNTK Eigenfunction {i+1}:")
        print(ntk_train_func[:100, i])
        
    # Print kernel eigenvalues
    print("\nNNGP kernel eigenvalues:")
    print(nngp_vals)
    print("\nNTK kernel eigenvalues:")
    print(ntk_vals)
    
    # Plot spectra
    plot_spectrum(nngp_vals, ntk_vals,
                 os.path.join(output_dir, 'kernel_spectra.png'),
                 'Kernel Spectra')
    
    # Plot NNGP eigenfunctions
    plot_eigenfunctions(nngp_train_func, X_train_jax, y_train_jax,
                       os.path.join(output_dir, 'nngp_eigenfunctions_train.png'),
                       'NNGP Eigenfunctions (Training Set)', k)
    plot_eigenfunctions(nngp_test_func, X_test_jax, y_test_jax,
                       os.path.join(output_dir, 'nngp_eigenfunctions_test.png'),
                       'NNGP Eigenfunctions (Test Set)', k)
                       
    # Plot NTK eigenfunctions
    plot_eigenfunctions(ntk_train_func, X_train_jax, y_train_jax,
                       os.path.join(output_dir, 'ntk_eigenfunctions_train.png'),
                       'NTK Eigenfunctions (Training Set)', k)
    plot_eigenfunctions(ntk_test_func, X_test_jax, y_test_jax,
                       os.path.join(output_dir, 'ntk_eigenfunctions_test.png'),
                       'NTK Eigenfunctions (Test Set)', k)
    
    # For NTK, maybe we should handle it differently or investigate further
    # For now, commenting out NTK eigenfunction plots
    # plot_eigenfunctions(ntk_vecs, X_train_jax, y_train_jax,
    #                    os.path.join(output_dir, 'ntk_eigenfunctions_train.png'),
    #                    'NTK Eigenfunctions (Training Set)', k)
    # plot_eigenfunctions(ntk_test_vals, X_test_jax, y_test_jax,
    #                    os.path.join(output_dir, 'ntk_eigenfunctions_test.png'),
    #                    'NTK Eigenfunctions (Test Set)', k)

if __name__ == "__main__":
    # Model parameters
    P = 8
    d = 30
    hidden_size = 500
    depth = 1
    mode = 'mup_pennington'
    k = 10  # number of eigenvalues/eigenfunctions to analyze
    
    # Dataset parameters
    n_train = 5000
    n_test = 1000
    
    # Define MSP sets
    msp_sets = [{7}, {2,7}, {0,2,7}, {5,7,4}, {1}, {0,4}, {3,7}, {0,1,2,3,4,6,7}]
    
    # Initialize MSP function
    msp = MSPFunction(P, msp_sets)
    
    # Generate datasets
    X_train, y_train, X_test, y_test = generate_datasets(P, d, n_train, n_test, msp)
    
    # Model path and output directory
    model_path = "/mnt/users/goringn/NNs_vs_Kernels/stair_function/results/msp_NN_grid_1612_nogrokk_mup_pennington/final_model_h500_d1_n20000_lr0.05_mup_pennington_20241219_163433_rank8.pt"
    output_dir = "ntk_analysis_results"
    
    # Run analysis
    analyze_model(model_path, X_train, y_train, X_test, y_test,
                 d, hidden_size, depth, mode, k, output_dir)


Actual function values (first 100 training points):
[ 0.  0.  4.  2.  2.  2.  2. -2. -2.  2.  0.  0.  2.  4.  4. -4.  0.  2.
 -2. -2.  4.  2. -6. -2. -2. -2. -2.  2.  4.  0.  0. -2.  4. -6.  4.  6.
 -6.  2.  2. -2.  6.  2. -2.  4.  0.  4.  2. -2. -2.  0.  0.  0. -4. -4.
 -2. -2.  2. -2.  0.  2.  2.  0.  0.  0.  4. -2. -2.  4.  0. -6.  0. -2.
 -4. -4. -6.  4.  0.  2.  4.  0.  0.  2.  0.  0. -2.  2.  2. -4.  0.  4.
 -2.  2. -2. -4.  0.  2.  4.  0. -4.  2.]

First 100 values of NNGP eigenfunctions in function space (training set):

NNGP Eigenfunction 1:
[-85.373726 -85.01965  -84.34769  -84.25189  -84.20341  -84.60514
 -84.746796 -85.31235  -85.078186 -84.87671  -85.07814  -85.15819
 -85.20114  -84.49798  -84.997116 -84.91655  -85.08417  -84.0194
 -85.252525 -84.75312  -85.058784 -84.7531   -85.21424  -85.67222
 -84.70594  -85.44693  -84.35448  -85.51021  -84.98477  -85.16695
 -84.60073  -85.07938  -84.63912  -84.7687   -84.69246  -84.37998
 -84.424736 -85.28353  -84.625244 -85.03923  -8

  cmap = plt.cm.get_cmap('viridis', num_colors)


In [14]:
import torch
import neural_tangents as nt
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Set, Tuple
import torch.nn as nn
import torch.optim as optim
import os
from scipy.linalg import eigh
from functools import partial

class MSPFunction:
    def __init__(self, P: int, sets: List[Set[int]]):
        self.P = P
        self.sets = sets
    
    def evaluate(self, z: torch.Tensor) -> torch.Tensor:
        device = z.device
        batch_size = z.shape[0]
        result = torch.zeros(batch_size, dtype=torch.float32, device=device)
        
        for S in self.sets:
            term = torch.ones(batch_size, dtype=torch.float32, device=device)
            for idx in S:
                term = term * z[:, idx]
            result = result + term
            
        return result

class DeepNN(nn.Module):
    def __init__(self, d: int, hidden_size: int, depth: int, mode: str = 'mup_pennington'):
        super().__init__()
        
        torch.set_default_dtype(torch.float32)
        
        self.mode = mode
        self.depth = depth
        self.hidden_size = hidden_size
        self.input_dim = d
        
        layers = []
        prev_dim = d
        self.layer_lrs = []  # Store layerwise learning rates
        
        for layer_idx in range(depth):
            linear = nn.Linear(prev_dim, hidden_size)
            
            if mode == 'mup_pennington':
                # muP initialization and learning rates from the paper
                if layer_idx == 0:  # Embedding layer
                    std = 1.0 / np.sqrt(prev_dim)
                    lr_scale = 1.0  # O(1) learning rate for embedding
                else:  # Hidden layers
                    std = 1.0 / np.sqrt(prev_dim)
                    lr_scale = 1.0 / prev_dim  # O(1/n) learning rate for hidden
                nn.init.normal_(linear.weight, mean=0.0, std=std)
                nn.init.zeros_(linear.bias)
                self.layer_lrs.append(lr_scale)
            
            layers.extend([
                linear,
                nn.ReLU()
            ])
            prev_dim = hidden_size
        
        # Final layer
        final_layer = nn.Linear(prev_dim, 1)
        if mode == 'mup_pennington':
            std = 1.0 / np.sqrt(prev_dim)
            lr_scale = 1.0 / prev_dim  # O(1/n) learning rate for readout
            nn.init.normal_(final_layer.weight, std=std)
            self.layer_lrs.append(lr_scale)
            
        nn.init.zeros_(final_layer.bias)
        layers.append(final_layer)
        
        self.network = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x).squeeze()

def generate_datasets(P: int, d: int, n_train: int, n_test: int, msp: MSPFunction, seed: int = 42):
    """Generate training and test datasets"""
    torch.manual_seed(seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    X_train = (2 * torch.bernoulli(0.5 * torch.ones((n_train, d), dtype=torch.float32)) - 1).to(device)
    y_train = msp.evaluate(X_train)
    
    X_test = (2 * torch.bernoulli(0.5 * torch.ones((n_test, d), dtype=torch.float32)) - 1).to(device)
    y_test = msp.evaluate(X_test)
    
    return X_train, y_train, X_test, y_test

def convert_data_to_jax(data: torch.Tensor) -> jnp.ndarray:
    """Convert PyTorch tensor to JAX array"""
    return jnp.array(data.detach().cpu().numpy())

def compute_kernels(X_train: jnp.ndarray, X_test: jnp.ndarray, params: List[jnp.ndarray], 
                   architecture: nn.Module) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """Compute NNGP and NTK kernels for both train and test sets"""
    # Create neural_tangents network
    layers = []
    
    # First layer
    layers.extend([
        nt.stax.Dense(architecture.hidden_size),
        nt.stax.Relu()
    ])
    
    # Additional hidden layers
    for _ in range(architecture.depth - 1):
        layers.extend([
            nt.stax.Dense(architecture.hidden_size),
            nt.stax.Relu()
        ])
    
    # Output layer
    layers.append(nt.stax.Dense(1))
    
    # Create the network
    init_fn, apply_fn, kernel_fn = nt.stax.serial(*layers)
    
    # Compute train kernels
    ntk_train = kernel_fn(X_train, X_train, 'ntk')
    nngp_train = kernel_fn(X_train, X_train, 'nngp')
    
    # Compute test-train kernels
    ntk_test = kernel_fn(X_test, X_train, 'ntk')
    nngp_test = kernel_fn(X_test, X_train, 'nngp')
    
    return nngp_train, ntk_train, nngp_test, ntk_test

def compute_spectrum(kernel: jnp.ndarray, k: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Compute top k eigenvalues and eigenvectors"""
    # Convert JAX array to numpy array for eigendecomposition
    kernel_np = np.array(kernel)
    n = kernel_np.shape[0]
    
    # Compute all eigenvalues and eigenvectors
    eigenvals, eigenvecs = eigh(kernel_np)
    
    # Sort in descending order and take top k
    idx = eigenvals.argsort()[::-1][:k]
    eigenvals = eigenvals[idx]
    eigenvecs = eigenvecs[:, idx]
    
    return eigenvals, eigenvecs

def plot_spectrum(eigenvals_nngp: jnp.ndarray, eigenvals_ntk: jnp.ndarray, 
                 save_path: str, title: str):
    """Plot eigenvalue spectrum for both NNGP and NTK"""
    plt.figure(figsize=(10, 6))
    plt.semilogy(np.arange(len(eigenvals_nngp)), eigenvals_nngp, 'b-', label='NNGP')
    plt.semilogy(np.arange(len(eigenvals_ntk)), eigenvals_ntk, 'r-', label='NTK')
    plt.xlabel('Index')
    plt.ylabel('Eigenvalue (log scale)')
    plt.title(title)
    plt.legend()
    plt.grid(True)
    plt.savefig(save_path)
    plt.close()

def plot_eigenfunctions(eigenvecs: jnp.ndarray, X: jnp.ndarray, y: jnp.ndarray,
                       save_path: str, title: str, k: int):
    """Plot eigenfunctions as a grid of colored squares based on integer values"""
    # Convert JAX arrays to numpy arrays
    eigenvecs_np = np.array(eigenvecs)
    
    # Convert eigenfunction values to nearest integers
    values = np.round(eigenvecs_np).astype(int)
    n_points = len(X)
    
    # Calculate grid dimensions
    grid_size = int(np.ceil(np.sqrt(n_points)))
    
    # Create subplot grid
    n_rows = (k + 2) // 3
    fig, axes = plt.subplots(n_rows, 3, figsize=(15, 5*n_rows))
    axes = axes.ravel()
    
    # Plot each eigenfunction
    for i in range(k):
        # Create grid with zeros for padding
        grid_values = np.zeros((grid_size, grid_size), dtype=int)
        
        # Fill in the actual values
        num_rows = values[:, i].size // grid_size
        remaining = values[:, i].size % grid_size
        
        # Fill complete rows
        grid_values[:num_rows, :] = values[:num_rows*grid_size, i].reshape(num_rows, grid_size)
        
        # Fill remaining values in last row
        if remaining > 0:
            grid_values[num_rows, :remaining] = values[num_rows*grid_size:, i]
        
        # Get unique integer values for colormap
        unique_vals = np.unique(values[:, i])
        vmin, vmax = np.min(unique_vals), np.max(unique_vals)
        
        # Create custom colormap with discrete colors
        num_colors = vmax - vmin + 1
        cmap = plt.cm.get_cmap('viridis', num_colors)
        
        # Create heatmap
        im = axes[i].imshow(grid_values, 
                          cmap=cmap,
                          vmin=vmin - 0.5,
                          vmax=vmax + 0.5)
        
        # Add colorbar with integer ticks
        colorbar = plt.colorbar(im, ax=axes[i], ticks=np.arange(vmin, vmax + 1))
        colorbar.set_label('Rounded Value')
        
        axes[i].set_title(f'Eigenfunction {i+1}')
        axes[i].axis('off')  # Hide axes
    
    # Remove empty subplots
    for i in range(k, len(axes)):
        fig.delaxes(axes[i])
    
    plt.suptitle(title)
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def compute_nngp_eigenfunction_values(kernel: jnp.ndarray, eigenvecs: jnp.ndarray, train_kernel: jnp.ndarray, y_train: jnp.ndarray):
    """Compute NNGP eigenfunction values by projecting through function space"""
    # Compute coefficients using training data
    y_train = y_train.reshape(-1, 1)
    coeffs = jnp.dot(eigenvecs.T, y_train)
    
    # Project to get function values using the appropriate kernel
    function_values = jnp.dot(kernel, eigenvecs)
    
    # Scale by the coefficients
    function_values = function_values * coeffs.T
    
    return function_values

def compute_ntk_eigenfunction_values(kernel: jnp.ndarray, eigenvecs: jnp.ndarray):
    """Compute NTK eigenfunction values which are already in function space"""
    # For NTK, eigenfunction values are directly in function space
    # We just need to project using the kernel for test points
    function_values = jnp.dot(kernel, eigenvecs)
    return function_values

def analyze_model(model_path: str, 
                 X_train: torch.Tensor, y_train: torch.Tensor,
                 X_test: torch.Tensor, y_test: torch.Tensor,
                 d: int, hidden_size: int, depth: int, mode: str,
                 k: int, output_dir: str):
    """Main analysis function"""
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Convert data to JAX arrays
    X_train_jax = convert_data_to_jax(X_train)
    y_train_jax = convert_data_to_jax(y_train)
    X_test_jax = convert_data_to_jax(X_test)
    y_test_jax = convert_data_to_jax(y_test)
    
    # Print some of the actual function values
    print("\nActual function values (first 100 training points):")
    print(y_train_jax[:100])
    
    # Load model architecture
    architecture = DeepNN(d, hidden_size, depth, mode)
    
    # Compute kernels directly without loading parameters
    nngp_train, ntk_train, nngp_test, ntk_test = compute_kernels(X_train_jax, X_test_jax, None, architecture)
    
    # Compute spectra from training kernels
    nngp_vals, nngp_vecs = compute_spectrum(nngp_train, k)
    ntk_vals, ntk_vecs = compute_spectrum(ntk_train, k)
    
    # Compute NNGP eigenfunction values in function space
    nngp_train_func = compute_nngp_eigenfunction_values(nngp_train, nngp_vecs, nngp_train, y_train_jax)
    nngp_test_func = compute_nngp_eigenfunction_values(nngp_test, nngp_vecs, nngp_train, y_train_jax)
    
    # Compute NTK eigenfunction values (already in function space)
    ntk_train_func = compute_ntk_eigenfunction_values(ntk_train, ntk_vecs)
    ntk_test_func = compute_ntk_eigenfunction_values(ntk_test, ntk_vecs)
    
    # Print first 100 values of the projected eigenfunctions
    print("\nFirst 100 values of NNGP eigenfunctions in function space (training set):")
    for i in range(k):
        print(f"\nNNGP Eigenfunction {i+1}:")
        print(nngp_train_func[:100, i])
    
    print("\nFirst 100 values of NTK eigenfunctions in function space (training set):")
    for i in range(k):
        print(f"\nNTK Eigenfunction {i+1}:")
        print(ntk_train_func[:100, i])
        
    # Print kernel eigenvalues
    print("\nNNGP kernel eigenvalues:")
    print(nngp_vals)
    print("\nNTK kernel eigenvalues:")
    print(ntk_vals)
    
    # Plot spectra
    plot_spectrum(nngp_vals, ntk_vals,
                 os.path.join(output_dir, 'kernel_spectra.png'),
                 'Kernel Spectra')
    
    # Plot NNGP eigenfunctions
    plot_eigenfunctions(nngp_train_func, X_train_jax, y_train_jax,
                       os.path.join(output_dir, 'nngp_eigenfunctions_train.png'),
                       'NNGP Eigenfunctions (Training Set)', k)
    plot_eigenfunctions(nngp_test_func, X_test_jax, y_test_jax,
                       os.path.join(output_dir, 'nngp_eigenfunctions_test.png'),
                       'NNGP Eigenfunctions (Test Set)', k)
                       
    # Plot NTK eigenfunctions
    plot_eigenfunctions(ntk_train_func, X_train_jax, y_train_jax,
                       os.path.join(output_dir, 'ntk_eigenfunctions_train.png'),
                       'NTK Eigenfunctions (Training Set)', k)
    plot_eigenfunctions(ntk_test_func, X_test_jax, y_test_jax,
                       os.path.join(output_dir, 'ntk_eigenfunctions_test.png'),
                       'NTK Eigenfunctions (Test Set)', k)
    
    # For NTK, maybe we should handle it differently or investigate further
    # For now, commenting out NTK eigenfunction plots
    # plot_eigenfunctions(ntk_vecs, X_train_jax, y_train_jax,
    #                    os.path.join(output_dir, 'ntk_eigenfunctions_train.png'),
    #                    'NTK Eigenfunctions (Training Set)', k)
    # plot_eigenfunctions(ntk_test_vals, X_test_jax, y_test_jax,
    #                    os.path.join(output_dir, 'ntk_eigenfunctions_test.png'),
    #                    'NTK Eigenfunctions (Test Set)', k)

if __name__ == "__main__":
    # Model parameters
    P = 8
    d = 30
    hidden_size = 500
    depth = 1
    mode = 'mup_pennington'
    k = 10  # number of eigenvalues/eigenfunctions to analyze
    
    # Dataset parameters
    n_train = 800
    n_test = 700
    
    # Define MSP sets
    msp_sets = [{7}, {2,7}, {0,2,7}, {5,7,4}, {1}, {0,4}, {3,7}, {0,1,2,3,4,6,7}]
    
    # Initialize MSP function
    msp = MSPFunction(P, msp_sets)
    
    # Generate datasets
    X_train, y_train, X_test, y_test = generate_datasets(P, d, n_train, n_test, msp)
    
    # Model path and output directory
    model_path = "/mnt/users/goringn/NNs_vs_Kernels/stair_function/results/msp_NN_grid_1612_nogrokk_mup_pennington/final_model_h500_d1_n20000_lr0.05_mup_pennington_20241219_163433_rank8.pt"
    output_dir = "ntk_analysis_results"
    
    # Run analysis
    analyze_model(model_path, X_train, y_train, X_test, y_test,
                 d, hidden_size, depth, mode, k, output_dir)


Actual function values (first 100 training points):
[ 0.  0.  4.  2.  2.  2.  2. -2. -2.  2.  0.  0.  2.  4.  4. -4.  0.  2.
 -2. -2.  4.  2. -6. -2. -2. -2. -2.  2.  4.  0.  0. -2.  4. -6.  4.  6.
 -6.  2.  2. -2.  6.  2. -2.  4.  0.  4.  2. -2. -2.  0.  0.  0. -4. -4.
 -2. -2.  2. -2.  0.  2.  2.  0.  0.  0.  4. -2. -2.  4.  0. -6.  0. -2.
 -4. -4. -6.  4.  0.  2.  4.  0.  0.  2.  0.  0. -2.  2.  2. -4.  0.  4.
 -2.  2. -2. -4.  0.  2.  4.  0. -4.  2.]

First 100 values of NNGP eigenfunctions in function space (training set):

NNGP Eigenfunction 1:
[-12.92087   -12.851156  -12.693413  -12.662748  -13.010132  -12.97372
 -12.797591  -13.00122   -13.122744  -12.792023  -12.926179  -13.005964
 -13.104382  -12.627592  -12.787695  -12.816565  -13.016538  -12.606921
 -12.933316  -12.867984  -12.791466  -12.844892  -12.905329  -12.740396
 -12.669813  -13.078498  -12.636147  -13.059945  -13.17651   -12.750199
 -12.47125   -13.0295925 -12.9372635 -12.906089  -12.996647  -12.926755
 -12.65088 

  cmap = plt.cm.get_cmap('viridis', num_colors)


In [None]:
#### function works

In [25]:
import torch
import neural_tangents as nt
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Set, Tuple
import torch.nn as nn
import torch.optim as optim
import os
from scipy.linalg import eigh
from functools import partial

class MSPFunction:
    def __init__(self, P: int, sets: List[Set[int]]):
        self.P = P
        self.sets = sets
    
    def evaluate(self, z: torch.Tensor) -> torch.Tensor:
        device = z.device
        batch_size = z.shape[0]
        result = torch.zeros(batch_size, dtype=torch.float32, device=device)
        
        for S in self.sets:
            term = torch.ones(batch_size, dtype=torch.float32, device=device)
            for idx in S:
                term = term * z[:, idx]
            result = result + term
            
        return result

class DeepNN(nn.Module):
    def __init__(self, d: int, hidden_size: int, depth: int, mode: str = 'mup_pennington'):
        super().__init__()
        
        torch.set_default_dtype(torch.float32)
        
        self.mode = mode
        self.depth = depth
        self.hidden_size = hidden_size
        self.input_dim = d
        
        layers = []
        prev_dim = d
        self.layer_lrs = []  # Store layerwise learning rates
        
        for layer_idx in range(depth):
            linear = nn.Linear(prev_dim, hidden_size)
            
            if mode == 'mup_pennington':
                # muP initialization and learning rates from the paper
                if layer_idx == 0:  # Embedding layer
                    std = 1.0 / np.sqrt(prev_dim)
                    lr_scale = 1.0  # O(1) learning rate for embedding
                else:  # Hidden layers
                    std = 1.0 / np.sqrt(prev_dim)
                    lr_scale = 1.0 / prev_dim  # O(1/n) learning rate for hidden
                nn.init.normal_(linear.weight, mean=0.0, std=std)
                nn.init.zeros_(linear.bias)
                self.layer_lrs.append(lr_scale)
            
            layers.extend([
                linear,
                nn.ReLU()
            ])
            prev_dim = hidden_size
        
        # Final layer
        final_layer = nn.Linear(prev_dim, 1)
        if mode == 'mup_pennington':
            std = 1.0 / np.sqrt(prev_dim)
            lr_scale = 1.0 / prev_dim  # O(1/n) learning rate for readout
            nn.init.normal_(final_layer.weight, std=std)
            self.layer_lrs.append(lr_scale)
            
        nn.init.zeros_(final_layer.bias)
        layers.append(final_layer)
        
        self.network = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x).squeeze()

def generate_datasets(P: int, d: int, n_train: int, n_test: int, msp: MSPFunction, seed: int = 42):
    """Generate training and test datasets"""
    torch.manual_seed(seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    X_train = (2 * torch.bernoulli(0.5 * torch.ones((n_train, d), dtype=torch.float32)) - 1).to(device)
    y_train = msp.evaluate(X_train)
    
    X_test = (2 * torch.bernoulli(0.5 * torch.ones((n_test, d), dtype=torch.float32)) - 1).to(device)
    y_test = msp.evaluate(X_test)
    
    return X_train, y_train, X_test, y_test

def convert_data_to_jax(data: torch.Tensor) -> jnp.ndarray:
    """Convert PyTorch tensor to JAX array"""
    return jnp.array(data.detach().cpu().numpy())

def compute_kernels(X_train: jnp.ndarray, X_test: jnp.ndarray, params: List[jnp.ndarray], 
                   architecture: nn.Module) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """Compute NNGP and NTK kernels for both train and test sets"""
    # Create neural_tangents network
    layers = []
    
    # First layer
    layers.extend([
        nt.stax.Dense(architecture.hidden_size),
        nt.stax.Relu()
    ])
    
    # Additional hidden layers
    for _ in range(architecture.depth - 1):
        layers.extend([
            nt.stax.Dense(architecture.hidden_size),
            nt.stax.Relu()
        ])
    
    # Output layer
    layers.append(nt.stax.Dense(1))
    
    # Create the network
    init_fn, apply_fn, kernel_fn = nt.stax.serial(*layers)
    
    # Compute train kernels
    ntk_train = kernel_fn(X_train, X_train, 'ntk')
    nngp_train = kernel_fn(X_train, X_train, 'nngp')
    
    # Compute test-train kernels
    ntk_test = kernel_fn(X_test, X_train, 'ntk')
    nngp_test = kernel_fn(X_test, X_train, 'nngp')
    
    return nngp_train, ntk_train, nngp_test, ntk_test

def compute_spectrum(kernel: jnp.ndarray, k: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Compute top k eigenvalues and eigenvectors"""
    # Convert JAX array to numpy array for eigendecomposition
    kernel_np = np.array(kernel)
    n = kernel_np.shape[0]
    
    # Compute all eigenvalues and eigenvectors
    eigenvals, eigenvecs = eigh(kernel_np)
    
    # Sort in descending order and take top k
    idx = eigenvals.argsort()[::-1][:k]
    eigenvals = eigenvals[idx]
    eigenvecs = eigenvecs[:, idx]
    
    return eigenvals, eigenvecs

def plot_spectrum(eigenvals_nngp: jnp.ndarray, eigenvals_ntk: jnp.ndarray, 
                 save_path: str, title: str):
    """Plot eigenvalue spectrum for both NNGP and NTK"""
    plt.figure(figsize=(10, 6))
    plt.semilogy(np.arange(len(eigenvals_nngp)), eigenvals_nngp, 'b-', label='NNGP')
    plt.semilogy(np.arange(len(eigenvals_ntk)), eigenvals_ntk, 'r-', label='NTK')
    plt.xlabel('Index')
    plt.ylabel('Eigenvalue (log scale)')
    plt.title(title)
    plt.legend()
    plt.grid(True)
    plt.savefig(save_path)
    plt.close()

def plot_eigenfunctions(eigenvecs: jnp.ndarray, X: jnp.ndarray, y: jnp.ndarray,
                       save_path: str, title: str, k: int):
    """Plot eigenfunctions as a 1×n_train grid of colored squares"""
    # Convert JAX arrays to numpy arrays
    eigenvecs_np = np.array(eigenvecs)
    X_np = np.array(X)
    n_points = len(X_np)
    
    # Create figure with k subplots, each 1×n_train
    fig, axes = plt.subplots(k, 1, figsize=(20, 0.5*k))
    if k == 1:
        axes = [axes]
    
    # Get global min/max for consistent color scaling
    vmin = np.min(eigenvecs_np)
    vmax = np.max(eigenvecs_np)
    
    # Create colormap normalization
    norm = plt.Normalize(vmin=vmin, vmax=vmax)
    sm = plt.cm.ScalarMappable(cmap='viridis', norm=norm)
    
    # Plot each eigenfunction
    for i in range(k):
        # Reshape to 1×n_train for display
        values = eigenvecs_np[:, i].reshape(1, -1)
        
        # Display as a single row of squares
        im = axes[i].imshow(values, 
                          aspect='auto',  # This ensures squares fill the width
                          cmap='viridis',
                          vmin=vmin,
                          vmax=vmax)
        
        axes[i].set_title(f'Eigenfunction {i+1}')
        axes[i].set_yticks([])  # Remove y-axis ticks
        axes[i].set_xticks([])  # Remove x-axis ticks
    
    # Add single colorbar
    fig.colorbar(sm, ax=axes, label='Eigenfunction Value')
    
    plt.suptitle(title)
    plt.tight_layout()
    plt.savefig(save_path, bbox_inches='tight', dpi=300)
    plt.close()

def compute_nngp_eigenfunction_values(kernel: jnp.ndarray, eigenvecs: jnp.ndarray, train_kernel: jnp.ndarray, y_train: jnp.ndarray):
    """Compute NNGP eigenfunction values by projecting through function space"""
    # Compute coefficients using training data
    y_train = y_train.reshape(-1, 1)
    coeffs = jnp.dot(eigenvecs.T, y_train)
    
    # Project to get function values using the appropriate kernel
    function_values = jnp.dot(kernel, eigenvecs)
    
    # Scale by the coefficients
    function_values = function_values * coeffs.T
    
    return function_values

def compute_ntk_eigenfunction_values(kernel: jnp.ndarray, eigenvecs: jnp.ndarray):
    """Compute NTK eigenfunction values which are already in function space"""
    # For NTK, eigenfunction values are directly in function space
    # We just need to project using the kernel for test points
    function_values = jnp.dot(kernel, eigenvecs)
    return function_values

def analyze_model(model_path: str, 
                 X_train: torch.Tensor, y_train: torch.Tensor,
                 X_test: torch.Tensor, y_test: torch.Tensor,
                 d: int, hidden_size: int, depth: int, mode: str,
                 k: int, output_dir: str):
    """Main analysis function"""
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Convert data to JAX arrays
    X_train_jax = convert_data_to_jax(X_train)
    y_train_jax = convert_data_to_jax(y_train)
    X_test_jax = convert_data_to_jax(X_test)
    y_test_jax = convert_data_to_jax(y_test)
    
    # Print some of the actual function values
    print("\nActual function values (first 100 training points):")
    print(y_train_jax[:100])
    
    # Load model architecture
    architecture = DeepNN(d, hidden_size, depth, mode)
    
    # Compute kernels directly without loading parameters
    nngp_train, ntk_train, nngp_test, ntk_test = compute_kernels(X_train_jax, X_test_jax, None, architecture)
    
    # Compute spectra from training kernels
    nngp_vals, nngp_vecs = compute_spectrum(nngp_train, k)
    ntk_vals, ntk_vecs = compute_spectrum(ntk_train, k)
    
    # Compute NNGP eigenfunction values in function space
    nngp_train_func = compute_nngp_eigenfunction_values(nngp_train, nngp_vecs, nngp_train, y_train_jax)
    nngp_test_func = compute_nngp_eigenfunction_values(nngp_test, nngp_vecs, nngp_train, y_train_jax)
    
    # Compute NTK eigenfunction values (already in function space)
    ntk_train_func = compute_ntk_eigenfunction_values(ntk_train, ntk_vecs)
    ntk_test_func = compute_ntk_eigenfunction_values(ntk_test, ntk_vecs)
    
    # Print first 100 values of the projected eigenfunctions
    print("\nFirst 100 values of NNGP eigenfunctions in function space (training set):")
    for i in range(k):
        print(f"\nNNGP Eigenfunction {i+1}:")
        print(nngp_train_func[:100, i])
    
    print("\nFirst 100 values of NTK eigenfunctions in function space (training set):")
    for i in range(k):
        print(f"\nNTK Eigenfunction {i+1}:")
        print(ntk_train_func[:100, i])
        
    # Print kernel eigenvalues
    print("\nNNGP kernel eigenvalues:")
    print(nngp_vals)
    print("\nNTK kernel eigenvalues:")
    print(ntk_vals)
    
    # Plot spectra
    plot_spectrum(nngp_vals, ntk_vals,
                 os.path.join(output_dir, 'kernel_spectra.png'),
                 'Kernel Spectra')
    
    # Plot NNGP eigenfunctions
    plot_eigenfunctions(nngp_train_func, X_train_jax, y_train_jax,
                       os.path.join(output_dir, 'nngp_eigenfunctions_train.png'),
                       'NNGP Eigenfunctions (Training Set)', k)
    plot_eigenfunctions(nngp_test_func, X_test_jax, y_test_jax,
                       os.path.join(output_dir, 'nngp_eigenfunctions_test.png'),
                       'NNGP Eigenfunctions (Test Set)', k)
                       
    # Plot NTK eigenfunctions
    plot_eigenfunctions(ntk_train_func, X_train_jax, y_train_jax,
                       os.path.join(output_dir, 'ntk_eigenfunctions_train.png'),
                       'NTK Eigenfunctions (Training Set)', k)
    plot_eigenfunctions(ntk_test_func, X_test_jax, y_test_jax,
                       os.path.join(output_dir, 'ntk_eigenfunctions_test.png'),
                       'NTK Eigenfunctions (Test Set)', k)
    
    # For NTK, maybe we should handle it differently or investigate further
    # For now, commenting out NTK eigenfunction plots
    # plot_eigenfunctions(ntk_vecs, X_train_jax, y_train_jax,
    #                    os.path.join(output_dir, 'ntk_eigenfunctions_train.png'),
    #                    'NTK Eigenfunctions (Training Set)', k)
    # plot_eigenfunctions(ntk_test_vals, X_test_jax, y_test_jax,
    #                    os.path.join(output_dir, 'ntk_eigenfunctions_test.png'),
    #                    'NTK Eigenfunctions (Test Set)', k)

if __name__ == "__main__":
    # Model parameters
    P = 8
    d = 30
    hidden_size = 500
    depth = 1
    mode = 'mup_pennington'
    k = 10  # number of eigenvalues/eigenfunctions to analyze
    
    # Dataset parameters
    n_train = 8000
    n_test = 8000
    
    # Define MSP sets
    msp_sets = [{7}, {2,7}, {0,2,7}, {5,7,4}, {1}, {0,4}, {3,7}, {0,1,2,3,4,6,7}]
    
    # Initialize MSP function
    msp = MSPFunction(P, msp_sets)
    
    # Generate datasets
    X_train, y_train, X_test, y_test = generate_datasets(P, d, n_train, n_test, msp)
    
    # Model path and output directory
    model_path = "/mnt/users/goringn/NNs_vs_Kernels/stair_function/results/msp_NN_grid_1612_nogrokk_mup_pennington/final_model_h500_d1_n20000_lr0.05_mup_pennington_20241219_163433_rank8.pt"
    output_dir = "ntk_analysis_results_init"
    
    # Run analysis
    analyze_model(model_path, X_train, y_train, X_test, y_test,
                 d, hidden_size, depth, mode, k, output_dir)


Actual function values (first 100 training points):
[ 0.  0.  4.  2.  2.  2.  2. -2. -2.  2.  0.  0.  2.  4.  4. -4.  0.  2.
 -2. -2.  4.  2. -6. -2. -2. -2. -2.  2.  4.  0.  0. -2.  4. -6.  4.  6.
 -6.  2.  2. -2.  6.  2. -2.  4.  0.  4.  2. -2. -2.  0.  0.  0. -4. -4.
 -2. -2.  2. -2.  0.  2.  2.  0.  0.  0.  4. -2. -2.  4.  0. -6.  0. -2.
 -4. -4. -6.  4.  0.  2.  4.  0.  0.  2.  0.  0. -2.  2.  2. -4.  0.  4.
 -2.  2. -2. -4.  0.  2.  4.  0. -4.  2.]

First 100 values of NNGP eigenfunctions in function space (training set):

NNGP Eigenfunction 1:
[-57.547653 -57.385853 -57.32267  -57.295376 -57.245842 -57.338284
 -57.532696 -57.59592  -57.843735 -57.495667 -57.524254 -57.76261
 -57.545918 -57.3961   -57.424057 -57.569164 -57.558372 -57.21671
 -57.682575 -57.50311  -57.557114 -57.54716  -57.538612 -57.656208
 -57.305386 -57.86829  -57.298805 -57.654217 -57.569412 -57.5054
 -57.22424  -57.68526  -57.264507 -57.3798   -57.432182 -57.39058
 -57.104225 -57.7068   -57.563557 -57.519077 

  plt.tight_layout()


In [26]:
import torch
import neural_tangents as nt
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Set, Tuple
import torch.nn as nn
import torch.optim as optim
import os
from scipy.linalg import eigh
from functools import partial

class MSPFunction:
    def __init__(self, P: int, sets: List[Set[int]]):
        self.P = P
        self.sets = sets
    
    def evaluate(self, z: torch.Tensor) -> torch.Tensor:
        device = z.device
        batch_size = z.shape[0]
        result = torch.zeros(batch_size, dtype=torch.float32, device=device)
        
        for S in self.sets:
            term = torch.ones(batch_size, dtype=torch.float32, device=device)
            for idx in S:
                term = term * z[:, idx]
            result = result + term
            
        return result

class DeepNN(nn.Module):
    def __init__(self, d: int, hidden_size: int, depth: int, mode: str = 'mup_pennington'):
        super().__init__()
        
        torch.set_default_dtype(torch.float32)
        
        self.mode = mode
        self.depth = depth
        self.hidden_size = hidden_size
        self.input_dim = d
        
        layers = []
        prev_dim = d
        self.layer_lrs = []  # Store layerwise learning rates
        
        for layer_idx in range(depth):
            linear = nn.Linear(prev_dim, hidden_size)
            
            if mode == 'mup_pennington':
                # muP initialization and learning rates from the paper
                if layer_idx == 0:  # Embedding layer
                    std = 1.0 / np.sqrt(prev_dim)
                    lr_scale = 1.0  # O(1) learning rate for embedding
                else:  # Hidden layers
                    std = 1.0 / np.sqrt(prev_dim)
                    lr_scale = 1.0 / prev_dim  # O(1/n) learning rate for hidden
                nn.init.normal_(linear.weight, mean=0.0, std=std)
                nn.init.zeros_(linear.bias)
                self.layer_lrs.append(lr_scale)
            
            layers.extend([
                linear,
                nn.ReLU()
            ])
            prev_dim = hidden_size
        
        # Final layer
        final_layer = nn.Linear(prev_dim, 1)
        if mode == 'mup_pennington':
            std = 1.0 / np.sqrt(prev_dim)
            lr_scale = 1.0 / prev_dim  # O(1/n) learning rate for readout
            nn.init.normal_(final_layer.weight, std=std)
            self.layer_lrs.append(lr_scale)
            
        nn.init.zeros_(final_layer.bias)
        layers.append(final_layer)
        
        self.network = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x).squeeze()

def generate_datasets(P: int, d: int, n_train: int, n_test: int, msp: MSPFunction, seed: int = 42):
    """Generate training and test datasets"""
    torch.manual_seed(seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    X_train = (2 * torch.bernoulli(0.5 * torch.ones((n_train, d), dtype=torch.float32)) - 1).to(device)
    y_train = msp.evaluate(X_train)
    
    X_test = (2 * torch.bernoulli(0.5 * torch.ones((n_test, d), dtype=torch.float32)) - 1).to(device)
    y_test = msp.evaluate(X_test)
    
    return X_train, y_train, X_test, y_test

def convert_data_to_jax(data: torch.Tensor) -> jnp.ndarray:
    """Convert PyTorch tensor to JAX array"""
    return jnp.array(data.detach().cpu().numpy())

def compute_kernels(X_train: jnp.ndarray, X_test: jnp.ndarray, params: List[jnp.ndarray], 
                   architecture: nn.Module) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """Compute NNGP and NTK kernels for both train and test sets"""
    # Create neural_tangents network
    layers = []
    
    # First layer
    layers.extend([
        nt.stax.Dense(architecture.hidden_size),
        nt.stax.Relu()
    ])
    
    # Additional hidden layers
    for _ in range(architecture.depth - 1):
        layers.extend([
            nt.stax.Dense(architecture.hidden_size),
            nt.stax.Relu()
        ])
    
    # Output layer
    layers.append(nt.stax.Dense(1))
    
    # Create the network
    init_fn, apply_fn, kernel_fn = nt.stax.serial(*layers)
    
    # Compute train kernels
    ntk_train = kernel_fn(X_train, X_train, 'ntk')
    nngp_train = kernel_fn(X_train, X_train, 'nngp')
    
    # Compute test-train kernels
    ntk_test = kernel_fn(X_test, X_train, 'ntk')
    nngp_test = kernel_fn(X_test, X_train, 'nngp')
    
    return nngp_train, ntk_train, nngp_test, ntk_test

def compute_spectrum(kernel: jnp.ndarray, k: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Compute top k eigenvalues and eigenvectors"""
    # Convert JAX array to numpy array for eigendecomposition
    kernel_np = np.array(kernel)
    n = kernel_np.shape[0]
    
    # Compute all eigenvalues and eigenvectors
    eigenvals, eigenvecs = eigh(kernel_np)
    
    # Sort in descending order and take top k
    idx = eigenvals.argsort()[::-1][:k]
    eigenvals = eigenvals[idx]
    eigenvecs = eigenvecs[:, idx]
    
    return eigenvals, eigenvecs

def plot_spectrum(eigenvals_nngp: jnp.ndarray, eigenvals_ntk: jnp.ndarray, 
                 save_path: str, title: str):
    """Plot eigenvalue spectrum for both NNGP and NTK"""
    plt.figure(figsize=(10, 6))
    plt.semilogy(np.arange(len(eigenvals_nngp)), eigenvals_nngp, 'b-', label='NNGP')
    plt.semilogy(np.arange(len(eigenvals_ntk)), eigenvals_ntk, 'r-', label='NTK')
    plt.xlabel('Index')
    plt.ylabel('Eigenvalue (log scale)')
    plt.title(title)
    plt.legend()
    plt.grid(True)
    plt.savefig(save_path)
    plt.close()

def plot_eigenfunctions(eigenvecs: jnp.ndarray, X: jnp.ndarray, y: jnp.ndarray,
                       save_path: str, title: str, k: int):
    """Plot eigenfunctions as a 1×n_train grid of colored squares"""
    # Convert JAX arrays to numpy arrays
    eigenvecs_np = np.array(eigenvecs)
    X_np = np.array(X)
    n_points = len(X_np)
    
    # Create figure with k subplots, each 1×n_train
    fig, axes = plt.subplots(k, 1, figsize=(20, 0.5*k))
    if k == 1:
        axes = [axes]
    
    # Get global min/max for consistent color scaling
    vmin = np.min(eigenvecs_np)
    vmax = np.max(eigenvecs_np)
    
    # Create colormap normalization
    norm = plt.Normalize(vmin=vmin, vmax=vmax)
    sm = plt.cm.ScalarMappable(cmap='viridis', norm=norm)
    
    # Plot each eigenfunction
    for i in range(k):
        # Reshape to 1×n_train for display
        values = eigenvecs_np[:, i].reshape(1, -1)
        
        # Display as a single row of squares
        im = axes[i].imshow(values, 
                          aspect='auto',  # This ensures squares fill the width
                          cmap='viridis',
                          vmin=vmin,
                          vmax=vmax)
        
        axes[i].set_title(f'Eigenfunction {i+1}')
        axes[i].set_yticks([])  # Remove y-axis ticks
        axes[i].set_xticks([])  # Remove x-axis ticks
    
    # Add single colorbar
    fig.colorbar(sm, ax=axes, label='Eigenfunction Value')
    
    plt.suptitle(title)
    plt.tight_layout()
    plt.savefig(save_path, bbox_inches='tight', dpi=300)
    plt.close()

def compute_nngp_eigenfunction_values(kernel: jnp.ndarray, eigenvecs: jnp.ndarray, train_kernel: jnp.ndarray, y_train: jnp.ndarray):
    """Compute NNGP eigenfunction values by projecting through function space"""
    # Compute coefficients using training data
    y_train = y_train.reshape(-1, 1)
    coeffs = jnp.dot(eigenvecs.T, y_train)
    
    # Project to get function values using the appropriate kernel
    function_values = jnp.dot(kernel, eigenvecs)
    
    # Scale by the coefficients
    function_values = function_values * coeffs.T
    
    return function_values

def compute_ntk_eigenfunction_values(kernel: jnp.ndarray, eigenvecs: jnp.ndarray):
    """Compute NTK eigenfunction values which are already in function space"""
    # For NTK, eigenfunction values are directly in function space
    # We just need to project using the kernel for test points
    function_values = jnp.dot(kernel, eigenvecs)
    return function_values

def analyze_model(model_path: str, 
                 X_train: torch.Tensor, y_train: torch.Tensor,
                 X_test: torch.Tensor, y_test: torch.Tensor,
                 d: int, hidden_size: int, depth: int, mode: str,
                 k: int, output_dir: str):
    """Main analysis function"""
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Convert data to JAX arrays
    X_train_jax = convert_data_to_jax(X_train)
    y_train_jax = convert_data_to_jax(y_train)
    X_test_jax = convert_data_to_jax(X_test)
    y_test_jax = convert_data_to_jax(y_test)
    
    # Print some of the actual function values
    print("\nActual function values (first 100 training points):")
    print(y_train_jax[:100])
    
    # Load model architecture
    architecture = DeepNN(d, hidden_size, depth, mode)
    
    # Compute kernels directly without loading parameters
    nngp_train, ntk_train, nngp_test, ntk_test = compute_kernels(X_train_jax, X_test_jax, None, architecture)
    
    # Compute spectra from training kernels
    nngp_vals, nngp_vecs = compute_spectrum(nngp_train, k)
    ntk_vals, ntk_vecs = compute_spectrum(ntk_train, k)
    
    # Compute NNGP eigenfunction values in function space
    nngp_train_func = compute_nngp_eigenfunction_values(nngp_train, nngp_vecs, nngp_train, y_train_jax)
    nngp_test_func = compute_nngp_eigenfunction_values(nngp_test, nngp_vecs, nngp_train, y_train_jax)
    
    # Compute NTK eigenfunction values (already in function space)
    ntk_train_func = compute_ntk_eigenfunction_values(ntk_train, ntk_vecs)
    ntk_test_func = compute_ntk_eigenfunction_values(ntk_test, ntk_vecs)
    
    # Print first 100 values of the projected eigenfunctions
    print("\nFirst 100 values of NNGP eigenfunctions in function space (training set):")
    for i in range(k):
        print(f"\nNNGP Eigenfunction {i+1}:")
        print(nngp_train_func[:100, i])
    
    print("\nFirst 100 values of NTK eigenfunctions in function space (training set):")
    for i in range(k):
        print(f"\nNTK Eigenfunction {i+1}:")
        print(ntk_train_func[:100, i])
        
    # Print kernel eigenvalues
    print("\nNNGP kernel eigenvalues:")
    print(nngp_vals)
    print("\nNTK kernel eigenvalues:")
    print(ntk_vals)
    
    # Plot spectra
    plot_spectrum(nngp_vals, ntk_vals,
                 os.path.join(output_dir, 'kernel_spectra.png'),
                 'Kernel Spectra')
    
    # Plot NNGP eigenfunctions
    plot_eigenfunctions(nngp_train_func, X_train_jax, y_train_jax,
                       os.path.join(output_dir, 'nngp_eigenfunctions_train.png'),
                       'NNGP Eigenfunctions (Training Set)', k)
    plot_eigenfunctions(nngp_test_func, X_test_jax, y_test_jax,
                       os.path.join(output_dir, 'nngp_eigenfunctions_test.png'),
                       'NNGP Eigenfunctions (Test Set)', k)
                       
    # Plot NTK eigenfunctions
    plot_eigenfunctions(ntk_train_func, X_train_jax, y_train_jax,
                       os.path.join(output_dir, 'ntk_eigenfunctions_train.png'),
                       'NTK Eigenfunctions (Training Set)', k)
    plot_eigenfunctions(ntk_test_func, X_test_jax, y_test_jax,
                       os.path.join(output_dir, 'ntk_eigenfunctions_test.png'),
                       'NTK Eigenfunctions (Test Set)', k)
    
    # For NTK, maybe we should handle it differently or investigate further
    # For now, commenting out NTK eigenfunction plots
    # plot_eigenfunctions(ntk_vecs, X_train_jax, y_train_jax,
    #                    os.path.join(output_dir, 'ntk_eigenfunctions_train.png'),
    #                    'NTK Eigenfunctions (Training Set)', k)
    # plot_eigenfunctions(ntk_test_vals, X_test_jax, y_test_jax,
    #                    os.path.join(output_dir, 'ntk_eigenfunctions_test.png'),
    #                    'NTK Eigenfunctions (Test Set)', k)

if __name__ == "__main__":
    # Model parameters
    P = 8
    d = 30
    hidden_size = 500
    depth = 1
    mode = 'mup_pennington'
    k = 10  # number of eigenvalues/eigenfunctions to analyze
    
    # Dataset parameters
    n_train = 8000
    n_test = 8000
    
    # Define MSP sets
    msp_sets = [{7}, {2,7}, {0,2,7}, {5,7,4}, {1}, {0,4}, {3,7}, {0,1,2,3,4,6,7}]
    
    # Initialize MSP function
    msp = MSPFunction(P, msp_sets)
    
    # Generate datasets
    X_train, y_train, X_test, y_test = generate_datasets(P, d, n_train, n_test, msp)
    
    # Model path and output directory
    model_path = "/mnt/users/goringn/NNs_vs_Kernels/stair_function/results/msp_NN_grid_1612_nogrokk_mup_pennington/initial_model_h500_d1_n20000_lr0.05_mup_pennington_20241219_163433_rank8.pt"
    output_dir = "ntk_analysis_results_init"
    
    # Run analysis
    analyze_model(model_path, X_train, y_train, X_test, y_test,
                 d, hidden_size, depth, mode, k, output_dir)


Actual function values (first 100 training points):
[ 0.  0.  4.  2.  2.  2.  2. -2. -2.  2.  0.  0.  2.  4.  4. -4.  0.  2.
 -2. -2.  4.  2. -6. -2. -2. -2. -2.  2.  4.  0.  0. -2.  4. -6.  4.  6.
 -6.  2.  2. -2.  6.  2. -2.  4.  0.  4.  2. -2. -2.  0.  0.  0. -4. -4.
 -2. -2.  2. -2.  0.  2.  2.  0.  0.  0.  4. -2. -2.  4.  0. -6.  0. -2.
 -4. -4. -6.  4.  0.  2.  4.  0.  0.  2.  0.  0. -2.  2.  2. -4.  0.  4.
 -2.  2. -2. -4.  0.  2.  4.  0. -4.  2.]

First 100 values of NNGP eigenfunctions in function space (training set):

NNGP Eigenfunction 1:
[-57.547653 -57.385853 -57.32267  -57.295376 -57.245842 -57.338284
 -57.532696 -57.59592  -57.843735 -57.495667 -57.524254 -57.76261
 -57.545918 -57.3961   -57.424057 -57.569164 -57.558372 -57.21671
 -57.682575 -57.50311  -57.557114 -57.54716  -57.538612 -57.656208
 -57.305386 -57.86829  -57.298805 -57.654217 -57.569412 -57.5054
 -57.22424  -57.68526  -57.264507 -57.3798   -57.432182 -57.39058
 -57.104225 -57.7068   -57.563557 -57.519077 

  plt.tight_layout()


In [23]:
import torch
import neural_tangents as nt
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Set, Tuple
import torch.nn as nn
import torch.optim as optim
import os
from scipy.linalg import eigh
from functools import partial

class MSPFunction:
    def __init__(self, P: int, sets: List[Set[int]]):
        self.P = P
        self.sets = sets
    
    def evaluate(self, z: torch.Tensor) -> torch.Tensor:
        device = z.device
        batch_size = z.shape[0]
        result = torch.zeros(batch_size, dtype=torch.float32, device=device)
        
        for S in self.sets:
            term = torch.ones(batch_size, dtype=torch.float32, device=device)
            for idx in S:
                term = term * z[:, idx]
            result = result + term
            
        return result

class DeepNN(nn.Module):
    def __init__(self, d: int, hidden_size: int, depth: int, mode: str = 'mup_pennington'):
        super().__init__()
        
        torch.set_default_dtype(torch.float32)
        
        self.mode = mode
        self.depth = depth
        self.hidden_size = hidden_size
        self.input_dim = d
        
        layers = []
        prev_dim = d
        self.layer_lrs = []  # Store layerwise learning rates
        
        for layer_idx in range(depth):
            linear = nn.Linear(prev_dim, hidden_size)
            
            if mode == 'mup_pennington':
                # muP initialization and learning rates from the paper
                if layer_idx == 0:  # Embedding layer
                    std = 1.0 / np.sqrt(prev_dim)
                    lr_scale = 1.0  # O(1) learning rate for embedding
                else:  # Hidden layers
                    std = 1.0 / np.sqrt(prev_dim)
                    lr_scale = 1.0 / prev_dim  # O(1/n) learning rate for hidden
                nn.init.normal_(linear.weight, mean=0.0, std=std)
                nn.init.zeros_(linear.bias)
                self.layer_lrs.append(lr_scale)
            
            layers.extend([
                linear,
                nn.ReLU()
            ])
            prev_dim = hidden_size
        
        # Final layer
        final_layer = nn.Linear(prev_dim, 1)
        if mode == 'mup_pennington':
            std = 1.0 / np.sqrt(prev_dim)
            lr_scale = 1.0 / prev_dim  # O(1/n) learning rate for readout
            nn.init.normal_(final_layer.weight, std=std)
            self.layer_lrs.append(lr_scale)
            
        nn.init.zeros_(final_layer.bias)
        layers.append(final_layer)
        
        self.network = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x).squeeze()

def generate_datasets(P: int, d: int, n_train: int, n_test: int, msp: MSPFunction, seed: int = 42):
    """Generate training and test datasets"""
    torch.manual_seed(seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    X_train = (2 * torch.bernoulli(0.5 * torch.ones((n_train, d), dtype=torch.float32)) - 1).to(device)
    y_train = msp.evaluate(X_train)
    
    X_test = (2 * torch.bernoulli(0.5 * torch.ones((n_test, d), dtype=torch.float32)) - 1).to(device)
    y_test = msp.evaluate(X_test)
    
    return X_train, y_train, X_test, y_test

def convert_data_to_jax(data: torch.Tensor) -> jnp.ndarray:
    """Convert PyTorch tensor to JAX array"""
    return jnp.array(data.detach().cpu().numpy())

def compute_kernels(X_train: jnp.ndarray, X_test: jnp.ndarray, params: List[jnp.ndarray], 
                   architecture: nn.Module) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """Compute NNGP and NTK kernels for both train and test sets"""
    # Create neural_tangents network
    layers = []
    
    # First layer
    layers.extend([
        nt.stax.Dense(architecture.hidden_size),
        nt.stax.Relu()
    ])
    
    # Additional hidden layers
    for _ in range(architecture.depth - 1):
        layers.extend([
            nt.stax.Dense(architecture.hidden_size),
            nt.stax.Relu()
        ])
    
    # Output layer
    layers.append(nt.stax.Dense(1))
    
    # Create the network
    init_fn, apply_fn, kernel_fn = nt.stax.serial(*layers)
    
    # Compute train kernels
    ntk_train = kernel_fn(X_train, X_train, 'ntk')
    nngp_train = kernel_fn(X_train, X_train, 'nngp')
    
    # Compute test-train kernels
    ntk_test = kernel_fn(X_test, X_train, 'ntk')
    nngp_test = kernel_fn(X_test, X_train, 'nngp')
    
    return nngp_train, ntk_train, nngp_test, ntk_test

def compute_spectrum(kernel: jnp.ndarray, k: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Compute top k eigenvalues and eigenvectors"""
    # Convert JAX array to numpy array for eigendecomposition
    kernel_np = np.array(kernel)
    n = kernel_np.shape[0]
    
    # Compute all eigenvalues and eigenvectors
    eigenvals, eigenvecs = eigh(kernel_np)
    
    # Sort in descending order and take top k
    idx = eigenvals.argsort()[::-1][:k]
    eigenvals = eigenvals[idx]
    eigenvecs = eigenvecs[:, idx]
    
    return eigenvals, eigenvecs

def plot_spectrum(eigenvals_nngp: jnp.ndarray, eigenvals_ntk: jnp.ndarray, 
                 save_path: str, title: str):
    """Plot eigenvalue spectrum for both NNGP and NTK"""
    plt.figure(figsize=(10, 6))
    plt.semilogy(np.arange(len(eigenvals_nngp)), eigenvals_nngp, 'b-', label='NNGP')
    plt.semilogy(np.arange(len(eigenvals_ntk)), eigenvals_ntk, 'r-', label='NTK')
    plt.xlabel('Index')
    plt.ylabel('Eigenvalue (log scale)')
    plt.title(title)
    plt.legend()
    plt.grid(True)
    plt.savefig(save_path)
    plt.close()

def binary_array_to_int(arr):
    """Convert binary array to integer"""
    binary_str = ''.join(['1' if x > 0 else '0' for x in arr])
    return int(binary_str, 2)

def get_gray_code_ordering(X: np.ndarray) -> np.ndarray:
    """Create ordering of points based on gray code of binary inputs"""
    # Convert inputs to binary (assuming they're ±1)
    X_binary = (X > 0).astype(int)
    
    # Convert each row to its gray code value
    gray_values = []
    for x in X_binary:
        val = binary_array_to_int(x)
        gray_val = val ^ (val >> 1)  # Convert to gray code
        gray_values.append(gray_val)
    
    # Return ordering
    return np.argsort(gray_values)

def plot_eigenfunctions(eigenfuncs: jnp.ndarray, X: jnp.ndarray, y: jnp.ndarray,
                       save_path: str, title: str, k: int):
    """Plot eigenfunctions as a 1×n_train grid of colored squares with gray code ordering"""
    # Convert JAX arrays to numpy arrays
    eigenfuncs_np = np.array(eigenfuncs)
    X_np = np.array(X)
    n_points = len(X_np)
    
    # Get gray code ordering
    order = get_gray_code_ordering(X_np)
    
    # Create figure with k subplots
    fig, axes = plt.subplots(k, 1, figsize=(15, 1.5*k))
    if k == 1:
        axes = [axes]
    
    # Get global min/max for consistent color scaling
    vmin = np.min(eigenfuncs_np)
    vmax = np.max(eigenfuncs_np)
    
    # Create colormap normalization
    norm = plt.Normalize(vmin=vmin, vmax=vmax)
    sm = plt.cm.ScalarMappable(cmap='rainbow', norm=norm)
    
    # Plot each eigenfunction
    for i in range(k):
        # Reorder values using gray code ordering and reshape to 1×n_train
        values = eigenfuncs_np[order, i].reshape(1, -1)
        
        # Display as a single row of squares
        im = axes[i].imshow(values, 
                          aspect=0.1,  # Reduced aspect ratio to make squares narrower
                          cmap='rainbow',
                          vmin=vmin,
                          vmax=vmax)
        
        axes[i].set_title(f'Eigenfunction {i+1}')
        axes[i].set_yticks([])  # Remove y-axis ticks
        axes[i].set_xticks([])  # Remove x-axis ticks
    
    # Add single colorbar
    fig.colorbar(sm, ax=axes, label='Eigenfunction Value')
    
    plt.suptitle(title)
    plt.tight_layout()
    plt.savefig(save_path, bbox_inches='tight', dpi=300)
    plt.close()

def compute_nngp_eigenfunction_values(kernel: jnp.ndarray, eigenvecs: jnp.ndarray, train_kernel: jnp.ndarray, y_train: jnp.ndarray):
    """Compute NNGP eigenfunction values by projecting through function space"""
    # Compute coefficients using training data
    y_train = y_train.reshape(-1, 1)
    coeffs = jnp.dot(eigenvecs.T, y_train)
    
    # Project to get function values using the appropriate kernel
    function_values = jnp.dot(kernel, eigenvecs)
    
    # Scale by the coefficients
    function_values = function_values * coeffs.T
    
    return function_values

def compute_ntk_eigenfunction_values(kernel: jnp.ndarray, eigenvecs: jnp.ndarray):
    """Compute NTK eigenfunction values which are already in function space"""
    # For NTK, eigenfunction values are directly in function space
    # We just need to project using the kernel for test points
    function_values = jnp.dot(kernel, eigenvecs)
    return function_values

def analyze_model(model_path: str, 
                 X_train: torch.Tensor, y_train: torch.Tensor,
                 X_test: torch.Tensor, y_test: torch.Tensor,
                 d: int, hidden_size: int, depth: int, mode: str,
                 k: int, output_dir: str):
    """Main analysis function"""
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Convert data to JAX arrays
    X_train_jax = convert_data_to_jax(X_train)
    y_train_jax = convert_data_to_jax(y_train)
    X_test_jax = convert_data_to_jax(X_test)
    y_test_jax = convert_data_to_jax(y_test)
    
    # Print some of the actual function values
    print("\nActual function values (first 100 training points):")
    print(y_train_jax[:100])
    
    # Load model architecture
    architecture = DeepNN(d, hidden_size, depth, mode)
    
    # Compute kernels directly without loading parameters
    nngp_train, ntk_train, nngp_test, ntk_test = compute_kernels(X_train_jax, X_test_jax, None, architecture)
    
    # Compute spectra from training kernels
    nngp_vals, nngp_vecs = compute_spectrum(nngp_train, k)
    ntk_vals, ntk_vecs = compute_spectrum(ntk_train, k)
    
    # Compute NNGP eigenfunction values in function space
    nngp_train_func = compute_nngp_eigenfunction_values(nngp_train, nngp_vecs, nngp_train, y_train_jax)
    nngp_test_func = compute_nngp_eigenfunction_values(nngp_test, nngp_vecs, nngp_train, y_train_jax)
    
    # Compute NTK eigenfunction values (already in function space)
    ntk_train_func = compute_ntk_eigenfunction_values(ntk_train, ntk_vecs)
    ntk_test_func = compute_ntk_eigenfunction_values(ntk_test, ntk_vecs)
    
    # Print first 100 values of the projected eigenfunctions
    print("\nFirst 100 values of NNGP eigenfunctions in function space (training set):")
    for i in range(k):
        print(f"\nNNGP Eigenfunction {i+1}:")
        print(nngp_train_func[:100, i])
    
    print("\nFirst 100 values of NTK eigenfunctions in function space (training set):")
    for i in range(k):
        print(f"\nNTK Eigenfunction {i+1}:")
        print(ntk_train_func[:100, i])
        
    # Print kernel eigenvalues
    print("\nNNGP kernel eigenvalues:")
    print(nngp_vals)
    print("\nNTK kernel eigenvalues:")
    print(ntk_vals)
    
    # Plot spectra
    plot_spectrum(nngp_vals, ntk_vals,
                 os.path.join(output_dir, 'kernel_spectra.png'),
                 'Kernel Spectra')
    
    # Plot NNGP eigenfunctions
    plot_eigenfunctions(nngp_train_func, X_train_jax, y_train_jax,
                       os.path.join(output_dir, 'nngp_eigenfunctions_train.png'),
                       'NNGP Eigenfunctions (Training Set)', k)
    plot_eigenfunctions(nngp_test_func, X_test_jax, y_test_jax,
                       os.path.join(output_dir, 'nngp_eigenfunctions_test.png'),
                       'NNGP Eigenfunctions (Test Set)', k)
                       
    # Plot NTK eigenfunctions
    plot_eigenfunctions(ntk_train_func, X_train_jax, y_train_jax,
                       os.path.join(output_dir, 'ntk_eigenfunctions_train.png'),
                       'NTK Eigenfunctions (Training Set)', k)
    plot_eigenfunctions(ntk_test_func, X_test_jax, y_test_jax,
                       os.path.join(output_dir, 'ntk_eigenfunctions_test.png'),
                       'NTK Eigenfunctions (Test Set)', k)
    
    # For NTK, maybe we should handle it differently or investigate further
    # For now, commenting out NTK eigenfunction plots
    # plot_eigenfunctions(ntk_vecs, X_train_jax, y_train_jax,
    #                    os.path.join(output_dir, 'ntk_eigenfunctions_train.png'),
    #                    'NTK Eigenfunctions (Training Set)', k)
    # plot_eigenfunctions(ntk_test_vals, X_test_jax, y_test_jax,
    #                    os.path.join(output_dir, 'ntk_eigenfunctions_test.png'),
    #                    'NTK Eigenfunctions (Test Set)', k)

if __name__ == "__main__":
    # Model parameters
    P = 8
    d = 30
    hidden_size = 500
    depth = 1
    mode = 'mup_pennington'
    k = 10  # number of eigenvalues/eigenfunctions to analyze
    
    # Dataset parameters
    n_train = 800
    n_test = 700
    
    # Define MSP sets
    msp_sets = [{7}, {2,7}, {0,2,7}, {5,7,4}, {1}, {0,4}, {3,7}, {0,1,2,3,4,6,7}]
    
    # Initialize MSP function
    msp = MSPFunction(P, msp_sets)
    
    # Generate datasets
    X_train, y_train, X_test, y_test = generate_datasets(P, d, n_train, n_test, msp)
    
    # Model path and output directory
    model_path = "/mnt/users/goringn/NNs_vs_Kernels/stair_function/results/msp_NN_grid_1612_nogrokk_mup_pennington/final_model_h500_d1_n20000_lr0.05_mup_pennington_20241219_163433_rank8.pt"
    output_dir = "ntk_analysis_results"
    
    # Run analysis
    analyze_model(model_path, X_train, y_train, X_test, y_test,
                 d, hidden_size, depth, mode, k, output_dir)


Actual function values (first 100 training points):
[ 0.  0.  4.  2.  2.  2.  2. -2. -2.  2.  0.  0.  2.  4.  4. -4.  0.  2.
 -2. -2.  4.  2. -6. -2. -2. -2. -2.  2.  4.  0.  0. -2.  4. -6.  4.  6.
 -6.  2.  2. -2.  6.  2. -2.  4.  0.  4.  2. -2. -2.  0.  0.  0. -4. -4.
 -2. -2.  2. -2.  0.  2.  2.  0.  0.  0.  4. -2. -2.  4.  0. -6.  0. -2.
 -4. -4. -6.  4.  0.  2.  4.  0.  0.  2.  0.  0. -2.  2.  2. -4.  0.  4.
 -2.  2. -2. -4.  0.  2.  4.  0. -4.  2.]

First 100 values of NNGP eigenfunctions in function space (training set):

NNGP Eigenfunction 1:
[-12.92087   -12.851156  -12.693413  -12.662748  -13.010132  -12.97372
 -12.797591  -13.00122   -13.122744  -12.792023  -12.926179  -13.005964
 -13.104382  -12.627592  -12.787695  -12.816565  -13.016538  -12.606921
 -12.933316  -12.867984  -12.791466  -12.844892  -12.905329  -12.740396
 -12.669813  -13.078498  -12.636147  -13.059945  -13.17651   -12.750199
 -12.47125   -13.0295925 -12.9372635 -12.906089  -12.996647  -12.926755
 -12.65088 

  plt.tight_layout()


In [22]:
import torch
import neural_tangents as nt
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Set, Tuple
import torch.nn as nn
import torch.optim as optim
import os
from scipy.linalg import eigh
from functools import partial

class MSPFunction:
    def __init__(self, P: int, sets: List[Set[int]]):
        self.P = P
        self.sets = sets
    
    def evaluate(self, z: torch.Tensor) -> torch.Tensor:
        device = z.device
        batch_size = z.shape[0]
        result = torch.zeros(batch_size, dtype=torch.float32, device=device)
        
        for S in self.sets:
            term = torch.ones(batch_size, dtype=torch.float32, device=device)
            for idx in S:
                term = term * z[:, idx]
            result = result + term
            
        return result

class DeepNN(nn.Module):
    def __init__(self, d: int, hidden_size: int, depth: int, mode: str = 'mup_pennington'):
        super().__init__()
        
        torch.set_default_dtype(torch.float32)
        
        self.mode = mode
        self.depth = depth
        self.hidden_size = hidden_size
        self.input_dim = d
        
        layers = []
        prev_dim = d
        self.layer_lrs = []  # Store layerwise learning rates
        
        for layer_idx in range(depth):
            linear = nn.Linear(prev_dim, hidden_size)
            
            if mode == 'mup_pennington':
                # muP initialization and learning rates from the paper
                if layer_idx == 0:  # Embedding layer
                    std = 1.0 / np.sqrt(prev_dim)
                    lr_scale = 1.0  # O(1) learning rate for embedding
                else:  # Hidden layers
                    std = 1.0 / np.sqrt(prev_dim)
                    lr_scale = 1.0 / prev_dim  # O(1/n) learning rate for hidden
                nn.init.normal_(linear.weight, mean=0.0, std=std)
                nn.init.zeros_(linear.bias)
                self.layer_lrs.append(lr_scale)
            
            layers.extend([
                linear,
                nn.ReLU()
            ])
            prev_dim = hidden_size
        
        # Final layer
        final_layer = nn.Linear(prev_dim, 1)
        if mode == 'mup_pennington':
            std = 1.0 / np.sqrt(prev_dim)
            lr_scale = 1.0 / prev_dim  # O(1/n) learning rate for readout
            nn.init.normal_(final_layer.weight, std=std)
            self.layer_lrs.append(lr_scale)
            
        nn.init.zeros_(final_layer.bias)
        layers.append(final_layer)
        
        self.network = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x).squeeze()

def generate_datasets(P: int, d: int, n_train: int, n_test: int, msp: MSPFunction, seed: int = 42):
    """Generate training and test datasets"""
    torch.manual_seed(seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    X_train = (2 * torch.bernoulli(0.5 * torch.ones((n_train, d), dtype=torch.float32)) - 1).to(device)
    y_train = msp.evaluate(X_train)
    
    X_test = (2 * torch.bernoulli(0.5 * torch.ones((n_test, d), dtype=torch.float32)) - 1).to(device)
    y_test = msp.evaluate(X_test)
    
    return X_train, y_train, X_test, y_test

def convert_data_to_jax(data: torch.Tensor) -> jnp.ndarray:
    """Convert PyTorch tensor to JAX array"""
    return jnp.array(data.detach().cpu().numpy())

def compute_kernels(X_train: jnp.ndarray, X_test: jnp.ndarray, params: List[jnp.ndarray], 
                   architecture: nn.Module) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """Compute NNGP and NTK kernels for both train and test sets"""
    # Create neural_tangents network
    layers = []
    
    # First layer
    layers.extend([
        nt.stax.Dense(architecture.hidden_size),
        nt.stax.Relu()
    ])
    
    # Additional hidden layers
    for _ in range(architecture.depth - 1):
        layers.extend([
            nt.stax.Dense(architecture.hidden_size),
            nt.stax.Relu()
        ])
    
    # Output layer
    layers.append(nt.stax.Dense(1))
    
    # Create the network
    init_fn, apply_fn, kernel_fn = nt.stax.serial(*layers)
    
    # Compute train kernels
    ntk_train = kernel_fn(X_train, X_train, 'ntk')
    nngp_train = kernel_fn(X_train, X_train, 'nngp')
    
    # Compute test-train kernels
    ntk_test = kernel_fn(X_test, X_train, 'ntk')
    nngp_test = kernel_fn(X_test, X_train, 'nngp')
    
    return nngp_train, ntk_train, nngp_test, ntk_test

def compute_spectrum(kernel: jnp.ndarray, k: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Compute top k eigenvalues and eigenvectors"""
    # Convert JAX array to numpy array for eigendecomposition
    kernel_np = np.array(kernel)
    n = kernel_np.shape[0]
    
    # Compute all eigenvalues and eigenvectors
    eigenvals, eigenvecs = eigh(kernel_np)
    
    # Sort in descending order and take top k
    idx = eigenvals.argsort()[::-1][:k]
    eigenvals = eigenvals[idx]
    eigenvecs = eigenvecs[:, idx]
    
    return eigenvals, eigenvecs

def plot_spectrum(eigenvals_nngp: jnp.ndarray, eigenvals_ntk: jnp.ndarray, 
                 save_path: str, title: str):
    """Plot eigenvalue spectrum for both NNGP and NTK"""
    plt.figure(figsize=(10, 6))
    plt.semilogy(np.arange(len(eigenvals_nngp)), eigenvals_nngp, 'b-', label='NNGP')
    plt.semilogy(np.arange(len(eigenvals_ntk)), eigenvals_ntk, 'r-', label='NTK')
    plt.xlabel('Index')
    plt.ylabel('Eigenvalue (log scale)')
    plt.title(title)
    plt.legend()
    plt.grid(True)
    plt.savefig(save_path)
    plt.close()

def binary_array_to_int(arr):
    """Convert binary array to integer"""
    binary_str = ''.join(['1' if x > 0 else '0' for x in arr])
    return int(binary_str, 2)

def get_gray_code_ordering(X: np.ndarray) -> np.ndarray:
    """Create ordering of points based on gray code of binary inputs"""
    # Convert inputs to binary (assuming they're ±1)
    X_binary = (X > 0).astype(int)
    
    # Convert each row to its gray code value
    gray_values = []
    for x in X_binary:
        val = binary_array_to_int(x)
        gray_val = val ^ (val >> 1)  # Convert to gray code
        gray_values.append(gray_val)
    
    # Return ordering
    return np.argsort(gray_values)

def plot_eigenfunctions(eigenfuncs: jnp.ndarray, X: jnp.ndarray, y: jnp.ndarray,
                       save_path: str, title: str, k: int,
                       box_width: float = 0.1,    # Width of each box in inches
                       box_height: float = 10.0):   # Height of each box in inches
    """
    Plot eigenfunctions as a 1×n_train grid of colored squares with gray code ordering.
    
    Parameters:
        box_width: Width of each colored box in inches
        box_height: Height of each colored box in inches
    """
    # Convert JAX arrays to numpy arrays
    eigenfuncs_np = np.array(eigenfuncs)
    X_np = np.array(X)
    n_points = len(X_np)
    
    # Get gray code ordering
    order = get_gray_code_ordering(X_np)
    
    # Calculate figure size based on box dimensions
    total_width = box_width * n_points + 2  # Add 2 inches for margins and colorbar
    total_height = k * (box_height + 0.5)   # Add 0.5 inch spacing between rows
    
    # Create figure with k subplots
    fig, axes = plt.subplots(k, 1, figsize=(total_width, total_height))
    if k == 1:
        axes = [axes]
    
    # Add more space between subplots
    plt.subplots_adjust(hspace=0.3)
    
    # Get global min/max for consistent color scaling
    vmin = np.min(eigenfuncs_np)
    vmax = np.max(eigenfuncs_np)
    
    # Create colormap normalization
    norm = plt.Normalize(vmin=vmin, vmax=vmax)
    sm = plt.cm.ScalarMappable(cmap='rainbow', norm=norm)
    
    # Calculate aspect ratio to maintain specified box dimensions
    aspect = box_width / box_height
    
    # Plot each eigenfunction
    for i in range(k):
        # Reorder values using gray code ordering and reshape to 1×n_train
        values = eigenfuncs_np[order, i].reshape(1, -1)
        
        # Display as a single row of squares
        im = axes[i].imshow(values, 
                          aspect=aspect,
                          cmap='rainbow',
                          vmin=vmin,
                          vmax=vmax,
                          interpolation='nearest')
        
        axes[i].set_title(f'Eigenfunction {i+1}')
        axes[i].set_yticks([])
        axes[i].set_xticks([])
    
    # Add single colorbar
    fig.colorbar(sm, ax=axes, label='Eigenfunction Value')
    
    plt.suptitle(title)
    plt.tight_layout()
    plt.savefig(save_path, bbox_inches='tight', dpi=300)
    plt.close()

def compute_nngp_eigenfunction_values(kernel: jnp.ndarray, eigenvecs: jnp.ndarray, train_kernel: jnp.ndarray, y_train: jnp.ndarray):
    """Compute NNGP eigenfunction values by projecting through function space"""
    # Compute coefficients using training data
    y_train = y_train.reshape(-1, 1)
    coeffs = jnp.dot(eigenvecs.T, y_train)
    
    # Project to get function values using the appropriate kernel
    function_values = jnp.dot(kernel, eigenvecs)
    
    # Scale by the coefficients
    function_values = function_values * coeffs.T
    
    return function_values

def compute_ntk_eigenfunction_values(kernel: jnp.ndarray, eigenvecs: jnp.ndarray):
    """Compute NTK eigenfunction values which are already in function space"""
    # For NTK, eigenfunction values are directly in function space
    # We just need to project using the kernel for test points
    function_values = jnp.dot(kernel, eigenvecs)
    return function_values

def analyze_model(model_path: str, 
                 X_train: torch.Tensor, y_train: torch.Tensor,
                 X_test: torch.Tensor, y_test: torch.Tensor,
                 d: int, hidden_size: int, depth: int, mode: str,
                 k: int, output_dir: str):
    """Main analysis function"""
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Convert data to JAX arrays
    X_train_jax = convert_data_to_jax(X_train)
    y_train_jax = convert_data_to_jax(y_train)
    X_test_jax = convert_data_to_jax(X_test)
    y_test_jax = convert_data_to_jax(y_test)
    
    # Print some of the actual function values
    print("\nActual function values (first 100 training points):")
    print(y_train_jax[:100])
    
    # Load model architecture
    architecture = DeepNN(d, hidden_size, depth, mode)
    
    # Compute kernels directly without loading parameters
    nngp_train, ntk_train, nngp_test, ntk_test = compute_kernels(X_train_jax, X_test_jax, None, architecture)
    
    # Compute spectra from training kernels
    nngp_vals, nngp_vecs = compute_spectrum(nngp_train, k)
    ntk_vals, ntk_vecs = compute_spectrum(ntk_train, k)
    
    # Compute NNGP eigenfunction values in function space
    nngp_train_func = compute_nngp_eigenfunction_values(nngp_train, nngp_vecs, nngp_train, y_train_jax)
    nngp_test_func = compute_nngp_eigenfunction_values(nngp_test, nngp_vecs, nngp_train, y_train_jax)
    
    # Compute NTK eigenfunction values (already in function space)
    ntk_train_func = compute_ntk_eigenfunction_values(ntk_train, ntk_vecs)
    ntk_test_func = compute_ntk_eigenfunction_values(ntk_test, ntk_vecs)
    
    # Print first 100 values of the projected eigenfunctions
    print("\nFirst 100 values of NNGP eigenfunctions in function space (training set):")
    for i in range(k):
        print(f"\nNNGP Eigenfunction {i+1}:")
        print(nngp_train_func[:100, i])
    
    print("\nFirst 100 values of NTK eigenfunctions in function space (training set):")
    for i in range(k):
        print(f"\nNTK Eigenfunction {i+1}:")
        print(ntk_train_func[:100, i])
        
    # Print kernel eigenvalues
    print("\nNNGP kernel eigenvalues:")
    print(nngp_vals)
    print("\nNTK kernel eigenvalues:")
    print(ntk_vals)
    
    # Plot spectra
    plot_spectrum(nngp_vals, ntk_vals,
                 os.path.join(output_dir, 'kernel_spectra.png'),
                 'Kernel Spectra')
    
    # Plot NNGP eigenfunctions
    plot_eigenfunctions(nngp_train_func, X_train_jax, y_train_jax,
                       os.path.join(output_dir, 'nngp_eigenfunctions_train.png'),
                       'NNGP Eigenfunctions (Training Set)', k)
    plot_eigenfunctions(nngp_test_func, X_test_jax, y_test_jax,
                       os.path.join(output_dir, 'nngp_eigenfunctions_test.png'),
                       'NNGP Eigenfunctions (Test Set)', k)
                       
    # Plot NTK eigenfunctions
    plot_eigenfunctions(ntk_train_func, X_train_jax, y_train_jax,
                       os.path.join(output_dir, 'ntk_eigenfunctions_train.png'),
                       'NTK Eigenfunctions (Training Set)', k)
    plot_eigenfunctions(ntk_test_func, X_test_jax, y_test_jax,
                       os.path.join(output_dir, 'ntk_eigenfunctions_test.png'),
                       'NTK Eigenfunctions (Test Set)', k)
    
    # For NTK, maybe we should handle it differently or investigate further
    # For now, commenting out NTK eigenfunction plots
    # plot_eigenfunctions(ntk_vecs, X_train_jax, y_train_jax,
    #                    os.path.join(output_dir, 'ntk_eigenfunctions_train.png'),
    #                    'NTK Eigenfunctions (Training Set)', k)
    # plot_eigenfunctions(ntk_test_vals, X_test_jax, y_test_jax,
    #                    os.path.join(output_dir, 'ntk_eigenfunctions_test.png'),
    #                    'NTK Eigenfunctions (Test Set)', k)

if __name__ == "__main__":
    # Model parameters
    P = 8
    d = 30
    hidden_size = 500
    depth = 1
    mode = 'mup_pennington'
    k = 10  # number of eigenvalues/eigenfunctions to analyze
    
    # Dataset parameters
    n_train = 800
    n_test = 700
    
    # Define MSP sets
    msp_sets = [{7}, {2,7}, {0,2,7}, {5,7,4}, {1}, {0,4}, {3,7}, {0,1,2,3,4,6,7}]
    
    # Initialize MSP function
    msp = MSPFunction(P, msp_sets)
    
    # Generate datasets
    X_train, y_train, X_test, y_test = generate_datasets(P, d, n_train, n_test, msp)
    
    # Model path and output directory
    model_path = "/mnt/users/goringn/NNs_vs_Kernels/stair_function/results/msp_NN_grid_1612_nogrokk_mup_pennington/final_model_h500_d1_n20000_lr0.05_mup_pennington_20241219_163433_rank8.pt"
    output_dir = "ntk_analysis_results"
    
    # Run analysis
    analyze_model(model_path, X_train, y_train, X_test, y_test,
                 d, hidden_size, depth, mode, k, output_dir)


Actual function values (first 100 training points):
[ 0.  0.  4.  2.  2.  2.  2. -2. -2.  2.  0.  0.  2.  4.  4. -4.  0.  2.
 -2. -2.  4.  2. -6. -2. -2. -2. -2.  2.  4.  0.  0. -2.  4. -6.  4.  6.
 -6.  2.  2. -2.  6.  2. -2.  4.  0.  4.  2. -2. -2.  0.  0.  0. -4. -4.
 -2. -2.  2. -2.  0.  2.  2.  0.  0.  0.  4. -2. -2.  4.  0. -6.  0. -2.
 -4. -4. -6.  4.  0.  2.  4.  0.  0.  2.  0.  0. -2.  2.  2. -4.  0.  4.
 -2.  2. -2. -4.  0.  2.  4.  0. -4.  2.]

First 100 values of NNGP eigenfunctions in function space (training set):

NNGP Eigenfunction 1:
[-12.92087   -12.851156  -12.693413  -12.662748  -13.010132  -12.97372
 -12.797591  -13.00122   -13.122744  -12.792023  -12.926179  -13.005964
 -13.104382  -12.627592  -12.787695  -12.816565  -13.016538  -12.606921
 -12.933316  -12.867984  -12.791466  -12.844892  -12.905329  -12.740396
 -12.669813  -13.078498  -12.636147  -13.059945  -13.17651   -12.750199
 -12.47125   -13.0295925 -12.9372635 -12.906089  -12.996647  -12.926755
 -12.65088 

  plt.tight_layout()


In [5]:
import torch
import neural_tangents as nt
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Set, Tuple
import torch.nn as nn
import torch.optim as optim
import os
from scipy.linalg import eigh
from functools import partial

class MSPFunction:
    def __init__(self, P: int, sets: List[Set[int]]):
        self.P = P
        self.sets = sets
    
    def evaluate(self, z: torch.Tensor) -> torch.Tensor:
        device = z.device
        batch_size = z.shape[0]
        result = torch.zeros(batch_size, dtype=torch.float32, device=device)
        
        for S in self.sets:
            term = torch.ones(batch_size, dtype=torch.float32, device=device)
            for idx in S:
                term = term * z[:, idx]
            result = result + term
            
        return result

class DeepNN(nn.Module):
    def __init__(self, d: int, hidden_size: int, depth: int, mode: str = 'mup_pennington'):
        super().__init__()
        
        torch.set_default_dtype(torch.float32)
        
        self.mode = mode
        self.depth = depth
        self.hidden_size = hidden_size
        self.input_dim = d
        
        layers = []
        prev_dim = d
        self.layer_lrs = []  # Store layerwise learning rates
        
        for layer_idx in range(depth):
            linear = nn.Linear(prev_dim, hidden_size)
            
            if mode == 'mup_pennington':
                # muP initialization and learning rates from the paper
                if layer_idx == 0:  # Embedding layer
                    std = 1.0 / np.sqrt(prev_dim)
                    lr_scale = 1.0  # O(1) learning rate for embedding
                else:  # Hidden layers
                    std = 1.0 / np.sqrt(prev_dim)
                    lr_scale = 1.0 / prev_dim  # O(1/n) learning rate for hidden
                nn.init.normal_(linear.weight, mean=0.0, std=std)
                nn.init.zeros_(linear.bias)
                self.layer_lrs.append(lr_scale)
            
            layers.extend([
                linear,
                nn.ReLU()
            ])
            prev_dim = hidden_size
        
        # Final layer
        final_layer = nn.Linear(prev_dim, 1)
        if mode == 'mup_pennington':
            std = 1.0 / np.sqrt(prev_dim)
            lr_scale = 1.0 / prev_dim  # O(1/n) learning rate for readout
            nn.init.normal_(final_layer.weight, std=std)
            self.layer_lrs.append(lr_scale)
            
        nn.init.zeros_(final_layer.bias)
        layers.append(final_layer)
        
        self.network = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x).squeeze()

def convert_data_to_jax(data: torch.Tensor) -> jnp.ndarray:
    """Convert PyTorch tensor to JAX array"""
    return jnp.array(data.detach().cpu().numpy())

def compute_kernels(X_train: jnp.ndarray, X_test: jnp.ndarray, params: List[jnp.ndarray], 
                   architecture: nn.Module) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """Compute NNGP and NTK kernels for both train and test sets"""
    # Create neural_tangents network
    layers = []
    
    # First layer
    layers.extend([
        nt.stax.Dense(architecture.hidden_size),
        nt.stax.Relu()
    ])
    
    # Additional hidden layers
    for _ in range(architecture.depth - 1):
        layers.extend([
            nt.stax.Dense(architecture.hidden_size),
            nt.stax.Relu()
        ])
    
    # Output layer
    layers.append(nt.stax.Dense(1))
    
    # Create the network
    init_fn, apply_fn, kernel_fn = nt.stax.serial(*layers)
    
    # Compute train kernels
    ntk_train = kernel_fn(X_train, X_train, 'ntk')
    nngp_train = kernel_fn(X_train, X_train, 'nngp')
    
    # Compute test-train kernels
    ntk_test = kernel_fn(X_test, X_train, 'ntk')
    nngp_test = kernel_fn(X_test, X_train, 'nngp')
    
    return nngp_train, ntk_train, nngp_test, ntk_test

def compute_spectrum(kernel: jnp.ndarray, k: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Compute top k eigenvalues and eigenvectors"""
    # Convert JAX array to numpy array for eigendecomposition
    kernel_np = np.array(kernel)
    
    # Compute eigenvalues and eigenvectors
    eigenvals, eigenvecs = eigh(kernel_np)
    
    # Sort in descending order and take top k
    idx = eigenvals.argsort()[::-1][:k]
    eigenvals = eigenvals[idx]
    eigenvecs = eigenvecs[:, idx]
    
    return eigenvals, eigenvecs

def plot_spectrum(eigenvals_nngp: jnp.ndarray, eigenvals_ntk: jnp.ndarray, 
                 save_path: str, title: str):
    """Plot eigenvalue spectrum for both NNGP and NTK"""
    plt.figure(figsize=(10, 6))
    plt.semilogy(np.arange(len(eigenvals_nngp)), eigenvals_nngp, 'b-', label='NNGP')
    plt.semilogy(np.arange(len(eigenvals_ntk)), eigenvals_ntk, 'r-', label='NTK')
    plt.xlabel('Index')
    plt.ylabel('Eigenvalue (log scale)')
    plt.title(title)
    plt.legend()
    plt.grid(True)
    plt.savefig(save_path)
    plt.close()

def plot_eigenfunctions_parallel(eigenfuncs: jnp.ndarray, X: jnp.ndarray, y: jnp.ndarray,
                             save_path: str, title: str, k: int):
    """Plot eigenfunctions using parallel coordinates"""
    # Convert JAX arrays to numpy arrays
    eigenfuncs_np = np.array(eigenfuncs)
    X_np = np.array(X)
    n_points = len(X_np)
    
    # Get gray code ordering
    order = get_gray_code_ordering(X_np)
    eigenfuncs_np = eigenfuncs_np[order]
    
    # Create k subplots vertically
    fig, axes = plt.subplots(k, 1, figsize=(12, 3*k))
    if k == 1:
        axes = [axes]
    
    # Get global min/max for consistent scaling
    vmin = np.min(eigenfuncs_np)
    vmax = np.max(eigenfuncs_np)
    
    # For consistent coloring across all plots
    norm = plt.Normalize(vmin=vmin, vmax=vmax)
    
    # Plot each eigenfunction
    for i in range(k):
        ax = axes[i]
        eigenfunction = eigenfuncs_np[:, i]
        
        # Plot bars for eigenfunction values
        ax.bar(np.arange(n_points), eigenfunction, 
               color=plt.cm.rainbow(norm(eigenfunction)), 
               width=1.0)
        
        # Add a colorbar
        sm = plt.cm.ScalarMappable(cmap='rainbow', norm=norm)
        plt.colorbar(sm, ax=ax, label=f'Eigenfunction {i+1} Value')
        
        # Customize the plot
        ax.set_title(f'Eigenfunction {i+1}')
        ax.set_ylabel('Value')
        ax.set_xlabel('Index (Gray Code Ordered)')
        
        # Set y-limits consistently
        ax.set_ylim(vmin - 0.1 * (vmax - vmin), vmax + 0.1 * (vmax - vmin))
    
    plt.suptitle(title)
    plt.tight_layout()
    plt.savefig(save_path, bbox_inches='tight', dpi=300)
    plt.close()

def generate_datasets(P: int, d: int, n_train: int, n_test: int, msp: MSPFunction, seed: int = 42):
    """Generate training and test datasets"""
    torch.manual_seed(seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    X_train = (2 * torch.bernoulli(0.5 * torch.ones((n_train, d), dtype=torch.float32)) - 1).to(device)
    y_train = msp.evaluate(X_train)
    
    X_test = (2 * torch.bernoulli(0.5 * torch.ones((n_test, d), dtype=torch.float32)) - 1).to(device)
    y_test = msp.evaluate(X_test)
    
    return X_train, y_train, X_test, y_test

def compute_nngp_eigenfunction_values(kernel: jnp.ndarray, eigenvecs: jnp.ndarray, train_kernel: jnp.ndarray, y_train: jnp.ndarray):
    """Compute NNGP eigenfunction values in function space"""
    y_train = y_train.reshape(-1, 1)
    coeffs = jnp.dot(eigenvecs.T, y_train)
    function_values = jnp.dot(kernel, eigenvecs)
    function_values = function_values * coeffs.T
    return function_values

def compute_ntk_eigenfunction_values(kernel: jnp.ndarray, eigenvecs: jnp.ndarray):
    """Compute NTK eigenfunction values"""
    return jnp.dot(kernel, eigenvecs)

if __name__ == "__main__":
    # Model parameters
    P = 8
    d = 30
    hidden_size = 500
    depth = 1
    mode = 'mup_pennington'
    k = 10  # number of eigenvalues/eigenfunctions to analyze
    
    # Dataset parameters
    n_train = 200
    n_test = 50
    
    # Define MSP sets
    msp_sets = [{7}, {2,7}, {0,2,7}, {5,7,4}, {1}, {0,4}, {3,7}, {0,1,2,3,4,6,7}]
    
    # Initialize MSP function
    msp = MSPFunction(P, msp_sets)
    
    # Generate datasets
    print("Generating datasets...")
    X_train, y_train, X_test, y_test = generate_datasets(P, d, n_train, n_test, msp)
    print(f"Generated datasets - X_train: {X_train.shape}, X_test: {X_test.shape}")
    
    # Model path and output directory
    model_path = "/mnt/users/goringn/NNs_vs_Kernels/stair_function/results/msp_NN_grid_1612_nogrokk_mup_pennington/final_model_h500_d1_n20000_lr0.05_mup_pennington_20241219_163433_rank8.pt"
    output_dir = "ntk_analysis_results"
    
    # Run analysis
    print("\nStarting analysis...")
    analyze_model(model_path, X_train, y_train, X_test, y_test,
                 d, hidden_size, depth, mode, k, output_dir)
    print("Analysis complete!")

Generating datasets...
Generated datasets - X_train: torch.Size([200, 30]), X_test: torch.Size([50, 30])

Starting analysis...
Starting analysis...
Created output directory: ntk_analysis_results
Converting data to JAX arrays...
Data shapes - X_train: (200, 30), X_test: (50, 30)

Actual function values (first 100 training points):
[ 0.  0.  4.  2.  2.  2.  2. -2. -2.  2.  0.  0.  2.  4.  4. -4.  0.  2.
 -2. -2.  4.  2. -6. -2. -2. -2. -2.  2.  4.  0.  0. -2.  4. -6.  4.  6.
 -6.  2.  2. -2.  6.  2. -2.  4.  0.  4.  2. -2. -2.  0.  0.  0. -4. -4.
 -2. -2.  2. -2.  0.  2.  2.  0.  0.  0.  4. -2. -2.  4.  0. -6.  0. -2.
 -4. -4. -6.  4.  0.  2.  4.  0.  0.  2.  0.  0. -2.  2.  2. -4.  0.  4.
 -2.  2. -2. -4.  0.  2.  4.  0. -4.  2.]

Initializing model architecture...

Computing kernels...
Kernel shapes - NNGP train: (200, 200), NTK train: (200, 200)

Computing eigendecomposition...
Eigenvalue shapes - NNGP: (10,), NTK: (10,)

Computing eigenfunction values...
Eigenfunction shapes - NNGP t

In [11]:
import torch
import neural_tangents as nt
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Set, Tuple
import torch.nn as nn
import torch.optim as optim
import os
from scipy.linalg import eigh
from functools import partial

class MSPFunction:
    def __init__(self, P: int, sets: List[Set[int]]):
        self.P = P
        self.sets = sets
    
    def evaluate(self, z: torch.Tensor) -> torch.Tensor:
        device = z.device
        batch_size = z.shape[0]
        result = torch.zeros(batch_size, dtype=torch.float32, device=device)
        
        for S in self.sets:
            term = torch.ones(batch_size, dtype=torch.float32, device=device)
            for idx in S:
                term = term * z[:, idx]
            result = result + term
            
        return result

class DeepNN(nn.Module):
    def __init__(self, d: int, hidden_size: int, depth: int, mode: str = 'mup_pennington'):
        super().__init__()
        
        torch.set_default_dtype(torch.float32)
        
        self.mode = mode
        self.depth = depth
        self.hidden_size = hidden_size
        self.input_dim = d
        
        layers = []
        prev_dim = d
        self.layer_lrs = []  # Store layerwise learning rates
        
        for layer_idx in range(depth):
            linear = nn.Linear(prev_dim, hidden_size)
            
            if mode == 'mup_pennington':
                # muP initialization and learning rates from the paper
                if layer_idx == 0:  # Embedding layer
                    std = 1.0 / np.sqrt(prev_dim)
                    lr_scale = 1.0  # O(1) learning rate for embedding
                else:  # Hidden layers
                    std = 1.0 / np.sqrt(prev_dim)
                    lr_scale = 1.0 / prev_dim  # O(1/n) learning rate for hidden
                nn.init.normal_(linear.weight, mean=0.0, std=std)
                nn.init.zeros_(linear.bias)
                self.layer_lrs.append(lr_scale)
            
            layers.extend([
                linear,
                nn.ReLU()
            ])
            prev_dim = hidden_size
        
        # Final layer
        final_layer = nn.Linear(prev_dim, 1)
        if mode == 'mup_pennington':
            std = 1.0 / np.sqrt(prev_dim)
            lr_scale = 1.0 / prev_dim  # O(1/n) learning rate for readout
            nn.init.normal_(final_layer.weight, std=std)
            self.layer_lrs.append(lr_scale)
            
        nn.init.zeros_(final_layer.bias)
        layers.append(final_layer)
        
        self.network = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x).squeeze()

def generate_datasets(P: int, d: int, n_train: int, n_test: int, msp: MSPFunction, seed: int = 42):
    """Generate training and test datasets"""
    torch.manual_seed(seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    X_train = (2 * torch.bernoulli(0.5 * torch.ones((n_train, d), dtype=torch.float32)) - 1).to(device)
    y_train = msp.evaluate(X_train)
    
    X_test = (2 * torch.bernoulli(0.5 * torch.ones((n_test, d), dtype=torch.float32)) - 1).to(device)
    y_test = msp.evaluate(X_test)
    
    return X_train, y_train, X_test, y_test

def convert_data_to_jax(data: torch.Tensor) -> jnp.ndarray:
    """Convert PyTorch tensor to JAX array"""
    return jnp.array(data.detach().cpu().numpy())

def compute_kernels(X_train: jnp.ndarray, X_test: jnp.ndarray, params: List[jnp.ndarray], 
                   architecture: nn.Module) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """Compute NNGP and NTK kernels for both train and test sets"""
    # Create neural_tangents network
    layers = []
    
    # First layer
    layers.extend([
        nt.stax.Dense(architecture.hidden_size),
        nt.stax.Relu()
    ])
    
    # Additional hidden layers
    for _ in range(architecture.depth - 1):
        layers.extend([
            nt.stax.Dense(architecture.hidden_size),
            nt.stax.Relu()
        ])
    
    # Output layer
    layers.append(nt.stax.Dense(1))
    
    # Create the network
    init_fn, apply_fn, kernel_fn = nt.stax.serial(*layers)
    
    # Compute train kernels
    ntk_train = kernel_fn(X_train, X_train, 'ntk')
    nngp_train = kernel_fn(X_train, X_train, 'nngp')
    
    # Compute test-train kernels
    ntk_test = kernel_fn(X_test, X_train, 'ntk')
    nngp_test = kernel_fn(X_test, X_train, 'nngp')
    
    return nngp_train, ntk_train, nngp_test, ntk_test

def compute_spectrum(kernel: jnp.ndarray, k: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Compute top k eigenvalues and eigenvectors"""
    # Convert JAX array to numpy array for eigendecomposition
    kernel_np = np.array(kernel)
    n = kernel_np.shape[0]
    
    # Compute all eigenvalues and eigenvectors
    eigenvals, eigenvecs = eigh(kernel_np)
    
    # Sort in descending order and take top k
    idx = eigenvals.argsort()[::-1][:k]
    eigenvals = eigenvals[idx]
    eigenvecs = eigenvecs[:, idx]
    
    return eigenvals, eigenvecs

def compute_nngp_eigenfunction_values(kernel: jnp.ndarray, eigenvecs: jnp.ndarray, train_kernel: jnp.ndarray, y_train: jnp.ndarray):
    """Compute NNGP eigenfunction values by projecting through function space"""
    # Compute coefficients using training data
    y_train = y_train.reshape(-1, 1)
    coeffs = jnp.dot(eigenvecs.T, y_train)
    
    # Project to get function values using the appropriate kernel
    function_values = jnp.dot(kernel, eigenvecs)
    
    # Scale by the coefficients
    function_values = function_values * coeffs.T
    
    return function_values

def compute_ntk_eigenfunction_values(kernel: jnp.ndarray, eigenvecs: jnp.ndarray):
    """Compute NTK eigenfunction values which are already in function space"""
    # For NTK, eigenfunction values are directly in function space
    function_values = jnp.dot(kernel, eigenvecs)
    return function_values

def plot_spectrum(eigenvals_nngp: jnp.ndarray, eigenvals_ntk: jnp.ndarray, 
                 save_path: str, title: str):
    """Plot eigenvalue spectrum for both NNGP and NTK"""
    plt.figure(figsize=(10, 6))
    plt.semilogy(np.arange(len(eigenvals_nngp)), eigenvals_nngp, 'b-', label='NNGP')
    plt.semilogy(np.arange(len(eigenvals_ntk)), eigenvals_ntk, 'r-', label='NTK')
    plt.xlabel('Index')
    plt.ylabel('Eigenvalue (log scale)')
    plt.title(title)
    plt.legend()
    plt.grid(True)
    plt.savefig(save_path)
    plt.close()

def plot_eigenfunctions(eigenfuncs: jnp.ndarray, X: jnp.ndarray, y: jnp.ndarray,
                       save_path: str, title: str, k: int):
    """Plot histograms of eigenfunction values"""
    # Create a figure with subplots
    fig, axes = plt.subplots(k, 1, figsize=(12, 3*k))
    if k == 1:
        axes = [axes]
    
    # Plot histogram for each eigenfunction
    for i in range(k):
        axes[i].hist(eigenfuncs[:, i], bins=50, density=True)
        axes[i].set_title(f'Eigenfunction {i+1}')
        axes[i].set_xlabel('Value')
        axes[i].set_ylabel('Density')
    
    plt.suptitle(title)
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def analyze_model(model_path: str, 
                 X_train: torch.Tensor, y_train: torch.Tensor,
                 X_test: torch.Tensor, y_test: torch.Tensor,
                 d: int, hidden_size: int, depth: int, mode: str,
                 k: int, output_dir: str):
    """Main analysis function"""
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Convert data to JAX arrays
    X_train_jax = convert_data_to_jax(X_train)
    y_train_jax = convert_data_to_jax(y_train)
    X_test_jax = convert_data_to_jax(X_test)
    y_test_jax = convert_data_to_jax(y_test)
    
    # Print some of the actual function values
    print("\nActual function values (first 100 training points):")
    print(y_train_jax[:100])
    
    # Load model architecture
    architecture = DeepNN(d, hidden_size, depth, mode)
    
    # Compute kernels directly without loading parameters
    nngp_train, ntk_train, nngp_test, ntk_test = compute_kernels(X_train_jax, X_test_jax, None, architecture)
    
    # Compute spectra from training kernels
    nngp_vals, nngp_vecs = compute_spectrum(nngp_train, k)
    ntk_vals, ntk_vecs = compute_spectrum(ntk_train, k)
    
    # Compute NNGP eigenfunction values in function space
    nngp_train_func = compute_nngp_eigenfunction_values(nngp_train, nngp_vecs, nngp_train, y_train_jax)
    nngp_test_func = compute_nngp_eigenfunction_values(nngp_test, nngp_vecs, nngp_train, y_train_jax)
    
    # Compute NTK eigenfunction values (already in function space)
    ntk_train_func = compute_ntk_eigenfunction_values(ntk_train, ntk_vecs)
    ntk_test_func = compute_ntk_eigenfunction_values(ntk_test, ntk_vecs)
    
    # Print first 100 values of the projected eigenfunctions
    print("\nFirst 100 values of NNGP eigenfunctions in function space (training set):")
    for i in range(k):
        print(f"\nNNGP Eigenfunction {i+1}:")
        print(nngp_train_func[:100, i])
    
    print("\nFirst 100 values of NTK eigenfunctions in function space (training set):")
    for i in range(k):
        print(f"\nNTK Eigenfunction {i+1}:")
        print(ntk_train_func[:100, i])
        
    # Print kernel eigenvalues
    print("\nNNGP kernel eigenvalues:")
    print(nngp_vals)
    print("\nNTK kernel eigenvalues:")
    print(ntk_vals)
    
    # Plot spectra
    plot_spectrum(nngp_vals, ntk_vals,
                 os.path.join(output_dir, 'kernel_spectra.png'),
                 'Kernel Spectra')
    
    # Plot NNGP eigenfunctions
    plot_eigenfunctions(nngp_train_func, X_train_jax, y_train_jax,
                       os.path.join(output_dir, 'nngp_eigenfunctions_train.png'),
                       'NNGP Eigenfunctions (Training Set)', k)
    plot_eigenfunctions(nngp_test_func, X_test_jax, y_test_jax,
                       os.path.join(output_dir, 'nngp_eigenfunctions_test.png'),
                       'NNGP Eigenfunctions (Test Set)', k)
                       
    # Plot NTK eigenfunctions
    plot_eigenfunctions(ntk_train_func, X_train_jax, y_train_jax,
                       os.path.join(output_dir, 'ntk_eigenfunctions_train.png'),
                       'NTK Eigenfunctions (Training Set)', k)
    plot_eigenfunctions(ntk_test_func, X_test_jax, y_test_jax,
                       os.path.join(output_dir, 'ntk_eigenfunctions_test.png'),
                       'NTK Eigenfunctions (Test Set)', k)

if __name__ == "__main__":
    # Model parameters
    P = 8
    d = 30
    hidden_size = 500
    depth = 1
    mode = 'mup_pennington'
    k = 10  # number of eigenvalues/eigenfunctions to analyze
    
    # Dataset parameters
    n_train = 8000
    n_test = 8000
    
    # Define MSP sets
    msp_sets = [{7}, {2,7}, {0,2,7}, {5,7,4}, {1}, {0,4}, {3,7}, {0,1,2,3,4,6,7}]
    
    # Initialize MSP function
    msp = MSPFunction(P, msp_sets)
    
    # Generate datasets
    X_train, y_train, X_test, y_test = generate_datasets(P, d, n_train, n_test, msp)
    
    # Model path and output directory
    import torch
import neural_tangents as nt
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Set, Tuple
import torch.nn as nn
import torch.optim as optim
import os
from scipy.linalg import eigh
from functools import partial

class MSPFunction:
    def __init__(self, P: int, sets: List[Set[int]]):
        self.P = P
        self.sets = sets
    
    def evaluate(self, z: torch.Tensor) -> torch.Tensor:
        device = z.device
        batch_size = z.shape[0]
        result = torch.zeros(batch_size, dtype=torch.float32, device=device)
        
        for S in self.sets:
            term = torch.ones(batch_size, dtype=torch.float32, device=device)
            for idx in S:
                term = term * z[:, idx]
            result = result + term
            
        return result

class DeepNN(nn.Module):
    def __init__(self, d: int, hidden_size: int, depth: int, mode: str = 'mup_pennington'):
        super().__init__()
        
        torch.set_default_dtype(torch.float32)
        
        self.mode = mode
        self.depth = depth
        self.hidden_size = hidden_size
        self.input_dim = d
        
        layers = []
        prev_dim = d
        self.layer_lrs = []  # Store layerwise learning rates
        
        for layer_idx in range(depth):
            linear = nn.Linear(prev_dim, hidden_size)
            
            if mode == 'mup_pennington':
                # muP initialization and learning rates from the paper
                if layer_idx == 0:  # Embedding layer
                    std = 1.0 / np.sqrt(prev_dim)
                    lr_scale = 1.0  # O(1) learning rate for embedding
                else:  # Hidden layers
                    std = 1.0 / np.sqrt(prev_dim)
                    lr_scale = 1.0 / prev_dim  # O(1/n) learning rate for hidden
                nn.init.normal_(linear.weight, mean=0.0, std=std)
                nn.init.zeros_(linear.bias)
                self.layer_lrs.append(lr_scale)
            
            layers.extend([
                linear,
                nn.ReLU()
            ])
            prev_dim = hidden_size
        
        # Final layer
        final_layer = nn.Linear(prev_dim, 1)
        if mode == 'mup_pennington':
            std = 1.0 / np.sqrt(prev_dim)
            lr_scale = 1.0 / prev_dim  # O(1/n) learning rate for readout
            nn.init.normal_(final_layer.weight, std=std)
            self.layer_lrs.append(lr_scale)
            
        nn.init.zeros_(final_layer.bias)
        layers.append(final_layer)
        
        self.network = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x).squeeze()

def generate_datasets(P: int, d: int, n_train: int, n_test: int, msp: MSPFunction, seed: int = 42):
    """Generate training and test datasets"""
    torch.manual_seed(seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    X_train = (2 * torch.bernoulli(0.5 * torch.ones((n_train, d), dtype=torch.float32)) - 1).to(device)
    y_train = msp.evaluate(X_train)
    
    X_test = (2 * torch.bernoulli(0.5 * torch.ones((n_test, d), dtype=torch.float32)) - 1).to(device)
    y_test = msp.evaluate(X_test)
    
    return X_train, y_train, X_test, y_test

def convert_data_to_jax(data: torch.Tensor) -> jnp.ndarray:
    """Convert PyTorch tensor to JAX array"""
    return jnp.array(data.detach().cpu().numpy())

def compute_kernels(X_train: jnp.ndarray, X_test: jnp.ndarray, params: List[jnp.ndarray], 
                   architecture: nn.Module) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """Compute NNGP and NTK kernels for both train and test sets"""
    # Create neural_tangents network
    layers = []
    
    # First layer
    layers.extend([
        nt.stax.Dense(architecture.hidden_size),
        nt.stax.Relu()
    ])
    
    # Additional hidden layers
    for _ in range(architecture.depth - 1):
        layers.extend([
            nt.stax.Dense(architecture.hidden_size),
            nt.stax.Relu()
        ])
    
    # Output layer
    layers.append(nt.stax.Dense(1))
    
    # Create the network
    init_fn, apply_fn, kernel_fn = nt.stax.serial(*layers)
    
    # Compute train kernels
    ntk_train = kernel_fn(X_train, X_train, 'ntk')
    nngp_train = kernel_fn(X_train, X_train, 'nngp')
    
    # Compute test-train kernels
    ntk_test = kernel_fn(X_test, X_train, 'ntk')
    nngp_test = kernel_fn(X_test, X_train, 'nngp')
    
    return nngp_train, ntk_train, nngp_test, ntk_test

def compute_spectrum(kernel: jnp.ndarray, k: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Compute top k eigenvalues and eigenvectors"""
    # Convert JAX array to numpy array for eigendecomposition
    kernel_np = np.array(kernel)
    n = kernel_np.shape[0]
    
    # Compute all eigenvalues and eigenvectors
    eigenvals, eigenvecs = eigh(kernel_np)
    
    # Sort in descending order and take top k
    idx = eigenvals.argsort()[::-1][:k]
    eigenvals = eigenvals[idx]
    eigenvecs = eigenvecs[:, idx]
    
    return eigenvals, eigenvecs

def compute_nngp_eigenfunction_values(kernel: jnp.ndarray, eigenvecs: jnp.ndarray, train_kernel: jnp.ndarray, y_train: jnp.ndarray):
    """Compute NNGP eigenfunction values by projecting through function space"""
    # Compute coefficients using training data
    y_train = y_train.reshape(-1, 1)
    coeffs = jnp.dot(eigenvecs.T, y_train)
    
    # Project to get function values using the appropriate kernel
    function_values = jnp.dot(kernel, eigenvecs)
    
    # Scale by the coefficients
    function_values = function_values * coeffs.T
    
    return function_values

def compute_ntk_eigenfunction_values(kernel: jnp.ndarray, eigenvecs: jnp.ndarray):
    """Compute NTK eigenfunction values which are already in function space"""
    # For NTK, eigenfunction values are directly in function space
    function_values = jnp.dot(kernel, eigenvecs)
    return function_values

def plot_spectrum(eigenvals_nngp: jnp.ndarray, eigenvals_ntk: jnp.ndarray, 
                 save_path: str, title: str):
    """Plot eigenvalue spectrum for both NNGP and NTK"""
    plt.figure(figsize=(10, 6))
    plt.semilogy(np.arange(len(eigenvals_nngp)), eigenvals_nngp, 'b-', label='NNGP')
    plt.semilogy(np.arange(len(eigenvals_ntk)), eigenvals_ntk, 'r-', label='NTK')
    plt.xlabel('Index')
    plt.ylabel('Eigenvalue (log scale)')
    plt.title(title)
    plt.legend()
    plt.grid(True)
    plt.savefig(save_path)
    plt.close()

def plot_eigenfunctions(eigenfuncs: jnp.ndarray, X: jnp.ndarray, y: jnp.ndarray,
                       save_path: str, title: str, k: int):
    """Plot histograms of eigenfunction values"""
    # Create a figure with subplots
    fig, axes = plt.subplots(k, 1, figsize=(12, 3*k))
    if k == 1:
        axes = [axes]
    
    # Plot histogram for each eigenfunction
    for i in range(k):
        axes[i].hist(eigenfuncs[:, i], bins=50, density=True)
        axes[i].set_title(f'Eigenfunction {i+1}')
        axes[i].set_xlabel('Value')
        axes[i].set_ylabel('Density')
    
    plt.suptitle(title)
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def analyze_model(model_path: str, 
                 X_train: torch.Tensor, y_train: torch.Tensor,
                 X_test: torch.Tensor, y_test: torch.Tensor,
                 d: int, hidden_size: int, depth: int, mode: str,
                 k: int, output_dir: str):
    """Main analysis function"""
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Convert data to JAX arrays
    X_train_jax = convert_data_to_jax(X_train)
    y_train_jax = convert_data_to_jax(y_train)
    X_test_jax = convert_data_to_jax(X_test)
    y_test_jax = convert_data_to_jax(y_test)
    
    # Print some of the actual function values
    print("\nActual function values (first 100 training points):")
    print(y_train_jax[:100])
    
    # Load model architecture
    architecture = DeepNN(d, hidden_size, depth, mode)
    
    # Compute kernels directly without loading parameters
    nngp_train, ntk_train, nngp_test, ntk_test = compute_kernels(X_train_jax, X_test_jax, None, architecture)
    
    # Compute spectra from training kernels
    nngp_vals, nngp_vecs = compute_spectrum(nngp_train, k)
    ntk_vals, ntk_vecs = compute_spectrum(ntk_train, k)
    
    # Compute NNGP eigenfunction values in function space
    nngp_train_func = compute_nngp_eigenfunction_values(nngp_train, nngp_vecs, nngp_train, y_train_jax)
    nngp_test_func = compute_nngp_eigenfunction_values(nngp_test, nngp_vecs, nngp_train, y_train_jax)
    
    # Compute NTK eigenfunction values (already in function space)
    ntk_train_func = compute_ntk_eigenfunction_values(ntk_train, ntk_vecs)
    ntk_test_func = compute_ntk_eigenfunction_values(ntk_test, ntk_vecs)
    
    # Print first 100 values of the projected eigenfunctions
    print("\nFirst 100 values of NNGP eigenfunctions in function space (training set):")
    for i in range(k):
        print(f"\nNNGP Eigenfunction {i+1}:")
        print(nngp_train_func[:100, i])
    
    print("\nFirst 100 values of NTK eigenfunctions in function space (training set):")
    for i in range(k):
        print(f"\nNTK Eigenfunction {i+1}:")
        print(ntk_train_func[:100, i])
        
    # Print kernel eigenvalues
    print("\nNNGP kernel eigenvalues:")
    print(nngp_vals)
    print("\nNTK kernel eigenvalues:")
    print(ntk_vals)
    
    # Plot spectra
    plot_spectrum(nngp_vals, ntk_vals,
                 os.path.join(output_dir, 'kernel_spectra.png'),
                 'Kernel Spectra')
    
    # Plot NNGP eigenfunctions
    plot_eigenfunctions(nngp_train_func, X_train_jax, y_train_jax,
                       os.path.join(output_dir, 'nngp_eigenfunctions_train.png'),
                       'NNGP Eigenfunctions (Training Set)', k)
    plot_eigenfunctions(nngp_test_func, X_test_jax, y_test_jax,
                       os.path.join(output_dir, 'nngp_eigenfunctions_test.png'),
                       'NNGP Eigenfunctions (Test Set)', k)
                       
    # Plot NTK eigenfunctions
    plot_eigenfunctions(ntk_train_func, X_train_jax, y_train_jax,
                       os.path.join(output_dir, 'ntk_eigenfunctions_train.png'),
                       'NTK Eigenfunctions (Training Set)', k)
    plot_eigenfunctions(ntk_test_func, X_test_jax, y_test_jax,
                       os.path.join(output_dir, 'ntk_eigenfunctions_test.png'),
                       'NTK Eigenfunctions (Test Set)', k)

if __name__ == "__main__":
    # Model parameters
    P = 8
    d = 30
    hidden_size = 500
    depth = 1
    mode = 'mup_pennington'
    k = 10  # number of eigenvalues/eigenfunctions to analyze
    
    # Dataset parameters
    n_train = 8000
    n_test = 8000
    
    # Define MSP sets
    msp_sets = [{7}, {2,7}, {0,2,7}, {5,7,4}, {1}, {0,4}, {3,7}, {0,1,2,3,4,6,7}]
    
    # Initialize MSP function
    msp = MSPFunction(P, msp_sets)
    
    # Generate datasets
    X_train, y_train, X_test, y_test = generate_datasets(P, d, n_train, n_test, msp)
    
    # Model path and output directory
    model_path = "/mnt/users/goringn/NNs_vs_Kernels/stair_function/results/msp_NN_grid_1612_nogrokk_mup_pennington/final_model_h500_d1_n20000_lr0.05_mup_pennington_20241219_163433_rank8.pt"
    model_path="/mnt/users/goringn/NNs_vs_Kernels/stair_function/results/msp_NN_grid_1612_nogrokk_standard/final_model_h600_d1_n20000_lr0.005_standard_20241219_042616_rank90.pt"
    output_dir = "ntk_analysis_results_init"
    
    # Run analysis
    analyze_model(model_path, X_train, y_train, X_test, y_test,
                 d, hidden_size, depth, mode, k, output_dir)


Actual function values (first 100 training points):
[ 0.  0.  4.  2.  2.  2.  2. -2. -2.  2.  0.  0.  2.  4.  4. -4.  0.  2.
 -2. -2.  4.  2. -6. -2. -2. -2. -2.  2.  4.  0.  0. -2.  4. -6.  4.  6.
 -6.  2.  2. -2.  6.  2. -2.  4.  0.  4.  2. -2. -2.  0.  0.  0. -4. -4.
 -2. -2.  2. -2.  0.  2.  2.  0.  0.  0.  4. -2. -2.  4.  0. -6.  0. -2.
 -4. -4. -6.  4.  0.  2.  4.  0.  0.  2.  0.  0. -2.  2.  2. -4.  0.  4.
 -2.  2. -2. -4.  0.  2.  4.  0. -4.  2.]

First 100 values of NNGP eigenfunctions in function space (training set):

NNGP Eigenfunction 1:
[-57.547653 -57.385853 -57.32267  -57.295376 -57.245842 -57.338284
 -57.532696 -57.59592  -57.843735 -57.495667 -57.524254 -57.76261
 -57.545918 -57.3961   -57.424057 -57.569164 -57.558372 -57.21671
 -57.682575 -57.50311  -57.557114 -57.54716  -57.538612 -57.656208
 -57.305386 -57.86829  -57.298805 -57.654217 -57.569412 -57.5054
 -57.22424  -57.68526  -57.264507 -57.3798   -57.432182 -57.39058
 -57.104225 -57.7068   -57.563557 -57.519077 