In [None]:
import plotly.io as pio
pio.renderers.default = "plotly_mimetype+notebook_connected"

colorscale = "Viridis"

In [None]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import numpy as np
import matplotlib.pyplot as plt

In [None]:
import brainpy as bp
import brainpy.math as bm

bm.enable_x64()
bm.set_platform('cpu')

## Decision Making


![](figs/shadlen_newsome.jpeg){fig-align="center" width="600"}

::: {style="text-align: right"}
[Shadlen and Newsome, 2001](https://doi.org/10.1152/jn.2001.86.4.1916)
:::

## Double-Well: Decision Making

:::: {.columns}

::: {.column}

In [None]:
WEE = 2.9
WEI = -1
WIE = 1
gamma = 1.8
sigma = 0.3

tauE = 0.05

JI = -gamma * WEI * WIE
JE = WEE - JI

@bp.odeint
def int_s1(s1, t, s2, coh=0, drive=0.1, JE=JE, JI=JI):
    fct = 2*s1*(1.-s1)/tauE
    cnv = 0.5 + bm.atanh(2*s1 - 1.)
    return - cnv * fct + JE*s1 * fct - JI*s2 * fct + (drive + coh) *fct

@bp.odeint
def int_s2(s2, t, s1, coh=0, drive=0.1, JE=JE, JI=JI):
    fct = 2*s2*(1.-s2)/tauE
    cnv = 0.5 + bm.atanh(2*s2 - 1.)
    return - cnv * fct + JE*s2 * fct - JI*s1 * fct + (drive + coh) *fct


analyzer = bp.analysis.PhasePlane2D(
    model=[int_s1, int_s2],
    target_vars={'s1': [0, 1], 's2': [0, 1]},
    pars_update={'drive': 0.1, 'coh': 0},
    resolutions=0.001,
)
analyzer.plot_vector_field()
analyzer.plot_nullcline(coords=dict(r2='r2-r1'),
                        x_style={'fmt': '-'},
                        y_style={'fmt': '-'})
analyzer.plot_fixed_point()
plt.gca().set_box_aspect(1)
plt.tight_layout()

:::

::: {.column .fragment}

In [None]:
dt = 0.005
time = np.arange(0, 10, dt)
Nt = len(time)

def curr_to_rate(x):
    return (1+np.tanh(x-0.5))/2

def get_traces(drive, coh):
    # initialize inputs 
    inputE1 = np.zeros((Nt,))
    inputE2 = np.zeros((Nt,))

    inputE1[Nt//4:] = drive + coh
    inputE2[Nt//4:] = drive - coh

    inpE1 = -0.005*np.ones((Nt,))
    inpE2 = -0.005*np.ones((Nt,))

    for i in range(Nt-1):
        noise = sigma * np.random.randn()
        inpE1[i+1] = inpE1[i] + dt/tauE*(-inpE1[i] + JE*curr_to_rate(inpE1[i]) - JI*curr_to_rate(inpE2[i]) + inputE1[i] + noise)
        noise = sigma * np.random.randn()
        inpE2[i+1] = inpE2[i] + dt/tauE*(-inpE2[i] + JE*curr_to_rate(inpE2[i]) - JI*curr_to_rate(inpE1[i]) + inputE2[i] + noise)

    rateE1 = curr_to_rate(inpE1)
    rateE2 = curr_to_rate(inpE2)

    return rateE1, rateE2, inputE1, inputE2

# Add traces, one for each slider step

frames = []
for step in range(20):
    fig = make_subplots(rows=2, cols=1, row_heights=[0.3, 0.7])
    rateE1, rateE2, inputE1, inputE2 = get_traces(0.1, 0.)
    fig.add_trace(
        go.Scatter(
            visible=True,
            line=dict(color="#FF0000", width=4),
            name="Population 1",
            x = time,
            y = rateE1,
            showlegend = False),
            row = 2, col = 1)
    fig.add_trace(
        go.Scatter(
            visible=True,
            line=dict(color="#DD0088", width=4),
            name="Population 2",
            x = time,
            y = rateE2,
            showlegend = False),
            row = 2, col = 1)
    fig.add_trace(
        go.Scatter(
            visible=True,
            line=dict(color="#FF0000", width=4),
            name="",
            x = time,
            y = inputE1,
            showlegend = False),
            row = 1, col = 1)
    fig.add_trace(
        go.Scatter(
            visible=True,
            line=dict(color="#DD0088", width=4),
            name="",
            x = time,
            y = inputE2,
            showlegend=False),
            row = 1, col = 1)

    fig.update_xaxes(title_text="", range=[0, 10], row=1, col=1)
    fig.update_xaxes(title_text="time (s)", range=[0, 10], row=2, col=1)
    fig.update_yaxes(title_text="current", range=[0, 0.25], row=1, col=1)
    fig.update_yaxes(title_text="rate", range=[0, 1], row=2, col=1)

    frames += [go.Frame(data=fig.data, layout=fig.layout, name=str(step))]

    ## store the first frame to reuse later
    if step == 0:
        first_fig = fig

fig = go.Figure(frames=frames)

## add the first frame to the figure so it shows up initially
fig.add_traces(first_fig.data,)
fig.layout = first_fig.layout

## the rest is coped from the plotly documentation example on mri volume slices
def frame_args(duration):
    return {
            "frame": {"duration": duration},
            "mode": "immediate",
            "fromcurrent": True,
            "transition": {"duration": 0, "easing": "linear"},
        }

sliders = [
            {
                "pad": {"b": 10, "t": 60},
                "len": 0.9,
                "x": 0.1,
                "y": 0,
                "steps": [
                    {
                        "args": [[f.name], frame_args(2000)],
                        "label": str(k),
                        "method": "animate",
                    }
                    for k, f in enumerate(fig.frames)
                ],
            }
        ]

fig.update_layout(
         title='',
         width=500,
         height=400,
         margin=dict(t=0, b=0, l=0, r=0),
         hovermode=False,
         updatemenus = [
            {
                "buttons": [
                    {
                        "args": [None, frame_args(2000)],
                        "label": "&#9654;", # play symbol
                        "method": "animate",
                    },
                    {
                        "args": [[None], frame_args(2000)],
                        "label": "&#9724;", # pause symbol
                        "method": "animate",
                    },
                ],
                "direction": "left",
                "pad": {"r": 10, "t": 70},
                "type": "buttons",
                "x": 0.1,
                "y": 0,
            }
         ],
         sliders=sliders
)


config = {'displayModeBar': False}

fig.show(config=config)

:::


::: {style="text-align: right"}
[Wang 2002](https://doi.org/10.1016/s0896-6273(02)01092-9), [Wong and Wang 2006](https://doi.org/10.1523/JNEUROSCI.3733-05.2006)
:::

::::


## Double well: Decision making

In [None]:
def integral(rate):
    lim = 2*rate - 1
    return 0.5*rate + 0.5* (lim*np.arctanh(lim) + 0.5*np.log(np.abs(1-lim*lim)) -0.7);

def get_energy(drive1, drive2):
    energy = np.zeros((100,100))    
    rate1 = np.linspace(0,1,100)
    rate2 = np.linspace(0,1,100)
    for i, r1 in enumerate(rate1):
        for j, r2 in enumerate(rate2):
            energy[i,j] = -0.5*JE*(r1**2 + r2**2) + JI*r1*r2 - (drive1*r1 + drive2*r2) + integral(r1) + integral(r2)

    return energy, rate1, rate2

fig = make_subplots(rows=2, cols=2,
      specs=[[{'rowspan': 1, 'colspan': 1}, {'rowspan': 2, 'colspan': 1}],
       [{'rowspan': 1, 'colspan': 1}, None]], 
      row_heights=[0.3, 0.7],
#print_grid=False,
      vertical_spacing=0.12,
      horizontal_spacing=0.08)
      
# Add traces, one for each slider step
coh = 0
for drive in np.arange(0, 0.2, 0.01):
    rateE1, rateE2, inputE1, inputE2 = get_traces(drive, 0.)
    fig.add_trace(
        go.Scatter(
            visible=False,
            line=dict(color="#FF0000", width=4),
            name="",
            x = time,
            y = rateE1,
            showlegend=False),
            row = 2, col = 1)
    fig.add_trace(
        go.Scatter(
            visible=False,
            line=dict(color="#DD0088", width=4),
            name="",
            x = time,
            y = rateE2,
            showlegend=False),
            row = 2, col = 1)
    fig.add_trace(
        go.Scatter(
            visible=False,
            line=dict(color="#FF0000", width=4),
            name="",
            x = time,
            y = inputE1,
            showlegend = False),
            row = 1, col = 1)
    fig.add_trace(
        go.Scatter(
            visible=False,
            line=dict(color="#DD0088", width=4),
            name="",
            x = time,
            y = inputE2,
            showlegend=False),
            row = 1, col = 1)

    energy, rateE1, rateE2 = get_energy(drive+coh, drive-coh)
    # energy[energy>-0.15]=np.nan
    fig.add_trace(
        go.Contour(
            visible=False,
            x = rateE1,
            y = rateE2,
            z = energy,
            contours = {"start": -0.5, "end": -0.2, "size": 0.002, "coloring":"lines"},
            colorscale=colorscale,
            colorbar_title_text='energy'
            ),
            row = 1, col = 2)

# Make 10th trace visible
mid = len(fig.data)//2
for i in range(5):
    fig.data[mid+i].visible = True

# Create and add slider
steps = []
for i in range(len(fig.data)//5):

    step = dict(
        method="update",
        args=[{"visible": [False] * len(fig.data)}], 
        label=str(np.around(0.01 * i,2)) # layout attribute
    )
    step["args"][0]["visible"][5*i:5*i+5] = [True]*5  # Toggle i'th trace to "visible"
    steps.append(step)

sliders = [dict(
    active=mid//5,
    currentvalue={"prefix": "Drive: "},
    pad={"t": 10},
    steps=steps
)]

fig.update_layout(
    sliders=sliders, height=450, width=1000,
    autosize=False,
    margin=dict(t=0, b=0, l=0, r=0),
    template="plotly_white",
    hovermode=False,
)


fig.update_xaxes(title_text="", range=[0, 10], row=1, col=1)
fig.update_xaxes(title_text="time (s)", range=[0, 10], row=2, col=1)
fig.update_yaxes(title_text="current", range=[0, 0.25], row=1, col=1)
fig.update_yaxes(title_text="rate", range=[0, 1], row=2, col=1)
fig.update_xaxes(title_text="rate 1", range=[0, 1], row=1, col=2)
fig.update_yaxes(title_text="rate 2", range=[0, 1], row=1, col=2)
fig.update_scenes(xaxis_title_text='rate 1',  
                  yaxis_title_text='rate 2', row=1, col=2)

config = {'displayModeBar': False}

fig.show(config=config)

## Double well: Decision making

In [None]:
fig = make_subplots(rows=2, cols=2,
      specs=[[{'rowspan': 1, 'colspan': 1}, {'rowspan': 2, 'colspan': 1}],
       [{'rowspan': 1, 'colspan': 1}, None]], 
      row_heights=[0.3, 0.7],
#print_grid=False,
      vertical_spacing=0.12,
      horizontal_spacing=0.08)
      
# Add traces, one for each slider step

drive = 0.1
for coh in np.arange(-0.05, 0.055, 0.01):
    rateE1, rateE2, inputE1, inputE2 = get_traces(drive, coh)
    fig.add_trace(
        go.Scatter(
            visible=False,
            line=dict(color="#FF0000", width=4),
            name="",
            x = time,
            y = rateE1,
            showlegend=False),
            row = 2, col = 1)
    fig.add_trace(
        go.Scatter(
            visible=False,
            line=dict(color="#DD0088", width=4),
            name="",
            x = time,
            y = rateE2,
            showlegend=False),
            row = 2, col = 1)
    fig.add_trace(
        go.Scatter(
            visible=False,
            line=dict(color="#FF0000", width=4),
            name="",
            x = time,
            y = inputE1,
            showlegend = False),
            row = 1, col = 1)
    fig.add_trace(
        go.Scatter(
            visible=False,
            line=dict(color="#DD0088", width=4),
            name="",
            x = time,
            y = inputE2,
            showlegend=False),
            row = 1, col = 1)

    energy, rateE1, rateE2 = get_energy(drive+coh, drive-coh)
    # energy[energy>-0.15]=np.nan
    fig.add_trace(
        go.Contour(
            visible=False,
            x = rateE1,
            y = rateE2,
            z = energy,
            contours = {"start": -0.5, "end": -0.2, "size": 0.002, "coloring":"lines"},
            colorscale=colorscale,
            colorbar_title_text='energy'
            ),
            row = 1, col = 2)

# Make 10th trace visible
mid = len(fig.data)//2
for i in range(5):
    fig.data[mid+i].visible = True

# Create and add slider
steps = []
for i in range(len(fig.data)//5):

    step = dict(
        method="update",
        args=[{"visible": [False] * len(fig.data)}], 
        label=str(np.around(0.01 * i - 0.05,2)) # layout attribute
    )
    step["args"][0]["visible"][5*i:5*i+5] = [True]*5  # Toggle i'th trace to "visible"
    steps.append(step)

sliders = [dict(
    active=mid//5,
    currentvalue={"prefix": "Coherency: "},
    pad={"t": 10},
    steps=steps
)]

fig.update_layout(
    sliders=sliders, height=450, width=1000,
    autosize=False,
    margin=dict(t=0, b=0, l=0, r=0),
    template="plotly_white",
    hovermode=False,
)


fig.update_xaxes(title_text="", range=[0, 10], row=1, col=1)
fig.update_xaxes(title_text="time (s)", range=[0, 10], row=2, col=1)
fig.update_yaxes(title_text="current", range=[0, 0.25], row=1, col=1)
fig.update_yaxes(title_text="rate", range=[0, 1], row=2, col=1)
fig.update_xaxes(title_text="rate 1", range=[0, 1], row=1, col=2)
fig.update_yaxes(title_text="rate 2", range=[0, 1], row=1, col=2)
fig.update_scenes(xaxis_title_text='rate 1',  
                  yaxis_title_text='rate 2', row=1, col=2)


config = {'displayModeBar': False}

fig.show(config=config)

## Experimental data

:::: {.columns}

::: {.column width="60%"}
![](figs/roitman1.png){fig-align="center" width="600"}

:::

::: {.column width="40%"}
![](figs/roitman23.png){fig-align="center" width="300"}
:::

::: {style="text-align: right"}
[Roitman and Shadlen 2002](https://doi.org/10.1523/JNEUROSCI.22-21-09475.2002)
:::

::::

## Experimental data


<!-- ![](figs/peixoto12.png){fig-align="center" width="600"} -->
![](figs/2025-12-03-16-58-17.png)

::: {style="text-align: right"}
<!-- [Peixoto et al. 2021](https://doi.org/10.1038/s41586-020-03181-9) -->
[Luo et al. 2025](https://doi.org/10.1038/s41586-025-09528-4)
:::


## Continuous spatial memory

:::: {.columns}

::: {.column}
![](figs/ODR.png)

:::

::: {.column .fragment}
![](figs/funahashi.png)

:::

::: {style="text-align: right"}
[Funahashi et al 1989](https://doi.org/10.1152/jn.1989.61.2.331)
:::

::::

## Ring attractor model


:::: {.columns}

::: {.column width="35%"}
![](figs/ring.svg)

:::

::: {.column .fragment width="65%"}

\begin{eqnarray*}

\newcommand{\blue}[1]{\color{blue}{#1}}
\newcommand{\red}[1]{\color{red}{#1}}

\scriptstyle \tau_E \frac{d r^E_i(t)}{dt} &=& \scriptstyle -r^E_i(t) + F_E\left[ \sum_j \red{W^{EE}}_{ij} r^E_j(t) - \blue{W^{EI}} r^I(t) \right] \\

\scriptstyle \tau_I \frac{d r^I(t)}{dt} &=& \scriptstyle -r^I(t) + F_I\left[ \red{W^{IE}} \sum_j r^E_j(t) - \blue{W^{II}} r^I(t) \right] \\

\scriptstyle W^{EE}_{ij} &=& \scriptstyle J_E \cos(\theta_i - \theta_j)

\end{eqnarray*}


::: {.fragment}

![](figs/go-down.svg){fig-align="center" width="100"}

\begin{eqnarray*}

\newcommand{\blue}[1]{\color{blue}{#1}}
\newcommand{\red}[1]{\color{red}{#1}}

\scriptstyle \tau \frac{d r_i(t)}{dt} &=& \scriptstyle -r_i(t) + F\left[ \sum_j J_{ij} r_j(t) \right] \\

\scriptstyle J_{ij} &=& \scriptstyle \red{J_E} \cos(\theta_i - \theta_j) - \blue{\frac{\gamma W^{EI}W^{IE}}{1+\gamma W^{II}}}

\end{eqnarray*}

:::

:::

::::

## Ring attractor network

In [None]:
Nn = 256
x= np.linspace(-np.pi, np.pi, Nn, endpoint=False)

cs = np.cos(x)
sn = np.sin(x)

J0 = -3.2
J2 = 8.5
JE = J0 + J2* (np.outer(cs,cs) + np.outer(sn, sn))


dt=0.01
T=5
time = np.arange(0,T,dt)
Nt = len(time)

# input-output function for all cells, as used previously (Brunel, Cereb Cortex 13:1151, 2003)
def fI(x):
    return x*x*(x>0)*(x<1) + np.sqrt(np.abs(4*x-3))*(x>=1)

input = np.zeros((Nt,))
input[int(1./dt):int(1.5/dt)] = 1

tau = 0.1

# routine to extract population vectors from matrix of rates
vecs = np.cos(x) + 1j*np.sin(x)
vecs = np.outer(np.ones((1,Nt)),vecs)
def decode(rate):
    res = np.sum(vecs*rate, axis=1)
    return np.angle(res)
    
def get_traces(stimat):

    I = np.exp(4*np.cos(x - stimat))
    #I = 4*(1 - 0.1 + 0.1 * np.cos(x - stimat)) # stimulus current
    rate = np.zeros((Nt,Nn))

    for i in range(Nt-1): 
        network_inputs = np.dot(JE, rate[i])/Nn
        noise = 0.3*np.random.randn(Nn,1).flatten()
        rate[i+1] = rate[i] + dt/tau * (-rate[i] + fI(network_inputs +  noise + input[i]*I))
        
    return rate

# Add traces, one for each slider step

frames = []
stims = np.random.randint(-180, 180, 20)

for step, stimat in enumerate(stims):

    fig = make_subplots(rows=2, cols=3,
      specs=[[{'rowspan': 1, 'colspan': 2}, None, {'rowspan': 2, 'colspan': 1}],
       [{'rowspan': 1, 'colspan': 2}, None, None]], 
      row_heights=[0.3, 0.7],
#print_grid=False,
      vertical_spacing=0.12,
      horizontal_spacing=0.09)
    
    rate = get_traces(np.deg2rad(stimat))
    enc = decode(rate)

    fig.add_trace(
        go.Scatter(
            visible=True,
            line=dict(color="#000000", width=4),
            name="",
            x = time,
            y = input,
            showlegend = False),
            row = 1, col = 1)
            
    fig.add_trace(go.Heatmap(z=rate.T, 
	    x=time, 
	    y=np.rad2deg(x), 
	    zmin=0, 
	    zmax=7, 
	    zauto=False,
	    colorscale=colorscale,
	    showscale=False), row=2, col=1)
	    
    fig.add_trace(go.Scatter(mode="markers",
    	    x=[1.], 
	    y=[stimat],
	    marker_symbol="arrow-right",
	    marker_color="red",
	    marker_size=15,
	    showlegend = False), row=2, col=1)
	    
    fig.add_trace(
        go.Scatter(
            visible=True,
            line=dict(color="magenta", width=2),
            name="",
            x = time,
            y = np.rad2deg(enc),
            showlegend = False),
            row = 2, col = 1)

    fig.add_trace(
        go.Scatter(
            visible=True,
            line=dict(color="#000000", width=1),
            name="",
            x = np.rad2deg(x),
            y = rate[-1],
            showlegend = False),
            row = 1, col = 3)
            
    fig.add_trace(
        go.Scatter(mode="markers",
    	    x=[stimat], 
	    y=[5.],
	    marker_symbol="arrow-down",
	    marker_color="red",
	    marker_size=15,
	    showlegend = False),
            row = 1, col = 3)
            
    fig.add_trace(
        go.Scatter(mode="markers",
    	    x=[np.rad2deg(enc[-1])], 
	    y=[5.5],
	    marker_symbol="arrow-down",
	    marker_color="magenta",
	    marker_size=15,
	    showlegend = False),
            row = 1, col = 3)

    fig.update_xaxes(title_text="", range=[0, T], row=1, col=1)
    fig.update_xaxes(title_text="time (s)", range=[0, T], row=2, col=1)
    fig.update_yaxes(title_text="current", row=1, col=1)
    fig.update_yaxes(title_text="neurons (deg)", range=[-180, 180], row=2, col=1)
    fig.update_xaxes(title_text="neurons (deg)", range=[-180, 180], row=1, col=3)
    fig.update_yaxes(title_text="rate", row=1, col=3)

    frames += [go.Frame(data=fig.data, layout=fig.layout, name=str(step))]

    ## store the first frame to reuse later
    if step == 0:
        first_fig = fig
        
fig = go.Figure(frames=frames)

## add the first frame to the figure so it shows up initially
fig.add_traces(first_fig.data,)
fig.layout = first_fig.layout

## the rest is coped from the plotly documentation example on mri volume slices
def frame_args(duration):
    return {
            "frame": {"duration": duration},
            "mode": "immediate",
            "fromcurrent": True,
            "transition": {"duration": 0, "easing": "linear"},
        }

sliders = [
            {
                "pad": {"b": 10, "t": 60},
                "len": 0.9,
                "x": 0.1,
                "y": 0,
                "steps": [
                    {
                        "args": [[f.name], frame_args(2000)],
                        "label": str(k),
                        "method": "animate",
                    }
                    for k, f in enumerate(fig.frames)
                ],
            }
        ]

fig.update_layout(
         title='',
         width=1000,
         height=500,
         autosize=False,
         margin=dict(t=0, b=0, l=0, r=0),
         hovermode=False,
         updatemenus = [
            {
                "buttons": [
                    {
                        "args": [None, frame_args(2000)],
                        "label": "&#9654;", # play symbol
                        "method": "animate",
                    },
                    {
                        "args": [[None], frame_args(2000)],
                        "label": "&#9724;", # pause symbol
                        "method": "animate",
                    },
                ],
                "direction": "left",
                "pad": {"r": 10, "t": 70},
                "type": "buttons",
                "x": 0.1,
                "y": 0,
            }
         ],
         sliders=sliders
)


config = {'displayModeBar': False}

fig.show(config=config)


## Ring attractor model

In [None]:
nsims=100
dt=0.01
time = np.arange(0,T,dt)
Nt = len(time)
x= np.linspace(-np.pi, np.pi, Nn, endpoint=False)

endpoints = np.zeros((nsims,))

# run the simulation multiple times
for n in range(nsims):

    phase = 2*np.pi*np.random.rand()
    I = 2*np.exp(4*np.cos(x - phase))

    rate = np.zeros((Nt,Nn))

    tau = 0.1

    for i in range(Nt-1):
        if (i>100)&(i<150):  
            input=I
        else:
            input=0
        network_inputs = np.dot(JE, rate[i])/Nn
        noise = 0.3*np.random.randn(Nn,1).flatten()
        rate[i+1] = rate[i] + dt/tau * (-rate[i] + fI(network_inputs + input + noise))
    
    trace = decode(rate)

    plt.plot(time, np.rad2deg(trace))

    endpoints[n] = trace[-1]

plt.ylim([-180, 180])
plt.xlabel('time (s)')
plt.ylabel('neuron (deg)');

## Ring attractor model

In [None]:
nsims=100
dt=0.01
time = np.arange(0,T,dt)
Nt = len(time)


endpoints = np.zeros((nsims,))

# run the simulation multiple times
for n in range(nsims):

    I = 2*np.exp(4*np.cos(x))

    rate = np.zeros((Nt,Nn))

    tau = 0.1

    for i in range(Nt-1):
        if (i>100)&(i<150):  
            input=I
        else:
            input=0
        network_inputs = np.dot(JE, rate[i])/Nn
        noise = 0.3*np.random.randn(Nn,1).flatten()
        rate[i+1] = rate[i] + dt/tau * (-rate[i] + fI(network_inputs + input + noise))
    
    trace = decode(rate)

    plt.plot(time, np.rad2deg(trace))

    endpoints[n] = trace[-1]

plt.ylim([-20, 20])
plt.xlabel('time (s)')
plt.ylabel('neuron (deg)');

## Experimental evidence

:::: {.columns}

::: {.column width="40%"}
![](figs/wimmer.png){fig-align="center" width="350"}
::: 

::: {.column width="60%"}
![](figs/tschiersch.png){fig-align="center" width="600"}
::: 

::::

::: {style="text-align: right"}
[Wimmer et al 2014](https://doi.org/10.1038/nn.3645)

[Tschiersch et al 2025](https://doi.org/10.1101/2025.01.15.633176)
:::

## Network models reviewed

* E-I networks:
	- Wilson-Cowan network
	- Inhibition-stabilized network
* Double well networks:
	- selective working memory
	- decision making
* Ring attractor network

## As an RNN: $\scriptstyle \tau \frac{d I_i(t)}{dt} = -I_i(t) + \sum_j J_{ij} F[I_j(t)]$

:::: {.columns}

::: {.column .fragment}

$$
\newcommand{\blue}[1]{\color{blue}{#1}}
\newcommand{\red}[1]{\color{red}{#1}}
\scriptstyle J_{ij} = \scriptstyle \red{J_E} \cos(\theta_i - \theta_j) - \blue{\frac{\gamma W^{EI}W^{IE}}{1+\gamma W^{II}}}
$$

::: {width=200}

In [None]:
plt.imshow(JE, origin="lower")
plt.xlabel('to neuron')
plt.ylabel('from neuron')
plt.colorbar(label="connection strength");

![](figs/ring.svg){.absolute top=90 left=-100 width="150"}

:::

:::

::: {.column .fragment}

$$
\newcommand{\blue}[1]{\color{blue}{#1}}
\newcommand{\red}[1]{\color{red}{#1}}
\scriptstyle J_{ij} = \scriptstyle \left\{
\begin{array}{ll}
      \scriptstyle \red{J_E}, & \scriptstyle C(i)=C(j) \\
      \scriptstyle \blue{J_I}, & \scriptstyle C(i) \neq C(j) \\
\end{array} 
\; \; \; \; \; \right.
$$

::: {width=200}

In [None]:
import scipy.linalg as linalg
a = 1.1*np.ones((Nn//2, Nn//2))
Jdw = -1.8*np.ones((Nn, Nn)) + linalg.block_diag(a,a)
plt.imshow(Jdw, origin="lower")
plt.xlabel('to neuron')
plt.ylabel('from neuron')
plt.colorbar(label="connection strength");

:::

![](figs/doublewell3.svg){.absolute top=125 right=-70 width="200"}

:::

::::

## Rank-1 RNN

:::: {.columns}

::: {.column}

$\scriptstyle J_{ij} = \xi_i \xi_j$

In [None]:
csi = np.random.randn(Nn)
J1 = np.outer(csi,csi)

plt.imshow(J1, origin="lower")
plt.xlabel('to neuron')
plt.ylabel('from neuron')
plt.colorbar(label="connection strength");

:::


::: {.column .fragment}
Now reorder $\xi$:


In [None]:
csi = np.sort(csi)
J1 = np.outer(csi,csi)

plt.imshow(J1, origin="lower")
plt.xlabel('to neuron')
plt.ylabel('from neuron')
plt.colorbar(label="connection strength");

:::

::::


## Dynamics in rank-1 RNN

In [None]:
R1 = np.random.randn(Nn)
inds = np.argsort(R1)
R1n = R1[inds]
W = np.outer(R1n,R1n)

J0 = -5.2
J2 = 5
J = J0 + J2* W

x = np.linspace(-1, 1, Nn)

dt=0.01
T=3
time = np.arange(0,T,dt)
Nt = len(time)

tau = 0.1

input = np.zeros((Nt,))
input[int(1./dt):int(1.5/dt)] = 1
    
def get_traces(stim):

    I = np.zeros((Nn,)) 
    I[R1n*stim>0] = 3

    rate = np.zeros((Nt,Nn))


    for i in range(Nt-1): 
        network_inputs = np.dot(J, rate[i])/Nn
        noise = 0.5*np.random.randn(Nn,1).flatten()
        rate[i+1] = rate[i] + dt/tau * (-rate[i] + fI(network_inputs +  noise + input[i]*I))
        
    return rate

# Add traces, one for each slider step

frames = []
stims = np.random.choice([-1, 1], 20)

for step, stim in enumerate(stims):

    fig = make_subplots(rows=2, cols=3,
      specs=[[{'rowspan': 1, 'colspan': 2}, None, {'rowspan': 2, 'colspan': 1}],
       [{'rowspan': 1, 'colspan': 2}, None, None]], 
      row_heights=[0.3, 0.7],
#print_grid=False,
      vertical_spacing=0.12,
      horizontal_spacing=0.09)
    
    rate = get_traces(stim)

    fig.add_trace(
        go.Scatter(
            visible=True,
            line=dict(color="#000000", width=4),
            name="",
            x = time,
            y = input,
            showlegend = False),
            row = 1, col = 1)
            
    fig.add_trace(go.Heatmap(z=rate.T, 
	    x=time, 
	    y=x,
	    zmin=0,
	    zmax=10,
	    zauto=False,
	    colorscale=colorscale,
	    showscale=False), row=2, col=1)

    fig.add_trace(
        go.Scatter(
            visible=True,
            line=dict(color="#000000", width=1),
            name="",
            x = x,
            y = rate[-1],
            showlegend = False),
            row = 1, col = 3)
            
    fig.add_trace(
        go.Scatter(
    	    x=[stim, 0], 
	    y=[8.5, 8.5],
            line=dict(color="#550000", width=4),
	    showlegend = False),
            row = 1, col = 3)

    fig.update_xaxes(title_text="", range=[0, T], row=1, col=1)
    fig.update_xaxes(title_text="time (s)", range=[0, T], row=2, col=1)
    fig.update_yaxes(title_text="current", row=1, col=1)
    fig.update_yaxes(title_text="neurons", range=[-1, 1], row=2, col=1)
    fig.update_xaxes(title_text="neurons", range=[-1, 1], row=1, col=3)
    fig.update_yaxes(title_text="rate", range=[0, 10], row=1, col=3)

    frames += [go.Frame(data=fig.data, layout=fig.layout, name=str(step))]

    ## store the first frame to reuse later
    if step == 0:
        first_fig = fig
        
fig = go.Figure(frames=frames)

## add the first frame to the figure so it shows up initially
fig.add_traces(first_fig.data,)
fig.layout = first_fig.layout

## the rest is coped from the plotly documentation example on mri volume slices
def frame_args(duration):
    return {
            "frame": {"duration": duration},
            "mode": "immediate",
            "fromcurrent": True,
            "transition": {"duration": 0, "easing": "linear"},
        }

sliders = [
            {
                "pad": {"b": 10, "t": 60},
                "len": 0.9,
                "x": 0.1,
                "y": 0,
                "steps": [
                    {
                        "args": [[f.name], frame_args(2000)],
                        "label": str(k),
                        "method": "animate",
                    }
                    for k, f in enumerate(fig.frames)
                ],
            }
        ]

fig.update_layout(
         title='',
         width=1000,
         height=500,
         autosize=False,
         margin=dict(t=0, b=0, l=0, r=0),
         hovermode=False,
         updatemenus = [
            {
                "buttons": [
                    {
                        "args": [None, frame_args(2000)],
                        "label": "&#9654;", # play symbol
                        "method": "animate",
                    },
                    {
                        "args": [[None], frame_args(2000)],
                        "label": "&#9724;", # pause symbol
                        "method": "animate",
                    },
                ],
                "direction": "left",
                "pad": {"r": 10, "t": 70},
                "type": "buttons",
                "x": 0.1,
                "y": 0,
            }
         ],
         sliders=sliders
)


config = {'displayModeBar': False}

fig.show(config=config)


## Rank-2 RNN

:::: {.columns}

::: {.column}

$\scriptstyle J_{ij} = \xi_i \xi_j + \psi_i \psi_j$

In [None]:
csi = np.random.randn(Nn)
psi = np.random.randn(Nn)
J2 = np.outer(csi,csi) + np.outer(psi,psi)

plt.imshow(J2, origin="lower")
plt.xlabel('to neuron')
plt.ylabel('from neuron')
plt.colorbar(label="connection strength");

:::


::: {.column .fragment}
Now reorder $\xi + i \psi$:


In [None]:
ang = np.angle(csi+1j*psi)
inds = np.argsort(ang)
W = np.outer(csi[inds],csi[inds]) + np.outer(psi[inds],psi[inds])

plt.imshow(W, origin="lower")
plt.xlabel('to neuron')
plt.ylabel('from neuron')
plt.colorbar(label="connection strength");

:::

::::


## Dynamics in rank-2 RNN

In [None]:
J0 = -5.2
#J2 = 11.119
J2 = 11.11
J = J0 + J2* W

x = np.linspace(-np.pi, np.pi, Nn, endpoint=False)

dt=0.01
T=3
time = np.arange(0,T,dt)
Nt = len(time)

tau = 0.1

input = np.zeros((Nt,))
input[int(1./dt):int(1.5/dt)] = 1


# routine to extract population vectors from matrix of rates
vecs = np.cos(ang[inds]) + 1j*np.sin(ang[inds])
vecs = np.outer(np.ones((1,Nt)),vecs)
def decode(rate):
    res = np.sum(vecs*rate, axis=1)
    return np.angle(res)
    
def get_traces(stimat):

    I = 2*np.exp(4*np.cos(x - stimat))
    rate = np.zeros((Nt,Nn))


    for i in range(Nt-1): 
        network_inputs = np.dot(J, rate[i])/Nn
        noise = 0.13*np.random.randn(Nn,1).flatten()
        rate[i+1] = rate[i] + dt/tau * (-rate[i] + fI(network_inputs +  noise + input[i]*I))
        
    return rate

# Add traces, one for each slider step

frames = []
stims = np.random.randint(-180, 180, 20)

for step, stimat in enumerate(stims):

    fig = make_subplots(rows=2, cols=3,
      specs=[[{'rowspan': 1, 'colspan': 2}, None, {'rowspan': 2, 'colspan': 1}],
       [{'rowspan': 1, 'colspan': 2}, None, None]], 
      row_heights=[0.3, 0.7],
      vertical_spacing=0.12,
      horizontal_spacing=0.09)
    
    rate = get_traces(np.deg2rad(stimat))
    enc = decode(rate)

    fig.add_trace(
        go.Scatter(
            visible=True,
            line=dict(color="#000000", width=4),
            name="",
            x = time,
            y = input,
            showlegend = False),
            row = 1, col = 1)
            
    fig.add_trace(go.Heatmap(z=rate.T, 
	    x=time, 
	    y=np.rad2deg(x),
	    colorscale=colorscale,
	    showscale=False), row=2, col=1)
	    
    fig.add_trace(go.Scatter(mode="markers",
    	    x=[1.], 
	    y=[stimat],
	    marker_symbol="arrow-right",
	    marker_color="red",
	    marker_size=15,
	    showlegend = False), row=2, col=1)
	    
    fig.add_trace(
        go.Scatter(
            visible=True,
            line=dict(color="magenta", width=2),
            name="",
            x = time,
            y = np.rad2deg(enc),
            showlegend = False),
            row = 2, col = 1)

    fig.add_trace(
        go.Scatter(
            visible=True,
            line=dict(color="#000000", width=1),
            name="",
            x = np.rad2deg(x),
            y = rate[-1],
            showlegend = False),
            row = 1, col = 3)
            
    fig.add_trace(
        go.Scatter(mode="markers",
    	    x=[stimat], 
	    y=[30.],
	    marker_symbol="arrow-down",
	    marker_color="red",
	    marker_size=15,
	    showlegend = False),
            row = 1, col = 3)
            
    fig.add_trace(
        go.Scatter(mode="markers",
    	    x=[np.rad2deg(enc[-1])], 
	    y=[33.],
	    marker_symbol="arrow-down",
	    marker_color="magenta",
	    marker_size=15,
	    showlegend = False),
            row = 1, col = 3)

    fig.update_xaxes(title_text="", range=[0, T], row=1, col=1)
    fig.update_xaxes(title_text="time (s)", range=[0, T], row=2, col=1)
    fig.update_yaxes(title_text="current", row=1, col=1)
    fig.update_yaxes(title_text="neurons (deg)", range=[-180, 180], row=2, col=1)
    fig.update_xaxes(title_text="neurons (deg)", range=[-180, 180], row=1, col=3)
    fig.update_yaxes(title_text="rate", row=1, col=3)

    frames += [go.Frame(data=fig.data, layout=fig.layout, name=str(step))]

    ## store the first frame to reuse later
    if step == 0:
        first_fig = fig
        
fig = go.Figure(frames=frames)

## add the first frame to the figure so it shows up initially
fig.add_traces(first_fig.data,)
fig.layout = first_fig.layout

## the rest is coped from the plotly documentation example on mri volume slices
def frame_args(duration):
    return {
            "frame": {"duration": duration},
            "mode": "immediate",
            "fromcurrent": True,
            "transition": {"duration": 0, "easing": "linear"},
        }

sliders = [
            {
                "pad": {"b": 10, "t": 60},
                "len": 0.9,
                "x": 0.1,
                "y": 0,
                "steps": [
                    {
                        "args": [[f.name], frame_args(2000)],
                        "label": str(k),
                        "method": "animate",
                    }
                    for k, f in enumerate(fig.frames)
                ],
            }
        ]

fig.update_layout(
         title='',
         width=1000,
         height=500,
         autosize=False,
         margin=dict(t=0, b=0, l=0, r=0),
         hovermode=False,
         updatemenus = [
            {
                "buttons": [
                    {
                        "args": [None, frame_args(2000)],
                        "label": "&#9654;", # play symbol
                        "method": "animate",
                    },
                    {
                        "args": [[None], frame_args(2000)],
                        "label": "&#9724;", # pause symbol
                        "method": "animate",
                    },
                ],
                "direction": "left",
                "pad": {"r": 10, "t": 70},
                "type": "buttons",
                "x": 0.1,
                "y": 0,
            }
         ],
         sliders=sliders
)


config = {'displayModeBar': False}

fig.show(config=config)


## Dynamics in rank-2 RNN

Stable bumps converge on a few fixed points: not truly ring attractor dynamics

In [None]:
nsims=100
dt=0.01
T=3
time = np.arange(0,T,dt)
Nt = len(time)

endpoints = np.zeros((nsims,))

# run the simulation multiple times
for n in range(nsims):

    phase = 2*np.pi*np.random.rand()
    I = 2*np.exp(4*np.cos(ang[inds]+phase))

    rate = np.zeros((Nt,Nn))

    tau = 0.1

    for i in range(Nt-1):
        if (i<0.5/dt): 
            input=I
        else:
            input=0
        network_inputs = np.dot(J, rate[i])/Nn
        noise = 0.13*np.random.randn(Nn,1).flatten()
        rate[i+1] = rate[i] + dt/tau * (-rate[i] + fI(network_inputs + input + noise))
    
    trace = decode(rate)

    plt.plot(time, trace)
    plt.xlabel("time since stimulus (s)")
    plt.ylabel("decoded angle (deg)")

    endpoints[n] = trace[-1]

## Are ring attractors biologically plausible?

::: {.incremental}
* cellular homeostatic mechanisms: [Renart et al 2003](https://doi.org/10.1016/s0896-6273(03)00255-1)
* slow synapses: [Istkov et al 2011](https://doi.org/10.3389/fncom.2011.00040); [Hansel and Mato 2013](https://doi.org/10.1523/JNEUROSCI.3455-12.2013)
* connectivity training: [Darshan and Rivkind 2022](https://doi.org/10.1016/j.celrep.2022.110612); [Clark et al 2025](https://doi.org/10.1101/2025.01.26.634933)

* efficient coding? [Ganguli Simoncelli 2014](https://doi.org/10.1162/NECO_a_00638); [Yang et al 2024](https://doi.org/10.7554/eLife.95160)

:::

# End of part 2
