In [1]:
from dataclasses import dataclass
import json
from pathlib import Path
from typing import Dict, List, Set, Tuple
from collections import defaultdict

@dataclass
class ExperimentConfig:
    depths: List[int]
    hidden_sizes: List[int]
    train_sizes: List[int]
    learning_rates: List[float]

@dataclass
class ResultsData:
    nn_results: List[dict]
    ntk_results: List[dict]
    config: ExperimentConfig
    
    def filter_results(self, model_type: str, **kwargs) -> List[dict]:
        """
        Filter results based on given parameters.
        
        Args:
            model_type: Either 'nn' or 'ntk'
            **kwargs: Key-value pairs to filter on (e.g., depth=2, hidden_size=128)
            
        Returns:
            List of results matching the criteria
        """
        results = self.nn_results if model_type == 'nn' else self.ntk_results
        filtered = []
        
        for result in results:
            matches = True
            for key, value in kwargs.items():
                # Handle learning_rate/lr difference
                if key == 'learning_rate' and 'lr' in result:
                    if result['lr'] != value:
                        matches = False
                        break
                elif result.get(key) != value:
                    matches = False
                    break
            if matches:
                filtered.append(result)
                
        return filtered

class ResultsLoader:
    @staticmethod
    def _load_file(file_path: Path) -> List[dict]:
        """Load a single JSON file, handling empty files and errors."""
        try:
            with open(file_path) as f:
                content = f.read().strip()
                if not content:
                    print(f"Warning: Empty file found - {file_path}")
                    return []
                    
                results = json.loads(content)
                return results if isinstance(results, list) else [results]
                    
        except Exception as e:
            print(f"Error processing {file_path}: {str(e)}")
            return []

    @staticmethod
    def _extract_config(nn_results: List[dict], ntk_results: List[dict]) -> ExperimentConfig:
        """Extract configuration parameters from results."""
        depths = sorted(set(r['depth'] for r in nn_results))
        hidden_sizes = sorted(set(r['hidden_size'] for r in nn_results))
        train_sizes = sorted(set(r['n_train'] for r in nn_results))
        
        # Handle different learning rate keys
        learning_rates = set()
        for r in nn_results:
            lr = r.get('learning_rate', r.get('lr'))
            if lr is not None:
                learning_rates.add(lr)
        
        return ExperimentConfig(
            depths=depths,
            hidden_sizes=hidden_sizes,
            train_sizes=train_sizes,
            learning_rates=sorted(learning_rates)
        )

    @staticmethod
    def _standardize_result(result: dict, model_type: str) -> dict:
        """Standardize a single result dictionary."""
        result = result.copy()
        
        # Standardize learning rate key
        if 'lr' in result and 'learning_rate' not in result:
            result['learning_rate'] = result['lr']
            del result['lr']
            
        # Add model type
        result['model_type'] = model_type
        
        return result

    @classmethod
    def load_results(cls, nn_results_dir: str, ntk_results_path: str) -> ResultsData:
        """
        Load and preprocess both NN and NTK results into a structured format.
        
        Args:
            nn_results_dir: Directory containing NN result JSON files
            ntk_results_path: Path to the NTK results JSON file
            
        Returns:
            ResultsData object containing preprocessed results and configuration
        """
        # Convert paths
        nn_dir = Path(nn_results_dir)
        ntk_path = Path(ntk_results_path)
        
        # Load NN results
        nn_files = list(nn_dir.glob("results*.json"))
        if not nn_files:
            raise ValueError(f"No result files found in {nn_results_dir}")
        
        nn_results = []
        for file_path in nn_files:
            results = cls._load_file(file_path)
            nn_results.extend(cls._standardize_result(r, 'nn') for r in results)
            
        # Load NTK results
        ntk_results = [
            cls._standardize_result(r, 'ntk') 
            for r in cls._load_file(ntk_path)
        ]
        
        print(f"Loaded {len(nn_results)} NN results and {len(ntk_results)} NTK results")
        
        # Extract configuration
        config = cls._extract_config(nn_results, ntk_results)
        
        # Sort results
        def sort_key(r):
            return (r['depth'], r['hidden_size'], r['n_train'])
            
        nn_results.sort(key=sort_key)
        ntk_results.sort(key=sort_key)
        
        return ResultsData(
            nn_results=nn_results,
            ntk_results=ntk_results,
            config=config
        )


