<a href="https://colab.research.google.com/github/avkornaev/Sleep_Stages/blob/main/SleepStages_vol_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [56]:
!pip install pytorch-lightning clearml



In [57]:
# ---------------------------
# 1. Imports & Environment Setup
# ---------------------------
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import pytorch_lightning as pl
from typing import Optional, List, Dict, Tuple
from collections import defaultdict
from clearml import Task
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from torch.utils.data import Dataset, DataLoader, TensorDataset

In [58]:
# Set up reproducibility
SEED = 42
pl.seed_everything(SEED)
torch.set_float32_matmul_precision('high')

INFO:lightning_fabric.utilities.seed:Seed set to 42


In [59]:
#Enter your code here to implement Step 2 of the logging instruction as it is shown below
%env CLEARML_WEB_HOST=https://app.clear.ml/
%env CLEARML_API_HOST=https://api.clear.ml
%env CLEARML_FILES_HOST=https://files.clear.ml
%env CLEARML_API_ACCESS_KEY=ZP02U03C6V5ER4K9VWRNZT7EWA5ZTV
%env CLEARML_API_SECRET_KEY=BtA5GXZufr6QGpaqhX1GSKPTvaCt56OLqaNqUGLNoxx2Ye8Ctwbui0Ln5OXVnzUgH4I

env: CLEARML_WEB_HOST=https://app.clear.ml/
env: CLEARML_API_HOST=https://api.clear.ml
env: CLEARML_FILES_HOST=https://files.clear.ml
env: CLEARML_API_ACCESS_KEY=ZP02U03C6V5ER4K9VWRNZT7EWA5ZTV
env: CLEARML_API_SECRET_KEY=BtA5GXZufr6QGpaqhX1GSKPTvaCt56OLqaNqUGLNoxx2Ye8Ctwbui0Ln5OXVnzUgH4I


In [60]:
DATASET = 'Sleep_Stages' # dataset with the real-world noise
#Clone the GitHub repository
repo_url = "https://github.com/avkornaev/Sleep_Stages"  # Replace with your repo URL
!git clone {repo_url}

#Navigate to the data folder
repo_name = repo_url.split("/")[-1].replace(".git", "")  # Extract repo name
data_dir = os.path.join(repo_name, "data")  # Replace "data" with your folder name
# os.chdir(data_dir)  # Change working directory to the data folder

# Verify the data directory
if os.path.exists(data_dir):
    print(f"Data directory found: {data_dir}")
else:
    print(f"Data directory not found: {data_dir}")

fatal: destination path 'Sleep_Stages' already exists and is not an empty directory.
Data directory found: Sleep_Stages/data


In [61]:
# ---------------------------
# 2. Configuration Constants
# ---------------------------
CONFIG = {
    "data_dir": "Sleep_Stages/data",
    "checkpoint_path": "saved_models/",
    "batch_size": 128,
    "window_size": 600,
    "learning_rate": 1e-3,
    "max_epochs": 50,
    "num_workers": 2
}

