In [None]:
import torch
import os
import json
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from typing import Tuple, Dict, List, Optional

class DeepNN(torch.nn.Module):
    def __init__(self, d: int, hidden_size: int, depth: int, mode: str = 'standard'):
        super().__init__()
        
        torch.set_default_dtype(torch.float32)
        
        self.mode = mode
        self.depth = depth
        self.hidden_size = hidden_size
        self.input_dim = d
        
        layers = []
        prev_dim = d
        
        for layer_idx in range(depth):
            linear = torch.nn.Linear(prev_dim, hidden_size)
            
            if mode == 'standard':
                torch.nn.init.xavier_uniform_(linear.weight)
                torch.nn.init.zeros_(linear.bias)
            
            layers.extend([
                linear,
                torch.nn.ReLU()
            ])
            prev_dim = hidden_size
        
        final_layer = torch.nn.Linear(prev_dim, 1)
        if mode == 'standard':
            torch.nn.init.xavier_uniform_(final_layer.weight)
        torch.nn.init.zeros_(final_layer.bias)
        layers.append(final_layer)
        
        self.network = torch.nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x).squeeze()

def load_rank_results(result_files):
    """Load and combine results from multiple rank files"""
    combined_results = []
    empty_files = []
    
    for file_path in result_files:
        try:
            with open(file_path) as f:
                content = f.read().strip()
                if not content:
                    empty_files.append(file_path)
                    continue
                    
                results = json.loads(content)
                if isinstance(results, list):
                    combined_results.extend(results)
                else:
                    combined_results.append(results)
                    
        except Exception as e:
            print(f"Error processing {file_path}: {str(e)}")
            continue
    
    if empty_files:
        print("\nThe following files were empty:")
        for f in empty_files:
            print(f"  - {f}")
            
    print(f"\nTotal results loaded: {len(combined_results)}")
    return combined_results

def load_experiment_data(results_dir: str) -> Tuple[List[Dict], Dict]:
    """Load experiment results from rank-based result files"""
    nn_files = list(Path(results_dir).glob("results*.json"))
    if not nn_files:
        raise ValueError(f"No result files found in {results_dir}")
    
    # Load results from all rank files
    nn_results = load_rank_results(nn_files)
    
    # Load hyperparameters from any rank file (they should be the same)
    hyperparams_file = next(Path(results_dir).glob("hyperparameters*.json"))
    with open(hyperparams_file, 'r') as f:
        hyperparams = json.load(f)
    
    return nn_results, hyperparams

def find_model_file(results_dir: str, model_prefix: str, rank: int) -> Optional[str]:
    """Find a model file matching the prefix and rank across different timestamps"""
    # Try both timestamps and model file patterns
    timestamps = ["20250107_150050", "20250107_151953"]
    patterns = [
        f"final_model_{model_prefix}_{{}}_rank{rank}.pt",
        f"initial_model_{model_prefix}_{{}}_rank{rank}.pt"
    ]
    
    for timestamp in timestamps:
        for pattern in patterns:
            filename = pattern.format(timestamp)
            filepath = os.path.join(results_dir, filename)
            if os.path.exists(filepath):
                return filepath
    return None

def find_data_file(results_dir: str, model_prefix: str, rank: int) -> Optional[str]:
    """Find a data file matching the prefix and rank across different timestamps"""
    timestamps = ["20241230_145424", "20241230_145221"]
    for timestamp in timestamps:
        filename = f"train_data_{model_prefix}_{timestamp}_rank{rank}.pt"
        filepath = os.path.join(results_dir, filename)
        if os.path.exists(filepath):
            return filepath
    return None



def load_model_and_data(data_dir: str, results_dir: str, result: Dict) -> Optional[Tuple[torch.nn.Module, torch.Tensor, torch.Tensor]]:
    """Load the model and its corresponding training data"""
    try:
        hyperparams_file = next(Path(results_dir).glob("hyperparameters*.json"))
        with open(hyperparams_file, 'r') as f:
            hyperparams = json.load(f)
        input_dim = hyperparams['d']
        
        hidden_size = result['hidden_size']
        depth = result['depth']
        n_train = result['n_train']
        lr = result.get('learning_rate', result.get('lr'))
        mode = result.get('mode', 'mup_pennington')  # Update default mode if needed
        shuffled = result.get('shuffled', False)
        rank = result.get('worker_rank', 0)
        
        model_prefix = f'h{hidden_size}_d{depth}_n{n_train}_lr{lr}_{mode}'
        if shuffled:
            model_prefix += '_shuffled'
        
        # Find model and data files
        model_path = find_model_file(results_dir, model_prefix, rank)
        train_dataset_path = find_data_file(results_dir, model_prefix, rank)
        
        if not model_path:
            print(f"Warning: No model file found for prefix: {model_prefix}, rank: {rank}")
            return None
            
        if not train_dataset_path:
            print(f"Warning: No dataset file found for prefix: {model_prefix}, rank: {rank}")
            return None
        
        model = DeepNN(input_dim, hidden_size, depth, mode=mode)
        model.load_state_dict(torch.load(model_path))
        
        train_data = torch.load(train_dataset_path)
        X_train = train_data['X']
        y_train = train_data['y']
        
        device = next(model.parameters()).device
        model = model.to(device)
        X_train = X_train.to(device)
        y_train = y_train.to(device)
        
        return model, X_train, y_train
        
    except Exception as e:
        print(f"Error loading model and data: {str(e)}")
        return None

def plot_low_dim_ratio(results: List[Dict], results_dir: str, threshold: float = 0.2):
    """Plot ratio of features with dimensionality below threshold vs training size"""
    train_sizes = sorted(set(r['n_train'] for r in results))
    hidden_sizes = sorted(set(r['hidden_size'] for r in results))
    
    plt.figure(figsize=(10, 6))
    
    for hidden_size in hidden_sizes:
        ratios = []
        valid_train_sizes = []
        for n_train in train_sizes:
            matching_results = [r for r in results 
                              if r['hidden_size'] == hidden_size and r['n_train'] == n_train]
            
            if matching_results:
                result = matching_results[0]
                model_data = load_model_and_data(results_dir, results_dir, result)
                
                if model_data is not None:
                    model, _, _ = model_data
                    feature_dims = analyze_feature_dimensionality(model)
                    ratio = np.mean(feature_dims < threshold)
                    ratios.append(ratio)
                    valid_train_sizes.append(n_train)
        
        if ratios:
            plt.plot(valid_train_sizes, ratios, '-o', label=f'h={hidden_size}')
    
    plt.xscale('log')
    plt.xlabel('Training Size')
    plt.ylabel(f'Ratio of Features with Dim < {threshold}')
    plt.title('Low-Dimensional Feature Ratio vs Training Size')
    plt.grid(True, alpha=0.3)
    plt.legend()
    
    return plt.gcf()

def create_feature_dim_histograms(results: List[Dict], results_dir: str):
    """Create histograms of feature dimensionality for different training sizes"""
    train_sizes = sorted(set(r['n_train'] for r in results))
    hidden_sizes = sorted(set(r['hidden_size'] for r in results))
    
    n_rows = len(hidden_sizes)
    n_cols = min(6, len(train_sizes))  # Limit number of columns for readability
    selected_train_sizes = np.logspace(np.log10(min(train_sizes)), 
                                     np.log10(max(train_sizes)), 
                                     n_cols).astype(int)
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 3*n_rows))
    if n_rows == 1:
        axes = axes.reshape(1, -1)
    
    for i, hidden_size in enumerate(hidden_sizes):
        for j, n_train in enumerate(selected_train_sizes):
            ax = axes[i, j]
            
            # Find closest available training size
            closest_n_train = min(train_sizes, key=lambda x: abs(x - n_train))
            
            matching_results = [r for r in results 
                              if r['hidden_size'] == hidden_size and r['n_train'] == closest_n_train]
            
            if matching_results:
                result = matching_results[0]
                model_data = load_model_and_data(results_dir, results_dir, result)
                
                if model_data is not None:
                    model, _, _ = model_data
                    feature_dims = analyze_feature_dimensionality(model)
                    
                    ax.hist(np.log1p(feature_dims), bins=30, alpha=0.7)
                    ax.set_title(f'h={hidden_size}, n={closest_n_train}')
                else:
                    ax.text(0.5, 0.5, 'No data available', 
                           ha='center', va='center',
                           transform=ax.transAxes)
                
            if j == 0:
                ax.set_ylabel('Count')
            if i == n_rows-1:
                ax.set_xlabel('log(1 + dim)')
    
    plt.tight_layout()
    return fig