def calculate_statistics(results: List[dict], group_by: str) -> Dict:
    """Calculate basic statistics for results grouped by a key."""
    stats = defaultdict(lambda: {'count': 0, 'sum': 0.0, 'sum_sq': 0.0})
    
    for result in results:
        group = result[group_by]
        error = result['test_error']
        
        stats[group]['count'] += 1
        stats[group]['sum'] += error
        stats[group]['sum_sq'] += error * error
    
    # Calculate mean and std
    formatted_stats = {}
    for group, values in stats.items():
        count = values['count']
        mean = values['sum'] / count
        variance = (values['sum_sq'] / count) - (mean * mean)
        std = variance ** 0.5 if variance > 0 else 0
        
        formatted_stats[group] = {
            'mean': mean,
            'std': std,
            'count': count
        }
    
    return formatted_stats

# Example usage:
def load_and_analyze_results(nn_dir: str, ntk_path: str) -> ResultsData:
    """Example function showing how to use the ResultsLoader."""
    # Load results
    results_data = ResultsLoader.load_results(nn_dir, ntk_path)
    
    # Print configuration summary
    print("\nExperiment Configuration:")
    print(f"Depths: {results_data.config.depths}")
    print(f"Hidden sizes: {results_data.config.hidden_sizes}")
    print(f"Training sizes: {results_data.config.train_sizes}")
    print(f"Learning rates: {results_data.config.learning_rates}")
    
    # Print some basic statistics
    print("\nNN Results Summary:")
    nn_stats = calculate_statistics(results_data.nn_results, 'depth')
    for depth, stats in sorted(nn_stats.items()):
        print(f"Depth {depth}:")
        print(f"  Mean: {stats['mean']:.4f}")
        print(f"  Std:  {stats['std']:.4f}")
        print(f"  Count: {stats['count']}")
    
    print("\nNTK Results Summary:")
    ntk_stats = calculate_statistics(results_data.ntk_results, 'depth')
    for depth, stats in sorted(ntk_stats.items()):
        print(f"Depth {depth}:")
        print(f"  Mean: {stats['mean']:.4f}")
        print(f"  Std:  {stats['std']:.4f}")
        print(f"  Count: {stats['count']}")
    
    return results_data

In [None]:
import matplotlib.pyplot as plt
from typing import List, Dict, Optional, Tuple
import numpy as np

