In [13]:
import os
import torch
from torch.utils.data import DataLoader, Dataset
import numpy as np
import torch.nn as nn
import torch.optim as optim
from models import EEGformer  # Import the EEGformer model
import resampy

# Define device and enable Data Parallelism if multiple GPUs are available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Check if multiple GPUs are available
num_gpus = torch.cuda.device_count()
print(f"Using {num_gpus} GPUs")
sampling_rate = 177          # Example sampling rate (samples per second)
duration = 3                  # Duration in seconds
samples_to_extract = sampling_rate * duration  # Total samples for 3 seconds

Using 4 GPUs


In [14]:
# Define a custom Dataset for loading .npz files
class EEGDataset(Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith(".npz")]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file_path = self.files[idx]
        npz_data = np.load(file_path)

        data = resampy.resample(npz_data['data'], sr_orig=npz_data['frequency'], sr_new=sampling_rate)
        label = npz_data['label']

        num_channels = data.shape[0]
        time_steps = data.shape[1]

        if samples_to_extract > time_steps:
            raise ValueError(f"Data only has {time_steps} time steps, but {samples_to_extract} are required.")

        data = data[0, :samples_to_extract]  # Select first channel and slice for 3 seconds
        data = np.expand_dims(data, axis=0)   # Shape: (1, samples_to_extract)

        data_tensor = torch.tensor(data.astype(np.float32))
        label_tensor = torch.tensor(label, dtype=torch.long)

        return data_tensor, label_tensor

In [15]:
# Paths to the processed .npz files
train_dir = "/home/hira/eeg/nmt_events/train"
val_dir = "/home/hira/eeg/nmt_events/eval"
test_dir = "/home/hira/eeg/EEG_crops_per_channel/test"
model_saving_path = "/home/hira/eeg/EEG_crops_per_channel/model/"
model_name = "eeg_former_v2"

# Initialize the datasets and dataloaders
batch_size = 8 * num_gpus  # Adjust batch size according to available GPUs
train_dataset = EEGDataset(data_dir=train_dir)
val_dataset = EEGDataset(data_dir=val_dir)
test_dataset = EEGDataset(data_dir=test_dir)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=10, pin_memory=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=10, pin_memory=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=10, pin_memory=True)

# Parameters
input_channels = 1
num_cls = 2
kernel_size = 10
num_blocks = 3
num_heads_rtm = 6
num_heads_stm = 6
num_heads_ttm = 11
num_submatrices = 12
CF_second = 2

# Create a dummy input with shape expected by the model
sample_input = torch.randn(8, samples_to_extract, input_channels).to(device)
print(sample_input.shape)

# Initialize the model
model = EEGformer(input=sample_input, num_cls=num_cls, input_channels=input_channels,
                  kernel_size=kernel_size, num_blocks=num_blocks, num_heads_RTM=num_heads_rtm,
                  num_heads_STM=num_heads_stm, num_heads_TTM=num_heads_ttm,
                  num_submatrices=num_submatrices, CF_second=CF_second)



# Use Data Parallelism if multiple GPUs are available
if num_gpus > 1:
    model = nn.DataParallel(model)

# Load the saved model state
model.load_state_dict(torch.load(model_saving_path + "eeg_former_v2" + ".pth"))

# Move the model to GPU(s)
model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.000001)

torch.Size([8, 531, 1])


In [16]:
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# Continue training
start_epoch = 57  # Start from the next epoch
num_epochs = 100  # Total number of epochs to train (25 more epochs)
best_val_acc = 82.72 # Set to the last saved validation accuracy
model.train()

# Early stopping parameters
patience = 10  # Number of epochs to wait for improvement (3 in this case)
counter = 0  # Counter for epochs without improvement