def analyze_feature_dimensionality(model: torch.nn.Module) -> np.ndarray:
    """Analyze feature dimensionality for layer 1 of the model"""
    first_layer = None
    for layer in model.network:
        if isinstance(layer, torch.nn.Linear):
            first_layer = layer
            break
    
    if first_layer is None:
        raise ValueError("No linear layer found in model")
        
    with torch.no_grad():
        features = first_layer.weight.data
        feature_norms = torch.norm(features, dim=0, keepdim=True)
        normalized_features = features / (feature_norms + 1e-8)
        feature_dots = normalized_features.T @ features
        feature_dots_squared = feature_dots ** 2
        feature_denominator = torch.sum(feature_dots_squared, dim=1)
        feature_numerator = torch.sum(features * features, dim=0)
        feature_dims = feature_numerator / (feature_denominator + 1e-8)
        
        return feature_dims.cpu().numpy()

def main():
    # Set your paths here
    nn_results_dir = "/mnt/users/goringn/NNs_vs_Kernels/stair_function/results/msp_NN_grid_0701_mup_lr005"
    ntk_results_path = "//mnt/users/goringn/NNs_vs_Kernels/stair_function/results/msp_NN_grid_1812_spectral/final_results_20241219_015151.json"
    output_dir = "analysis_outputs"
    os.makedirs(output_dir, exist_ok=True)
    
    # Load results
    nn_results, hyperparams = load_experiment_data(nn_results_dir)
    with open(ntk_results_path, 'r') as f:
        ntk_results = json.load(f)
    
    print(f"Loaded {len(nn_results)} NN results and {len(ntk_results)} NTK results")
    
    # Create threshold plot
    threshold_fig = create_threshold_plot(nn_results, ntk_results)
    threshold_fig.savefig(os.path.join(output_dir, 'threshold_analysis.png'), dpi=300, bbox_inches='tight')
    
    # Create low-dimensional ratio plot
    ratio_fig = plot_low_dim_ratio(nn_results, nn_results_dir)
    ratio_fig.savefig(os.path.join(output_dir, 'low_dim_ratio.png'), dpi=300, bbox_inches='tight')
    
    # Create feature dimensionality histograms
    hist_fig = create_feature_dim_histograms(nn_results, nn_results_dir)
    hist_fig.savefig(os.path.join(output_dir, 'feature_dim_histograms.png'), dpi=300, bbox_inches='tight')
    
    plt.close('all')

if __name__ == "__main__":
    main()




Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7fdd021cbb20>>
Traceback (most recent call last):
  File "/mnt/users/goringn/NNs_vs_Kernels/env_dev_1/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 



Total results loaded: 866
Loaded 866 NN results and 2346 NTK results


  plt.tight_layout()
  threshold_fig.savefig(os.path.join(output_dir, 'threshold_analysis.png'), dpi=300, bbox_inches='tight')




  plt.legend()




In [None]:
import torch
import os
import json
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, Dict, List, Optional
from pathlib import Path

class DeepNN(torch.nn.Module):
    def __init__(self, d: int, hidden_size: int, depth: int, mode: str = 'standard'):
        super().__init__()
        
        torch.set_default_dtype(torch.float32)
        
        self.mode = mode
        self.depth = depth
        self.hidden_size = hidden_size
        self.input_dim = d
        
        layers = []
        prev_dim = d
        
        for layer_idx in range(depth):
            linear = torch.nn.Linear(prev_dim, hidden_size)
            
            if mode == 'standard':
                torch.nn.init.xavier_uniform_(linear.weight)
                torch.nn.init.zeros_(linear.bias)
            
            layers.extend([
                linear,
                torch.nn.ReLU()
            ])
            prev_dim = hidden_size
        
        final_layer = torch.nn.Linear(prev_dim, 1)
        if mode == 'standard':
            torch.nn.init.xavier_uniform_(final_layer.weight)
        torch.nn.init.zeros_(final_layer.bias)
        layers.append(final_layer)
        
        self.network = torch.nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x).squeeze()

def find_model_file(results_dir: str, model_prefix: str, rank: int) -> Optional[str]:
    """Find a model file matching the prefix and rank across different timestamps"""
    # Try both timestamps and model file patterns
    timestamps = ["20250107_150050", "20250107_151953"]
    patterns = [
        f"final_model_{model_prefix}_{{}}_rank{rank}.pt",
        f"initial_model_{model_prefix}_{{}}_rank{rank}.pt"
    ]
    
    for timestamp in timestamps:
        for pattern in patterns:
            filename = pattern.format(timestamp)
            filepath = os.path.join(results_dir, filename)
            if os.path.exists(filepath):
                return filepath
    return None

def load_model(results_dir: str, result: Dict) -> Optional[torch.nn.Module]:
    """Load just the model without dataset"""
    try:
        hyperparams_file = next(Path(results_dir).glob("hyperparameters*.json"))
        with open(hyperparams_file, 'r') as f:
            hyperparams = json.load(f)
        input_dim = hyperparams['d']
        
        hidden_size = result['hidden_size']
        depth = result['depth']
        n_train = result['n_train']
        lr = result.get('learning_rate', result.get('lr'))
        mode = result.get('mode', 'mup_pennington')  # Update default mode if needed
        shuffled = result.get('shuffled', False)
        rank = result.get('worker_rank', 0)
        
        model_prefix = f'h{hidden_size}_d{depth}_n{n_train}_lr{lr}_{mode}'
        if shuffled:
            model_prefix += '_shuffled'
        
        # Find model file
        model_path = find_model_file(results_dir, model_prefix, rank)
        
        if not model_path:
            print(f"Warning: No model file found for prefix: {model_prefix}, rank: {rank}")
            return None
        
        model = DeepNN(input_dim, hidden_size, depth, mode=mode)
        model.load_state_dict(torch.load(model_path))
        
        device = next(model.parameters()).device
        model = model.to(device)
        
        return model
        
    except Exception as e:
        print(f"Error loading model: {str(e)}")
        return None

def load_experiment_data(results_dir: str) -> Tuple[List[Dict], Dict]:
    """Load experiment results from rank-based result files"""
    nn_files = list(Path(results_dir).glob("results*.json"))
    if not nn_files:
        raise ValueError(f"No result files found in {results_dir}")
    
    # Load results from all rank files
    combined_results = []
    empty_files = []
    
    for file_path in nn_files:
        try:
            with open(file_path) as f:
                content = f.read().strip()
                if not content:
                    empty_files.append(file_path)
                    continue
                    
                results = json.loads(content)
                if isinstance(results, list):
                    combined_results.extend(results)
                else:
                    combined_results.append(results)
                    
        except Exception as e:
            print(f"Error processing {file_path}: {str(e)}")
            continue
    
    if empty_files:
        print("\nThe following files were empty:")
        for f in empty_files:
            print(f"  - {f}")
            
    print(f"\nTotal results loaded: {len(combined_results)}")
    
    # Load hyperparameters from any rank file (they should be the same)
    hyperparams_file = next(Path(results_dir).glob("hyperparameters*.json"))
    with open(hyperparams_file, 'r') as f:
        hyperparams = json.load(f)
    
    return combined_results, hyperparams

def analyze_feature_dimensionality(model: torch.nn.Module) -> np.ndarray:
    """Analyze feature dimensionality for layer 1 of the model"""
    first_layer = None
    for layer in model.network:
        if isinstance(layer, torch.nn.Linear):
            first_layer = layer
            break
    
    if first_layer is None:
        raise ValueError("No linear layer found in model")
        
    with torch.no_grad():
        features = first_layer.weight.data
        feature_norms = torch.norm(features, dim=0, keepdim=True)
        normalized_features = features / (feature_norms + 1e-8)
        feature_dots = normalized_features.T @ features
        feature_dots_squared = feature_dots ** 2
        feature_denominator = torch.sum(feature_dots_squared, dim=1)
        feature_numerator = torch.sum(features * features, dim=0)
        feature_dims = feature_numerator / (feature_denominator + 1e-8)
        
        return feature_dims.cpu().numpy()

def plot_low_dim_ratio(results: List[Dict], results_dir: str, threshold: float = 0.1):
    """Plot ratio of features with dimensionality below threshold vs training size"""
    train_sizes = sorted(set(r['n_train'] for r in results))
    hidden_sizes = sorted(set(r['hidden_size'] for r in results))
    
    plt.figure(figsize=(10, 6))
    
    for hidden_size in hidden_sizes:
        ratios = []
        valid_train_sizes = []
        for n_train in train_sizes:
            matching_results = [r for r in results 
                              if r['hidden_size'] == hidden_size and r['n_train'] == n_train]
            
            if matching_results:
                result = matching_results[0]
                model = load_model(results_dir, result)
                
                if model is not None:
                    feature_dims = analyze_feature_dimensionality(model)
                    ratio = np.mean(feature_dims < threshold)
                    ratios.append(ratio)
                    valid_train_sizes.append(n_train)
        
        if ratios:
            plt.plot(valid_train_sizes, ratios, '-o', label=f'h={hidden_size}')
    
    plt.xscale('log')
    plt.xlabel('Training Size')
    plt.ylabel(f'Ratio of Features with Dim < {threshold}')
    plt.title('Low-Dimensional Feature Ratio vs Training Size')
    plt.grid(True, alpha=0.3)
    plt.legend()
    
    return plt.gcf()

