# Figure S6C and Figure S6D

- **Figure S6C**: Plots showing wSE and wSVE change with x.
- **Figure S6D**: Plots showing cost changes with each parameter.

**Dependencies**: Python 3.6+, numpy, scipy, matplotlib  
**Author**: Yingming Pei  
**Date**: 2025-05-14  
**Output**: Saves figure as 'FigS6C.pdf' and 'FigS6D.pdf'

In [None]:
import numpy as np
from scipy.optimize import differential_evolution
import matplotlib.pyplot as plt

# Configure matplotlib plotting settings
def configure_plot_style():
    """Set matplotlib parameters for consistent figure styling."""
    plt.rcParams["axes.spines.top"] = False
    plt.rcParams["axes.spines.right"] = False
    plt.rcParams["axes.linewidth"] = 2
    plt.rcParams["lines.linewidth"] = 2
    plt.rcParams["xtick.major.size"] = 10
    plt.rcParams["xtick.major.width"] = 2
    plt.rcParams["ytick.major.size"] = 10
    plt.rcParams["ytick.major.width"] = 2
    plt.rcParams["xtick.minor.size"] = 5
    plt.rcParams["xtick.minor.width"] = 2
    plt.rcParams["ytick.minor.size"] = 5
    plt.rcParams["ytick.minor.width"] = 2
    plt.rcParams["xtick.labelsize"] = 20
    plt.rcParams["ytick.labelsize"] = 20
    plt.rcParams["axes.labelsize"] = 20
    plt.rcParams["legend.fontsize"] = 20
    plt.rcParams["axes.titlesize"] = 20
    plt.rcParams["font.size"] = 20
    plt.rcParams["font.family"] = "Arial"

# Define experimental data and parameters
rEdatacontrol = np.array([[0.1012, 0.1032, 0.1254], [0.1824, 0.1939, 0.2282]])  # Pyr control firing rates
rEdataDART = np.array([[0.1190, 0.1304, 0.1593], [0.2021, 0.2235, 0.2665]])  # Pyr DART firing rates
rSdatacontrol = np.array([[0.0539, 0.0728, 0.1014], [0.1285, 0.1688, 0.2091]])  # SST control firing rates
rSdataDART = np.array([[0.0246, 0.0463, 0.0883], [0.1116, 0.1717, 0.2423]])  # SST DART firing rates
rEseDART = np.array([[0.00742876, 0.00778821, 0.00931483],
                     [0.01152858, 0.01174181, 0.01370909]])
rSseDART = np.array([[0.00427268, 0.00634232, 0.01054626],
                     [0.01366939, 0.01577187, 0.02128831]])
alphaE = 0.25 / rEdatacontrol[1, 1] 
alphaS = 0.25 / rSdatacontrol[1, 1] 
contrast = [0.25, 0.5, 1.0]  # Contrast levels
total_duration = 1000  # Simulation duration
bounds = [(-1, 1.5), (0, 2.5), (0, 10), (-4.2, 4.2), (-4.2, 4.2), (0, 2.4)]  # Parameter bounds
params_labels = ['wEE-wEE*', 'wES-wES*', 'wSE', 'wSE*', 'wSS*', 'g']  # Parameter labels