In [62]:
# ---------------------------
# 3. Data Loading & Preprocessing
# ---------------------------
def load_and_preprocess_data(directory: str) -> Tuple[Dict[str, tuple], List[str]]:
    """Robust data loader with full header handling"""
    volunteer_data = defaultdict(list)
    label_sets = []
    feature_means = None

    # First pass: Calculate global feature means
    print("Calculating global feature means...")
    all_features = []
    valid_files = 0

    for filename in os.listdir(directory):
        if not (filename.startswith("Vol_") and filename.endswith(".csv.gz")):
            continue

        file_path = os.path.join(directory, filename)
        try:
            # Skip header and handle types
            df = pd.read_csv(
                file_path,
                compression='gzip',
                skiprows=1,  # Skip header row
                header=None,
                dtype={i: np.float32 for i in range(19)},  # Enforce numeric types
                on_bad_lines='warn'
            )
            df.columns = [f"feature_{i}" for i in range(19)] + ["label"]

            # Validate label column
            df['label'] = df['label'].str.strip().replace('', np.nan)
            df = df.dropna(subset=['label'])

            if not df.empty:
                all_features.append(df.iloc[:, :-1].values)
                valid_files += 1

        except Exception as e:
            print(f"Skipping {filename}: {str(e)}")
            continue

    if not all_features:
        raise ValueError("No valid data found in any files")

    # Calculate global statistics
    feature_means = np.nanmean(np.concatenate(all_features), axis=0)
    print(f"Processed {valid_files} files with global feature means:\n{feature_means}")

    # Second pass: Process files with imputation
    for filename in os.listdir(directory):
        if not (filename.startswith("Vol_") and filename.endswith(".csv.gz")):
            continue

        file_path = os.path.join(directory, filename)
        print(f"\nProcessing: {filename}")

        try:
            # Read with consistent header handling
            df = pd.read_csv(
                file_path,
                compression='gzip',
                skiprows=1,  # Skip header row
                header=None,
                dtype={i: np.float32 for i in range(19)},
                on_bad_lines='warn'
            )
            df.columns = [f"feature_{i}" for i in range(19)] + ["label"]

            # Clean labels
            df['label'] = df['label'].str.strip().replace('', np.nan)
            df = df.dropna(subset=['label'])

            if df.empty:
                print("  No valid rows with labels - skipping")
                continue

            # Impute missing values
            for i in range(19):
                df[f'feature_{i}'] = df[f'feature_{i}'].fillna(feature_means[i])

            # Store data
            features = df.iloc[:, :-1].values.astype(np.float32)
            labels = df['label'].values

            volunteer_id = filename.split("_")[1].split(".")[0]
            volunteer_data[volunteer_id].append((features, labels))
            label_sets.append(set(labels))

            print(f"  Successfully processed {len(features)} samples")
            print(f"  Unique labels: {sorted(set(labels))}")

        except Exception as e:
            print(f"  Failed to process: {str(e)}")
            continue

    # Handle labels with threshold
    label_counts = {}
    for labels in label_sets:
        for lbl in labels:
            label_counts[lbl] = label_counts.get(lbl, 0) + 1

    # Require labels present in at least 60% of volunteers
    min_volunteers = int(1.0 * len(volunteer_data))
    common_labels = [lbl for lbl, count in label_counts.items() if count >= min_volunteers]

    if not common_labels:
        print("\nLabel distribution across volunteers:")
        for i, labels in enumerate(label_sets):
            print(f"Volunteer {i+1}: {sorted(labels)}")
        raise ValueError(f"No labels common across at least {min_volunteers} volunteers")

    print(f"\nFinal labels: {common_labels}")

    # After determining common_labels:
    label_to_id = {label: idx for idx, label in enumerate(common_labels)}
    print(f"\nLabel encoding: {label_to_id}")

    # Process and normalize data
    merged_data = {}
    for vol_id, data_list in volunteer_data.items():
        all_features = np.concatenate([x[0] for x in data_list])
        all_labels = np.concatenate([x[1] for x in data_list])

        # Filter to common labels and encode to integers
        valid_mask = np.isin(all_labels, common_labels)
        features = all_features[valid_mask]
        str_labels = all_labels[valid_mask]
        labels = np.vectorize(label_to_id.get)(str_labels)  # Convert to integers

        if len(features) == 0:
            print(f"Skipping {vol_id} - no common label data")
            continue

        # Normalize features
        features = (features - feature_means) / (np.std(features, axis=0) + 1e-8)

        merged_data[f"Vol_{vol_id}"] = (features, labels)
        print(f"\nVolunteer {vol_id} summary:")
        print(f"- Total samples: {len(features)}")
        print("- Label distribution:")
        for lbl in common_labels:
            print(f"  {lbl}: {np.sum(labels == lbl)} samples")

    return merged_data, common_labels

