In [2]:
import os
import mne
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
import torch.optim as optim
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from tqdm import tqdm
import warnings
import random

warnings.filterwarnings("ignore")

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Global parameters (you can adjust these later)
T = 160
delta = 4
Delta = 8
Gamma = 30
channels = ['Oz..', 'T7..', 'Cz..']


In [3]:
def process_edf(file_path, channels, T=160, delta=4, Delta=8, Gamma=30):
    raw = mne.io.read_raw_edf(file_path, preload=True, verbose=False)
    raw.pick(channels)
    data = raw.get_data().T.astype(np.float32)
    
    # Normalize the data.
    scaler = MinMaxScaler()
    normalized_data = scaler.fit_transform(data).astype(np.float32)
    
    n_samples = normalized_data.shape[0]
    windows = [normalized_data[start:start+T, :] 
               for start in range(0, n_samples - T + 1, delta)]
    windows = np.array(windows, dtype=np.float32)
    
    # Create augmented samples by grouping Gamma consecutive windows with a step of Delta.
    augmented = [windows[i:i+Gamma] 
                 for i in range(0, len(windows) - Gamma + 1, Delta)]
    augmented = np.array(augmented, dtype=np.float32)
    
    return augmented


In [None]:
class PhysionetAugmentedDataset(Dataset):
    def __init__(self, subject_ids, channels, T=160, delta=4, Delta=8, Gamma=30):
        self.subject_ids = subject_ids
        self.channels = channels
        self.T = T
        self.delta = delta
        self.Delta = Delta
        self.Gamma = Gamma
        self.indices = []
        
        for label, subject in enumerate(tqdm(subject_ids, desc="Building index")):
            for r in range(1, 3):
                file_path = f'./files/S{subject:03d}/S{subject:03d}R{r:02d}.edf'
                try:
                    augmented = process_edf(file_path, channels, T, delta, Delta, Gamma)
                except Exception as e:
                    print(f"Error processing {file_path}: {e}")
                    continue
                num_augmented = augmented.shape[0]
                for aug_idx in range(num_augmented):
                    self.indices.append((subject, r, aug_idx, file_path, label))
    
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        subject, r, aug_idx, file_path, label = self.indices[idx]
        augmented = process_edf(file_path, self.channels, self.T, self.delta, self.Delta, self.Gamma)
        sample = augmented[aug_idx]
        sample_tensor = torch.tensor(sample, dtype=torch.float32)
        return sample_tensor, label
    
subjects = list(range(1, 91))
dataset = PhysionetAugmentedDataset(subjects, channels, T, delta, Delta, Gamma)


Building index: 100%|██████████| 90/90 [00:05<00:00, 17.92it/s]


In [5]:
# Define subject groups:
all_indices = list(range(len(dataset)))
labels = [dataset.indices[i][-1] for i in all_indices]  # label is the last element in the index tuple

# Perform a stratified split (80% train, 20% test)
train_idx, test_idx = train_test_split(all_indices, test_size=0.2, stratify=labels, random_state=42)

# Create subset datasets using the indices.
train_dataset = Subset(dataset, train_idx)
test_dataset = Subset(dataset, test_idx)

# Create DataLoaders for both splits.
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)


In [None]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 128, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.maxpool = nn.MaxPool2d(kernel_size=(1,2))
        self.maxpool2 = nn.MaxPool2d(kernel_size=(2,2))
        self.flatten = nn.Flatten()
        self.dense1 = nn.Linear(153600, 1024)
        self.dense2 = nn.Linear(1024, 90)
        self.dropout = nn.Dropout2d(0.5) #hyperparameter

    def forward(self, x):
        x = x.permute(0, 3, 1, 2)
        x = F.relu(self.conv1(x))
        x = self.maxpool(x)
        
        x = F.relu(self.conv2(x))
        x = self.maxpool(x)

        x = F.relu(self.conv3(x))
        x = self.maxpool2(x)
        
        x = self.flatten(x)

        x = F.relu(self.dense1(x))
        x = self.dropout(x)
        x = self.dense2(x)
        return x
    
model = CNN().to(device)

In [10]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

num_epochs = 20  # Adjust as needed

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    all_train_preds = []
    all_train_labels = []

    for batch_data, batch_labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        batch_data = batch_data.to(device)
        batch_labels = batch_labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(batch_data)
        loss = criterion(outputs, batch_labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()

        preds = outputs.argmax(dim=1)
        all_train_preds.extend(preds.cpu().numpy())
        all_train_labels.extend(batch_labels.cpu().numpy())
    
    train_accuracy = accuracy_score(all_train_labels, all_train_preds)
    avg_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1}: Loss = {avg_loss:.4f}, Train Accuracy = {train_accuracy:.4f}")

Epoch 1/20: 100%|██████████| 668/668 [05:10<00:00,  2.15it/s]


Epoch 1: Loss = 3.1628, Train Accuracy = 0.2207


Epoch 2/20: 100%|██████████| 668/668 [05:11<00:00,  2.14it/s]


