In [1]:
import numpy as np
import pandas as pd
import mne
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score ,cohen_kappa_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import os
from scipy import signal


In [2]:

# Set random seeds for reproducibility
# np.random.seed(42)
# torch.manual_seed(42)

# Define constants
N_CHANNELS_PSG = 6
N_CHANNELS_HEADBAND = 2
EPOCH_LENGTH_SEC = 30  
SAMPLE_RATE = 256  
FFT_WINDOW_SIZE = 4  # seconds
BATCH_SIZE = 64
NUM_EPOCHS = 50
LEARNING_RATE = 0.0001

# Sleep stage labels (standard)
SLEEP_STAGES = {
    0: 'Wake',
    1: 'N1',
    2: 'N2',
    3: 'N3',
    4: 'REM'
}
SLEEP_STAGES_reverse = {
    0:'W',    # Wake
    1:'N1',   # Non-REM stage 1
    2: 'N2',   # Non-REM stage 2
    3: 'N3',   # Non-REM stage 3
    4: 'R'     # REM sleep
}
def load_edf_files(psg_path, headband_path):
    """Load PSG and headband EDF files"""
    print("Loading EDF files...")
    
    # Load PSG data
    psg_raw = mne.io.read_raw_edf(psg_path, preload=True)
    
    # Load headband data
    headband_raw = mne.io.read_raw_edf(headband_path, preload=True)
    
    print(f"PSG channels: {psg_raw.ch_names}")
    print(f"Headband channels: {headband_raw.ch_names}")
    
    return psg_raw, headband_raw
    
def load_hypnogram(hypno_path):
    """Load sleep stages from a hypnogram file"""
    # This function would need to be adapted based on your hypnogram format
    # For this example, we'll assume a simple CSV with epoch number and stage
    # .map(SLEEP_STAGES_reverse).fillna('W')
    hypno_df = pd.read_csv(hypno_path, sep='\t')
    sleep_stages = hypno_df['majority'].apply(lambda s: s if s in [0, 1, 2, 3, 4] else 0).values
    
    return sleep_stages
    
def synchronize_recordings(psg_raw, headband_raw):
    """Synchronize PSG and headband recordings based on timestamps"""
    # Extract start times
    psg_start = psg_raw.info['meas_date']
    headband_start = headband_raw.info['meas_date']
    
    print(f"PSG start time: {psg_start}")
    print(f"Headband start time: {headband_start}")
    
    # Calculate offset
    if psg_start and headband_start:
        time_diff = (headband_start - psg_start).total_seconds()
        print(f"Time difference: {time_diff} seconds")
        
        # Determine which recording started first and crop accordingly
        if time_diff > 0:
            # Headband started later, crop PSG
            psg_raw.crop(tmin=time_diff)
        else:
            # PSG started later, crop headband
            headband_raw.crop(tmin=-time_diff)
    
    # Ensure both recordings have the same duration
    duration = min(psg_raw.times[-1], headband_raw.times[-1])
    psg_raw.crop(tmax=duration)
    headband_raw.crop(tmax=duration)
    
    print(f"Synchronized duration: {duration} seconds")
    
    return psg_raw, headband_raw

