# CRL Pre-Training with Linear Evalutaion

#### Importing Dependecies

In [4]:
import os
import numpy as np
import glob
import logging
from typing import Dict, Optional

eeg_data = {
    'train': {label: [] for label in range(5)},
    'validation': {label: [] for label in range(5)},
    'test': {label: [] for label in range(5)},
}

dataset_path = "./dset/Sleep-EDF-2018/npz/Fpz-Cz"
num_files_to_process = 200

npz_files = sorted(glob.glob(os.path.join(dataset_path, '*.npz')))
if not npz_files:
    raise FileNotFoundError(f"No .npz files found in {dataset_path}.")
if num_files_to_process is not None:
    npz_files = npz_files[:num_files_to_process]

# Extract subject indices from filenames
subject_indices = []
for npz_file in npz_files:
    basename = os.path.basename(npz_file)
    subject_idx = int(basename[3:5])
    subject_indices.append(subject_idx)
unique_subject_indices = list(set(subject_indices))

In [6]:
len(unique_subject_indices)

78

In [None]:


# Split subject indices into train, validation, test
from sklearn.model_selection import train_test_split
train_subjects, temp_subjects = train_test_split(unique_subject_indices, test_size=0.2)
validation_subjects, test_subjects = train_test_split(temp_subjects, test_size=0.5)

logger.info(f"Subjects split into train ({len(train_subjects)}), validation ({len(validation_subjects)}), test ({len(test_subjects)}).")

# Process files and assign to corresponding sets
for idx, npz_file in enumerate(npz_files, 1):
    try:
        basename = os.path.basename(npz_file)
        subject_idx = int(basename[3:5])

        # Determine the set for the current subject
        if subject_idx in train_subjects:
            set_name = 'train'
        elif subject_idx in validation_subjects:
            set_name = 'validation'
        elif subject_idx in test_subjects:
            set_name = 'test'
        else:
            logger.warning(f"Subject index {subject_idx} not found in any set.")
            continue

        with np.load(npz_file) as data:
            eeg_epochs, labels = data['x'], data['y']
            for label in range(5):
                eeg_data[set_name][label].extend(eeg_epochs[labels == label])
    except Exception as e:
        logger.error(f"Error processing {npz_file}: {e}")
    if idx % 10 == 0 or idx == len(npz_files):
        logger.info(f"Processed {idx}/{len(npz_files)} files.")

# Convert lists to numpy arrays
for set_name in eeg_data.keys():
    for label in eeg_data[set_name].keys():
        eeg_data[set_name][label] = np.array(eeg_data[set_name][label])

eeg_data

In [1]:
from utils import *

In [2]:
eeg_data = load_eeg_data("./dset/Sleep-EDF-2018/npz/Fpz-Cz", 200)

In [9]:
for key in eeg_data.keys():
    for i in range(5):
        print(f"{key} {i}: {len(eeg_data[key][i])}")

train 0: 55834
train 1: 16840
train 2: 53972
train 3: 10284
train 4: 20483
validation 0: 7091
validation 1: 2410
validation 2: 7889
validation 3: 1039
validation 4: 2563
test 0: 6607
test 1: 2141
test 2: 6790
test 3: 1677
test 4: 2669


In [2]:
import glob
import json
import random
import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline
import os
os.environ['OMP_NUM_THREADS'] = '4'  # Set to avoid memory leak warning on Windows with MKL

import umap
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader


from sklearn.manifold import TSNE, trustworthiness
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    silhouette_score, davies_bouldin_score, adjusted_rand_score,
    adjusted_mutual_info_score
)

from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report

from scipy import signal
from scipy.ndimage import shift
from scipy.stats import entropy
from scipy.optimize import linear_sum_assignment


import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) 
warnings.filterwarnings("ignore", category=UserWarning) 

### Loading EEG Data

In [1]:
# config_path = 'config.json'
# with open(config_path, 'r') as f:
#     config = json.load(f)

# dset_cfg = config['dataset']
# root_dir = dset_cfg['root_dir']
# dset_name = dset_cfg['name']
# eeg_channel = dset_cfg['eeg_channel']
# num_splits = dset_cfg['num_splits']
# seq_len = dset_cfg['seq_len']
# target_idx = dset_cfg['target_idx']

# print("Configuration Loaded:")
# print(json.dumps(dset_cfg, indent=4))

import glob
import json
import random
import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline
import os

root_dir = "./"
dset_name = "Sleep-EDF-2018"
eeg_channel = "Fpz-Cz"
dataset_path = os.path.join(root_dir, 'dset', dset_name, 'npz', eeg_channel)
npz_files = sorted(glob.glob(os.path.join(dataset_path, '*.npz')))

print(f"Found {len(npz_files)} npz files in {dataset_path}.")

sample_file = npz_files[0]
data = np.load(sample_file)
x = data['x']
y = data['y']
print(f"Loaded data from {sample_file}:")
print("EEG data shape (epochs, samples):", x.shape)
print("Labels shape:", y.shape)

unique, counts = np.unique(y, return_counts=True)

# % of each class:
print("Percentage of each class :")
print(np.asarray((unique, counts*100/np.sum(counts))).T)

fs = 100
epoch_length = 30
samples_per_epoch = fs * epoch_length
num_classes = 5
epochs_per_class = 10

eeg_data = {i: [] for i in range(5)}

for file_idx, file in enumerate(npz_files, start=1):
    with np.load(file) as data:
        x = data['x']  # EEG epochs
        y = data['y']  # Corresponding labels

        for label in range(5):
            epochs = x[y == label]
            eeg_data[label].extend(epochs)
    
    if file_idx % 10 == 0 or file_idx == len(npz_files):
        print(f"Processed {file_idx}/{len(npz_files)} files.")
    break                                                                        # Remove this break

for label in eeg_data:
    eeg_data[label] = np.array(eeg_data[label])
    print(f"#EEG Epochs for Label {label} : {len(eeg_data[label])}")

Found 152 npz files in ./dset\Sleep-EDF-2018\npz\Fpz-Cz.
Loaded data from ./dset\Sleep-EDF-2018\npz\Fpz-Cz\SC4001E0.npz:
EEG data shape (epochs, samples): (841, 3000)
Labels shape: (841,)
Percentage of each class :
[[ 0.         22.35434007]
 [ 1.          6.89655172]
 [ 2.         29.72651605]
 [ 3.         26.15933413]
 [ 4.         14.86325803]]
#EEG Epochs for Label 0 : 188
#EEG Epochs for Label 1 : 58
#EEG Epochs for Label 2 : 250
#EEG Epochs for Label 3 : 220
#EEG Epochs for Label 4 : 125


## CRL Pre-Training of Encoder

### Data Augmentations

In [2]:
from data_augmentations import (
    RandomAmplitudeScale,
    RandomDCShift,
    RandomTimeShift,
    RandomZeroMasking,
    RandomAdditiveGaussianNoise,
    RandomBandStopFilter,
    TimeWarping,
    TimeReverse,
    Permutation,
    CutoutResize,
    TailoredMixup,
    AverageFilter,
    SignFlip
)



# class RandomAmplitudeScale:

#     def __init__(self, range=(0.5, 2.0), p=0.5):
#         self.range = range
#         self.p = p
#         self.requires_x_random = False

#     def __call__(self, x, x_random=None):
#         if torch.rand(1) < self.p:
#             scale = random.uniform(self.range[0], self.range[1])
#             # print(f"Scaled by {scale}")
#             return x * scale
#         return x
    

# class RandomDCShift:
    
#     def __init__(self, range=(-10.0, 10.0), p=0.5):
#         self.range = range
#         self.p = p
#         self.requires_x_random = False


#     def __call__(self, x, x_random=None):
#         if torch.rand(1) < self.p:
#             shift_value = random.uniform(self.range[0], self.range[1])
#             # print(f"Shifted by {shift_value}")
#             return x + shift_value
#         return x
    
# class RandomTimeShift:

#     def __init__(self, range=(-300, 300), mode='constant', cval=0.0, p=0.5):
#         self.range = range
#         self.mode = mode
#         self.cval = cval
#         self.p = p
#         self.requires_x_random = False


#     def __call__(self, x, x_random=None):
#         if torch.rand(1) < self.p:
#             t_shift = random.randint(self.range[0], self.range[1])
#             x_shifted = shift(input=x, shift=t_shift, mode=self.mode, cval=self.cval)
#             # print(f"Time Shifted by : {t_shift}")
#             return x_shifted
#         return x
    

# class RandomZeroMasking:
        
#     def __init__(self, range=(0, 300), p=0.5):
#         self.range = range
#         self.p = p
#         self.requires_x_random = False


#     def __call__(self, x, x_random=None):
#         if torch.rand(1) < self.p:
#             mask_len = random.randint(self.range[0], self.range[1])
#             # print(f"Mask Length : {mask_len}")
#             random_pos = random.randint(0, len(x) - mask_len)
#             # print(f"Position : {random_pos}")
#             mask = np.ones_like(x)
#             mask[random_pos:random_pos + mask_len] = 0
#             return x * mask
#         return x
    

# class RandomAdditiveGaussianNoise:
        
#     def __init__(self, range=(0.0, 0.2), p=0.5):
#         self.range = range
#         self.p = p
#         self.requires_x_random = False


#     def __call__(self, x, x_random=None):
#         if torch.rand(1) < self.p:
#             sigma = random.uniform(self.range[0], self.range[1])
#             noise = np.random.normal(0, sigma, x.shape)
#             # print(f"Gaussian Noise std = {sigma}")
#             return x + noise
#         return x
    

# class RandomBandStopFilter:

#     def __init__(self, range=(0.5, 30.0), band_width=2.0, sampling_rate=100.0, p=0.5):
#         self.range = range
#         self.band_width = band_width
#         self.sampling_rate = sampling_rate
#         self.p = p
#         self.requires_x_random = False