for epoch_idx in range(start_epoch, num_epochs):
    # --- Training Phase ---
    total_train_loss = 0.0
    train_true = []
    train_preds = []

    print(f"\nEpoch [{epoch_idx + 1}/{num_epochs}] - Training")
    with tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc=f"Training Epoch {epoch_idx + 1}") as train_bar:
        for batch_idx, (inputs, labels) in train_bar:
            if inputs.shape[0] != batch_size:
                continue
            inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True)

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update loss
            total_train_loss += loss.item()

            # Get predictions
            _, predicted = torch.max(outputs, 1)
            train_true.extend(labels.cpu().numpy())
            train_preds.extend(predicted.cpu().numpy())

            # Update the progress bar with current loss
            train_bar.set_postfix(loss=loss.item())

    # Compute training metrics
    avg_train_loss = total_train_loss / len(train_dataloader)
    train_accuracy = accuracy_score(train_true, train_preds) * 100
    train_precision = precision_score(train_true, train_preds, average='macro', zero_division=0)
    train_recall = recall_score(train_true, train_preds, average='macro', zero_division=0)
    train_f1 = f1_score(train_true, train_preds, average='macro', zero_division=0)

    print(f"Training Loss: {avg_train_loss:.4f}, Accuracy: {train_accuracy:.2f}%")
    print(f"Precision: {train_precision:.4f}, Recall: {train_recall:.4f}, F1-score: {train_f1:.4f}")

    # --- Validation Phase ---
    model.eval()
    total_val_loss = 0.0
    val_true = []
    val_preds = []

    print(f"\nEpoch [{epoch_idx + 1}/{num_epochs}] - Validation")
    with torch.no_grad():
        with tqdm(enumerate(val_dataloader), total=len(val_dataloader), desc=f"Validation Epoch {epoch_idx + 1}") as val_bar:
            for batch_idx, (inputs, labels) in val_bar:
                if inputs.shape[0] != batch_size:
                    continue

                inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True)

                outputs = model(inputs)
                loss = criterion(outputs, labels)

                total_val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                val_true.extend(labels.cpu().numpy())
                val_preds.extend(predicted.cpu().numpy())

                val_bar.set_postfix(loss=loss.item())

    avg_val_loss = total_val_loss / len(val_dataloader)
    val_accuracy = accuracy_score(val_true, val_preds) * 100

    print(f"Validation Loss: {avg_val_loss:.4f}, Accuracy: {val_accuracy:.2f}%")

    # --- Save the model if validation accuracy improved ---
    if val_accuracy > best_val_acc:
        best_val_acc = val_accuracy
        torch.save(model.state_dict(), model_saving_path + model_name + ".pth")
        print(f"Model saved at epoch {epoch_idx + 1} with validation accuracy: {val_accuracy:.2f}%")
        counter = 0  # Reset counter when improvement happens
    else:
        counter += 1
        print(f"Early stopping counter: {counter} out of {patience}")
        if counter >= patience:
            print("Early stopping triggered!")
            break #stop the training loop.

    model.train()  # Switch back to training mode


Epoch [36/100] - Training


Training Epoch 36: 100%|██████████| 5987/5987 [3:31:26<00:00,  2.12s/it, loss=0.384]  


Training Loss: 0.4302, Accuracy: 87.81%
Precision: 0.8821, Recall: 0.8745, F1-score: 0.8765

Epoch [36/100] - Validation


Validation Epoch 36: 100%|██████████| 1120/1120 [11:36<00:00,  1.61it/s, loss=0.434]

Validation Loss: 0.4814, Accuracy: 82.41%
Early stopping counter: 1 out of 10

Epoch [37/100] - Training



Training Epoch 37: 100%|██████████| 5987/5987 [3:32:27<00:00,  2.13s/it, loss=0.416]  


Training Loss: 0.4298, Accuracy: 87.86%
Precision: 0.8825, Recall: 0.8751, F1-score: 0.8771

Epoch [37/100] - Validation


Validation Epoch 37: 100%|██████████| 1120/1120 [11:36<00:00,  1.61it/s, loss=0.435]

Validation Loss: 0.4807, Accuracy: 82.51%
Early stopping counter: 2 out of 10

Epoch [38/100] - Training



Training Epoch 38: 100%|██████████| 5987/5987 [3:32:59<00:00,  2.13s/it, loss=0.374]  


Training Loss: 0.4294, Accuracy: 87.96%
Precision: 0.8835, Recall: 0.8761, F1-score: 0.8781

Epoch [38/100] - Validation


Validation Epoch 38: 100%|██████████| 1120/1120 [11:36<00:00,  1.61it/s, loss=0.434]


