### 1 Fitting 

- Fitting with experimental data of DART group
- Only V-gain is free to change between states.
- Target points are six: three rEs and three rSs in DART group
- Parameters are five weights in stationary state and a gain of V

    wEE_star, wES_star, wSE, wSE_star, wSS_star, g
- External inputs are caluculated with data of control group

    - rE(l,m,h) = gain_E*(wEE_star * rE(l,m,h) - wES_star * r(l,m,h) + IE(l,m,h))^2
    - rS(l,m,h) = gain_S*((wSE - wSE_star) * rE(l,m,h) + wSS_star * rS(l,m,h) + IS(l,m,h))^2

In [None]:
import numpy as np
from scipy.optimize import minimize
from scipy.optimize import differential_evolution

# define the data
rEdatacontrol = np.array([[0.1012, 0.1032, 0.1254], [0.1824, 0.1939, 0.2282]])
rEdataDART = np.array([[0.1190, 0.1304, 0.1593], [0.2021, 0.2235, 0.2665]])
rSdatacontrol = np.array([[0.0539, 0.0728, 0.1014], [0.1285, 0.1688, 0.2091]])
rSdataDART = np.array([[0.0246, 0.0463, 0.0883], [0.1116, 0.1717, 0.2423]])
rEseDART = np.array([[0.007428755604043545,0.007788206032446528,0.009314834901506188], 
                     [0.01152857600761969,0.011741811713686944,0.013709086399540904]])
rSseDART = np.array([[0.004272683280978827,0.006342322667567268,0.010546259821164696], 
                     [0.013669393588839406,0.015771868172978647,0.021288306730872463]])

alphaE = 0.25 / rEdatacontrol[1, 1]
alphaS = 0.25 / rSdatacontrol[1, 1]

contrast = [0.25, 0.5, 1.0]