#     def __call__(self, x, x_random=None):
#         if torch.rand(1) < self.p:
#             low_freq = random.uniform(self.range[0], self.range[1] - self.band_width)
#             center_freq = low_freq + self.band_width / 2.0
#             # print(f"Central Freq : {center_freq}")
#             b, a = signal.iirnotch(center_freq, Q=center_freq / self.band_width, fs=self.sampling_rate)
#             x_filtered = signal.lfilter(b, a, x)
#             return x_filtered
#         return x
    

# class TimeWarping:
    
#     def __init__(self, n_segments=4, scale_range=(0.5, 2.0), p=0.5):
#         self.n_segments = n_segments
#         self.scale_range = scale_range
#         self.p = p
#         self.requires_x_random = False

        
#     def __call__(self, x, x_random=None):
#         if torch.rand(1) < self.p:
#             L = len(x) #3000

#             segment_length = L // self.n_segments
#             segments = []
#             for i in range(self.n_segments):
#                 start = i * segment_length
#                 end = start + segment_length if i < self.n_segments -1 else L
#                 Si = x[start:end]


#                 omega = random.uniform(self.scale_range[0], self.scale_range[1])
#                 new_length = int(len(Si) * omega)

#                 if new_length < 1:
#                     new_length = 1

#                 Si_transformed = signal.resample(Si, new_length)
#                 segments.append(Si_transformed)
                
#             x_aug = np.concatenate(segments)
#             x_aug = signal.resample(x_aug, L)
#             return x_aug
        
#         else:
#             return x
        

# class TimeReverse:
    
#     def __init__(self, p=0.5):
#         self.p = p
#         self.requires_x_random = False

        
#     def __call__(self, x, x_random=None):
#         if torch.rand(1) < self.p:
#             return np.flip(x).copy()
#         else:
#             return x
        

# class Permutation:
    
#     def __init__(self, n_segments=4, p=0.5):
#         self.n_segments = n_segments
#         self.p = p
#         self.requires_x_random = False

        
#     def __call__(self, x, x_random=None):
#         if torch.rand(1) < self.p:
#             L = len(x)

#             segment_length = L // self.n_segments
#             segments = []
#             indices = []

#             for i in range(self.n_segments):
#                 start = i * segment_length
#                 end = start + segment_length if i < self.n_segments -1 else L
#                 Si = x[start:end]
#                 segments.append(Si)
#                 indices.append(i)

#             random.shuffle(indices)
#             shuffled_segments = [segments[i] for i in indices]
#             x_aug = np.concatenate(shuffled_segments)
#             return x_aug
        
#         else:
#             return x
        

# class CutoutResize:
    
#     def __init__(self, n_segments=4, p=0.5):
#         self.n_segments = n_segments
#         self.p = p
#         self.requires_x_random = False

        
#     def __call__(self, x, x_random=None):
#         if torch.rand(1) < self.p:
#             L = len(x)
#             segment_length = L // self.n_segments
#             segments = []
#             for i in range(self.n_segments):
#                 start = i * segment_length
#                 end = start + segment_length if i < self.n_segments -1 else L
#                 Si = x[start:end]
#                 segments.append(Si)

#             r = random.randint(0, self.n_segments - 1)
#             del segments[r]

#             x_aug = np.concatenate(segments)
#             x_aug = signal.resample(x_aug, L)
#             return x_aug
        
#         else:
#             return x
        

# class AverageFilter:
    
#     def __init__(self, k_range=(3, 10), p=0.5):
#         self.k_range = k_range
#         self.p = p
#         self.requires_x_random = False

        
#     def __call__(self, x, x_random=None):
#         if torch.rand(1) < self.p:
#             k = random.randint(self.k_range[0], self.k_range[1])
#             kernel = np.ones(k) / k
#             x_aug = np.convolve(x, kernel, mode='same')
#             return x_aug
#         else:
#             return x
        

# class SignFlip:
    
#     def __init__(self, p=0.5):
#         self.p = p
#         self.requires_x_random = False

        
#     def __call__(self, x, x_random=None):
#         if torch.rand(1) < self.p:
#             return (-x).copy()
#         else:
#             return x
        

# class TailoredMixup:
    
#     def __init__(self, p=0.5, fs=100, beta=0.5):
#         """
#         Parameters:
#         - p: probability of applying the mixup.
#         - fs: sampling frequency of the signal.
#         - beta: parameter for sampling mixup coefficients.
#         """
#         self.p = p
#         self.fs = fs
#         self.beta = beta  # For simplicity, we'll use a fixed beta value here.
#         self.requires_x_random = True  # Indicate that this augmentation requires x_random

        
#     def __call__(self, x_anchor, x_random=None):
#         """
#         Apply the tailored mixup between x_anchor and x_random.
        
#         Parameters:
#         - x_anchor: the anchor EEG signal (1D numpy array).
#         - x_random: a randomly chosen EEG signal to mix with (1D numpy array).
        
#         Returns:
#         - x_aug: the augmented EEG signal.
#         """
#         if torch.rand(1) < self.p and x_random is not None:
#             # Compute FFT of both signals
#             X_anchor = np.fft.fft(x_anchor)
#             X_random = np.fft.fft(x_random)
            
#             # Obtain magnitude and phase
#             A_anchor = np.abs(X_anchor)
#             P_anchor = np.angle(X_anchor)
#             A_random = np.abs(X_random)
#             P_random = np.angle(X_random)
            
#             # Mixup coefficients for magnitude and phase
#             lambda_A = random.uniform(self.beta, 1.0)
#             lambda_P = random.uniform(self.beta, 1.0)
            
#             # Mix magnitude
#             A_mix = lambda_A * A_anchor + (1 - lambda_A) * A_random
            
#             # Compute phase difference
#             delta_theta = P_anchor - P_random
#             delta_theta = (delta_theta + np.pi) % (2 * np.pi) - np.pi  # Wrap between [-π, π]
            
#             # Adjust phase
#             P_mix = P_anchor - delta_theta * (1 - lambda_P)
            
#             # Reconstruct the complex spectrum
#             X_mix = A_mix * np.exp(1j * P_mix)
            
#             # Inverse FFT to get the augmented signal
#             x_aug = np.fft.ifft(X_mix).real  # Take the real part
            
#             return x_aug
#         else:
#             return x_anchor.copy()

### Dataset Class

In [25]:
from sup_dataset import SupervisedEEGDataset

In [24]:
from crl_dataset import ContrastiveEEGDataset


# class ContrastiveEEGDataset(Dataset):
#     def __init__(self, eeg_data, augmentations=None, return_labels=False):
#         """
#         eeg_data: Dictionary containing EEG signals per class.
#         augmentations: A list of augmentation callables.
#         return_labels: Whether to return labels in __getitem__.
#         """
#         self.data = []
#         self.labels = []
#         for label, signals in eeg_data.items():
#             self.data.extend(signals)
#             self.labels.extend([label] * len(signals))
#         self.data = np.array(self.data)
#         self.labels = np.array(self.labels)
#         self.augmentations = augmentations if augmentations is not None else []
#         self.return_labels = return_labels
        
#         # Determine if any augmentation requires x_random
#         self.requires_x_random = any(getattr(aug, 'requires_x_random', False) for aug in self.augmentations)
        
#     def __len__(self):
#         return len(self.data)
    
#     def __getitem__(self, idx):
#         x = self.data[idx]
#         y = self.labels[idx]
#         x_i = x.copy()
#         x_j = x.copy()
        
#         # If any augmentation requires x_random, sample it
#         if self.requires_x_random:
#             idx_random = random.randint(0, len(self.data) - 1)
#             x_random = self.data[idx_random]
#         else:
#             x_random = None
        
#         # Apply augmentations to x_i
#         for aug in self.augmentations:
#             if getattr(aug, 'requires_x_random', False):
#                 x_i = aug(x_i, x_random.copy())
#             else:
#                 x_i = aug(x_i)
        
#         # Apply augmentations to x_j
#         for aug in self.augmentations:
#             if getattr(aug, 'requires_x_random', False):
#                 x_j = aug(x_j, x_random.copy())
#             else:
#                 x_j = aug(x_j)
        
#         # Convert to torch tensors and add channel dimension
#         x_i = torch.tensor(x_i, dtype=torch.float32).unsqueeze(0)
#         x_j = torch.tensor(x_j, dtype=torch.float32).unsqueeze(0)
        
#         if self.return_labels:
#             y = torch.tensor(y, dtype=torch.long)
#             return x_i, x_j, y
#         else:
#             return x_i, x_j

### SimpleSLeepNet : Encoder Backbone

In [12]:
from models import SimpleSleepNet

# class SimpleSleepNet(nn.Module):
#     def __init__(self, latent_dim=128, dropout=0.5):
#         super(SimpleSleepNet, self).__init__()
        
#         self.latent_dim = latent_dim
#         self.dropout = nn.Dropout(p=dropout)
        
#         self.conv_path = nn.Sequential(
#             # First Convolutional Block
#             nn.Conv1d(in_channels=1, out_channels=32, kernel_size=64, stride=8, padding=32, bias=False),
#             nn.BatchNorm1d(32),
#             nn.ReLU(),
#             self.dropout,
            
#             # Second Convolutional Block
#             nn.Conv1d(in_channels=32, out_channels=64, kernel_size=32, stride=4, padding=16, bias=False),
#             nn.BatchNorm1d(64),
#             nn.ReLU(),
#             self.dropout,
            
#             # Third Convolutional Block
#             nn.Conv1d(in_channels=64, out_channels=128, kernel_size=16, stride=2, padding=8, bias=False),
#             nn.BatchNorm1d(128),
#             nn.ReLU(),
#             self.dropout
#         )
        
#         # Fully Connected Layer for Embedding
#         self.fc = nn.Sequential(
#             nn.Linear(128, self.latent_dim),
#             nn.BatchNorm1d(self.latent_dim),
#             nn.ReLU(),
#             self.dropout
#         )
        