Validation Loss: 0.4803, Accuracy: 82.61%
Model saved at epoch 38 with validation accuracy: 82.61%

Epoch [39/100] - Training


Training Epoch 39: 100%|██████████| 5987/5987 [3:33:25<00:00,  2.14s/it, loss=0.39]   


Training Loss: 0.4289, Accuracy: 88.01%
Precision: 0.8842, Recall: 0.8765, F1-score: 0.8786

Epoch [39/100] - Validation


Validation Epoch 39: 100%|██████████| 1120/1120 [11:37<00:00,  1.61it/s, loss=0.44] 

Validation Loss: 0.4809, Accuracy: 82.41%
Early stopping counter: 1 out of 10

Epoch [40/100] - Training



Training Epoch 40: 100%|██████████| 5987/5987 [3:29:42<00:00,  2.10s/it, loss=0.444]  


Training Loss: 0.4289, Accuracy: 87.99%
Precision: 0.8839, Recall: 0.8764, F1-score: 0.8784

Epoch [40/100] - Validation


Validation Epoch 40: 100%|██████████| 1120/1120 [11:36<00:00,  1.61it/s, loss=0.441]

Validation Loss: 0.4805, Accuracy: 82.52%
Early stopping counter: 2 out of 10

Epoch [41/100] - Training



Training Epoch 41: 100%|██████████| 5987/5987 [3:30:13<00:00,  2.11s/it, loss=0.374]  


Training Loss: 0.4293, Accuracy: 87.94%
Precision: 0.8835, Recall: 0.8759, F1-score: 0.8779

Epoch [41/100] - Validation


Validation Epoch 41: 100%|██████████| 1120/1120 [11:38<00:00,  1.60it/s, loss=0.442]

Validation Loss: 0.4797, Accuracy: 82.59%
Early stopping counter: 3 out of 10

Epoch [42/100] - Training



Training Epoch 42: 100%|██████████| 5987/5987 [3:31:28<00:00,  2.12s/it, loss=0.566]  


Training Loss: 0.4289, Accuracy: 87.94%
Precision: 0.8834, Recall: 0.8759, F1-score: 0.8779

Epoch [42/100] - Validation


Validation Epoch 42: 100%|██████████| 1120/1120 [11:31<00:00,  1.62it/s, loss=0.434]

Validation Loss: 0.4800, Accuracy: 82.57%
Early stopping counter: 4 out of 10

Epoch [43/100] - Training



Training Epoch 43: 100%|██████████| 5987/5987 [3:32:17<00:00,  2.13s/it, loss=0.333]  


Training Loss: 0.4291, Accuracy: 87.95%
Precision: 0.8833, Recall: 0.8760, F1-score: 0.8780

Epoch [43/100] - Validation


Validation Epoch 43: 100%|██████████| 1120/1120 [11:39<00:00,  1.60it/s, loss=0.441]

Validation Loss: 0.4803, Accuracy: 82.54%
Early stopping counter: 5 out of 10

Epoch [44/100] - Training



Training Epoch 44: 100%|██████████| 5987/5987 [3:29:52<00:00,  2.10s/it, loss=0.434]  


Training Loss: 0.4289, Accuracy: 88.00%
Precision: 0.8841, Recall: 0.8765, F1-score: 0.8785

Epoch [44/100] - Validation


Validation Epoch 44: 100%|██████████| 1120/1120 [12:11<00:00,  1.53it/s, loss=0.443]


Validation Loss: 0.4799, Accuracy: 82.61%
Model saved at epoch 44 with validation accuracy: 82.61%

Epoch [45/100] - Training


Training Epoch 45: 100%|██████████| 5987/5987 [3:23:27<00:00,  2.04s/it, loss=0.404]  


Training Loss: 0.4289, Accuracy: 87.98%
Precision: 0.8837, Recall: 0.8764, F1-score: 0.8783

Epoch [45/100] - Validation


Validation Epoch 45: 100%|██████████| 1120/1120 [12:11<00:00,  1.53it/s, loss=0.441]


Validation Loss: 0.4795, Accuracy: 82.62%
Model saved at epoch 45 with validation accuracy: 82.62%

Epoch [46/100] - Training