def extract_features(raw_data, channel_names, window_size=FFT_WINDOW_SIZE, sample_rate=SAMPLE_RATE):
    """Extract frequency domain features from EEG signals"""
    # Select only the channels we want
    data = raw_data.get_data(picks=channel_names)
    
    # Calculate number of epochs
    n_epochs = int(data.shape[1] / (EPOCH_LENGTH_SEC * sample_rate))
    
    # Initialize feature arrays
    n_channels = len(channel_names)
    
    # Define frequency bands
    bands = {
        'delta': (0.5, 4),
        'theta': (4, 8),
        'alpha': (8, 13),
        'beta': (13, 30),
        'gamma': (30, 45)
    }
    
    # Initialize feature matrix
    n_bands = len(bands)
    features = np.zeros((n_epochs, n_channels * n_bands * 2))  # 2 features per band (power, rel_power)
    
    # Process each epoch
    for epoch_idx in range(n_epochs):
        start_sample = epoch_idx * EPOCH_LENGTH_SEC * sample_rate
        end_sample = (epoch_idx + 1) * EPOCH_LENGTH_SEC * sample_rate
        
        # Extract epoch data
        epoch_data = data[:, start_sample:end_sample]
        
        # Calculate features for each channel
        for ch_idx in range(n_channels):
            channel_data = epoch_data[ch_idx, :]
            
            # Compute power spectral density
            freqs, psd = signal.welch(channel_data, fs=sample_rate, nperseg=window_size*sample_rate)
            
            # Calculate band powers
            feature_idx = ch_idx * n_bands * 2
            total_power = np.sum(psd)
            
            for band_idx, (band_name, (low_freq, high_freq)) in enumerate(bands.items()):
                # Find indices corresponding to the frequency band
                idx_band = np.logical_and(freqs >= low_freq, freqs <= high_freq)
                
                # Calculate absolute and relative band power
                band_power = np.sum(psd[idx_band])
                rel_power = band_power / total_power if total_power > 0 else 0
                
                # Store features
                features[epoch_idx, feature_idx + band_idx*2] = band_power
                features[epoch_idx, feature_idx + band_idx*2 + 1] = rel_power
    
    return features


In [3]:
class SleepDataset(Dataset):
    """PyTorch dataset for sleep staging"""
    def __init__(self, X, y):
        # Convert X and y to numpy arrays with proper types first
        if isinstance(X, np.ndarray) and X.dtype == object:
            X = np.stack(X).astype(np.float32)
        if isinstance(y, np.ndarray) and y.dtype == object:
            y = np.array(y, dtype=np.int64)
            
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


In [4]:
class TransformerModel(nn.Module):
    """Simple transformer model for sleep stage classification"""
    def __init__(self, input_dim, num_classes, d_model=64, nhead=4, num_layers=2, dropout=0.1):
        super(TransformerModel, self).__init__()
        
        # Reshape input to sequence for transformer
        # For spectral features, we'll treat each channel-band combination as a token
        self.input_dim = input_dim
        self.d_model = d_model
        
        # Project input to d_model dimensions
        self.input_projection = nn.Linear(input_dim, d_model)
        
        # Create positional encoding (fixed)
        self.pos_encoder = nn.Sequential(
            nn.Linear(1, d_model),
            nn.GELU()
        )
        
        # Transformer encoder
        encoder_layers = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, 
                                                  dim_feedforward=d_model*4,
                                                  dropout=dropout, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_layers)
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model // 2, num_classes)
        )
    
    def forward(self, x):
        # Reshape input to sequence form
        # Here we're treating the whole feature vector as a single token
        # For more complex models, you could reshape to have multiple tokens
        batch_size = x.size(0)
        
        # Project input
        x = self.input_projection(x).unsqueeze(1)  # [batch, 1, d_model]
        
        # Add positional encoding
        positions = torch.zeros(1, 1, 1).to(x.device)
        pos_encoding = self.pos_encoder(positions)
        x = x + pos_encoding
        
        # Apply transformer
        x = self.transformer_encoder(x)
        
        # Take the output corresponding to the first position
        x = x[:, 0, :]
        
        # Classify
        x = self.classifier(x)
        
        return x

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device):
    """Train the model"""
    model.to(device)
    best_val_acc = 0.0
    best_model = None
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]"):
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            # Statistics
            train_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            train_correct += (predicted == labels).sum().item()
            train_total += labels.size(0)
        
        train_loss = train_loss / train_total
        train_acc = train_correct / train_total
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for inputs, labels in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]"):
                inputs, labels = inputs.to(device), labels.to(device)
                
                # Forward pass
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                
                # Statistics
                val_loss += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs, 1)
                val_correct += (predicted == labels).sum().item()
                val_total += labels.size(0)
        
        val_loss = val_loss / val_total
        val_acc = val_correct / val_total
        
        print(f"Epoch {epoch+1}/{num_epochs}: "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model = model.state_dict().copy()
    
    # Load best model
    model.load_state_dict(best_model)
    
    return model

def evaluate_model(model, test_loader, device):
    """Evaluate the model on test data"""
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in tqdm(test_loader, desc="Evaluating"):
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Forward pass
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # Calculate metrics
    acc = accuracy_score(all_labels, all_preds)
    report = classification_report(all_labels, all_preds, labels=list(SLEEP_STAGES.keys()), target_names=list(SLEEP_STAGES.values()), output_dict=True)
    
    # Create confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    print("Cohen’s Kappa:", cohen_kappa_score(all_labels, all_preds))
    return acc, report, cm, all_preds, all_labels

def visualize_results(cm, report, sleep_stages):
    """Visualize classification results"""
    # Plot confusion matrix
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=list(sleep_stages.values()),
                yticklabels=list(sleep_stages.values()))
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.tight_layout()
    plt.savefig('confusion_matrix.png')
    plt.close()
    
    # Plot performance metrics
    metrics = ['precision', 'recall', 'f1-score']
    plt.figure(figsize=(12, 6))
    
    for i, metric in enumerate(metrics):
        plt.subplot(1, 3, i+1)
        values = [report[stage][metric] for stage in sleep_stages.values()]
        sns.barplot(x=list(sleep_stages.values()), y=values)
        plt.title(f'{metric.capitalize()}')
        plt.ylim(0, 1)
        plt.xticks(rotation=45)
    
    plt.tight_layout()
    plt.savefig('performance_metrics.png')
    plt.close()


