In [1]:
import os
import torch
from torch.utils.data import DataLoader, Dataset
import numpy as np
import torch.nn as nn
import torch.optim as optim
from models_duplicate import EEGformer  # Import the EEGformer model
import resampy

# Define device and enable Data Parallelism if multiple GPUs are available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Check if multiple GPUs are available
num_gpus = torch.cuda.device_count()
print(f"Using {num_gpus} GPUs")
sampling_rate = 177          # Example sampling rate (samples per second)
duration = 3                  # Duration in seconds
samples_to_extract = sampling_rate * duration  # Total samples for 3 seconds

Using 4 GPUs


In [2]:
# Define a custom Dataset for loading .npz files
class EEGDataset(Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith(".npz")]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file_path = self.files[idx]
        npz_data = np.load(file_path)

        data = resampy.resample(npz_data['data'], sr_orig=npz_data['frequency'], sr_new=sampling_rate)
        label = npz_data['label']

        num_channels = data.shape[0]
        time_steps = data.shape[1]

        if samples_to_extract > time_steps:
            raise ValueError(f"Data only has {time_steps} time steps, but {samples_to_extract} are required.")

        data = data[0, :samples_to_extract]  # Select first channel and slice for 3 seconds
        data = np.expand_dims(data, axis=0)   # Shape: (1, samples_to_extract)

        data_tensor = torch.tensor(data.astype(np.float32))
        label_tensor = torch.tensor(label, dtype=torch.long)

        return data_tensor, label_tensor

In [3]:
# Paths to the processed .npz files
train_dir = "/home/hira/eeg/EEG_crops_per_channel/train"
val_dir = "/home/hira/eeg/EEG_crops_per_channel/eval"
test_dir = "/home/hira/eeg/EEG_crops_per_channel/test"
model_saving_path = "/home/hira/eeg/EEG_crops_per_channel/model/"
model_name = "eeg_former_v1"

# Initialize the datasets and dataloaders
batch_size = 8 * num_gpus  # Adjust batch size according to available GPUs
train_dataset = EEGDataset(data_dir=train_dir)
val_dataset = EEGDataset(data_dir=val_dir)
test_dataset = EEGDataset(data_dir=test_dir)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=10, pin_memory=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=10, pin_memory=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=10, pin_memory=True)

# Parameters
input_channels = 1
num_cls = 2
kernel_size = 10
num_blocks = 3
num_heads_rtm = 6
num_heads_stm = 6
num_heads_ttm = 12
num_submatrices = 12
CF_second = 2

# Create a dummy input with shape expected by the model
sample_input = torch.randn(8, samples_to_extract, input_channels).to(device)
print(sample_input.shape)

# Initialize the model
model = EEGformer(input=sample_input, num_cls=num_cls, input_channels=input_channels,
                  kernel_size=kernel_size, num_blocks=num_blocks, num_heads_RTM=num_heads_rtm,
                   num_heads_TTM=num_heads_ttm,
                  num_submatrices=num_submatrices, CF_second=CF_second)

# Use Data Parallelism if multiple GPUs are available
if num_gpus > 1:
    model = nn.DataParallel(model)

# Move the model to GPU(s)
model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

torch.Size([8, 531, 1])


In [4]:
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

num_epochs = 10
best_val_acc = 0.0
model.train()