def create_feature_dim_histograms(results: List[Dict], results_dir: str):
    """Create histograms of feature dimensionality for different training sizes"""
    train_sizes = sorted(set(r['n_train'] for r in results))
    hidden_sizes = sorted(set(r['hidden_size'] for r in results))
    
    n_rows = len(hidden_sizes)
    n_cols = min(6, len(train_sizes))  # Limit number of columns for readability
    selected_train_sizes = np.logspace(np.log10(min(train_sizes)), 
                                     np.log10(max(train_sizes)), 
                                     n_cols).astype(int)
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 3*n_rows))
    if n_rows == 1:
        axes = axes.reshape(1, -1)
    
    for i, hidden_size in enumerate(hidden_sizes):
        for j, n_train in enumerate(selected_train_sizes):
            ax = axes[i, j]
            
            # Find closest available training size
            closest_n_train = min(train_sizes, key=lambda x: abs(x - n_train))
            
            matching_results = [r for r in results 
                              if r['hidden_size'] == hidden_size and r['n_train'] == closest_n_train]
            
            if matching_results:
                result = matching_results[0]
                model = load_model(results_dir, result)
                
                if model is not None:
                    feature_dims = analyze_feature_dimensionality(model)
                    
                    ax.hist(np.log1p(feature_dims), bins=30, alpha=0.7)
                    ax.set_title(f'h={hidden_size}, n={closest_n_train}')
                else:
                    ax.text(0.5, 0.5, 'No data available', 
                           ha='center', va='center',
                           transform=ax.transAxes)
                
            if j == 0:
                ax.set_ylabel('Count')
            if i == n_rows-1:
                ax.set_xlabel('log(1 + dim)')
    
    plt.tight_layout()
    return fig

def create_threshold_plot(nn_results: List[Dict], ntk_results: List[Dict], performance_threshold: float = 80.0):
    """Create threshold crossing analysis plot"""
    # Extract parameters
    depths = sorted(set(r['depth'] for r in nn_results))
    learning_rates = sorted(set(r.get('learning_rate', r.get('lr')) for r in nn_results))
    hidden_sizes = sorted(set(r['hidden_size'] for r in nn_results))
    
    fig, axes = plt.subplots(1, len(learning_rates), figsize=(6*len(learning_rates), 5))
    if len(learning_rates) == 1:
        axes = [axes]
            
    fig.suptitle(f'Training Size at First Performance Threshold Crossing\n' +
                 f'(When NN error first becomes ≤ {100-performance_threshold}% of NTK error)',
                 fontsize=16, y=1.02)
    
    depth_colors = plt.cm.viridis(np.linspace(0, 1, len(depths)))
    
    for j, lr in enumerate(learning_rates):
        ax = axes[j]
        
        for depth_idx, depth in enumerate(depths):
            threshold_data = []
            
            for hidden_size in hidden_sizes:
                ntk_errors_dict = {}
                for r in ntk_results:
                    if (r['depth'] == depth and 
                        r['hidden_size'] == hidden_size and 
                        r['training_mode'] == 'ntk' and
                        r['status'] == 'success'):
                        ntk_errors_dict[r['n_train']] = r['test_error']
                
                nn_data = [(r['n_train'], r['test_error']) 
                          for r in nn_results 
                          if r['depth'] == depth and 
                          r['hidden_size'] == hidden_size and 
                          (r.get('learning_rate', r.get('lr')) == lr)]
                
                if nn_data:
                    nn_points = sorted(nn_data)
                    
                    for train_size, nn_error in nn_points:
                        ntk_error = ntk_errors_dict.get(train_size)
                        if ntk_error is not None:
                            if nn_error <= ntk_error * (100 - performance_threshold) / 100:
                                threshold_data.append((hidden_size, train_size))
                                break
            
            if threshold_data:
                hidden_widths, crossing_points = zip(*sorted(threshold_data))
                ax.plot(hidden_widths, crossing_points, '-o', linewidth=2, 
                       markersize=6, color=depth_colors[depth_idx],
                       label=f'd={depth}')
        
        ax.set_xscale('log')
        ax.set_xlabel('Hidden Width')
        if j == 0:
            ax.set_ylabel('Training Size at Threshold')
        
        ax.text(0.05, 0.95, f'lr={lr:.1e}', 
               transform=ax.transAxes, 
               verticalalignment='top',
               bbox=dict(facecolor='white', alpha=0.8))
        
        ax.grid(True, which="both", ls="-", alpha=0.2)
        ax.legend(title='Network Depth', bbox_to_anchor=(1.05, 1), loc='upper left')
    
    plt.tight_layout()
    return fig

def main():
    # Set your paths here
    nn_results_dir = "/mnt/users/goringn/NNs_vs_Kernels/stair_function/results/msp_NN_grid_0701_mup_lr005"
    ntk_results_path = "/mnt/users/goringn/NNs_vs_Kernels/stair_function/results/msp_NN_grid_1812_spectral/final_results_20241219_015151.json"
    output_dir = "analysis_outputs"
    os.makedirs(output_dir, exist_ok=True)
    
    # Load results
    nn_results, hyperparams = load_experiment_data(nn_results_dir)
    with open(ntk_results_path, 'r') as f:
        ntk_results = json.load(f)
    
    print(f"Loaded {len(nn_results)} NN results and {len(ntk_results)} NTK results")
    
    # Create threshold plot
    threshold_fig = create_threshold_plot(nn_results, ntk_results)
    threshold_fig.savefig(os.path.join(output_dir, 'threshold_analysis.png'), dpi=300, bbox_inches='tight')
    
    # Create low-dimensional ratio plot
    ratio_fig = plot_low_dim_ratio(nn_results, nn_results_dir)
    ratio_fig.savefig(os.path.join(output_dir, 'low_dim_ratio.png'), dpi=300, bbox_inches='tight')
    
    # Create feature dimensionality histograms
    hist_fig = create_feature_dim_histograms(nn_results, nn_results_dir)
    hist_fig.savefig(os.path.join(output_dir, 'feature_dim_histograms.png'), dpi=300, bbox_inches='tight')
    
    plt.close('all')

if __name__ == "__main__":
    main()

KeyboardInterrupt: 

In [11]:
import torch
import os
import json
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, Dict, List, Optional
from pathlib import Path

class DeepNN(torch.nn.Module):
    def __init__(self, d: int, hidden_size: int, depth: int, mode: str = 'standard'):
        super().__init__()
        
        torch.set_default_dtype(torch.float32)
        
        self.mode = mode
        self.depth = depth
        self.hidden_size = hidden_size
        self.input_dim = d
        
        layers = []
        prev_dim = d
        
        for layer_idx in range(depth):
            linear = torch.nn.Linear(prev_dim, hidden_size)
            
            if mode == 'standard':
                torch.nn.init.xavier_uniform_(linear.weight)
                torch.nn.init.zeros_(linear.bias)
            
            layers.extend([
                linear,
                torch.nn.ReLU()
            ])
            prev_dim = hidden_size
        
        final_layer = torch.nn.Linear(prev_dim, 1)
        if mode == 'standard':
            torch.nn.init.xavier_uniform_(final_layer.weight)
        torch.nn.init.zeros_(final_layer.bias)
        layers.append(final_layer)
        
        self.network = torch.nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x).squeeze()

def find_model_file(results_dir: str, model_prefix: str, rank: int) -> Optional[str]:
    """Find a model file matching the prefix and rank across different timestamps"""
    # Try both timestamps and model file patterns
    timestamps = ["20250107_150050", "20250107_151953"]
    patterns = [
        f"final_model_{model_prefix}_{{}}_rank{rank}.pt",
        f"initial_model_{model_prefix}_{{}}_rank{rank}.pt"
    ]
    
    for timestamp in timestamps:
        for pattern in patterns:
            filename = pattern.format(timestamp)
            filepath = os.path.join(results_dir, filename)
            if os.path.exists(filepath):
                return filepath
    return None

