In [1]:
import torch
import torch.optim as optim
from tqdm import tqdm
import os
import sys
sys.path.append(os.getcwd().split('/pretrain_comparison/fine_tune/ahi_diagnosis')[0] + '/pretrain_comparison')
from fine_tune.models.model import SleepEventLSTMClassifier
from fine_tune.models.dataset import SleepEventClassificationDataset, finetune_collate_fn
from fine_tune.utils import *
from comparison.utils import *
import json
import math
import pandas as pd
import torch.nn.functional as F
import glob
import h5py
from torch.utils.data import Dataset, DataLoader

In [3]:
config_path = os.getcwd().split('/pretrain_comparison/fine_tune/ahi_diagnosis')[0] + '/pretrain_comparison/fine_tune/config_fine_tune.yaml'
config = load_data(config_path)
#config["batch_size"] = config["batch_size"] // 2

In [4]:




class SleepEventClassificationDataset(Dataset):
    def __init__(self, 
                 config,
                 channel_groups=None,
                 hdf5_paths=[],
                 split="train",
                 pretrain_type = "MAE",
                 specific_files = None):

        self.config = config
        #self.max_channels = self.config["max_channels"]
        self.context = int(self.config["context"])
        self.channel_like = self.config["channel_like"]

        #diagnosis, death, and demographics
        self.df_demographics = pd.read_csv(config['demographics_labels_path'])
        self.df_diagnosis_presence = pd.read_csv(os.path.join(config['diagnosis_labels_path'], 'is_event.csv'))
        self.df_diagnosis_time = pd.read_csv(os.path.join(config['diagnosis_labels_path'], 'time_to_event.csv'))
        self.df_death_presence = pd.read_csv(os.path.join(config['death_labels_path'], 'is_event.csv'), usecols=['Study ID','death'])
        self.df_death_time = pd.read_csv(os.path.join(config['death_labels_path'], 'time_to_event.csv'), usecols=['Study ID','death'])
        self.df_ahi = pd.read_csv(config['ahi_labels_path'])
        self.df_ahi['diagnosis'] = self.df_ahi.ahi.apply(lambda x: 1 if x >= 15 else 0)

        unique_study_ids_in_demo_diag_death = set(self.df_demographics['Study ID'].values).intersection(set(self.df_diagnosis_presence['Study ID'].values)).intersection(set(self.df_diagnosis_time['Study ID'].values)).intersection(set(self.df_death_presence['Study ID'].values)).intersection(set(self.df_death_time['Study ID'].values))
        unique_study_ids_in_demo_diag_death = set(self.df_ahi['Study ID'].values).intersection(unique_study_ids_in_demo_diag_death)

        labels_path = self.config["labels_path"]
        dataset = self.config["dataset"]
        dataset = dataset.split(",")

        label_files = []

        for dataset_name in dataset:
            label_files += glob.glob(os.path.join(labels_path, dataset_name, "**", "*.csv"), recursive=True)

        # label_files = [label_file for label_file in os.listdir(labels_path) if label_file.endswith(".csv")]

        hdf5_paths = load_data(config["split_path"])[split]
        #print(f'first hdf5_paths: {hdf5_paths[0]}')
        #print(f'len hdf5_paths: {len(hdf5_paths)}')
        #print(f'first label_files: {label_files[0]}')
        #print(f'len label_files: {len(label_files)}')
        study_ids = set([os.path.basename(label_file).split(".")[0] for label_file in label_files])
        #print(f'first study_ids: {list(study_ids)[0]}')
        #print(f'len study_ids: {len(study_ids)}')

        hdf5_paths = [f for f in hdf5_paths if os.path.exists(f)]
        #print(f'len hdf5_paths: {len(hdf5_paths)}')
        hdf5_paths = [f for f in hdf5_paths if f.split("/")[-1].split(".")[0] in study_ids]
        hdf5_paths = [f for f in hdf5_paths if f.split("/")[-1].split(".")[0] in unique_study_ids_in_demo_diag_death]
        #print(f'len hdf5_paths: {len(hdf5_paths)}')

        hdf5_paths_ids = set([os.path.basename(hdf5_path).split(".")[0] for hdf5_path in hdf5_paths])
        #print(f'first hdf5_paths_ids: {list(hdf5_paths_ids)[0]}')
        #print(f'len hdf5_paths_ids: {len(hdf5_paths_ids)}')

        hdf5_paths_new = []
        #print(f'dataset: {dataset}')
        #for dataset_name in dataset:
            #hdf5_paths_new += glob.glob(os.path.join(config["embedding_path"], dataset_name, "**", "*.hdf5"), recursive=True)
        hdf5_paths_new += glob.glob(os.path.join(config["embedding_path"], pretrain_type, "**", "*.hdf5"), recursive=True)
        #print(f'first hdf5_paths_new: {hdf5_paths_new[0]}')
        
        #print(f'len hdf5_paths_new: {len(hdf5_paths_new)}')
        
        hdf5_paths_new = [item for item in hdf5_paths_new if os.path.basename(item).split(".")[0] in hdf5_paths_ids]
        #print(f'len hdf5_paths_new: {len(hdf5_paths_new)}')
        hdf5_paths = hdf5_paths_new
        hdf5_paths = [f for f in hdf5_paths if os.path.exists(f)]
        #print(f'len hdf5_paths: {len(hdf5_paths)}')

        if config["max_files"]:
            hdf5_paths = hdf5_paths[:config["max_files"]]
        else:
            hdf5_paths = hdf5_paths

        labels_dict = {
            os.path.basename(item).split(".")[0]: item for item in label_files
        }
        if specific_files:
            #print(f'hdf5_paths[0] {hdf5_paths[0]}')
            #print(f'specific_files[0] {specific_files[0]}')
            
            # Extract base names from specific_files (without extension) for proper comparison
            specific_files_base = [os.path.splitext(f)[0] for f in specific_files]
            
            # Filter hdf5_paths to only include files whose base names are in specific_files
            hdf5_paths = [f for f in hdf5_paths if os.path.splitext(os.path.basename(f))[0] in specific_files_base]
            
            #print(f'number of specific_files: {len(hdf5_paths)}')
            repeats = max(1024 // len(specific_files), 1)
            
            # Repeat the hdf5 files
            hdf5_paths = [f for f in hdf5_paths for _ in range(repeats)]
            #print(f'number of training items per epoch: {len(hdf5_paths)}')
        if self.context == -1:
            self.index_map = [(path, labels_dict[path.split("/")[-1].split(".")[0]], -1) for path in hdf5_paths]
        else:
            self.index_map = []
            loop = tqdm(hdf5_paths[:], total=len(hdf5_paths), desc=f"Indexing {split} data")
            for hdf5_file_path in loop:
                file_prefix = os.path.basename(hdf5_file_path).split(".")[0]
                with h5py.File(hdf5_file_path, "r") as file:
                    dataset_names = list(file.keys())[:]
                    dataset_name = dataset_names[0]
                    dataset_length = file[dataset_name].shape[0]
                    for i in range(0, dataset_length, self.context):
                        self.index_map.append((hdf5_file_path, labels_dict[file_prefix], i))           
            
        #logger.info(f"Number of files in {split} set: {len(hdf5_paths)}")
        #logger.info(f"Number of files to be processed in {split} set: {len(self.index_map)}")
        self.total_len = len(self.index_map)
        self.max_seq_len = config["model_params"]["max_seq_length"]

    def __len__(self):
        return self.total_len

    def get_index_map(self):
        return self.index_map

    def __getitem__(self, idx):
        hdf5_path, label_path, start_index = self.index_map[idx]
        labels_df = pd.read_csv(label_path)
        y_data = labels_df["StageNumber"].to_numpy()
        if self.context != -1:
            y_data = y_data[start_index:start_index+self.context]
        x_data = []
        with h5py.File(hdf5_path, 'r') as hf:
            dset_names = list(hf.keys())[:]
            for dataset_name in dset_names:
                x_data.append(hf[dataset_name][:])
        x_data = np.array(x_data)
        # Convert x_data to tensor
        x_data = torch.tensor(x_data, dtype=torch.float32)
        y_data = torch.tensor(y_data, dtype=torch.float32)
        min_length = min(x_data.shape[1], len(y_data))
        x_data = x_data[:, :min_length, :].squeeze()
        y_data = y_data[:min_length]
        
        #diagnosis, death, and demographics
        study_id = os.path.basename(hdf5_path).split(".")[0]
        try:
            diagnosis_presence = torch.tensor(self.df_diagnosis_presence[self.df_diagnosis_presence['Study ID'] == study_id].values[0][1:].astype(np.float32))
            diagnosis_time = torch.tensor(self.df_diagnosis_time[self.df_diagnosis_time['Study ID'] == study_id].values[0][1:].astype(np.float32))
            death_presence = torch.tensor(self.df_death_presence[self.df_death_presence['Study ID'] == study_id].values[0][1:].astype(np.float32))
            death_time = torch.tensor(self.df_death_time[self.df_death_time['Study ID'] == study_id].values[0][1:].astype(np.float32))
            age = torch.tensor(self.df_demographics[self.df_demographics['Study ID'] == study_id]['Age at Study Date'].values) / 100
            ahi_diagnosis = torch.tensor(self.df_ahi[self.df_ahi['Study ID'] == study_id]['diagnosis'].values)
        except:
            print(f'Study ID {study_id} not found in demographics, diagnosis, or death data')

        
        return x_data, y_data, self.max_seq_len, hdf5_path, diagnosis_presence, diagnosis_time, death_presence, death_time, age, ahi_diagnosis

def finetune_collate_fn(batch):

    x_data, y_data, max_seq_len_list, hdf5_path_list, diagnosis_presence, diagnosis_time, death_presence, death_time, age, ahi_diagnosis = zip(*batch)

    # padding the temporal as in sleep_event_finetune_full_collate_fn
    max_seq_len_temp = max([item.size(0) for item in x_data])
    # Determine the max sequence length for padding
    if max_seq_len_list[0] is None:
        max_seq_len = max_seq_len_temp
    else:
        max_seq_len = min(max_seq_len_temp, max_seq_len_list[0])
    
    padded_x_data = []
    padded_y_data = []
    padded_mask = []
    diagnosis_presence_list = []
    diagnosis_time_list = []
    death_presence_list = []
    death_time_list = []
    age_list = []
    ahi_diagnosis_list = []

    for x_item, y_item, diagnosis_presence_item, diagnosis_time_item, death_presence_item, death_time_item, age_item, ahi_diagnosis_item  in zip(x_data, y_data, diagnosis_presence, diagnosis_time, death_presence, death_time, age, ahi_diagnosis):
        # Get the shape of x_item
        s, e = x_item.size()

        s = min(s, max_seq_len)

        # Create a padded tensor and a mask tensor for x_data
        padded_x_item = torch.zeros((max_seq_len, e))
        mask = torch.ones((max_seq_len))

        # Copy the actual data to the padded tensor and set the mask for real data
        padded_x_item[:s, :e] = x_item[:s, :e]
        mask[:s] = 0  # 0 for real data, 1 for padding

        # Pad y_data with zeros to match max_seq_len
        padded_y_item = torch.zeros(max_seq_len)
        padded_y_item[:s] = y_item[:s]

        # Append padded items to lists
        padded_x_data.append(padded_x_item)
        padded_y_data.append(padded_y_item)
        padded_mask.append(mask)
        diagnosis_presence_list.append(diagnosis_presence_item)
        diagnosis_time_list.append(diagnosis_time_item)
        death_presence_list.append(death_presence_item)
        death_time_list.append(death_time_item)
        age_list.append(age_item)
        ahi_diagnosis_list.append(ahi_diagnosis_item)



    # Stack all tensors into a batch
    x_data = torch.stack(padded_x_data)
    y_data = torch.stack(padded_y_data)
    padded_mask = torch.stack(padded_mask)

    diagnosis_presence = torch.stack(diagnosis_presence_list)
    diagnosis_time = torch.stack(diagnosis_time_list)
    death_presence = torch.tensor(death_presence_list).unsqueeze(1)
    death_time = torch.tensor(death_time_list).unsqueeze(1)
    age = torch.tensor(age_list).unsqueeze(1)
    ahi_diagnosis = torch.tensor(ahi_diagnosis_list).unsqueeze(1)
    
    return x_data, y_data, padded_mask, hdf5_path_list, diagnosis_presence, diagnosis_time, death_presence, death_time, age, ahi_diagnosis


In [5]:

import os
import sys
sys.path.append('/oak/stanford/groups/jamesz/magnusrk/pretraining_comparison')
from comparison.utils import *
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


#model classes
class PositionalEncoding(nn.Module):
    def __init__(self, max_seq_len, d_model):
        super().__init__()
        position = torch.arange(max_seq_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_seq_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # Shape: (1, max_seq_len, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x shape: (batch_size, seq_len, d_model)
        x = x + self.pe[:, :x.size(1), :]
        return x


class AttentionPooling(nn.Module):
    def __init__(self, input_dim, num_heads=1, dropout=0.1):
        super(AttentionPooling, self).__init__()
        self.transformer_layer = nn.TransformerEncoderLayer(
            d_model=input_dim, 
            nhead=num_heads, 
            dropout=dropout, 
            batch_first=True
        )

    def forward(self, x, key_padding_mask=None):
        batch_size, seq_len, input_dim = x.size()
        
        if key_padding_mask is not None:
            if key_padding_mask.size(1) == 1:
                return x.mean(dim=1)
            if key_padding_mask.dtype != torch.bool:
                key_padding_mask = key_padding_mask.to(dtype=torch.bool)
                
        transformer_output = self.transformer_layer(x, src_key_padding_mask=key_padding_mask)
        pooled_output = transformer_output.mean(dim=1)  # Average pooling over the sequence length
        
        return pooled_output

class SleepEventLSTMClassifier(nn.Module):
    def __init__(self, embed_dim, num_heads, num_layers, num_classes, pooling_head=4, dropout=0.1, max_seq_length=128):
        super(SleepEventLSTMClassifier, self).__init__()
        
        # Define spatial pooling
        #self.spatial_pooling = AttentionPooling(embed_dim, num_heads=pooling_head, dropout=dropout)

        # Set max sequence length
        if max_seq_length is None:
            max_seq_length = 20000
            
        self.positional_encoding = PositionalEncoding(max_seq_length, embed_dim)
        self.layer_norm = nn.LayerNorm(embed_dim)

        # Transformer encoder for spatial modeling
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dropout=dropout, batch_first=True, norm_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # LSTM for temporal modeling
        lstm_dropout = dropout if num_layers > 1 else 0.0
        self.lstm = nn.LSTM(input_size=embed_dim, hidden_size=embed_dim//2, num_layers=num_layers, batch_first=True, dropout=lstm_dropout, bidirectional=True)
        
        # Fully connected layer for sleep stage classification
        self.fc_sleep_stage = nn.Linear(embed_dim, num_classes)

        self.temporal_pooling = AttentionPooling(embed_dim, num_heads=pooling_head, dropout=dropout)

        self.fc_age = nn.Sequential(
            nn.Linear(embed_dim, 1),
            nn.Softplus()  # Ensures smooth, non-negative outputs
        )

        self.fc_ahi_diagnosis = nn.Linear(embed_dim, 1)

        self.fc_death = nn.Linear(embed_dim, 1)

        self.fc_diagnosis = nn.Linear(embed_dim, 12)

    def forward(self, x, mask):
        B, S, E = x.shape       
        device = x.device 

        # Apply positional encoding and layer normalization
        x = self.positional_encoding(x)
        x = self.layer_norm(x)

        # Apply transformer encoder for spatial modeling
        mask_temporal = mask[:, :]
        x = self.transformer_encoder(x, src_key_padding_mask=mask_temporal)

        # Apply LSTM for temporal modeling
        x, _ = self.lstm(x)  # Shape: (B, S, E)

        # Apply the final fully connected layer for classification
        sleep_stage = self.fc_sleep_stage(x)  # Shape: (B, S, num_classes)

        
        #x_diagnosis = self.temporal_pooling_diagnosis(x, mask_temporal)
        #x_death = self.temporal_pooling_death(x, mask_temporal)
        #x_age = self.temporal_pooling_age(x, mask_temporal)
        x = self.temporal_pooling(x, mask_temporal)
        hazards_death = self.fc_death(x)
        hazards_diagnosis = self.fc_diagnosis(x)
        age = self.fc_age(x)
        ahi_diagnosis = self.fc_ahi_diagnosis(x)

        return sleep_stage.to(device), mask[:, :].to(device), age.to(device), hazards_diagnosis.to(device), hazards_death.to(device), ahi_diagnosis.to(device)  # Return mask along temporal dimension

In [6]:
def run_iteration(model, data, optimizer=None, scaler=None, config=None, device=None, mode='train'):
    """
    Run one iteration (batch) of training or validation.
    
    Args:
        model: The PyTorch model
        data: Tuple of batch data
        optimizer: PyTorch optimizer (only needed for training)
        scaler: Gradient scaler for mixed precision training
        config: Configuration dictionary
        device: PyTorch device
        mode: Either 'train' or 'val'
    """
    is_training = mode == 'train'
    
    # Unpack the batch data
    x_data, y_data, mask, _, diagnosis_presence, diagnosis_time, death_presence, death_time, age_target, ahi_diagnosis_target = data
    
    # Move data to device
    x_data = x_data.to(device)
    y_data = y_data.to(device)
    mask = mask.bool().to(device)
    diagnosis_presence = diagnosis_presence.to(device)
    diagnosis_time = diagnosis_time.to(device)
    death_presence = death_presence.to(device)
    death_time = death_time.to(device)
    age_target = age_target.to(device)
    ahi_diagnosis_target = ahi_diagnosis_target.to(device)

    if is_training:
        optimizer.zero_grad()
        
    # Context manager for mixed precision training
    with torch.cuda.amp.autocast() if is_training else torch.no_grad():
        output, mask, age_out, hazards_diagnosis, hazards_death, ahi_diagnosis = model(x_data, mask)
        
        # Reshape outputs and targets
        output_reshaped = output.reshape(-1, config['model_params']['num_classes'])
        targets_reshaped = y_data.reshape(-1).long()
        
        # Handle masking
        if mask is not None:
            mask_reshaped = mask.reshape(-1)
            valid_targets = targets_reshaped != -1
            valid_mask = ~mask_reshaped & valid_targets
            # Force contiguous memory layout before indexing

            # If using DataParallel, ensure tensors are on the same device
            if isinstance(model, torch.nn.DataParallel):
                device = torch.device(f'cuda:{model.device_ids[0]}')
                output_reshaped = output_reshaped.to(device)
                valid_mask = valid_mask.to(device)

            valid_mask = valid_mask.contiguous()
            output_reshaped = output_reshaped.contiguous()
            targets_reshaped = targets_reshaped.contiguous()
            
            
            
            # if no valid targets set losses to 0 and return
            if targets_reshaped.size(0) == 0:
                loss = torch.tensor(0.0).to(device)
                metrics = {
                    'loss': loss.item(),
                    'loss_sleep_staging': loss.item(),
                    'loss_diagnosis': loss.item(),
                    'loss_death': loss.item(),
                    'loss_age': loss.item(),
                    'loss_ahi_diagnosis': loss.item(),
                    'correct': 0,
                    'total': 0,
                    'tp': torch.zeros(config['model_params']['num_classes']).to(device),
                    'fp': torch.zeros(config['model_params']['num_classes']).to(device),
                    'fn': torch.zeros(config['model_params']['num_classes']).to(device)
                }
                return metrics
        
        # Calculate losses
        if mode == 'train':
            loss_sleep_staging = torch.tensor(0.0, device=device, requires_grad=True)#masked_cross_entropy_loss(output, y_data, None)
            loss_diagnosis = cox_ph_loss(hazards_diagnosis, diagnosis_time, diagnosis_presence)
            loss_death = cox_ph_loss(hazards_death, death_time, death_presence)
            loss_age = F.mse_loss(age_target.float(), age_out.float())
            loss_ahi_diagnosis = F.binary_cross_entropy_with_logits(ahi_diagnosis, ahi_diagnosis_target.float())
            loss = loss_ahi_diagnosis
        else:
            with torch.no_grad():
                loss_sleep_staging = torch.tensor(0.0, device=device)#masked_cross_entropy_loss(output, y_data, mask)
                loss_diagnosis = cox_ph_loss(hazards_diagnosis, diagnosis_time, diagnosis_presence)
                loss_death = cox_ph_loss(hazards_death, death_time, death_presence)
                loss_age = F.mse_loss(age_target.float(), age_out.float())
                loss_ahi_diagnosis = F.binary_cross_entropy_with_logits(ahi_diagnosis, ahi_diagnosis_target.float())
                loss = loss_ahi_diagnosis

    # Handle backpropagation for training
    if is_training:
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()

    # Calculate metrics
    with torch.no_grad():
        _, predicted = torch.max(output_reshaped, 1)
        total = targets_reshaped.size(0)
        correct = (predicted == targets_reshaped).sum().item()
        
        # Calculate F1 components
        tp = torch.zeros(config['model_params']['num_classes']).to(device)
        fp = torch.zeros(config['model_params']['num_classes']).to(device)
        fn = torch.zeros(config['model_params']['num_classes']).to(device)
        
        for class_idx in range(config['model_params']['num_classes']):
            pred_mask = predicted == class_idx
            target_mask = targets_reshaped == class_idx
            
            tp[class_idx] += (pred_mask & target_mask).sum()
            fp[class_idx] += (pred_mask & ~target_mask).sum()
            fn[class_idx] += (~pred_mask & target_mask).sum()
    # Before returning, check for NaN values
    metrics = {
        'loss': loss.item(),
        'loss_sleep_staging': loss_sleep_staging.item(),
        'loss_diagnosis': loss_diagnosis.item(),
        'loss_death': loss_death.item(),
        'loss_age': loss_age.item(),
        'loss_ahi_diagnosis': loss_ahi_diagnosis.item(),
        'correct': correct,
        'total': total,
        'tp': tp,
        'fp': fp,
        'fn': fn
    }

    # Check for NaN values
    
    for key, value in metrics.items():
        if isinstance(value, (float, int)):
            if math.isnan(value):
                print(f"NaN detected in {key}")
                print(f"Debug info:")
                print(f"loss: {loss}")
                print(f"loss_sleep_staging: {loss_sleep_staging}")
                print(f"loss_diagnosis: {loss_diagnosis}")
                print(f"loss_death: {loss_death}")
                print(f"loss_age: {loss_age}")
                print(f"y_data: {y_data}")
                print(f"output: {output}")
                print(f"valid_mask: {valid_mask}")
                print(f"nan in y data: {torch.isnan(y_data).any()}")
                print(f"nan in output: {torch.isnan(output).any()}")
                print(f"nan in valid_mask: {torch.isnan(valid_mask).any()}")
                unique_targets_reshaped = torch.unique(targets_reshaped)
                print(f"unique targets: {unique_targets_reshaped}")
                unique_valid_mask = torch.unique(valid_mask)
                print(f"unique valid mask: {unique_valid_mask}")
                raise ValueError(f"NaN detected in {key}")
                

    return metrics


In [7]:
def train(model, train_loader, validation_loader, optimizer, scaler, config, device, patience=10):
    best_val_loss = float('inf')
    patience_counter = 0
    best_model_state = None

    for epoch in range(config['epochs']):
        # Training metrics
        train_metrics = {
            'running_loss': 0.0,
            'running_sleep_staging_loss': 0.0,
            'running_diagnosis_loss': 0.0,
            'running_death_loss': 0.0,
            'running_age_loss': 0.0,
            'running_ahi_diagnosis_loss': 0.0,
            'correct': 0,
            'total': 0,
            'tp': torch.zeros(config['model_params']['num_classes']).to(device),
            'fp': torch.zeros(config['model_params']['num_classes']).to(device),
            'fn': torch.zeros(config['model_params']['num_classes']).to(device)
        }

        # Training loop
        model.train()
        train_loop = tqdm(enumerate(train_loader), 
                            total=len(train_loader), 
                            desc=f'Epoch {epoch}/{config["epochs"]-1}',
                            leave=True,
                            ncols=250)
        
        for i, batch_data in train_loop:
            batch_metrics = run_iteration(model, batch_data, optimizer, scaler, config, device, mode='train')
            
            # Update running metrics
            train_metrics['running_loss'] += batch_metrics['loss']
            train_metrics['running_sleep_staging_loss'] += batch_metrics['loss_sleep_staging']
            train_metrics['running_diagnosis_loss'] += batch_metrics['loss_diagnosis']
            train_metrics['running_death_loss'] += batch_metrics['loss_death']
            train_metrics['running_age_loss'] += batch_metrics['loss_age']
            train_metrics['running_ahi_diagnosis_loss'] += batch_metrics['loss_ahi_diagnosis']
            train_metrics['correct'] += batch_metrics['correct']
            train_metrics['total'] += batch_metrics['total']
            train_metrics['tp'] += batch_metrics['tp']
            train_metrics['fp'] += batch_metrics['fp']
            train_metrics['fn'] += batch_metrics['fn']

            # Calculate current metrics for progress bar
            batch_count = i + 1
            avg_loss = train_metrics['running_loss'] / batch_count
            accuracy = train_metrics['correct'] / train_metrics['total'] if train_metrics['total'] > 0 else 0
            
            # Calculate F1 score
            precision = train_metrics['tp'] / (train_metrics['tp'] + train_metrics['fp'] + 1e-7)
            recall = train_metrics['tp'] / (train_metrics['tp'] + train_metrics['fn'] + 1e-7)
            f1 = 2 * (precision * recall) / (precision + recall + 1e-7)
            macro_f1 = f1.mean().item()

            train_loop.set_postfix({
                'loss': f'cur:{batch_metrics["loss"]:.3f}/avg:{avg_loss:.3f}',
                'sleep': f'cur:{batch_metrics["loss_sleep_staging"]:.3f}/acc:{accuracy:.3f}/f1:{macro_f1:.3f}',
                'diag': f'cur:{batch_metrics["loss_diagnosis"]:.3f}',
                'death': f'cur:{batch_metrics["loss_death"]:.3f}',
                'age': f'cur:{batch_metrics["loss_age"]:.3f}',
                'ahi': f'cur:{batch_metrics["loss_ahi_diagnosis"]:.3f}/avg:{train_metrics["running_ahi_diagnosis_loss"] / batch_count:.3f}'
            })

        # Validation loop
        val_metrics = {
            'running_loss': 0.0,
            'running_sleep_staging_loss': 0.0,
            'running_diagnosis_loss': 0.0,
            'running_death_loss': 0.0,
            'running_age_loss': 0.0,
            'running_ahi_diagnosis_loss': 0.0,
            'correct': 0,
            'total': 0,
            'tp': torch.zeros(config['model_params']['num_classes']).to(device),
            'fp': torch.zeros(config['model_params']['num_classes']).to(device),
            'fn': torch.zeros(config['model_params']['num_classes']).to(device)
        }

        model.eval()
        val_loop = tqdm(enumerate(validation_loader), 
                        total=len(validation_loader), 
                        desc=f'Validation Epoch {epoch}/{config["epochs"]-1}',
                        leave=True,
                        ncols=250)
        with torch.no_grad():
            for i, batch_data in val_loop:
                batch_metrics = run_iteration(model, batch_data, None, None, config, device, mode='val')
                    
                
                # Update validation metrics
                val_metrics['running_loss'] += batch_metrics['loss']
                val_metrics['running_sleep_staging_loss'] += batch_metrics['loss_sleep_staging']
                val_metrics['running_diagnosis_loss'] += batch_metrics['loss_diagnosis']
                val_metrics['running_death_loss'] += batch_metrics['loss_death']
                val_metrics['running_age_loss'] += batch_metrics['loss_age']
                val_metrics['running_ahi_diagnosis_loss'] += batch_metrics['loss_ahi_diagnosis']
                val_metrics['correct'] += batch_metrics['correct']
                val_metrics['total'] += batch_metrics['total']
                val_metrics['tp'] += batch_metrics['tp']
                val_metrics['fp'] += batch_metrics['fp']
                val_metrics['fn'] += batch_metrics['fn']

                # Calculate current metrics
                batch_count = i + 1
                avg_val_loss = val_metrics['running_loss'] / batch_count
                val_accuracy = val_metrics['correct'] / val_metrics['total'] if val_metrics['total'] > 0 else 0
                
                # Calculate F1 score
                precision = val_metrics['tp'] / (val_metrics['tp'] + val_metrics['fp'] + 1e-7)
                recall = val_metrics['tp'] / (val_metrics['tp'] + val_metrics['fn'] + 1e-7)
                f1 = 2 * (precision * recall) / (precision + recall + 1e-7)
                val_macro_f1 = f1.mean().item()

                val_loop.set_postfix({
                    'val_loss': f'cur:{batch_metrics["loss"]:.3f}/avg:{avg_val_loss:.3f}',
                    'sleep': f'cur:{batch_metrics["loss_sleep_staging"]:.3f}/acc:{val_accuracy:.3f}/f1:{val_macro_f1:.3f}',
                    'diag': f'cur:{batch_metrics["loss_diagnosis"]:.3f}',
                    'death': f'cur:{batch_metrics["loss_death"]:.3f}',
                    'age': f'cur:{batch_metrics["loss_age"]:.3f}',
                    'ahi': f'cur:{batch_metrics["loss_ahi_diagnosis"]:.3f}/avg:{val_metrics["running_ahi_diagnosis_loss"] / batch_count:.3f}'
                })

        # Early stopping check
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            best_model_state = model.state_dict()
        else:
            patience_counter += 1

        # Early stopping trigger
        if patience_counter >= patience:
            print(f'\nEarly stopping triggered after {epoch + 1} epochs')
            model.load_state_dict(best_model_state)
            break

        print(f'\nEpoch {epoch} Summary: Training Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}, F1: {macro_f1:.4f} Validation Loss: {avg_val_loss:.4f}, Accuracy: {val_accuracy:.4f}, F1: {val_macro_f1:.4f} Best validation loss: {best_val_loss:.4f} Patience counter: {patience_counter}/{patience}')

    print('\nTraining finished!')
    print(f'Best validation loss: {best_val_loss:.4f}') 
    return model

In [8]:
def save_model(model, optimizer, scaler, config, model_path):
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scaler_state_dict': scaler.state_dict(),
        'config': config
    }, model_path)
    print(f'Model saved at {model_path}')

In [9]:
def evaluate_and_save(model, test_loader, output_path, device):
    """
    Evaluate model on test set and save predictions and targets.
    
    Args:
        model: PyTorch model
        test_loader: DataLoader for test set
        output_path: Path to save results
        device: PyTorch device
    """
    model.eval()
    
    # Initialize lists to store predictions and targets
    #sleep_preds = []
    #sleep_targets = []
    #age_preds = []
    #age_targets = []
    ahi_diagnosis_preds = []
    ahi_diagnosis_targets = []
    
    with torch.no_grad():
        test_loop = tqdm(test_loader, desc='Evaluating', ncols=100)
        
        for x_data, y_data, mask, _, _, _, _, _, age_target, ahi_dignosis_target in test_loop:
            # Move data to device
            x_data = x_data.to(device)
            y_data = y_data.to(device)
            mask = mask.bool().to(device)
            age_target = age_target.to(device)
            
            # Forward pass
            output, mask, age_out, _, _, ahi_diagnosis = model(x_data, mask)
            
            # Process sleep staging predictions
            output_reshaped = output.reshape(-1, output.size(-1))
            targets_reshaped = y_data.reshape(-1).long()
            
            # Apply masking
            if mask is not None:
                mask_reshaped = mask.reshape(-1)
                valid_targets = targets_reshaped != -1
                valid_mask = ~mask_reshaped & valid_targets
                
                output_reshaped = output_reshaped[valid_mask]
                targets_reshaped = targets_reshaped[valid_mask]
            
            # Get predictions
            _, predicted = torch.max(output_reshaped, 1)
            
            # Store predictions and targets
            #sleep_preds.extend(predicted.cpu().numpy().tolist())
            #sleep_targets.extend(targets_reshaped.cpu().numpy().tolist())
            #age_preds.extend(age_out.cpu().numpy().flatten().tolist())
            #age_targets.extend(age_target.cpu().numpy().flatten().tolist())
            ahi_diagnosis_preds.extend(torch.sigmoid(ahi_diagnosis).cpu().numpy().flatten().tolist())
            ahi_diagnosis_targets.extend(ahi_dignosis_target.cpu().numpy().flatten().tolist())
    
    # Save results
    results = {
        #'sleep_predictions': sleep_preds,
        #'sleep_targets': sleep_targets,
        #'age_predictions': age_preds,
        #'age_targets': age_targets,
        'ahi_diagnosis_predictions': ahi_diagnosis_preds,
        'ahi_diagnosis_targets': ahi_diagnosis_targets
    }
    
    # Save as numpy arrays
    np.save(output_path, results)
    print(f'Results saved to {output_path}')
    
    
    return results

In [10]:
config['lr'] = config['lr'] / 2
config['epochs'] = 40
config['patience'] = 8
config['wandb'] = False
config['num_workers'] = 8

In [11]:
# Replace the model initialization section with:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Get the number of available GPUs
num_gpus = torch.cuda.device_count()
print(f"Number of available GPUs: {num_gpus}")
for pretrain_type in config['pretrain_type']:
    print(f'Fine-tuning model with pretrain type: {pretrain_type} for AHI diagnosis')
    
    # Create datasets and dataloaders - note the increased batch size
    train_dataset = SleepEventClassificationDataset(config, split="pretrain", pretrain_type=pretrain_type)
    validation_dataset = SleepEventClassificationDataset(config, split="validation", pretrain_type=pretrain_type)
    test_dataset = SleepEventClassificationDataset(config, split="test", pretrain_type=pretrain_type)
    
    # Multiply batch size by number of GPUs since DataParallel splits it automatically
    train_loader = DataLoader(train_dataset, 
                            batch_size=config['batch_size'], 
                            shuffle=True, 
                            num_workers=config['num_workers'], 
                            collate_fn=finetune_collate_fn,
                            #pin_memory=True, 
                           drop_last=True)
    
    validation_loader = DataLoader(validation_dataset, 
                                 batch_size=(config['batch_size']), 
                                 shuffle=False, 
                                 num_workers=config['num_workers'], 
                                 collate_fn=finetune_collate_fn,
                                 #pin_memory=True, 
                                 drop_last=True)
    
    test_loader = DataLoader(test_dataset, 
                           batch_size=(config['batch_size']), 
                           shuffle=False, 
                           num_workers=config['num_workers'], 
                           collate_fn=finetune_collate_fn,
                           #pin_memory=True, 
                           drop_last=True)

    model = SleepEventLSTMClassifier(
        embed_dim=config['model_params']['embed_dim'],
        num_heads=config['model_params']['num_heads'],
        num_layers=config['model_params']['num_layers'],
        num_classes=config['model_params']['num_classes'],
        pooling_head=config['model_params']['pooling_head'],
        dropout=config['model_params']['dropout'],
        max_seq_length=config['model_params']['max_seq_length']
    )
    
    # Wrap model with DataParallel before moving to device
    if num_gpus > 1:
        model = torch.nn.DataParallel(model, device_ids=[0, 1])  # Explicitly specify GPU devices
    model = model.to(device)
    
    # Scale learning rate with number of GPUs
    optimizer = optim.AdamW(model.parameters(), lr=config['lr'])
    model.train()
    scaler = torch.cuda.amp.GradScaler()

    best_model = train(model, train_loader, validation_loader, optimizer, scaler, config, device, patience=config['patience'])
    
    #save_path = f'/scratch/users/magnusrk/pretraining_comparision/final_embeddings/{pretrain_type}/ahi_diagnosis_model.pt'
    save_path = os.path.join(config['save_path'], f'{pretrain_type}/ahi_diagnosis_model.pt')
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    save_model(best_model, optimizer, scaler, config, save_path)

    #output_path = f'/oak/stanford/groups/jamesz/magnusrk/pretraining_comparison_data/ahi_results/{pretrain_type}_ahi_diagnosis_results.npy'
    output_path = os.path.join(config['save_path'], f'{pretrain_type}/ahi_diagnosis_results.npy')
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    evaluate_and_save(best_model, test_loader, output_path, device)

Using device: cuda
Number of available GPUs: 1
Fine-tuning model with pretrain type: CL_pairwise_epochs_36 for AHI diagnosis


  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast() if is_training else torch.no_grad():
Epoch 0/39: 100%|██████████████████████████████████████████████████████████████████| 18/18 [00:49<00:00,  2.72s/it, loss=cur:0.848/avg:1.156, sleep=cur:0.000/acc:0.237/f1:0.176, diag=cur:0.000, death=cur:0.000, age=cur:0.261, ahi=cur:0.848/avg:1.156]
Validation Epoch 0/39: 100%|█████████████████████████████████████████████████████| 9/9 [00:09<00:00,  1.09s/it, val_loss=cur:0.585/avg:0.958, sleep=cur:0.000/acc:0.212/f1:0.154, diag=cur:0.208, death=cur:0.000, age=cur:0.130, ahi=cur:0.585/avg:0.958]


Epoch 0 Summary: Training Loss: 1.1565, Accuracy: 0.2374, F1: 0.1759 Validation Loss: 0.9580, Accuracy: 0.2122, F1: 0.1544 Best validation loss: 0.9580 Patience counter: 0/8



Epoch 1/39: 100%|██████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.30it/s, loss=cur:0.687/avg:0.586, sleep=cur:0.000/acc:0.197/f1:0.153, diag=cur:0.355, death=cur:0.000, age=cur:0.243, ahi=cur:0.687/avg:0.586]
Validation Epoch 1/39: 100%|█████████████████████████████████████████████████████| 9/9 [00:02<00:00,  3.67it/s, val_loss=cur:1.198/avg:1.427, sleep=cur:0.000/acc:0.228/f1:0.156, diag=cur:0.200, death=cur:0.000, age=cur:0.085, ahi=cur:1.198/avg:1.427]