def create_focused_training_plot(
    results_data: ResultsData,
    nn_config: Dict[str, List],  # {depths: [...], widths: [...], learning_rates: [...]}
    ntk_config: Dict[str, List], # {depths: [...], widths: [...]}
    nngp_config: Optional[Dict[str, List]] = None,  # {depths: [...], widths: [...]}
    performance_threshold: float = 80.0,
    figsize: Tuple[int, int] = (15, 10),
    custom_colors: Optional[Dict[str, str]] = None,
    custom_markers: Optional[Dict[str, str]] = None,
    custom_linestyles: Optional[Dict[str, str]] = None
) -> plt.Figure:
    """
    Create a focused training plot with independent configurations for each model type.
    
    Args:
        results_data: ResultsData object containing all results
        nn_config: Dict with lists of depths, widths, and learning rates for NN
        ntk_config: Dict with lists of depths and widths for NTK
        nngp_config: Optional dict with lists of depths and widths for NNGP
        performance_threshold: Threshold for marking NN outperformance
        figsize: Tuple of (width, height) for the figure
        custom_colors, custom_markers, custom_linestyles: Optional style customization
    """
    # Default style settings
    default_colors = {
        'nn': '#1f77b4',  # Blue
        'ntk': '#2ca02c',  # Green
        'nngp': '#ff7f0e'  # Orange
    }
    default_markers = {
        'nn': 'o',
        'ntk': 's',
        'nngp': '^'
    }
    default_linestyles = {
        'nn': '-',
        'ntk': '--',
        'nngp': ':'
    }
    
    colors = custom_colors or default_colors
    markers = custom_markers or default_markers
    linestyles = custom_linestyles or default_linestyles
    
    # Create subplots grid based on NN configuration
    n_depths = len(nn_config['depths'])
    n_lrs = len(nn_config['learning_rates'])
    fig, axes = plt.subplots(n_depths, n_lrs, figsize=figsize, squeeze=False)
    
    # Add super title
    title = ('Test Error vs Training Size\n' +
             f'NN: depths={nn_config["depths"]}, widths={nn_config["widths"]}\n' +
             f'NTK: depths={ntk_config["depths"]}, widths={ntk_config["widths"]}')
    if nngp_config:
        title += f'\nNNGP: depths={nngp_config["depths"]}, widths={nngp_config["widths"]}'
    title += f'\nOrange crosses: First point where NN error ≤ {100-performance_threshold}% of baseline'
    fig.suptitle(title, fontsize=16, y=1.02)
    
    # Plot data
    for i, nn_depth in enumerate(nn_config['depths']):
        for j, lr in enumerate(nn_config['learning_rates']):
            ax = axes[i, j]
            
            # Plot NN results for each width
            for nn_width in nn_config['widths']:
                results = results_data.filter_results(
                    'nn',
                    depth=nn_depth,
                    hidden_size=nn_width,
                    learning_rate=lr
                )
                
                if results:
                    sorted_results = sorted(results, key=lambda x: x['n_train'])
                    train_sizes = [r['n_train'] for r in sorted_results]
                    errors = [r['test_error'] for r in sorted_results]
                    label = f'NN d={nn_depth},h={nn_width}'
                    ax.plot(train_sizes, errors, 
                           color=colors['nn'],
                           marker=markers['nn'],
                           linestyle=linestyles['nn'],
                           linewidth=2, markersize=6,
                           alpha=0.8, label=label)
            
            # Plot NTK results
            for ntk_depth in ntk_config['depths']:
                for ntk_width in ntk_config['widths']:
                    results = results_data.filter_results(
                        'ntk',
                        depth=ntk_depth,
                        hidden_size=ntk_width
                    )
                    
                    if results:
                        sorted_results = sorted(results, key=lambda x: x['n_train'])
                        train_sizes = [r['n_train'] for r in sorted_results]
                        errors = [r['test_error'] for r in sorted_results]
                        label = f'NTK d={ntk_depth},h={ntk_width}'
                        ax.plot(train_sizes, errors,
                               color=colors['ntk'],
                               marker=markers['ntk'],
                               linestyle=linestyles['ntk'],
                               linewidth=2, markersize=6,
                               alpha=0.8, label=label)
            
            # Plot NNGP results if configured
            if nngp_config:
                for nngp_depth in nngp_config['depths']:
                    for nngp_width in nngp_config['widths']:
                        results = results_data.filter_results(
                            'nngp',
                            depth=nngp_depth,
                            hidden_size=nngp_width
                        )
                        
                        if results:
                            sorted_results = sorted(results, key=lambda x: x['n_train'])
                            train_sizes = [r['n_train'] for r in sorted_results]
                            errors = [r['test_error'] for r in sorted_results]
                            label = f'NNGP d={nngp_depth},h={nngp_width}'
                            ax.plot(train_sizes, errors,
                                   color=colors['nngp'],
                                   marker=markers['nngp'],
                                   linestyle=linestyles['nngp'],
                                   linewidth=2, markersize=6,
                                   alpha=0.8, label=label)
            
            # Customize plot
            ax.set_xscale('log')
            ax.set_yscale('log')
            ax.grid(True, which="both", ls="-", alpha=0.2)
            
            if i == n_depths - 1:
                ax.set_xlabel('Training Size', fontsize=12)
            if j == 0:
                ax.set_ylabel('Test Error', fontsize=12)
            
            ax.text(0.05, 0.95, f'NN d={nn_depth}\nlr={lr:.1e}',
                   transform=ax.transAxes,
                   verticalalignment='top',
                   bbox=dict(facecolor='white', alpha=0.8))
            
            if i == 0 and j == 0:
                ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left',
                         borderaxespad=0., fontsize=10)
    
    plt.tight_layout()
    return fig





In [9]:
from pathlib import Path
import matplotlib.pyplot as plt
from typing import Optional