In [5]:
"""Main function to compare PSG and headband recordings and evaluate sleep stage prediction"""
# File paths 
sub_no = 94
psg_path = f"Dataset_clean_for_jupyter/sub-{sub_no}/eeg/sub-{sub_no}_task-Sleep_acq-psg_eeg_6-channels.edf"
headband_path = f"Dataset_clean_for_jupyter/sub-{sub_no}/eeg/sub-{sub_no}_task-Sleep_acq-headband_eeg_2-channels.edf" 
hypnogram_path = f"Dataset_clean_for_jupyter/sub-{sub_no}/eeg/sub-{sub_no}_task-Sleep_acq-psg_events.tsv"

# 1. Load EDF files
psg_raw, headband_raw = load_edf_files(psg_path, headband_path)

# 2. Synchronize recordings
# psg_raw, headband_raw = synchronize_recordings(psg_raw, headband_raw)

# 3. Extract features
psg_channels = ['PSG_F3', 'PSG_F4', 'PSG_C3', 'PSG_C4', 'PSG_O1', 'PSG_O2']
headband_channels = ['HB_1', 'HB_2']

print("Extracting PSG features...")
psg_features = extract_features(psg_raw, psg_channels)

print("Extracting headband features...")
headband_features = extract_features(headband_raw, headband_channels)

# 4. Load sleep stages (ground truth)
try:
    sleep_stages = load_hypnogram(hypnogram_path)
    print((sleep_stages))
    
    print(f"Loaded {len(sleep_stages)} sleep stage labels")
except Exception as e:
    print(f"Error loading hypnogram: {e}")
    print("Generating synthetic labels for demonstration purposes")
    # Create synthetic labels for demonstration
    n_epochs = psg_features.shape[0]
    sleep_stages = np.random.randint(0, 5, size=n_epochs)


Loading EDF files...
Extracting EDF parameters from C:\Users\naikh\SleepResearchCode_experient\Dataset_clean_for_jupyter\sub-94\eeg\sub-94_task-Sleep_acq-psg_eeg_6-channels.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 7720703  =      0.000 ... 30158.996 secs...
Extracting EDF parameters from C:\Users\naikh\SleepResearchCode_experient\Dataset_clean_for_jupyter\sub-94\eeg\sub-94_task-Sleep_acq-headband_eeg_2-channels.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 7720703  =      0.000 ... 30158.996 secs...
PSG channels: ['PSG_F3', 'PSG_F4', 'PSG_C3', 'PSG_C4', 'PSG_O1', 'PSG_O2']
Headband channels: ['HB_1', 'HB_2']
Extracting PSG features...
Extracting headband features...
[0 0 0 ... 2 2 0]
Loaded 1005 sleep stage labels