Epoch 1 Summary: Training Loss: 0.5858, Accuracy: 0.1969, F1: 0.1531 Validation Loss: 1.4273, Accuracy: 0.2284, F1: 0.1559 Best validation loss: 0.9580 Patience counter: 1/8



Epoch 2/39: 100%|██████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.28it/s, loss=cur:0.052/avg:0.670, sleep=cur:0.000/acc:0.249/f1:0.170, diag=cur:0.461, death=cur:0.000, age=cur:0.270, ahi=cur:0.052/avg:0.670]
Validation Epoch 2/39: 100%|█████████████████████████████████████████████████████| 9/9 [00:02<00:00,  3.65it/s, val_loss=cur:1.630/avg:2.054, sleep=cur:0.000/acc:0.275/f1:0.166, diag=cur:0.212, death=cur:0.000, age=cur:0.261, ahi=cur:1.630/avg:2.054]


Epoch 2 Summary: Training Loss: 0.6699, Accuracy: 0.2494, F1: 0.1697 Validation Loss: 2.0537, Accuracy: 0.2749, F1: 0.1657 Best validation loss: 0.9580 Patience counter: 2/8



Epoch 3/39: 100%|██████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.25it/s, loss=cur:0.127/avg:0.460, sleep=cur:0.000/acc:0.253/f1:0.169, diag=cur:0.076, death=cur:0.000, age=cur:0.290, ahi=cur:0.127/avg:0.460]
Validation Epoch 3/39: 100%|█████████████████████████████████████████████████████| 9/9 [00:02<00:00,  3.64it/s, val_loss=cur:1.138/avg:1.710, sleep=cur:0.000/acc:0.267/f1:0.168, diag=cur:0.214, death=cur:0.000, age=cur:0.116, ahi=cur:1.138/avg:1.710]


