<a href="https://colab.research.google.com/github/avkornaev/Sleep_Stages/blob/main/SleepStages.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Sleep Stages
*March, 21, 2025.*

## Problem Statement

Given data files for a several volunteers. Each file contain time series multi-sensory data and their processing results (19 columns total) and a column for label at each time step. The equipment used for data collection implements non-invasive optical diagnostic methods for recording peripheral blood flow parameters using laser Doppler flowmetry (LDF) and tissue oxidative metabolism through fluorescence spectroscopy (FS). The data were recorded during several hour of sleep. The time step was 0.05 s for all the sensors. The label is a sleep stage (NREM1, NREM2, REM, Wakefulness). Some of the data is missed. [Repo
](https://github.com/avkornaev/Sleep_Stages) of the project.

## Tasks and Requirements  

- Review the [Lightning framework](https://lightning.ai/docs/pytorch/stable/) (Level Up, Core API, Optional API sections of the manual).  
- Briefly review the [ClearML](https://clear.ml/docs/latest/docs/integrations/pytorch_lightning/) documentation.

# Preparation of simulation models

## Import and Install Libraries

In [242]:
!pip install pytorch-lightning clearml



In [243]:
#Pytorch modules
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader, random_split, TensorDataset
from torchvision import datasets, transforms, models
#scipy
from scipy.stats import mode
#sklearn
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.metrics import accuracy_score
from sklearn.model_selection import LeaveOneGroupOut

#Numpy
import numpy as np
#Pandas
import pandas as pd
#Lightning & logging
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping
#Data observation
import os
import sys
import pickle
import requests
from pathlib import Path
from collections import defaultdict
#Plotting
import matplotlib.pyplot as plt
import seaborn as sns
#Logging
from clearml import Task

## Set the Models

### Simulation Settings

Check the current directory

In [244]:
os.getcwd() #returns the current working directory

'/content'

In [245]:
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = os.environ.get("PATH_CHECKPOINT", "saved_models/")
print(f'CHECKPOINT_PATH: {CHECKPOINT_PATH}')

os.makedirs(CHECKPOINT_PATH, exist_ok=True)

CHECKPOINT_PATH: saved_models/


Set the reproducibility options

In [246]:
# Function for setting the seed to implement parallel tests
SEEDS =  [42] #[42, 0, 17, 9, 3, 16, 2]
SEED = 42 # random seed by default
pl.seed_everything(SEED)

# Determine the device (GPU if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Prioritizes speed but may reduce precision
torch.set_float32_matmul_precision('high')

INFO:lightning_fabric.utilities.seed:Seed set to 42


### Logging

To configure ClearML in your Colab environment, follow these steps:

---

*Step 1: Create a ClearML Account*
1. Go to the [ClearML website](https://clear.ml/).
2. Sign up for a free account if you don’t already have one.
3. Once registered, log in to your ClearML account.

---

*Step 2: Get Your ClearML Credentials*
1. After logging in, navigate to the **Settings** page (click on your profile icon in the top-right corner and select **Settings**).
2. Under the **Workspace** section, find your **+ Create new credentials**.
3. Copy these credentials for a Jupiter notebook into the code cell below.

---

*Step 3: Accessing the ClearML Dashboard*
1. Go to your ClearML dashboard (https://app.clear.ml).
2. Navigate to the **Projects** section to see your experiments.
3. Click on the experiment (e.g., `Lab_1`) to view detailed metrics, logs, and artifacts.

---

In [247]:
#Enter your code here to implement Step 2 of the logging instruction as it is shown below
%env CLEARML_WEB_HOST=https://app.clear.ml/
%env CLEARML_API_HOST=https://api.clear.ml
%env CLEARML_FILES_HOST=https://files.clear.ml
%env CLEARML_API_ACCESS_KEY=ZP02U03C6V5ER4K9VWRNZT7EWA5ZTV
%env CLEARML_API_SECRET_KEY=BtA5GXZufr6QGpaqhX1GSKPTvaCt56OLqaNqUGLNoxx2Ye8Ctwbui0Ln5OXVnzUgH4I

env: CLEARML_WEB_HOST=https://app.clear.ml/
env: CLEARML_API_HOST=https://api.clear.ml
env: CLEARML_FILES_HOST=https://files.clear.ml
env: CLEARML_API_ACCESS_KEY=ZP02U03C6V5ER4K9VWRNZT7EWA5ZTV
env: CLEARML_API_SECRET_KEY=BtA5GXZufr6QGpaqhX1GSKPTvaCt56OLqaNqUGLNoxx2Ye8Ctwbui0Ln5OXVnzUgH4I


### Dataset

Summary

In [248]:
DATASET = 'Sleep_Stages' # dataset with the real-world noise
#Clone the GitHub repository
repo_url = "https://github.com/avkornaev/Sleep_Stages"  # Replace with your repo URL
!git clone {repo_url}

#Navigate to the data folder
repo_name = repo_url.split("/")[-1].replace(".git", "")  # Extract repo name
data_dir = os.path.join(repo_name, "data")  # Replace "data" with your folder name
# os.chdir(data_dir)  # Change working directory to the data folder

# Verify the data directory
if os.path.exists(data_dir):
    print(f"Data directory found: {data_dir}")
else:
    print(f"Data directory not found: {data_dir}")

fatal: destination path 'Sleep_Stages' already exists and is not an empty directory.
Data directory found: Sleep_Stages/data


### Collect parameters

In [249]:
#Model parameters
LOSS_FUN = 'CE' # 'CE','CELoss'(custom), 'N', 'B', etc.
ARCHITECTURE = '1DCNN' #

#Collect the parameters (hyperparams and others)
hparams = {
    "seed": SEED,
    "lr": 0.001,
    'weight_decay': 0.0,
    "dropout": 0.0,
    "bs": 128,
    "num_workers": 2,
    "num_epochs": 2,
    "criterion": LOSS_FUN,
    "architecture": ARCHITECTURE,
    "window_size": 600,
    'label_smoothing': 0.0,
    }

#Visualization
vis_params = {
    'fig_size': 5,
    'num_samples': 5,
    'num_bins': 50,
}

## Functions

### Lightning

Data module

In [250]:
def load_and_preprocess_data(directory, allowed_labels=None):
    """
    Load and merge volunteer data with NaN handling.

    Args:
        directory (str): Path to the directory containing the data files.
        allowed_labels (list, optional): List of labels to include. If None, only labels
                                         present for all volunteers are used.

    Returns:
        dict: A dictionary where keys are volunteer IDs (e.g., "Vol_01") and values are
              tuples of (features, labels).
    """
    volunteer_data = defaultdict(list)

    # Load data for all volunteers
    for filename in os.listdir(directory):
        if filename.endswith(".csv.gz") and filename.startswith("Vol_"):
            volunteer_id = filename.split("_")[1].split(".")[0]
            file_path = os.path.join(directory, filename)

            try:
                df = pd.read_csv(file_path, compression='gzip')
                # Check for empty/nan-only data
                if df.drop(columns=['label']).isna().all().all():
                    print(f"Warning: {filename} contains only NaN values. Skipping.")
                    continue

                # Forward fill missing values
                df.ffill(inplace=True)
                # Drop remaining NaN rows
                df.dropna(inplace=True)

                volunteer_data[volunteer_id].append(df)

            except Exception as e:
                print(f"Error loading {filename}: {str(e)}")
                continue

    # Determine common labels if not provided
    if allowed_labels is None:
        all_labels = set()
        for volunteer_id, dfs in volunteer_data.items():
            if dfs:
                df = pd.concat(dfs, ignore_index=True)
                all_labels.update(df['label'].unique())

        # Find labels present for all volunteers
        common_labels = all_labels.copy()
        for volunteer_id, dfs in volunteer_data.items():
            if dfs:
                df = pd.concat(dfs, ignore_index=True)
                common_labels.intersection_update(df['label'].unique())

        if not common_labels:
            raise ValueError("No common labels found across all volunteers.")

        allowed_labels = list(common_labels)
        print(f"Using common labels: {allowed_labels}")

    # Process each volunteer's data
    merged_data = {}
    for volunteer_id, dfs in volunteer_data.items():
        if not dfs:  # Skip volunteers with all files invalid
            print(f"Warning: No valid data for Vol_{volunteer_id}")
            continue

        df = pd.concat(dfs, ignore_index=True)

        # Filter rows with allowed labels
        df = df[df['label'].isin(allowed_labels)]

        # Final NaN check
        if df.empty:
            print(f"Warning: No valid data remaining for Vol_{volunteer_id}")
            continue

        # Normalize features (excluding label column)
        features = df.iloc[:, :-1].values.astype(np.float32)
        features = (features - np.mean(features, axis=0)) / np.std(features, axis=0)

        # Convert labels to integers
        label_to_code = {label: idx for idx, label in enumerate(allowed_labels)}
        labels = df['label'].map(label_to_code).values.astype(np.int64)

        merged_data[f"Vol_{volunteer_id}"] = (features, labels)

    return merged_data

In [251]:
class SleepStageDataset(Dataset):
    def __init__(self, features, labels, window_size=600, stride=300):
        """
        Args:
            features (np.ndarray): Input features of shape (num_samples, num_features).
            labels (np.ndarray): Corresponding labels of shape (num_samples,).
            window_size (int): Size of the sliding window.
            stride (int): Stride for the sliding window.
        """
        self.features = features
        self.labels = labels
        self.window_size = window_size
        self.stride = stride
        self.windows, self.window_labels = self._create_windows()

    def _create_windows(self):
        """
        Creates sliding windows from the time-series data.
        Returns:
            windows (np.ndarray): Windows of shape (num_windows, num_features, window_size).
            window_labels (np.ndarray): Labels for each window of shape (num_windows,).
        """
        windows = []
        window_labels = []
        n_samples = self.features.shape[0]

        for start in range(0, n_samples - self.window_size + 1, self.stride):
            end = start + self.window_size
            window = self.features[start:end]  # Shape: (window_size, num_features)
            label = self.labels[end - 1]  # Use the label at the end of the window
            windows.append(window.T)  # Transpose to (num_features, window_size)
            window_labels.append(label)

        return np.array(windows), np.array(window_labels)

    def __len__(self):
        return len(self.windows)

    def __getitem__(self, idx):
        """
        Returns:
            x (torch.Tensor): Input tensor of shape (num_features, window_size).
            y (torch.Tensor): Label tensor (scalar).
        """
        x = torch.tensor(self.windows[idx], dtype=torch.float32)
        y = torch.tensor(self.window_labels[idx], dtype=torch.long)
        return x, y

class SleepDataModule(pl.LightningDataModule):
    def __init__(self, merged_data, batch_size=32, window_size=600):
        super().__init__()
        self.merged_data = merged_data
        self.batch_size = batch_size
        self.window_size = window_size
        self.subjects = list(merged_data.keys())

    def setup(self, stage=None):
        """Prepares all subject datasets."""
        self.datasets = {}
        for vol_id, (features, labels) in self.merged_data.items():
            self.datasets[vol_id] = SleepStageDataset(features, labels, self.window_size)

    def get_loso_splits(self, test_subject):
        """
        Generates LOSO splits for a given test subject.
        Returns:
            train_loader (DataLoader): DataLoader for training data.
            val_loader (DataLoader): DataLoader for validation data.
        """
        train_subs = [s for s in self.subjects if s != test_subject]

        # Combine training subjects
        train_data = torch.cat([torch.tensor(self.datasets[s].windows, dtype=torch.float32) for s in train_subs])
        train_labels = torch.cat([torch.tensor(self.datasets[s].window_labels, dtype=torch.long) for s in train_subs])

        # Test subject
        test_data = torch.tensor(self.datasets[test_subject].windows, dtype=torch.float32)
        test_labels = torch.tensor(self.datasets[test_subject].window_labels, dtype=torch.long)

        # Create DataLoaders
        train_loader = DataLoader(
            TensorDataset(train_data, train_labels),
            batch_size=self.batch_size,
            shuffle=True
        )
        val_loader = DataLoader(
            TensorDataset(test_data, test_labels),
            batch_size=self.batch_size,
            shuffle=False
        )

        return train_loader, val_loader

Training module

In [252]:
class SleepClassifier(pl.LightningModule):
    def __init__(self, input_channels=19, num_classes=4):
        super().__init__()
        self.save_hyperparameters()
        self.model = torch.nn.Sequential(
            torch.nn.Conv1d(input_channels, 64, kernel_size=3, padding=1),
            torch.nn.BatchNorm1d(64),
            torch.nn.ReLU(),
            torch.nn.MaxPool1d(2),

            torch.nn.Conv1d(64, 128, kernel_size=3, padding=1),
            torch.nn.BatchNorm1d(128),
            torch.nn.ReLU(),
            torch.nn.MaxPool1d(2),

            torch.nn.Conv1d(128, 256, kernel_size=3, padding=1),
            torch.nn.BatchNorm1d(256),
            torch.nn.ReLU(),
            torch.nn.AdaptiveAvgPool1d(1),

            torch.nn.Flatten(),
            torch.nn.Linear(256, num_classes)
        )
        self.loss_fn = torch.nn.CrossEntropyLoss()

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        self.log("val_loss", loss, prog_bar=True)

        # Track predictions and labels
        preds = torch.argmax(y_hat, dim=1)
        return {"preds": preds, "labels": y}

    def validation_epoch_end(self, outputs):
        # Aggregate predictions and labels
        preds = torch.cat([x["preds"] for x in outputs])
        labels = torch.cat([x["labels"] for x in outputs])

        # Compute accuracy
        acc = (preds == labels).float().mean()
        self.log("val_acc", acc, prog_bar=True)
        return {"val_acc": acc}

### Models

In [253]:
# class SleepStageCNN(pl.LightningModule):
#     def __init__(self, input_channels=19, num_classes=4):  # Update input_channels
#         super().__init__()
#         self.model = torch.nn.Sequential(
#             torch.nn.Conv1d(input_channels, 64, kernel_size=3, padding=1),
#             torch.nn.BatchNorm1d(64),
#             torch.nn.ReLU(),
#             torch.nn.MaxPool1d(2),

#             torch.nn.Conv1d(64, 128, kernel_size=3, padding=1),
#             torch.nn.BatchNorm1d(128),
#             torch.nn.ReLU(),
#             torch.nn.MaxPool1d(2),

#             torch.nn.Conv1d(128, 256, kernel_size=3, padding=1),
#             torch.nn.BatchNorm1d(256),
#             torch.nn.ReLU(),
#             torch.nn.AdaptiveAvgPool1d(1),

#             torch.nn.Flatten(),
#             torch.nn.Linear(256, num_classes)
#         )
#         self.loss_fn = torch.nn.CrossEntropyLoss()

### Loss functions

Create a loss function class, or use a standart one.

In [254]:
# # Cross entropy loss maden from scratch (just in case)
# class CELoss(nn.Module):
#     def __init__(self, params=hparams):
#         super(CELoss, self).__init__()
#         self.smoothing = params.get('label_smoothing', 0.1)  # Default smoothing value
#         self.num_classes = params.get('n_classes', 10)
#         self.inv_smoothing = 1.0 - self.smoothing  # Probability for the correct class

#     def forward(self, x, y):
#         """
#         x: Model output (logits)
#             - Shape: (batch_size, num_classes)
#         y: Labels
#             - Shape: (batch_size,)
#         """
#         # Apply label smoothing to the one-hot encoded labels
#         with torch.no_grad():
#             yoh = torch.zeros_like(x)  # Create a one-hot encoded version of y
#             yoh.fill_(self.smoothing / (self.num_classes - 1))  # Fill with smoothed values
#             yoh.scatter_(1, y.unsqueeze(1), self.inv_smoothing)  # Set correct class to 1 - smoothing

#         # Compute the cross-entropy loss between logits and smoothed labels
#         log_probs = F.log_softmax(x, dim=1)  # Log probabilities
#         loss = -(yoh * log_probs).sum(dim=1).mean()  # Sum over classes and mean over batch

#         return loss

In [255]:
# class NLoss(nn.Module):
#     def __init__(self, params=hparams):
#         super(NLoss, self).__init__()
#         self.smoothing =   params.get('label_smoothing', 0.0)
#         self.num_classes = params.get('n_classes', 10)
#         self.inv_smoothing = 1.0 - self.smoothing  # Probability for the correct class

#     def forward(self, x, y):
#         """
#         x: Model output (logits + log variance)
#             - x[:, :self.num_classes]: Logits for class probabilities (h)
#             - x[:, self.num_classes:]: Logarithmic variance (s)
#         y: Labels
#         """
#         # Split the model output into predictions (h) and log variance (s)
#         logits = x[:, :self.num_classes]  # Predictions (h)
#         log_var = x[:, self.num_classes:]  # Logarithmic variance (s)

#         # Apply label smoothing to the one-hot encoded labels
#         with torch.no_grad():
#             yoh = torch.zeros_like(logits)
#             yoh.fill_(self.smoothing / (self.num_classes - 1))
#             yoh.scatter_(1, y.data.unsqueeze(1), self.inv_smoothing)

#         # Compute the squared differences between predictions and smoothed labels
#         squared_diff = torch.pow(yoh - logits, 2)  # (y_k - h_k)^2

#         # Compute the exponential of the negative log variance (e^{-s})
#         exp_neg_log_var = torch.exp(-log_var)

#         # Compute the first term of the loss: e^{-s} * sum((y_k - h_k)^2)
#         term1 = exp_neg_log_var * squared_diff.sum(dim=1)

#         # Compute the second term of the loss: N * s
#         term2 = self.num_classes * log_var

#         # Combine the terms and compute the mean over the batch
#         loss = (term1 + term2).mean()

#         return loss

In [256]:
class BLoss(nn.Module):
    def __init__(self, params=hparams):
        super(BLoss, self).__init__()
        self.smoothing =   params.get('label_smoothing', 0.0)
        self.num_classes = params.get('n_classes', 10)
        self.inv_smoothing = 1.0 - self.smoothing  # Probability for the correct class


    def forward(self, x, y):
        # Extract certainty and probabilities from the model output
        certainty = torch.sigmoid(x[:, self.num_classes:])  # Certainty values
        logits = x[:, :self.num_classes]  # Logits for class probabilities
        prob = F.softmax(logits, dim=1)  # Softmax probabilities

        # Compute cosine similarity between predictions and labels
        cos = nn.CosineSimilarity(dim=1)

        # Apply label smoothing to the one-hot encoded labels
        with torch.no_grad():
            yoh = torch.zeros_like(logits)
            yoh.fill_(self.smoothing / (self.num_classes - 1))
            yoh.scatter_(1, y.data.unsqueeze(1), self.inv_smoothing)


        # Compute the terms of the loss
        cosyh = cos(yoh, prob)
        delta = yoh * prob  # Element-wise product of one-hot labels and probabilities
        entropy_term = delta * torch.log(delta + 1e-10)  # Entropy term (avoid log(0))

        # Loss terms
        loss0 = -cosyh * torch.log(certainty / self.num_classes + 1e-10)  # First term
        loss1 = -(self.num_classes - 1) * (1 - cosyh) * torch.log((1 - certainty) / self.num_classes + 1e-10)  # Second term

        # Combine the terms and compute the mean over the batch
        loss = (entropy_term.sum(dim=1) + loss0 + loss1).mean()

        return loss

### Models zoo

Architectures and loss functions

In [257]:
# def get_arch_and_loss(hparams):
#     """
#     Returns the architecture and loss function based on the provided hparams.

#     Args:
#         hparams (dict): Hyperparameters dictionary, including 'ARCHITECTURE' and 'criterion'.

#     Returns:
#         arch: The model architecture.
#         loss: The loss function.
#     """
#     # Determine the number of outputs based on the loss function
#     if hparams['criterion'] in ['B', 'N']:
#         n_outputs = hparams['n_classes'] + 1  # Add 1 output neuron for BLoss or NLoss
#     else:
#         n_outputs = hparams['n_classes']  # Default number of outputs

#     # Define the architectures
#     architectures = {
#         'CNN': CNN(n_outputs=n_outputs),
#         'ResNet50': ResNet50(n_outputs=n_outputs, freeze=hparams.get('freeze', True)),
#         'ViT': ViT(n_outputs=n_outputs, freeze=hparams.get('freeze', True)),
#     }

#     # Define the loss functions
#     losses = {
#         'CE':CELoss(),
#         'B': BLoss(),
#         'N': NLoss(),
#     }

#     # Get the architecture and loss based on hparams
#     arch = architectures.get(hparams['architecture'])
#     loss = losses.get(hparams['criterion'])

#     if arch is None:
#         raise ValueError(f"Architecture '{hparams['ARCHITECTURE']}' is not supported.")
#     if loss is None:
#         raise ValueError(f"Loss function '{hparams['criterion']}' is not supported.")

#     return arch, loss


### Metrics

In [258]:
# def metrics(dataloader,model,hparams=hparams,loss_fn_red=None):
#     # Collect images, predictions, and losses
#     # images = []
#     preds  = []
#     labels = []
#     losses = []
#     correct= 0
#     total  = 0
#     for batch in dataloader:
#         x, y, _ = batch
#         with torch.no_grad():
#             logits = model(x)
#             # loss = loss_fn_red(h,y)
#             pred = torch.argmax(logits[:,:hparams['n_classes']], dim=1)
#         correct += (pred == y).sum().item()  # Number of correct predictions
#         total += y.size(0)  # Total number of samples

#         # images.extend(x.cpu())
#         preds.extend(pred.cpu().numpy())
#         labels.extend(y.cpu().numpy())
#         # losses.extend(loss.cpu().numpy())
#     acc = correct / total
#     return preds, labels, acc

# Train with LOSO Cross-Validation

## Create Dataset and Data Loaders

Initialization of the dataset, the dataloader, and the training module

In [259]:
# Load data with NaN handling
merged_data = load_and_preprocess_data(data_dir)

# Verify data existence
if not merged_data:
    raise ValueError("No valid data found in directory!")

# Check available volunteers
print(f"Available volunteers: {list(merged_data.keys())}")

# Proceed with model training...

Using common labels: ['N2', 'R']



invalid value encountered in divide


invalid value encountered in divide



Available volunteers: ['Vol_02', 'Vol_03', 'Vol_01']



invalid value encountered in divide



## Train

In [261]:
# Initialize ClearML task
task = Task.init(
    project_name="ICML-2025",
    task_name=f"arch_{ARCHITECTURE}_loss_{LOSS_FUN}"
)

# Log hyperparameters
task.connect({
    "architecture": ARCHITECTURE,
    "loss_function": LOSS_FUN,
})

ClearML Task: created new task id=d1eec303ee994a83a60f6fc16b959aa3
ClearML results page: https://app.clear.ml/projects/ccaa059e6de442b6abe578eab9e214c8/experiments/d1eec303ee994a83a60f6fc16b959aa3/output/log
ClearML Monitor: GPU monitoring failed getting GPU reading, switching off GPU monitoring


{'architecture': 'SleepStageCNN', 'loss_function': 'CE'}

In [262]:
# Initialize components
datamodule = SleepDataModule(merged_data)
model = SleepClassifier()

# Setup data
datamodule.setup()

# Track LOSO accuracy
loso_accuracies = []

# LOSO Cross-Validation
for test_subject in datamodule.subjects:
    print(f"\n=== Training on {len(datamodule.subjects)-1} subjects, validating on {test_subject} ===")

    # Get data loaders for this split
    train_loader, val_loader = datamodule.get_loso_splits(test_subject)

    # Configure trainer
    trainer = Trainer(
        max_epochs=50,
        callbacks=[EarlyStopping(monitor="val_loss", patience=5)],
        enable_checkpointing=False
    )

    # Train/validate
    trainer.fit(model, train_loader, val_loader)

    # Get validation accuracy for this subject
    val_acc = trainer.callback_metrics["val_acc"].item()
    loso_accuracies.append(val_acc)

    # Log validation performance
    task.get_logger().report_scalar(
        title="Validation Accuracy",
        series=test_subject,
        value=val_acc,
        iteration=trainer.current_epoch
    )

    # Update task title with validation info
    task.set_name(
        f"arch_{ARCHITECTURE}_loss_{LOSS_FUN}_seed_{seed}_noise_{NOISE_TYPE}_val_acc_{val_acc:.4f}"
    )

# Compute final LOSO accuracy
final_loso_accuracy = np.mean(loso_accuracies)
print(f"\n=== Final LOSO Accuracy: {final_loso_accuracy:.4f} ===")

# Log final LOSO accuracy
task.get_logger().report_scalar(
    title="LOSO Accuracy",
    series="Final",
    value=final_loso_accuracy,
    iteration=0
)



=== Training on 2 subjects, validating on Vol_02 ===


INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


MisconfigurationException: No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.

In [None]:
# Compute confusion matrix and classification report
all_preds = []
all_labels = []

for test_subject in datamodule.subjects:
    val_loader = datamodule.get_loso_splits(test_subject)[1]  # Get validation loader
    preds = []
    labels = []
    for batch in val_loader:
        x, y = batch
        with torch.no_grad():
            y_hat = model(x)
            preds.extend(torch.argmax(y_hat, dim=1).cpu().numpy())
            labels.extend(y.cpu().numpy())

    all_preds.extend(preds)
    all_labels.extend(labels)

# Log confusion matrix
cm = confusion_matrix(all_labels, all_preds)
task.get_logger().report_matrix(
    title="Confusion Matrix",
    series="Final",
    matrix=cm,
    iteration=0
)

# Log classification report
report = classification_report(all_labels, all_preds, target_names=["NREM1", "NREM2", "REM", "Wake"])
task.get_logger().report_text(
    title="Classification Report",
    series="Final",
    value=report
)

In [None]:
# List to store predictions from each model
# all_predictions = []

In [None]:
# for seed in SEEDS:
#     # Set seed for reproducibility at the VERY BEGINNING
#     pl.seed_everything(seed)

#     # Reinitialize the model architecture for each seed
#     arch, loss_fn = get_arch_and_loss(hparams)
#     # archs_and_losses = get_arch_and_loss(hparams)

#     # arch, loss_fn = archs_and_losses[hparams['criterion']]['arch']
#     # loss_fn = archs_and_losses[hparams['criterion']]['loss']


#     checkpoint_callback_img = ModelCheckpoint(
#         monitor='val_loss',       # Monitor validation loss
#         dirpath=CHECKPOINT_PATH,  # Directory to save checkpoints
#         filename=f'best_model_{ARCHITECTURE}_{LOSS_FUN}_{seed}_{NOISE_TYPE}',  # Checkpoint filename
#         save_top_k=1,             # Save only the best model
#         mode='min',               # Minimize validation loss
#     )

#     task = Task.init(project_name="ICML-2025",
#                      task_name=f'arch_{ARCHITECTURE}_loss_{LOSS_FUN}_seed_{seed}_noise_{NOISE_TYPE}')

#     # Initialize the model with the reinitialized architecture
#     model = train_model(model=arch, loss=loss_fn)

#     # Log hyperparameters to ClearML
#     task.connect(model.hparams)

#     trainer = Trainer(max_epochs=hparams['num_epochs'],
#                       callbacks=[checkpoint_callback_img],
#                       accelerator="auto", devices="auto")
#     trainer.fit(model, data_module)

#     best_model_path = checkpoint_callback_img.best_model_path
#     task.update_output_model(model_path=best_model_path, auto_delete_file=False)
#     best_model = train_model.load_from_checkpoint(best_model_path,
#                                                   model=arch,
#                                                   loss=loss_fn)

#     # Test set
#     test_dataloader = data_module.test_dataloader()
#     # Move the model to the correct device
#     best_model = best_model.to(device)
#     predictions = []
#     with torch.no_grad():
#         for batch in test_dataloader:
#             x, _, _, = batch  # We only need the input data, not the labels
#             logits = best_model(x.to(device))
#             preds = torch.argmax(logits[:, :NUM_CLASSES], dim=1)
#             predictions.append(preds.cpu().numpy())
#     predictions = np.concatenate(predictions)  # Combine all batch predictions
#     all_predictions.append(predictions)

#     if seed != SEEDS[-1]:
#         task.close()
#         del[model, best_model, task, arch, loss_fn]

## Test the models and the ensemble of the models

In [None]:
# all_predictions

Individual models

In [None]:
# # List to store individual model accuracies
# individual_accuracies = []

# # Compute accuracy for each model
# for i, predictions in enumerate(all_predictions):
#     # Get predictions for the current model
#     model_predictions = predictions  # Shape: (num_samples,)

#     # Get true labels (already collected earlier)
#     true_labels = np.array(data_module.cifar10_test.targets)

#     # Calculate accuracy for the current model
#     accuracy = accuracy_score(true_labels, model_predictions)
#     individual_accuracies.append(accuracy)
#     print(f'Model {i+1} Accuracy: {accuracy:.4f}')

# # Convert to numpy array for easier calculations
# individual_accuracies = np.array(individual_accuracies)

# # Compute mean accuracy
# mean_accuracy = np.mean(individual_accuracies)

# # Compute standard deviation of accuracy
# std_accuracy = np.std(individual_accuracies)

# print(f'Mean Accuracy: {mean_accuracy:.4f}')
# print(f'Standard Deviation of Accuracy: {std_accuracy:.4f}')

Ensemble

In [None]:
# # Stack predictions from all models
# all_predictions = np.stack(all_predictions)  # Shape: (num_models, num_samples, num_classes)

# # Ensemble predictions (e.g., by averaging)
# ensemble_predictions = np.mean(all_predictions, axis=0)  # Shape: (num_samples, num_classes)
# final_predictions, _ = mode(all_predictions, axis=0)  # Majority voting
# final_predictions = final_predictions.flatten()  # Flatten to 1D array

# # Get true labels from the CIFAR-10 data set
# test_labels = np.array(data_module.cifar10_test.targets)
# # test_labels = data_module.test_dataset.labels  # Adjust this based on your dataset

# # Calculate accuracy
# accuracy = accuracy_score(test_labels, final_predictions)
# print(f'Ensemble Accuracy: {accuracy:.4f}')

# # Compute confusion matrix
# cm = confusion_matrix(test_labels, final_predictions)

In [None]:
# # Simulated test metrics
# test_metrics = {
#     "Mean Accuracy (individual)": mean_accuracy,
#     "Standard Deviation of Accuracy (individual)": std_accuracy,
#     "Ensemble Accuracy": accuracy,
# }

# task.connect(test_metrics)

In [None]:
task.close()