Schritt 1: Alle nötigen Imports importieren

In [None]:
import torch
import torch.nn as nn
from norse.torch import LIFCell, LIFParameters, LIFState

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Schritt 2: Implementierung des SNN und des LIF

In [None]:
class STDP_SNN(nn.Module):
    # Definition des SNN und Integration des LIF-Modells
    def __init__(self, input_size, hidden_size, output_size, learning_rate=0.01, tau_plus=20.0, tau_minus=20.0):
        super(STDP_SNN, self).__init__()

        # LIF-Parameter (normalerweise sollten diese konstant bleiben)
        lif_params = LIFParameters()
        self.lif1 = LIFCell(p=lif_params)
        self.synapse_weights = torch.randn(input_size, hidden_size) * 0.01
        
        # Normalisierung der Einagbedaten
        nn.init.xavier_normal_(self.synapse_weights)

        self.fc2 = nn.Linear(hidden_size, output_size)
        self.learning_rate = learning_rate
        self.tau_plus = tau_plus
        self.tau_minus = tau_minus

        # Aktivierungsfunktion für die verborgene Schicht
        self.hidden_activation = nn.ReLU()

    # Lernmethode implementieren
    def apply_stdp(self, prespike_time, postsynaptic_time):
        time_difference = postsynaptic_time - prespike_time
        weight_increase = torch.exp(-torch.abs(time_difference) / self.tau_plus)
        weight_decrease = torch.exp(-torch.abs(time_difference) / self.tau_minus)
        weight_delta = self.learning_rate * (weight_increase - weight_decrease)
        weight_delta_expanded = weight_delta.view(self.synapse_weights.shape)
        self.synapse_weights += weight_delta_expanded

    # Verlustfunktion definieren
    def apply_loss_function(self, y_true, y_pred):
        y_pred = torch.sigmoid(y_pred)
        return nn.BCELoss()(y_pred, y_true)

    # Optimierungsfunktion definieren
    def apply_opt_function(self, loss):
        self.optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    # Normalisierung der Eingabedaten
    def normalize_input(self, data):
        return (data - data.mean()) / data.std()

    # Starten des SNN
    def forward(self, x):
        x = self.normalize_input(x)
        spikes = self.lif1(x)
        hidden_output = self.hidden_activation(spikes)
        output = self.fc2(hidden_output)
        return output
