In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme()
import time

# Natural Evolution Strategies (NES) Implementation
class NaturalEvolutionStrategy:
    def __init__(self, network, population_size, learning_rate, sigma_init, seed=None):
        self.network = network
        self.population_size = population_size
        self.learning_rate = learning_rate
        self.sigma = sigma_init
        self.mean = self.network.get_parameters()
        self.dim = len(self.mean)
        self.rng = np.random.default_rng(seed)

    def fitness_shaping(self, rewards):
        ranks = np.argsort(np.argsort(-rewards))
        utilities = np.maximum(0, np.log(self.population_size / 2 + 1) - np.log(ranks + 1))
        utilities -= np.mean(utilities)
        return utilities

    def optimize(self, fitness_function, generations, sigma_min=1e-5, sigma_max=10.0):
        self.loss_history = []
        self.time_history = []
        start_time = time.time()

        for gen in range(generations):
            samples = self.rng.normal(self.mean, self.sigma, (self.population_size, self.dim))
            fitness = np.array([fitness_function(sample) for sample in samples])
            utilities = self.fitness_shaping(fitness)

            log_derivative_mean = (samples - self.mean) / self.sigma**2
            log_derivative_sigma = (np.linalg.norm(samples - self.mean, axis=1)**2 - self.dim) / self.sigma
            grad_mean = np.dot(utilities, log_derivative_mean) / self.population_size
            grad_sigma = np.dot(utilities, log_derivative_sigma) / self.population_size

            fisher = np.dot(log_derivative_mean.T, log_derivative_mean) / self.population_size
            reg_fisher = fisher + np.eye(fisher.shape[0]) * 1e-5

            self.mean += self.learning_rate * np.linalg.solve(reg_fisher, grad_mean)
            self.sigma *= np.exp(self.learning_rate / 2 * grad_sigma)
            self.sigma = np.clip(self.sigma, sigma_min, sigma_max)

            self.loss_history.append(-np.mean(fitness))
            self.time_history.append(time.time() - start_time)

    def get_parameters(self):
        return self.mean

# Neural Network with Support for Multiple Layers
class NeuralNetwork:
    def __init__(self, layer_sizes):
        self.layer_sizes = layer_sizes
        self.params = self.initialize_parameters()

    def initialize_parameters(self):
        params = []
        for i in range(len(self.layer_sizes) - 1):
            input_size, output_size = self.layer_sizes[i], self.layer_sizes[i + 1]
            params.append(np.random.randn(input_size * output_size) * 0.1)
            params.append(np.random.randn(output_size) * 0.1)
        return np.concatenate(params)

    def get_parameters(self):
        return self.params

    def set_parameters(self, params):
        self.params = params

    def predict(self, X):
        params = self.params
        start = 0
        for i in range(len(self.layer_sizes) - 1):
            input_size, output_size = self.layer_sizes[i], self.layer_sizes[i + 1]
            end_weights = start + input_size * output_size
            end_biases = end_weights + output_size

            weights = params[start:end_weights].reshape(input_size, output_size)
            biases = params[end_weights:end_biases]

            X = np.dot(X, weights) + biases
            if i < len(self.layer_sizes) - 2:
                X = np.tanh(X)
            start = end_biases
        return X

    def evaluate_loss(self, X, y):
        predictions = self.predict(X)
        return np.mean((predictions - y) ** 2)