In [6]:

# 5. Split data into train, validation, and test sets
X_train, X_temp, y_train, y_temp = train_test_split(
    headband_features, sleep_stages, test_size=0.3, random_state=42, stratify=sleep_stages
)
X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp, test_size=0.5, random_state=42, stratify=y_temp
)

print(f"Training set: {X_train.shape[0]} samples")
print(f"Validation set: {X_val.shape[0]} samples")
print(f"Test set: {X_test.shape[0]} samples")

# 6. Scale features
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_val = scaler.transform(X_val)
X_test = scaler.transform(X_test)



Training set: 703 samples
Validation set: 151 samples
Test set: 151 samples


In [7]:

# 7. Create data loaders
train_dataset = SleepDataset(X_train, y_train)
val_dataset = SleepDataset(X_val, y_val)
test_dataset = SleepDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

In [8]:

# 8. Initialize model
input_dim = X_train.shape[1]
num_classes = len(SLEEP_STAGES)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = TransformerModel(input_dim, num_classes)
print(f"Model created with {sum(p.numel() for p in model.parameters())} parameters")

# 9. Train model
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

model = train_model(model, train_loader, val_loader, criterion, optimizer, NUM_EPOCHS, device)

# 10. Evaluate model
accuracy, report, cm, predictions, true_labels = evaluate_model(model, test_loader, device)

print(f"\nTest Accuracy: {accuracy:.4f}")
print("\nClassification Report:")
for stage, metrics in report.items():
    if stage in SLEEP_STAGES.values():
        print(f"{stage}: Precision: {metrics['precision']:.4f}, Recall: {metrics['recall']:.4f}, F1: {metrics['f1-score']:.4f}")

# 11. Visualize results
visualize_results(cm, report, SLEEP_STAGES)

# 12. Compare with PSG-based model (optional)
print("\nComparing with PSG-based model:")
# Here you could implement a similar pipeline for PSG features
# and compare the performance

# 13. Generate report on headband accuracy
print("\nHeadband Accuracy Assessment:")
print(f"Overall accuracy: {accuracy:.4f}")

weighted_f1 = report['weighted avg']['f1-score']
print(f"Weighted F1 score: {weighted_f1:.4f}")

# Sleep architecture analysis
true_distribution = np.bincount(true_labels, minlength=len(SLEEP_STAGES)) / len(true_labels)
pred_distribution = np.bincount(predictions, minlength=len(SLEEP_STAGES)) / len(predictions)

print("\nSleep Architecture Comparison:")
for i, stage in SLEEP_STAGES.items():
    print(f"{stage}: True: {true_distribution[i]:.4f}, Predicted: {pred_distribution[i]:.4f}, " 
          f"Difference: {abs(true_distribution[i] - pred_distribution[i]):.4f}")

# Final assessment
if accuracy > 0.8:
    assessment = "Excellent"
elif accuracy > 0.7:
    assessment = "Good"
elif accuracy > 0.6:
    assessment = "Moderate"
else:
    assessment = "Poor"

print(f"\nOverall assessment: The headband shows {assessment} accuracy for sleep stage prediction.")


Model created with 103685 parameters


Epoch 1/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 64.71it/s]
Epoch 1/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 230.73it/s]


Epoch 1/50: Train Loss: 1.5185, Train Acc: 0.3499, Val Loss: 1.4231, Val Acc: 0.6093


Epoch 2/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 72.37it/s]
Epoch 2/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 249.88it/s]


Epoch 2/50: Train Loss: 1.3194, Train Acc: 0.6686, Val Loss: 1.2417, Val Acc: 0.6623


Epoch 3/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 72.85it/s]
Epoch 3/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 272.76it/s]


Epoch 3/50: Train Loss: 1.1686, Train Acc: 0.7041, Val Loss: 1.0991, Val Acc: 0.6556


Epoch 4/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 73.83it/s]
Epoch 4/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 250.06it/s]


Epoch 4/50: Train Loss: 1.0436, Train Acc: 0.6956, Val Loss: 0.9990, Val Acc: 0.6623


Epoch 5/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 80.89it/s]
Epoch 5/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 272.39it/s]


