In [3]:
from simulation_Pyr import simulation_Pyr
from simulation_PV import simulation_PV
from helper import firing_rate
import ray

# Initialize Ray with multiple workers
ray.init(ignore_reinit_error=True, num_cpus=2)

total_time = 500 #total simulation time in ms

amp1_range, amp2_range, freq1_range,freq2_range = \
    np.linspace(-5, 5, 10), np.linspace(0, 10, 10), np.linspace(-3, 3, 10), np.linspace(1, 5, 10)

X1, X2, X3, X4 = np.meshgrid(amp1_range, amp2_range, freq1_range, freq2_range, indexing='ij')
grid_points = np.column_stack((X1.flatten(), X2.flatten(), X3.flatten(), X4.flatten()))
n_grid_points = len(grid_points)

def simulate_and_evaluate(x):
    amp1, amp2, freq1, freq2 = x[0], x[1], x[2], x[3]
    results = [simulation_Pyr.remote(
                num_electrode = 1,
                amp1 = amp1, amp2 = amp2, freq1 = freq1, freq2 = freq2,
                total_time = total_time,
                plot_waveform = False # Set to True to plot injected current
            ),
            simulation_PV.remote(
                num_electrode=1,
                amp1 = amp1, amp2 = amp2, freq1 = freq1, freq2 = freq2,
                total_time = total_time,
                plot_waveform = False # Set to True to plot injected current
            )]
    (response_Pyr, t), (response_PV, t) = ray.get(results)
    FR_Pyr = firing_rate(response_Pyr, 500)
    FR_PV = firing_rate(response_PV, 500)
    return FR_Pyr, FR_PV, FR_Pyr - FR_PV

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF
from sklearn.preprocessing import StandardScaler
from numpy.random import seed

def sample_initial_points():
    return grid_points[np.rand.randint(0, n_grid_points, size=(10))]

def acquisition_function(mean, std, best_f):
    with np.errstate(divide='ignore', invalid='ignore'):
        imp = mean - best_f
        Z = imp / std
        ei = np.zeros_like(imp)
        ei[ei < imp]=imp[ei<imp]
        ei += np.abs(imp) * norm.cdf(Z) + std * norm.pdf(Z)
        ei[std == 0.0] = 0.0
    return ei

def BO_loop():
    opt_trials = 10
    X_init = sample_initial_points()
    Y_init = np.zeros((len(X_init), 2))
    obj_init = np.zeros(len(X_init))
    for i, x in enumerate(X_init):
        FR_Pyr, FR_PV, obj_val = simulate_and_evaluate(x)
        Y_init[i, 0] = FR_Pyr
        Y_init[i, 1] = FR_PV
        obj_init[i] = obj_val
    total_points = 10 + opt_trials
    X_data = np.zeros((total_points, 4))
    y_data = np.zeros((total_points, 2))
    obj_data = np.zeros(total_points)
    X_data[:len(X_init), 0] = X_init
    y_data[:len(X_init), 0] = Y_init
    obj_data[:len(X_init)] = obj_init
    n_points = X_init.shape[0]

    kernel = RBF(length_scale=1.0, length_scale_bounds=(1e-2, 1e3))
    gp = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=5, normalize_y=True, alpha=1e-6)
    best_observed = np.zeros(opt_trials + 1)
    current_best = np.max(obj_data[:n_points])
    best_observed[0] = current_best
    X_search = grid_points
    for i in range(opt_trials):
        gp.fit(X_data[:n_points], obj_data[:n_points])
        mean, std = gp.predict(X_search, return_std=True)
        std = std.reshape(-1, 1)
        ei_values = acquisition_function(mean.reshape(-1, 1), std, current_best)
        next_idx = np.argmax(ei_values)
        next_point = X_search[next_idx]
        FR_Pyr, FR_PV, next_value = simulate_and_evaluate(next_point)
        X_data[n_points, 0] = next_point
        obj_data[n_points, 0] = next_value
        n_points += 1
        next_value = max(next_value, current_best)
        best_observed[i+1] = current_best
        if i == opt_trials - 1:
            plot_bo_iteration(X_data[:n_points], obj_data[:n_points], gp, X_search, ei_values, next_point, i)


def plot_bo_iteration(X_data, y_data, gp, X_search, ei_values, next_point, iteration):
    plt.figure(figsize=(10, 6))
    X_plot = np.linspace(-10, 10, 1000).reshape(-1, 1)
    y_plot = objective(X_plot)
    print(X_data)
    plt.plot(X_plot, y_plot, 'k--', label='Objective Function')
    mean, std = gp.predict(X_plot, return_std=True)
    plt.plot(X_plot, mean, 'b-', label='GP Mean')
    plt.fill_between(X_plot.ravel(), mean - 2*std, mean + 2*std, alpha=0.2, color='b')
    plt.scatter(X_data[:10], y_data[:10], c='r', marker='o', label='Init Observations')
    plt.scatter(X_data[10:-1], y_data[10:-1], c='g', marker='o', label='Iterated Observations')
    plt.xlabel('x')
    plt.ylabel('f(x)')
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f'bo_iteration_{iteration+1}.png')
    plt.show()

BO_loop()

ModuleNotFoundError: No module named 'simulation_Pyr'