def cost_function(params, x):
    WEE, WES, WSE, WSVE, WSS, g = params
    WSVE = [WSVE, WSVE * g]
    WSS = [WSS, WSS * g]

    if np.abs(WSVE[1]) > 10 or np.abs(WSS[1]) > 10:
        return 100
    
    if WSE < max(WSVE):
        return 100
    
    cost = 0.0

    for istate in range(2):

        for icontrast in range(3):
            IE = np.sqrt(rEdatacontrol[istate, icontrast] / alphaE) - WEE * rEdatacontrol[istate, icontrast] + WES * rSdatacontrol[istate, icontrast]
            IS = np.sqrt(rSdatacontrol[istate, icontrast] / alphaS) - (WSE - WSVE[istate]) * rEdatacontrol[istate, icontrast] - WSS[istate] * rSdatacontrol[istate, icontrast]
            rE = rEdataDART[istate, icontrast]
            rS = rSdataDART[istate, icontrast]
            rE_values = []
            rS_values = []

            for it in range(total_duration):
                if rE >= 10.0 or rS >= 10.0:
                    return 100
                drE = -rE + alphaE * (WEE * rE - WES * rS + IE) ** 2
                drS = -rS + alphaS * (((1.0 - x) * WSE - WSVE[istate]) * rE + WSS[istate] * rS + IS) ** 2
                rE += 0.1 * drE
                rS += 0.1 * drS
                rE_values.append(rE)
                rS_values.append(rS)
        
            if len(rE_values) == total_duration and len(rS_values) == total_duration:
                rE_max_first_half = max(rE_values[:total_duration//2])
                rE_max_second_half = max(rE_values[total_duration//2:])
            else:
                return 100

            rE_avg = sum(rE_values[total_duration//2:]) / (total_duration//2)
            rS_avg = sum(rS_values[total_duration//2:]) / (total_duration//2)
        
            cost += ((rE_avg - rEdataDART[istate, icontrast])/rEseDART[istate, icontrast]) ** 2 + \
                ((rS_avg - rSdataDART[istate, icontrast])/rSseDART[istate, icontrast]) ** 2

    return cost

total_duration = 1000

bounds = [(-1, 1.5), (0, 2.5), (0, 10), (-4.2, 4.2), (-4.2, 4.2), (0, 2.4)]

x = 0.5
initial_guess = [0, 0, 0, 0, 0, 1]

result = minimize(cost_function, initial_guess, x, bounds=bounds)
# result = differential_evolution(cost_function, bounds, args=(x,))
best_params = result.x
best_cost = result.fun
print(f"{x:.6f} {best_cost:.6f} {best_params}")
best_results = [x, best_cost, *best_params]


### 2 Visualization

In [142]:
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.pyplot as plt

# Figure configurations
# remove the top and right spines from plot in the global plt setting
plt.rcParams["axes.spines.top"] = False
plt.rcParams["axes.spines.right"] = False

# change the linewidth of the axes and spines
plt.rcParams["axes.linewidth"] = 2
plt.rcParams["lines.linewidth"] = 2
plt.rcParams["xtick.major.size"] = 10
plt.rcParams["xtick.major.width"] = 2
plt.rcParams["ytick.major.size"] = 10
plt.rcParams["ytick.major.width"] = 2
plt.rcParams["xtick.minor.size"] = 5
plt.rcParams["xtick.minor.width"] = 2
plt.rcParams["ytick.minor.size"] = 5
plt.rcParams["ytick.minor.width"] = 2

# change the fontsize of the ticks label
plt.rcParams["xtick.labelsize"] = 20
plt.rcParams["ytick.labelsize"] = 20

# change the fontsize of the axes label
plt.rcParams["axes.labelsize"] = 20
# change the fontsize of the legend
plt.rcParams["legend.fontsize"] = 20
# change the fontsize of the title
plt.rcParams["axes.titlesize"] = 20
# change the title font size
plt.rcParams["font.size"] = 20

# change the font family to Arial
plt.rcParams["font.family"] = "Arial"


In [None]:
## sensitivity analysis

def cost_function(params, x):
    WEE, WES, WSE, WSVE, WSS, g = params
    WSVE = [WSVE, WSVE * g]
    WSS = [WSS, WSS * g]

    if np.abs(WSVE[1]) > 10 or np.abs(WSS[1]) > 10:
        return 100
    
    if WSE < max(WSVE):
        return 100
    
    cost = 0.0

    for istate in range(2):

        for icontrast in range(3):

            IE = np.sqrt(rEdatacontrol[istate, icontrast] / alphaE) - WEE * rEdatacontrol[istate, icontrast] + WES * rSdatacontrol[istate, icontrast]
            IS = np.sqrt(rSdatacontrol[istate, icontrast] / alphaS) - (WSE - WSVE[istate]) * rEdatacontrol[istate, icontrast] - WSS[istate] * rSdatacontrol[istate, icontrast]
            rE = rEdataDART[istate, icontrast]
            rS = rSdataDART[istate, icontrast]
            rE_values = []
            rS_values = []

            for it in range(total_duration):
                if rE >= 10.0 or rS >= 10.0:
                    return 100
                drE = -rE + alphaE * (WEE * rE - WES * rS + IE) ** 2
                drS = -rS + alphaS * (((1.0 - x) * WSE - WSVE[istate]) * rE + WSS[istate] * rS + IS) ** 2
                rE += 0.1 * drE
                rS += 0.1 * drS
                rE_values.append(rE)
                rS_values.append(rS)
        
            if len(rE_values) == total_duration and len(rS_values) == total_duration:
                rE_max_first_half = max(rE_values[:total_duration//2])
                rE_max_second_half = max(rE_values[total_duration//2:])
            else:
                return 100  
            
            rE_avg = sum(rE_values[total_duration//2:]) / (total_duration//2)
            rS_avg = sum(rS_values[total_duration//2:]) / (total_duration//2)
        
            cost += ((rE_avg - rEdataDART[istate, icontrast])/rEseDART[istate, icontrast]) ** 2 + \
                ((rS_avg - rSdataDART[istate, icontrast])/rSseDART[istate, icontrast]) ** 2

    return cost


def calculate_cost(x_values, best_results):
    new_results = []

    for x in x_values:
        pre_line = best_results
        print(pre_line)
        pre_best = pre_line[2:]
        print(pre_best)

        for iw in range(5, 6):
            new_results.append([x, iw, pre_best[iw], pre_line[1], *pre_best])

            for istep in range(1, 1000):

                for direction in [1, -1]:  # 1 for up-scan, -1 for down-scan
                    current_params = pre_best.copy()
                    current_weight = pre_best[iw] + direction * pre_best[iw] * istep * 0.02
                    current_params[iw] = current_weight
                    cost = cost_function(current_params, x)
                    print(f"{x} {iw} {current_weight} {cost} {current_params}")
                    new_results.append([x, iw, current_weight, cost, *current_params])
    return new_results

x_values = [0.5,]
states = range(2)
sa_results = calculate_cost(x_values, best_results)


In [None]:
params_labels = ['wEE-wEE*', 'wES-wES*', 'wSE', 'wSE*', 'wSS*', 'g']
def visualize_results(x_values, sa_results, best_results):
    bests = np.array(sa_results)

    for x in x_values:

        for iw in range(5, 6):
            plt.figure(figsize=(8, 4))
            state_all = bests[(bests[:, 0] == x) & (bests[:, 1] == iw)]
            sort_index = np.argsort(state_all[:, 2])
            plt.plot(state_all[sort_index, 2], state_all[sort_index, 3], label='All', color='black')
            midpoint_x = state_all[sort_index, 2][len(state_all)//2]
            midpoint_y = state_all[sort_index, 3][len(state_all)//2]
            plt.scatter(midpoint_x, midpoint_y, color='blue', label='Best All')

            plt.title(f'Cost change with {params_labels[iw]} when x = {x}')
            plt.xlim(-18,18)
            # plt.xlim(0.16,0.4)
            # plt.xlim(0.9,1.1)
            # plt.xlim(2.4,5)
            # plt.xlim(-4.2, -3)
            # plt.xlim(-2.5, 0.5)
            plt.xlim(1.5, 2.5)
            plt.xlabel(f'{params_labels[iw]}')
            # plt.grid()
            plt.ylim(0, 10)
            # plt.xticks([0.9,1.0, 1.1])
            # plt.xticks([0.16, 0.28, 0.4])
            # plt.xticks([2.4, 3.7, 5.0])
            # plt.xticks([-4.2, -3.6, -3])
            # plt.xticks([-2.5, -1, 0.5])
            plt.xticks([1.5, 2, 2.5])
            plt.yticks([0,5,10])
            plt.ylabel('Cost')
            # plt.legend()
            # plt.savefig(f'{params_labels[iw]}_x{x}.pdf', format='pdf')
            plt.show()
visualize_results(x_values, sa_results, best_results)
