In [None]:
import os
import torch
from torch.utils.data import DataLoader, Dataset
import numpy as np
import torch.nn as nn
import torch.optim as optim
from models import EEGformer  # Import the EEGformer model
#from model_fft import EEGformer
import resampy

# Define device and enable Data Parallelism if multiple GPUs are available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Check if multiple GPUs are available
num_gpus = torch.cuda.device_count()
print(f"Using {num_gpus} GPUs")
sampling_rate = 177          # Example sampling rate (samples per second)
duration = 3                  # Duration in seconds
samples_to_extract = sampling_rate * duration  # Total samples for 3 seconds

In [None]:
# Define a custom Dataset for loading .npz files
class EEGDataset(Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith(".npz")]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file_path = self.files[idx]
        npz_data = np.load(file_path)

        data = resampy.resample(npz_data['data'], sr_orig=npz_data['frequency'], sr_new=sampling_rate)
        label = npz_data['label']

        num_channels = data.shape[0]
        time_steps = data.shape[1]

        if samples_to_extract > time_steps:
            raise ValueError(f"Data only has {time_steps} time steps, but {samples_to_extract} are required.")

        data = data[0, :samples_to_extract]  # Select first channel and slice for 3 seconds
        data = np.expand_dims(data, axis=0)   # Shape: (1, samples_to_extract)

        data_tensor = torch.tensor(data.astype(np.float32))
        label_tensor = torch.tensor(label, dtype=torch.long)

        return data_tensor, label_tensor

In [None]:
# Paths to the processed .npz files
train_dir = "/home/hira/eeg/nmt_events/train"
val_dir = "/home/hira/eeg/nmt_events/eval"
test_dir = "/home/hira/eeg/EEG_crops_per_channel/test"
model_saving_path = "/home/hira/eeg/EEG_crops_per_channel/model/"
model_name = "eeg_former_v2"

# Initialize the datasets and dataloaders
batch_size = 8 * num_gpus  # Adjust batch size according to available GPUs
train_dataset = EEGDataset(data_dir=train_dir)
val_dataset = EEGDataset(data_dir=val_dir)
test_dataset = EEGDataset(data_dir=test_dir)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=10, pin_memory=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=10, pin_memory=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=10, pin_memory=True)

# Parameters
input_channels = 1
num_cls = 2
kernel_size = 10
num_blocks = 3
num_heads_rtm = 6
num_heads_stm = 6
num_heads_ttm = 11
num_submatrices = 12
CF_second = 2

# Create a dummy input with shape expected by the model
sample_input = torch.randn(8, samples_to_extract, input_channels).to(device)
print(sample_input.shape)

# Initialize the model
model = EEGformer(input=sample_input, num_cls=num_cls, input_channels=input_channels,
                  kernel_size=kernel_size, num_blocks=num_blocks, num_heads_RTM=num_heads_rtm,
                  num_heads_STM=num_heads_stm, num_heads_TTM=num_heads_ttm,
                  num_submatrices=num_submatrices, CF_second=CF_second)



# Use Data Parallelism if multiple GPUs are available
if num_gpus > 1:
    model = nn.DataParallel(model)

# Load the saved model state
model.load_state_dict(torch.load(model_saving_path + "eeg_former_v2" + ".pth"))

# Move the model to GPU(s)
model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.000005, weight_decay=0.02)

In [None]:
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# Continue training
start_epoch = 95  # Start from the next epoch
num_epochs = 100  # Total number of epochs to train (25 more epochs)
best_val_acc = 82.75   
 # Set to the last saved validation accuracy
model.train()

# Early stopping parameters
patience = 10  # Number of epochs to wait for improvement (3 in this case)
counter = 0  # Counter for epochs without improvement