def create_analysis_plot():
    # Define paths
    nn_dir = "/mnt/users/goringn/NNs_vs_Kernels/stair_function/results/msp_NN_grid_1612_nogrokk_mup_pennington"
    ntk_path = "/mnt/users/goringn/NNs_vs_Kernels/stair_function/results/msp_NN_grid_1812_spectral/final_results_20241219_015151.json"
    
    # Load results
    print("Loading results...")
    results_data = ResultsLoader.load_results(
        nn_results_dir=nn_dir,
        ntk_results_path=ntk_path
    )
    
    # Configure what to plot for each model type
    nn_config = {
        'depths': [4],
        'widths': [100, 500],
        'learning_rates': [0.05]
    }
    
    ntk_config = {
        'depths': [4],
        'widths': [8000]
    }
    
    # Custom colors for better visibility
    custom_colors = {
        'nn': '#1f77b4',    # Strong blue
        'ntk': '#2ca02c',   # Strong green
        'nngp': '#ff7f0e'   # Strong orange
    }
    
    custom_markers = {
        'nn': 'o',
        'ntk': 's',
        'nngp': '^'
    }
    
    custom_linestyles = {
        'nn': '-',
        'ntk': '--',
        'nngp': ':'
    }
    
    print("\nCreating plot with:")
    print(f"NN config: {nn_config}")
    print(f"NTK config: {ntk_config}")
    
    # Create plot
    fig = create_focused_training_plot(
        results_data,
        nn_config=nn_config,
        ntk_config=ntk_config,
        nngp_config=None,
        performance_threshold=80.0,
        figsize=(12, 8),
        custom_colors=custom_colors,
        custom_markers=custom_markers,
        custom_linestyles=custom_linestyles
    )
    
    # Save plot
    output_dir = Path("plots")
    output_dir.mkdir(exist_ok=True)
    output_path = output_dir / "training_analysis1.png"
    fig.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"\nPlot saved to: {output_path}")
    
    return results_data, fig

if __name__ == "__main__":
    results_data, fig = create_analysis_plot()
    plt.show()

In [12]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from cycler import cycler
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
from pathlib import Path

class PlottingConfig:
    """Configure matplotlib plotting styles."""
    
    @staticmethod
    def setup_ieee_style():
        """IEEE publication style."""
        plt.style.use('default')
        mpl.rcParams.update({
            'figure.figsize': (3.3, 2.5),
            'figure.dpi': 600,
            'font.family': 'serif',
            'font.serif': ['Times'],
            'font.size': 8,
            'axes.labelsize': 8,
            'xtick.labelsize': 8,
            'ytick.labelsize': 8,
            'legend.fontsize': 8,
            'axes.linewidth': 0.5,
            'grid.linewidth': 0.5,
            'lines.linewidth': 1.0
        })
    
    @staticmethod
    def setup_nature_style():
        """Nature publication style."""
        plt.style.use('default')
        mpl.rcParams.update({
            'figure.figsize': (3.3, 2.5),
            'font.family': 'sans-serif',
            'font.sans-serif': ['DejaVu Sans', 'Arial'],
            'font.size': 7,
            'axes.labelsize': 7,
            'xtick.labelsize': 7,
            'ytick.labelsize': 7,
            'legend.fontsize': 7,
            'axes.linewidth': 0.5,
            'grid.linewidth': 0.5,
            'lines.linewidth': 1.0,
            'lines.markersize': 3
        })