# Define cost function for optimization and sensitivity analysis
def cost_function(params, x):
    """Calculate cost by comparing simulated firing rates to experimental DART data."""
    WEE, WES, WSE, WSVE, WSS, g = params
    WSVE = [WSVE, WSVE * g]
    WSS = [WSS, WSS * g]
    if np.abs(WSVE[1]) > 10 or np.abs(WSS[1]) > 10:
        return 100
    if WSE < max(WSVE):
        return 100
    cost = 0.0
    for istate in range(2):
        for icontrast in range(3):
            IE = np.sqrt(rEdatacontrol[istate, icontrast] / alphaE) - \
                 WEE * rEdatacontrol[istate, icontrast] + WES * rSdatacontrol[istate, icontrast]
            IS = np.sqrt(rSdatacontrol[istate, icontrast] / alphaS) - \
                 (WSE - WSVE[istate]) * rEdatacontrol[istate, icontrast] - WSS[istate] * rSdatacontrol[istate, icontrast]
            rE = rEdataDART[istate, icontrast]
            rS = rSdataDART[istate, icontrast]
            rE_values = []
            rS_values = []
            for it in range(total_duration):
                if rE >= 10.0 or rS >= 10.0:
                    return 100
                drE = -rE + alphaE * (WEE * rE - WES * rS + IE) ** 2
                drS = -rS + alphaS * (((1.0 - x) * WSE - WSVE[istate]) * rE + WSS[istate] * rS + IS) ** 2
                rE += 0.1 * drE
                rS += 0.1 * drS
                rE_values.append(rE)
                rS_values.append(rS)
            if len(rE_values) != total_duration or len(rS_values) != total_duration:
                return 100
            rE_max_first_half = max(rE_values[:total_duration//2])
            rE_max_second_half = max(rE_values[total_duration//2:])
            if rE_max_second_half - rE_max_first_half > 1e-4:
                return 100
            rE_avg = np.mean(rE_values[total_duration//2:])
            rS_avg = np.mean(rS_values[total_duration//2:])
            cost += ((rE_avg - rEdataDART[istate, icontrast]) / rEseDART[istate, icontrast]) ** 2 + \
                    ((rS_avg - rSdataDART[istate, icontrast]) / rSseDART[istate, icontrast]) ** 2
    return cost

# Perform parameter optimization
def optimize_parameters(x):
    """Optimize parameters for given x using differential evolution."""
    result = differential_evolution(cost_function, bounds, args=(x,))
    return [x, result.fun, *result.x]

# Calculate sensitivity analysis results
def calculate_sensitivity(x_values, best_results):
    """Perform sensitivity analysis by varying g around optimal value."""
    new_results = []
    for x in x_values:
        pre_line = best_results
        pre_best = pre_line[2:]
        for iw in range(6):
            new_results.append([x, iw, pre_best[iw], pre_line[1], *pre_best])
            for istep in range(1, 1000):
                for direction in [1, -1]:
                    current_params = pre_best.copy()
                    current_weight = pre_best[iw] + direction * pre_best[iw] * istep * 0.02
                    current_params[iw] = current_weight
                    cost = cost_function(current_params, x)
                    new_results.append([x, iw, current_weight, cost, *current_params])
                    print(f"x={x:.2f}, param={params_labels[iw]}, weight={current_weight:.4f}, cost={cost:.4f}")
    return np.array(new_results)

# Visualize sensitivity analysis results
def visualize_sensitivity(x_values, sa_results, params_labels):
    """Plot cost changes with g for all states combined."""
    configure_plot_style()
    bests = np.array(sa_results)
    for x in x_values:
        for iw in range(6):
            plt.figure(figsize=(8, 4))
            state_all = bests[(bests[:, 0] == x) & (bests[:, 1] == iw)]
            sort_index = np.argsort(state_all[:, 2])
            plt.plot(state_all[sort_index, 2], state_all[sort_index, 3], label='All', color='black')
            midpoint_idx = len(state_all) // 2
            plt.scatter(state_all[sort_index, 2][midpoint_idx], state_all[sort_index, 3][midpoint_idx],
                        color='blue', label='Best All')
            plt.title(f'Cost Change with {params_labels[iw]} when x = {x}')
            plt.xlabel(params_labels[iw])
            plt.ylabel('Cost')
            # plt.xlim(1.5, 2.5)
            # plt.ylim(0, 10)
            # plt.xticks([1.5, 2, 2.5])
            # plt.yticks([0, 5, 10])
            plt.legend()
            # plt.savefig(f'Sensitivity_{params_labels[iw]}_x{x}_FigS7D.pdf', format='pdf', dpi=400, bbox_inches='tight')
            plt.show()


x = 0.5
best_results = optimize_parameters(x)
x_values = [0.5]
sa_results = calculate_sensitivity(x_values, best_results)
visualize_sensitivity(x_values, sa_results, params_labels)