Epoch 3 Summary: Training Loss: 0.4597, Accuracy: 0.2531, F1: 0.1694 Validation Loss: 1.7100, Accuracy: 0.2666, F1: 0.1678 Best validation loss: 0.9580 Patience counter: 3/8



Epoch 4/39: 100%|██████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.24it/s, loss=cur:0.228/avg:0.329, sleep=cur:0.000/acc:0.250/f1:0.176, diag=cur:0.000, death=cur:0.000, age=cur:0.168, ahi=cur:0.228/avg:0.329]
Validation Epoch 4/39: 100%|█████████████████████████████████████████████████████| 9/9 [00:02<00:00,  3.66it/s, val_loss=cur:1.029/avg:2.368, sleep=cur:0.000/acc:0.226/f1:0.155, diag=cur:0.209, death=cur:0.000, age=cur:0.113, ahi=cur:1.029/avg:2.368]


Epoch 4 Summary: Training Loss: 0.3287, Accuracy: 0.2496, F1: 0.1763 Validation Loss: 2.3684, Accuracy: 0.2255, F1: 0.1550 Best validation loss: 0.9580 Patience counter: 4/8



Epoch 5/39: 100%|██████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.27it/s, loss=cur:0.718/avg:0.240, sleep=cur:0.000/acc:0.229/f1:0.167, diag=cur:0.690, death=cur:0.000, age=cur:0.270, ahi=cur:0.718/avg:0.240]
Validation Epoch 5/39: 100%|█████████████████████████████████████████████████████| 9/9 [00:02<00:00,  3.65it/s, val_loss=cur:0.600/avg:1.795, sleep=cur:0.000/acc:0.262/f1:0.170, diag=cur:0.246, death=cur:0.000, age=cur:0.126, ahi=cur:0.600/avg:1.795]


