In [1]:
import numpy as np

data = np.load('A01T.npz')  
signal = data['s']  # Assuming shape is (total_samples, 22)
event_types = data['etyp'].T[0]
event_positions = data['epos'].T[0]

# Initialize empty arrays
signals = []
trial_types = []
valid_labels = {769, 770, 771, 772}

for i in range(0, len(event_positions) - 1):
    event_type = event_types[i]
    next_event_type = event_types[i + 1]
    
    if event_type == 768 and next_event_type in valid_labels:  # Valid trial start
        pos = event_positions[i+1]
        
        # Extract 750 samples x 22 channels
        trial_signal = signal[pos+750 : pos+1500, 0:22]  # All 22 channels
        
        # Verify the shape is correct
        if trial_signal.shape != (750, 22):
            print(f"Unexpected shape at trial {len(signals)}: {trial_signal.shape}")
            continue
            
        signals.append(trial_signal)
        trial_types.append(next_event_type)

# Convert to numpy arrays
signals_array = np.array(signals)  # Shape (288, 750, 22)
labels_array = np.array(trial_types)  # Shape (288,)

# Verify final shapes
print("Signals shape:", signals_array.shape)
print("Labels shape:", labels_array.shape)
print("Unique labels:", np.unique(labels_array))

Signals shape: (273, 750, 22)
Labels shape: (273,)
Unique labels: [769 770 771 772]


In [2]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import torch.nn as nn 
import torch.optim as optim 

class EEGDataset(Dataset):
    def __init__(self, trials, labels):
        # Trials: (num_trials, 750, 22)
        # Labels: (num_trials,)
        
        # Normalize data per channel
        self.data = torch.tensor(trials, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)
        
        # Map original labels to 0-3
        self.label_mapping = {769: 0, 770: 1, 771: 2, 772: 3}
        self.labels = torch.tensor([self.label_mapping[int(x)] for x in labels])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

# Assuming you have loaded your data into trials and labels arrays
trials = np.load('eeg_signals.npy')  # Shape (288, 750, 22)
labels = np.load('eeg_labels.npy')     # Shape (288,)

# Split data
X_train, X_test, y_train, y_test = train_test_split(
    trials, labels, test_size=0.2, stratify=labels, random_state=42
)

# Create datasets and dataloaders
train_dataset = EEGDataset(X_train, y_train)
test_dataset = EEGDataset(X_test, y_test)

batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

In [5]:
import torch
import torch.nn as nn

class LTC_Cell(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(LTC_Cell, self).__init__()
        self.hidden_dim = hidden_dim
        self.W_xh = nn.Linear(input_dim, hidden_dim)
        self.W_hh = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.W_tau = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, x, h):
        # Compute the time constant tau using a sigmoid and a small constant offset
        tau = torch.sigmoid(self.W_tau(h)) + 0.1
        # Compute the state derivative
        dh = -h / tau + torch.tanh(self.W_xh(x) + self.W_hh(h))
        # Update the hidden state with a fixed time step of 0.1
        return h + 0.1 * dh

class SimpleHybridLTC_LSTM(nn.Module):
    def __init__(self, input_dim=22, cnn_dim=32, lstm_hidden_dim=64,
                 ltc_hidden_dim=64, num_classes=4, num_ltc_layers=1):
        super(SimpleHybridLTC_LSTM, self).__init__()
        
        # CNN block: processes the raw input
        self.cnn = nn.Sequential(
            nn.Conv1d(input_dim, cnn_dim, kernel_size=3, padding=1),
            nn.BatchNorm1d(cnn_dim),
            nn.ReLU()
        )
        
        # LSTM block: encodes sequential information from the CNN features
        self.lstm = nn.LSTM(input_size=cnn_dim, hidden_size=lstm_hidden_dim, batch_first=True)
        
        # LTC layers: further process the LSTM output for dynamic temporal behavior
        self.ltc_layers = nn.ModuleList([
            LTC_Cell(lstm_hidden_dim if i == 0 else ltc_hidden_dim, ltc_hidden_dim)
            for i in range(num_ltc_layers)
        ])
        
        # Classifier block: aggregates the processed sequence and outputs class logits
        self.classifier = nn.Sequential(
            nn.Linear(ltc_hidden_dim, 32),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(32, num_classes)
        )

    def forward(self, x):
        # x is assumed to have the shape: [batch, seq_len, channels]
        
        # Rearrange to [batch, channels, seq_len] for the CNN layer
        x = x.permute(0, 2, 1)
        x = self.cnn(x)
        
        # Swap back to [batch, seq_len, cnn_dim]
        x = x.permute(0, 2, 1)
        
        # Process the features with the LSTM layer
        x, _ = self.lstm(x)
        batch_size, seq_len, _ = x.size()
        
        # Initialize the LTC layers' hidden state(s)
        # Here we support multiple LTC layers; keep each layer’s state separately
        h_ltc = [torch.zeros(batch_size, layer.hidden_dim, device=x.device)
                 for layer in self.ltc_layers]
        
        all_hidden = []
        # Process the sequence time step by time step with LTC layers
        for t in range(seq_len):
            x_t = x[:, t, :]
            for i, ltc_layer in enumerate(self.ltc_layers):
                h_ltc[i] = ltc_layer(x_t, h_ltc[i])
                x_t = h_ltc[i]
            all_hidden.append(x_t)
        
        # Aggregate temporal information (here using mean pooling)
        hidden_stack = torch.stack(all_hidden, dim=1)
        pooled = hidden_stack.mean(dim=1)
        
        # Output the final class predictions
        return self.classifier(pooled)