Epoch 5/50: Train Loss: 0.9557, Train Acc: 0.6970, Val Loss: 0.9300, Val Acc: 0.6623


Epoch 6/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 76.39it/s]
Epoch 6/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 272.46it/s]


Epoch 6/50: Train Loss: 0.8930, Train Acc: 0.6970, Val Loss: 0.8782, Val Acc: 0.6623


Epoch 7/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 80.23it/s]
Epoch 7/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 272.62it/s]


Epoch 7/50: Train Loss: 0.8534, Train Acc: 0.7183, Val Loss: 0.8353, Val Acc: 0.6821


Epoch 8/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 82.09it/s]
Epoch 8/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 332.94it/s]


Epoch 8/50: Train Loss: 0.8037, Train Acc: 0.7368, Val Loss: 0.7963, Val Acc: 0.7152


Epoch 9/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 80.29it/s]
Epoch 9/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 300.16it/s]


Epoch 9/50: Train Loss: 0.7689, Train Acc: 0.7568, Val Loss: 0.7609, Val Acc: 0.7285


Epoch 10/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 82.70it/s]
Epoch 10/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 299.79it/s]


Epoch 10/50: Train Loss: 0.7275, Train Acc: 0.7667, Val Loss: 0.7288, Val Acc: 0.7483


Epoch 11/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 81.48it/s]
Epoch 11/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 272.78it/s]


Epoch 11/50: Train Loss: 0.6945, Train Acc: 0.7752, Val Loss: 0.7006, Val Acc: 0.7550


Epoch 12/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 90.91it/s]
Epoch 12/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 272.92it/s]


Epoch 12/50: Train Loss: 0.6770, Train Acc: 0.7767, Val Loss: 0.6752, Val Acc: 0.7616


Epoch 13/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 76.94it/s]
Epoch 13/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 272.58it/s]


Epoch 13/50: Train Loss: 0.6456, Train Acc: 0.7866, Val Loss: 0.6520, Val Acc: 0.7550


Epoch 14/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 68.74it/s]
Epoch 14/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 250.21it/s]


Epoch 14/50: Train Loss: 0.6175, Train Acc: 0.7966, Val Loss: 0.6328, Val Acc: 0.7682


Epoch 15/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 69.62it/s]
Epoch 15/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 230.99it/s]


Epoch 15/50: Train Loss: 0.5967, Train Acc: 0.7966, Val Loss: 0.6145, Val Acc: 0.7682


Epoch 16/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 71.43it/s]
Epoch 16/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 230.68it/s]


Epoch 16/50: Train Loss: 0.5751, Train Acc: 0.7966, Val Loss: 0.5972, Val Acc: 0.7815


Epoch 17/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 67.90it/s]
Epoch 17/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 272.52it/s]


Epoch 17/50: Train Loss: 0.5585, Train Acc: 0.8193, Val Loss: 0.5816, Val Acc: 0.7881


Epoch 18/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 67.91it/s]
Epoch 18/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 142.86it/s]


Epoch 18/50: Train Loss: 0.5430, Train Acc: 0.8222, Val Loss: 0.5632, Val Acc: 0.7881


Epoch 19/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 63.58it/s]
Epoch 19/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 230.60it/s]


Epoch 19/50: Train Loss: 0.5271, Train Acc: 0.8250, Val Loss: 0.5467, Val Acc: 0.7881


Epoch 20/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 87.31it/s]
Epoch 20/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 230.65it/s]


Epoch 20/50: Train Loss: 0.5042, Train Acc: 0.8265, Val Loss: 0.5296, Val Acc: 0.7881


Epoch 21/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 70.31it/s]
Epoch 21/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 166.55it/s]


Epoch 21/50: Train Loss: 0.4863, Train Acc: 0.8350, Val Loss: 0.5216, Val Acc: 0.7947


Epoch 22/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 62.50it/s]
Epoch 22/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 176.61it/s]


Epoch 22/50: Train Loss: 0.4764, Train Acc: 0.8336, Val Loss: 0.5034, Val Acc: 0.7881