Epoch 2: Loss = 1.9437, Train Accuracy = 0.4610


Epoch 3/20: 100%|██████████| 668/668 [05:13<00:00,  2.13it/s]


Epoch 3: Loss = 1.5204, Train Accuracy = 0.5532


Epoch 4/20: 100%|██████████| 668/668 [05:13<00:00,  2.13it/s]


Epoch 4: Loss = 1.2239, Train Accuracy = 0.6283


Epoch 5/20: 100%|██████████| 668/668 [05:13<00:00,  2.13it/s]


Epoch 5: Loss = 0.9922, Train Accuracy = 0.6884


Epoch 6/20: 100%|██████████| 668/668 [05:13<00:00,  2.13it/s]


Epoch 6: Loss = 0.8074, Train Accuracy = 0.7420


Epoch 7/20: 100%|██████████| 668/668 [05:14<00:00,  2.12it/s]


Epoch 7: Loss = 0.6917, Train Accuracy = 0.7752


Epoch 8/20: 100%|██████████| 668/668 [05:17<00:00,  2.10it/s]


Epoch 8: Loss = 0.5694, Train Accuracy = 0.8145


Epoch 9/20: 100%|██████████| 668/668 [05:17<00:00,  2.10it/s]


Epoch 9: Loss = 0.4778, Train Accuracy = 0.8409


Epoch 10/20: 100%|██████████| 668/668 [05:16<00:00,  2.11it/s]


Epoch 10: Loss = 0.4145, Train Accuracy = 0.8607


Epoch 11/20: 100%|██████████| 668/668 [05:16<00:00,  2.11it/s]


Epoch 11: Loss = 0.3492, Train Accuracy = 0.8826


Epoch 12/20: 100%|██████████| 668/668 [05:16<00:00,  2.11it/s]


Epoch 12: Loss = 0.3011, Train Accuracy = 0.8996


Epoch 13/20: 100%|██████████| 668/668 [05:12<00:00,  2.13it/s]


Epoch 13: Loss = 0.2642, Train Accuracy = 0.9101


Epoch 14/20: 100%|██████████| 668/668 [05:13<00:00,  2.13it/s]


Epoch 14: Loss = 0.2415, Train Accuracy = 0.9198


Epoch 15/20: 100%|██████████| 668/668 [05:13<00:00,  2.13it/s]


Epoch 15: Loss = 0.1945, Train Accuracy = 0.9345


Epoch 16/20: 100%|██████████| 668/668 [05:12<00:00,  2.14it/s]


Epoch 16: Loss = 0.1768, Train Accuracy = 0.9415


Epoch 17/20: 100%|██████████| 668/668 [05:13<00:00,  2.13it/s]


Epoch 17: Loss = 0.1555, Train Accuracy = 0.9497


Epoch 18/20: 100%|██████████| 668/668 [05:14<00:00,  2.12it/s]


Epoch 18: Loss = 0.1608, Train Accuracy = 0.9479


Epoch 19/20: 100%|██████████| 668/668 [05:15<00:00,  2.11it/s]


Epoch 19: Loss = 0.1330, Train Accuracy = 0.9550


Epoch 20/20: 100%|██████████| 668/668 [05:16<00:00,  2.11it/s]

Epoch 20: Loss = 0.1204, Train Accuracy = 0.9604





In [11]:
model.eval()
all_test_preds = []
all_test_labels = []
with torch.no_grad():
    for batch_data, batch_labels in tqdm(test_loader, desc="Evaluating"):
        batch_data = batch_data.to(device)
        batch_labels = batch_labels.to(device)
        outputs = model(batch_data)
        preds = outputs.argmax(dim=1)
        all_test_preds.extend(preds.cpu().numpy())
        all_test_labels.extend(batch_labels.cpu().numpy())
    
    test_accuracy = accuracy_score(all_test_labels, all_test_preds)
    print(f"Epoch {epoch+1}: Test Accuracy = {test_accuracy:.4f}")

Evaluating: 100%|██████████| 167/167 [01:15<00:00,  2.22it/s]

Epoch 20: Test Accuracy = 0.9079





In [None]:
model.load_state_dict(torch.load)
model.eval()

embedding_output = None
def hook_fn(module, input, output):
    global embedding_output
    embedding_output = output.detach()

hook_handle = model.dropout.register_forward_hook(hook_fn)

def get_embedding(sample):
    global embedding_output
    if sample.dim() == 3:
        sample = sample.unsqueeze(0)
    sample = sample.to(device)

    with torch.no_grad():
        _ = model(sample)

    emb = embedding_output.unsqueeze(0)
    norm = torch.norm(emb, p=2)
    if norm > 0:
        emb = emb / norm
    return emb.cpu().numpy()

In [None]:
enrolled_subjects = {
    'subject1': [sample1, sample2, sample3],
    'subject2': [sample1, sample2, sample3],
    # Add additional subjects as needed.
}