# Example instantiation:
model = SimpleHybridLTC_LSTM(
    input_dim=22,         # For example, the number of EEG channels
    cnn_dim=32,           
    lstm_hidden_dim=64,  
    ltc_hidden_dim=64,   
    num_classes=4,        
    num_ltc_layers=2      
)
print(model)


SimpleHybridLTC_LSTM(
  (cnn): Sequential(
    (0): Conv1d(22, 32, kernel_size=(3,), stride=(1,), padding=(1,))
    (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (lstm): LSTM(32, 64, batch_first=True)
  (ltc_layers): ModuleList(
    (0-1): 2 x LTC_Cell(
      (W_xh): Linear(in_features=64, out_features=64, bias=True)
      (W_hh): Linear(in_features=64, out_features=64, bias=False)
      (W_tau): Linear(in_features=64, out_features=64, bias=True)
    )
  )
  (classifier): Sequential(
    (0): Linear(in_features=64, out_features=32, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=32, out_features=4, bias=True)
  )
)


In [7]:
# Initialize model
import os
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


model = SimpleHybridLTC_LSTM(
    input_dim=22,         # For example, the number of EEG channels
    cnn_dim=32,           
    lstm_hidden_dim=64,  
    ltc_hidden_dim=64,   
    num_classes=4,        
    num_ltc_layers=1      
).to(device)

model_path = 'ltc_cnn_lstm_model.pth'
if os.path.exists(model_path):
    model.load_state_dict(torch.load(model_path))
    print("Loaded model from checkpoint.")
else:
    print("No checkpoint found. Starting from scratch.")

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5)

num_epochs = 100
previous_val_acc = None
same_acc_streak = 0
best_val_acc = 0

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    
    for batch, (data, labels) in enumerate(train_loader):
        # Move data to device
        data, labels = data.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, labels)
        loss.backward()
        
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        train_loss += loss.item()
    
    # Validation
    model.eval()
    val_loss, correct, total = 0, 0, 0
   
    
    with torch.no_grad():
        for data, labels in test_loader:
            data, labels = data.to(device), labels.to(device)
            outputs = model(data)
            val_loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    val_acc = 100 * correct / total
    train_loss /= len(train_loader)
    val_loss /= len(test_loader)
    scheduler.step(val_loss)

    if previous_val_acc is None or val_acc != previous_val_acc:
        same_acc_streak = 0  # Reset counter if there's any change in accuracy
    else:
        same_acc_streak += 1  # Increment counter if accuracy is the same

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), 'ltc_cnn_lstm_model.pth')
    
    print(f'Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}% | 'f'Same accuracy streak: {same_acc_streak}/10')

    # Early stopping triggered after 10 consecutive epochs with no change in accuracy
    if same_acc_streak >= 10:
        print(f'\nEarly stopping triggered after {epoch+1} epochs!')
        break
    
    # Update previous_val_acc for the next iteration
    previous_val_acc = val_acc

# Load best model and final evaluation
print(f'\nTraining complete. Best validation accuracy: {best_val_acc:.2f}%')




No checkpoint found. Starting from scratch.
Epoch 1/100 | Train Loss: 1.3901 | Val Loss: 1.3855 | Val Acc: 25.00% | Same accuracy streak: 0/10
Epoch 2/100 | Train Loss: 1.3875 | Val Loss: 1.3838 | Val Acc: 25.00% | Same accuracy streak: 1/10
Epoch 3/100 | Train Loss: 1.3909 | Val Loss: 1.3822 | Val Acc: 25.00% | Same accuracy streak: 2/10
Epoch 4/100 | Train Loss: 1.3900 | Val Loss: 1.3816 | Val Acc: 25.00% | Same accuracy streak: 3/10
Epoch 5/100 | Train Loss: 1.3859 | Val Loss: 1.3813 | Val Acc: 25.00% | Same accuracy streak: 4/10
Epoch 6/100 | Train Loss: 1.3884 | Val Loss: 1.3810 | Val Acc: 25.00% | Same accuracy streak: 5/10
Epoch 7/100 | Train Loss: 1.3836 | Val Loss: 1.3800 | Val Acc: 25.00% | Same accuracy streak: 6/10
Epoch 8/100 | Train Loss: 1.3855 | Val Loss: 1.3790 | Val Acc: 25.00% | Same accuracy streak: 7/10
Epoch 9/100 | Train Loss: 1.3823 | Val Loss: 1.3782 | Val Acc: 27.08% | Same accuracy streak: 0/10
Epoch 10/100 | Train Loss: 1.3850 | Val Loss: 1.3772 | Val Acc: 2