# CA3/CA1 Neuromorphic Module – Advanced Visual Demo

This notebook provides a more advanced demo of the CA3→CA1 network:

1. Builds a 200-neuron network (100 CA3, 100 CA1)
2. Trains the network on several binary input patterns
3. Applies simple STDP plasticity after each pattern
4. Visualizes:
   - Firing activity of CA3 and CA1
   - Distribution of synaptic weights before and after learning
   - Recovery of a stored pattern from a noisy cue


In [ ]:
import random
import numpy as np
import matplotlib.pyplot as plt

from backend.network import CA3CA1Network
from backend.simulation import Simulation
from backend.plasticity import STDP
from agents.input_agent import InputAgent
from agents.monitoring_agent import MonitoringAgent
from orchestrator.orchestrator import Orchestrator

# Create network and connect
net = CA3CA1Network(n_ca3=100, n_ca1=100)
net.connect()
stdp = STDP()

print("Network created.")
print(f"CA3 neurons: {len(net.ca3)}")
print(f"CA1 neurons: {len(net.ca1)}")
print(f"Total synapses: {len(net.synapses)}")

In [ ]:
# Helper functions
def get_activity(neurons):
    return np.array([1 if n.fired else 0 for n in neurons])

def get_weights(synapses):
    return np.array([s.weight for s in synapses])

def present_pattern(pattern, net, stdp, steps=3):
    input_agent = InputAgent()
    monitor = MonitoringAgent()
    orch = Orchestrator(net, [input_agent, monitor])

    # Run one input cycle
    orch.run_cycle(pattern)

    sim = Simulation(net)
    sim.run(steps)

    # Apply STDP after activity
    for syn in net.synapses:
        stdp.apply(syn, syn.pre.fired, syn.post.fired)

    ca3_act = get_activity(net.ca3)
    ca1_act = get_activity(net.ca1)
    return ca3_act, ca1_act


In [ ]:
# Generate a set of random binary patterns for CA3
num_patterns = 5
n_ca3 = len(net.ca3)

patterns = []
for _ in range(num_patterns):
    patterns.append([random.choice([0, 1]) for _ in range(n_ca3)])

print("Generated", num_patterns, "patterns.")

In [ ]:
# Visualize initial synaptic weight distribution
initial_weights = get_weights(net.synapses)

plt.figure(figsize=(6, 4))
plt.hist(initial_weights, bins=30)
plt.title("Initial Synaptic Weight Distribution")
plt.xlabel("Weight")
plt.ylabel("Count")
plt.show()

In [ ]:
# Train the network on each pattern and record activity
ca3_activities = []
ca1_activities = []

for idx, p in enumerate(patterns):
    print(f"\nPresenting pattern {idx}...")
    ca3_act, ca1_act = present_pattern(p, net, stdp, steps=3)
    ca3_activities.append(ca3_act)
    ca1_activities.append(ca1_act)

ca3_activities = np.array(ca3_activities)
ca1_activities = np.array(ca1_activities)
print("Training completed.")

In [ ]:
# Visualize CA3 and CA1 activity as heatmaps
plt.figure(figsize=(8, 4))
plt.imshow(ca3_activities, aspect='auto')
plt.title("CA3 Activity Across Patterns")
plt.xlabel("Neuron index")
plt.ylabel("Pattern index")
plt.colorbar(label="Firing (0/1)")
plt.show()

plt.figure(figsize=(8, 4))
plt.imshow(ca1_activities, aspect='auto')
plt.title("CA1 Activity Across Patterns")
plt.xlabel("Neuron index")
plt.ylabel("Pattern index")
plt.colorbar(label="Firing (0/1)")
plt.show()

In [ ]:
# Visualize synaptic weights after training
final_weights = get_weights(net.synapses)

plt.figure(figsize=(6, 4))
plt.hist(final_weights, bins=30)
plt.title("Final Synaptic Weight Distribution After Training")
plt.xlabel("Weight")
plt.ylabel("Count")
plt.show()

print("Initial mean weight:", float(np.mean(initial_weights)))
print("Final mean weight:", float(np.mean(final_weights)))

In [ ]:
# Test pattern recall with noisy input
test_index = 0
original = patterns[test_index]

# Create a noisy version (flip bits with probability 0.2)
noisy = [v if random.random() > 0.2 else 1 - v for v in original]

print("Original pattern (first 30 bits):", original[:30])
print("Noisy pattern    (first 30 bits):", noisy[:30])

# Present noisy pattern WITHOUT applying further learning
input_agent = InputAgent()
monitor = MonitoringAgent()
orch = Orchestrator(net, [input_agent, monitor])

orch.run_cycle(noisy)
sim = Simulation(net)
sim.run(3)

recalled_ca1 = get_activity(net.ca1)

plt.figure(figsize=(10, 4))
plt.subplot(1, 3, 1)
plt.imshow(np.array(original)[None, :], aspect='auto')
plt.title("Original (CA3)")
plt.yticks([])
plt.subplot(1, 3, 2)
plt.imshow(np.array(noisy)[None, :], aspect='auto')
plt.title("Noisy Input (CA3)")
plt.yticks([])
plt.subplot(1, 3, 3)
plt.imshow(recalled_ca1[None, :], aspect='auto')
plt.title("CA1 Response")
plt.yticks([])
plt.tight_layout()
plt.show()