for epoch_idx in range(start_epoch, num_epochs):
    # --- Training Phase ---
    total_train_loss = 0.0
    train_true = []
    train_preds = []

    print(f"\nEpoch [{epoch_idx + 1}/{num_epochs}] - Training")
    with tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc=f"Training Epoch {epoch_idx + 1}") as train_bar:
        for batch_idx, (inputs, labels) in train_bar:
            if inputs.shape[0] != batch_size:
                continue
            inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True)

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update loss
            total_train_loss += loss.item()

            # Get predictions
            _, predicted = torch.max(outputs, 1)
            train_true.extend(labels.cpu().numpy())
            train_preds.extend(predicted.cpu().numpy())

            # Update the progress bar with current loss
            train_bar.set_postfix(loss=loss.item())

    # Compute training metrics
    avg_train_loss = total_train_loss / len(train_dataloader)
    train_accuracy = accuracy_score(train_true, train_preds) * 100
    train_precision = precision_score(train_true, train_preds, average='macro', zero_division=0)
    train_recall = recall_score(train_true, train_preds, average='macro', zero_division=0)
    train_f1 = f1_score(train_true, train_preds, average='macro', zero_division=0)

    print(f"Training Loss: {avg_train_loss:.4f}, Accuracy: {train_accuracy:.2f}%")
    print(f"Precision: {train_precision:.4f}, Recall: {train_recall:.4f}, F1-score: {train_f1:.4f}")

    # --- Validation Phase ---
    model.eval()
    total_val_loss = 0.0
    val_true = []
    val_preds = []

    print(f"\nEpoch [{epoch_idx + 1}/{num_epochs}] - Validation")
    with torch.no_grad():
        with tqdm(enumerate(val_dataloader), total=len(val_dataloader), desc=f"Validation Epoch {epoch_idx + 1}") as val_bar:
            for batch_idx, (inputs, labels) in val_bar:
                if inputs.shape[0] != batch_size:
                    continue

                inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True)

                outputs = model(inputs)
                loss = criterion(outputs, labels)

                total_val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                val_true.extend(labels.cpu().numpy())
                val_preds.extend(predicted.cpu().numpy())

                val_bar.set_postfix(loss=loss.item())

    avg_val_loss = total_val_loss / len(val_dataloader)
    val_accuracy = accuracy_score(val_true, val_preds) * 100

    print(f"Validation Loss: {avg_val_loss:.4f}, Accuracy: {val_accuracy:.2f}%")

    # --- Save the model if validation accuracy improved ---
    if val_accuracy > best_val_acc:
        best_val_acc = val_accuracy
        torch.save(model.state_dict(), model_saving_path + model_name + ".pth")
        print(f"Model saved at epoch {epoch_idx + 1} with validation accuracy: {val_accuracy:.2f}%")
        counter = 0  # Reset counter when improvement happens
    else:
        counter += 1
        print(f"Early stopping counter: {counter} out of {patience}")
        if counter >= patience:
            print("Early stopping triggered!")
            break #stop the training loop.

    model.train()  # Switch back to training mode

In [None]:
# Load the best model weights
state_dict = torch.load(model_saving_path + model_name + ".pth")

# If trained with DataParallel, remove "module." prefix
if "module." in list(state_dict.keys())[0]:  
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        new_state_dict[k.replace("module.", "")] = v
    state_dict = new_state_dict  # Update state dict

model.load_state_dict(state_dict)
model.to(device)  # Move to GPU
model.eval()  # Set model to evaluation mode

total_test_loss = 0.0
test_true = []
test_preds = []

print("\nTesting Phase")
with torch.no_grad():
    with tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc="Testing") as test_bar:
        for batch_idx, (inputs, labels) in test_bar:
            if inputs.shape[0] != batch_size:
                continue
            inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            total_test_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            test_true.extend(labels.cpu().numpy())
            test_preds.extend(predicted.cpu().numpy())
            
            test_bar.set_postfix(loss=loss.item())

avg_test_loss = total_test_loss / len(test_dataloader)
test_accuracy = accuracy_score(test_true, test_preds) * 100
test_precision = precision_score(test_true, test_preds, average='macro', zero_division=0)
test_recall = recall_score(test_true, test_preds, average='macro', zero_division=0)
test_f1 = f1_score(test_true, test_preds, average='macro', zero_division=0)

print(f"Test Loss: {avg_test_loss:.4f}, Accuracy: {test_accuracy:.2f}%")
print(f"Precision: {test_precision:.4f}, Recall: {test_recall:.4f}, F1-score: {test_f1:.4f}")
