# Shared

In [None]:
import sys
import os

src_path = os.path.join(os.getcwd(), 'src')
if src_path not in sys.path:
    sys.path.insert(0, src_path)

In [None]:
import matplotlib.pyplot as plt

def scale_text(f=1):    
    plt.rcParams.update({
        'font.size': 10*f,            # Tamaño general del texto
        'axes.titlesize': 10*f,       # Títulos de los ejes
        'axes.labelsize': 10*f,       # Etiquetas de los ejes
        'xtick.labelsize': 8*f,      # Ticks del eje X
        'ytick.labelsize': 8*f,      # Ticks del eje Y
        'legend.fontsize': 10*f,      # Tamaño de la leyenda
        'figure.titlesize': 14*f,      # Título global de la figura
    })
    
px = 1/plt.rcParams['figure.dpi']  # pixel in inches
cm = 1/2.54 # centimeter in inches
cm2px = 37.79

static_figures = "static_figures/"
output_figs = "output_figures/"
tmp_figs = output_figs + "tmp/"

# Figure 1

In [None]:
if True:
    import numpy as np
    from matplotlib import pyplot as plt
    from matplotlib import patches
    
    from scipy.ndimage import gaussian_filter1d
    def smooth_spikes(t,dt, spike_times, sigma_s=200.0, resolution=1000):
        sub = (spike_times>=t) * (spike_times<=(t+dt))
        spike_times = spike_times[sub]
        hist, times = np.histogram(spike_times, bins=resolution, range=(t,t+dt))
        signal = gaussian_filter1d(hist.astype('float'), sigma_s/dt)*resolution/dt
        return times[:-1], signal
    
    from experiments import (
        Experiment_StatePopulationsWithChannels, 
        Experiment_ProAntiWithChannels,
    )

    n_actions = 2
    lr = 1e-2
    d1_lr = lr
    d2_lr = 2*lr
    np.random.seed(0)
    
    e = Experiment_StatePopulationsWithChannels(
        n_trials=3, 
        n_trial_types=n_actions,
        n_actions=n_actions,
        
        #with_ach=False,
        
        d1_learning_rate = d1_lr, #2e-2, #5e-3,
        d2_learning_rate = d2_lr, #8e-2, #2e-2 #*4
    
        n_trial_steps = 50,
        trial_step_duration = 4,
        trial_width_step = 1,
        trial_width = 10, #10,
    )
    
    e.run()
    ot = np.array(e.ot)
    oi = np.array(e.oi)