Epoch 5 Summary: Training Loss: 0.2400, Accuracy: 0.2293, F1: 0.1667 Validation Loss: 1.7948, Accuracy: 0.2622, F1: 0.1697 Best validation loss: 0.9580 Patience counter: 5/8



Epoch 6/39: 100%|██████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.29it/s, loss=cur:0.571/avg:0.177, sleep=cur:0.000/acc:0.228/f1:0.163, diag=cur:0.625, death=cur:0.000, age=cur:0.055, ahi=cur:0.571/avg:0.177]
Validation Epoch 6/39: 100%|█████████████████████████████████████████████████████| 9/9 [00:02<00:00,  3.66it/s, val_loss=cur:1.608/avg:2.186, sleep=cur:0.000/acc:0.275/f1:0.168, diag=cur:0.242, death=cur:0.000, age=cur:0.128, ahi=cur:1.608/avg:2.186]


Epoch 6 Summary: Training Loss: 0.1771, Accuracy: 0.2279, F1: 0.1632 Validation Loss: 2.1863, Accuracy: 0.2755, F1: 0.1681 Best validation loss: 0.9580 Patience counter: 6/8



Epoch 7/39: 100%|██████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.29it/s, loss=cur:0.138/avg:0.081, sleep=cur:0.000/acc:0.235/f1:0.165, diag=cur:0.441, death=cur:1.302, age=cur:0.231, ahi=cur:0.138/avg:0.081]
Validation Epoch 7/39: 100%|█████████████████████████████████████████████████████| 9/9 [00:02<00:00,  3.68it/s, val_loss=cur:2.139/avg:3.041, sleep=cur:0.000/acc:0.245/f1:0.159, diag=cur:0.245, death=cur:0.000, age=cur:0.090, ahi=cur:2.139/avg:3.041]