Epoch 23/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 62.50it/s]
Epoch 23/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 230.65it/s]


Epoch 23/50: Train Loss: 0.4700, Train Acc: 0.8236, Val Loss: 0.4917, Val Acc: 0.8278


Epoch 24/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 77.16it/s]
Epoch 24/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 200.05it/s]


Epoch 24/50: Train Loss: 0.4550, Train Acc: 0.8378, Val Loss: 0.4801, Val Acc: 0.8477


Epoch 25/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 68.75it/s]
Epoch 25/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 374.93it/s]


Epoch 25/50: Train Loss: 0.4485, Train Acc: 0.8464, Val Loss: 0.4671, Val Acc: 0.8477


Epoch 26/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 87.31it/s]
Epoch 26/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 300.12it/s]


Epoch 26/50: Train Loss: 0.4428, Train Acc: 0.8464, Val Loss: 0.4519, Val Acc: 0.8477


Epoch 27/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 84.62it/s]
Epoch 27/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 374.67it/s]


Epoch 27/50: Train Loss: 0.4222, Train Acc: 0.8620, Val Loss: 0.4555, Val Acc: 0.8477


Epoch 28/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 88.71it/s]
Epoch 28/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 333.77it/s]


Epoch 28/50: Train Loss: 0.4156, Train Acc: 0.8435, Val Loss: 0.4409, Val Acc: 0.8477


Epoch 29/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 87.28it/s]
Epoch 29/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 374.75it/s]


Epoch 29/50: Train Loss: 0.4065, Train Acc: 0.8478, Val Loss: 0.4335, Val Acc: 0.8477


Epoch 30/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 82.71it/s]
Epoch 30/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 374.90it/s]


Epoch 30/50: Train Loss: 0.4047, Train Acc: 0.8620, Val Loss: 0.4301, Val Acc: 0.8411


Epoch 31/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 94.83it/s]
Epoch 31/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 374.97it/s]


Epoch 31/50: Train Loss: 0.3834, Train Acc: 0.8649, Val Loss: 0.4211, Val Acc: 0.8477


Epoch 32/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 93.50it/s]
Epoch 32/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 374.71it/s]


Epoch 32/50: Train Loss: 0.3869, Train Acc: 0.8706, Val Loss: 0.4262, Val Acc: 0.8411


Epoch 33/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 92.44it/s]
Epoch 33/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 374.96it/s]


Epoch 33/50: Train Loss: 0.3911, Train Acc: 0.8634, Val Loss: 0.4160, Val Acc: 0.8477


Epoch 34/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 87.30it/s]
Epoch 34/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 375.13it/s]


Epoch 34/50: Train Loss: 0.3862, Train Acc: 0.8606, Val Loss: 0.4141, Val Acc: 0.8411


Epoch 35/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 86.62it/s]
Epoch 35/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 375.33it/s]


Epoch 35/50: Train Loss: 0.3733, Train Acc: 0.8706, Val Loss: 0.4022, Val Acc: 0.8411


Epoch 36/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 93.22it/s]
Epoch 36/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 428.73it/s]


Epoch 36/50: Train Loss: 0.3769, Train Acc: 0.8620, Val Loss: 0.3990, Val Acc: 0.8477


Epoch 37/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 97.35it/s]
Epoch 37/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 374.94it/s]


Epoch 37/50: Train Loss: 0.3735, Train Acc: 0.8634, Val Loss: 0.3977, Val Acc: 0.8543


Epoch 38/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 94.02it/s]
Epoch 38/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 375.13it/s]


Epoch 38/50: Train Loss: 0.3555, Train Acc: 0.8805, Val Loss: 0.3942, Val Acc: 0.8543


Epoch 39/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 92.44it/s]
Epoch 39/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 333.30it/s]


Epoch 39/50: Train Loss: 0.3745, Train Acc: 0.8720, Val Loss: 0.3906, Val Acc: 0.8609


Epoch 40/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 96.49it/s]
Epoch 40/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 374.84it/s]


Epoch 40/50: Train Loss: 0.3683, Train Acc: 0.8663, Val Loss: 0.4013, Val Acc: 0.8543


