In [11]:
from ipywidgets import VBox, HBox, Layout, interactive_output, FloatSlider
from ipywidgets import interactive as interactive_ipyw

from brian2 import * #type: ignore
import NeuronEquations
import BrianHF
import numpy as np

import plotly.graph_objects as go
from plotly.subplots import make_subplots
import webbrowser
import plotly.io as pio
# NOTE: HAVING NO RESET v=vr CAUSES MORE SPIKES TO OCCUR AS THERE IS NO INTEGRATION TIME NEEDED TO REACH THRESHOLD!!!!!!


# Use Plotly for the interactive plot


In [12]:
# Assuming NeuronEquations.EQ_SCM_IF and BrianHF.calculate_ChebyshevNeighbours are defined elsewhere
def simulate_neurons_interactive(tau, vt, vr, beta, Wi, Wk, run_time):
    start_scope()
    defaultclock.dt = 0.5*ms

    grid_size_x = 5 # Number of neurons in the x direction
    grid_size_y = 5
    N_Neurons = grid_size_x * grid_size_y # Number of neurons\
        
        
    # Neuron Equation
    Eqs_Neurons = NeuronEquations.EQ_SCM_IF

    # Neuron Parameters
    Neuron_Params = {'tau': tau*ms, 'vt': vt, 'vr': vr, 'P': 0, 'incoming_spikes': 0, 'method': 'exact'}
    # Synapse Parameters
    Syn_Params = {'Num_Neighbours' : 8, 'beta': beta, 'Wi': Wi, 'Wk': Wk}
    Num_Neighbours = Syn_Params['Num_Neighbours']

    input = []
    # Input generation
    for i in range (5):
        input.extend(range(i, N_Neurons, 5))
    times = []
    for i in range(5):
        times.extend([i]*5)
    
    G_Spikes = SpikeGeneratorGroup(N_Neurons, input, times*ms)

    neuronsGrid = NeuronGroup(N_Neurons, Eqs_Neurons+ 'lastspike : second', threshold='v>vt',
                            reset='''
                            lastspike = t
                            v = vr
                            incoming_spikes_post = 0
                            ''',
                            events={'P_ON': 'v > vt', 'P_OFF': '(timestep(t - lastspike, dt) > timestep(dt, dt) and v <= vt)'},
                            namespace=Neuron_Params)
    
    neuronsGrid.set_event_schedule('P_ON', when = 'after_thresholds')
    neuronsGrid.run_on_event('P_ON', 'P = 1' , when = 'after_thresholds')
    neuronsGrid.set_event_schedule('P_OFF', when = 'groups')
    neuronsGrid.run_on_event('P_OFF', 'P = 0', when = 'groups')

    # # Generate x and y values for each neuron
    # x_values = np.repeat(np.arange(grid_size_x), grid_size_y)
    # y_values = np.tile(np.arange(grid_size_y), grid_size_x)

    # # Assign x, y, and z values to each neuron
    # neuronsGrid.X = x_values
    # neuronsGrid.Y = y_values

    # Creating the synapses
    Syn_Input_Neurons = Synapses(G_Spikes, neuronsGrid,
                                 'beta : 1 (constant)',
                                 on_pre='''
                                 ExtIn_post = beta
                                 ''',
                                 namespace=Syn_Params)

    Syn_Neurons_Neurons = Synapses(neuronsGrid, neuronsGrid,
                               '''
                               Wi : 1
                               Wk : 1
                               ''',
                               on_pre={'pre': 'incoming_spikes_post += 1; Exc_pre = Wi',
                                   'pre_2' : 'Inh_post = P_post * (Wk/incoming_spikes_post)'},
                               namespace=Syn_Params)

    # SYNAPSE CONNECTIONS:
    # Connect the first synapses from input to neuronsGrid on a 1 to 1 basis
    Syn_Input_Neurons.connect(j = 'i')
    Syn_Input_Neurons.beta = beta

    # Unzip the pairs into two lists
    indexes_i, indexes_j = BrianHF.calculate_ChebyshevNeighbours(neuronsGrid, Num_Neighbours)
    # Connect the last group of synapses from a neuron to its neighbors
    Syn_Neurons_Neurons.connect(i=indexes_i, j=indexes_j)
    Syn_Neurons_Neurons.Wk = Wk
    Syn_Neurons_Neurons.Wi = Wi

    # SETTING UP MONITORS:
    # Monitor the spikes
    SpikeMon_Input = SpikeMonitor(G_Spikes)
    SpikeMon_Neurons = SpikeMonitor(neuronsGrid)
    
    # RUNNING THE SIMULATION:
    # display(scheduling_summary())
    BrianLogger.log_level_error()
    run(run_time)
    
    return SpikeMon_Input, SpikeMon_Neurons

def update_plot_with_plotly(tau, vt, vr, beta, Wi, Wk):
    # Simulate the neurons (assuming this function returns the necessary data)
    SpikeMon_Input, SpikeMon_Neurons = simulate_neurons_interactive(tau, vt, vr, beta, Wi, Wk)
    
    # Create a subplot with Plotly
    fig = make_subplots(rows=2, cols=1)
    
    # Input spikes
    fig.add_trace(go.Scatter(x=SpikeMon_Input.t/ms, y=SpikeMon_Input.i, mode='markers', name='Input Spikes'), row=1, col=1)
    
    # Neuron Grid spikes
    fig.add_trace(go.Scatter(x=SpikeMon_Neurons.t/ms, y=SpikeMon_Neurons.i, mode='markers', name='Neuron Grid Spikes'), row=2, col=1)
    
    # Update layout
    fig.update_layout(height=600, width=800, title_text="Neuron Simulation Results")
    fig.update_xaxes(title_text="Time (ms)")
    fig.update_yaxes(title_text="Neuron index", row=1, col=1)
    fig.update_yaxes(title_text="Neuron index", row=2, col=1)
    
    # Save the plot as HTML and open in a browser
    html_path = "plotly_output.html"
    pio.write_html(fig, file=html_path)
    webbrowser.open_new_tab(html_path)

In [13]:
# Create sliders for parameters

tau_slider = FloatSlider(min=0.01, max=5, step=0.01, value=0.2, description='tau', continuous_update=False)
vt_slider = FloatSlider(min=0.01, max=1.0, step=0.01, value=0.1, description='vt', continuous_update=False)
vr_slider = FloatSlider(min=0.0, max=1.1, step=0.1, value=0.0, description='vr', continuous_update=False)
beta_slider = FloatSlider(min=0.1, max=5.0, step=0.05, value=0.5, description='beta', continuous_update=False)
Wi_slider = FloatSlider(min=0.0, max=20.0, step=0.01, value=6.0, description='Wi', continuous_update=False)
Wk_slider = FloatSlider(min=-20.0, max=0.0, step=0.01, value=-3.0, description='Wk', continuous_update=False)

# Display the interactive plot
# Create a dictionary of sliders

# Create a dictionary of sliders
sliders = {
    'tau': tau_slider,
    'vt': vt_slider,
    'vr': vr_slider,
    'beta': beta_slider,
    'Wi': Wi_slider,
    'Wk': Wk_slider
}

# Use the ** operator to unpack the dictionary into keyword arguments
interactive_plot_with_plotly = interactive_ipyw(update_plot_with_plotly, **sliders)