Epoch 7 Summary: Training Loss: 0.0813, Accuracy: 0.2349, F1: 0.1646 Validation Loss: 3.0409, Accuracy: 0.2452, F1: 0.1591 Best validation loss: 0.9580 Patience counter: 7/8



Epoch 8/39: 100%|██████████████████████████████████████████████████████████████████| 18/18 [00:05<00:00,  3.29it/s, loss=cur:0.001/avg:0.077, sleep=cur:0.000/acc:0.228/f1:0.155, diag=cur:0.540, death=cur:0.000, age=cur:0.117, ahi=cur:0.001/avg:0.077]
Validation Epoch 8/39: 100%|█████████████████████████████████████████████████████| 9/9 [00:02<00:00,  3.67it/s, val_loss=cur:2.561/avg:2.436, sleep=cur:0.000/acc:0.296/f1:0.171, diag=cur:0.236, death=cur:0.000, age=cur:0.088, ahi=cur:2.561/avg:2.436]



Early stopping triggered after 9 epochs

Training finished!
Best validation loss: 0.9580
Model saved at /oak/stanford/groups/mignot/projects/SleepBenchTest/pretrain_comparison/output/results/CL_pairwise_epochs_36/ahi_diagnosis_model.pt


Evaluating: 100%|█████████████████████████████████████████████████████| 9/9 [00:03<00:00,  2.32it/s]

Results saved to /oak/stanford/groups/mignot/projects/SleepBenchTest/pretrain_comparison/output/results/CL_pairwise_epochs_36/ahi_diagnosis_results.npy





In [19]:
#output_path = f'/oak/stanford/groups/jamesz/magnusrk/pretraining_comparison_data/ahi_results/{pretrain_type}_ahi_diagnosis_results.npy'
#evaluate_and_save(model, test_loader, output_path, device)