#     def forward(self, x):
#         x = self.conv_path(x)          # (Batch, 128, L')
#         x = F.adaptive_avg_pool1d(x, 1)  # (Batch, 128, 1)
#         x = x.view(x.size(0), -1)     # (Batch, 128)
#         x = self.fc(x)                 # (Batch, latent_dim)
#         x = F.normalize(x, p=2, dim=1)  # Normalize embeddings
#         return x

### NT-Xent Loss

In [None]:
# from loss import nt_xent_loss


# def nt_xent_loss(z_i, z_j, temperature=0.5):
#     """
#     Computes the NT-Xent loss as introduced in SimCLR.

#     Parameters:
#     - z_i, z_j: Normalized embeddings of two augmented views (batch_size, embedding_dim).
#     - temperature: Temperature scaling factor.

#     Returns:
#     - loss: The computed NT-Xent loss.
#     """
#     batch_size = z_i.size(0)
#     device = z_i.device

#     # Normalize embeddings
#     z_i = nn.functional.normalize(z_i, dim=1)
#     z_j = nn.functional.normalize(z_j, dim=1)

#     # Concatenate embeddings
#     z = torch.cat([z_i, z_j], dim=0)  # Shape: (2*batch_size, embedding_dim)

#     # Compute similarity matrix
#     sim_matrix = torch.matmul(z, z.T) / temperature  # Shape: (2*batch_size, 2*batch_size)

#     # Remove self-similarities
#     mask = torch.eye(2 * batch_size, dtype=torch.bool).to(device)
#     sim_matrix.masked_fill_(mask, -float('inf'))

#     # Positive sample indices
#     labels = torch.arange(batch_size).to(device)
#     labels = torch.cat([labels + batch_size, labels], dim=0)

#     # Compute loss
#     loss = nn.CrossEntropyLoss()(sim_matrix, labels)

#     return loss

### CRL Pre-Training

In [6]:
from crl_pretraining import train_contrastive


# def train_contrastive(model, dataloader, optimizer, num_epochs=5): #add temperature and device parameters to this class
#     model.train()
#     for epoch in range(num_epochs):
#         total_loss = 0.0
#         for x_i, x_j in dataloader:
#             x_i = x_i.to(device)
#             x_j = x_j.to(device)
            
#             optimizer.zero_grad()
            
#             z_i = model(x_i)
#             z_j = model(x_j)
            
#             loss = nt_xent_loss(z_i, z_j, temperature=0.5)
#             loss.backward()
#             optimizer.step()
            
#             total_loss += loss.item()
        
#         avg_loss = total_loss / len(dataloader)
#         print(f"Contrastive Training - Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")

### Benchmarking Latent Space

In [None]:
# class LatentSpaceEvaluator:
#     def __init__(
#         self, model, dataloader, device='cpu', umap=True, pca=True, tsne=True,
#         visualize=True, compute_metrics=True, n_clusters=5
#     ):
#         self.model = model.to(device)
#         self.dataloader = dataloader
#         self.device = device
#         self.use_umap = umap
#         self.use_pca = pca
#         self.use_tsne = tsne
#         self.visualize = visualize
#         self.compute_metrics = compute_metrics
#         self.n_clusters = n_clusters
#         self.embeddings = None
#         self.labels = None
#         self.embeddings_scaled = None
#         self.results = {}
#         self.unique_labels = None

#     def extract_embeddings(self):
#         self.model.eval()
#         embeddings = []
#         labels = []
#         with torch.no_grad():
#             for batch in self.dataloader:
#                 x_i, x_j, y = batch
#                 x_i = x_i.to(self.device)
#                 z_i = self.model(x_i)
#                 embeddings.append(z_i.cpu().numpy())
#                 labels.append(y.numpy())
#         self.embeddings = np.concatenate(embeddings, axis=0)
#         self.labels = np.concatenate(labels, axis=0)
#         self.unique_labels = np.unique(self.labels)
#         if len(self.unique_labels) != self.n_clusters:
#             raise ValueError(
#                 f"Expected {self.n_clusters} unique labels, but found {len(self.unique_labels)}: {self.unique_labels}"
#             )
#         scaler = StandardScaler()
#         self.embeddings_scaled = scaler.fit_transform(self.embeddings)

#     @staticmethod
#     def purity_score(y_true, y_pred):
#         contingency_matrix = np.zeros((len(np.unique(y_true)), len(np.unique(y_pred))))
#         for i, true_label in enumerate(np.unique(y_true)):
#             for j, pred_label in enumerate(np.unique(y_pred)):
#                 contingency_matrix[i, j] = np.sum(
#                     (y_true == true_label) & (y_pred == pred_label)
#                 )
#         row_ind, col_ind = linear_sum_assignment(-contingency_matrix)
#         purity = contingency_matrix[row_ind, col_ind].sum() / np.sum(contingency_matrix)
#         return purity

#     @staticmethod
#     def compute_average_entropy(y_true, y_pred):
#         cluster_labels = defaultdict(list)
#         for label, cluster in zip(y_true, y_pred):
#             cluster_labels[cluster].append(label)
#         entropy_total = 0
#         for labels_in_cluster in cluster_labels.values():
#             label_counts = np.bincount(labels_in_cluster)
#             probabilities = label_counts / len(labels_in_cluster)
#             cluster_entropy = entropy(probabilities, base=2)
#             entropy_total += cluster_entropy * len(labels_in_cluster)
#         average_entropy = entropy_total / len(y_true)
#         return average_entropy

#     def apply_dimensionality_reduction(self):
#         if self.use_tsne:
#             tsne = TSNE(
#                 n_components=2,
#                 random_state=42,
#                 perplexity=30,
#                 n_iter=1000,
#                 learning_rate='auto',
#                 init='random',
#                 verbose=1
#             )
#             embeddings_tsne = tsne.fit_transform(self.embeddings_scaled)
#             self.results['t-SNE'] = {'embeddings': embeddings_tsne}

#         if self.use_umap:
#             with warnings.catch_warnings():
#                 warnings.filterwarnings("ignore", category=UserWarning, module="umap")
#                 umap_reducer = umap.UMAP(
#                     n_components=2,
#                     random_state=42,
#                     n_neighbors=15,
#                     min_dist=0.1,
#                     metric='euclidean',
#                     n_jobs=1  # Explicitly set to suppress warning
#                 )
#                 embeddings_umap = umap_reducer.fit_transform(self.embeddings_scaled)
#             self.results['UMAP'] = {'embeddings': embeddings_umap}

#         if self.use_pca:
#             pca = PCA(n_components=2, random_state=42)
#             embeddings_pca = pca.fit_transform(self.embeddings_scaled)
#             self.results['PCA'] = {'embeddings': embeddings_pca}

#     def evaluate_metrics(self):
#         for key, result in self.results.items():
#             embeddings_2d = result['embeddings']
#             metrics = {}
#             # Clustering with KMeans
#             kmeans = KMeans(
#                 n_clusters=self.n_clusters,
#                 random_state=42,
#                 n_init='auto'  # Set explicitly to suppress FutureWarning
#             )
#             labels_pred = kmeans.fit_predict(embeddings_2d)
#             metrics['Silhouette Score'] = silhouette_score(embeddings_2d, labels_pred)
#             metrics['Davies-Bouldin Index'] = davies_bouldin_score(embeddings_2d, labels_pred)
#             metrics['Adjusted Rand Index'] = adjusted_rand_score(self.labels, labels_pred)
#             # Purity
#             purity = self.purity_score(self.labels, labels_pred)
#             metrics['Purity Score'] = purity
#             # Average Entropy
#             average_entropy = self.compute_average_entropy(self.labels, labels_pred)
#             metrics['Average Entropy'] = average_entropy
#             # Adjusted Mutual Information
#             ami = adjusted_mutual_info_score(self.labels, labels_pred)
#             metrics['Adjusted Mutual Information'] = ami
#             # Trustworthiness
#             trust = trustworthiness(self.embeddings_scaled, embeddings_2d, n_neighbors=5)
#             metrics['Trustworthiness'] = trust
#             # Continuity is not implemented
#             metrics['Continuity'] = None
#             self.results[key]['metrics'] = metrics

#     def visualize_embeddings(self):
#         # Use the new Matplotlib colormaps API to get the colormap
#         cmap = matplotlib.colormaps.get_cmap('tab10')
#         colors = [cmap(i / self.n_clusters) for i in range(self.n_clusters)]
        
#         for key, result in self.results.items():
#             embeddings_2d = result['embeddings']
#             plt.figure(figsize=(10, 8))
#             for i, label in enumerate(self.unique_labels):
#                 idxs = self.labels == label
#                 plt.scatter(
#                     embeddings_2d[idxs, 0],
#                     embeddings_2d[idxs, 1],
#                     label=f"Class {label}",
#                     color=colors[i],
#                     alpha=0.7,
#                     edgecolors='k',
#                     linewidth=0.5
#                 )
#             plt.legend(title="Classes")
#             plt.title(f"{key} Visualization of Embeddings")
#             plt.xlabel(f"{key} Dimension 1")
#             plt.ylabel(f"{key} Dimension 2")
#             plt.grid(True)
#             plt.show()


#     def run(self):
#         self.extract_embeddings()
#         self.apply_dimensionality_reduction()
#         if self.compute_metrics:
#             self.evaluate_metrics()
#             for key, result in self.results.items():
#                 print(f"\n--- {key} Evaluation Metrics ---")
#                 metrics = result['metrics']
#                 for metric, value in metrics.items():
#                     if value is None:
#                         print(f"{metric}: N/A")
#                     else:
#                         print(f"{metric}: {value:.4f}")
#         if self.visualize:
#             self.visualize_embeddings()

In [None]:
# import os
# import warnings
# from collections import defaultdict
# from pathlib import Path

# import matplotlib.pyplot as plt
# import matplotlib
# import numpy as np
# import pandas as pd
# import seaborn as sns
# import torch
# from sklearn.decomposition import PCA
# from sklearn.cluster import KMeans
# from sklearn.metrics import (
#     adjusted_rand_score,
#     adjusted_mutual_info_score,
#     davies_bouldin_score,
#     silhouette_score
# )
# from sklearn.manifold import TSNE, trustworthiness
# from sklearn.preprocessing import StandardScaler
# from scipy.optimize import linear_sum_assignment
# from scipy.stats import entropy
# import umap