Epoch 41/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 92.89it/s]
Epoch 41/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 300.18it/s]


Epoch 41/50: Train Loss: 0.3612, Train Acc: 0.8748, Val Loss: 0.3842, Val Acc: 0.8609


Epoch 42/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 94.83it/s]
Epoch 42/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 333.76it/s]


Epoch 42/50: Train Loss: 0.3531, Train Acc: 0.8777, Val Loss: 0.3816, Val Acc: 0.8609


Epoch 43/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 94.02it/s]
Epoch 43/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 375.18it/s]


Epoch 43/50: Train Loss: 0.3533, Train Acc: 0.8734, Val Loss: 0.3815, Val Acc: 0.8543


Epoch 44/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 98.46it/s]
Epoch 44/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 300.02it/s]


Epoch 44/50: Train Loss: 0.3561, Train Acc: 0.8762, Val Loss: 0.3807, Val Acc: 0.8675


Epoch 45/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 88.72it/s]
Epoch 45/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 374.74it/s]


Epoch 45/50: Train Loss: 0.3546, Train Acc: 0.8777, Val Loss: 0.3749, Val Acc: 0.8609


Epoch 46/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 86.62it/s]
Epoch 46/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 374.76it/s]


Epoch 46/50: Train Loss: 0.3388, Train Acc: 0.8777, Val Loss: 0.3812, Val Acc: 0.8543


Epoch 47/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 93.22it/s]
Epoch 47/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 375.11it/s]


Epoch 47/50: Train Loss: 0.3453, Train Acc: 0.8720, Val Loss: 0.3727, Val Acc: 0.8609


Epoch 48/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 91.02it/s]
Epoch 48/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 375.14it/s]


Epoch 48/50: Train Loss: 0.3410, Train Acc: 0.8777, Val Loss: 0.3729, Val Acc: 0.8609


Epoch 49/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 86.61it/s]
Epoch 49/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 374.89it/s]


Epoch 49/50: Train Loss: 0.3499, Train Acc: 0.8748, Val Loss: 0.3675, Val Acc: 0.8675


Epoch 50/50 [Train]: 100%|██████████| 11/11 [00:00<00:00, 49.77it/s]
Epoch 50/50 [Val]: 100%|██████████| 3/3 [00:00<00:00, 265.94it/s]


Epoch 50/50: Train Loss: 0.3395, Train Acc: 0.8791, Val Loss: 0.3695, Val Acc: 0.8675


