In [1]:
import torch
import norse.torch as norse
import torch.nn as nn

device = "cuda" if torch.cuda.is_available() else "cpu"

class STDP_SNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, tau_plus=20.0, tau_minus=20.0, learning_rate=0.005):
        super(STDP_SNN, self).__init__()

        self.lif1 = norse.LIFCell()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.synapse_weights = torch.randn(hidden_size, input_size).to(device)

        # STDP-Parameter
        self.tau_plus = tau_plus
        self.tau_minus = tau_minus
        self.learning_rate = learning_rate

    def forward(self, x, state):
        x = torch.relu(self.fc1(x))
        spiked, state = self.lif1(x, state)
        self.apply_stdp(x, spiked)
        x = self.fc2(spiked)

        return x, state

    def apply_stdp(self, prespike, postsynaptic_spike):
        time_difference = postsynaptic_spike - prespike
        weight_increase = torch.exp(-torch.abs(time_difference) / self.tau_plus)
        weight_decrease = torch.exp(-torch.abs(time_difference) / self.tau_minus)
        weight_delta = self.learning_rate * (weight_increase - weight_decrease)
        weight_delta_expanded = weight_delta.view(-1, 1)
        self.synapse_weights += weight_delta_expanded * self.synapse_weights


time_steps = 100
input_size = 10
hidden_size = 20
output_size = 5
learning_rate = 0.005

snn = STDP_SNN(input_size, hidden_size, output_size).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(snn.parameters(), lr=learning_rate)

target = torch.randint(0, output_size, (time_steps, 1)).to(device)
for t in range(time_steps):
    data = torch.rand(1, input_size).to(device)
    output, state = snn(data, None)
    loss = criterion(output, target[t].view(-1))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if t % 10 == 0:
        print(f"Step {t}, Loss: {loss.item()}")

print("Training abgeschlossen.")


Step 0, Loss: 1.788346767425537
Step 10, Loss: 1.7531473636627197
Step 20, Loss: 1.5394386053085327
Step 30, Loss: 1.7207698822021484
Step 40, Loss: 1.5212738513946533
Step 50, Loss: 1.5142273902893066
Step 60, Loss: 1.708132266998291
Step 70, Loss: 1.653657078742981
Step 80, Loss: 1.6463713645935059
Step 90, Loss: 1.631257176399231
Training abgeschlossen.
