# Implementierung eines SNN und CNN: Ein Vergleich mit dem MNIST-Datensatz


*Author: Ümmühan Ay*

Schritt 1: Alle nötigen Imports importieren

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import jit
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

Schritt 2: Cuda und GPU als device festlegen

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import torch
print("PyTorch-Version:", torch.__version__)
print("CUDA verfügbar:", torch.cuda.is_available())
print("CUDA-Version in PyTorch:", torch.version.cuda)
print("Gefundene GPUs:", torch.cuda.device_count())
if torch.cuda.is_available():
    print("GPU-Name:", torch.cuda.get_device_name(0))


PyTorch-Version: 2.6.0+cpu
CUDA verfügbar: False
CUDA-Version in PyTorch: None
Gefundene GPUs: 0


Schritt 3: Implementierung des SNN und des LIF

In [3]:
class SNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, learning_rate=0.01, tau_plus=20.0, tau_minus=20.0, tau_mem=10.0, v_rest=-65.0, v_thresh=-50.0, v_reset=-65.0):
        super(SNN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.synapse_weights = nn.Parameter(torch.randn(input_size, hidden_size) * 0.01)
        self.fc2 = nn.Linear(hidden_size, output_size)

        self.learning_rate = learning_rate
        self.tau_plus = tau_plus
        self.tau_minus = tau_minus

        self.tau_mem = tau_mem  # Membranzeitkonstante
        self.v_rest = v_rest  # Ruhepotenzial
        self.v_thresh = v_thresh  # Schwellenwert
        self.v_reset = v_reset  # Reset-Potenzial

        # Membranpotentiale initialisieren
        self.v_hidden = torch.full((hidden_size,), self.v_rest, device="cuda")

        # STDP Zustände (Spuren für pre- und post-Spikes)
        self.pre_trace = torch.zeros(input_size, device="cuda")
        self.post_trace = torch.zeros(hidden_size, device="cuda")

    def forward(self, x, dt=1e-3):
        batch_size, time_steps, _ = x.shape  
        dt_tensor = torch.tensor(dt, device=x.device)
        spikes_out = torch.zeros(batch_size, time_steps, self.hidden_size, device=x.device)

        for t in range(time_steps):
            z_pre = x[:, t, :]
            # Update des Membranpotentials mit Leck (LIF-Modell)
            self.v_hidden = self.v_hidden * torch.exp(-dt_tensor / self.tau_mem) + torch.matmul(z_pre, self.synapse_weights)
            # Spiking: Neuron feuert, wenn Membranpotential den Schwellenwert überschreitet
            z_post = (self.v_hidden > self.v_thresh).float()
            # Reset des Membranpotentials nach dem Spike
            self.v_hidden = torch.where(z_post > 0, self.v_reset, self.v_hidden)
            spikes_out[:, t, :] = z_post
            self.stdp_update(z_pre, z_post, dt)

        output = self.fc2(spikes_out.mean(dim=1))  
        return output

    def stdp_update(self, z_pre, z_post, dt):
        self.pre_trace = (1 - dt / self.tau_plus) * self.pre_trace + z_pre.mean(dim=0)
        self.post_trace = (1 - dt / self.tau_minus) * self.post_trace + z_post.mean(dim=0)
        hebbian = torch.ger(self.post_trace, self.pre_trace)  # (hidden_size, input_size)
        anti_hebbian = torch.ger(self.pre_trace, self.post_trace)  # (input_size, hidden_size)
        # STDP-Gewichtsänderung
        dw = self.learning_rate * (hebbian - anti_hebbian.T)  
        # Gewichte aktualisieren und begrenzen
        self.synapse_weights.data += dw.T  # Transponieren, damit es (784, 1000) hat


Schritt 4: MNIST-Datensatz laden

In [4]:
batch_size = 32
time_steps = 50

transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

def image_to_spikes(images, time_steps):
    batch_size, num_inputs = images.shape
    
    images = images / 255.0
    random_values = torch.rand(batch_size, time_steps, num_inputs, device=images.device)

    spikes = (random_values < images.unsqueeze(1)).float()
    
    return spikes


Schritt 5: CNN erstellen

In [5]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        # 1 input channel (grayscale), 32 output channels, 3x3 kernel
        self.conv1 = nn.Conv2d(1, 32, 3)  
        self.pool = nn.MaxPool2d(2, 2)
        # Fully connected layer with 32*13*13 input features (after conv + pool) and 10 output classes
        self.fc1 = nn.Linear(32 * 13 * 13, 10)  

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # Convolution + ReLU + MaxPooling
        # Flatten the output of the conv layer for the fully connected layer
        x = x.view(-1, 32 * 13 * 13)
        x = self.fc1(x)
        return x

6. Hyperparameter, Loss-Funktion und Optimizer definieren

In [6]:
# Hyperparameter
input_size = 28 * 28
hidden_size = 1000
output_size = 10
learning_rate = 0.001
epochs = 5
time_steps = 50
p = 0.5

device = torch.device('cuda')
snn_model = SNN(input_size=28*28, hidden_size=1000, output_size=10).to(device)
cnn_model = SimpleCNN().to(device)
snn_optimizer = torch.optim.Adam(snn_model.parameters(), lr=0.001)
cnn_optimizer = torch.optim.Adam(cnn_model.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()
# Function to average spikes over time
def average_spikes(spikes, time_steps):
    return spikes.mean(dim=1)

AssertionError: Torch not compiled with CUDA enabled

7. Training SNN

In [199]:
for epoch in range(epochs):
    snn_model.train()
    running_loss = 0.0
    
    for images, labels in train_loader:
        images = images.view(images.size(0), -1).to(device)  # Flatten (28x28 → 784)
        spikes = image_to_spikes(images, time_steps)  # Spikes generieren

        outputs = snn_model(spikes)
        loss = loss_fn(outputs, labels.to(device))

        snn_optimizer.zero_grad()
        loss.backward()
        snn_optimizer.step()

        running_loss += loss.item()
    
    print(f'Epoch [{epoch+1}/{epochs}], SNN Loss: {running_loss/len(train_loader):.4f}')

KeyboardInterrupt: 

8. CNN trainieren

In [8]:
# Training CNN
for epoch in range(epochs):
    cnn_model.train()
    running_loss = 0.0
    
    for images, labels in train_loader:
        images = images.view(images.size(0), -1)
        spikes = image_to_spikes(images, time_steps)
        averaged_spikes = average_spikes(spikes, time_steps)
        spikes_for_cnn = averaged_spikes.view(averaged_spikes.size(0), 1, 28, 28)
        
        outputs = cnn_model(spikes_for_cnn)
        
        loss = loss_fn(outputs, labels)
        cnn_optimizer.zero_grad()
        loss.backward()
        cnn_optimizer.step()
        
        running_loss += loss.item()
    
    print(f'Epoch [{epoch+1}/{epochs}], CNN Loss: {running_loss/len(train_loader):.4f}')

Epoch [1/5], CNN Loss: 0.2893
Epoch [2/5], CNN Loss: 0.1030
Epoch [3/5], CNN Loss: 0.0728
Epoch [4/5], CNN Loss: 0.0588
Epoch [5/5], CNN Loss: 0.0487


9. SNN Evaluieren

In [177]:
# Testing SNN
snn_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images = images.view(images.size(0), -1)
        spikes = image_to_spikes(images, time_steps)
        outputs = snn_model(spikes)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print(f'Accuracy (SNN): {100 * correct / total:.2f}%')

Accuracy (SNN): 58.75%


10. CNN Evaluieren

In [10]:
# Testing CNN
cnn_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images = images.view(images.size(0), -1)
        spikes = image_to_spikes(images, time_steps)
        averaged_spikes = average_spikes(spikes, time_steps)
        spikes_for_cnn = averaged_spikes.view(averaged_spikes.size(0), 1, 28, 28)
        outputs = cnn_model(spikes_for_cnn)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print(f'Accuracy (CNN): {100 * correct / total:.2f}%')

Accuracy (CNN): 97.76%
