# Neural Field Equation: 
Consider the ring model:

$$\tau \frac{\partial r(\theta, t)}{\partial t} = -r(\theta, t) + \phi\left(\frac{1}{2\pi} \int_{-\pi}^{\pi} W(\theta - \theta') r(\theta', t) d\theta' + I_0\right)$$

with the kernel $W(\Delta \theta) = W_0 + W_1 \cos(\Delta \theta)$. Observation: the kernel is symmetric, i.e., $\cos(\Delta \theta) = \cos(-\Delta \theta)$. 

## Steady-State Solutions

### Homogeneous Steady States
The stationary solutions for the steady-state firing rate $r_0$ are determined by the equation:
$$ r_0 = \phi(r_0 W_0 + I_0), $$
where the nonlinearity function $\phi(x)$ is defined by the piecewise function:

$$
\phi(x) = 
\begin{cases} 
x^2 & \text{for } 0 \leq x \leq 1, \\
2\sqrt{x - \frac{3}{4}} & \text{for } x > 1, \\
0 & \text{otherwise}.
\end{cases}
$$

### Solutions Based on Input and Coupling Conditions
Depending on the value of the input $I_0$ and the coupling $W_0$, the solutions are:

#### For $0 \leq r_0 W_0 + I_0 \leq 1$:
- **r_01 (in the code)**: 
$$
r_0 = \frac{1 - 2W_0 I_0 + \sqrt{1 - 4W_0 I_0}}{2 W_0^2}
$$

- **r_02 (in the code)**: 
$$
r_0 = \frac{1 - 2W_0 I_0 - \sqrt{1 - 4W_0 I_0}}{2 W_0^2}
$$

#### For $r_0 W_0 + I_0 > 1$:
- **r_03 (in the code)**: 
$$
r_0 = 2 W_0 + 2 \sqrt{W_0^2 + I_0 - \frac{3}{4}}
$$

This code visualizes the fixed point solutions for the system described by the nonlinear equation $r_0 = \phi(w_0 r_0 + I_0)$:

In [26]:
%matplotlib widget
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from common_utils import fixed_point_solver, nonlinearity, derivative_nonlinearity, r_01, r_02, r_03

def update_plots(w0, I0):
    plt.clf()  
        
    r_01_val = r_01(w0, I0)
    r_02_val = r_02(w0, I0) 
    r_03_val = r_03(w0, I0) 
    
    r0_values = np.linspace(-2.5, 15, 1000)
    phi_values = nonlinearity(w0 * r0_values + I0)

    plt.fill_between(r0_values, -2.5, 15, where=(w0 * r0_values + I0 >= 0) & (w0 * r0_values + I0 <= 1), color='lightblue', alpha=0.3, label='$0 \leq (w_0 \cdot r_0 + I_0) \leq 1$')
    plt.fill_between(r0_values, -2.5, 15, where=w0 * r0_values + I0 > 1, color='lightgray', alpha=0.3, label='$(w_0 \cdot r_0 + I_0) > 1$')

    plt.plot(r0_values, phi_values, label='$\phi(w_0 \cdot r_0 + I_0)$', color='blue')
    plt.plot(r0_values, r0_values, label='$r_0$', linestyle='--', color='red')

    tolerance=1e-4
    if np.isclose(r_01_val, nonlinearity(w0*r_01_val+I0), atol=tolerance):
        plt.scatter([r_01_val], [nonlinearity(w0*r_01_val+I0)], color='red', label=f'$r_{{01}}$: {r_01_val:.2f}', zorder=5)
        
    if np.isclose(r_02_val, nonlinearity(w0*r_02_val+I0), atol=tolerance):
        plt.scatter([r_02_val], [nonlinearity(w0*r_02_val+I0)], color='orange', label=f'$r_{{02}}$: {r_02_val:.2f}', zorder=5)
    
    if np.isclose(r_03_val, nonlinearity(w0*r_03_val+I0), atol=tolerance):
        plt.scatter([r_03_val], [nonlinearity(w0*r_03_val+I0)], color='purple', label=f'$r_{{03}}$: {r_03_val:.2f}', zorder=5)
        
    plt.gca().set_facecolor('white')
    plt.gcf().set_facecolor('white')
    plt.xlabel('$r_0$', color='black')
    plt.ylabel('Value', color='black')
    plt.tick_params(axis='x', colors='black')
    plt.tick_params(axis='y', colors='black')
    plt.legend()
    plt.grid(color='gray')
    
    plt.show()

w0_slider = widgets.FloatSlider(value=-1, min=-3, max=3.5, step=0.1, description='$w_0$:', continuous_update=False)
I0_slider = widgets.FloatSlider(value=0.5, min=0, max=2, step=0.05, description='$I_0$:', continuous_update=False)

widgets.interactive(update_plots, w0=w0_slider, I0=I0_slider)

interactive(children=(FloatSlider(value=-1.0, continuous_update=False, description='$w_0$:', max=3.5, min=-3.0…

In [69]:
# TO DO: plot critical w0, fix r_03 solution

%matplotlib widget
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from scipy.integrate import solve_ivp
from scipy.optimize import fsolve
from common_utils import dr_dt, r_01, r_02, r_03, nonlinearity, derivative_nonlinearity

def update_plots(I0):
    N = 1000
    w0_values = np.linspace(-3, 5, N)
    plt.clf()  
    
    def apply_mask(w0, I_0, r_func, tolerance=1e-1):
        r_values = r_func(w0, I_0)
        mask = np.isclose(r_values, nonlinearity(w0 * r_values + I_0), atol=tolerance)
        
        r_values_filtered = np.where(mask, r_values, np.nan)
        w0_filtered = np.where(mask, w0, np.nan)  
    
        return w0_filtered, r_values_filtered

    def select_and_solve(w0, r_filtered, I_0, func):
        non_nan_indices = ~np.isnan(r_filtered)
        w0_non_nan = w0[non_nan_indices]
        
        if len(w0_non_nan) == 0:
            return np.array([]), np.array([])  
        
        indices_selected = np.linspace(0, len(w0_non_nan) - 1, min(len(w0_non_nan), 10), dtype=int)
        w0_selected = w0_non_nan[indices_selected]
        r_num = []
        t_span = [0, 100]
        t_eval = np.linspace(t_span[0], t_span[1], 1000)
        
        for w0_val in w0_selected:
            r0 = [func(w0_val, I_0)] # TOY CHECK: add +0.01 to check stability of each 
            sol = solve_ivp(dr_dt, t_span, r0, args=(w0_val, I_0), t_eval=t_eval, method='RK45')
            r_num.append(sol.y[0, -1])
        
        return w0_selected, r_num

    def find_critical_w0(r_func, I_0, initial_guess=0.5):

        equation = lambda w0: w0 * derivative_nonlinearity(w0 * r_func(w0, I_0) + I_0) - 1
    
        critical_w0 = fsolve(equation, initial_guess)[0]
        return critical_w0
    
    for func, label, linestyle in [(r_01, 'r_{01}', '--'), (r_02, 'r_{02}', '-'), (r_03, 'r_{03}', '-')]:                                              
        w0_filtered, r_filtered = apply_mask(w0_values, I0, func)
        if len(w0_filtered) > 0:
            plt.plot(w0_filtered, r_filtered, label=f'${label}$', linestyle=linestyle)
            
            w0_selected, r_num = select_and_solve(w0_filtered, r_filtered, I0, func)
            if len(w0_selected) > 0:
                plt.plot(w0_selected, r_num, '.', label=f'num ${label}$', alpha=0.75, markersize=5)

    if I0 != 0 and I0 < 0.5:
        plt.axvline(x=1/(4*I0), color='red', linestyle='--', label='$\\frac{1}{4I_0} = ' + f'{1/(4*I0):.2f}$' + ' Vertical Line')


    critical_r01 = find_critical_w0(r_01, I0)
    critical_r02 = find_critical_w0(r_02, I0)
    critical_r03 = find_critical_w0(r_03, I0)
    
    plt.axvline(x=critical_r01, color='purple', linestyle='--')
    print('Critical w0 for r01 = ' + f'{critical_r01:.2f}')

    plt.axvline(x=critical_r02, color='green', linestyle='--')
    print('Critical w0 for r02 = ' + f'{critical_r02:.2f}')

    plt.axvline(x=critical_r03, color='blue', linestyle='--')
    print('Critical w0 for r03 = ' + f'{critical_r03:.2f}')
    
    plt.xlabel('$w_0$')
    plt.ylabel('$r_0$')
    plt.title(f'Bifurcation Diagram at $I_0 = {I0:.2f}$')
    plt.legend()
    plt.grid(True)
    plt.show()


slider = widgets.FloatSlider(value=1/8, min=0, max=1.5, step=0.05, description='$I_0$:', continuous_update=False)
widgets.interactive(update_plots, I0=slider)


interactive(children=(FloatSlider(value=0.125, continuous_update=False, description='$I_0$:', max=1.5, step=0.…

there is a tradeoff in the tolerance param applied in the mask, the smaller it is the "smoother" the graph looks, the worst the numerical solutions for the unstable solution r_01. I find the optimal balance to be 1e-1, with this value the theoretical solutions plot looks smooth 