class SciencePlotter:
    """Class for creating publication-quality scientific plots."""
    
    def __init__(self, style: str = 'ieee'):
        """
        Initialize plotter with specified style.
        
        Args:
            style: 'ieee' or 'nature'
        """
        if style.lower() == 'ieee':
            PlottingConfig.setup_ieee_style()
        elif style.lower() == 'nature':
            PlottingConfig.setup_nature_style()
        else:
            raise ValueError("Style must be 'ieee' or 'nature'")
    
    def create_figure(self, width: float = 3.3, height: float = 2.5) -> Tuple[plt.Figure, plt.Axes]:
        """Create a figure with publication styling."""
        fig, ax = plt.subplots(figsize=(width, height), facecolor='white', edgecolor='black')
        fig.patch.set_linewidth(0.5)
        return fig, ax
    
    def plot_iv_curves(self, 
                      orders: List[int] = [10, 15, 20, 30, 50, 100],
                      voltage_range: Tuple[float, float] = (0.7, 1.2),
                      current_range: Tuple[float, float] = (0.0, 1.2),
                      num_points: int = 1000,
                      custom_colors: Optional[List[str]] = None) -> plt.Figure:
        """
        Create IV curve plot.
        
        Args:
            orders: List of order values for different curves
            voltage_range: (min_voltage, max_voltage)
            current_range: (min_current, max_current)
            num_points: Number of points for voltage array
            custom_colors: Optional list of colors for curves
        """
        default_colors = ['#1f77b4', '#2ca02c', '#ff7f0e', 
                         '#d62728', '#9467bd', '#7f7f7f']
        colors = custom_colors or default_colors
        
        fig, ax = self.create_figure()
        
        V = np.linspace(*voltage_range, num_points)
        
        for order, color in zip(orders, colors):
            I = current_range[1] / (1 + np.exp(-order * (V - 1.0)))
            ax.plot(V, I, '-', color=color, label=str(order), lw=1.0)
        
        ax.set_xlabel('Voltage (mV)')
        ax.set_ylabel('Current (μA)')
        ax.set_xlim(voltage_range)
        ax.set_ylim(current_range)
        
        legend = ax.legend(title='Order', frameon=False,
                          loc='upper left',
                          bbox_to_anchor=(0.02, 0.98))
        legend.get_title().set_fontsize(8)
        
        plt.tight_layout()
        return fig
    
    def plot_training_curves(self,
                           results_data: ResultsData,
                           nn_config: Dict[str, List],
                           ntk_config: Dict[str, List],
                           nngp_config: Optional[Dict[str, List]] = None,
                           performance_threshold: float = 80.0,
                           figsize: Tuple[float, float] = (15, 10),
                           custom_colors: Optional[Dict[str, str]] = None,
                           custom_markers: Optional[Dict[str, str]] = None,
                           custom_linestyles: Optional[Dict[str, str]] = None) -> plt.Figure:
        """
        Create training curves plot.
        
        Args:
            results_data: ResultsData object with experiment results
            nn_config: Neural network configuration
            ntk_config: NTK configuration
            nngp_config: Optional NNGP configuration
            performance_threshold: Threshold for performance comparison
            figsize: Figure size
            custom_colors, custom_markers, custom_linestyles: Optional style customization
        """
        default_colors = {
            'nn': '#1f77b4',
            'ntk': '#2ca02c',
            'nngp': '#ff7f0e'
        }
        default_markers = {
            'nn': 'o',
            'ntk': 's',
            'nngp': '^'
        }
        default_linestyles = {
            'nn': '-',
            'ntk': '--',
            'nngp': ':'
        }
        
        colors = custom_colors or default_colors
        markers = custom_markers or default_markers
        linestyles = custom_linestyles or default_linestyles
        
        fig, ax = plt.subplots(figsize=figsize)
        
        # Plot NN results
        for depth in nn_config['depths']:
            for width in nn_config['widths']:
                for lr in nn_config['learning_rates']:
                    results = results_data.filter_results(
                        'nn', depth=depth, hidden_size=width, learning_rate=lr
                    )
                    if results:
                        sorted_results = sorted(results, key=lambda x: x['n_train'])
                        train_sizes = [r['n_train'] for r in sorted_results]
                        errors = [r['test_error'] for r in sorted_results]
                        label = f'NN d={depth},h={width},lr={lr}'
                        ax.plot(train_sizes, errors, 
                               color=colors['nn'],
                               marker=markers['nn'],
                               linestyle=linestyles['nn'],
                               label=label)
        
        # Plot NTK results
        for depth in ntk_config['depths']:
            for width in ntk_config['widths']:
                results = results_data.filter_results(
                    'ntk', depth=depth, hidden_size=width
                )
                if results:
                    sorted_results = sorted(results, key=lambda x: x['n_train'])
                    train_sizes = [r['n_train'] for r in sorted_results]
                    errors = [r['test_error'] for r in sorted_results]
                    label = f'NTK d={depth},h={width}'
                    ax.plot(train_sizes, errors,
                           color=colors['ntk'],
                           marker=markers['ntk'],
                           linestyle=linestyles['ntk'],
                           label=label)
        
        ax.set_xscale('log')
        ax.set_yscale('log')
        ax.grid(True, which="both", ls="-", alpha=0.2)
        ax.set_xlabel('Training Size')
        ax.set_ylabel('Test Error')
        ax.legend()
        
        plt.tight_layout()
        return fig
    
    @staticmethod
    def save_figure(fig: plt.Figure, 
                   filename: str,
                   dpi: int = 600,
                   bbox_inches: str = 'tight',
                   pad_inches: float = 0.1):
        """Save figure with publication-quality settings."""
        fig.savefig(filename,
                   dpi=dpi,
                   bbox_inches=bbox_inches,
                   facecolor='white',
                   edgecolor='black',
                   pad_inches=pad_inches)