# class LatentSpaceEvaluator:
#     def __init__(
#         self,
#         model,
#         dataloader,
#         device='cpu',
#         umap_enabled=True,
#         pca_enabled=True,
#         tsne_enabled=True,
#         visualize=True,
#         compute_metrics=True,
#         n_clusters=5,
#         output_image_dir='visualizations',
#         output_metrics_dir='metrics',
#         augmentation_strategy='default_strategy'
#     ):
#         """
#         Initializes the LatentSpaceEvaluator.

#         Parameters:
#             model: The neural network model to extract embeddings from.
#             dataloader: DataLoader providing the data batches.
#             device (str): Device to run the model on ('cpu' or 'cuda').
#             umap_enabled (bool): Whether to use UMAP for dimensionality reduction.
#             pca_enabled (bool): Whether to use PCA for dimensionality reduction.
#             tsne_enabled (bool): Whether to use t-SNE for dimensionality reduction.
#             visualize (bool): Whether to visualize embeddings.
#             compute_metrics (bool): Whether to compute evaluation metrics.
#             n_clusters (int): Number of expected clusters/classes.
#             output_image_dir (str): Directory path to save visualization images.
#             output_metrics_dir (str): Directory path to save metrics CSV files.
#             augmentation_strategy (str): Placeholder for data-augmentation strategy.
#         """
#         self.model = model.to(device)
#         self.dataloader = dataloader
#         self.device = device
#         self.use_umap = umap_enabled
#         self.use_pca = pca_enabled
#         self.use_tsne = tsne_enabled
#         self.visualize = visualize
#         self.compute_metrics = compute_metrics
#         self.n_clusters = n_clusters
#         self.embeddings = None
#         self.labels = None
#         self.embeddings_scaled = None
#         self.results = {}
#         self.unique_labels = None
#         self.augmentation_strategy = augmentation_strategy

#         # Set up output directories
#         self.output_image_dir = Path(output_image_dir)
#         self.output_metrics_dir = Path(output_metrics_dir)
#         self.output_image_dir.mkdir(parents=True, exist_ok=True)
#         self.output_metrics_dir.mkdir(parents=True, exist_ok=True)

#         # Set Seaborn style for better aesthetics
#         sns.set_theme(style='whitegrid', context='talk')
#         matplotlib.rcParams.update({'figure.autolayout': True})

#     def extract_embeddings(self):
#         """
#         Extracts embeddings from the model using the provided dataloader.
#         """
#         self.model.eval()
#         embeddings = []
#         labels = []
#         with torch.no_grad():
#             for batch in self.dataloader:
#                 # Assuming batch contains (x_i, x_j, y)
#                 x_i, _, y = batch
#                 x_i = x_i.to(self.device)
#                 z_i = self.model(x_i)
#                 embeddings.append(z_i.cpu().numpy())
#                 labels.append(y.numpy())
#         self.embeddings = np.concatenate(embeddings, axis=0)
#         self.labels = np.concatenate(labels, axis=0)
#         self.unique_labels = np.unique(self.labels)
#         if len(self.unique_labels) != self.n_clusters:
#             raise ValueError(
#                 f"Expected {self.n_clusters} unique labels, but found {len(self.unique_labels)}: {self.unique_labels}"
#             )
#         scaler = StandardScaler()
#         self.embeddings_scaled = scaler.fit_transform(self.embeddings)
#         print("Embeddings extracted and scaled.")

#     @staticmethod
#     def purity_score(y_true, y_pred):
#         """
#         Calculates the purity score for the clustering.

#         Parameters:
#             y_true (array-like): True labels.
#             y_pred (array-like): Predicted cluster labels.

#         Returns:
#             float: Purity score.
#         """
#         contingency_matrix = np.zeros((len(np.unique(y_true)), len(np.unique(y_pred))))
#         for i, true_label in enumerate(np.unique(y_true)):
#             for j, pred_label in enumerate(np.unique(y_pred)):
#                 contingency_matrix[i, j] = np.sum(
#                     (y_true == true_label) & (y_pred == pred_label)
#                 )
#         row_ind, col_ind = linear_sum_assignment(-contingency_matrix)
#         purity = contingency_matrix[row_ind, col_ind].sum() / np.sum(contingency_matrix)
#         return purity

#     @staticmethod
#     def compute_average_entropy(y_true, y_pred):
#         """
#         Computes the average entropy for the clustering.

#         Parameters:
#             y_true (array-like): True labels.
#             y_pred (array-like): Predicted cluster labels.

#         Returns:
#             float: Average entropy.
#         """
#         cluster_labels = defaultdict(list)
#         for label, cluster in zip(y_true, y_pred):
#             cluster_labels[cluster].append(label)
#         entropy_total = 0
#         for labels_in_cluster in cluster_labels.values():
#             label_counts = np.bincount(labels_in_cluster)
#             probabilities = label_counts / len(labels_in_cluster)
#             # Filter out zero probabilities to avoid log2(0)
#             probabilities = probabilities[probabilities > 0]
#             cluster_entropy = entropy(probabilities, base=2)
#             entropy_total += cluster_entropy * len(labels_in_cluster)
#         average_entropy = entropy_total / len(y_true)
#         return average_entropy

#     def apply_dimensionality_reduction(self):
#         """
#         Applies the selected dimensionality reduction techniques to the embeddings.
#         """
#         if self.use_tsne:
#             print("Applying t-SNE...")
#             tsne = TSNE(
#                 n_components=2,
#                 random_state=42,
#                 perplexity=30,
#                 n_iter=1000,
#                 learning_rate='auto',
#                 init='random',
#                 verbose=0  # Set to 1 for detailed output
#             )
#             embeddings_tsne = tsne.fit_transform(self.embeddings_scaled)
#             self.results['t-SNE'] = {'embeddings': embeddings_tsne}
#             print("t-SNE completed.")

#         if self.use_umap:
#             print("Applying UMAP...")
#             with warnings.catch_warnings():
#                 warnings.filterwarnings("ignore", category=UserWarning, module="umap")
#                 umap_reducer = umap.UMAP(
#                     n_components=2,
#                     random_state=42,
#                     n_neighbors=15,
#                     min_dist=0.1,
#                     metric='euclidean',
#                     n_jobs=-1  # Use all available cores
#                 )
#                 embeddings_umap = umap_reducer.fit_transform(self.embeddings_scaled)
#             self.results['UMAP'] = {'embeddings': embeddings_umap}
#             print("UMAP completed.")

#         if self.use_pca:
#             print("Applying PCA...")
#             pca = PCA(n_components=2, random_state=42)
#             embeddings_pca = pca.fit_transform(self.embeddings_scaled)
#             self.results['PCA'] = {'embeddings': embeddings_pca}
#             print("PCA completed.")

#     def evaluate_metrics(self):
#         """
#         Evaluates clustering metrics for each dimensionality reduction result.
#         """
#         print("Evaluating metrics...")
#         for key, result in self.results.items():
#             embeddings_2d = result['embeddings']
#             metrics = {}
#             # Clustering with KMeans
#             kmeans = KMeans(
#                 n_clusters=self.n_clusters,
#                 random_state=42,
#                 n_init='auto'  # Set explicitly to suppress FutureWarning
#             )
#             labels_pred = kmeans.fit_predict(embeddings_2d)
#             metrics['Silhouette Score'] = silhouette_score(embeddings_2d, labels_pred)
#             metrics['Davies-Bouldin Index'] = davies_bouldin_score(embeddings_2d, labels_pred)
#             metrics['Adjusted Rand Index'] = adjusted_rand_score(self.labels, labels_pred)
#             # Purity
#             purity = self.purity_score(self.labels, labels_pred)
#             metrics['Purity Score'] = purity
#             # Average Entropy
#             average_entropy = self.compute_average_entropy(self.labels, labels_pred)
#             metrics['Average Entropy'] = average_entropy
#             # Adjusted Mutual Information
#             ami = adjusted_mutual_info_score(self.labels, labels_pred)
#             metrics['Adjusted Mutual Information'] = ami
#             # Trustworthiness
#             trust = trustworthiness(self.embeddings_scaled, embeddings_2d, n_neighbors=5)
#             metrics['Trustworthiness'] = trust
#             # Continuity is not implemented
#             metrics['Continuity'] = None
#             self.results[key]['metrics'] = metrics

#             # Save metrics to CSV
#             metrics_df = pd.DataFrame([metrics])
#             metrics_filename = f"{key.lower()}_metrics_{self.augmentation_strategy}.csv"
#             metrics_path = self.output_metrics_dir / metrics_filename
#             metrics_df.to_csv(metrics_path, index=False)
#             print(f"Metrics for {key} saved to {metrics_path}.")

#     def visualize_embeddings(self):
#         """
#         Visualizes the embeddings and saves the plots as image files.
#         """
#         print("Visualizing embeddings...")
#         # Define color palette
#         palette = sns.color_palette("tab10", n_colors=self.n_clusters)

#         for key, result in self.results.items():
#             embeddings_2d = result['embeddings']
#             plt.figure(figsize=(12, 10))
#             sns.scatterplot(
#                 x=embeddings_2d[:, 0],
#                 y=embeddings_2d[:, 1],
#                 hue=self.labels,
#                 palette=palette,
#                 legend='full',
#                 alpha=0.7,
#                 edgecolor='k',
#                 linewidth=0.5
#             )
#             plt.title(f"{key} Visualization of Embeddings", fontsize=16)
#             plt.xlabel(f"{key} Dimension 1", fontsize=14)
#             plt.ylabel(f"{key} Dimension 2", fontsize=14)
#             plt.legend(title="Classes", fontsize=12, title_fontsize=13, loc='best')
#             plt.tight_layout()

