In [None]:
import os
import mne
import numpy as np
import torch
from sklearn.preprocessing import MinMaxScaler
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import warnings
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

warnings.filterwarnings("ignore")

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Global parameters (you can adjust these later)
T = 160
delta = 4
Delta = 8
Gamma = 30
channels = ['Oz..', 'T7..', 'Cz..']


In [None]:
def process_edf(file_path, channels, T=160, delta=4, Delta=8, Gamma=30):
    raw = mne.io.read_raw_edf(file_path, preload=True, verbose=False)
    raw.pick(channels)
    data = raw.get_data().T.astype(np.float32)
    
    # Normalize the data.
    scaler = MinMaxScaler()
    normalized_data = scaler.fit_transform(data).astype(np.float32)
    
    n_samples = normalized_data.shape[0]
    windows = [normalized_data[start:start+T, :] 
               for start in range(0, n_samples - T + 1, delta)]
    windows = np.array(windows, dtype=np.float32)
    
    # Create augmented samples by grouping Gamma consecutive windows with a step of Delta.
    augmented = [windows[i:i+Gamma] 
                 for i in range(0, len(windows) - Gamma + 1, Delta)]
    augmented = np.array(augmented, dtype=np.float32)
    
    return augmented


In [None]:
class PhysionetAugmentedDataset(Dataset):
    def __init__(self, subject_ids, channels, T=160, delta=4, Delta=8, Gamma=30):
        self.subject_ids = subject_ids
        self.channels = channels
        self.T = T
        self.delta = delta
        self.Delta = Delta
        self.Gamma = Gamma
        self.indices = []
        
        for label, subject in enumerate(tqdm(subject_ids, desc="Building index")):
            for r in range(1, 15):
                file_path = f'./files/S{subject:03d}/S{subject:03d}R{r:02d}.edf'
                try:
                    augmented = process_edf(file_path, channels, T, delta, Delta, Gamma)
                except Exception as e:
                    print(f"Error processing {file_path}: {e}")
                    continue
                num_augmented = augmented.shape[0]
                for aug_idx in range(num_augmented):
                    self.indices.append((subject, r, aug_idx, file_path, label))
    
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        subject, r, aug_idx, file_path, label = self.indices[idx]
        augmented = process_edf(file_path, self.channels, self.T, self.delta, self.Delta, self.Gamma)
        sample = augmented[aug_idx]
        sample_tensor = torch.tensor(sample, dtype=torch.float32)
        return sample_tensor, label


In [None]:
# Define subject groups:
alpha_subjects = list(range(1, 91))    # Training subjects
beta_subjects = list(range(91, 110))     # Test subjects

# Create datasets:
train_dataset = PhysionetAugmentedDataset(alpha_subjects, channels, T, delta, Delta, Gamma)
test_dataset = PhysionetAugmentedDataset(beta_subjects, channels, T, delta, Delta, Gamma)

# Create DataLoaders:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=0)


In [None]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 128, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.maxpool = nn.MaxPool2d(kernel_size=(1,2))
        self.maxpool = nn.MaxPool2d(kernel_size=(2,2))
        self.flatten = nn.Flatten()
        self.dense1 = nn.Linear(153600, 1024)
        self.dense2 = nn.Linear(1024, 90)
        self.dropout = nn.Dropout2d(0.5) #hyperparameter

    def forward(self, x):
        x = x.permute(0, 3, 1, 2)
        x = F.relu(self.conv1(x))
        x = self.maxpool(x)
        
        x = F.relu(self.conv2(x))
        x = self.maxpool(x)

        x = F.relu(self.conv3(x))
        x = self.maxpool2(x)
        
        x = self.flatten(x)

        x = F.relu(self.dense1(x))
        x = self.dropout(x)
        x = self.dense2(x)
        return x
    
model = CNN().to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 10  # Adjust as needed

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for batch_data, batch_labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        batch_data = batch_data.to(device)
        batch_labels = batch_labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(batch_data)
        loss = criterion(outputs, batch_labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    print(f"Epoch {epoch+1}, Loss: {running_loss / len(train_loader):.4f}")

In [None]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for batch_data, batch_labels in tqdm(test_loader, desc="Evaluating"):
        batch_data = batch_data.to(device)
        batch_labels = batch_labels.to(device)
        
        outputs = model(batch_data)
        _, predicted = torch.max(outputs.data, 1)
        total += batch_labels.size(0)
        correct += (predicted == batch_labels).sum().item()

print(f"Test Accuracy: {100 * correct / total:.2f}%")