# Benchmarking Configurations
def evaluate_nn_configuration(layer_sizes, population_size, generations=10000):
    X = np.random.uniform(-1, 1, (100, 2))
    y = np.sin(np.pi * X[:, 0]) + np.cos(np.pi * X[:, 1])
    y = y.reshape(-1, 1)

    network = NeuralNetwork(layer_sizes=layer_sizes)

    def fitness_function(params):
        network.set_parameters(params)
        return -network.evaluate_loss(X, y)

    nes = NaturalEvolutionStrategy(network, population_size, learning_rate=0.1, sigma_init=0.5, seed=42)
    nes.optimize(fitness_function, generations=generations)

    grid_size = 50
    x1 = np.linspace(-1, 1, grid_size)
    x2 = np.linspace(-1, 1, grid_size)
    X1, X2 = np.meshgrid(x1, x2)
    X_grid = np.c_[X1.ravel(), X2.ravel()]
    y_true = np.sin(np.pi * X_grid[:, 0]) + np.cos(np.pi * X_grid[:, 1])
    y_pred = network.predict(X_grid).flatten()
    errors = np.abs(y_true - y_pred).reshape(grid_size, grid_size)

    return {
        "loss_history": nes.loss_history,
        "time_history": nes.time_history,
        "total_time": nes.time_history[-1],  # Total time for all generations
        "y_true": y_true,
        "y_pred": y_pred,
        "errors": errors,
        "configuration": f"Layers {layer_sizes}, Pop {population_size}",
        "X_grid": X_grid
    }


# Configurations to Compare
configurations = [
    {"layer_sizes": [2, 10, 1], "population_size": 50},
    {"layer_sizes": [2, 10, 1], "population_size": 500},
    {"layer_sizes": [2, 10, 20, 10, 1], "population_size": 50},
    {"layer_sizes": [2, 10, 20, 10, 1], "population_size": 500},
]


# Plotting Results with Improvements
def plot_comparisons(results):
    # Plot MSE vs Generations
    plt.figure(figsize=(12, 6))
    for result in results:
        plt.plot(range(1, len(result["loss_history"]) + 1), result["loss_history"],
                 label=f"{result['configuration']} (Total Time: {result['total_time']:.2f}s)")
    plt.xscale("log")
    plt.xlabel("Generations (Log Scale)")
    plt.ylabel("MSE Loss")
    plt.title("MSE Loss vs Generations for Different Configurations")
    plt.legend()
    plt.show()

    # Create a single 2x2 grid for "True vs Predicted Output"
    fig, axes = plt.subplots(2, 2, figsize=(14, 12))
    for i, result in enumerate(results):
        ax = axes[i // 2, i % 2]
        X_grid = result["X_grid"]
        y_true = result["y_true"]
        y_pred = result["y_pred"]

        ax.scatter(X_grid[:, 0], y_true, alpha=0.5, label="True Function (X2 slices)")
        ax.scatter(X_grid[:, 0], y_pred, alpha=0.5, label="NN Predictions (X2 slices)")
        ax.set_title(result["configuration"])
        ax.set_xlabel("Input Feature (X1)")
        ax.set_ylabel("Output (y)")
        ax.legend()

    fig.suptitle("True vs Predicted Output for All Configurations", fontsize=16)
    plt.tight_layout()
    plt.show()

    # Create a single 2x2 grid for heatmaps with shared color bar
    fig, axes = plt.subplots(2, 2, figsize=(14, 12))
    all_errors = [result["errors"] for result in results]
    min_error, max_error = np.min(all_errors), np.max(all_errors  # Get global min/max for color normalization
    )

    for i, result in enumerate(results):
        ax = axes[i // 2, i % 2]
        im = ax.imshow(result["errors"], extent=(-1, 1, -1, 1), origin="lower", cmap="plasma",
                       vmin=min_error, vmax=max_error, aspect="auto")
        ax.set_title(result["configuration"])
        ax.set_xlabel("X1")
        ax.set_ylabel("X2")

    # Add a single color bar for all subplots
    cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])  # [left, bottom, width, height]
    fig.colorbar(im, cax=cbar_ax, label="Absolute Error")

    fig.suptitle("Prediction Errors Heatmap for All Configurations", fontsize=16)
    plt.tight_layout(rect=[0, 0, 0.9, 1])  # Adjust layout to accommodate color bar
    plt.show()


# Run all configurations and generate plots
results = [evaluate_nn_configuration(**config) for config in configurations]
plot_comparisons(results)

KeyboardInterrupt: 