def load_model(results_dir: str, result: Dict) -> Optional[torch.nn.Module]:
    """Load just the model without dataset"""
    try:
        hyperparams_file = next(Path(results_dir).glob("hyperparameters*.json"))
        with open(hyperparams_file, 'r') as f:
            hyperparams = json.load(f)
        input_dim = hyperparams['d']
        
        hidden_size = result['hidden_size']
        depth = result['depth']
        n_train = result['n_train']
        lr = result.get('learning_rate', result.get('lr'))
        mode = result.get('mode', 'mup_pennington')  # Update default mode if needed
        shuffled = result.get('shuffled', False)
        rank = result.get('worker_rank', 0)
        
        model_prefix = f'h{hidden_size}_d{depth}_n{n_train}_lr{lr}_{mode}'
        if shuffled:
            model_prefix += '_shuffled'
        
        # Find model file
        model_path = find_model_file(results_dir, model_prefix, rank)
        
        if not model_path:
            print(f"Warning: No model file found for prefix: {model_prefix}, rank: {rank}")
            return None
        
        model = DeepNN(input_dim, hidden_size, depth, mode=mode)
        model.load_state_dict(torch.load(model_path))
        
        device = next(model.parameters()).device
        model = model.to(device)
        
        return model
        
    except Exception as e:
        print(f"Error loading model: {str(e)}")
        return None

def load_experiment_data(results_dir: str) -> Tuple[List[Dict], Dict]:
    """Load experiment results from rank-based result files"""
    nn_files = list(Path(results_dir).glob("results*.json"))
    if not nn_files:
        raise ValueError(f"No result files found in {results_dir}")
    
    # Load results from all rank files
    combined_results = []
    empty_files = []
    
    for file_path in nn_files:
        try:
            with open(file_path) as f:
                content = f.read().strip()
                if not content:
                    empty_files.append(file_path)
                    continue
                    
                results = json.loads(content)
                if isinstance(results, list):
                    combined_results.extend(results)
                else:
                    combined_results.append(results)
                    
        except Exception as e:
            print(f"Error processing {file_path}: {str(e)}")
            continue
    
    if empty_files:
        print("\nThe following files were empty:")
        for f in empty_files:
            print(f"  - {f}")
            
    print(f"\nTotal results loaded: {len(combined_results)}")
    
    # Load hyperparameters from any rank file (they should be the same)
    hyperparams_file = next(Path(results_dir).glob("hyperparameters*.json"))
    with open(hyperparams_file, 'r') as f:
        hyperparams = json.load(f)
    
    return combined_results, hyperparams

def analyze_feature_dimensionality(model: torch.nn.Module) -> np.ndarray:
    """Analyze feature dimensionality for layer 1 of the model"""
    first_layer = None
    for layer in model.network:
        if isinstance(layer, torch.nn.Linear):
            first_layer = layer
            break
    
    if first_layer is None:
        raise ValueError("No linear layer found in model")
        
    with torch.no_grad():
        features = first_layer.weight.data
        feature_norms = torch.norm(features, dim=0, keepdim=True)
        normalized_features = features / (feature_norms + 1e-8)
        feature_dots = normalized_features.T @ features
        feature_dots_squared = feature_dots ** 2
        feature_denominator = torch.sum(feature_dots_squared, dim=1)
        feature_numerator = torch.sum(features * features, dim=0)
        feature_dims = feature_numerator / (feature_denominator + 1e-8)
        
        return feature_dims.cpu().numpy()

def plot_low_dim_ratio(results: List[Dict], results_dir: str, threshold: float = 0.1):
    """Plot ratio of features with dimensionality below threshold vs training size"""
    train_sizes = sorted(set(r['n_train'] for r in results))
    hidden_sizes = sorted(set(r['hidden_size'] for r in results))
    
    plt.figure(figsize=(10, 6))
    
    # Create colormap from red to blue
    n_sizes = len(hidden_sizes)
    colors = plt.cm.RdBu(np.linspace(0, 1, n_sizes))
    
    for idx, hidden_size in enumerate(hidden_sizes):
        ratios = []
        valid_train_sizes = []
        for n_train in train_sizes:
            matching_results = [r for r in results 
                              if r['hidden_size'] == hidden_size and r['n_train'] == n_train]
            
            if matching_results:
                result = matching_results[0]
                model = load_model(results_dir, result)
                
                if model is not None:
                    feature_dims = analyze_feature_dimensionality(model)
                    ratio = np.mean(feature_dims < threshold)
                    ratios.append(ratio)
                    valid_train_sizes.append(n_train)
        
        if ratios:
            plt.plot(valid_train_sizes, ratios, '-o', 
                    color=colors[idx], 
                    alpha=0.7,
                    label=f'h={hidden_size}',
                    markersize=4)
    
    plt.xscale('log')
    plt.xlabel('Training Size')
    plt.ylabel(f'Ratio of Features with Dim < {threshold}')
    plt.title('Low-Dimensional Feature Ratio vs Training Size')
    plt.grid(True, alpha=0.3)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    
    return plt.gcf()

def create_feature_dim_histograms(results: List[Dict], results_dir: str):
    """Create histograms of feature dimensionality for different training sizes"""
    train_sizes = sorted(set(r['n_train'] for r in results))
    hidden_sizes = sorted(set(r['hidden_size'] for r in results))
    
    n_rows = len(hidden_sizes)
    n_cols = min(6, len(train_sizes))  # Limit number of columns for readability
    selected_train_sizes = np.logspace(np.log10(min(train_sizes)), 
                                     np.log10(max(train_sizes)), 
                                     n_cols).astype(int)
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 3*n_rows))
    if n_rows == 1:
        axes = axes.reshape(1, -1)
    
    for i, hidden_size in enumerate(hidden_sizes):
        for j, n_train in enumerate(selected_train_sizes):
            ax = axes[i, j]
            
            # Find closest available training size
            closest_n_train = min(train_sizes, key=lambda x: abs(x - n_train))
            
            matching_results = [r for r in results 
                              if r['hidden_size'] == hidden_size and r['n_train'] == closest_n_train]
            
            if matching_results:
                result = matching_results[0]
                model = load_model(results_dir, result)
                
                if model is not None:
                    feature_dims = analyze_feature_dimensionality(model)
                    
                    ax.hist(np.log1p(feature_dims), bins=30, alpha=0.7)
                    ax.set_title(f'h={hidden_size}, n={closest_n_train}')
                else:
                    ax.text(0.5, 0.5, 'No data available', 
                           ha='center', va='center',
                           transform=ax.transAxes)
                
            if j == 0:
                ax.set_ylabel('Count')
            if i == n_rows-1:
                ax.set_xlabel('log(1 + dim)')
    
    plt.tight_layout()
    return fig

def create_threshold_plot(nn_results: List[Dict], ntk_results: List[Dict], performance_threshold: float = 80.0):
    """Create threshold crossing analysis plot"""
    # Extract parameters
    depths = sorted(set(r['depth'] for r in nn_results))
    learning_rates = sorted(set(r.get('learning_rate', r.get('lr')) for r in nn_results))
    hidden_sizes = sorted(set(r['hidden_size'] for r in nn_results))
    
    fig, axes = plt.subplots(1, len(learning_rates), figsize=(6*len(learning_rates), 5))
    if len(learning_rates) == 1:
        axes = [axes]
            
    fig.suptitle(f'Training Size at First Performance Threshold Crossing\n' +
                 f'(When NN error first becomes ≤ {100-performance_threshold}% of NTK error)',
                 fontsize=16, y=1.02)
    
    depth_colors = plt.cm.viridis(np.linspace(0, 1, len(depths)))
    
    for j, lr in enumerate(learning_rates):
        ax = axes[j]
        
        for depth_idx, depth in enumerate(depths):
            threshold_data = []
            
            for hidden_size in hidden_sizes:
                ntk_errors_dict = {}
                for r in ntk_results:
                    if (r['depth'] == depth and 
                        r['hidden_size'] == hidden_size and 
                        r['training_mode'] == 'ntk' and
                        r['status'] == 'success'):
                        ntk_errors_dict[r['n_train']] = r['test_error']
                
                nn_data = [(r['n_train'], r['test_error']) 
                          for r in nn_results 
                          if r['depth'] == depth and 
                          r['hidden_size'] == hidden_size and 
                          (r.get('learning_rate', r.get('lr')) == lr)]
                
                if nn_data:
                    nn_points = sorted(nn_data)
                    
                    for train_size, nn_error in nn_points:
                        ntk_error = ntk_errors_dict.get(train_size)
                        if ntk_error is not None:
                            if nn_error <= ntk_error * (100 - performance_threshold) / 100:
                                threshold_data.append((hidden_size, train_size))
                                break
            
            if threshold_data:
                hidden_widths, crossing_points = zip(*sorted(threshold_data))
                ax.plot(hidden_widths, crossing_points, '-o', linewidth=2, 
                       markersize=6, color=depth_colors[depth_idx],
                       label=f'd={depth}')
        
        ax.set_xscale('log')
        ax.set_xlabel('Hidden Width')
        if j == 0:
            ax.set_ylabel('Training Size at Threshold')
        
        ax.text(0.05, 0.95, f'lr={lr:.1e}', 
               transform=ax.transAxes, 
               verticalalignment='top',
               bbox=dict(facecolor='white', alpha=0.8))
        
        ax.grid(True, which="both", ls="-", alpha=0.2)
        ax.legend(title='Network Depth', bbox_to_anchor=(1.05, 1), loc='upper left')
    
    plt.tight_layout()
    return fig