for epoch_idx in range(num_epochs):
    # --- Training Phase ---
    total_train_loss = 0.0
    train_true = []
    train_preds = []

    print(f"\nEpoch [{epoch_idx + 1}/{num_epochs}] - Training")
    with tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc=f"Training Epoch {epoch_idx + 1}") as train_bar:
        for batch_idx, (inputs, labels) in train_bar:
            if inputs.shape[0] != batch_size:
                continue
            inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True)

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update loss
            total_train_loss += loss.item()

            # Get predictions
            _, predicted = torch.max(outputs, 1)
            train_true.extend(labels.cpu().numpy())
            train_preds.extend(predicted.cpu().numpy())

            # Update the progress bar with current loss
            train_bar.set_postfix(loss=loss.item())

    # Compute training metrics
    avg_train_loss = total_train_loss / len(train_dataloader)
    train_accuracy = accuracy_score(train_true, train_preds) * 100
    train_precision = precision_score(train_true, train_preds, average='macro', zero_division=0)
    train_recall = recall_score(train_true, train_preds, average='macro', zero_division=0)
    train_f1 = f1_score(train_true, train_preds, average='macro', zero_division=0)

    print(f"Training Loss: {avg_train_loss:.4f}, Accuracy: {train_accuracy:.2f}%")
    print(f"Precision: {train_precision:.4f}, Recall: {train_recall:.4f}, F1-score: {train_f1:.4f}")

    # --- Validation Phase ---
    model.eval()
    total_val_loss = 0.0
    val_true = []
    val_preds = []

    print(f"\nEpoch [{epoch_idx + 1}/{num_epochs}] - Validation")
    with torch.no_grad():
        with tqdm(enumerate(val_dataloader), total=len(val_dataloader), desc=f"Validation Epoch {epoch_idx + 1}") as val_bar:
            for batch_idx, (inputs, labels) in val_bar:
                if inputs.shape[0] != batch_size:
                    continue

                inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True)

                outputs = model(inputs)
                loss = criterion(outputs, labels)

                total_val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                val_true.extend(labels.cpu().numpy())
                val_preds.extend(predicted.cpu().numpy())

                val_bar.set_postfix(loss=loss.item())

    avg_val_loss = total_val_loss / len(val_dataloader)
    val_accuracy = accuracy_score(val_true, val_preds) * 100

    print(f"Validation Loss: {avg_val_loss:.4f}, Accuracy: {val_accuracy:.2f}%")

    # --- Save the model if validation accuracy improved ---
    if val_accuracy > best_val_acc:
        best_val_acc = val_accuracy
        torch.save(model.state_dict(), model_saving_path + model_name + ".pth")
        print(f"Model saved at epoch {epoch_idx + 1} with validation accuracy: {val_accuracy:.2f}%")

    model.train()  # Switch back to training mode


Epoch [1/10] - Training


Training Epoch 1: 100%|██████████| 25740/25740 [52:30<00:00,  8.17it/s, loss=0.47]   


Training Loss: 0.5330, Accuracy: 78.03%
Precision: 0.5568, Recall: 0.5000, F1-score: 0.4383

Epoch [1/10] - Validation


Validation Epoch 1: 100%|██████████| 8580/8580 [07:06<00:00, 20.13it/s, loss=0.532]


Validation Loss: 0.5329, Accuracy: 78.03%
Model saved at epoch 1 with validation accuracy: 78.03%

Epoch [2/10] - Training


Training Epoch 2: 100%|██████████| 25740/25740 [29:12<00:00, 14.69it/s, loss=0.626]  


Training Loss: 0.5330, Accuracy: 78.03%
Precision: 0.3901, Recall: 0.5000, F1-score: 0.4383

Epoch [2/10] - Validation


Validation Epoch 2: 100%|██████████| 8580/8580 [05:53<00:00, 24.30it/s, loss=0.532]


Validation Loss: 0.5329, Accuracy: 78.03%

Epoch [3/10] - Training


Training Epoch 3: 100%|██████████| 25740/25740 [27:44<00:00, 15.47it/s, loss=0.532]


Training Loss: 0.5330, Accuracy: 78.03%
Precision: 0.3901, Recall: 0.5000, F1-score: 0.4383

Epoch [3/10] - Validation


Validation Epoch 3: 100%|██████████| 8580/8580 [06:12<00:00, 23.02it/s, loss=0.532]


Validation Loss: 0.5329, Accuracy: 78.03%

Epoch [4/10] - Training


Training Epoch 4: 100%|██████████| 25740/25740 [27:48<00:00, 15.42it/s, loss=0.438]


Training Loss: 0.5330, Accuracy: 78.03%
Precision: 0.3901, Recall: 0.5000, F1-score: 0.4383

Epoch [4/10] - Validation


Validation Epoch 4: 100%|██████████| 8580/8580 [05:34<00:00, 25.69it/s, loss=0.532]


Validation Loss: 0.5329, Accuracy: 78.03%

Epoch [5/10] - Training


Training Epoch 5: 100%|██████████| 25740/25740 [27:17<00:00, 15.72it/s, loss=0.47] 


