In [None]:

# =============================================================================
# 1. Reproducibility and Device Setup
# =============================================================================

def seed_everything(seed=4):
    """
    Set random seeds for reproducibility across Python, NumPy, and PyTorch.
    
    Args:
        seed (int): Seed value for all random number generators.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything()

# Select device: GPU if available, otherwise CPU.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# =============================================================================
# 2. Mathematical Function Definitions
# =============================================================================

# --- Symbolic Functions ---
def symbolic_function_1(x):
    """
    Computes f(x) = sin(5x) * exp(-x^2)
    
    Args:
        x (np.array): Input values.
    
    Returns:
        np.array: Function values.
    """
    return np.sin(5 * x) * np.exp(-x**2)

def symbolic_function_2(x):
    """
    Computes a conditional function:
      - f(x) = sin(5x) if x < 0
      - f(x) = cos(5x) if x >= 0
    Multiplies the result by exp(-x^2).
    
    Args:
        x (np.array): Input values.
    
    Returns:
        np.array: Function values.
    """
    return np.where(x < 0, np.sin(5 * x), np.cos(5 * x)) * np.exp(-x**2)

def symbolic_function_3(x):
    """
    Computes f(x) = |sin(5x)| * exp(-x^2)
    
    Args:
        x (np.array): Input values.
    
    Returns:
        np.array: Function values.
    """
    return np.abs(np.sin(5 * x)) * np.exp(-x**2)

def symbolic_function_4(x):
    """
    Computes f(x) = sin(10x) * exp(-x^2)
    
    Args:
        x (np.array): Input values.
    
    Returns:
        np.array: Function values.
    """
    return np.sin(10 * x) * np.exp(-x**2)

def symbolic_function_5(x):
    """
    Computes f(x) = sin(15x) * exp(-x^2)
    
    Args:
        x (np.array): Input values.
    
    Returns:
        np.array: Function values.
    """
    return np.sin(15 * x) * np.exp(-x**2)

def symbolic_function_6(x):
    """
    Computes a piecewise function:
      - f(x) = sin(5x) for x < 0
      - f(x) = sin(5x) + 1 for x >= 0
    Then multiplies the result by exp(-x^2).
    
    Args:
        x (np.array): Input values.
    
    Returns:
        np.array: Function values.
    """
    return np.piecewise(x, [x < 0, x >= 0], 
                          [lambda x: np.sin(5 * x), lambda x: np.sin(5 * x) + 1]) * np.exp(-x**2)

# --- Additional Functions: Polynomials, Exponentials, Logarithms, etc. ---
def polynomial_function_1(x):
    """Computes f(x) = x^2."""
    return x**2

def polynomial_function_2(x):
    """Computes f(x) = x^3 - 2x + 1."""
    return x**3 - 2*x + 1

def exponential_function_1(x):
    """Computes f(x) = exp(x)."""
    return np.exp(x)

def exponential_function_2(x):
    """Computes f(x) = 2 * exp(0.5 * x)."""
    return 2 * np.exp(0.5 * x)

def logarithmic_function_1(x):
    """Computes f(x) = log(|x| + 1)."""
    return np.log(np.abs(x) + 1)

def logarithmic_function_2(x):
    """Computes f(x) = log10(|x| + 1)."""
    return np.log10(np.abs(x) + 1)

def trigonometric_function_3(x):
    """Computes f(x) = tan(x) * exp(-x^2)."""
    return np.tan(x) * np.exp(-x**2)

def trigonometric_function_4(x):
    """Computes f(x) = (1/tan(x)) * exp(-x^2)."""
    return (1 / np.tan(x)) * np.exp(-x**2)

def step_function_1(x):
    """Computes a step function: f(x) = 0 if x < 0, else 1."""
    return np.where(x < 0, 0, 1)

def step_function_2(x):
    """Computes a step function: f(x) = 0 if x < -1, else 1."""
    return np.where(x < -1, 0, 1)

def uniform_function_1(x):
    """Returns a constant function f(x) = 5."""
    return 5 * np.ones_like(x)

def uniform_function_2(x):
    """
    Computes f(x) = 3x + 2 with added uniform noise in the range [-0.5, 0.5].
    
    Args:
        x (np.array): Input values.
    
    Returns:
        np.array: Function values with noise.
    """
    return 3 * x + 2 + np.random.uniform(-0.5, 0.5, size=x.shape)

# --- Dictionary for Easy Function Access ---
functions = {
    'Function 1': symbolic_function_1,
    'Function 2': symbolic_function_2,
    'Function 3': symbolic_function_3,
    'Function 4': symbolic_function_4,
    'Function 5': symbolic_function_5,
    'Function 6': symbolic_function_6,
    'Polynomial 1': polynomial_function_1,
    'Polynomial 2': polynomial_function_2,
    'Exponential 1': exponential_function_1,
    'Exponential 2': exponential_function_2,
    'Logarithmic 1': logarithmic_function_1,
    'Logarithmic 2': logarithmic_function_2,
    'Trigonometric 3': trigonometric_function_3,
    'Trigonometric 4': trigonometric_function_4,
    'Step 1': step_function_1,
    'Step 2': step_function_2,
    'Uniform 1': uniform_function_1,
    'Uniform 2': uniform_function_2,
}

# =============================================================================
# 3. Data Generation and Normalization
# =============================================================================

def generate_data(symbolic_function, n_samples=1000):
    """
    Generates data samples using the provided function, splits the data into
    training, validation, and testing sets, and applies normalization to both
    inputs and outputs.
    
    Args:
        symbolic_function (callable): The function to generate y-values.
        n_samples (int): Total number of samples to generate.
    
    Returns:
        tuple: (train_loader, val_loader, test_loader, x_test (numpy array),
                y_test_orig (original-scale test targets), y_scaler)
    """
    # Create evenly spaced x-values and shuffle them
    x_values = np.linspace(-2, 2, n_samples)
    np.random.shuffle(x_values)
    y_values = symbolic_function(x_values)

    # Split data: 70% training, 15% validation, 15% test
    train_split = int(0.7 * n_samples)
    val_split = int(0.85 * n_samples)

    x_train = x_values[:train_split]
    y_train = y_values[:train_split]
    x_val = x_values[train_split:val_split]
    y_val = y_values[train_split:val_split]
    x_test = x_values[val_split:]
    y_test = y_values[val_split:]

    # Normalize inputs using training mean and std
    x_mean = x_train.mean()
    x_std = x_train.std()
    x_train = (x_train - x_mean) / x_std
    x_val = (x_val - x_mean) / x_std
    x_test = (x_test - x_mean) / x_std

    # Normalize outputs with MinMaxScaler based on training data
    y_scaler = MinMaxScaler()
    y_train_scaled = y_scaler.fit_transform(y_train.reshape(-1, 1)).flatten()
    y_val_scaled = y_scaler.transform(y_val.reshape(-1, 1)).flatten()
    y_test_scaled = y_scaler.transform(y_test.reshape(-1, 1)).flatten()

    # Obtain the inverse-transformed test targets for plotting
    y_test_orig = y_scaler.inverse_transform(y_test_scaled.reshape(-1, 1)).flatten()

    # Convert numpy arrays to PyTorch tensors and add a feature dimension
    x_train_tensor = torch.tensor(x_train, dtype=torch.float32).unsqueeze(1).to(device)
    y_train_tensor = torch.tensor(y_train_scaled, dtype=torch.float32).unsqueeze(1).to(device)
    x_val_tensor = torch.tensor(x_val, dtype=torch.float32).unsqueeze(1).to(device)
    y_val_tensor = torch.tensor(y_val_scaled, dtype=torch.float32).unsqueeze(1).to(device)
    x_test_tensor = torch.tensor(x_test, dtype=torch.float32).unsqueeze(1).to(device)
    y_test_tensor = torch.tensor(y_test_scaled, dtype=torch.float32).unsqueeze(1).to(device)

    # Create DataLoaders for batching during training and evaluation
    train_dataset = TensorDataset(x_train_tensor, y_train_tensor)
    val_dataset = TensorDataset(x_val_tensor, y_val_tensor)
    test_dataset = TensorDataset(x_test_tensor, y_test_tensor)

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=0)

    return train_loader, val_loader, test_loader, x_test, y_test_orig, y_scaler



# =============================================================================
# 5. Training and Evaluation Functions
# =============================================================================

def train_model(model, optimizer, criterion, train_loader, val_loader, num_epochs=100, patience=50):
    """
    Train the provided model using the specified optimizer and loss function.
    
    Implements early stopping based on validation loss and includes learning
    rate scheduling.
    
    Args:
        model (nn.Module): Neural network model.
        optimizer (torch.optim.Optimizer): Optimizer for training.
        criterion (nn.Module): Loss function.
        train_loader (DataLoader): DataLoader for training data.
        val_loader (DataLoader): DataLoader for validation data.
        num_epochs (int): Maximum number of training epochs.
        patience (int): Number of epochs to wait for improvement before stopping.
    
    Returns:
        tuple: Lists of training and validation losses per epoch.
    """
    train_losses = []
    val_losses = []
    best_val_loss = np.inf
    patience_counter = 0

    # Reduce learning rate if the validation loss plateaus.
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.5)
    best_model_state = None

    for epoch in tqdm(range(num_epochs), desc='Training'):
        model.train()
        running_loss = 0.0
        for inputs, targets in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            # Clip gradients to prevent exploding gradients.
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
        epoch_train_loss = running_loss / len(train_loader.dataset)
        train_losses.append(epoch_train_loss)

        # Evaluate on validation set
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for inputs, targets in val_loader:
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item() * inputs.size(0)
        epoch_val_loss = val_loss / len(val_loader.dataset)
        val_losses.append(epoch_val_loss)

        scheduler.step(epoch_val_loss)

        # Early stopping logic
        if epoch_val_loss < best_val_loss:
            best_val_loss = epoch_val_loss
            patience_counter = 0
            best_model_state = model.state_dict()
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered.")
                break

    # Restore the best model weights.
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    return train_losses, val_losses

def evaluate_model(model, criterion, test_loader, y_scaler):
    """
    Evaluate the trained model on test data and compute error metrics.
    
    Args:
        model (nn.Module): Trained model.
        criterion (nn.Module): Loss function.
        test_loader (DataLoader): DataLoader for test data.
        y_scaler (MinMaxScaler): Scaler for inverse-transforming outputs.
    
    Returns:
        tuple: (test_loss, MAE, R², MAPE, Max Error, predictions in original scale,
                true targets in original scale)
    """
    model.eval()
    test_loss = 0.0
    preds = []
    trues = []
    with torch.no_grad():
        for inputs, targets in test_loader:
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            test_loss += loss.item() * inputs.size(0)
            preds.append(outputs.cpu().numpy())
            trues.append(targets.cpu().numpy())
    test_loss = test_loss / len(test_loader.dataset)
    preds = np.concatenate(preds).flatten()
    trues = np.concatenate(trues).flatten()

    # Inverse transform predictions and true targets to original scale.
    preds_orig = y_scaler.inverse_transform(preds.reshape(-1, 1)).flatten()
    trues_orig = y_scaler.inverse_transform(trues.reshape(-1, 1)).flatten()

    mae = mean_absolute_error(trues_orig, preds_orig)
    r2 = r2_score(trues_orig, preds_orig)
    mape = mean_absolute_percentage_error(trues_orig, preds_orig)
    max_err = max_error(trues_orig, preds_orig)

    return test_loss, mae, r2, mape, max_err, preds_orig, trues_orig

# =============================================================================
# 6. Main Training Loop for Model Comparison
# =============================================================================

criterion = nn.MSELoss()
results = {}

# Loop over all defined functions for comparison.
for func_name, func in functions.items():
    print(f"\nProcessing {func_name}...")

    # Generate normalized data from the current function.
    train_loader, val_loader, test_loader, x_test, y_test_orig, y_scaler = generate_data(func)

    # Initialize models and print parameter counts.
    mlp_model = MLP().to(device)
    kan_model = KAN().to(device)
    print("MLP model parameter count:", sum(p.numel() for p in mlp_model.parameters()))
    print("KAN model parameter count:", sum(p.numel() for p in kan_model.parameters()))

    # Set up optimizers.
    mlp_optimizer = optim.Adam(mlp_model.parameters(), lr=0.001)
    kan_optimizer = optim.Adam(kan_model.parameters(), lr=0.001, weight_decay=1e-4)

    # Train and evaluate the MLP model.
    print("Training MLP Model...")
    mlp_train_losses, mlp_val_losses = train_model(mlp_model, mlp_optimizer, criterion, train_loader, val_loader)
    mlp_test_loss, mlp_mae, mlp_r2, mlp_mape, mlp_max_err, mlp_preds, mlp_trues = evaluate_model(
        mlp_model, criterion, test_loader, y_scaler
    )

    # Train and evaluate the KAN model.
    print("Training KAN Model...")
    kan_train_losses, kan_val_losses = train_model(kan_model, kan_optimizer, criterion, train_loader, val_loader)
    kan_test_loss, kan_mae, kan_r2, kan_mape, kan_max_err, kan_preds, kan_trues = evaluate_model(
        kan_model, criterion, test_loader, y_scaler
    )

    # Store all metrics and predictions for later analysis.
    results[func_name] = {
        'MLP': {
            'train_losses': mlp_train_losses,
            'val_losses': mlp_val_losses,
            'test_loss': mlp_test_loss,
            'mae': mlp_mae,
            'r2': mlp_r2,
            'mape': mlp_mape,
            'max_error': mlp_max_err,
            'preds': mlp_preds,
            'trues': mlp_trues,
        },
        'KAN': {
            'train_losses': kan_train_losses,
            'val_losses': kan_val_losses,
            'test_loss': kan_test_loss,
            'mae': kan_mae,
            'r2': kan_r2,
            'mape': kan_mape,
            'max_error': kan_max_err,
            'preds': kan_preds,
            'trues': kan_trues,
        },
        'x_test': x_test,
        'y_test': y_test_orig,
    }

    # ---------------------------
    # Plot Training and Validation Losses
    # ---------------------------
    plt.figure(figsize=(12, 6))
    plt.plot(mlp_train_losses, label='MLP Training Loss')
    plt.plot(mlp_val_losses, label='MLP Validation Loss')
    plt.plot(kan_train_losses, label='KAN Training Loss')
    plt.plot(kan_val_losses, label='KAN Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('MSE Loss')
    plt.title(f'Training and Validation Losses for {func_name}')
    plt.legend()
    plt.show()

    # ---------------------------
    # Print Test Metrics for Both Models
    # ---------------------------
    print(f"MLP Test MSE Loss: {results[func_name]['MLP']['test_loss']:.6f}")
    print(f"MLP Test MAE: {results[func_name]['MLP']['mae']:.6f}")
    print(f"MLP Test R^2 Score: {results[func_name]['MLP']['r2']:.6f}")
    print(f"MLP Test MAPE: {results[func_name]['MLP']['mape']:.2f}%")
    print(f"MLP Test Max Error: {results[func_name]['MLP']['max_error']:.6f}")
    
    print(f"KAN Test MSE Loss: {results[func_name]['KAN']['test_loss']:.6f}")
    print(f"KAN Test MAE: {results[func_name]['KAN']['mae']:.6f}")
    print(f"KAN Test R^2 Score: {results[func_name]['KAN']['r2']:.6f}")
    print(f"KAN Test MAPE: {results[func_name]['KAN']['mape']:.2f}%")
    print(f"KAN Test Max Error: {results[func_name]['KAN']['max_error']:.6f}")

    # ---------------------------
    # Plot Predictions vs. True Function
    # ---------------------------
    x_test = results[func_name]['x_test']
    mlp_preds = results[func_name]['MLP']['preds']
    kan_preds = results[func_name]['KAN']['preds']
    true_vals = results[func_name]['MLP']['trues']  # True targets are the same for both models.

    # Sort the test data for a cleaner plot.
    sorted_indices = np.argsort(x_test)
    x_test_sorted = x_test[sorted_indices]
    true_sorted = true_vals[sorted_indices]
    mlp_preds_sorted = mlp_preds[sorted_indices]
    kan_preds_sorted = kan_preds[sorted_indices]

    plt.figure(figsize=(12, 6))
    plt.plot(x_test_sorted, true_sorted, 'k-', label='True Function', linewidth=2)
    plt.plot(x_test_sorted, mlp_preds_sorted, 'r--', label='MLP Predictions', linewidth=2)
    plt.plot(x_test_sorted, kan_preds_sorted, 'b--', label='KAN Predictions', linewidth=2)
    plt.xlabel('x')
    plt.ylabel('f(x)')
    plt.title(f'Function Approximation for {func_name}')
    plt.legend()
    plt.show()

    # ---------------------------
    # Plot Residuals (Predicted - True)
    # ---------------------------
    plt.figure(figsize=(12, 6))
    plt.scatter(x_test, mlp_preds - true_vals, alpha=0.5, label='MLP Residuals')
    plt.scatter(x_test, kan_preds - true_vals, alpha=0.5, label='KAN Residuals')
    plt.hlines(0, x_test.min(), x_test.max(), colors='k', linestyles='dashed')
    plt.xlabel('x')
    plt.ylabel('Residual')
    plt.title(f'Residuals for {func_name}')
    plt.legend()
    plt.show()

    # ---------------------------
    # Plot Histogram of Residuals
    # ---------------------------
    plt.figure(figsize=(12, 6))
    plt.hist(mlp_preds - true_vals, bins=30, alpha=0.5, label='MLP Residuals')
    plt.hist(kan_preds - true_vals, bins=30, alpha=0.5, label='KAN Residuals')
    plt.xlabel('Residual')
    plt.ylabel('Frequency')
    plt.title(f'Residual Histogram for {func_name}')
    plt.legend()
    plt.show()

# =============================================================================
# 7. Summary Table of Metrics and Visualization
# =============================================================================

summary_data = []

# Create a summary table that collects error metrics for both models across functions.
for func_name in functions.keys():
    mlp_metrics = results[func_name]['MLP']
    kan_metrics = results[func_name]['KAN']

    summary_data.append({
        'Function': func_name,
        'Model': 'MLP',
        'MSE': mlp_metrics['test_loss'],
        'MAE': mlp_metrics['mae'],
        'R²': mlp_metrics['r2'],
        'MAPE (%)': mlp_metrics['mape'],
        'Max Error': mlp_metrics['max_error']
    })

    summary_data.append({
        'Function': func_name,
        'Model': 'KAN',
        'MSE': kan_metrics['test_loss'],
        'MAE': kan_metrics['mae'],
        'R²': kan_metrics['r2'],
        'MAPE (%)': kan_metrics['mape'],
        'Max Error': kan_metrics['max_error']
    })

summary_df = pd.DataFrame(summary_data)
print("\nSummary of Metrics:")
print(summary_df)
summary_df.to_csv('model_comparison_summary.csv', index=False)

# Visualize the metrics for comparison using Seaborn bar plots.
sns.set(style="whitegrid")
metrics = ['MSE', 'MAE', 'R²', 'MAPE (%)', 'Max Error']

for metric in metrics:
    plt.figure(figsize=(14, 7))
    sns.barplot(x='Function', y=metric, hue='Model', data=summary_df)
    plt.title(f'Comparison of {metric} between MLP and KAN')
    plt.xlabel('Function')
    plt.ylabel(metric)
    plt.legend(title='Model')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()