def main():
    # Set your paths here
    nn_results_dir = "/mnt/users/goringn/NNs_vs_Kernels/stair_function/results/msp_NN_grid_0701_mup_lr005"
    ntk_results_path = "/mnt/users/goringn/NNs_vs_Kernels/stair_function/results/msp_NN_grid_1812_spectral/final_results_20241219_015151.json"
    output_dir = "analysis_outputs"
    os.makedirs(output_dir, exist_ok=True)
    
    # Load results
    nn_results, hyperparams = load_experiment_data(nn_results_dir)
    with open(ntk_results_path, 'r') as f:
        ntk_results = json.load(f)
    
    print(f"Loaded {len(nn_results)} NN results and {len(ntk_results)} NTK results")
    
    # Create threshold plot
    threshold_fig = create_threshold_plot(nn_results, ntk_results)
    threshold_fig.savefig(os.path.join(output_dir, 'threshold_analysis.png'), dpi=300, bbox_inches='tight')
    
    # Create low-dimensional ratio plot
    ratio_fig = plot_low_dim_ratio(nn_results, nn_results_dir)
    ratio_fig.savefig(os.path.join(output_dir, 'low_dim_ratio.png'), dpi=300, bbox_inches='tight')
    
    # Create feature dimensionality histograms
    hist_fig = create_feature_dim_histograms(nn_results, nn_results_dir)
    hist_fig.savefig(os.path.join(output_dir, 'feature_dim_histograms.png'), dpi=300, bbox_inches='tight')
    
    plt.close('all')

if __name__ == "__main__":
    main()


Total results loaded: 866
Loaded 866 NN results and 2346 NTK results


  model.load_state_dict(torch.load(model_path))


In [2]:
#### working msp
import torch
import os
import json
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from typing import Tuple, Dict, List, Optional
from pathlib import Path

def setup_science_style(font_size: float = 8):
    """Set up consistent science styling with proper log-scale ticks."""
    plt.style.use('default')
    mpl.rcParams.update({
        # Font sizes
        'font.size': font_size,
        'axes.labelsize': font_size,
        'xtick.labelsize': font_size,
        'ytick.labelsize': font_size,
        'legend.fontsize': font_size,
        'figure.figsize': (7.5, 3.9),
        'figure.dpi': 300,
        
        # Font settings
        'font.family': 'serif',
        'font.serif': ['cmr10', 'Computer Modern Serif', 'DejaVu Serif'],
        'text.usetex': False,
        'axes.formatter.use_mathtext': True,
        'mathtext.fontset': 'cm',
        
        # Axis settings
        'axes.linewidth': 0.5,
        'axes.spines.top': True,
        'axes.spines.right': True,
        'axes.spines.left': True,
        'axes.spines.bottom': True,
        
        # Tick settings
        'xtick.direction': 'in',
        'ytick.direction': 'in',
        'xtick.major.width': 0.5,
        'ytick.major.width': 0.5,
        'xtick.minor.width': 0.5,
        'ytick.minor.width': 0.5,
        'xtick.major.size': 3,
        'ytick.major.size': 3,
        'xtick.minor.size': 1.5,
        'ytick.minor.size': 1.5,
        'xtick.top': True,
        'ytick.right': True,
        
        # Grid settings
        'grid.linewidth': 0.5,
        
        # Line settings
        'lines.linewidth': 2.0,
        'lines.markersize': 3,
        
        # Legend settings
        'legend.frameon': False,
        'legend.borderpad': 0,
        'legend.borderaxespad': 1.0,
        'legend.handlelength': 1.0,
        'legend.handletextpad': 0.5,
    })

class DeepNN(torch.nn.Module):
    def __init__(self, d: int, hidden_size: int, depth: int, mode: str = 'standard'):
        super().__init__()
        torch.set_default_dtype(torch.float32)
        self.mode = mode
        self.depth = depth
        self.hidden_size = hidden_size
        self.input_dim = d
        
        layers = []
        prev_dim = d
        
        for layer_idx in range(depth):
            linear = torch.nn.Linear(prev_dim, hidden_size)
            if mode == 'standard':
                torch.nn.init.xavier_uniform_(linear.weight)
                torch.nn.init.zeros_(linear.bias)
            layers.extend([linear, torch.nn.ReLU()])
            prev_dim = hidden_size
        
        final_layer = torch.nn.Linear(prev_dim, 1)
        if mode == 'standard':
            torch.nn.init.xavier_uniform_(final_layer.weight)
        torch.nn.init.zeros_(final_layer.bias)
        layers.append(final_layer)
        
        self.network = torch.nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x).squeeze()

def load_experiment_data(results_dir: str) -> Tuple[List[Dict], Dict]:
    """Load experiment results from rank-based result files"""
    nn_files = list(Path(results_dir).glob("results*.json"))
    if not nn_files:
        raise ValueError(f"No result files found in {results_dir}")
    
    combined_results = []
    empty_files = []
    
    for file_path in nn_files:
        try:
            with open(file_path) as f:
                content = f.read().strip()
                if not content:
                    empty_files.append(file_path)
                    continue
                results = json.loads(content)
                if isinstance(results, list):
                    combined_results.extend(results)
                else:
                    combined_results.append(results)
        except Exception as e:
            print(f"Error processing {file_path}: {str(e)}")
            continue
    
    if empty_files:
        print("\nEmpty files found:")
        for f in empty_files:
            print(f"  - {f}")
    
    print(f"\nTotal results loaded: {len(combined_results)}")
    
    hyperparams = {
        "ambient_dim": 30,
        "hidden_sizes": sorted(set(r['hidden_size'] for r in combined_results)),
        "depths": sorted(set(r['depth'] for r in combined_results))
    }
    
    for result in combined_results:
        result['ambient_dim'] = hyperparams['ambient_dim']
    
    return combined_results, hyperparams

def find_model_file(results_dir: str, model_prefix: str, rank: int) -> Optional[str]:
    """Find a model file matching the prefix and rank across different timestamps"""
    timestamps = ["20250107_150050", "20250107_151953"]
    patterns = [
        f"final_model_{model_prefix}_{{}}_rank{rank}.pt",
        f"initial_model_{model_prefix}_{{}}_rank{rank}.pt"
    ]
    
    for timestamp in timestamps:
        for pattern in patterns:
            filename = pattern.format(timestamp)
            filepath = os.path.join(results_dir, filename)
            if os.path.exists(filepath):
                return filepath
    return None

def load_model(results_dir: str, result: Dict) -> Optional[torch.nn.Module]:
    """Load just the model without dataset"""
    try:
        hidden_size = result['hidden_size']
        depth = result['depth']
        n_train = result['n_train']
        lr = result.get('learning_rate', result.get('lr'))
        mode = result.get('mode', 'mup_pennington')
        rank = result.get('worker_rank', 0)
        ambient_dim = result.get('ambient_dim', 20)
        
        model_prefix = f'h{hidden_size}_d{depth}_n{n_train}_lr{lr}_{mode}'
        model_path = find_model_file(results_dir, model_prefix, rank)
        
        if not model_path:
            print(f"Warning: No model file found for prefix: {model_prefix}, rank: {rank}")
            return None
        
        model = DeepNN(ambient_dim, hidden_size, depth)
        model.load_state_dict(torch.load(model_path))
        model = model.to(next(model.parameters()).device)
        
        return model
        
    except Exception as e:
        print(f"Error loading model: {str(e)}")
        return None

def analyze_feature_dimensionality(model: torch.nn.Module) -> np.ndarray:
    """Analyze feature dimensionality for layer 1 of the model"""
    first_layer = None
    for layer in model.network:
        if isinstance(layer, torch.nn.Linear):
            first_layer = layer
            break
    
    if first_layer is None:
        raise ValueError("No linear layer found in model")
        
    with torch.no_grad():
        features = first_layer.weight.data
        feature_norms = torch.norm(features, dim=0, keepdim=True)
        normalized_features = features / (feature_norms + 1e-8)
        feature_dots = normalized_features.T @ features
        feature_dots_squared = feature_dots ** 2
        feature_denominator = torch.sum(feature_dots_squared, dim=1)
        feature_numerator = torch.sum(features * features, dim=0)
        feature_dims = feature_numerator / (feature_denominator + 1e-8)
        
        return feature_dims.cpu().numpy()