#             # Save the plot
#             image_filename = f"{key.lower()}_embeddings_{self.augmentation_strategy}.png"
#             image_path = self.output_image_dir / image_filename
#             plt.savefig(image_path, dpi=300)
#             plt.close()
#             print(f"Visualization for {key} saved to {image_path}.")

#     def run(self):
#         """
#         Executes the full evaluation pipeline: extraction, dimensionality reduction,
#         metric evaluation, and visualization.
#         """
#         print("Starting Latent Space Evaluation...")
#         self.extract_embeddings()
#         self.apply_dimensionality_reduction()
#         if self.compute_metrics:
#             self.evaluate_metrics()
#             print("Evaluation metrics computed and saved.")
#         if self.visualize:
#             self.visualize_embeddings()
#             print("Embeddings visualized and saved.")
#         print("Latent Space Evaluation Completed.")


### Pre-train Encoder with CRL

In [17]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

augmentations = [RandomAmplitudeScale(p=1.0),TailoredMixup(p=1.0),RandomAdditiveGaussianNoise(p=1.0)]

contrastive_dataset = ContrastiveEEGDataset(eeg_data, augmentations=augmentations)
contrastive_loader = DataLoader(contrastive_dataset, batch_size=128, shuffle=True)  #batchsize = 64

model = SimpleSleepNet(latent_dim=128).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

train_contrastive(model, contrastive_loader, optimizer, device=device, num_epochs=10, temperature = 0.5)


Using device: cuda
Contrastive Training - Epoch [1/10], Loss: 5.3588
Contrastive Training - Epoch [2/10], Loss: 5.3384
Contrastive Training - Epoch [3/10], Loss: 5.3210
Contrastive Training - Epoch [4/10], Loss: 5.2958
Contrastive Training - Epoch [5/10], Loss: 5.2447
Contrastive Training - Epoch [6/10], Loss: 5.2447
Contrastive Training - Epoch [7/10], Loss: 5.1947
Contrastive Training - Epoch [8/10], Loss: 5.1642
Contrastive Training - Epoch [9/10], Loss: 5.1218
Contrastive Training - Epoch [10/10], Loss: 5.0927


### Assess Latent Space

In [11]:
# Initialize the dataset with return_labels=True
visualization_dataset = ContrastiveEEGDataset(
    eeg_data= eeg_data,
    augmentations=[],
    return_labels=True
)

visualization_loader = DataLoader(visualization_dataset, batch_size=128, shuffle=False)


# evaluator = LatentSpaceEvaluator(
#     model=model,
#     dataloader=visualization_loader,
#     device='cuda',       # or 'cpu'
#     umap=True,
#     pca=True,
#     tsne=True,
#     visualize=True,
#     compute_metrics=True,
#     n_clusters=5         # Number of expected clusters/classes
# )

# evaluator.run()

# Assuming you have already defined `model` and `visualization_loader`


from latent_space_evaluator import LatentSpaceEvaluator

evaluator = LatentSpaceEvaluator(
    model=model,
    dataloader=visualization_loader,
    device='cuda',                   # or 'cpu'
    umap_enabled=True,
    pca_enabled=True,
    tsne_enabled=True,
    visualize=True,
    compute_metrics=True,
    n_clusters=5,                   # Number of expected clusters/classes
    output_image_dir='latent_space_viz',    # Specify your desired image output directory
    output_metrics_dir='latent_space_metrics', # Specify your desired metrics output directory
    augmentation_strategy='1' # Replace with your augmentation strategy identifier
)

evaluator.run()


Starting Latent Space Evaluation...
Embeddings extracted and scaled.
Applying t-SNE...
t-SNE completed.
Applying UMAP...
UMAP completed.
Applying PCA...
PCA completed.
Evaluating metrics...
Metrics for t-SNE saved to latent_space_metrics\t-sne_metrics_1.csv.
Metrics for UMAP saved to latent_space_metrics\umap_metrics_1.csv.




Metrics for PCA saved to latent_space_metrics\pca_metrics_1.csv.
Evaluation metrics computed and saved.
Visualizing embeddings...
Visualization for t-SNE saved to latent_space_viz\t-sne_embeddings_1.png.
Visualization for UMAP saved to latent_space_viz\umap_embeddings_1.png.
Visualization for PCA saved to latent_space_viz\pca_embeddings_1.png.
Embeddings visualized and saved.
Latent Space Evaluation Completed.


# Downstream Classification

### SleepStageClassifier

In [None]:
# class SleepStageClassifier(nn.Module):
#     def __init__(self, input_dim=128, num_classes=5):
#         super(SleepStageClassifier, self).__init__()
#         self.classifier = nn.Sequential(
#             nn.Linear(input_dim, 256),
#             nn.BatchNorm1d(256),
#             nn.ReLU(),
#             nn.Dropout(p=0.5),
#             nn.Linear(256, 128),
#             nn.BatchNorm1d(128),
#             nn.ReLU(),
#             nn.Dropout(p=0.5),
#             nn.Linear(128, num_classes)
#         )
        
#     def forward(self, x):
#         x = self.classifier(x)
#         return x
    

# import torch
# import torch.nn as nn

# class SleepStageClassifier(nn.Module):
#     def __init__(self, input_dim=128, num_classes=5, dropout_probs=(0.5, 0.5)):
#         """
#         Sleep stage classifier with configurable dropout rates.

#         Parameters:
#         - input_dim (int): Dimensionality of the input features.
#         - num_classes (int): Number of classes to predict.
#         - dropout_probs (tuple): Dropout probabilities for the layers.
#         """
#         super(SleepStageClassifier, self).__init__()
        
#         # Ensure dropout_probs is a tuple of the correct size
#         if not isinstance(dropout_probs, (tuple, list)) or len(dropout_probs) != 2:
#             raise ValueError("dropout_probs must be a tuple or list of length 2.")
        
#         self.classifier = nn.Sequential(
#             nn.Linear(input_dim, 256),
#             nn.BatchNorm1d(256),
#             nn.ReLU(),
#             nn.Dropout(p=dropout_probs[0]),  # First dropout rate
#             nn.Linear(256, 128),
#             nn.BatchNorm1d(128),
#             nn.ReLU(),
#             nn.Dropout(p=dropout_probs[1]),  # Second dropout rate
#             nn.Linear(128, num_classes)
#         )
        
#     def forward(self, x):
#         """
#         Forward pass of the classifier.

#         Parameters:
#         - x (torch.Tensor): Input tensor of shape (Batch, input_dim).

#         Returns:
#         - torch.Tensor: Output logits of shape (Batch, num_classes).
#         """
#         x = self.classifier(x)
#         return x

In [13]:
from models import SimpleSleepNet, SleepStageClassifier

### Dataset Class

In [None]:
# class SupervisedEEGDataset(Dataset):
#     def __init__(self, eeg_data):

#         self.data = []
#         self.labels = []

#         for label, signals in eeg_data.items():
#             self.data.extend(signals)
#             self.labels.extend([label] * len(signals))

#         self.data = np.array(self.data)
#         self.labels = np.array(self.labels)
        
#     def __len__(self):
#         return len(self.data)
    
#     def __getitem__(self, idx):
#         x = self.data[idx]
#         y = self.labels[idx]
#         x = torch.tensor(x, dtype=torch.float32).unsqueeze(0)
#         y = torch.tensor(y, dtype=torch.long)
#         return x, y

### Train-Test-Val Split

In [26]:
from sklearn.model_selection import train_test_split

# Combine all data and labels
all_data = []
all_labels = []

for label, signals in eeg_data.items():
    all_data.extend(signals)
    all_labels.extend([label] * len(signals))

all_data = np.array(all_data)
all_labels = np.array(all_labels)

# First split into train+val and test sets (80% train+val, 20% test)
X_train_val, X_test, y_train_val, y_test = train_test_split(
    all_data, all_labels, test_size=0.2, stratify=all_labels, random_state=42
)

# Then split train_val into train and val sets (75% train, 25% val of the 80%)
X_train, X_val, y_train, y_val = train_test_split(
    X_train_val, y_train_val, test_size=0.25, stratify=y_train_val, random_state=42
)
# Now we have 60% train, 20% val, 20% test

# Create dictionaries to store data by label
num_classes = len(np.unique(all_labels))
train_data = {label: [] for label in range(num_classes)}
val_data = {label: [] for label in range(num_classes)}
test_data = {label: [] for label in range(num_classes)}

# Fill the dictionaries
for x, y in zip(X_train, y_train):
    train_data[y].append(x)
for x, y in zip(X_val, y_val):
    val_data[y].append(x)
for x, y in zip(X_test, y_test):
    test_data[y].append(x)

# Convert lists to numpy arrays
for label in train_data:
    train_data[label] = np.array(train_data[label])
    val_data[label] = np.array(val_data[label])
    test_data[label] = np.array(test_data[label])

# Create datasets
train_dataset = SupervisedEEGDataset(train_data)
val_dataset = SupervisedEEGDataset(val_data)
test_dataset = SupervisedEEGDataset(test_data)

# Create dataloaders
batch_size = 128   #64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

### Supervised Training with validation

In [None]:
# def evaluate_classifier(encoder, classifier, data_loader, criterion):
#     encoder.eval()
#     classifier.eval()
#     total_loss = 0.0
#     correct = 0
#     total = 0
    
#     with torch.no_grad():
#         for x, y in data_loader:
#             x = x.to(device)
#             y = y.to(device)
            
#             z = encoder(x)
#             outputs = classifier(z)
#             loss = criterion(outputs, y)
#             total_loss += loss.item()
            
#             _, preds = torch.max(outputs, 1)
#             correct += (preds == y).sum().item()
#             total += y.size(0)
    
#     avg_loss = total_loss / len(data_loader)
#     accuracy = correct / total
#     return avg_loss, accuracy

# def train_classifier(encoder, classifier, train_loader, val_loader, criterion, optimizer, num_epochs=50):
#     encoder.eval()  # Set encoder to evaluation mode
    
#     best_val_loss = float('inf')
    
#     for epoch in range(num_epochs):
#         classifier.train()  # Set classifier to training mode
#         total_loss = 0.0
        
