In [1]:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

class EEGOnsetLSTM(nn.Module):
    def __init__(self, n_channels, hidden_size=64, num_layers=1):
        super().__init__()
        self.lstm = nn.LSTM(input_size=n_channels,
                            hidden_size=hidden_size,
                            num_layers=num_layers,
                            batch_first=True,
                            bidirectional=True)
        # Bidirektional → hidden_size * 2
        self.fc = nn.Linear(hidden_size * 2, 1)  # Output: 1 Wert pro Sample (Onset)

    def forward(self, x, lengths):
        # x: [batch, seq_len, n_channels]
        # lengths: [batch] echte Längen ohne Padding

        # Sortiere nach Länge absteigend (notwendig für pack_padded_sequence)
        lengths_sorted, sorted_idx = lengths.sort(descending=True)
        x_sorted = x[sorted_idx]

        # Packe die Sequenzen
        packed_input = pack_padded_sequence(x_sorted, lengths_sorted.cpu(), batch_first=True)

        # LSTM vorwärts
        packed_output, (hn, cn) = self.lstm(packed_input)

        # hn: [num_layers * num_directions, batch, hidden_size]
        # Für bidirektionales LSTM: 2 Richtungen → wir konkateniere die letzten Layerstates

        # Wir nehmen die letzten Layerstates beider Richtungen
        # Layerindex: -1 (letzte Schicht)
        # hn shape: [2, batch, hidden_size]
        # Wir transponieren und flatten zu [batch, hidden_size*2]
        hn = hn.view(self.lstm.num_layers, 2, x.size(0), self.lstm.hidden_size)
        hn_last_layer = hn[-1]  # Form: [2, batch, hidden_size]
        hn_cat = torch.cat((hn_last_layer[0], hn_last_layer[1]), dim=1)  # [batch, hidden_size*2]

        # Rücksortieren, um ursprüngliche Reihenfolge wiederherzustellen
        _, original_idx = sorted_idx.sort()
        hn_cat = hn_cat[original_idx]

        # Fully connected zum Onset (regression)
        output = self.fc(hn_cat).squeeze(1)  # [batch]

        return output

# Beispiel Daten und Labels (Dummy)
batch_size = 3
n_channels = 21
seq_lengths = torch.tensor([1000, 800, 600])  # variable Längen
max_len = seq_lengths.max()

# Zufällige Daten mit Padding (Nullen)
x = torch.zeros(batch_size, max_len, n_channels)
for i, length in enumerate(seq_lengths):
    x[i, :length] = torch.randn(length, n_channels)

# Beispiel-Onsets als Indexwerte (ground truth)
labels = torch.tensor([400, 350, 200], dtype=torch.float32)

# Modell, Optimizer, Loss
model = EEGOnsetLSTM(n_channels)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

# Trainingsschritt (ein Beispiel)
model.train()
optimizer.zero_grad()
outputs = model(x, seq_lengths)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

print(f"Predicted Onsets: {outputs.detach().cpu().numpy()}")
print(f"True Onsets: {labels.cpu().numpy()}")
print(f"Loss: {loss.item():.4f}")

Predicted Onsets: [0.00711361 0.02607869 0.02207232]
True Onsets: [400. 350. 200.]
Loss: 107489.0703