def plot_low_dim_ratio(results: List[Dict], results_dir: str, threshold: float = 0.2, font_size: float = 8):
    """Plot ratio of features with dimensionality below threshold vs training size"""
    setup_science_style(font_size=font_size)
    
    train_sizes = sorted(set(r['n_train'] for r in results))
    hidden_sizes = sorted(set(r['hidden_size'] for r in results))
    
    print("\nFound train sizes:", train_sizes)
    print("Found hidden sizes:", hidden_sizes)
    
    # Create figure with specific size ratio
    plt.figure(figsize=(7.5, 3.9))
    
    # Create color gradient from purple to cyan
    num_sizes = len(hidden_sizes)
    colors = []
    start_rgb = np.array([238, 0, 255])   # #ee00ff (purple)
    end_rgb = np.array([0, 251, 255])     # #00fbff (cyan)
    
    for i in range(num_sizes):
        t = i / max(1, num_sizes - 1)
        rgb = start_rgb * (1 - t) + end_rgb * t
        rgb = rgb.astype(int)
        colors.append(f'#{rgb[0]:02x}{rgb[1]:02x}{rgb[2]:02x}')
    
    for idx, hidden_size in enumerate(hidden_sizes):
        ratios = []
        valid_train_sizes = []
        
        for n_train in train_sizes:
            matching_results = [r for r in results 
                              if r['hidden_size'] == hidden_size and r['n_train'] == n_train]
            
            if matching_results:
                result = matching_results[0]
                model = load_model(results_dir, result)
                
                if model is None:
                    print(f"Could not load model for width={hidden_size}, n_train={n_train}")
                    continue
                    
                feature_dims = analyze_feature_dimensionality(model)
                ratio = np.mean(feature_dims < threshold)
                ratios.append(ratio)
                valid_train_sizes.append(n_train)
        
        if ratios:
            plt.plot(valid_train_sizes, ratios, '-s', 
                    color=colors[idx],
                    linewidth=2.0,
                    markersize=3,
                    label=f'$N={hidden_size}$')
    
    plt.xscale('log')
    plt.xlabel(r'Training Set Size $m$', labelpad=2)
    plt.ylabel(r'Ratio of Features with Dim $< 0.2$', labelpad=2)
    
    # Configure ticks and grid
    ax = plt.gca()
    ax.minorticks_on()
    ax.tick_params(which='both', direction='in')
    plt.grid(True, alpha=0.3, which='both')
    
    # Place legend outside
    plt.legend(bbox_to_anchor=(1.0, 1), loc='upper left', fontsize=font_size)
    
    # Adjust layout to prevent legend cutoff
    plt.tight_layout()
    
    return plt.gcf()

def main():
    # Set paths
    results_dir = "/mnt/users/goringn/NNs_vs_Kernels/stair_function/results/msp_NN_grid_0701_mup_lr005"
    output_dir = "analysis_outputs"
    os.makedirs(output_dir, exist_ok=True)
    
    # Load results
    nn_results, hyperparams = load_experiment_data(results_dir)
    print(f"\nLoaded {len(nn_results)} NN results")
    print("\nUnique configurations found:")
    for result in nn_results:
        print(f"Width: {result['hidden_size']}, Depth: {result['depth']}, N_train: {result['n_train']}")
    
    # Create plot
    ratio_fig = plot_low_dim_ratio(nn_results, results_dir, font_size=7)
    
    # Save as PDF with high quality
    ratio_fig.savefig(os.path.join(output_dir, 'low_dim_ratio.pdf'), 
                     format='pdf',
                     dpi=300, 
                     bbox_inches='tight',
                     facecolor='white',
                     edgecolor='black')
    
    plt.close('all')

if __name__ == "__main__":
    main()


Total results loaded: 866

Loaded 866 NN results

Unique configurations found:
Width: 800, Depth: 4, N_train: 4000
Width: 800, Depth: 4, N_train: 400
Width: 800, Depth: 1, N_train: 10000
Width: 800, Depth: 1, N_train: 2000
Width: 800, Depth: 1, N_train: 10
Width: 8000, Depth: 4, N_train: 4000
Width: 8000, Depth: 4, N_train: 400
Width: 8000, Depth: 1, N_train: 10000
Width: 8000, Depth: 1, N_train: 2000
Width: 8000, Depth: 1, N_train: 10
Width: 2000, Depth: 4, N_train: 4000
Width: 2000, Depth: 4, N_train: 400
Width: 2000, Depth: 1, N_train: 10000
Width: 2000, Depth: 1, N_train: 2000
Width: 2000, Depth: 1, N_train: 10
Width: 150, Depth: 4, N_train: 4000
Width: 150, Depth: 4, N_train: 400
Width: 150, Depth: 1, N_train: 10000
Width: 150, Depth: 1, N_train: 2000
Width: 150, Depth: 1, N_train: 10
Width: 120, Depth: 4, N_train: 4000
Width: 120, Depth: 4, N_train: 400
Width: 120, Depth: 1, N_train: 10000
Width: 120, Depth: 1, N_train: 2000
Width: 120, Depth: 1, N_train: 10
Width: 100, Depth: 4

  model.load_state_dict(torch.load(model_path))


In [None]:
##########

In [None]:
import torch
import os
import json
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, Dict, List, Optional
from pathlib import Path

def load_experiment_data(results_dir: str) -> Tuple[List[Dict], Dict]:
    """Load experiment results from rank-based result files"""
    nn_files = list(Path(results_dir).glob("results*.json"))
    if not nn_files:
        raise ValueError(f"No result files found in {results_dir}")
    
    combined_results = []
    empty_files = []
    
    for file_path in nn_files:
        try:
            with open(file_path) as f:
                content = f.read().strip()
                if not content:
                    empty_files.append(file_path)
                    continue
                    
                results = json.loads(content)
                if isinstance(results, list):
                    combined_results.extend(results)
                else:
                    combined_results.append(results)
                    
        except Exception as e:
            print(f"Error processing {file_path}: {str(e)}")
            continue
    
    if empty_files:
        print("\nThe following files were empty:")
        for f in empty_files:
            print(f"  - {f}")
            
    print(f"\nTotal results loaded: {len(combined_results)}")
    
    # Build hyperparams from defaults and available data
    hyperparams = {
        "ambient_dim": 20,  # Fixed for this dataset
        "hidden_sizes": sorted(set(r['hidden_size'] for r in combined_results)),
        "depths": sorted(set(r['depth'] for r in combined_results))
    }
    
    # Add ambient_dim to all results
    for result in combined_results:
        result['ambient_dim'] = hyperparams['ambient_dim']
    
    return combined_results, hyperparams

class DeepNN(torch.nn.Module):
    def __init__(self, d: int, hidden_size: int, depth: int, mode: str = 'standard'):
        super().__init__()
        
        torch.set_default_dtype(torch.float32)
        
        self.mode = mode
        self.depth = depth
        self.hidden_size = hidden_size
        self.input_dim = d
        
        layers = []
        prev_dim = d
        
        for layer_idx in range(depth):
            linear = torch.nn.Linear(prev_dim, hidden_size)
            
            if mode == 'standard':
                torch.nn.init.xavier_uniform_(linear.weight)
                torch.nn.init.zeros_(linear.bias)
            
            layers.extend([
                linear,
                torch.nn.ReLU()
            ])
            prev_dim = hidden_size
        
        final_layer = torch.nn.Linear(prev_dim, 1)
        if mode == 'standard':
            torch.nn.init.xavier_uniform_(final_layer.weight)
        torch.nn.init.zeros_(final_layer.bias)
        layers.append(final_layer)
        
        self.network = torch.nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x).squeeze()

def find_model_file(results_dir: str, model_prefix: str, rank: int) -> Optional[str]:
    """Find a model file matching the prefix and rank across different timestamps"""
    timestamps = ["20241230_145221", "20241230_145424"]
    patterns = [
        f"final_model_{model_prefix}_{{}}_rank{rank}.pt",
        f"initial_model_{model_prefix}_{{}}_rank{rank}.pt"
    ]
    
    for timestamp in timestamps:
        for pattern in patterns:
            filename = pattern.format(timestamp)
            filepath = os.path.join(results_dir, filename)
            if os.path.exists(filepath):
                return filepath
    return None

def load_model(results_dir: str, result: Dict) -> Optional[torch.nn.Module]:
    """Load just the model without dataset"""
    try:
        hidden_size = result['hidden_size']
        depth = result['depth']
        n_train = result['n_train']
        lr = result.get('learning_rate', result.get('lr'))
        mode = result.get('mode', 'mup_pennington')
        rank = result.get('worker_rank', 0)
        ambient_dim = result.get('ambient_dim', 20)  # Default to 20 if not found
        
        model_prefix = f'h{hidden_size}_d{depth}_n{n_train}_lr{lr}_{mode}'
        
        # Find model file
        model_path = find_model_file(results_dir, model_prefix, rank)
        
        if not model_path:
            print(f"Warning: No model file found for prefix: h{hidden_size}_d{depth}_n{n_train}_lr{lr}_{mode}, rank: {rank}")
            return None
        
        model = DeepNN(ambient_dim, hidden_size, depth)
        model.load_state_dict(torch.load(model_path))
        
        device = next(model.parameters()).device
        model = model.to(device)
        
        return model
        
    except Exception as e:
        print(f"Error loading model: {str(e)}")
        return None

