In [None]:
import torch
import torch.nn as nn
import numpy as np
import random

# Set seeds for reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

# Constants
ROWS, COLS = 7, 5
INPUT_NEURONS = 37          # padded input size
RECOGNIZE_NEURONS = 43      # upper bound for recognition layer
CLASS_NEURONS = 26
SPIKE_LENGTH = 35
A_CLASS_INDEX = 0
TRAIN_SAMPLES = 3000
TEST_SAMPLES = 400
use_biological_input = True  # üîÅ Toggle this to switch input mode

# ------------------------
# Encode letter 'A'
# ------------------------
def encode_letter_A():
    grid = np.array([
        [0, 0, 1, 0, 0],
        [0, 1, 0, 1, 0],
        [1, 0, 0, 0, 1],
        [1, 1, 1, 1, 1],
        [1, 0, 0, 0, 1],
        [1, 0, 0, 0, 1],
        [1, 0, 0, 0, 1]
    ])
    flat = grid.flatten().tolist()
    flat += [0, 0]  # pad to 37
    return torch.tensor(flat[:INPUT_NEURONS], dtype=torch.float32)

# ------------------------
# Noise injection
# ------------------------
def generate_noisy_sample(base_tensor, noise_level=0.20):
    base_array = base_tensor.detach().cpu().numpy()
    noise = np.random.rand(*base_array.shape) < noise_level
    noisy_array = np.where(noise, 1 - base_array, base_array)
    return torch.tensor(noisy_array, dtype=torch.float32)

# ------------------------
# Dynamic Input Module
# ------------------------
def input_module_dynamic(num_inputs=35, num_spikes=SPIKE_LENGTH):
    I = torch.zeros(num_inputs)
    for _ in range(num_spikes):
        i = random.randint(0, num_inputs - 1)
        I[i] += 1
    active_indices = []
    for i in range(num_inputs):
        if I[i] >= 2:
            I[i] -= 1
        if I[i] > 0:
            active_indices.append(i)
    R = I[active_indices]
    return R, active_indices

# ------------------------
# Dynamic Recognize Module 1
# ------------------------
def recognize_module_1_dynamic(R, active_indices, group_size=7):
    num_active = len(active_indices)
    num_groups = (num_active + group_size - 1) // group_size
    T = torch.zeros(num_groups, group_size)
    for idx, r_val in enumerate(R):
        k = idx // group_size
        j = idx % group_size
        T[k, j] = r_val
    return T

# ------------------------
# Dynamic Recognize Module 2
# ------------------------
def recognize_module_2_dynamic(T):
    Out = torch.zeros(T.shape[0])
    for k in range(T.shape[0]):
        Out[k] = T[k].sum()
    return Out

# ------------------------
# Dynamic Structured Pipeline
# ------------------------
def simulate_structured_input_dynamic(noise_level=0.20):
    R, active_indices = input_module_dynamic(num_inputs=35, num_spikes=SPIKE_LENGTH)
    T = recognize_module_1_dynamic(R, active_indices, group_size=7)
    Out = recognize_module_2_dynamic(T)

    # Pad to fixed recognition size (43)
    vec = torch.zeros(RECOGNIZE_NEURONS)
    for i in range(min(RECOGNIZE_NEURONS, len(Out))):
        vec[i] = Out[i]

    noisy_vec = generate_noisy_sample(vec, noise_level=noise_level)
    return noisy_vec.unsqueeze(0)

# ------------------------
# Noisy encoded input (direct grid)
# ------------------------
def simulate_noisy_encoded_input(noise_level=0.20):
    clean = encode_letter_A()
    noisy = generate_noisy_sample(clean, noise_level=noise_level)
    return noisy.unsqueeze(0)

# ------------------------
# Pixel retention
# ------------------------
def count_matching_black_pixels(original, noisy):
    black_matches = torch.sum((original == 1) & (noisy == 1)).item()
    total_black = torch.sum(original == 1).item()
    retention_ratio = black_matches / total_black * 100
    return black_matches, total_black, retention_ratio

