### Project Computational Science

In [11]:
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
from concept.hh_model import HodgkinHuxleyNeuron
import matplotlib
matplotlib.use('TkAgg')

N = 2
G = nx.DiGraph()
G.add_nodes_from(range(N))

for i in range(N - 1):
    G.add_edge(i, i + 1)

for node in G.nodes():
    network.nodes[node]['neuron'] = HodgkinHuxleyNeuron()

for u, v in G.edges():
    G[u][v]['weight'] = 0.05


T = 50.0
dt = 0.01
time = np.arange(0, T, dt)

network.nodes[0]['neuron'].I_ext = 15.0

V_record = {node: [] for node in network.nodes()}

for t in time:
    for node in network.nodes():
        neuron = network.nodes[node]['neuron']
        I_syn = 0.0

        # Calculate total synaptic current from neighbors
        for neighbor in network.neighbors(node):
            neighbor_neuron = network.nodes[neighbor]['neuron']
            weight = network[node][neighbor].get('weight', 0.1)
            tau = 5.0  # Synaptic decay time constant
            I_syn += weight * np.exp(-(neuron.V - neighbor_neuron.V) / tau)
            
        # Update neuron with total synaptic current
        neuron.step(dt, I_syn)
        V_record[node].append(neuron.V)
        
        if t > 30.0:
            network.nodes[0]['neuron'].I_ext = 0.0






# Plot Network
# pos = nx.spring_layout(G)  # Compute layout for the nodes
# nx.draw_networkx_nodes(G, pos, node_size=500, node_color="lightblue")
# nx.draw_networkx_edges(G, pos, edge_color="gray", arrows=True)
# nx.draw_networkx_labels(G, pos, font_size=12, font_color="black")
# edge_labels = {(u, v): f"{data['weight']:.2f}" for u, v, data in G.edges(data=True)}
# nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=10)
# plt.title("Directed Graph with Hodgkin-Huxley Neurons and Weights")
# plt.axis("off")  # Turn off the axes
# plt.show()


### Plot

In [7]:
fig, axs = plt.subplots(1, 2, figsize=(18, 8))  # Create a figure with 2 subplots

# Subplot 1: Neuron voltage traces
for node, V in V_record.items():
    axs[0].plot(time, V, label=f'Neuron: {node + 1}')
axs[0].legend()
axs[0].set_title("Neuron Voltage Traces")
axs[0].set_xlabel("Time")
axs[0].set_ylabel("Voltage")

# Subplot 2: Neural network visualization
pos = nx.circular_layout(network)
nx.draw(
    network, 
    pos,
    with_labels=True, 
    node_color='skyblue', 
    node_size=500, 
    font_size=10, 
    font_weight='bold',
    edge_color='gray',
    ax=axs[1]  # Draw the network on the second subplot
)
axs[1].set_title("Neural Network Visualization")

plt.tight_layout()
plt.show()