# Example usage in Jupyter notebook:

# Initialize plotter
plotter = SciencePlotter(style='ieee')

# Create IV curves
fig_iv = plotter.plot_iv_curves()
plotter.save_figure(fig_iv, 'iv_curves.png')

# Create training curves
nn_config = {
    'depths': [4],
    'widths': [100, 500],
    'learning_rates': [0.05]
}
ntk_config = {
    'depths': [4],
    'widths': [8000]
}

fig_training = plotter.plot_training_curves(
    results_data,
    nn_config=nn_config,
    ntk_config=ntk_config
)
plotter.save_figure(fig_training, 'training_curves2.png')


In [13]:
# Initialize plotter
plotter = SciencePlotter(style='ieee')

# For IV curves
fig_iv = plotter.plot_iv_curves(
    orders=[10, 15, 20, 30, 50, 100],  # customize orders
    voltage_range=(0.7, 1.2),          # customize range
    custom_colors=['blue', 'green', 'orange', 'red', 'purple', 'gray']  # customize colors
)

# For training curves
nn_config = {
    'depths': [4],
    'widths': [100, 500],
    'learning_rates': [0.05]
}
ntk_config = {
    'depths': [4],
    'widths': [8000]
}

fig_training = plotter.plot_training_curves(
    results_data,
    nn_config=nn_config,
    ntk_config=ntk_config,
    custom_colors={'nn': 'blue', 'ntk': 'green'}  # customize colors
)

# Save figures
plotter.save_figure(fig_iv, 'iv_curves.png')
plotter.save_figure(fig_training, 'training_curves3.png')

In [21]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from cycler import cycler
from typing import Dict, List, Optional, Tuple, Union
from dataclasses import dataclass
from pathlib import Path

@dataclass
class PlotStyle:
    color: str
    linestyle: str
    linewidth: float
    marker: Optional[str] = None
    markersize: Optional[float] = None
    alpha: float = 1.0

