In [None]:
    import numpy as np
    import pandas as pd
    import torch
    import matplotlib.pyplot as plt
    import snntorch as snn 
    import torch.nn as nn
    import torch.utils.data.dataloader
    import torchvision
    import torchvision.transforms as transforms
    from snntorch import surrogate
    from snntorch import functional as SF


#!TEST SNN FOR MNIST DATASET

# DEVICE
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#? print(device)

# HYPERPARAMETERS
# Hyperparameters
batch_size = 100
input_size = 784
hidden_size = 196
output_size = 10 
membrane_potential_decay_rate = 0.9
num_steps = 20  
epochs = 10
learning_rate = 0.001

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dataset preparation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # MNIST mean and std normalization
])

train_dataset = torchvision.datasets.MNIST(
    root="../CNN,SNN,SCNN", train=True, transform=transform, download=True)

test_dataset = torchvision.datasets.MNIST(
    root="../CNN,SNN,SCNN", train=False, transform=transform, download=True)

train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset, batch_size=batch_size, shuffle=False)

# Spike gradient
spike_grad = surrogate.atan()  # Increased slope for better gradient flow

class ImprovedSpikingNeuralNetwork(nn.Module):
    def __init__(self, input_size=784, hidden_size=196, output_size=10, 
                 beta=0.9, spike_grad=surrogate.sigmoid()):
        super().__init__()
        
        # Input layer
        self.l1 = nn.Linear(input_size, hidden_size)
        
        # Initialize weights for better gradient flow
        nn.init.xavier_uniform_(self.l1.weight)
        nn.init.zeros_(self.l1.bias)
        
        # Leaky Integrate-and-Fire (LIF) neuron for hidden layer
        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad, learn_beta=True, threshold=1.0)
        
        # Output layer
        self.l2 = nn.Linear(hidden_size, output_size)
        
        # Initialize weights for output layer
        nn.init.xavier_uniform_(self.l2.weight)
        nn.init.zeros_(self.l2.bias)
        
        # LIF neuron for output layer
        self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad, learn_beta=True)

    def forward(self, x):
        # Flatten input
        x = x.view(-1, input_size)
        
        # Initialize hidden states
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        
        # Spike and membrane potential records
        spike_record = []
        mem_record = []

        # Propagate through time steps
        for _ in range(num_steps):
            # Hidden layer
            cur1 = self.l1(x) / np.sqrt(hidden_size)
            spk1, mem1 = self.lif1(cur1, mem1)
            
            # Output layer
            cur2 = self.l2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)

            spike_record.append(spk2)
            mem_record.append(mem2)

        # Stack records
        spike_record = torch.stack(spike_record)
        
        # Sum spikes across time steps for classification
        spike_sum = torch.sum(spike_record, dim=0)
        
        return spike_sum / num_steps

def train(model, train_loader, optimizer, criterion, device, epochs=10):
    model.train()
    
    for epoch in range(epochs):
        total_loss = 0
        for images, labels in train_loader:
            # Move to device
            images = images.to(device)
            labels = labels.to(device)
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass
            output = model(images)
            
            # Compute loss
            loss = criterion(output, labels)
            
            # Backward pass with gradient clipping
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            # Update weights
            optimizer.step()
            
            total_loss += loss.item()
        
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(train_loader):.4f}')

def test(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            
            output = model(images)
            _, predicted = torch.max(output, 1)
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')
    return accuracy

# Main training script
def main():
    # Model initialization
    model = ImprovedSpikingNeuralNetwork(
        input_size=input_size, 
        hidden_size=hidden_size, 
        output_size=output_size, 
        beta=membrane_potential_decay_rate, 
        spike_grad=spike_grad
    ).to(device)

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=1e-5)

    # Training loop
    train(model, train_loader, optimizer, criterion, device, epochs)

    # Testing
    test(model, test_loader, device)

    # Gradient checking with detailed reporting
    print("\nGradient Statistics:")
    for name, param in model.named_parameters():
        if param.grad is not None:
            print(f"{name} gradient stats:")
            print(f"  Mean: {param.grad.mean().item()}")
            print(f"  Std: {param.grad.std().item()}")
            print(f"  Min: {param.grad.min().item()}")
            print(f"  Max: {param.grad.max().item()}")
        else:
            print(f"{name} has no gradient")

#Checks the gradients MUST not be 0 or NaN

if __name__ == "__main__":
    main()

for name, param in model.named_parameters():
    if param.grad is not None:
        print(f"{name} gradient stats:")
        print(f"  Mean: {param.grad.mean().item()}")

Epoch [1/10], Loss: 1.5467
Epoch [2/10], Loss: 1.4982
Epoch [3/10], Loss: 1.4901
Epoch [4/10], Loss: 1.4865
Epoch [5/10], Loss: 1.4846
Epoch [6/10], Loss: 1.4827
Epoch [7/10], Loss: 1.4812
Epoch [8/10], Loss: 1.4812
Epoch [9/10], Loss: 1.4801
Epoch [10/10], Loss: 1.4795
Test Accuracy: 97.20%

Gradient Statistics:
l1.weight gradient stats:
  Mean: -9.24091727938503e-06
  Std: 0.0013174987398087978
  Min: -0.018957529217004776
  Max: 0.021073246374726295
l1.bias gradient stats:
  Mean: 2.7900770874111913e-05
  Std: 0.0012870851205661893
  Min: -0.007065191864967346
  Max: 0.007519489154219627
lif1.beta gradient stats:
  Mean: 0.0
  Std: nan
  Min: 0.0
  Max: 0.0
l2.weight gradient stats:
  Mean: -5.4274929425446317e-05
  Std: 0.0030238705221563578
  Min: -0.020267846062779427
  Max: 0.010348159819841385
l2.bias gradient stats:
  Mean: -0.00025546387769281864
  Std: 0.007648996543139219
  Min: -0.015035318210721016
  Max: 0.01034816075116396
lif2.beta gradient stats:
  Mean: 0.0
  Std: na

  print(f"  Std: {param.grad.std().item()}")


NameError: name 'model' is not defined