subject_fingerprints = {}
for subject_id, samples in enrolled_subjects.items():
    embeddings = [get_embedding(sample) for sample in samples]
    embeddings = np.stack(embeddings, axis=0)
    # Average embeddings across samples to create a fingerprint.
    avg_embedding = np.mean(embeddings, axis=0)
    # Normalize the fingerprint.
    avg_embedding = avg_embedding / np.linalg.norm(avg_embedding)
    subject_fingerprints[subject_id] = avg_embedding

In [None]:
def classify_sample(sample, subject_fingerprints, threshold=0.275):
    """
    Classify a new EEG sample by comparing its embedding to enrolled subject fingerprints.
    The classification uses cosine distance with a threshold margin.
    
    Args:
        sample (torch.Tensor): New EEG sample tensor.
        subject_fingerprints (dict): Mapping of subject IDs to fingerprint embeddings.
        threshold (float): Threshold margin for cosine distance.
    
    Returns:
        tuple: (predicted_subject, similarity_scores, cosine_distances)
               predicted_subject is the subject if the minimum cosine distance is below threshold,
               otherwise it is set to 'dissimilar' (or 'unknown').
    """
    emb = get_embedding(sample)
    similarities = {}
    cosine_distances = {}
    
    # Compute cosine similarity and derive cosine distance.
    for subject_id, fingerprint in subject_fingerprints.items():
        similarity = np.dot(emb, fingerprint)
        distance = 1 - similarity  # cosine distance for normalized vectors.
        similarities[subject_id] = similarity
        cosine_distances[subject_id] = distance
    
    # Identify the subject with the minimum cosine distance.
    best_subject = min(cosine_distances, key=cosine_distances.get)
    best_distance = cosine_distances[best_subject]
    
    # If the best distance is less than the threshold, we consider it a match.
    if best_distance < threshold:
        predicted_subject = best_subject
    else:
        predicted_subject = "dissimilar"  # or "unknown"
    
    return predicted_subject, similarities, cosine_distances

In [None]:
new_sample = ...  # your new EEG sample tensor here
predicted_subject, similarity_scores = classify_sample(new_sample, subject_fingerprints)
print("Predicted Subject:", predicted_subject)
print("Cosine Similarity Scores:", similarity_scores)

# --- 8. Cleanup: Remove the Hook When Done ---
hook_handle.remove()

In [None]:
subject_fingerprints_new = {}
enrollment_runs = range(1, 4)  # use runs 1, 2, 3 for enrollment

for subject in range(91, 110):  # subjects 91 to 109
    embeddings = []
    for r in enrollment_runs:
        file_path = f'./files/S{subject:03d}/S{subject:03d}R{r:02d}.edf'
        try:
            augmented = process_edf(file_path, channels, T, delta, Delta, Gamma)
        except Exception as e:
            print(f"Error processing {file_path}: {e}")
            continue
        # For each augmented sample, extract an embedding.
        for i in range(augmented.shape[0]):
            sample = augmented[i]  # shape: (Gamma, T, n_channels)
            # Convert to torch.Tensor
            sample_tensor = torch.tensor(sample, dtype=torch.float32)
            emb = get_embedding(sample_tensor)
            embeddings.append(emb)
    if len(embeddings) == 0:
        print(f"No valid embeddings for subject {subject}")
        continue
    embeddings = np.stack(embeddings, axis=0)
    avg_embedding = np.mean(embeddings, axis=0)
    avg_embedding = avg_embedding / np.linalg.norm(avg_embedding)
    subject_fingerprints_new[subject] = avg_embedding

print("Enrollment complete for subjects 91-109.")

# --- Validation ---
# We use runs 4-6 for validation.
validation_runs = range(4, 7)
results = {}

for subject in range(91, 110):
    subject_results = []
    for r in validation_runs:
        file_path = f'./files/S{subject:03d}/S{subject:03d}R{r:02d}.edf'
        try:
            augmented = process_edf(file_path, channels, T, delta, Delta, Gamma)
        except Exception as e:
            print(f"Error processing {file_path}: {e}")
            continue
        for i in range(augmented.shape[0]):
            sample = augmented[i]  # shape: (Gamma, T, n_channels)
            sample_tensor = torch.tensor(sample, dtype=torch.float32)
            # classify_sample returns (predicted_subject, similarities, cosine_distances)
            predicted_subject, similarities, cosine_distances = classify_sample(sample_tensor, subject_fingerprints_new, threshold=0.275)
            distance = cosine_distances.get(predicted_subject, None)
            subject_results.append({
                'run': r,
                'aug_idx': i,
                'predicted': predicted_subject,
                'distance': distance
            })
            print(f"Subject {subject:03d} Run {r:02d} Aug {i}: predicted = {predicted_subject}, cosine distance = {distance:.4f}")
    results[subject] = subject_results

# --- Optional Accuracy Calculation ---
# If the ground truth for validation is that the enrolled subject should match, compute accuracy.
correct = 0
total = 0
for subject in results:
    for res in results[subject]:
        total += 1
        if res['predicted'] == subject:
            correct += 1

print(f"Validation Accuracy: {correct/total:.2%} ({correct}/{total})")