def analyze_feature_dimensionality(model: torch.nn.Module) -> np.ndarray:
    """Analyze feature dimensionality for layer 1 of the model"""
    first_layer = None
    for layer in model.network:
        if isinstance(layer, torch.nn.Linear):
            first_layer = layer
            break
    
    if first_layer is None:
        raise ValueError("No linear layer found in model")
        
    with torch.no_grad():
        features = first_layer.weight.data
        feature_norms = torch.norm(features, dim=0, keepdim=True)
        normalized_features = features / (feature_norms + 1e-8)
        feature_dots = normalized_features.T @ features
        feature_dots_squared = feature_dots ** 2
        feature_denominator = torch.sum(feature_dots_squared, dim=1)
        feature_numerator = torch.sum(features * features, dim=0)
        feature_dims = feature_numerator / (feature_denominator + 1e-8)
        
        return feature_dims.cpu().numpy()

def plot_low_dim_ratio(results: List[Dict], results_dir: str, threshold: float = 0.2):
    """Plot ratio of features with dimensionality below threshold vs training size"""
    train_sizes = sorted(set(r['n_train'] for r in results))
    hidden_sizes = sorted(set(r['hidden_size'] for r in results))
    
    plt.figure(figsize=(10, 6))
    
    # Create colormap from red to blue
    n_sizes = len(hidden_sizes)
    colors = plt.cm.RdBu(np.linspace(0, 1, n_sizes))
    
    for idx, hidden_size in enumerate(hidden_sizes):
        ratios = []
        valid_train_sizes = []
        for n_train in train_sizes:
            matching_results = [r for r in results 
                              if r['hidden_size'] == hidden_size and r['n_train'] == n_train]
            
            if matching_results:
                result = matching_results[0]
                model = load_model(results_dir, result)
                
                if model is not None:
                    feature_dims = analyze_feature_dimensionality(model)
                    ratio = np.mean(feature_dims < threshold)
                    ratios.append(ratio)
                    valid_train_sizes.append(n_train)
        
        if ratios:
            plt.plot(valid_train_sizes, ratios, '-o', 
                    color=colors[idx], 
                    alpha=0.7,
                    label=f'h={hidden_size}',
                    markersize=4)
    
    plt.xscale('log')
    plt.xlabel('Training Size')
    plt.ylabel(f'Ratio of Features with Dim < {threshold}')
    plt.title('Low-Dimensional Feature Ratio vs Training Size')
    plt.grid(True, alpha=0.3)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    
    return plt.gcf()

def create_feature_dim_histograms(results: List[Dict], results_dir: str):
    """Create histograms of feature dimensionality for different training sizes"""
    train_sizes = sorted(set(r['n_train'] for r in results))
    hidden_sizes = sorted(set(r['hidden_size'] for r in results))
    
    n_rows = len(hidden_sizes)
    n_cols = min(6, len(train_sizes))  # Limit number of columns for readability
    selected_train_sizes = np.logspace(np.log10(min(train_sizes)), 
                                     np.log10(max(train_sizes)), 
                                     n_cols).astype(int)
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 3*n_rows))
    if n_rows == 1:
        axes = axes.reshape(1, -1)
    
    for i, hidden_size in enumerate(hidden_sizes):
        for j, n_train in enumerate(selected_train_sizes):
            ax = axes[i, j]
            
            # Find closest available training size
            closest_n_train = min(train_sizes, key=lambda x: abs(x - n_train))
            
            matching_results = [r for r in results 
                              if r['hidden_size'] == hidden_size and r['n_train'] == closest_n_train]
            
            if matching_results:
                result = matching_results[0]
                model = load_model(results_dir, result)
                
                if model is not None:
                    feature_dims = analyze_feature_dimensionality(model)
                    
                    ax.hist(np.log1p(feature_dims), bins=30, alpha=0.7)
                    ax.set_title(f'h={hidden_size}, n={closest_n_train}')
                else:
                    ax.text(0.5, 0.5, 'No data available', 
                           ha='center', va='center',
                           transform=ax.transAxes)
                
            if j == 0:
                ax.set_ylabel('Count')
            if i == n_rows-1:
                ax.set_xlabel('log(1 + dim)')
    
    plt.tight_layout()
    return fig

def main():
    # Set your paths here
    results_dir = "/mnt/users/goringn/NNs_vs_Kernels/low_dim_poly/results/low_dim_poly_NN_2812_mup_lr0001"
    output_dir = "analysis_outputs"
    os.makedirs(output_dir, exist_ok=True)
    
    # Load results
    nn_results, hyperparams = load_experiment_data(results_dir)
    print(f"Loaded {len(nn_results)} NN results")
    
    # Create low-dimensional ratio plot
    ratio_fig = plot_low_dim_ratio(nn_results, results_dir)
    ratio_fig.savefig(os.path.join(output_dir, 'low_dim_ratio.png'), dpi=300, bbox_inches='tight')
    
    # Create feature dimensionality histograms
    hist_fig = create_feature_dim_histograms(nn_results, results_dir)
    hist_fig.savefig(os.path.join(output_dir, 'feature_dim_histograms.png'), dpi=300, bbox_inches='tight')
    
    plt.close('all')

if __name__ == "__main__":
    main()


Total results loaded: 827
Loaded 827 NN results


  model.load_state_dict(torch.load(model_path))


In [26]:
#### working low_dim
import torch
import os
import json
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from typing import Tuple, Dict, List, Optional
from pathlib import Path

def setup_science_style(font_size: float = 8):
    """Set up consistent science styling with proper log-scale ticks."""
    plt.style.use('default')
    mpl.rcParams.update({
        # Font sizes
        'font.size': font_size,
        'axes.labelsize': font_size,
        'xtick.labelsize': font_size,
        'ytick.labelsize': font_size,
        'legend.fontsize': font_size,
        'figure.figsize': (4, 2.7),
        'figure.dpi': 100,
        
        # Font settings
        'font.family': 'serif',
        'font.serif': ['cmr10', 'Computer Modern Serif', 'DejaVu Serif'],
        'text.usetex': False,
        'axes.formatter.use_mathtext': True,
        'mathtext.fontset': 'cm',
        
        # Axis settings
        'axes.linewidth': 0.5,
        'axes.spines.top': True,
        'axes.spines.right': True,
        'axes.spines.left': True,
        'axes.spines.bottom': True,
        
        # Tick settings
        'xtick.direction': 'in',
        'ytick.direction': 'in',
        'xtick.major.width': 0.5,
        'ytick.major.width': 0.5,
        'xtick.minor.width': 0.5,
        'ytick.minor.width': 0.5,
        'xtick.major.size': 3,
        'ytick.major.size': 3,
        'xtick.minor.size': 1.5,
        'ytick.minor.size': 1.5,
        'xtick.top': True,
        'ytick.right': True,
        
        # Grid settings
        'grid.linewidth': 0.5,
        
        # Line settings
        'lines.linewidth': 2.0,
        'lines.markersize': 3,
        
        # Legend settings
        'legend.frameon': False,
        'legend.borderpad': 0,
        'legend.borderaxespad': 1.0,
        'legend.handlelength': 1.0,
        'legend.handletextpad': 0.5,
    })

def load_experiment_data(results_dir: str) -> Tuple[List[Dict], Dict]:
    """Load experiment results from rank-based result files"""
    nn_files = list(Path(results_dir).glob("results*.json"))
    if not nn_files:
        raise ValueError(f"No result files found in {results_dir}")
    
    combined_results = []
    empty_files = []
    
    for file_path in nn_files:
        try:
            with open(file_path) as f:
                content = f.read().strip()
                if not content:
                    empty_files.append(file_path)
                    continue
                    
                results = json.loads(content)
                if isinstance(results, list):
                    combined_results.extend(results)
                else:
                    combined_results.append(results)
                    
        except Exception as e:
            print(f"Error processing {file_path}: {str(e)}")
            continue
    
    if empty_files:
        print("\nThe following files were empty:")
        for f in empty_files:
            print(f"  - {f}")
            
    print(f"\nTotal results loaded: {len(combined_results)}")
    
    # Build hyperparams from defaults and available data
    hyperparams = {
        "ambient_dim": 20,  # Fixed for this dataset
        "hidden_sizes": sorted(set(r['hidden_size'] for r in combined_results)),
        "depths": sorted(set(r['depth'] for r in combined_results))
    }
    
    # Add ambient_dim to all results
    for result in combined_results:
        result['ambient_dim'] = hyperparams['ambient_dim']
    
    return combined_results, hyperparams