Training Loss: 0.5330, Accuracy: 78.03%
Precision: 0.3901, Recall: 0.5000, F1-score: 0.4383

Epoch [5/10] - Validation


Validation Epoch 5: 100%|██████████| 8580/8580 [06:01<00:00, 23.75it/s, loss=0.532]


Validation Loss: 0.5329, Accuracy: 78.03%

Epoch [6/10] - Training


Training Epoch 6: 100%|██████████| 25740/25740 [27:17<00:00, 15.72it/s, loss=0.345]


Training Loss: 0.5330, Accuracy: 78.03%
Precision: 0.3901, Recall: 0.5000, F1-score: 0.4383

Epoch [6/10] - Validation


Validation Epoch 6: 100%|██████████| 8580/8580 [05:30<00:00, 25.98it/s, loss=0.532]


Validation Loss: 0.5329, Accuracy: 78.03%

Epoch [7/10] - Training


Training Epoch 7: 100%|██████████| 25740/25740 [27:55<00:00, 15.36it/s, loss=0.501]


Training Loss: 0.5330, Accuracy: 78.03%
Precision: 0.3901, Recall: 0.5000, F1-score: 0.4383

Epoch [7/10] - Validation


Validation Epoch 7: 100%|██████████| 8580/8580 [05:50<00:00, 24.48it/s, loss=0.532]


Validation Loss: 0.5329, Accuracy: 78.03%

Epoch [8/10] - Training


Training Epoch 8: 100%|██████████| 25740/25740 [27:25<00:00, 15.64it/s, loss=0.532]


Training Loss: 0.5330, Accuracy: 78.03%
Precision: 0.3901, Recall: 0.5000, F1-score: 0.4383

Epoch [8/10] - Validation


Validation Epoch 8: 100%|██████████| 8580/8580 [05:53<00:00, 24.26it/s, loss=0.532]


Validation Loss: 0.5329, Accuracy: 78.03%

Epoch [9/10] - Training


Training Epoch 9: 100%|██████████| 25740/25740 [27:03<00:00, 15.86it/s, loss=0.438]


Training Loss: 0.5330, Accuracy: 78.03%
Precision: 0.3901, Recall: 0.5000, F1-score: 0.4383

Epoch [9/10] - Validation


Validation Epoch 9: 100%|██████████| 8580/8580 [05:41<00:00, 25.12it/s, loss=0.532]


Validation Loss: 0.5329, Accuracy: 78.03%

Epoch [10/10] - Training


Training Epoch 10: 100%|██████████| 25740/25740 [43:05<00:00,  9.96it/s, loss=0.501]  


Training Loss: 0.5330, Accuracy: 78.03%
Precision: 0.3901, Recall: 0.5000, F1-score: 0.4383

Epoch [10/10] - Validation


Validation Epoch 10: 100%|██████████| 8580/8580 [10:50<00:00, 13.18it/s, loss=0.532]


Validation Loss: 0.5329, Accuracy: 78.03%


In [5]:
# Load the best model weights
state_dict = torch.load(model_saving_path + model_name + ".pth")

# If trained with DataParallel, remove "module." prefix
if "module." in list(state_dict.keys())[0]:  
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        new_state_dict[k.replace("module.", "")] = v
    state_dict = new_state_dict  # Update state dict

model.load_state_dict(state_dict)
model.to(device)  # Move to GPU
model.eval()  # Set model to evaluation mode

total_test_loss = 0.0
test_true = []
test_preds = []

print("\nTesting Phase")
with torch.no_grad():
    with tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc="Testing") as test_bar:
        for batch_idx, (inputs, labels) in test_bar:
            if inputs.shape[0] != batch_size:
                continue
            inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            total_test_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            test_true.extend(labels.cpu().numpy())
            test_preds.extend(predicted.cpu().numpy())
            
            test_bar.set_postfix(loss=loss.item())

avg_test_loss = total_test_loss / len(test_dataloader)
test_accuracy = accuracy_score(test_true, test_preds) * 100
test_precision = precision_score(test_true, test_preds, average='macro', zero_division=0)
test_recall = recall_score(test_true, test_preds, average='macro', zero_division=0)
test_f1 = f1_score(test_true, test_preds, average='macro', zero_division=0)