In [None]:
if True:
    scale_text(0.75)
    
    t = 0
    dt = 6
    x_legend = "Trial time (s)"
    
    fig, axs = plt.subplots(1,1, figsize=(14*cm, 7*cm))
    ax = axs
    
    # Trial regions with vertical boundaries
    x = np.linspace(t, t + dt, num=1000)
    
    ax.text(1, 0.5, "Trial #1", horizontalalignment='center')
    ax.text(3, 0.5, "Trial #2", horizontalalignment='center')
    ax.text(5, 0.5, "Trial #3", horizontalalignment='center')
    
    # Exclusion periods with dashed boundary lines
    exclusion_periods = [(0, 0.2), (2.0, 2.2), (4.0, 4.2)]
    for i, (start, end) in enumerate(exclusion_periods):
        ax.axvline(x=start, linestyle=':', color='black', label='Trial start' if i==0 else None)
        ax.fill_between(x, 0, 1, where=(start < x) * (x < end), label='Exclusion period' if i==0 else None, alpha=0.3, color='gray')
        #ax.vlines(end, 0,1, linestyle='--', color='black', linewidth=0.8, alpha=0.7)
    ax.set_ylim((0,1))
    
    ax.legend(loc='upper right')
    ax.get_yaxis().set_visible(False)
    ax.set_title('Inputs spikes through different trials')
    ax.set_xlabel('Simulation time (s)')
    
    # Spike raster plot
    ax2 = ax.twinx()
    subset_inputs = e.input_layer[e.input_layer % 4 == 0]
    sub = np.isin(oi, subset_inputs) * (ot >= t) * (ot <= (t + dt))
    n_neurons_per_input = e.n_inputs // e.n_actions
    colors = (oi[sub] // n_neurons_per_input)
    
    # Plot spikes with unique colors for each stimulus
    for icolor in set(colors):
        _sub = colors == icolor
        ax2.scatter(ot[sub][_sub], np.max(oi[sub]) + 1 - oi[sub][_sub], marker='|', s=30, color=f'C{icolor}')
        ax2.scatter([], [], marker='|', s=50, color=f'C{icolor}', label=f'Stimulus #{icolor + 1} spikes')
    
    # Adjust legend for colorblind-friendly design
    ax2.legend(loc='lower left')
    ax2.set_ylabel('Neuron index')
    ax2.yaxis.set_label_position("left")
    ax2.yaxis.tick_left()
    
    plt.tight_layout()
    plt.savefig(output_figs+"figure_1.svg", bbox_inches="tight")
    plt.savefig(output_figs+"figure_1.png", dpi=300, bbox_inches="tight")

    plt.show()

# Figure 2

In [None]:
scale_text(1)

## Subpanel B

In [None]:
if True:
    import numpy as np
    from matplotlib import pyplot as plt
    
    fig, axs = plt.subplots(1,2, figsize=(9*cm,6*cm))
    
    _ = fig.suptitle("Learning Kernels")
    _ = axs[0].set_title("SPN D1")
    _ = axs[1].set_title("SPN D2")
    _ = axs[0].set_ylabel("Normalized\nweight change")
    _ = fig.supxlabel("Time difference (s)")
    
    for ax in axs:
        ax.set_ylim((-0.5,1))
        ax.set_yticks([0,1])
        ax.set_xlim((-0.2,0.2))
    
    axs[1].plot([],[], color='green', label='Reward kernel')
    axs[1].plot([],[], color='red', label='Punish kernel')
    
    stdp_d1_rw_prepos = 1.0
    stdp_d1_rw_pospre = 0.0
    stdp_d1_pn_prepos = -0.5
    stdp_d1_pn_pospre = -0.5
    
    stdp_d2_rw_prepos = -0.5
    stdp_d2_rw_pospre = 0.0
    stdp_d2_pn_prepos = 1.0
    stdp_d2_pn_pospre = 1.0
    
    tau_decay = 0.032
    
    x = np.linspace(-0.2,0.2, num=1000)
    
    y_rw_d1 = np.zeros_like(x)
    y_pn_d1 = np.zeros_like(x)
    y_rw_d1[x<0] = stdp_d1_rw_prepos * np.exp(x[x<0] / tau_decay)
    y_rw_d1[x>0] = stdp_d1_rw_pospre * np.exp(-x[x>0] / tau_decay)
    y_pn_d1[x<0] = stdp_d1_pn_prepos * np.exp(x[x<0] / tau_decay)
    y_pn_d1[x>0] = stdp_d1_pn_pospre * np.exp(-x[x>0] / tau_decay)
    axs[0].plot(x[x<0], y_rw_d1[x<0], color='green')
    axs[0].plot(x[x>0], y_rw_d1[x>0], color='green')
    axs[0].plot(x[x<0], y_pn_d1[x<0], color='red')
    axs[0].plot(x[x>0], y_pn_d1[x>0], color='red')
    
    y_rw_d2 = np.zeros_like(x)
    y_pn_d2 = np.zeros_like(x)
    y_rw_d2[x<0] = stdp_d2_rw_prepos * np.exp(x[x<0] / tau_decay)
    y_rw_d2[x>0] = stdp_d2_rw_pospre * np.exp(-x[x>0] / tau_decay)
    y_pn_d2[x<0] = stdp_d2_pn_prepos * np.exp(x[x<0] / tau_decay)
    y_pn_d2[x>0] = stdp_d2_pn_pospre * np.exp(-x[x>0] / tau_decay)
    axs[1].plot(x[x<0], y_rw_d2[x<0], color='green')
    axs[1].plot(x[x>0], y_rw_d2[x>0], color='green')
    axs[1].plot(x[x<0], y_pn_d2[x<0], color='red')
    axs[1].plot(x[x>0], y_pn_d2[x>0], color='red')
    
    plt.tight_layout()
    
    axs[1].legend(bbox_to_anchor=(1.1,0.75))
    
    plt.savefig(tmp_figs + "figure_2b.svg")
    plt.show()

## Subpanel C

In [None]:
if True:
    import numpy as np
    from matplotlib import pyplot as plt
    
    total_time = 3
    dt = 1e-3
    tau_decay = 0.3
    decay = np.exp(-dt/tau_decay)
    
    time = np.linspace(0,total_time, num=int(total_time/dt))
    
    spk_pre_times = [0.2, 1, 2]
    spk_pos_times = [0.15, 1.1, 1.9]
    rw_times = [0.5, 1.25, 1.7]
    pn_times = [0.3, 2.25]
    ach_value = np.zeros_like(time)
    ach_value[(1<time)*(time<1.5)] = 1
    ach_value[(1.9<time)*(time<2.4)] = 1
    
    pre_trace = np.zeros_like(time)
    pos_trace = np.zeros_like(time)
    prepos_trace = np.zeros_like(time)
    pospre_trace = np.zeros_like(time)
    
    stdp_prepos = 1.0
    stdp_pospre = -1.0
    stdp_rw_prepos = stdp_d1_rw_prepos = 1.0
    stdp_rw_pospre = stdp_d1_rw_pospre = 0.0
    stdp_pn_prepos = stdp_d1_pn_prepos = -0.5
    stdp_pn_pospre = stdp_d1_pn_pospre = -0.5
    stdp_d2_rw_prepos = -0.5
    stdp_d2_rw_pospre = 0.0
    stdp_d2_pn_prepos = 1.0
    stdp_d2_pn_pospre = 1.0
    
    dw_stdp = np.zeros_like(time)
    dw_rf = np.zeros_like(time)
    dw_ach_d1 = np.zeros_like(time)
    dw_ach_d2 = np.zeros_like(time)
    
    # Calculate pre and pos traces
    
    for spk_time in spk_pre_times:
        idx = np.searchsorted(time, spk_time)
        pre_trace[idx] = 1
    for spk_time in spk_pos_times:
        idx = np.searchsorted(time, spk_time)
        pos_trace[idx] = 1
    
    for i, t in enumerate(time[1:]):
        pre_trace[i] += pre_trace[i-1] * decay
        pos_trace[i] += pos_trace[i-1] * decay
    
    # Calculate pre-pos and pos-pre traces
    
    for spk_time in spk_pre_times:
        idx = np.searchsorted(time, spk_time)
        pospre_trace[idx] = pos_trace[idx]
    for spk_time in spk_pos_times:
        idx = np.searchsorted(time, spk_time)
        prepos_trace[idx] = pre_trace[idx]
    
    for i, t in enumerate(time[1:]):
        prepos_trace[i] += prepos_trace[i-1] * decay
        pospre_trace[i] += pospre_trace[i-1] * decay
    
    # Calculate weight evolution
    
    # Classic STDP
    for spk_time in spk_pre_times:
        idx = np.searchsorted(time, spk_time)
        dw_stdp[idx] = pospre_trace[idx] * stdp_pospre
    for spk_time in spk_pos_times:
        idx = np.searchsorted(time, spk_time)
        dw_stdp[idx] = prepos_trace[idx] * stdp_prepos
    for i, t in enumerate(time[1:]):
        dw_stdp[i] += dw_stdp[i-1]
    
    # Reward-modulated STDP
    for spk_time in rw_times:
        idx = np.searchsorted(time, spk_time)
        dw_rf[idx] = prepos_trace[idx]*stdp_rw_prepos + pospre_trace[idx]*stdp_rw_pospre
    for spk_time in pn_times:
        idx = np.searchsorted(time, spk_time)
        dw_rf[idx] = prepos_trace[idx]*stdp_pn_prepos + pospre_trace[idx]*stdp_pn_pospre
    for i, t in enumerate(time[1:]):
        dw_rf[i] += dw_rf[i-1]
    
    # ACh-gated STDP SPN D1
    for spk_time in rw_times:
        idx = np.searchsorted(time, spk_time)
        dw_ach_d1[idx] = (prepos_trace[idx]*stdp_d1_rw_prepos + pospre_trace[idx]*stdp_d1_rw_pospre) * ach_value[idx]
    for spk_time in pn_times:
        idx = np.searchsorted(time, spk_time)
        dw_ach_d1[idx] = (prepos_trace[idx]*stdp_d1_pn_prepos + pospre_trace[idx]*stdp_d1_pn_pospre) * ach_value[idx]
    for i, t in enumerate(time[1:]):
        dw_ach_d1[i] += dw_ach_d1[i-1]
    
    # ACh-gated STDP SPN D2
    for spk_time in rw_times:
        idx = np.searchsorted(time, spk_time)
        dw_ach_d2[idx] = (prepos_trace[idx]*stdp_d2_rw_prepos + pospre_trace[idx]*stdp_d2_rw_pospre) * ach_value[idx]
    for spk_time in pn_times:
        idx = np.searchsorted(time, spk_time)
        dw_ach_d2[idx] = (prepos_trace[idx]*stdp_d2_pn_prepos + pospre_trace[idx]*stdp_d2_pn_pospre) * ach_value[idx]
    for i, t in enumerate(time[1:]):
        dw_ach_d2[i] += dw_ach_d2[i-1]
    
    # Plotting
    
    fig, axs = plt.subplots(3,1, figsize=(5, 4.8), height_ratios=[1,2,1] )
    
    axs[0].scatter(spk_pre_times, np.ones_like(spk_pre_times)*1, marker='|', s=100, label='pre-synaptic\nspikes')
    axs[0].scatter(spk_pos_times, np.ones_like(spk_pos_times)*0, marker='|', s=100, label='post-synaptic\nspikes')
    axs[0].set_ylim((-0.5,1.5))
    axs[0].set_xlim((0, total_time))
    #axs[0].get_yaxis().set_visible(False)
    axs[0].get_xaxis().set_visible(False)
    axs[0].set_ylabel('Neurons')
    
    axs[1].plot(time, pre_trace, linestyle=':', linewidth=1, color='C0', label='pre- trace')
    axs[1].plot(time, pos_trace, linestyle=':', linewidth=1, color='C1', label='post- trace')
    axs[1].plot(time, prepos_trace, color='C0', label='pre-post trace')
    axs[1].plot(time, pospre_trace, color='C1', label='post-pre trace')
    axs[1].vlines(rw_times, 0,1, linewidth=10.0, alpha=0.5, color='limegreen', label='Reward')
    axs[1].vlines(pn_times, 0,1, linewidth=10.0, alpha=0.5, color='red', label='Punish')
    axs[1].fill_between(time, 0,1, where=ach_value==1, color='black', alpha=0.25, label='ACh signal')
    axs[1].set_xlim((0, total_time))
    axs[1].get_xaxis().set_visible(False)
    axs[1].set_ylabel('Trace intensity')
    
    offset = 0.05
    axs[2].set_xlim((0, total_time))
    axs[2].plot(time, offset*0 + dw_stdp, color='black', linewidth=2.0, label='Classic STDP')
    axs[2].plot(time, offset*1 + dw_rf, color='brown', linewidth=2.0,  label='Reward-mod\nSTDP')
    axs[2].plot(time, offset*2 + dw_ach_d1, color='aqua', linewidth=2.0, label='DA/ACh-mod\nSTDP (SPN D1)')
    axs[2].plot(time, offset*3 + dw_ach_d2, color='teal', linewidth=2.0, label='DA/ACh-mod\nSTDP (SPN D2)')
    axs[2].set_xlabel('Time (s)')
    axs[2].set_ylabel('Accumulated\nweight change')
    
    axs[0].legend(bbox_to_anchor=(1.41, 1.1))
    axs[1].legend(bbox_to_anchor=(1.01, 1.1))
    axs[2].legend(bbox_to_anchor=(1.0125, 1.2))
    
    plt.savefig(tmp_figs + "figure_2c.svg")
    plt.show()

## All

In [None]:
import svgutils.transform as sg
import sys 

#create new SVG figure
fig = sg.SVGFigure()

# load matpotlib-generated figures
fig1 = sg.fromfile(static_figures + 'figure_network.svg')
fig2 = sg.fromfile(tmp_figs + 'figure_2b.svg')
fig3 = sg.fromfile(tmp_figs + 'figure_2c.svg')

# get the plot objects
plot1 = fig1.getroot()
plot2 = fig2.getroot()
plot2.moveto(9.75*cm2px, 0.25*cm2px, scale_x=0.65, scale_y=0.65)
plot3 = fig3.getroot()
plot3.moveto(9*cm2px, 3*cm2px, scale_x=0.6, scale_y=0.6)

# add text labels
txt1 = sg.TextElement("2.5mm", "5mm", "A", size=12, weight="bold")
txt2 = sg.TextElement("9cm", "5mm", "B", size=12, weight="bold")
txt3 = sg.TextElement("9cm", "35mm", "C", size=12, weight="bold")

# append plots and labels to figure
fig.append([plot1, plot3, plot2])
fig.append([txt1, txt2, txt3])

# save generated SVG files
fig.save(output_figs + "figure_2.svg")

In [None]:
from IPython.core.display import HTML
img_path = output_figs + "figure_2.svg"
HTML(
    f'<div style="background-color:white; padding-left:130px; padding-top:80px; display:inline-block; transform: scale(1.5);">'
    f'<img src="{img_path}" width="620px" height="320px"></div>'
)

# Figure 3

## Simulation

In [None]:
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import patches

from scipy.ndimage import gaussian_filter1d
def smooth_spikes(t,dt, spike_times, sigma_s=200.0, resolution=1000):
    sub = (spike_times>=t) * (spike_times<=(t+dt))
    spike_times = spike_times[sub]
    hist, times = np.histogram(spike_times, bins=resolution, range=(t,t+dt))
    signal = gaussian_filter1d(hist.astype('float'), sigma_s/dt)*resolution/dt
    return times[:-1], signal

from experiments import (
    Experiment_StatePopulationsWithChannels, 
    Experiment_ProAntiWithChannels,
)

In [None]:
# No-ACh run

n_actions = 2
lr = 1e-3
d1_lr = lr
d2_lr = 2*lr

np.random.seed(0)
e1 = Experiment_StatePopulationsWithChannels(
    n_trials=1000, 
    n_trial_types=n_actions,
    n_actions=n_actions,
    
    with_ach=False,
    
    d1_learning_rate = d1_lr, #2e-2, #5e-3,
    d2_learning_rate = d2_lr, #8e-2, #2e-2 #*4

    n_trial_steps = 50,
    trial_step_duration = 4,
    trial_width_step = 1,
    trial_width = 10, #10,
)

e1.run()
ot1 = np.array(e1.ot)
oi1 = np.array(e1.oi)

In [None]:
# ACh run

n_actions = 2
lr = 1e-2
d1_lr = lr
d2_lr = 2*lr

np.random.seed(0)
e2 = Experiment_StatePopulationsWithChannels(
    n_trials=500, 
    n_trial_types=n_actions,
    n_actions=n_actions,
    
    #with_ach=False,
    
    d1_learning_rate = d1_lr, #2e-2, #5e-3,
    d2_learning_rate = d2_lr, #8e-2, #2e-2 #*4

    n_trial_steps = 50,
    trial_step_duration = 4,
    trial_width_step = 1,
    trial_width = 10, #10,
)

e2.run()
ot2 = np.array(e2.ot)
oi2 = np.array(e2.oi)

## Plotting

In [None]:
# Creamos los marcos

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.patches as mpatches
import matplotlib.lines as mlines

# Crear la figura con un GridSpec
fig = plt.figure(figsize=(12, 6))
gs = gridspec.GridSpec(4, 4, height_ratios=[0.5, 2, 2, 2], width_ratios=[1, 1, 0.3, 1])

fig.text(0.03,    0.95, "A", fontsize=14, fontweight="bold")
fig.text(0.32, 0.95, "B", fontsize=14, fontweight="bold")

# Paneles superiores (Inputs and Reinforcement)
ax_top_left = fig.add_subplot(gs[0, 0])  # Global modulation
ax_top_right = fig.add_subplot(gs[0, 1])  # Channel-specific ACh gating
ax_top_left.set_title("Global modulation", fontsize=12)#, fontweight='bold')
ax_top_right.set_title("Channel-specific ACh gating", fontsize=12)#, fontweight='bold')
for ax in [ax_top_left, ax_top_right]:
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

# Paneles intermedios (Firing rates)
ax_mid_left_top = fig.add_subplot(gs[1, 0], sharex=ax_top_left)  # Channel #1
ax_mid_left_bottom = fig.add_subplot(gs[2, 0], sharex=ax_mid_left_top)  # Channel #2
ax_mid_right_top = fig.add_subplot(gs[1, 1], sharex=ax_top_right)  # Channel #1 (ACh gating)
ax_mid_right_bottom = fig.add_subplot(gs[2, 1], sharex=ax_mid_right_top)  # Channel #2 (ACh gating)
for ax in [ax_mid_left_top, ax_mid_right_top]:
    ax.get_xaxis().set_visible(False)

ax_mid_left_top.set_title("Channel #1", color='C0', fontsize=10, fontweight='bold')
ax_mid_left_bottom.set_title("Channel #2", color='C1', fontsize=10, fontweight='bold')
ax_mid_right_top.set_title("Channel #1", color='C0', fontsize=10, fontweight='bold')
ax_mid_right_bottom.set_title("Channel #2", color='C1', fontsize=10, fontweight='bold')


for ax in [ax_mid_left_top, ax_mid_left_bottom, ax_mid_right_top, ax_mid_right_bottom]:
    ax.set_ylabel("Firing rate (Hz)", fontsize=8)
    ax.set_xlabel("Time (s)", fontsize=8)
    ax.tick_params(axis='both', labelsize=7)

# Panel inferior (D1 Weights Evolution)
ax_bottom_left = fig.add_subplot(gs[3, 0])  # Global modulation weights
ax_bottom_right = fig.add_subplot(gs[3, 1])  # Channel-specific ACh gating weights

ax_bottom_left.set_title("D1 Weights Evolution", fontsize=10)
ax_bottom_right.set_title("D1 Weights Evolution", fontsize=10)
ax_bottom_left.set_xlabel("Time (s)", fontsize=8)
ax_bottom_right.set_xlabel("Time (s)", fontsize=8)
ax_bottom_left.set_ylabel("Mean weight", fontsize=8)
ax_bottom_right.set_ylabel("Mean weight", fontsize=8)
ax_bottom_left.tick_params(axis='both', labelsize=7)
ax_bottom_right.tick_params(axis='both', labelsize=7)

# Quitamos ejes Y derechos
for ax in [ax_mid_right_top, ax_mid_right_bottom, ax_bottom_right]:
    ax.get_yaxis().set_visible(False)


# Leyenda general con secciones usando legend()
ax_legend = fig.add_subplot(gs[:, 2])  # Ocupa todas las filas en la tercera columna
ax_legend.axis("off")  # Ocultar ejes de la leyenda

# Definir secciones de la leyenda
sections = {
    "Stimuli": ["Stimulus #1", "Stimulus #2"],
    "Reinforcement": ["Rewards", "Punishments"],
    "Activity": ["Trial init", "SPN D1", "SPN D2", "Actions"]
}
colors = {
    "Stimulus #1": 'C0', "Stimulus #2": 'C1',
    "Rewards": 'limegreen', "Punishments": 'red',
    "Trial init": 'silver', "SPN D1": 'teal', "SPN D2": 'magenta', "Actions": 'black'
}

# Crear los handles para legend()
handles = []
for section, labels in sections.items():
    handles.append(mpatches.Patch(color='white', label=f"\n{section}"))  # Sección como texto
    handles.extend([mpatches.Patch(color=colors[label], label=label) for label in labels])

ax_legend.legend(handles=handles, loc='center', fontsize=8, frameon=True, handleheight=1.5)

# Ajustar los espacios
plt.tight_layout()
fig.subplots_adjust(wspace=0.1, hspace=0.4)  # Ajusta manualmente espacios

x,y, w,h = ax_bottom_left.get_position().bounds
ax_bottom_left.set_position([x,y-0.05,w,h])  # [x, y, ancho, alto]
x,y, w,h = ax_bottom_right.get_position().bounds
ax_bottom_right.set_position([x,y-0.05,w,h])  # [x, y, ancho, alto]

# Añadir rectángulos de zoom
#rect_bottom = mpatches.Rectangle((0.75, 0.1), 0.2, 0.8, transform=ax_bottom_left.transAxes, color='black', fill=False, linewidth=1.5)
#ax_bottom_left.add_patch(rect_bottom)
# Añadir llaves horizontales
line = mlines.Line2D([0.045, 0.31], [0.28, 0.28], transform=fig.transFigure, color='black', linewidth=1.5); fig.lines.append(line)
line = mlines.Line2D([0.045, 0.04], [0.28, 0.29], transform=fig.transFigure, color='black', linewidth=1.5); fig.lines.append(line)
line = mlines.Line2D([0.31, 0.315], [0.28, 0.29], transform=fig.transFigure, color='black', linewidth=1.5); fig.lines.append(line)

# Añadir rectángulos de zoom
#rect_bottom = mpatches.Rectangle((0.75, 0.1), 0.2, 0.8, transform=ax_bottom_right.transAxes, color='black', fill=False, linewidth=1.5)
#ax_bottom_right.add_patch(rect_bottom)
# Añadir llaves horizontales
line = mlines.Line2D([0.335, 0.6], [0.28, 0.28], transform=fig.transFigure, color='black', linewidth=1.5); fig.lines.append(line)
line = mlines.Line2D([0.335, 0.33], [0.28, 0.29], transform=fig.transFigure, color='black', linewidth=1.5); fig.lines.append(line)
line = mlines.Line2D([0.6, 0.605], [0.28, 0.29], transform=fig.transFigure, color='black', linewidth=1.5); fig.lines.append(line)



# Pintamos las figuras

# Primera columna

dt=6
#t = e1.current_time-dt
t = 1966#1976
#trial_changes = [1994, 1996, 1998]
trial_changes = [t + i*2 for i in range(3)]
from_time = t-50
to_time = t+50
weight_lim = (3.5,9.2)

channels = e1.channels
n_channels = list(range(len(e1.channels)))

# Exclusion periods with dashed boundary lines
axs = [ax_top_left, ax_mid_left_top, ax_mid_left_bottom]
for i, start in enumerate(trial_changes):
    for j, ax in enumerate(axs):
        ax.axvline(x=start+0.1, linestyle='-', color='silver', linewidth=9, alpha=0.5, label='Trial init' if i==0 and j==2 else None, zorder=-10)

# Inputs and rewards
# Rewards and punishments
ax = ax_top_left
ax.set_xlim((t,t+dt))
inputs_sub = np.isin(oi1, e1.input_layer) * (ot1>=t) * (ot1<=(t+dt))
min_y = -np.min(oi1[inputs_sub])
max_y = -np.max(oi1[inputs_sub])
sub = np.isin(oi1, e1.reward_layer)
ax.fill_between(ot1, min_y,max_y, where=sub, color='limegreen', linewidth=10, alpha=0.5)
sub = np.isin(oi1, e1.punish_layer)
ax.fill_between(ot1, min_y,max_y, where=sub, color='red', linewidth=10, alpha=0.5)
# Inputs
sub = np.isin(oi1, e1.input_layer) * (ot1>=t) * (ot1<=(t+dt)) 
n_neurons_per_input = e1.n_inputs // e1.n_actions
colors = (oi1[sub] // n_neurons_per_input)
for icolor in set(colors):
    _sub = colors==icolor
    ax.scatter(ot1[sub][_sub],-oi1[sub][_sub], s=1, color=f'C{icolor}')
    #ax.scatter([],[], s=10, color=f'C{icolor}', label=f'Stimulus #{icolor+1}')
    #ax.fill_between([],[],[], label='Rewards', color='limegreen', alpha=0.5)
    #ax.fill_between([],[],[], label='Punishments', color='red', alpha=0.5)
    #if with_legend: ax.legend(loc='center right', bbox_to_anchor=(1.6, 0.5), ncol=1)

# Channels: MSN activity and actions
axs = [ax_mid_left_top, ax_mid_left_bottom]
for i, (n, c) in enumerate(zip(n_channels, channels)):
    ax = axs[i]
    ax.set_xlim((t,t+dt))
        
    ax.patch.set_alpha(0.0)
    ax.set_ylabel('Firing rate\n(Hz)')
    
    # Mean firing rate
    sub1 = np.isin(oi1, c.msn_d1_layer)
    sub2 = np.isin(oi1, c.msn_d2_layer)
    x,y = smooth_spikes(t,t+dt, ot1[sub1], sigma_s=3000.0, resolution=100000)
    y /= channels[0].n_outputs
    ax.plot(x,y, color='teal', label='SPN D1', linewidth=2.0)
    x,y = smooth_spikes(t,t+dt, ot1[sub2], sigma_s=3000.0, resolution=100000)
    y /= channels[0].n_outputs
    ax.plot(x,y, color='magenta', label='SPN D2', linewidth=2.0)
    #ax.plot([],[], color='black', label='Actions', linewidth=2.0)
    #if i==(len(channels)-1) and with_legend: ax.legend(loc='lower right', bbox_to_anchor=(1.6, 0.75), ncol=1)
    ax.set_ylim((0,100))

    # Title and lims
    _ax = ax.twinx()
    #ax.yaxis.set_label_position("right")
    #ax.yaxis.tick_right()
    
    ax = _ax
    ax.zorder = -1
    ax.get_yaxis().set_visible(False)
    #ax.set_title(f'Channel #{n+1}', color=f'C{n}', fontweight='bold')
    ax.set_xlim((t,t+dt))
    
    # Actions
    sub = oi1==e1.action_layer[n]
    ax.vlines(ot1[sub], min(c.msn_d1_layer), max(c.msn_d2_layer), color='black', linewidth=2)
    
    # Raster plot
    if dt <= 20:
        ax.scatter(ot1[sub1],oi1[sub1], s=1, color='skyblue')
        ax.scatter(ot1[sub2],oi1[sub2], s=1, color='violet')
#ax_mid_left_bottom.set_xlabel('Time (s)')    

# Weights
x,y = smooth_spikes(0,e1.current_time, ot1[oi1==e1.reward_layer[0]], sigma_s=20.0e3)
time = x
d1w = np.array(e1.channel_d1_weights)
    
conns_per_neuron = d1w.shape[2] // e1.channels[0].n_outputs
conns_per_stimulus = conns_per_neuron // n_actions
sub_stim = (np.arange(d1w.shape[2]) // conns_per_stimulus) % n_actions
for i_chn in [0,1]:#range(n_actions):
    for i_stm in range(2):
        if i_stm==0:
            sub = sub_stim==i_chn
        else:
            sub = sub_stim==(i_chn+1) % n_actions
        ws = d1w[i_chn, :, sub]
        mean = np.mean(ws, axis=0)
        std = np.std(ws, axis=0)
        se = std / np.sqrt(np.sum(sub))
        time = np.linspace(time[0], time[-1], num=mean.size)
        label = f'Chn{i_chn+1}/Stim{i_chn+1}' if i_stm==0 else f'Chn{i_chn+1}/Stim{((i_chn-1)%n_actions)+1}' 
        ax_bottom_left.plot(time, mean, color=f'C{i_chn}', linestyle='-' if i_stm==0 else '--', label=label)
ax_bottom_left.fill_between(time, weight_lim[0]+0.25,weight_lim[1]-0.25, where=(time>=from_time)*(time<=to_time), color='none', edgecolor='black', linewidth=2)
ax_bottom_left.set_ylim(weight_lim)
#if with_legend: ax_bottom_left.legend(loc='center', bbox_to_anchor=(x_legend, 0.5), ncol=1)
#ax_bottom_left.set_title('D1 Weights Evolution')
#ax_bottom_left.set_ylabel('Mean weight')

# Segunda columna

dt=6
t = e2.current_time-dt
trial_changes = [994, 996, 998]
from_time = 950
to_time = 1050

channels = e2.channels
n_channels = list(range(len(e2.channels)))

# Exclusion periods with dashed boundary lines
axs = [ax_top_right, ax_mid_right_top, ax_mid_right_bottom]
for i, start in enumerate(trial_changes):
    for j, ax in enumerate(axs):
        ax.axvline(x=start+0.1, linestyle='-', color='silver', linewidth=15, alpha=0.5, label='Trial init' if i==0 and j==2 else None, zorder=-10)

# Inputs and rewards
# Rewards and punishments
ax = ax_top_right
ax.set_xlim((t,t+dt))
inputs_sub = np.isin(oi2, e2.input_layer) * (ot2>=t) * (ot2<=(t+dt))
min_y = -np.min(oi2[inputs_sub])
max_y = -np.max(oi2[inputs_sub])
#ax.get_yaxis().set_visible(False)
#ax.set_title('Inputs and Reinforcement')
sub = np.isin(oi2, e2.reward_layer)
ax.fill_between(ot2, min_y,max_y, where=sub, color='limegreen', linewidth=10, alpha=0.5)
sub = np.isin(oi2, e2.punish_layer)
ax.fill_between(ot2, min_y,max_y, where=sub, color='red', linewidth=10, alpha=0.5)
# Inputs
sub = np.isin(oi2, e2.input_layer) * (ot2>=t) * (ot2<=(t+dt)) 
n_neurons_per_input = e2.n_inputs // e2.n_actions
colors = (oi2[sub] // n_neurons_per_input)
for icolor in set(colors):
    _sub = colors==icolor
    ax.scatter(ot2[sub][_sub],-oi2[sub][_sub], s=1, color=f'C{icolor}')
    #ax.scatter([],[], s=10, color=f'C{icolor}', label=f'Stimulus #{icolor+1}')
    #ax.fill_between([],[],[], label='Rewards', color='limegreen', alpha=0.5)
    #ax.fill_between([],[],[], label='Punishments', color='red', alpha=0.5)
    #if with_legend: ax.legend(loc='center right', bbox_to_anchor=(1.6, 0.5), ncol=1)

# Channels: MSN activity and actions
axs = [ax_mid_right_top, ax_mid_right_bottom]
for i, (n, c) in enumerate(zip(n_channels, channels)):
    ax = axs[i]
        
    ax.patch.set_alpha(0.0)
    ax.set_ylabel('Firing rate\n(Hz)')
    
    # Mean firing rate
    sub1 = np.isin(oi2, c.msn_d1_layer)
    sub2 = np.isin(oi2, c.msn_d2_layer)
    x,y = smooth_spikes(t,t+dt, ot2[sub1], sigma_s=3000.0, resolution=100000)
    y /= channels[0].n_outputs
    ax.plot(x,y, color='teal', label='SPN D1', linewidth=2.0)
    x,y = smooth_spikes(t,t+dt, ot2[sub2], sigma_s=3000.0, resolution=100000)
    y /= channels[0].n_outputs
    ax.plot(x,y, color='magenta', label='SPN D2', linewidth=2.0)
    #ax.plot([],[], color='black', label='Actions', linewidth=2.0)
    #if i==(len(channels)-1) and with_legend: ax.legend(loc='lower right', bbox_to_anchor=(1.6, 0.75), ncol=1)
    ax.set_ylim((0,100))

    # Title and lims
    _ax = ax.twinx()
    #ax.yaxis.set_label_position("right")
    #ax.yaxis.tick_right()
    
    ax = _ax
    ax.zorder = -1
    ax.get_yaxis().set_visible(False)
    #ax.set_title(f'Channel #{n+1}', color=f'C{n}', fontweight='bold')
    ax.set_xlim((t,t+dt))
    
    # Actions
    sub = oi2==e2.action_layer[n]
    ax.vlines(ot2[sub], min(c.msn_d1_layer), max(c.msn_d2_layer), color='black', linewidth=2)
    
    # Raster plot
    if dt <= 20:
        ax.scatter(ot2[sub1],oi2[sub1], s=1, color='skyblue')
        ax.scatter(ot2[sub2],oi2[sub2], s=1, color='violet')
#ax_mid_left_bottom.set_xlabel('Time (s)')    

# Weights
x,y = smooth_spikes(0,e2.current_time, ot2[oi2==e2.reward_layer[0]], sigma_s=20.0e3)
time = x
d1w = np.array(e2.channel_d1_weights)
    
conns_per_neuron = d1w.shape[2] // e2.channels[0].n_outputs
conns_per_stimulus = conns_per_neuron // n_actions
sub_stim = (np.arange(d1w.shape[2]) // conns_per_stimulus) % n_actions
for i_chn in [0,1]:#range(n_actions):
    for i_stm in range(2):
        if i_stm==0:
            sub = sub_stim==i_chn
        else:
            sub = sub_stim==(i_chn+1) % n_actions
        ws = d1w[i_chn, :, sub]
        mean = np.mean(ws, axis=0)
        std = np.std(ws, axis=0)
        se = std / np.sqrt(np.sum(sub))
        time = np.linspace(time[0], time[-1], num=mean.size)
        label = f'Chn{i_chn+1}/Stim{i_chn+1}' if i_stm==0 else f'Chn{i_chn+1}/Stim{((i_chn-1)%n_actions)+1}' 
        ax_bottom_right.plot(time, mean, color=f'C{i_chn}', linestyle='-' if i_stm==0 else '--', label=label)
ax_bottom_right.fill_between(time, weight_lim[0]+0.25,weight_lim[1]-0.25, where=(time>=from_time)*(time<=to_time), color='none', edgecolor='black', linewidth=2)
ax_bottom_right.set_ylim(weight_lim)
#if with_legend: ax_bottom_left.legend(loc='center', bbox_to_anchor=(x_legend, 0.5), ncol=1)
#ax_bottom_left.set_title('D1 Weights Evolution')
#ax_bottom_left.set_ylabel('Mean weight')

plt.savefig(output_figs+"figure_3.svg", bbox_inches="tight")
plt.savefig(output_figs+"figure_3.png", dpi=300, bbox_inches="tight")
plt.show()

# Figure 4

In [None]:
import itertools
from tqdm.auto import tqdm
import numpy as np
import pandas as pd
import json
import os
from matplotlib import pyplot as plt
import matplotlib as mpl
from scipy.stats import norm, chi2
import random

def product_dict(**kwargs):
    keys = kwargs.keys()
    for instance in itertools.product(*kwargs.values()):
        yield dict(zip(keys, instance))
        
data_file = 'data/experimental_data_nov14params.csv'

## Subpanel A

In [None]:
class StimulusActionMapping:
    def __init__(self, n_stimuli, alpha=0.1, gamma=0.9, epsilon=0.1):
        self.n_stimuli = n_stimuli
        self.alpha = alpha  # Learning rate
        self.gamma = gamma  # Discount factor
        self.epsilon = epsilon  # Exploration rate
        self.q_table = np.zeros((n_stimuli, n_stimuli))  # Q-table, initialized to zero
        
        # Generate random mappings for rewards: stimulus i maps to action i
        self.correct_action = {i: i for i in range(n_stimuli)}

    def get_reward(self, stimulus, action):
        """Reward is 1 if action matches stimulus mapping; otherwise -1."""
        return 1 if self.correct_action[stimulus] == action else -1

    def choose_action(self, stimulus):
        """Choose an action using epsilon-greedy strategy."""
        if random.uniform(0, 1) < self.epsilon:
            return random.randint(0, self.n_stimuli - 1)  # Explore: random action
        else:
            return np.argmax(self.q_table[stimulus])  # Exploit: best known action

    def update_q_table(self, stimulus, action, reward, next_action):
        """Update Q-table using the Q-learning update rule."""
        best_future_q = np.max(self.q_table[stimulus])
        self.q_table[stimulus, action] += self.alpha * (reward + self.gamma * best_future_q - self.q_table[stimulus, action])

    def train(self, episodes=1000):
        """Train the agent over a specified number of episodes and track accuracy."""
        accuracy_over_time = []
        for episode in range(episodes):
            stimulus = random.randint(0, self.n_stimuli - 1)  # Random stimulus
            action = self.choose_action(stimulus)
            reward = self.get_reward(stimulus, action)
            self.update_q_table(stimulus, action, reward, action)

            # Calculate accuracy every 10 episodes
            if episode % 10 == 0:
                correct_count = sum(
                    np.argmax(self.q_table[stimulus]) == self.correct_action[stimulus]
                    for stimulus in range(self.n_stimuli)
                )
                accuracy = correct_count / self.n_stimuli
                accuracy_over_time.append(accuracy)

        return accuracy_over_time

    def test(self, stimulus):
        """Test the agent's learned action for a given stimulus."""
        return np.argmax(self.q_table[stimulus])

episodes = 500
simulations = 100
confidence_level = 0.95

# Guardando los resultados de los agentes RL
df_rl = pd.DataFrame(columns=['n_actions', 'accuracy'])

for i, n_stimuli in enumerate(tqdm([2, 4, 8, 16, 32])):
    x_dom = list(range(0,episodes,10))
    for _ in (range(simulations)):
        agent = StimulusActionMapping(
            n_stimuli, 
            alpha = 0.1,
            gamma = 0.9,
            epsilon = 0.1
        )
        accuracy_over_time = agent.train(episodes=episodes)
        df_rl.loc[len(df_rl)] = [n_stimuli, np.array(accuracy_over_time)]

# Cargar el CSV
df_loaded = pd.read_csv(data_file)

# Obtener algunas constantes
times = np.array(json.loads(df_loaded['times'].iloc[0]))
time_per_trial = times[-1] / df_loaded['n_trials'].iloc[0]
episode = times / time_per_trial

# Convertir la columna JSON a listas de Python
df_loaded['accuracy'] = df_loaded['accuracy'].apply(json.loads)
df_loaded['accuracy'] = df_loaded['accuracy'].apply(np.array)
df_loaded['accuracy'] = df_loaded['accuracy'] * time_per_trial

In [None]:
# Normal plot
# Parámetros
confidence_level = 0.95

t_alpha_over_2 = norm.ppf(1 - (1 - confidence_level) / 2)
fig, ax = plt.subplots(figsize=(6, 4.5))


# Agentes RL
grouped = df_rl.groupby(['n_actions'])['accuracy'].apply(
    lambda x: (
        np.mean(np.stack(x), axis=0),
        np.std(np.stack(x), axis=0),
        len(x)
    )
)
mpl.rcParams['hatch.linewidth'] = 3.0  # previous pdf hatch linewidth
for i, ((n_actions), (m, s, n)) in enumerate(grouped.items()):
    if n_actions <= 2: continue
    se = s / np.sqrt(n)
    ci = t_alpha_over_2 * se
    
    label = f"{n_actions} actions"
    ax.fill_between(x_dom, m-ci, m+ci+3e-3, color='white', alpha=0.5, edgecolor=f'C{i}', linewidth=1.0, linestyle=':', hatch = '\\\\')
    #ax.fill_between(x_dom, m-ci, m+ci, color=f'C{i}', alpha=1.0, edgecolor='none', linewidth=0.0)#, hatch = '\\\\')
    #ax.plot(x_dom, m, label=label, color=f'C{i}', linewidth=2.0)


# Agentes SNN
grouped = df_loaded.groupby(['n_actions', 'with_ach'])['accuracy'].apply(
    lambda x: (
        np.mean(np.stack(x), axis=0),
        np.std(np.stack(x), axis=0),
        len(x)
    )
)

# Graficar cada combinación de `n_actions` y `with_ach`
for i, ((n_actions, _with_ach), (m, s, n)) in enumerate(grouped.items()):
    #if n_actions <= 2: continue
    se = s / np.sqrt(n)
    ci = t_alpha_over_2 * se
    
    label = f"{n_actions} actions"
    i_col = int(np.log(n_actions)/np.log(2)) - 1
    ax.plot(episode, m, label=label, color=f'C{i_col}')
    ax.fill_between(episode, m-ci, m+ci, color=f'C{i_col}', alpha=0.5)


# Personalizar el gráfico
ax.set_title("Mean Accuracy Over Time (CI 95%)")
ax.set_xlabel("Episode")
ax.set_ylabel("Accuracy")
ax.legend(loc='lower right')
ax.set_ylim((5e-2,1))
#ax.set_xscale('log'); ax.set_xlim((10,500))
#ax.set_yscale('log')

plt.savefig(tmp_figs+"figure_4a.svg", bbox_inches="tight")
plt.savefig(tmp_figs+"figure_4a.png", dpi=300, bbox_inches="tight")

plt.show()

## Subpanel B

In [None]:
import numpy as np
from scipy.stats import linregress
import matplotlib.pyplot as plt

# Function to process grouped data and calculate log values
def process_grouped_data(grouped, x_dom, skip, logy_threshold, is_snn=False):
    """
    Processes grouped data for RL or SNN agents.

    Parameters:
        grouped (dict): Grouped data.
        x_dom (list): Domain (episodes) for RL or SNN.
        skip (int): Number of initial episodes to skip.
        logy_threshold (float): Threshold for valid log(accuracy).
        is_snn (bool): Whether the data belongs to SNN agents.

    Returns:
        log_episodes: Log-transformed x-axis data.
        log_means: Log-transformed accuracy means.
        slopes: Calculated slopes for each group.
    """
    log_episodes = np.log10(x_dom[skip:])  # Log-transform x-axis
    log_means = {}

    for key, values in grouped.items():
        # Handle different grouped structures
        if is_snn:
            mean_array = values[0][skip:]  # Use first array in SNN tuple
        else:
            mean_array = values[0][skip:]  # Use first array in RL tuple

        # Avoid log(0) and store log-transformed data
        log_means[key] = np.log10(mean_array + 1e-10)

    # Compute slopes
    slopes = {}
    for key, log_mean in log_means.items():
        valid_mask = log_mean < logy_threshold  # Filter valid log points
        if np.sum(valid_mask) == 0:  # Skip if no valid points
            continue
        slope, intercept, r_value, _, _ = linregress(log_episodes[valid_mask], log_mean[valid_mask])
        slopes[key] = (slope, intercept, r_value**2)
        print(f"Key: {key}, Slope: {slope:.3f}, R²: {r_value**2:.3f}")

    return log_episodes, log_means, slopes

# Function to plot data and fitted lines
def plot_log_data(ax, log_episodes, log_means, slopes, label_prefix, plot_real_data, logy_threshold, color_offset=0, linestyle='-'):
    """
    Plots raw data and fitted lines for RL or SNN agents.

    Parameters:
        ax (matplotlib.Axes): Axis object for plotting.
        log_episodes (array): Log-transformed episodes.
        log_means (dict): Log-transformed accuracy means.
        slopes (dict): Computed slopes.
        label_prefix (str): Label prefix for legend.
        plot_real_data (bool): Whether to plot raw data.
        logy_threshold (float): Threshold for valid log(accuracy).
        color_offset (int): Offset for color indexing.
    """
    for i, (key, log_mean) in enumerate(log_means.items()):
        if isinstance(key, tuple):
            n_actions = key[0]
        else:
            n_actions = key

        if n_actions <= 4:  # Skip small action spaces
            continue

        # Plot raw data if required
        if plot_real_data:
            valid_mask = log_mean < logy_threshold
            if np.sum(valid_mask) == 0:
                continue
            ax.plot(log_episodes[valid_mask], log_mean[valid_mask], linestyle=linestyle, color=f'C{i + color_offset}', alpha=0.5)

        # Plot fitted line
        if key in slopes:
            slope, intercept, _ = slopes[key]
            fitted_line = slope * log_episodes + intercept
            ax.plot(log_episodes, fitted_line, linestyle=linestyle, color=f'C{i + color_offset}', label=f"{label_prefix} {n_actions} actions")

# Main script
skip = 2
logy_threshold = -0.1
plot_real_data = False

fig, ax = plt.subplots(figsize=(6, 4.5))

# RL Agents
x_dom_rl = list(range(0, 500, 10))  # Domain for RL episodes
grouped_rl = df_rl.groupby(['n_actions'])['accuracy'].apply(
    lambda x: (np.mean(np.stack(x), axis=0), np.std(np.stack(x), axis=0), len(x))
)
log_episodes_rl, log_means_rl, slopes_rl = process_grouped_data(grouped_rl, x_dom_rl, skip, logy_threshold)
plot_log_data(ax, log_episodes_rl, log_means_rl, slopes_rl, "QL", plot_real_data, logy_threshold, linestyle='--')

# SNN Agents
x_dom_snn = list(range(len(df_loaded.iloc[0]['accuracy'])))  # Domain for SNN episodes
grouped_snn = df_loaded.groupby(['n_actions', 'with_ach'])['accuracy'].apply(
    lambda x: (np.mean(np.stack(x), axis=0), np.std(np.stack(x), axis=0), len(x))
)
log_episodes_snn, log_means_snn, slopes_snn = process_grouped_data(grouped_snn, x_dom_snn, skip, logy_threshold, is_snn=True)
plot_log_data(ax, log_episodes_snn, log_means_snn, slopes_snn, "SNN", plot_real_data, logy_threshold, color_offset=1)

# Customize the plot
ax.set_title("Log-Log Plot with Fitted Lines")
ax.set_xlabel("Log10(Episode)")
ax.set_ylabel("Log10(Accuracy)")
ax.set_ylim((-1.5, 0.0))
ax.set_xlim((1.5, 3))
ax.legend()

plt.savefig(tmp_figs+"figure_4b.svg", bbox_inches="tight")
plt.savefig(tmp_figs+"figure_4b.png", dpi=300, bbox_inches="tight")

plt.show()

## All

In [None]:
import svgutils.transform as sg
import sys 

#create new SVG figure
fig = sg.SVGFigure()

# load matpotlib-generated figures
fig1 = sg.fromfile(tmp_figs + 'figure_4a.svg')
fig2 = sg.fromfile(tmp_figs + 'figure_4b.svg')

# get the plot objects
plot1 = fig1.getroot()
plot2 = fig2.getroot()
plot2.moveto(10.5*cm2px, 0*cm2px)

# add text labels
txt1 = sg.TextElement("1mm", "5mm", "A", size=18, weight="bold")
txt2 = sg.TextElement("10.5cm", "5mm", "B", size=18, weight="bold")

# append plots and labels to figure
fig.append([plot1, plot2])
fig.append([txt1, txt2])

# save generated SVG files
fig.save(output_figs + "figure_4.svg")

# Figure 5

## Simulating

In [None]:
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import patches

from scipy.ndimage import gaussian_filter1d
def smooth_spikes(t,dt, spike_times, sigma_s=200.0, resolution=1000):
    sub = (spike_times>=t) * (spike_times<=(t+dt))
    spike_times = spike_times[sub]
    hist, times = np.histogram(spike_times, bins=resolution, range=(t,t+dt))
    signal = gaussian_filter1d(hist.astype('float'), sigma_s/dt)*resolution/dt
    return times[:-1], signal

from experiments import Experiment_StatePopulationWithChannelsReLearning


n_actions = 4
lr = 1e-2
d1_lr = lr
d2_lr = 5*lr

np.random.seed(1)
n_trials = 500
e = Experiment_StatePopulationWithChannelsReLearning(
    n_trials=n_trials, 
    n_trial_types=n_actions,
    n_actions=n_actions,
    
    #with_ach=False,
    
    d1_learning_rate = d1_lr, #2e-2, #5e-3,
    d2_learning_rate = d2_lr, #8e-2, #2e-2 #*4

    n_trial_steps = 50,
    trial_step_duration = 4,
    trial_width_step = 1,
    trial_width = 10, #10,
)

for i in range(n_actions):
    e.experiment_design[i] = i
    e.experiment_design[n_trials//2 - n_actions + i] = i
    e.experiment_design[n_trials//2 + i] = i
    e.experiment_design[n_trials - n_actions + i] = i
    

for current_time, i_trial, trial in e.run(as_generator=True):
    if i_trial == 249:
        remap = np.roll(range(e.reward_policy.shape[0]), 1)
        e.reward_policy = e.reward_policy[remap, :]

ot = np.array(e.ot)
oi = np.array(e.oi)

## Plotting

### Subplot A

In [None]:
fig, ax = plt.subplots(figsize=(10,1.5))
x,y = smooth_spikes(0,e.current_time, ot[oi==e.reward_layer[0]], sigma_s=10.0e3)
episode = x // 2
ax.plot(episode, y*200, color='green', label='rewards')
x, y = smooth_spikes(0,e.current_time, ot[oi==e.punish_layer[0]], sigma_s=10.0e3)
ax.plot(episode, y*200, color='red', label='punishments')
ax.legend(loc='center right')
ax.set_xlabel('Episode')
ax.set_ylabel('% of events')
# plt.xlim((0,500))
ax.set_ylim((-5,105))
#_ = plt.title(f'Reinforcement history with {e.n_actions} actions and {"with" if e.with_ach else "without"} ACh (d1_lr={d1_lr}, d2_lr={d2_lr})')
_ = ax.set_title(f'Reinforcement history with {e.n_actions} actions')
ax.vlines(250, -10,110, linestyle=':', linewidth=2, color='black')


# Coordenadas de la anotación
xy_target = (250, 50)  # Punto al que apunta la flecha
xy_text = (150, 50)    # Ubicación del cuadro de texto

# Crear la anotación con un cuadro de texto de bordes redondeados
ax.annotate(
    "Policy\nswitch",  # Texto dentro del cuadro
    xy=xy_target,       # Punto de la gráfica al que apunta la flecha
    xytext=xy_text,     # Posición del cuadro de texto
    fontsize=10, 
    color="black",
    bbox=dict(boxstyle="round,pad=0.4", fc="lightyellow", ec="black", lw=1),  # Fondo amarillo pastel con bordes redondeados
    arrowprops=dict(arrowstyle="->", color="black", lw=1.5),  # Flecha negra apuntando
)

# Definir rectángulos para sombrear regiones
rect = patches.Rectangle(
    (0, -2.5),  # Coordenadas (x, y) de la esquina inferior izquierda
    8,  # Ancho del rectángulo
    105,  # Alto del rectángulo
    linewidth=1.5,  # Grosor del borde
    edgecolor="black",  # Color del borde
    facecolor="gray",  # Color de fondo (gris)
    alpha=0.5  # Transparencia (0 = totalmente transparente, 1 = opaco)
)
ax.add_patch(rect)
rect = patches.Rectangle(
    (250-10, -2.5),  # Coordenadas (x, y) de la esquina inferior izquierda
    8,  # Ancho del rectángulo
    105,  # Alto del rectángulo
    linewidth=1.5,  # Grosor del borde
    edgecolor="black",  # Color del borde
    facecolor="gray",  # Color de fondo (gris)
    alpha=0.5  # Transparencia (0 = totalmente transparente, 1 = opaco)
)
ax.add_patch(rect)
rect = patches.Rectangle(
    (253, -2.5),  # Coordenadas (x, y) de la esquina inferior izquierda
    8,  # Ancho del rectángulo
    105,  # Alto del rectángulo
    linewidth=1.5,  # Grosor del borde
    edgecolor="black",  # Color del borde
    facecolor="gray",  # Color de fondo (gris)
    alpha=0.5  # Transparencia (0 = totalmente transparente, 1 = opaco)
)
ax.add_patch(rect)
rect = patches.Rectangle(
    (500-8, -2.5),  # Coordenadas (x, y) de la esquina inferior izquierda
    8,  # Ancho del rectángulo
    105,  # Alto del rectángulo
    linewidth=1.5,  # Grosor del borde
    edgecolor="black",  # Color del borde
    facecolor="gray",  # Color de fondo (gris)
    alpha=0.5  # Transparencia (0 = totalmente transparente, 1 = opaco)
)
ax.add_patch(rect)


plt.savefig(tmp_figs+"figure_5a.svg", bbox_inches="tight")
plt.savefig(tmp_figs+"figure_5a.png", dpi=300, bbox_inches="tight")
plt.show()

In [None]:
def plot_spikes(t=None, dt=10, xsize=2.5, ysize=0.75, yaxis=False, channels=None, x_legend=None, color_shift=0, with_legend=False): 
    t = e.current_time-dt if t is None else t
    if channels is None:
        channels = e.channels
        n_channels = list(range(len(e.channels)))
    else:
        n_channels = channels[:]
        channels = [e.channels[i_chn] for i_chn in channels]
    trial_changes = [i for i in range(int(t), int(t+dt+1)) if i%2==0]

    plt.rcParams['figure.figsize'] = (xsize, ysize*(2+len(channels)))
    fig, axs = plt.subplots(
        1+len(channels),1, sharex=True, 
        height_ratios=[0.5] + [1]*len(channels) 
    )
    
        
    # Inputs and rewards
    ax1 = axs[0]
    # Trial changes
    for i, start in enumerate(trial_changes):
        ax1.axvline(x=start+0.1, linestyle='-', color='silver', linewidth=5, label='Trial start' if i==0 else None, zorder=-10)
    # Rewards and punishments
    inputs_sub = np.isin(oi, e.input_layer) * (ot>=t) * (ot<=(t+dt))
    min_y = -np.min(oi[inputs_sub])
    max_y = -np.max(oi[inputs_sub])
    ax1.get_yaxis().set_visible(False)
    ax1.set_title('Inputs and Reinforcement')
    sub = np.isin(oi, e.reward_layer)
    ax1.fill_between(ot, min_y,max_y, where=sub, color='limegreen', linewidth=10, alpha=0.5)
    sub = np.isin(oi, e.punish_layer)
    ax1.fill_between(ot, min_y,max_y, where=sub, color='red', linewidth=10, alpha=0.5)
    # Inputs
    sub = np.isin(oi, e.input_layer) * (ot>=t) * (ot<=(t+dt)) 
    n_neurons_per_input = e.n_inputs // e.n_actions
    colors = (oi[sub] // n_neurons_per_input)
    for icolor in set(colors):
        _sub = colors==icolor
        ax1.scatter(ot[sub][_sub],-oi[sub][_sub], s=1, color=f'C{icolor}')
        ax1.scatter([],[], s=10, color=f'C{icolor}', label=f'Stimulus #{icolor+1}')
    ax1.fill_between([],[],[], label='Rewards', color='limegreen', alpha=0.5)
    ax1.fill_between([],[],[], label='Punishments', color='red', alpha=0.5)
    if with_legend: ax1.legend(loc='center right', bbox_to_anchor=(1.6, 0.5), ncol=1)
        

    # Channels: MSN activity and actions
    for i, (n, c) in enumerate(zip(n_channels, channels)):
        ax = axs[1+i]

        # Trial changes
        for i, start in enumerate(trial_changes):
            ax.axvline(x=start+0.1, linestyle='-', color='silver', linewidth=5, label='Trial start' if i==0 else None, zorder=-10)
        
        ax.patch.set_alpha(0.0)
        ax.set_ylabel('Firing rate\n(Hz)')
        # Mean firing rate
        sub1 = np.isin(oi, c.msn_d1_layer)
        sub2 = np.isin(oi, c.msn_d2_layer)
        x,y = smooth_spikes(t,t+dt, ot[sub1], sigma_s=3000.0, resolution=100000)
        y /= channels[0].n_outputs
        ax.plot(x,y, color='teal', label='MSN D1', linewidth=2.0)
        x,y = smooth_spikes(t,t+dt, ot[sub2], sigma_s=3000.0, resolution=100000)
        y /= channels[0].n_outputs
        ax.plot(x,y, color='magenta', label='MSN D2', linewidth=2.0)
        ax.plot([],[], color='black', label='Actions', linewidth=2.0)
        if i==(len(channels)-1) and x_legend is not None: ax.legend(loc='lower right', bbox_to_anchor=(x_legend, 0.5), ncol=1)
        ax.set_ylim((0,100))

        # Title and lims
        _ax = ax.twinx()
        ax.yaxis.set_label_position("right")
        ax.yaxis.tick_right()
        if yaxis==False:
            ax.get_yaxis().set_visible(False)

        ax = _ax
        ax.zorder = -1
        ax.get_yaxis().set_visible(False)
        ax.set_title(f'Channel #{n+1}', color=f'C{(n+color_shift)%len(channels)}', fontweight='bold')
        ax.set_xlim((t,t+dt))
    
        # Actions
        sub = oi==e.action_layer[n]
        ax.vlines(ot[sub], min(c.msn_d1_layer), max(c.msn_d2_layer), color='black')

        # Raster plot
        if dt <= 20:
            ax.scatter(ot[sub1],oi[sub1], s=1, color='skyblue')
            ax.scatter(ot[sub2],oi[sub2], s=1, color='violet')

    axs[-1].set_xlabel('Time (s)')
    plt.tight_layout()

### Subplot B

In [None]:
plot_spikes(0, dt=8)
plt.savefig(tmp_figs+"figure_5b.svg", bbox_inches="tight")
plt.savefig(tmp_figs+"figure_5b.png", dpi=300, bbox_inches="tight")

### Subplot C

In [None]:
plot_spikes(t=492, dt=8)
plt.savefig(tmp_figs+"figure_5c.svg", bbox_inches="tight")
plt.savefig(tmp_figs+"figure_5c.png", dpi=300, bbox_inches="tight")

### Subplot D

In [None]:
plot_spikes(t=500, dt=8, color_shift=1)
plt.savefig(tmp_figs+"figure_5d.svg", bbox_inches="tight")
plt.savefig(tmp_figs+"figure_5d.png", dpi=300, bbox_inches="tight")

### Subplot E

In [None]:
plot_spikes(dt=8, color_shift=1, xsize=3.1, yaxis=True)
plt.savefig(tmp_figs+"figure_5e.svg", bbox_inches="tight")
plt.savefig(tmp_figs+"figure_5e.png", dpi=300, bbox_inches="tight")

### All

In [None]:
# These images are too big to work with them vectorized, so I'll just merge them manually as pngs
if False:
    import svgutils.transform as sg
    import sys 
    
    #create new SVG figure
    fig = sg.SVGFigure()
    
    # load matpotlib-generated figures
    fig1 = sg.fromfile(tmp_figs + 'figure_5a.svg')
    fig2 = sg.fromfile(tmp_figs + 'figure_5b.svg')
    fig3 = sg.fromfile(tmp_figs + 'figure_5c.svg')
    fig4 = sg.fromfile(tmp_figs + 'figure_5d.svg')
    fig5 = sg.fromfile(tmp_figs + 'figure_5e.svg')
    
    # get the plot objects
    plot1 = fig1.getroot()
    plot2 = fig2.getroot()
    plot3 = fig3.getroot()
    plot4 = fig4.getroot()
    plot5 = fig5.getroot()
    plot2.moveto(10.5*cm2px, 0*cm2px, scale_x=1, scale_y=1)
    
    # add text labels
    txt1 = sg.TextElement("1mm", "5mm", "A", size=18, weight="bold")
    txt2 = sg.TextElement("10.5cm", "5mm", "B", size=18, weight="bold")
    txt3 = sg.TextElement("10.5cm", "5mm", "B", size=18, weight="bold")
    txt4 = sg.TextElement("10.5cm", "5mm", "B", size=18, weight="bold")
    txt5 = sg.TextElement("10.5cm", "5mm", "B", size=18, weight="bold")
    
    # append plots and labels to figure
    fig.append([plot1, plot2, plot3, plot4, plot5])
    fig.append([txt1, txt2, txt3, txt4, txt5])
    
    # save generated SVG files
    fig.save(output_figs + "figure_5.svg")