#         for x, y in train_loader:
#             x = x.to(device)
#             y = y.to(device)
            
#             with torch.no_grad():
#                 z = encoder(x)  # Get embeddings from frozen encoder
            
#             optimizer.zero_grad()
#             outputs = classifier(z)
#             loss = criterion(outputs, y)
#             loss.backward()
#             optimizer.step()
#             total_loss += loss.item()
        
#         avg_train_loss = total_loss / len(train_loader)
        
#         # Evaluate on validation set
#         val_loss, val_accuracy = evaluate_classifier(encoder, classifier, val_loader, criterion)
        
#         print(f"Epoch [{epoch+1}/{num_epochs}], "
#               f"Train Loss: {avg_train_loss:.4f}, "
#               f"Val Loss: {val_loss:.4f}, "
#               f"Val Acc: {val_accuracy:.4f}")
        
#         # Save the best model based on validation loss
#         if val_loss < best_val_loss:
#             best_val_loss = val_loss
#             torch.save(classifier.state_dict(), 'best_classifier.pth')




# import torch

# def evaluate_classifier(encoder, classifier, data_loader, criterion, device='cuda'):
#     """
#     Evaluates the classifier on a given dataset.

#     Parameters:
#     - encoder (nn.Module): Frozen encoder to generate embeddings.
#     - classifier (nn.Module): Classifier to evaluate.
#     - data_loader (DataLoader): DataLoader providing the evaluation dataset.
#     - criterion (nn.Module): Loss function.
#     - device (str): Device to run the evaluation on ('cuda' or 'cpu').

#     Returns:
#     - avg_loss (float): Average loss over the dataset.
#     - accuracy (float): Classification accuracy.
#     """
#     encoder.eval()
#     classifier.eval()
#     total_loss = 0.0
#     correct = 0
#     total = 0
    
#     with torch.no_grad():
#         for x, y in data_loader:
#             x = x.to(device)
#             y = y.to(device)
            
#             z = encoder(x)
#             outputs = classifier(z)
#             loss = criterion(outputs, y)
#             total_loss += loss.item()
            
#             _, preds = torch.max(outputs, 1)
#             correct += (preds == y).sum().item()
#             total += y.size(0)
    
#     avg_loss = total_loss / len(data_loader)
#     accuracy = correct / total
#     return avg_loss, accuracy


# def train_classifier(
#     encoder,
#     classifier,
#     train_loader,
#     val_loader,
#     criterion,
#     optimizer,
#     num_epochs=50,
#     device='cuda',
#     save_path='best_classifier.pth'
# ):
#     """
#     Trains the classifier while keeping the encoder frozen.

#     Parameters:
#     - encoder (nn.Module): Frozen encoder to generate embeddings.
#     - classifier (nn.Module): Classifier to train.
#     - train_loader (DataLoader): DataLoader for the training set.
#     - val_loader (DataLoader): DataLoader for the validation set.
#     - criterion (nn.Module): Loss function.
#     - optimizer (Optimizer): Optimizer for the classifier.
#     - num_epochs (int): Number of training epochs.
#     - device (str): Device to run the training on ('cuda' or 'cpu').
#     - save_path (str): Path to save the best classifier model.

#     Returns:
#     - best_val_loss (float): Best validation loss achieved during training.
#     """
#     encoder.eval()  # Set encoder to evaluation mode
#     classifier.to(device)
#     best_val_loss = float('inf')
    
#     for epoch in range(num_epochs):
#         classifier.train()  # Set classifier to training mode
#         total_loss = 0.0
        
#         for x, y in train_loader:
#             x = x.to(device)
#             y = y.to(device)
            
#             with torch.no_grad():
#                 z = encoder(x)  # Get embeddings from frozen encoder
            
#             optimizer.zero_grad()
#             outputs = classifier(z)
#             loss = criterion(outputs, y)
#             loss.backward()
#             optimizer.step()
#             total_loss += loss.item()
        
#         avg_train_loss = total_loss / len(train_loader)
        
#         # Evaluate on validation set
#         val_loss, val_accuracy = evaluate_classifier(encoder, classifier, val_loader, criterion, device)
        
#         print(f"Epoch [{epoch+1}/{num_epochs}], "
#               f"Train Loss: {avg_train_loss:.4f}, "
#               f"Val Loss: {val_loss:.4f}, "
#               f"Val Acc: {val_accuracy:.4f}")
        
#         # Save the best model based on validation loss
#         if val_loss < best_val_loss:
#             best_val_loss = val_loss
#             torch.save(classifier.state_dict(), save_path)
#             print(f"Saved best model to {save_path}")
    
#     return best_val_loss

### Evaluation

In [None]:
# def evaluate_classifier(encoder, classifier, data_loader, criterion):
#     encoder.eval()
#     classifier.eval()
#     total_loss = 0.0
#     correct = 0
#     total = 0
    
#     with torch.no_grad():
#         for x, y in data_loader:
#             x = x.to(device)
#             y = y.to(device)
            
#             z = encoder(x)
#             outputs = classifier(z)
#             loss = criterion(outputs, y)
#             total_loss += loss.item()
            
#             _, preds = torch.max(outputs, 1)
#             correct += (preds == y).sum().item()
#             total += y.size(0)
    
#     avg_loss = total_loss / len(data_loader)
#     accuracy = correct / total
#     return avg_loss, accuracy

### Predictions

In [None]:
# import numpy as np
# import torch

# def get_predictions(encoder, classifier, data_loader, device='cuda'):
#     """
#     Generates predictions for a dataset using the encoder and classifier.

#     Parameters:
#     - encoder (nn.Module): Pretrained encoder model.
#     - classifier (nn.Module): Classifier model.
#     - data_loader (DataLoader): DataLoader providing the dataset.
#     - device (str): Device to run the models on ('cuda' or 'cpu').

#     Returns:
#     - all_preds (np.ndarray): Predicted labels for the dataset.
#     - all_labels (np.ndarray): True labels for the dataset.
#     """
#     encoder.eval()
#     classifier.eval()
#     all_preds = []
#     all_labels = []
    
#     with torch.no_grad():
#         for x, y in data_loader:
#             x = x.to(device)
#             y = y.to(device)
            
#             # Generate embeddings with the encoder
#             z = encoder(x)
            
#             # Classifier outputs
#             outputs = classifier(z)
#             _, preds = torch.max(outputs, 1)
            
#             # Collect predictions and labels
#             all_preds.extend(preds.cpu().numpy())
#             all_labels.extend(y.cpu().numpy())
    
#     return np.array(all_preds), np.array(all_labels)


In [72]:
from get_predictions import get_predictions

### Train Classifier

In [15]:
from sup_training_classifier import train_classifier