print(f"Test Loss: {avg_test_loss:.4f}, Accuracy: {test_accuracy:.2f}%")
print(f"Precision: {test_precision:.4f}, Recall: {test_recall:.4f}, F1-score: {test_f1:.4f}")


RuntimeError: Error(s) in loading state_dict for DataParallel:
	Missing key(s) in state_dict: "module.projection_layer.weight", "module.projection_layer.bias", "module.odcm.cvf1.weight", "module.odcm.cvf1.bias", "module.odcm.cvf2.weight", "module.odcm.cvf2.bias", "module.odcm.cvf3.weight", "module.odcm.cvf3.bias", "module.rtm.weight", "module.rtm.bias", "module.rtm.cls", "module.rtm.tfb.0.Wqkv", "module.rtm.tfb.0.Wo", "module.rtm.tfb.0.lnorm.weight", "module.rtm.tfb.0.lnorm.bias", "module.rtm.tfb.0.lnormz.weight", "module.rtm.tfb.0.lnormz.bias", "module.rtm.tfb.0.mlp.fc1.weight", "module.rtm.tfb.0.mlp.fc1.bias", "module.rtm.tfb.0.mlp.fc2.weight", "module.rtm.tfb.0.mlp.fc2.bias", "module.rtm.tfb.1.Wqkv", "module.rtm.tfb.1.Wo", "module.rtm.tfb.1.lnorm.weight", "module.rtm.tfb.1.lnorm.bias", "module.rtm.tfb.1.lnormz.weight", "module.rtm.tfb.1.lnormz.bias", "module.rtm.tfb.1.mlp.fc1.weight", "module.rtm.tfb.1.mlp.fc1.bias", "module.rtm.tfb.1.mlp.fc2.weight", "module.rtm.tfb.1.mlp.fc2.bias", "module.rtm.tfb.2.Wqkv", "module.rtm.tfb.2.Wo", "module.rtm.tfb.2.lnorm.weight", "module.rtm.tfb.2.lnorm.bias", "module.rtm.tfb.2.lnormz.weight", "module.rtm.tfb.2.lnormz.bias", "module.rtm.tfb.2.mlp.fc1.weight", "module.rtm.tfb.2.mlp.fc1.bias", "module.rtm.tfb.2.mlp.fc2.weight", "module.rtm.tfb.2.mlp.fc2.bias", "module.ttm.weight", "module.ttm.bias", "module.ttm.cls", "module.ttm.projection_layer.weight", "module.ttm.projection_layer.bias", "module.ttm.tfb.0.Wqkv", "module.ttm.tfb.0.Wo", "module.ttm.tfb.0.lnorm.weight", "module.ttm.tfb.0.lnorm.bias", "module.ttm.tfb.0.lnormz.weight", "module.ttm.tfb.0.lnormz.bias", "module.ttm.tfb.0.mlp.fc1.weight", "module.ttm.tfb.0.mlp.fc1.bias", "module.ttm.tfb.0.mlp.fc2.weight", "module.ttm.tfb.0.mlp.fc2.bias", "module.ttm.tfb.1.Wqkv", "module.ttm.tfb.1.Wo", "module.ttm.tfb.1.lnorm.weight", "module.ttm.tfb.1.lnorm.bias", "module.ttm.tfb.1.lnormz.weight", "module.ttm.tfb.1.lnormz.bias", "module.ttm.tfb.1.mlp.fc1.weight", "module.ttm.tfb.1.mlp.fc1.bias", "module.ttm.tfb.1.mlp.fc2.weight", "module.ttm.tfb.1.mlp.fc2.bias", "module.ttm.tfb.2.Wqkv", "module.ttm.tfb.2.Wo", "module.ttm.tfb.2.lnorm.weight", "module.ttm.tfb.2.lnorm.bias", "module.ttm.tfb.2.lnormz.weight", "module.ttm.tfb.2.lnormz.bias", "module.ttm.tfb.2.mlp.fc1.weight", "module.ttm.tfb.2.mlp.fc1.bias", "module.ttm.tfb.2.mlp.fc2.weight", "module.ttm.tfb.2.mlp.fc2.bias", "module.ttm.lnorm_extra.weight", "module.ttm.lnorm_extra.bias", "module.cnndecoder.cvd1.weight", "module.cnndecoder.cvd1.bias", "module.cnndecoder.cvd2.weight", "module.cnndecoder.cvd2.bias", "module.cnndecoder.cvd3.weight", "module.cnndecoder.cvd3.bias", "module.cnndecoder.fc.weight", "module.cnndecoder.fc.bias". 
	Unexpected key(s) in state_dict: "projection_layer.weight", "projection_layer.bias", "odcm.cvf1.weight", "odcm.cvf1.bias", "odcm.cvf2.weight", "odcm.cvf2.bias", "odcm.cvf3.weight", "odcm.cvf3.bias", "rtm.weight", "rtm.bias", "rtm.cls", "rtm.tfb.0.Wqkv", "rtm.tfb.0.Wo", "rtm.tfb.0.lnorm.weight", "rtm.tfb.0.lnorm.bias", "rtm.tfb.0.lnormz.weight", "rtm.tfb.0.lnormz.bias", "rtm.tfb.0.mlp.fc1.weight", "rtm.tfb.0.mlp.fc1.bias", "rtm.tfb.0.mlp.fc2.weight", "rtm.tfb.0.mlp.fc2.bias", "rtm.tfb.1.Wqkv", "rtm.tfb.1.Wo", "rtm.tfb.1.lnorm.weight", "rtm.tfb.1.lnorm.bias", "rtm.tfb.1.lnormz.weight", "rtm.tfb.1.lnormz.bias", "rtm.tfb.1.mlp.fc1.weight", "rtm.tfb.1.mlp.fc1.bias", "rtm.tfb.1.mlp.fc2.weight", "rtm.tfb.1.mlp.fc2.bias", "rtm.tfb.2.Wqkv", "rtm.tfb.2.Wo", "rtm.tfb.2.lnorm.weight", "rtm.tfb.2.lnorm.bias", "rtm.tfb.2.lnormz.weight", "rtm.tfb.2.lnormz.bias", "rtm.tfb.2.mlp.fc1.weight", "rtm.tfb.2.mlp.fc1.bias", "rtm.tfb.2.mlp.fc2.weight", "rtm.tfb.2.mlp.fc2.bias", "ttm.weight", "ttm.bias", "ttm.cls", "ttm.projection_layer.weight", "ttm.projection_layer.bias", "ttm.tfb.0.Wqkv", "ttm.tfb.0.Wo", "ttm.tfb.0.lnorm.weight", "ttm.tfb.0.lnorm.bias", "ttm.tfb.0.lnormz.weight", "ttm.tfb.0.lnormz.bias", "ttm.tfb.0.mlp.fc1.weight", "ttm.tfb.0.mlp.fc1.bias", "ttm.tfb.0.mlp.fc2.weight", "ttm.tfb.0.mlp.fc2.bias", "ttm.tfb.1.Wqkv", "ttm.tfb.1.Wo", "ttm.tfb.1.lnorm.weight", "ttm.tfb.1.lnorm.bias", "ttm.tfb.1.lnormz.weight", "ttm.tfb.1.lnormz.bias", "ttm.tfb.1.mlp.fc1.weight", "ttm.tfb.1.mlp.fc1.bias", "ttm.tfb.1.mlp.fc2.weight", "ttm.tfb.1.mlp.fc2.bias", "ttm.tfb.2.Wqkv", "ttm.tfb.2.Wo", "ttm.tfb.2.lnorm.weight", "ttm.tfb.2.lnorm.bias", "ttm.tfb.2.lnormz.weight", "ttm.tfb.2.lnormz.bias", "ttm.tfb.2.mlp.fc1.weight", "ttm.tfb.2.mlp.fc1.bias", "ttm.tfb.2.mlp.fc2.weight", "ttm.tfb.2.mlp.fc2.bias", "ttm.lnorm_extra.weight", "ttm.lnorm_extra.bias", "cnndecoder.cvd1.weight", "cnndecoder.cvd1.bias", "cnndecoder.cvd2.weight", "cnndecoder.cvd2.bias", "cnndecoder.cvd3.weight", "cnndecoder.cvd3.bias", "cnndecoder.fc.weight", "cnndecoder.fc.bias". 