Training Epoch 46: 100%|██████████| 5987/5987 [3:22:20<00:00,  2.03s/it, loss=0.448]  


Training Loss: 0.4292, Accuracy: 87.96%
Precision: 0.8837, Recall: 0.8760, F1-score: 0.8780

Epoch [46/100] - Validation


Validation Epoch 46: 100%|██████████| 1120/1120 [12:04<00:00,  1.55it/s, loss=0.446]

Validation Loss: 0.4799, Accuracy: 82.57%
Early stopping counter: 1 out of 10

Epoch [47/100] - Training



Training Epoch 47: 100%|██████████| 5987/5987 [3:24:36<00:00,  2.05s/it, loss=0.397]  


Training Loss: 0.4287, Accuracy: 87.96%
Precision: 0.8835, Recall: 0.8761, F1-score: 0.8781

Epoch [47/100] - Validation


Validation Epoch 47: 100%|██████████| 1120/1120 [12:09<00:00,  1.54it/s, loss=0.435]

Validation Loss: 0.4799, Accuracy: 82.61%
Early stopping counter: 2 out of 10

Epoch [48/100] - Training



Training Epoch 48: 100%|██████████| 5987/5987 [3:27:40<00:00,  2.08s/it, loss=0.445]  


Training Loss: 0.4285, Accuracy: 88.04%
Precision: 0.8842, Recall: 0.8770, F1-score: 0.8790

Epoch [48/100] - Validation


Validation Epoch 48: 100%|██████████| 1120/1120 [11:36<00:00,  1.61it/s, loss=0.443]


Validation Loss: 0.4801, Accuracy: 82.55%
Early stopping counter: 3 out of 10

Epoch [49/100] - Training


Training Epoch 49: 100%|██████████| 5987/5987 [3:33:58<00:00,  2.14s/it, loss=0.476]  


Training Loss: 0.4287, Accuracy: 88.00%
Precision: 0.8839, Recall: 0.8765, F1-score: 0.8785

Epoch [49/100] - Validation


Validation Epoch 49: 100%|██████████| 1120/1120 [11:33<00:00,  1.62it/s, loss=0.449]


Validation Loss: 0.4794, Accuracy: 82.65%
Model saved at epoch 49 with validation accuracy: 82.65%

Epoch [50/100] - Training


Training Epoch 50: 100%|██████████| 5987/5987 [3:35:28<00:00,  2.16s/it, loss=0.435]  


Training Loss: 0.4287, Accuracy: 87.98%
Precision: 0.8837, Recall: 0.8764, F1-score: 0.8784

Epoch [50/100] - Validation


Validation Epoch 50: 100%|██████████| 1120/1120 [13:02<00:00,  1.43it/s, loss=0.448]

Validation Loss: 0.4796, Accuracy: 82.57%
Early stopping counter: 1 out of 10

Epoch [51/100] - Training



Training Epoch 51: 100%|██████████| 5987/5987 [4:25:02<00:00,  2.66s/it, loss=0.366]  


Training Loss: 0.4284, Accuracy: 88.08%
Precision: 0.8845, Recall: 0.8774, F1-score: 0.8793

Epoch [51/100] - Validation


Validation Epoch 51: 100%|██████████| 1120/1120 [11:47<00:00,  1.58it/s, loss=0.451]


Validation Loss: 0.4791, Accuracy: 82.72%
Model saved at epoch 51 with validation accuracy: 82.72%

Epoch [52/100] - Training


Training Epoch 52: 100%|██████████| 5987/5987 [3:52:16<00:00,  2.33s/it, loss=0.417]  


Training Loss: 0.4285, Accuracy: 88.01%
Precision: 0.8842, Recall: 0.8766, F1-score: 0.8786

Epoch [52/100] - Validation


Validation Epoch 52: 100%|██████████| 1120/1120 [15:26<00:00,  1.21it/s, loss=0.45] 


Validation Loss: 0.4796, Accuracy: 82.65%
Early stopping counter: 1 out of 10

Epoch [53/100] - Training


Training Epoch 53: 100%|██████████| 5987/5987 [3:54:53<00:00,  2.35s/it, loss=0.369]  


Training Loss: 0.4282, Accuracy: 88.04%
Precision: 0.8843, Recall: 0.8770, F1-score: 0.8790

