Schritt 1: Alle nötigen Imports importieren

In [31]:
import torch
import torch.nn as nn
from norse.torch import LIFCell, LIFParameters

Schritt 2: Cuda und GPU als device festlegen

In [32]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Schritt 3: Implementierung des SNN und des LIF

In [48]:
import torch
import torch.nn as nn
import torch.optim as optim

# Beispiel LIF-Parameter und Zelle
class LIFParameters:
    def __init__(self):
        self.v_rest = -65  # Ruhepotenzial
        self.v_thresh = -50  # Schwellenwert
        self.v_reset = -65  # Reset-Potenzial
        self.tau_mem = 10  # Membranzeitkonstante

class LIFCell(nn.Module):
    def __init__(self, p):
        super(LIFCell, self).__init__()
        self.p = p
        self.v = None  # Membranpotential

    def forward(self, x):
        # Einfacher LIF-Zellenmechanismus, nur als Platzhalter
        spike_train = torch.randn_like(x) > 0  # Zufällige Spikes zur Demonstration
        return spike_train, None

class STDP_SNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, learning_rate=0.01, tau_plus=20.0, tau_minus=20.0):
        super(STDP_SNN, self).__init__()
        # LIF-Parameter und Zellen
        lif_params = LIFParameters()
        self.lif1 = LIFCell(p=lif_params)
        self.synapse_weights = nn.Parameter(torch.randn(input_size, hidden_size) * 0.01)
        self.fc2 = nn.Linear(hidden_size, output_size)

        # Lernparameter
        self.learning_rate = learning_rate
        self.tau_plus = tau_plus
        self.tau_minus = tau_minus

    # STDP-Lernmethode
    def apply_stdp(self, prespike_time, postsynaptic_time):
        prespike_time = prespike_time[:, None]
        postsynaptic_time = postsynaptic_time[None, :]
        time_difference = postsynaptic_time - prespike_time
        weight_increase = torch.exp(-torch.abs(time_difference) / self.tau_plus)
        weight_decrease = torch.exp(-torch.abs(time_difference) / self.tau_minus)
        weight_delta = self.learning_rate * (weight_increase - weight_decrease)
        weight_delta[time_difference < 0] *= -1
        new_weights = torch.clamp(self.synapse_weights + weight_delta, min=0.0, max=1.0)
        self.synapse_weights = nn.Parameter(new_weights)

    # Normalisierung der Eingabedaten
    def normalize_input(self, data):
        std = data.std()
        return (data - data.mean()) / std if std > 0 else data - data.mean()

    # Vorwärtsdurchlauf
    def forward(self, x):
        x = self.normalize_input(x)
        spikes, _ = self.lif1(x)  # Extrahiere nur den Spike-Wert
        spikes_flat = spikes.view(spikes.size(0), -1)  # Flatten
        spikes_flat = spikes_flat[:, :self.fc2.in_features]  # Um auf die Anzahl der Features in fc2 zu passen
        
        # Umwandlung der Spikes in Float
        spikes_flat = spikes_flat.float()  # Umwandlung in float32 für den Linear-Layer
        
        output = self.fc2(spikes_flat)
        return output



Test (noch keine echten Daten)

In [88]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Transform für den MNIST-Datensatz
transform = transforms.Compose([
    transforms.ToTensor(),         # Konvertiere in Tensor
    transforms.Normalize((0.5,), (0.5,))  # Normalisierung
])

# Lade den MNIST-Datensatz
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

def image_to_spikes(image, time_steps=100):
    """
    Converts an image to a spike representation.
    image: Tensor of size (Batch_Size, 784) with normalized pixel values.
    time_steps: Number of time steps.
    """
    batch_size = image.size(0)  # Correctly use the batch size from the input image
    spikes = torch.zeros(batch_size, time_steps, image.size(1))  # Shape: [Batch_Size, Time_Steps, Features]

    for t in range(time_steps):
        spike_train = (image > (t / time_steps)).float()  # Generate spike train based on pixel intensity
        spikes[:, t, :] = spike_train  # Set the spike train at the current time step
    
    return spikes


In [89]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        # 1 input channel (grayscale), 32 output channels, 3x3 kernel
        self.conv1 = nn.Conv2d(1, 32, 3)  
        self.pool = nn.MaxPool2d(2, 2)  # MaxPooling with 2x2 kernel and stride 2
        # Fully connected layer with 32*13*13 input features (after conv + pool) and 10 output classes
        self.fc1 = nn.Linear(32 * 13 * 13, 10)  

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # Convolution + ReLU + MaxPooling
        # Flatten the output of the conv layer for the fully connected layer
        x = x.view(-1, 32 * 13 * 13)  # Correctly reshape to [batch_size, 32 * 13 * 13]
        x = self.fc1(x)  # Fully connected layer
        return x