Evaluating: 100%|██████████| 3/3 [00:00<00:00, 375.31it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Cohen’s Kappa: 0.6852974611595302

Test Accuracy: 0.8543

Classification Report:
Wake: Precision: 0.5625, Recall: 0.9000, F1: 0.6923
N1: Precision: 1.0000, Recall: 0.5000, F1: 0.6667
N2: Precision: 0.9065, Recall: 0.9151, F1: 0.9108
N3: Precision: 0.0000, Recall: 0.0000, F1: 0.0000
REM: Precision: 0.8000, Recall: 0.6897, F1: 0.7407

Comparing with PSG-based model:

Headband Accuracy Assessment:
Overall accuracy: 0.8543
Weighted F1 score: 0.8540

Sleep Architecture Comparison:
Wake: True: 0.0662, Predicted: 0.1060, Difference: 0.0397
N1: True: 0.0397, Predicted: 0.0199, Difference: 0.0199
N2: True: 0.7020, Predicted: 0.7086, Difference: 0.0066
N3: True: 0.0000, Predicted: 0.0000, Difference: 0.0000
REM: True: 0.1921, Predicted: 0.1656, Difference: 0.0265

Overall assessment: The headband shows Excellent accuracy for sleep stage prediction.


In [9]:
print(f"\nTest Accuracy: {accuracy:.4f}")

print("\nClassification Report:")
for stage, metrics in report.items():
    if stage in SLEEP_STAGES.values():
        print(f"{stage}: Precision: {metrics['precision']:.4f}, "
              f"Recall: {metrics['recall']:.4f}, "
              f"F1: {metrics['f1-score']:.4f}")

# 11. Visualize results
visualize_results(cm, report, SLEEP_STAGES)

# 12. Compare with PSG-based model (optional)
print("\nComparing with PSG-based model:")
# Here you could implement a similar pipeline for PSG features
# and compare the performance

# 13. Generate report on headband accuracy
print("\nHeadband Accuracy Assessment:")
print(f"Overall accuracy: {accuracy:.4f}")
weighted_f1 = report['weighted avg']['f1-score']
print(f"Weighted F1 score: {weighted_f1:.4f}")



Test Accuracy: 0.8543

Classification Report:
Wake: Precision: 0.5625, Recall: 0.9000, F1: 0.6923
N1: Precision: 1.0000, Recall: 0.5000, F1: 0.6667
N2: Precision: 0.9065, Recall: 0.9151, F1: 0.9108
N3: Precision: 0.0000, Recall: 0.0000, F1: 0.0000
REM: Precision: 0.8000, Recall: 0.6897, F1: 0.7407

Comparing with PSG-based model:

Headband Accuracy Assessment:
Overall accuracy: 0.8543
Weighted F1 score: 0.8540


In [10]:
# Generate comprehensive sleep stage evaluation table
from tabulate import tabulate


# Sleep architecture analysis
true_distribution = np.bincount(true_labels, minlength=len(SLEEP_STAGES)) / len(true_labels)
pred_distribution = np.bincount(predictions, minlength=len(SLEEP_STAGES)) / len(predictions)


# Prepare data for the table
table_data = []
headers = ["Sleep Stage", "True %", "Predicted %", "Difference", "Precision", "Recall", "F1-Score"]

# Populate the table data
for i, stage in SLEEP_STAGES.items():
    if stage in report:  # Make sure the stage exists in the report
        row = [
            stage,
            f"{true_distribution[i]:.2%}",
            f"{pred_distribution[i]:.2%}",
            f"{abs(true_distribution[i] - pred_distribution[i]):.2%}",
            f"{report[stage]['precision']:.4f}",
            f"{report[stage]['recall']:.4f}",
            f"{report[stage]['f1-score']:.4f}"
        ]
        table_data.append(row)

# Add summary row
table_data.append([
    "Overall",
    "100.00%",
    "100.00%",
    f"{sum(abs(true_distribution - pred_distribution))/2:.2%}",  # Total distribution error
    f"{report['weighted avg']['precision']:.4f}",
    f"{report['weighted avg']['recall']:.4f}",
    f"{report['weighted avg']['f1-score']:.4f}"
])

# Print the table
print("\nComprehensive Sleep Stage Evaluation:")
print(tabulate(table_data, headers=headers, tablefmt="grid"))




Comprehensive Sleep Stage Evaluation:
+---------------+----------+---------------+--------------+-------------+----------+------------+
| Sleep Stage   | True %   | Predicted %   | Difference   |   Precision |   Recall |   F1-Score |
| Wake          | 6.62%    | 10.60%        | 3.97%        |      0.5625 |   0.9    |     0.6923 |
+---------------+----------+---------------+--------------+-------------+----------+------------+
| N1            | 3.97%    | 1.99%         | 1.99%        |      1      |   0.5    |     0.6667 |
+---------------+----------+---------------+--------------+-------------+----------+------------+
| N2            | 70.20%   | 70.86%        | 0.66%        |      0.9065 |   0.9151 |     0.9108 |
+---------------+----------+---------------+--------------+-------------+----------+------------+
| N3            | 0.00%    | 0.00%         | 0.00%        |      0      |   0      |     0      |
+---------------+----------+---------------+--------------+-------------+------

In [11]:
# Print overall assessment
# Final assessment
if accuracy > 0.8:
    assessment = "Excellent"
elif accuracy > 0.7:
    assessment = "Good"
elif accuracy > 0.6:
    assessment = "Moderate"
else:
    assessment = "Poor"
print(f"\nOverall assessment: The headband shows {assessment} accuracy for sleep stage prediction.")
print(f"Test Accuracy: {accuracy:.4f}")


Overall assessment: The headband shows Excellent accuracy for sleep stage prediction.
Test Accuracy: 0.8543
