## Imports

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
import numpy as np
import h5py

# Similar to CNN and reads in data and seperates from the h5 file and turn it into a ds structure

In [4]:
h5_path = 'Part1SubjectHB10.h5'

f = h5py.File(h5_path, 'r')

fs = f.attrs['fs'][0]
print("Sampling rate: %.1f Hz" % (fs))

states = []
for name, grp in f.items():
    states.append(name)
    print("State: %s" % (name))
    print("Segment IDs:", list(grp.keys()))

lfp = {key: [] for key in states}
for key in states:
    group = f[key]
    n = len(group)
    for i in range(n):
        lfp[key].append(group[str(i+1)][()].astype(float))

all_signals = lfp['NREM'] + lfp['WAKE']
all_labels = [0] * len(lfp['NREM']) + [1] * len(lfp['WAKE'])

max_length = max(signal.shape[0] for signal in all_signals)

padded_signals = [np.pad(signal, (0, max_length - signal.shape[0]), mode='constant') if signal.shape[0] < max_length else signal[:max_length] for signal in all_signals]

# Stack the padded signals
signals = np.stack(padded_signals)
labels = np.array(all_labels)

print("Signals shape:", signals.shape)
print("Labels shape:", labels.shape)


Sampling rate: 1000.0 Hz
State: NREM
Segment IDs: ['1', '10', '11', '12', '13', '14', '15', '16', '17', '2', '3', '4', '5', '6', '7', '8', '9']
State: WAKE
Segment IDs: ['1', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '2', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '3', '30', '31', '32', '33', '34', '35', '36', '37', '38', '4', '5', '6', '7', '8', '9']
Signals shape: (55, 285000)
Labels shape: (55,)


In [5]:
class H5SignalDataset(Dataset):
    def __init__(self, signals, labels):
        self.data = torch.tensor(signals, dtype=torch.float32)
        if len(self.data.shape) == 2:
            self.data = self.data.unsqueeze(-1)  # (N, Length, 1) for LSTM
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

full_dataset = H5SignalDataset(signals, labels)

train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)


In [None]:
class SignalLSTM(nn.Module):
    def __init__(self, input_dim=1, hidden_dim=64, output_dim=2, num_layers=2):
        super(SignalLSTM, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        lstm_out, _ = self.lstm(x)
        out = self.fc(lstm_out[:, -1, :])  # output of the classification
        return out

model = SignalLSTM()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

epochs = 5
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for signals, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(signals)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    train_loss = running_loss / len(train_loader)
    train_acc = correct / total
    train_losses.append(train_loss)
    train_accuracies.append(train_acc)

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for signals, labels in val_loader:
            outputs = model(signals)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_loss = val_loss / len(val_loader)
    val_acc = correct / total
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)

    print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}")

Epoch 1/5, Train Loss: 0.6973, Val Loss: 0.6953, Train Acc: 0.3864, Val Acc: 0.4545


# Conclusion:

LSTM is a tupe of classification method for TNN and it is dependent on gateways during training. This LSTM model processes the raw signal one timestep at a time and outputs a hidden representation to classify the layers. The training is similar to the last CNN model and on the forward pass it calculates and predicts per epoch and then on the back updates the weights based on the error calculations, training is used to minimize training loss from the data and allows it to better accurately predict using the raw signal as inputs.