class DeepNN(torch.nn.Module):
    def __init__(self, d: int, hidden_size: int, depth: int, mode: str = 'standard'):
        super().__init__()
        
        torch.set_default_dtype(torch.float32)
        
        self.mode = mode
        self.depth = depth
        self.hidden_size = hidden_size
        self.input_dim = d
        
        layers = []
        prev_dim = d
        
        for layer_idx in range(depth):
            linear = torch.nn.Linear(prev_dim, hidden_size)
            
            if mode == 'standard':
                torch.nn.init.xavier_uniform_(linear.weight)
                torch.nn.init.zeros_(linear.bias)
            
            layers.extend([
                linear,
                torch.nn.ReLU()
            ])
            prev_dim = hidden_size
        
        final_layer = torch.nn.Linear(prev_dim, 1)
        if mode == 'standard':
            torch.nn.init.xavier_uniform_(final_layer.weight)
        torch.nn.init.zeros_(final_layer.bias)
        layers.append(final_layer)
        
        self.network = torch.nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x).squeeze()

def find_model_file(results_dir: str, model_prefix: str, rank: int) -> Optional[str]:
    """Find a model file matching the prefix and rank across different timestamps"""
    timestamps = ["20241230_145221", "20241230_145424"]
    patterns = [
        f"final_model_{model_prefix}_{{}}_rank{rank}.pt",
        f"initial_model_{model_prefix}_{{}}_rank{rank}.pt"
    ]
    
    for timestamp in timestamps:
        for pattern in patterns:
            filename = pattern.format(timestamp)
            filepath = os.path.join(results_dir, filename)
            if os.path.exists(filepath):
                return filepath
    return None

def load_model(results_dir: str, result: Dict) -> Optional[torch.nn.Module]:
    """Load just the model without dataset"""
    try:
        hidden_size = result['hidden_size']
        depth = result['depth']
        n_train = result['n_train']
        lr = result.get('learning_rate', result.get('lr'))
        mode = result.get('mode', 'mup_pennington')
        rank = result.get('worker_rank', 0)
        ambient_dim = result.get('ambient_dim', 20)
        
        model_prefix = f'h{hidden_size}_d{depth}_n{n_train}_lr{lr}_{mode}'
        
        model_path = find_model_file(results_dir, model_prefix, rank)
        
        if not model_path:
            print(f"Warning: No model file found for prefix: h{hidden_size}_d{depth}_n{n_train}_lr{lr}_{mode}, rank: {rank}")
            return None
        
        model = DeepNN(ambient_dim, hidden_size, depth)
        model.load_state_dict(torch.load(model_path))
        
        device = next(model.parameters()).device
        model = model.to(device)
        
        return model
        
    except Exception as e:
        print(f"Error loading model: {str(e)}")
        return None

def analyze_feature_dimensionality(model: torch.nn.Module) -> np.ndarray:
    """Analyze feature dimensionality for layer 1 of the model"""
    first_layer = None
    for layer in model.network:
        if isinstance(layer, torch.nn.Linear):
            first_layer = layer
            break
    
    if first_layer is None:
        raise ValueError("No linear layer found in model")
        
    with torch.no_grad():
        features = first_layer.weight.data
        feature_norms = torch.norm(features, dim=0, keepdim=True)
        normalized_features = features / (feature_norms + 1e-8)
        feature_dots = normalized_features.T @ features
        feature_dots_squared = feature_dots ** 2
        feature_denominator = torch.sum(feature_dots_squared, dim=1)
        feature_numerator = torch.sum(features * features, dim=0)
        feature_dims = feature_numerator / (feature_denominator + 1e-8)
        
        return feature_dims.cpu().numpy()

def plot_low_dim_ratio(results: List[Dict], results_dir: str, threshold: float = 0.2, font_size: float = 8):
    """Plot ratio of features with dimensionality below threshold vs training size"""
    # Set up the science style with font size
    setup_science_style(font_size=font_size)
    
    train_sizes = sorted(set(r['n_train'] for r in results))
    hidden_sizes = sorted(set(r['hidden_size'] for r in results))
    
    print("\nFound train sizes:", train_sizes)
    print("Found hidden sizes:", hidden_sizes)
    
    # Create figure with specific size
    plt.figure(figsize=(7.5, 3.9))
    
    # Create an evenly spaced colormap from purple to cyan
    num_sizes = len(hidden_sizes)
    colors = []
    
    # Define start and end colors
    start_rgb = np.array([238, 0, 255])   # #ee00ff (purple)
    end_rgb = np.array([0, 251, 255])     # #00fbff (cyan)
    
    for i in range(num_sizes):
        t = i / max(1, num_sizes - 1)
        rgb = start_rgb * (1 - t) + end_rgb * t
        rgb = rgb.astype(int)
        colors.append(f'#{rgb[0]:02x}{rgb[1]:02x}{rgb[2]:02x}')
    
    for idx, hidden_size in enumerate(hidden_sizes):
        ratios = []
        valid_train_sizes = []
        for n_train in train_sizes:
            matching_results = [r for r in results 
                              if r['hidden_size'] == hidden_size and r['n_train'] == n_train]
            
            if matching_results:
                result = matching_results[0]
                model = load_model(results_dir, result)
                
                if model is None:
                    print(f"Could not load model for width={hidden_size}, n_train={n_train}")
                
                if model is not None:
                    feature_dims = analyze_feature_dimensionality(model)
                    ratio = np.mean(feature_dims < threshold)
                    ratios.append(ratio)
                    valid_train_sizes.append(n_train)
        
        if ratios:
            plt.plot(valid_train_sizes, ratios, '-s', 
                    color=colors[idx % len(colors)],
                    linewidth=2.0,
                    markersize=3,
                    label=f'$N={hidden_size}$')
    
    plt.xscale('log')
    plt.xlabel(r'Training Set Size $m$', labelpad=2)
    plt.ylabel(r'Ratio of Features with Dim $< 0.2$', labelpad=2)
    
    # Configure ticks
    ax = plt.gca()
    ax.minorticks_on()
    ax.tick_params(which='both', direction='in')
    
    # Add grid with low opacity
    plt.grid(True, alpha=0.3, which='both')
    
    # Move legend outside
    plt.legend(bbox_to_anchor=(1.0, 1), loc='upper left', fontsize=font_size)
    
    # Adjust layout to prevent legend cutoff
    plt.tight_layout()
    
    return plt.gcf()

def main():
    # Set your paths here
    results_dir = "/mnt/users/goringn/NNs_vs_Kernels/low_dim_poly/results/low_dim_poly_NN_2812_mup_lr0001"
    output_dir = "analysis_outputs"
    os.makedirs(output_dir, exist_ok=True)
    
    # Load results
    nn_results, hyperparams = load_experiment_data(results_dir)
    print(f"\nLoaded {len(nn_results)} NN results")
    print("\nUnique configurations found:")
    for result in nn_results:
        print(f"Width: {result['hidden_size']}, Depth: {result['depth']}, N_train: {result['n_train']}")
    
    # Create low-dimensional ratio plot with custom font size
    # You can adjust this value to make fonts larger or smaller
    ratio_fig = plot_low_dim_ratio(nn_results, results_dir, font_size=7)
    
    # Save as PDF with high quality
    ratio_fig.savefig(os.path.join(output_dir, 'low_dim_ratio.pdf'), 
                     format='pdf',
                     dpi=300, 
                     bbox_inches='tight',
                     facecolor='white',
                     edgecolor='black')
    
    plt.close('all')

if __name__ == "__main__":
    main()


Total results loaded: 827

Loaded 827 NN results

Unique configurations found:
Width: 5000, Depth: 4, N_train: 60000
Width: 5000, Depth: 4, N_train: 2500
Width: 5000, Depth: 4, N_train: 50
Width: 5000, Depth: 1, N_train: 8000
Width: 5000, Depth: 1, N_train: 200
Width: 3000, Depth: 4, N_train: 15000
Width: 3000, Depth: 4, N_train: 400
Width: 3000, Depth: 1, N_train: 30000
Width: 3000, Depth: 1, N_train: 800
Width: 2000, Depth: 4, N_train: 60000
Width: 2000, Depth: 4, N_train: 2500
Width: 2000, Depth: 4, N_train: 50
Width: 2000, Depth: 1, N_train: 8000
Width: 2000, Depth: 1, N_train: 200
Width: 800, Depth: 4, N_train: 15000
Width: 800, Depth: 4, N_train: 400
Width: 800, Depth: 1, N_train: 30000
Width: 800, Depth: 1, N_train: 800
Width: 600, Depth: 4, N_train: 60000
Width: 600, Depth: 4, N_train: 2500
Width: 600, Depth: 4, N_train: 50
Width: 600, Depth: 1, N_train: 8000
Width: 600, Depth: 1, N_train: 200
Width: 400, Depth: 4, N_train: 15000
Width: 400, Depth: 4, N_train: 400
Width: 400, 

  model.load_state_dict(torch.load(model_path))
