# Neural Field Equation: 
Consider the ring model:

$$\tau \frac{\partial r(\theta, t)}{\partial t} = -r(\theta, t) + \phi\left(\frac{1}{2\pi} \int_{-\pi}^{\pi} W(\theta - \theta') r(\theta', t) d\theta' + I_0\right)$$

with the kernel $W(\Delta \theta) = W_0 + W_1 \cos(\Delta \theta)$. Observation: the kernel is symmetric, i.e., $\cos(\Delta \theta) = \cos(-\Delta \theta)$. 

## Steady-State Solutions
- Homogeneous steady states (stationary solutions) are given by: $$r_0 = \phi(r_0 W_0 + I_0).$$

- Using the nonlinearity function $\phi(x)$ defined as:

$$
\phi(x) = 
\begin{cases} 
x^2 & \text{if } 0 \leq x \leq 1, \\
2\sqrt{x - \frac{3}{4}} & \text{if } x > 1, \\
0 & \text{otherwise}.
\end{cases}
$$

- For the steady-state firing rate $r_0$, depending on the condition of the input and coupling, we have the solutions are given by:

#### When $0 \leq r_0 W_0 + I_0 \leq 1$:
    - r_01 (in the code): 
  $$
  r_0 = \frac{1 - 2W_0 I_0 + \sqrt{1 - 4W_0 I_0}}{2 W_0^2}
  $$

    - r_02 (in the code): 
  $$
  r_0 = \frac{1 - 2W_0 I_0 - \sqrt{1 - 4W_0 I_0}}{2 W_0^2}
  $$

#### When $r_0 W_0 + I_0 > 1$:
    - r_03 (in the code): 
  $$
  r_0 = 2 W_0 + 2 \sqrt{W_0^2 + I_0 - \frac{3}{4}}
  $$

In [56]:
# TO DO: proper graph bounds, so that we can see better the solutions, plot theoretical solutions r only if r \approx \phi(wr+I)

%matplotlib widget
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from common_utils import fixed_point_solver, nonlinearity, derivative_nonlinearity, r_01, r_02, r_03

def update_plots(w0, I0):
    plt.clf()  
    
    r0_num = fixed_point_solver(w0, I0, initial_guess=r_02(w0, I0))
    
    r_01_val = r_01(w0, I0) 
    r_02_val = r_02(w0, I0) 
    r_03_val = r_03(w0, I0) 
    
    r0_values = np.linspace(-5, 18, 1000)
    phi_values = nonlinearity(w0 * r0_values + I0)

    # Determine the conditions for the fill_between
    cond_less_than_0 = w0 * r0_values + I0 < 0
    cond_between_0_and_1 = (w0 * r0_values + I0 >= 0) & (w0 * r0_values + I0 <= 1)
    cond_greater_than_1 = w0 * r0_values + I0 > 1

    plt.fill_between(r0_values, -10, 15, where=cond_less_than_0, color='lightgray', alpha=0.3, label='$(w_0 \cdot r_0 + I_0) < 0$')
    plt.fill_between(r0_values, -10, 15, where=cond_between_0_and_1, color='lightblue', alpha=0.3, label='$0 \leq (w_0 \cdot r_0 + I_0) \leq 1$')
    plt.fill_between(r0_values, -10, 15, where=cond_greater_than_1, color='lightcoral', alpha=0.3, label='$(w_0 \cdot r_0 + I_0) > 1$')

    plt.plot(r0_values, phi_values, label='$\phi(w_0 \cdot r_0 + I_0)$', color='blue')
    plt.plot(r0_values, r0_values, label='$r_0$', linestyle='--', color='red')

    plt.scatter([r_01_val], [r_01_val], color='red', label=f'$r_{{01}}$: {r_01_val:.2f}', zorder=5) 
    plt.scatter([r_02_val], [r_02_val], color='orange', label=f'$r_{{02}}$: {r_02_val:.2f}', zorder=5) 
    plt.scatter([r_03_val], [r_03_val], color='purple', label=f'$r_{{03}}$: {r_03_val:.2f}', zorder=5) 
    
    # Set the plot properties
    plt.gca().set_facecolor('white')
    plt.gcf().set_facecolor('white')
    plt.xlabel('$r_0$', color='black')
    plt.ylabel('Value', color='black')
    plt.tick_params(axis='x', colors='black')
    plt.tick_params(axis='y', colors='black')
    plt.legend()
    plt.grid(color='gray')
    
    plt.show()

w0_slider = widgets.FloatSlider(value=-1, min=-5, max=5, step=0.1, description='$w_0$:', continuous_update=False)
I0_slider = widgets.FloatSlider(value=0.5, min=0, max=2, step=0.01, description='$I_0$:', continuous_update=False)

widgets.interactive(update_plots, w0=w0_slider, I0=I0_slider)


interactive(children=(FloatSlider(value=-1.0, continuous_update=False, description='$w_0$:', max=5.0, min=-5.0…

In [57]:
# TO DO: plot critical w0, fix r_03 solution

%matplotlib widget
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from scipy.integrate import solve_ivp
from common_utils import r_01, r_02, r_03, nonlinearity, derivative_nonlinearity

def dr_dt(t, r, w0, I0):
    return -r + nonlinearity(w0 * r + I0)

def update_plots(I0):
    N = 1000
    w0_values = np.linspace(-5, 5, N)
    plt.clf()  
    
    def apply_mask(w0, I_0, r_func, mask_type=1):
        r_values = r_func(w0, I_0)
        if mask_type == 1:
            mask = (0 <= w0 * r_values + I_0) & (w0 * r_values + I_0 <= 1)
        elif mask_type == 2:
            mask = w0 * r_values + I_0 > 1
        else:
            raise ValueError("Invalid mask_type specified. Use 1 or 2.")
    
        r_values_filtered = np.where(mask, r_values, np.nan)
        return w0, r_values_filtered
    
    def select_and_solve(w0, r_filtered, I_0, func):
        non_nan_indices = ~np.isnan(r_filtered)
        w0_non_nan = w0[non_nan_indices]
        
        if len(w0_non_nan) == 0:
            return np.array([]), np.array([])  
        
        indices_selected = np.linspace(0, len(w0_non_nan) - 1, min(len(w0_non_nan), 10), dtype=int)
        w0_selected = w0_non_nan[indices_selected]
        r_num = []
        t_span = [0, 100]
        t_eval = np.linspace(t_span[0], t_span[1], 1000)
        
        for w0_val in w0_selected:
            r0 = [func(w0_val, I_0)] # TOY CHECK: add +0.01 to check stability of each 
            sol = solve_ivp(dr_dt, t_span, r0, args=(w0_val, I_0), t_eval=t_eval, method='RK45')
            r_num.append(sol.y[0, -1])
        
        return w0_selected, r_num

    for func, mask_type, label, linestyle in [(r_01, 1, 'r_{01}', '--'), 
                                              (r_02, 1, 'r_{02}', '-'), 
                                              (r_03, 2, 'r_{03}', '-')]:
        w0_filtered, r_filtered = apply_mask(w0_values, I0, func, mask_type)
        if len(w0_filtered) > 0:
            plt.plot(w0_filtered, r_filtered, label=f'${label}(w_0, I_0={I0:.2f})$', linestyle=linestyle)
            
            w0_selected, r_num = select_and_solve(w0_filtered, r_filtered, I0, func)
            if len(w0_selected) > 0:
                plt.plot(w0_selected, r_num, '.', label=f'Numerical {label}', alpha=0.75, markersize=5)

    if I0 != 0:
        plt.axvline(x=1/(4*I0), color='red', linestyle='--', label='$\\frac{1}{4I_0} = ' + f'{1/(4*I0):.2f}$' + ' Vertical Line')

    plt.xlabel('$w_0$')
    plt.ylabel('$r_0$')
    plt.title(f'Bifurcation Diagram at $I_0 = {I0:.2f}$')
    plt.legend()
    plt.grid(True)
    plt.show()


slider = widgets.FloatSlider(value=1/8, min=0, max=1, step=0.01, description='$I_0$:', continuous_update=False)
widgets.interactive(update_plots, I0=slider)


interactive(children=(FloatSlider(value=0.125, continuous_update=False, description='$I_0$:', max=1.0, step=0.…