In [103]:
# Hyperparameter
input_size = 28 * 28
hidden_size = 100
output_size = 10
learning_rate = 0.001
epochs = 5
time_steps = 100

snn_model = STDP_SNN(input_size, hidden_size, output_size, learning_rate)
cnn_model = SimpleCNN().to(device)
snn_optimizer = torch.optim.Adam(snn_model.parameters(), lr=learning_rate)
cnn_optimizer = torch.optim.Adam(cnn_model.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()
# Function to average spikes over time
def average_spikes(spikes, time_steps):
    # Average spikes over the time dimension (e.g., by taking mean across time)
    return spikes.mean(dim=1)  # Shape: [batch_size, 784] after averaging over time

# Main training loop
for epoch in range(epochs):
    snn_model.train()
    cnn_model.train()
    running_loss = 0.0
    running_loss_2 = 0.0

    for images, labels in train_loader:
        images = images.view(images.size(0), -1)  # Flatten the images (28x28 -> 784)
        
        # Generate spike trains using the image-to-spike conversion function
        spikes = image_to_spikes(images, time_steps)

        # Average spikes over the time dimension to collapse it into a 2D representation
        averaged_spikes = average_spikes(spikes, time_steps)  # Shape: [batch_size, 784]

        # Reshape the averaged spikes tensor to [batch_size, 1, 28, 28] for CNN
        spikes_for_cnn = averaged_spikes.view(averaged_spikes.size(0), 1, 28, 28)  # Shape: [batch_size, 1, 28, 28]

        # Pass the spike trains through the models
        outputs = snn_model(spikes)
        outputs_2 = cnn_model(spikes_for_cnn)  # Use reshaped and averaged spikes for CNN

        # Calculate the loss
        loss = loss_fn(outputs, labels)
        loss_2 = loss_fn(outputs_2, labels)

        # Backpropagation
        snn_optimizer.zero_grad()
        cnn_optimizer.zero_grad()
        loss.backward()
        loss_2.backward()
        snn_optimizer.step()
        cnn_optimizer.step()

        # Accumulate loss
        running_loss += loss.item()
        running_loss_2 += loss_2.item()

    # Print the average loss for the epoch
    print(f'Epoch [{epoch+1}/{epochs}], SNN Loss: {running_loss/len(train_loader):.4f}')
    print(f'Epoch [{epoch+1}/{epochs}], CNN Loss: {running_loss_2/len(train_loader):.4f}')


# Testen
snn_model.eval()
cnn_model.eval()
correct = 0
correct_2 = 0  # Für das CNN
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images = images.view(images.size(0), -1)  # Bilder flach machen (28x28 -> 784)

        # Generiere Spike-Daten
        spikes = image_to_spikes(images, time_steps)

        # Durchschnitt der Spikes über die Zeitdimension (für CNN)
        averaged_spikes = average_spikes(spikes, time_steps)  # Shape: [batch_size, 784]

        # Reshape der Spike-Daten für CNN (zu [batch_size, 1, 28, 28])
        spikes_for_cnn = averaged_spikes.view(averaged_spikes.size(0), 1, 28, 28)

        # Modelle ausführen
        outputs = snn_model(spikes)
        outputs_2 = cnn_model(spikes_for_cnn)  # CNN verwendet die vorbereiteten Spike-Daten

        # Vorhersage und Genauigkeit berechnen
        _, predicted = torch.max(outputs, 1)
        _, predicted_2 = torch.max(outputs_2, 1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        correct_2 += (predicted_2 == labels).sum().item()

# Gesamte Genauigkeit ausgeben
print(f'Accuracy (SNN): {100 * correct / total:.2f}%')
print(f'Accuracy (CNN): {100 * correct_2 / total:.2f}%')


Epoch [1/5], SNN Loss: 2.3243
Epoch [1/5], CNN Loss: 0.2888
Epoch [2/5], SNN Loss: 2.3100
Epoch [2/5], CNN Loss: 0.1011
Epoch [3/5], SNN Loss: 2.3087
Epoch [3/5], CNN Loss: 0.0727
Epoch [4/5], SNN Loss: 2.3084
Epoch [4/5], CNN Loss: 0.0582
Epoch [5/5], SNN Loss: 2.3074
Epoch [5/5], CNN Loss: 0.0488
Accuracy (SNN): 10.83%
Accuracy (CNN): 97.91%