class SciencePlotter:
    """Class for creating publication-quality scientific plots."""
    
    def __init__(self):
        self.setup_science_style()
        
        # Store custom colors for IV curves
        self.iv_colors = {
            10: '#01B5EB',   # blue
            15: '#96FFC0',   # green
            20: '#FFB463',   # orange
            30: '#FF0000',   # red
            50: '#8710FF',   # purple
            100: '#7f7f7f'   # gray
        }
        
    @staticmethod
    def setup_science_style():
        plt.style.use('default')
        mpl.rcParams.update({
            'figure.figsize': (3.3, 2.5),
            'figure.dpi': 600,
            
            # Font settings
            'font.family': 'serif',
            'font.serif': ['cmr10', 'Computer Modern Serif', 'DejaVu Serif'],
            'text.usetex': False,
            'axes.formatter.use_mathtext': True,
            'mathtext.fontset': 'cm',
            
            # Axis settings
            'axes.linewidth': 0.5,
            'axes.spines.top': True,
            'axes.spines.right': True,
            'axes.spines.left': True,
            'axes.spines.bottom': True,
            
            # Tick settings
            'xtick.direction': 'in',
            'ytick.direction': 'in',
            'xtick.major.width': 0.5,
            'ytick.major.width': 0.5,
            'xtick.minor.width': 0.5,
            'ytick.minor.width': 0.5,
            'xtick.major.size': 3,
            'ytick.major.size': 3,
            'xtick.minor.size': 1.5,
            'ytick.minor.size': 1.5,
            'xtick.top': True,
            'ytick.right': True,
            
            # Grid settings
            'grid.linewidth': 0.5,
            
            # Line settings
            'lines.linewidth': 1.0,
            'lines.markersize': 3,
            
            # Legend settings
            'legend.frameon': False,
            'legend.borderpad': 0,
            'legend.borderaxespad': 1.0,
            'legend.handlelength': 1.0,
            'legend.handletextpad': 0.5,
        })

    def plot_training_curves(self,
                           results_data: ResultsData,
                           nn_config: Dict[str, List],
                           ntk_config: Dict[str, List],
                           nn_styles: Optional[Dict[int, PlotStyle]] = None,
                           ntk_styles: Optional[Dict[int, PlotStyle]] = None,
                           nngp_config: Optional[Dict[str, List]] = None,
                           nngp_styles: Optional[Dict[int, PlotStyle]] = None,
                           performance_threshold: float = 80.0,
                           figsize: Optional[Tuple[float, float]] = None) -> plt.Figure:
        """
        Create training curves plot with custom styling for each width.
        
        Args:
            results_data: ResultsData object
            nn_config: Neural network configuration
            ntk_config: NTK configuration
            nn_styles: Dict mapping widths to PlotStyle objects for NN curves
            ntk_styles: Dict mapping widths to PlotStyle objects for NTK curves
            nngp_config: Optional NNGP configuration
            nngp_styles: Dict mapping widths to PlotStyle objects for NNGP curves
            performance_threshold: Threshold for performance comparison
            figsize: Optional figure size
        """
        if figsize is None:
            figsize = (3.3, 2.5)
            
        fig, ax = plt.subplots(figsize=figsize)
        
        # Default styles if none provided
        default_nn_style = PlotStyle(color='#01B5EB', linestyle='-', linewidth=1.0)
        default_ntk_style = PlotStyle(color='#96FFC0', linestyle='--', linewidth=1.0)
        default_nngp_style = PlotStyle(color='#FFB463', linestyle=':', linewidth=1.0)
        
        # Plot NN results
        for depth in nn_config['depths']:
            for width in nn_config['widths']:
                style = nn_styles.get(width, default_nn_style) if nn_styles else default_nn_style
                
                for lr in nn_config['learning_rates']:
                    results = results_data.filter_results(
                        'nn', depth=depth, hidden_size=width, learning_rate=lr
                    )
                    if results:
                        sorted_results = sorted(results, key=lambda x: x['n_train'])
                        train_sizes = [r['n_train'] for r in sorted_results]
                        errors = [r['test_error'] for r in sorted_results]
                        label = f'NN d={depth},h={width}'
                        ax.plot(train_sizes, errors, 
                               color=style.color,
                               linestyle=style.linestyle,
                               linewidth=style.linewidth,
                               marker=style.marker,
                               markersize=style.markersize,
                               alpha=style.alpha,
                               label=label)
        
        # Plot NTK results
        for depth in ntk_config['depths']:
            for width in ntk_config['widths']:
                style = ntk_styles.get(width, default_ntk_style) if ntk_styles else default_ntk_style
                
                results = results_data.filter_results(
                    'ntk', depth=depth, hidden_size=width
                )
                if results:
                    sorted_results = sorted(results, key=lambda x: x['n_train'])
                    train_sizes = [r['n_train'] for r in sorted_results]
                    errors = [r['test_error'] for r in sorted_results]
                    label = f'NTK d={depth},h={width}'
                    ax.plot(train_sizes, errors,
                           color=style.color,
                           linestyle=style.linestyle,
                           linewidth=style.linewidth,
                           marker=style.marker,
                           markersize=style.markersize,
                           alpha=style.alpha,
                           label=label)
        
        # Customize plot
        ax.set_xscale('log')
        ax.set_yscale('log')
        ax.set_xlabel('Training Size', labelpad=2)
        ax.set_ylabel('Test Error', labelpad=2)
        
        # Force y-axis ticks to point inward
        ax.yaxis.set_tick_params(direction='in', which='both')
        
        # Add minor ticks
        ax.minorticks_on()
        
        # Add legend
        ax.legend(frameon=False, 
                 loc='upper right',
                 bbox_to_anchor=(0.98, 0.98),
                 handlelength=1.0,
                 handletextpad=0.5)
        
        plt.tight_layout()
        return fig

# Example usage:

plotter = SciencePlotter()

# Define custom styles for different widths
nn_styles = {
    100: PlotStyle(color='#01B5EB', linestyle='-', linewidth=1.0),
    500: PlotStyle(color='#FF0000', linestyle='-', linewidth=1.5)
}

ntk_styles = {
    8000: PlotStyle(color='#8710FF', linestyle='--', linewidth=1.0, 
                    marker='o', markersize=3)
}

# Create and save training curves with custom styles
nn_config = {
    'depths': [4],
    'widths': [100, 500],
    'learning_rates': [0.05]
}
ntk_config = {
    'depths': [4],
    'widths': [8000]
}

fig_training = plotter.plot_training_curves(
    results_data,
    nn_config=nn_config,
    ntk_config=ntk_config,
    nn_styles=nn_styles,
    ntk_styles=ntk_styles
)
plotter.save_figure(fig_training, 'training_curves.png')