In [63]:
# ---------------------------
# 4. Dataset & DataModule (Corrected)
# ---------------------------
class SleepDataset(Dataset):
    def __init__(self, features: np.ndarray, labels: np.ndarray, window_size: int):
        self.windows, self.labels = self._create_windows(features, labels, window_size)
        # Convert to tensors during initialization
        self.windows = torch.tensor(self.windows, dtype=torch.float32)
        self.labels = torch.tensor(self.labels, dtype=torch.long)

    def _create_windows(self, features, labels, window_size):
        windows = []
        window_labels = []
        for start in range(0, len(features) - window_size + 1, window_size//2):
            end = start + window_size
            window = features[start:end].T  # (features, time)
            label = labels[end-1]
            windows.append(window)
            window_labels.append(label)
        return np.array(windows), np.array(window_labels)

    def __len__(self):
        return len(self.windows)

    def __getitem__(self, idx):
        return self.windows[idx], self.labels[idx]

class SleepDataModule(pl.LightningDataModule):
    def __init__(self, merged_data: Dict[str, tuple], config: dict):
        super().__init__()  # Initialize parent class
        self.data = merged_data
        self.config = config
        self.batch_size = config['batch_size']
        self.window_size = config['window_size']
        self.datasets = {}

    def setup(self, stage: Optional[str] = None):
        """Create windowed datasets for all volunteers"""
        for vol_id, (features, labels) in self.data.items():
            try:
                # Convert to sliding windows
                windows, window_labels = self._create_windows(features, labels)
                self.datasets[vol_id] = TensorDataset(
                    torch.tensor(windows, dtype=torch.float32),
                    torch.tensor(window_labels, dtype=torch.long)
                )
            except Exception as e:
                print(f"Error processing {vol_id}: {str(e)}")

    def _create_windows(self, features: np.ndarray, labels: np.ndarray):
        """Generate time-series windows with overlap"""
        windows = []
        window_labels = []
        for start in range(0, len(features) - self.window_size + 1, self.window_size // 2):
            end = start + self.window_size
            window = features[start:end].T  # (features, time)
            label = labels[end-1]
            if not np.isnan(window).any():
                windows.append(window)
                window_labels.append(label)
        return np.array(windows), np.array(window_labels)

    def get_loso_splits(self, test_subject: str):
        """Generate LOSO splits for a given test subject"""
        train_data = []
        train_labels = []

        # Aggregate training data
        for vol_id, dataset in self.datasets.items():
            if vol_id != test_subject:
                train_data.append(dataset.tensors[0])
                train_labels.append(dataset.tensors[1])

        return (
            DataLoader(TensorDataset(torch.cat(train_data), torch.cat(train_labels)),
                     batch_size=self.batch_size, shuffle=True),
            DataLoader(self.datasets[test_subject], batch_size=self.batch_size)
        )

In [63]:
# ---------------------------
# 5. Model Architecture
# ---------------------------
class SleepClassifier(nn.Module):
    def __init__(self, input_channels: int, num_classes: int):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv1d(input_channels, 64, 3, padding=1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Conv1d(64, 128, 3, padding=1),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Conv1d(128, 256, 3, padding=1),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        return self.layers(x)

class LightningWrapper(pl.LightningModule):
    def __init__(self, model: nn.Module, lr: float):
        super().__init__()
        self.model = model
        self.lr = lr
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        acc = (y_hat.argmax(dim=1) == y).float().mean()
        self.log("val_acc", acc, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

In [64]:
# ---------------------------
# 6. Training Pipeline
# ---------------------------
def run_training(config: dict):
    """Complete training workflow"""
    # Initialize ClearML
    task = Task.init(project_name="SleepStaging", task_name="1DCNN-Baseline")
    task.connect(config)

    try:
        # Load and validate data
        merged_data, class_labels = load_and_preprocess_data(config['data_dir'])
        print(f"Loaded data from {len(merged_data)} volunteers")
        print(f"Class distribution: {class_labels}")

        # Initialize components
        datamodule = SleepDataModule(merged_data, config)
        datamodule.setup()

        # Training loop
        for test_subject in datamodule.datasets.keys():
            print(f"\n{'='*40}\nTraining on {len(datamodule.datasets)-1} subjects | Validating on {test_subject}")

            # Model setup
            model = LightningWrapper(
                SleepClassifier(input_channels=19, num_classes=len(class_labels)),
                lr=config['learning_rate']
            )

            # Configure trainer
            trainer = Trainer(
                max_epochs=config['max_epochs'],
                callbacks=[
                    EarlyStopping(monitor="val_acc", patience=5, mode="max"),
                    ModelCheckpoint(
                        dirpath=config['checkpoint_path'],
                        filename=f"best-{test_subject}",
                        monitor="val_acc",
                        mode="max"
                    )
                ],
                deterministic=True
            )

            # Get data loaders
            train_loader, val_loader = datamodule.get_loso_splits(test_subject)

            # Train/validate
            trainer.fit(model, train_loader, val_loader)

    except Exception as e:
        task.get_logger().report_text(f"Training failed: {str(e)}")
        raise
    finally:
        task.close()

# ---------------------------
# 7. Main Execution
# ---------------------------
if __name__ == "__main__":
    os.makedirs(CONFIG['checkpoint_path'], exist_ok=True)
    run_training(CONFIG)

ClearML Task: created new task id=038d7dd85f754a778f066ed267e6f570
ClearML results page: https://app.clear.ml/projects/3bb8fcd9215b413abba86250d3b3ee7f/experiments/038d7dd85f754a778f066ed267e6f570/output/log
Calculating global feature means...
ClearML Task: created new task id=038d7dd85f754a778f066ed267e6f570
ClearML results page: https://app.clear.ml/projects/3bb8fcd9215b413abba86250d3b3ee7f/experiments/038d7dd85f754a778f066ed267e6f570/output/log
ClearML Monitor: GPU monitoring failed getting GPU reading, switching off GPU monitoring
Processed 5 files with global feature means:
[ 18.295462   36.243195  122.49296    44.21548     0.3589693  26.968536
  38.03891   109.97027    80.47851     5.042736   25.270573   45.850376
  17.166382    6.484191   61.160995   38.58385   116.55848    66.4239
   8.278388 ]

Processing: Vol_02.csv.gz
  Successfully processed 133328 samples
  Unique labels: ['N2', 'R', 'W']

Processing: Vol_03_1.csv.gz
  Successfully processed 491994 samples
  Unique labels:

INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs



Training on 2 subjects | Validating on Vol_02


INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type             | Params | Mode 
-----------------------------------------------------
0 | model   | SleepClassifier  | 128 K  | train
1 | loss_fn | CrossEntropyLoss | 0      | train
-----------------------------------------------------
128 K     Trainable params
0         Non-trainable params
128 K     Total params
0.516     Total estimated model params size (MB)
17        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]


The number of training batches (48) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.



Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs

Checkpoint directory /content/saved_models exists and is not empty.

INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type             | Params | Mode 
-----------------------------------------------------
0 | model   | SleepClassifier  | 128 K  | train
1 | loss_fn | CrossEntropyLoss | 0      | train
-----------------------------------------------------
128 K     Trainable params
0         Non-trainable params
128 K     Total params
0.516     Total estimated model params size (MB)
17        Modules in train mode
0         Modules in eval mode



Training on 2 subjects | Validating on Vol_03


Sanity Checking: |          | 0/? [00:00<?, ?it/s]


The number of training batches (13) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.



Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs



Training on 2 subjects | Validating on Vol_01


INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type             | Params | Mode 
-----------------------------------------------------
0 | model   | SleepClassifier  | 128 K  | train
1 | loss_fn | CrossEntropyLoss | 0      | train
-----------------------------------------------------
128 K     Trainable params
0         Non-trainable params
128 K     Total params
0.516     Total estimated model params size (MB)
17        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]


The number of training batches (42) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.



Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]