Epoch [53/100] - Validation


Validation Epoch 53: 100%|██████████| 1120/1120 [11:40<00:00,  1.60it/s, loss=0.438]

Validation Loss: 0.4793, Accuracy: 82.59%
Early stopping counter: 2 out of 10

Epoch [54/100] - Training



Training Epoch 54: 100%|██████████| 5987/5987 [3:33:23<00:00,  2.14s/it, loss=0.416]  


Training Loss: 0.4279, Accuracy: 88.13%
Precision: 0.8852, Recall: 0.8779, F1-score: 0.8798

Epoch [54/100] - Validation


Validation Epoch 54: 100%|██████████| 1120/1120 [12:00<00:00,  1.56it/s, loss=0.442]

Validation Loss: 0.4793, Accuracy: 82.65%
Early stopping counter: 3 out of 10

Epoch [55/100] - Training



Training Epoch 55: 100%|██████████| 5987/5987 [4:10:05<00:00,  2.51s/it, loss=0.387]  


Training Loss: 0.4282, Accuracy: 88.07%
Precision: 0.8847, Recall: 0.8772, F1-score: 0.8792

Epoch [55/100] - Validation


Validation Epoch 55: 100%|██████████| 1120/1120 [11:38<00:00,  1.60it/s, loss=0.449]

Validation Loss: 0.4793, Accuracy: 82.65%
Early stopping counter: 4 out of 10

Epoch [56/100] - Training



Training Epoch 56: 100%|██████████| 5987/5987 [4:03:35<00:00,  2.44s/it, loss=0.493]  


Training Loss: 0.4281, Accuracy: 88.07%
Precision: 0.8847, Recall: 0.8773, F1-score: 0.8793

Epoch [56/100] - Validation


Validation Epoch 56: 100%|██████████| 1120/1120 [15:32<00:00,  1.20it/s, loss=0.454]

Validation Loss: 0.4796, Accuracy: 82.56%
Early stopping counter: 5 out of 10

Epoch [57/100] - Training



Training Epoch 57:  13%|█▎        | 799/5987 [34:43<3:50:31,  2.67s/it, loss=0.403]

In [None]:
# existing_model_path = model_saving_path + model_name + '.pth'
# print(existing_model_path)

In [None]:
# import shutil
# import torch

# #... your existing code...

# # Path to your existing model weights file
# existing_model_path = model_saving_path + model_name + ".pth"

# # Path to the new file where you want to copy the weights
# new_model_path = model_saving_path + model_name + "_25epochs_backup.pth"  # Or any other descriptive name

# # Copy the file
# shutil.copyfile(existing_model_path, new_model_path)

#... continue with your training code...

In [None]:
# Load the best model weights
state_dict = torch.load(model_saving_path + model_name + ".pth")

# If trained with DataParallel, remove "module." prefix
if "module." in list(state_dict.keys())[0]:  
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        new_state_dict[k.replace("module.", "")] = v
    state_dict = new_state_dict  # Update state dict

model.load_state_dict(state_dict)
model.to(device)  # Move to GPU
model.eval()  # Set model to evaluation mode

total_test_loss = 0.0
test_true = []
test_preds = []

print("\nTesting Phase")
with torch.no_grad():
    with tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc="Testing") as test_bar:
        for batch_idx, (inputs, labels) in test_bar:
            if inputs.shape[0] != batch_size:
                continue
            inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            total_test_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            test_true.extend(labels.cpu().numpy())
            test_preds.extend(predicted.cpu().numpy())
            
            test_bar.set_postfix(loss=loss.item())

avg_test_loss = total_test_loss / len(test_dataloader)
test_accuracy = accuracy_score(test_true, test_preds) * 100
test_precision = precision_score(test_true, test_preds, average='macro', zero_division=0)
test_recall = recall_score(test_true, test_preds, average='macro', zero_division=0)
test_f1 = f1_score(test_true, test_preds, average='macro', zero_division=0)

print(f"Test Loss: {avg_test_loss:.4f}, Accuracy: {test_accuracy:.2f}%")
print(f"Precision: {test_precision:.4f}, Recall: {test_recall:.4f}, F1-score: {test_f1:.4f}")