In [29]:
classifier = SleepStageClassifier(input_dim=128, num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(classifier.parameters(), lr=1e-3)

# Freeze the encoder
for param in model.parameters():
    param.requires_grad = False

train_classifier(model, classifier, train_loader, val_loader, criterion, optimizer, num_epochs=10, device='cuda', save_path='best_classifier.pth') #num_epochs = 100

Epoch [1/10], Train Loss: 1.4623, Val Loss: 1.5836, Val Acc: 0.2976
Saved best model to best_classifier.pth
Epoch [2/10], Train Loss: 1.0011, Val Loss: 1.5375, Val Acc: 0.6548
Saved best model to best_classifier.pth
Epoch [3/10], Train Loss: 0.8928, Val Loss: 1.4806, Val Acc: 0.6845
Saved best model to best_classifier.pth
Epoch [4/10], Train Loss: 0.8572, Val Loss: 1.4162, Val Acc: 0.6786
Saved best model to best_classifier.pth
Epoch [5/10], Train Loss: 0.8035, Val Loss: 1.3507, Val Acc: 0.7024
Saved best model to best_classifier.pth
Epoch [6/10], Train Loss: 0.8038, Val Loss: 1.2815, Val Acc: 0.7024
Saved best model to best_classifier.pth
Epoch [7/10], Train Loss: 0.7611, Val Loss: 1.2128, Val Acc: 0.7083
Saved best model to best_classifier.pth
Epoch [8/10], Train Loss: 0.7659, Val Loss: 1.1394, Val Acc: 0.7083
Saved best model to best_classifier.pth
Epoch [9/10], Train Loss: 0.7156, Val Loss: 1.0674, Val Acc: 0.7024
Saved best model to best_classifier.pth
Epoch [10/10], Train Loss: 0

0.9930530488491058

### Measure Performance

In [31]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report


classifier.load_state_dict(torch.load('best_classifier.pth'))
predictions, true_labels = get_predictions(model, classifier, test_loader)

# Overall Accuracy
accuracy = accuracy_score(true_labels, predictions)
print(f"Overall Accuracy: {accuracy:.4f}")

# Precision, Recall, F1 Score per Class
precision, recall, f1_score, support = precision_recall_fscore_support(
    true_labels, predictions, labels=range(num_classes)
)

# Macro F1 Score
macro_f1 = np.mean(f1_score)
print(f"Macro F1 Score: {macro_f1:.4f}")

# Print detailed classification report
print("\nClassification Report:")
print(classification_report(true_labels, predictions, labels=range(num_classes), target_names=[f"Class {i}" for i in range(num_classes)]))

Overall Accuracy: 0.7219
Macro F1 Score: 0.5888

Classification Report:
              precision    recall  f1-score   support

     Class 0       0.86      0.50      0.63        38
     Class 1       0.00      0.00      0.00        12
     Class 2       0.83      0.78      0.80        50
     Class 3       0.91      0.89      0.90        44
     Class 4       0.44      1.00      0.61        25

    accuracy                           0.72       169
   macro avg       0.61      0.63      0.59       169
weighted avg       0.74      0.72      0.70       169



  classifier.load_state_dict(torch.load('best_classifier.pth'))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


# Loop

In [None]:
12,16,11, 10,7,6,4

go to each module and suggest improvements and doc strings and ask it to explain everything in the you can also improve the logging.

In [None]:
import logging


In [2]:
import torch
import json
from torch.utils.data import DataLoader
import torch.optim as optim
import numpy as np
import torch.nn as nn
from sklearn.model_selection import train_test_split
import logging

from datasets import ContrastiveEEGDataset, SupervisedEEGDataset
from models import SimpleSleepNet, SleepStageClassifier
from training import train_contrastive, train_classifier
from evaluation import LatentSpaceEvaluator, get_predictions, save_classification_results
from augmentations import *
from utils import load_eeg_data
from logging_config import setup_logging


config_path = 'config.json'
with open(config_path, 'r') as f:
    config = json.load(f)

setup_logging(log_level=logging.INFO, log_file=f'logs/experiment_{config["experiment_num"]}.log')
logger = logging.getLogger(__name__)
logger.info("Starting the EEG Project")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
logger.info(f"Using device: {device}")

BATCH_SIZE = config["pretraining_params"]["batch_size"]
FS = 100
EPOCH_LENGTH = 30
SAMPLES_PER_EPOCH = FS * EPOCH_LENGTH
NUM_CLASSES = 5
EPOCHS_PER_CLASS = 10



eeg_data = load_eeg_data(dset_path=config['dataset']['dset_path'], num_files=config['dataset']['max_files'])
logger.info(f"EEG data loaded. Number of classes: {len(eeg_data)}")


AUGMENTATION_CLASSES = {
    "RandomAmplitudeScale": RandomAmplitudeScale,
    "RandomDCShift": RandomDCShift,
    "RandomTimeShift": RandomTimeShift,
    "RandomZeroMasking": RandomZeroMasking,
    "RandomAdditiveGaussianNoise": RandomAdditiveGaussianNoise,
    "RandomBandStopFilter": RandomBandStopFilter,
    "TimeWarping": TimeWarping,
    "TimeReverse": TimeReverse,
    "Permutation": Permutation,
    "CutoutResize": CutoutResize,
    "TailoredMixup": TailoredMixup,
    "AverageFilter": AverageFilter,
    "SignFlip": SignFlip
}

augmentations = []
for aug_name, aug_params in config.get("augmentations", {}).items():
    if aug_name in AUGMENTATION_CLASSES:
        augmentation_class = AUGMENTATION_CLASSES[aug_name]
        augmentations.append(augmentation_class(**aug_params))
    else:
        logger.warning(f"Warning: Augmentation '{aug_name}' not recognized. Skipping.")

contrastive_dataset = ContrastiveEEGDataset(eeg_data, augmentations=augmentations)
contrastive_loader = DataLoader(contrastive_dataset, batch_size=BATCH_SIZE, shuffle=True)  #batchsize = 64
logger.info("Contrastive dataset created and loaded.")

model = SimpleSleepNet(latent_dim=config["pretraining_params"]["latent_dim"]).to(device)
optimizer = optim.Adam(model.parameters(), lr=config["pretraining_params"]["learning_rate"])
logger.info("Model and optimizer created.")

train_contrastive(model, contrastive_loader, optimizer, num_epochs=config["pretraining_params"]["num_epochs"]) # num_epochs = 500
logger.info("Contrastive pretraining completed.")

visualization_dataset = ContrastiveEEGDataset(eeg_data= eeg_data, augmentations=[], return_labels=True)
visualization_loader = DataLoader(visualization_dataset, batch_size=BATCH_SIZE, shuffle=False)

evaluator = LatentSpaceEvaluator(
    model=model,
    dataloader=visualization_loader,
    device=device,
    umap_enabled=config["latent_space_params"]["umap_enabled"],
    pca_enabled=config["latent_space_params"]["pca_enabled"],
    tsne_enabled=config["latent_space_params"]["tsne_enabled"],
    visualize=config["latent_space_params"]["visualize"],
    compute_metrics=config["latent_space_params"]["compute_metrics"],
    n_clusters=config["latent_space_params"]["n_clusters"],
    output_image_dir=config["latent_space_params"]["output_image_dir"],   
    output_metrics_dir=config["latent_space_params"]["output_metrics_dir"], 
    augmentation_strategy=config["experiment_num"]
)

evaluator.run()
logging.info("Latent space evaluation completed.")



all_data = []
all_labels = []

for label, signals in eeg_data.items():
    all_data.extend(signals)
    all_labels.extend([label] * len(signals))

all_data = np.array(all_data)
all_labels = np.array(all_labels)

# First split into train+val and test sets (80% train+val, 20% test)
X_train_val, X_test, y_train_val, y_test = train_test_split(
    all_data, all_labels, test_size=0.2, stratify=all_labels, random_state=42
)

# Then split train_val into train and val sets (75% train, 25% val of the 80%)
X_train, X_val, y_train, y_val = train_test_split(
    X_train_val, y_train_val, test_size=0.25, stratify=y_train_val, random_state=42
)
# Now we have 60% train, 20% val, 20% test

# Create dictionaries to store data by label
num_classes = len(np.unique(all_labels))
train_data = {label: [] for label in range(num_classes)}
val_data = {label: [] for label in range(num_classes)}
test_data = {label: [] for label in range(num_classes)}

# Fill the dictionaries
for x, y in zip(X_train, y_train):
    train_data[y].append(x)
for x, y in zip(X_val, y_val):
    val_data[y].append(x)
for x, y in zip(X_test, y_test):
    test_data[y].append(x)

# Convert lists to numpy arrays
for label in train_data:
    train_data[label] = np.array(train_data[label])
    val_data[label] = np.array(val_data[label])
    test_data[label] = np.array(test_data[label])

# Create datasets
train_dataset = SupervisedEEGDataset(train_data)
val_dataset = SupervisedEEGDataset(val_data)
test_dataset = SupervisedEEGDataset(test_data)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
logger.info("Supervised dataset created and loaded.")

### --- Train Classifier ---###
classifier = SleepStageClassifier(input_dim=config["pretraining_params"]["latent_dim"], num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(classifier.parameters(), lr=config["sup_training_params"]["learning_rate"])
logger.info("Classifier created.")

# Freeze the encoder
for param in model.parameters():
    param.requires_grad = False
logger.info("Encoder frozen.")

train_classifier(model, classifier, train_loader, val_loader, criterion, optimizer, num_epochs=config["sup_training_params"]["num_epochs"], device=device, save_path=config["sup_training_params"]["best_model_pth"]) #num_epochs = 100
logger.info("Classifier training completed.")

classifier.load_state_dict(torch.load(config["sup_training_params"]["best_model_pth"]))
predictions, true_labels = get_predictions(model, classifier, test_loader, device=device)
save_classification_results(results_folder = config["results_folder"], predictions = predictions, true_labels=true_labels, num_classes=num_classes, experiment_num = config["experiment_num"])
logger.info("Classification results saved.")

2024-11-24 19:54:29,916 - __main__ - INFO - Starting the EEG Project
2024-11-24 19:54:30,479 - __main__ - INFO - Using device: cuda
2024-11-24 19:54:30,482 - __main__ - INFO - EEG data loaded. Number of classes: 5
2024-11-24 19:54:30,498 - __main__ - INFO - Contrastive dataset created and loaded.


Using device: cuda
Found 152 npz files in ./dset/Sleep-EDF-2018/npz/Fpz-Cz.
Processed 1/1 files.
#EEG Epochs for Label 0 : 188
#EEG Epochs for Label 1 : 58
#EEG Epochs for Label 2 : 250
#EEG Epochs for Label 3 : 220
#EEG Epochs for Label 4 : 125


2024-11-24 19:54:31,249 - __main__ - INFO - Model and optimizer created.
2024-11-24 19:54:32,883 - __main__ - INFO - Contrastive pretraining completed.


Contrastive Training - Epoch [1/1], Loss: 5.3580
Starting Latent Space Evaluation...
Embeddings extracted and scaled.
Applying t-SNE...
t-SNE completed.
Applying UMAP...
UMAP completed.
Applying PCA...
PCA completed.
Evaluating metrics...
Metrics for t-SNE saved to latent_space_metrics\t-sne_metrics_1.csv.
Metrics for UMAP saved to latent_space_metrics\umap_metrics_1.csv.




Metrics for PCA saved to latent_space_metrics\pca_metrics_1.csv.
Evaluation metrics computed and saved.
Visualizing embeddings...
Visualization for t-SNE saved to latent_space_viz\t-sne_embeddings_1.png.
Visualization for UMAP saved to latent_space_viz\umap_embeddings_1.png.


2024-11-24 19:54:42,150 - root - INFO - Latent space evaluation completed.
2024-11-24 19:54:42,166 - __main__ - INFO - Supervised dataset created and loaded.
2024-11-24 19:54:42,179 - __main__ - INFO - Classifier created.
2024-11-24 19:54:42,179 - __main__ - INFO - Encoder frozen.


Visualization for PCA saved to latent_space_viz\pca_embeddings_1.png.
Embeddings visualized and saved.
Latent Space Evaluation Completed.


2024-11-24 19:54:42,403 - __main__ - INFO - Classifier training completed.
  classifier.load_state_dict(torch.load(config["sup_training_params"]["best_model_pth"]))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
2024-11-24 19:54:42,462 - __main__ - INFO - Classification results saved.


Epoch [1/1], Train Loss: 1.5006, Val Loss: 1.5879, Val Acc: 0.5298
Saved best model to best_classifier.pth
Overall Accuracy: 0.5385
Macro F1 Score: 0.2867

Classification Report:
              precision    recall  f1-score     support
Class 0        0.000000  0.000000  0.000000   38.000000
Class 1        0.000000  0.000000  0.000000   12.000000
Class 2        0.442478  1.000000  0.613497   50.000000
Class 3        0.732143  0.931818  0.820000   44.000000
Class 4        0.000000  0.000000  0.000000   25.000000
accuracy       0.538462  0.538462  0.538462    0.538462
macro avg      0.234924  0.386364  0.286699  169.000000
weighted avg   0.321528  0.538462  0.394999  169.000000
Results saved to results\overall_1.csv and results\perclass_1.csv


In [None]:
# import os
# import pandas as pd
# from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report

# def save_classification_results(results_folder, predictions, true_labels, num_classes, experiment_num):
#     """
#     Saves classification results to Excel files.

#     Parameters:
#     - results_folder (str): Path to the folder where results will be saved.
#     - predictions (np.ndarray): Predicted labels.
#     - true_labels (np.ndarray): True labels.
#     - num_classes (int): Number of classes.
#     - experiment_num (int): Experiment number for file naming.

#     Outputs:
#     - Saves overall metrics and per-class metrics to separate Excel files in the specified folder.
#     """
#     # Create the folder if it doesn't exist
#     os.makedirs(results_folder, exist_ok=True)

#     # Calculate metrics
#     accuracy = accuracy_score(true_labels, predictions)
#     precision, recall, f1_score, support = precision_recall_fscore_support(
#         true_labels, predictions, labels=range(num_classes)
#     )
#     macro_f1 = np.mean(f1_score)

#     # Save overall metrics
#     overall_metrics = {
#         "Metric": ["Overall Accuracy", "Macro F1"],
#         "Value": [accuracy, macro_f1]
#     }
#     overall_metrics_df = pd.DataFrame(overall_metrics)
#     overall_metrics_path = os.path.join(results_folder, f'overall_{experiment_num}.csv')
#     overall_metrics_df.to_csv(overall_metrics_path, index=False)

#     print(f"Overall Accuracy: {accuracy:.4f}")
#     print(f"Macro F1 Score: {macro_f1:.4f}")

#     # Save per-class metrics
#     print("\nClassification Report:")
#     class_report = classification_report(
#         true_labels,
#         predictions,
#         labels=range(num_classes),
#         target_names=[f"Class {i}" for i in range(num_classes)],
#         output_dict=True
#     )
#     print(pd.DataFrame(class_report).transpose())

#     class_metrics_df = pd.DataFrame(class_report).transpose()
#     class_metrics_path = os.path.join(results_folder, f'perclass_{experiment_num}.csv')
#     class_metrics_df.to_csv(class_metrics_path, index=True)

#     print(f"Results saved to {overall_metrics_path} and {class_metrics_path}")


In [24]:
### --- Load EEG Data --- ###
config_path = 'sample_config_2018.json'
with open(config_path, 'r') as f:
    config = json.load(f)

dset_cfg = config['dataset']
root_dir = dset_cfg['root_dir']
dset_name = dset_cfg['name']
eeg_channel = dset_cfg['eeg_channel']
num_splits = dset_cfg['num_splits']
seq_len = dset_cfg['seq_len']
target_idx = dset_cfg['target_idx']

print("Configuration Loaded:")
print(json.dumps(dset_cfg, indent=4))


dataset_path = os.path.join(root_dir, 'dset', dset_name, 'npz', eeg_channel)
npz_files = sorted(glob.glob(os.path.join(dataset_path, '*.npz')))

print(f"Found {len(npz_files)} npz files in {dataset_path}.")

sample_file = npz_files[0]
data = np.load(sample_file)
x = data['x']
y = data['y']
print(f"Loaded data from {sample_file}:")
print("EEG data shape (epochs, samples):", x.shape)
print("Labels shape:", y.shape)

unique, counts = np.unique(y, return_counts=True)

# % of each class:
print("Percentage of each class :")
print(np.asarray((unique, counts*100/np.sum(counts))).T)

fs = 100
epoch_length = 30
samples_per_epoch = fs * epoch_length
num_classes = 5
epochs_per_class = 10

eeg_data = {i: [] for i in range(5)}

for file_idx, file in enumerate(npz_files, start=1):
    with np.load(file) as data:
        x = data['x']  # EEG epochs
        y = data['y']  # Corresponding labels

        for label in range(5):
            epochs = x[y == label]
            eeg_data[label].extend(epochs)
    
    if file_idx % 10 == 0 or file_idx == len(npz_files):
        print(f"Processed {file_idx}/{len(npz_files)} files.")
    break                                                                        # Remove this break - this is to only import 1 file so that our loop runs quickly and we can see our code infra is working.

for label in eeg_data:
    eeg_data[label] = np.array(eeg_data[label])
    print(f"#EEG Epochs for Label {label} : {len(eeg_data[label])}")



### --- Pre-Train Encdoer with CRL --- ###
augmentations = [RandomAmplitudeScale(p=1.0),TailoredMixup(p=1.0),RandomAdditiveGaussianNoise(p=1.0)]

contrastive_dataset = ContrastiveEEGDataset(eeg_data, augmentations=augmentations)
contrastive_loader = DataLoader(contrastive_dataset, batch_size=128, shuffle=True)  #batchsize = 64

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = SimpleSleepNet(latent_dim=128).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

train_contrastive(model, contrastive_loader, optimizer, num_epochs=10) # num_epochs = 500


### --- Assess Latent Space --- ###

# Initialize the dataset with return_labels=True
visualization_dataset = ContrastiveEEGDataset(
    eeg_data= eeg_data,
    augmentations=[],
    return_labels=True
)

visualization_loader = DataLoader(visualization_dataset, batch_size=128, shuffle=False)


evaluator = LatentSpaceEvaluator(
    model=model,
    dataloader=visualization_loader,
    device='cuda',                   # or 'cpu'
    umap_enabled=True,
    pca_enabled=True,
    tsne_enabled=True,
    visualize=True,
    compute_metrics=True,
    n_clusters=5,                   # Number of expected clusters/classes
    output_image_dir='latent_space_viz',    # Specify your desired image output directory
    output_metrics_dir='latent_space_metrics', # Specify your desired metrics output directory
    augmentation_strategy='1' # Replace with your augmentation strategy identifier
)

evaluator.run()

### --- Train Test Val Split --- ###

# Combine all data and labels
all_data = []
all_labels = []

for label, signals in eeg_data.items():
    all_data.extend(signals)
    all_labels.extend([label] * len(signals))

all_data = np.array(all_data)
all_labels = np.array(all_labels)

# First split into train+val and test sets (80% train+val, 20% test)
X_train_val, X_test, y_train_val, y_test = train_test_split(
    all_data, all_labels, test_size=0.2, stratify=all_labels, random_state=42
)

# Then split train_val into train and val sets (75% train, 25% val of the 80%)
X_train, X_val, y_train, y_val = train_test_split(
    X_train_val, y_train_val, test_size=0.25, stratify=y_train_val, random_state=42
)
# Now we have 60% train, 20% val, 20% test

# Create dictionaries to store data by label
num_classes = len(np.unique(all_labels))
train_data = {label: [] for label in range(num_classes)}
val_data = {label: [] for label in range(num_classes)}
test_data = {label: [] for label in range(num_classes)}

# Fill the dictionaries
for x, y in zip(X_train, y_train):
    train_data[y].append(x)
for x, y in zip(X_val, y_val):
    val_data[y].append(x)
for x, y in zip(X_test, y_test):
    test_data[y].append(x)

# Convert lists to numpy arrays
for label in train_data:
    train_data[label] = np.array(train_data[label])
    val_data[label] = np.array(val_data[label])
    test_data[label] = np.array(test_data[label])

# Create datasets
train_dataset = SupervisedEEGDataset(train_data)
val_dataset = SupervisedEEGDataset(val_data)
test_dataset = SupervisedEEGDataset(test_data)

# Create dataloaders
batch_size = 128   #64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

### --- Train Classifier ---###
classifier = SleepStageClassifier(input_dim=128, num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(classifier.parameters(), lr=1e-3)

# Freeze the encoder
for param in model.parameters():
    param.requires_grad = False

train_classifier(model, classifier, train_loader, val_loader, criterion, optimizer, num_epochs=100)


### ---- Measure Performance of Netowork --- ###

classifier.load_state_dict(torch.load('best_classifier.pth'))
predictions, true_labels = get_predictions(model, classifier, test_loader)

# Overall Accuracy
accuracy = accuracy_score(true_labels, predictions)
print(f"Overall Accuracy: {accuracy:.4f}")

# Precision, Recall, F1 Score per Class
precision, recall, f1_score, support = precision_recall_fscore_support(
    true_labels, predictions, labels=range(num_classes)
)

# Macro F1 Score
macro_f1 = np.mean(f1_score)
print(f"Macro F1 Score: {macro_f1:.4f}")

# Print detailed classification report
print("\nClassification Report:")
print(classification_report(true_labels, predictions, labels=range(num_classes), target_names=[f"Class {i}" for i in range(num_classes)]))

Configuration Loaded:
{
    "name": "Sleep-EDF-2018",
    "eeg_channel": "Fpz-Cz",
    "num_splits": 10,
    "seq_len": 1,
    "target_idx": 0,
    "root_dir": "./"
}
Found 152 npz files in ./dset\Sleep-EDF-2018\npz\Fpz-Cz.
Loaded data from ./dset\Sleep-EDF-2018\npz\Fpz-Cz\SC4001E0.npz:
EEG data shape (epochs, samples): (841, 3000)
Labels shape: (841,)
Percentage of each class :
[[ 0.         22.35434007]
 [ 1.          6.89655172]
 [ 2.         29.72651605]
 [ 3.         26.15933413]
 [ 4.         14.86325803]]
#EEG Epochs for Label 0 : 188
#EEG Epochs for Label 1 : 58
#EEG Epochs for Label 2 : 250
#EEG Epochs for Label 3 : 220
#EEG Epochs for Label 4 : 125
Using device: cuda
Contrastive Training - Epoch [1/10], Loss: 5.3529
Contrastive Training - Epoch [2/10], Loss: 5.3276
Contrastive Training - Epoch [3/10], Loss: 5.3214
Contrastive Training - Epoch [4/10], Loss: 5.2779
Contrastive Training - Epoch [5/10], Loss: 5.2712
Contrastive Training - Epoch [6/10], Loss: 5.2412
Contrastive Tr

  classifier.load_state_dict(torch.load('best_classifier.pth'))