# ------------------------
# Hebbian Spiking Layer
# ------------------------
class HebbianSpikingLayer(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.weights = nn.Parameter(torch.randn(output_size, input_size))
        self.threshold = nn.Parameter(torch.ones(output_size) * 0.5)

    def forward(self, x):
        spike_counts = torch.matmul(x, self.weights.T)
        self.last_input = x
        self.last_output = spike_counts
        return spike_counts

    def hebbian_update(self, learning_rate=0.2):
        with torch.no_grad():
            for i in range(self.weights.shape[0]):
                for j in range(self.weights.shape[1]):
                    x = self.last_input[0][j]
                    y = self.last_output[0][i]
                    if x == 1 and y > self.threshold[i]:
                        self.weights[i][j] += learning_rate
                    elif x == 1 and y <= self.threshold[i]:
                        self.weights[i][j] -= learning_rate
                    elif x == 0 and y > self.threshold[i]:
                        self.weights[i][j] -= learning_rate

# ------------------------
# Multi-layer Hebbian Network
# ------------------------
class HebbianSpikingNetwork(nn.Module):
    def __init__(self, layer_sizes):
        super().__init__()
        self.layers = nn.ModuleList([
            HebbianSpikingLayer(layer_sizes[i], layer_sizes[i+1])
            for i in range(len(layer_sizes) - 1)
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

    def hebbian_update(self, learning_rate=0.2):
        for layer in self.layers:
            layer.hebbian_update(learning_rate)

# ------------------------
# Initialize model
# ------------------------
net = HebbianSpikingNetwork([RECOGNIZE_NEURONS, RECOGNIZE_NEURONS, CLASS_NEURONS])
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()
target_class = torch.tensor([A_CLASS_INDEX])

# ------------------------
# Training Phase (Clean Only)
# ------------------------
print("üîß Training Phase (Clean Only)")
for epoch in range(TRAIN_SAMPLES):
    # Always use clean input (no noise)
    if use_biological_input:
        input_A = simulate_structured_input_dynamic(noise_level=0.0)
    else:
        input_A = simulate_noisy_encoded_input(noise_level=0.0)

    optimizer.zero_grad()
    output = net(input_A)
    loss = loss_fn(output, target_class)
    loss.backward()
    optimizer.step()
    net.hebbian_update()

    if epoch % 100 == 0:
        predicted = torch.argmax(output).item()
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}, Predicted Index: {predicted}")
        print(f"Output: {output.detach().numpy()}")

# ------------------------
# Testing Phase
# ------------------------
print("\nüß† Testing Phase")
correct = 0
original = encode_letter_A()
for _ in range(TEST_SAMPLES):
    if use_biological_input:
        test_input = simulate_structured_input_dynamic(noise_level=0.0)
        predicted_class = torch.argmax(net(test_input)).item()
    else:
        predicted_class = torch.argmax(net(original.unsqueeze(0))).item()
    if predicted_class == A_CLASS_INDEX:
        correct += 1

accuracy = correct / TEST_SAMPLES * 100
print(f"‚Üí Recognition Accuracy (Clean Only): {accuracy:.2f}% ({correct}/{TEST_SAMPLES})")

# ------------------------
# Recognition accuracy across noise levels
# ------------------------
print("\nRecognition Accuracy Across Noise Levels:")
noise_levels = [0.00, 0.03, 0.06, 0.09, 0.12, 0.14, 0.17, 0.20]
for nl in noise_levels:
    correct = 0
    for _ in range(500):
        if use_biological_input:
            test_input = simulate_structured_input_dynamic(noise_level=nl)
            predicted_class = torch.argmax(net(test_input)).item()
        else:
            noisy_input = generate_noisy_sample(original, noise_level=nl).unsqueeze(0)
            predicted_class = torch.argmax(net(noisy_input)).item()
        if predicted_class == A_CLASS_INDEX:
            correct += 1
    accuracy = correct / 500 * 100
    print(f"Noise {int(nl*100)}% ‚Üí Accuracy: {accuracy:.2f}% ({correct}/500)")

# ------------------------
# Pixel retention (averaged across trials)
# ------------------------
print("\nBlack Pixel Retention Across Noise Levels (Averaged):")
noise_levels = [0.00, 0.03, 0.06, 0.09, 0.12, 0.14, 0.17, 0.20]
trials = 200  # number of trials per noise level

original = encode_letter_A()

for nl in noise_levels:
    avg_retention = 0
    total_black_matches = 0
    total_black = 0

    for _ in range(trials):
        noisy_sample = generate_noisy_sample(original, noise_level=nl)
        black_matches, total_black_pixels, retention = count_matching_black_pixels(original, noisy_sample)
        avg_retention += retention
        total_black_matches += black_matches
        total_black = total_black_pixels  # same across trials

    avg_retention /= trials
    print(f"Noise {int(nl*100)}% ‚Üí Avg Retention: {avg_retention:.2f}% "
          f"(avg {total_black_matches//trials}/{total_black})")

print("\nüß™ Pixel-wise Fidelity Across Noise Levels:")
for nl in noise_levels:
    pixel_matches = 0
    total_pixels = 0
    num_trials = 200
    for _ in range(num_trials):
        noisy = generate_noisy_sample(original, noise_level=nl)
        pixel_matches += torch.sum(noisy == original).item()
        total_pixels += noisy.numel()
    pixel_accuracy = pixel_matches / total_pixels * 100
    print(f"Noise {int(nl*100)}% ‚Üí Fidelity: {pixel_accuracy:.2f}%")

üîß Training Phase (Clean Only)
Epoch 0, Loss: 340.7761, Predicted Index: 8
Output: [[-107.46562   -86.58119   -65.648544   56.368416  -16.921783  -46.78425
  -146.7969    -30.6078    233.31049    94.30464  -106.851974   -7.716512
    15.722309  -71.16066    45.64531    95.45817     5.073925 -218.95131
    75.063416   51.873108   80.125595  169.40639  -201.54193  -210.65039
  -121.93263   -33.075127]]
Epoch 100, Loss: 0.0000, Predicted Index: 0
Output: [[ 390.41342   -129.106      -23.894978    57.598263    84.69662
   -45.53135    -96.03798    -87.2378       5.4866266   81.451225
   -26.010754    86.31819     23.728823   -42.01634    -97.48688
   121.13432    110.03923   -228.84438    -28.854805    37.37793
   197.76855   -166.73006   -147.02023   -116.64503    -43.462524
   -22.874529 ]]
Epoch 200, Loss: 0.0000, Predicted Index: 0
Output: [[1069.8762    -102.818726     6.2571564   73.71859    147.01134
    15.168076  -101.34212   -136.21184      6.1322174   21.488968
    45.02735   