In [16]:
import torch
import torch.optim as optim
from tqdm import tqdm
import os
import sys
sys.path.append(os.getcwd().split('/pretrain_comparison/fine_tune/sleep_stage_and_age')[0] + '/pretrain_comparison')
from fine_tune.models.model import SleepEventLSTMClassifier
from fine_tune.models.dataset import SleepEventClassificationDataset, finetune_collate_fn
from fine_tune.utils import *
from comparison.utils import *
import json
from torch.utils.data import DataLoader
import math

In [17]:
config_path = os.getcwd().split('/pretrain_comparison/fine_tune/sleep_stage_and_age')[0] + '/pretrain_comparison/fine_tune/config_fine_tune.yaml'
config = load_data(config_path)

In [18]:
def run_iteration(model, data, optimizer=None, scaler=None, config=None, device=None, mode='train'):
    """
    Run one iteration (batch) of training or validation.
    
    Args:
        model: The PyTorch model
        data: Tuple of batch data
        optimizer: PyTorch optimizer (only needed for training)
        scaler: Gradient scaler for mixed precision training
        config: Configuration dictionary
        device: PyTorch device
        mode: Either 'train' or 'val'
    """
    is_training = mode == 'train'
    
    # Unpack the batch data
    x_data, y_data, mask, _, diagnosis_presence, diagnosis_time, death_presence, death_time, age_target = data

    # Compute norms along the feature dimension
    norms = x_data.norm(dim=2, keepdim=True)  # Shape: (batch, sequence, 1)

    # Normalize the vectors along the feature dimension
    x_data_normalized = x_data / (norms + 1e-8)
    
    # Move data to device
    x_data = x_data.to(device)
    y_data = y_data.to(device)
    mask = mask.bool().to(device)
    diagnosis_presence = diagnosis_presence.to(device)
    diagnosis_time = diagnosis_time.to(device)
    death_presence = death_presence.to(device)
    death_time = death_time.to(device)
    age_target = age_target.to(device)

    if is_training:
        optimizer.zero_grad()
        
    # Context manager for mixed precision training
    with torch.cuda.amp.autocast() if is_training else torch.no_grad():
        output, mask, age_out, hazards_diagnosis, hazards_death = model(x_data, mask)
        
        # Reshape outputs and targets
        output_reshaped = output.reshape(-1, config['model_params']['num_classes'])
        targets_reshaped = y_data.reshape(-1).long()
        
        # Handle masking
        if mask is not None:
            mask_reshaped = mask.reshape(-1)
            valid_targets = targets_reshaped != -1
            valid_mask = ~mask_reshaped & valid_targets
            
            output_reshaped = output_reshaped[valid_mask]
            targets_reshaped = targets_reshaped[valid_mask]
            # if no valid targets set losses to 0 and return
            if targets_reshaped.size(0) == 0:
                loss = torch.tensor(0.0).to(device)
                metrics = {
                    'loss': loss.item(),
                    'loss_sleep_staging': loss.item(),
                    'loss_diagnosis': loss.item(),
                    'loss_death': loss.item(),
                    'loss_age': loss.item(),
                    'correct': 0,
                    'total': 0,
                    'tp': torch.zeros(config['model_params']['num_classes']).to(device),
                    'fp': torch.zeros(config['model_params']['num_classes']).to(device),
                    'fn': torch.zeros(config['model_params']['num_classes']).to(device)
                }
                return metrics
        
        # Calculate losses
        loss_sleep_staging = masked_cross_entropy_loss(output, y_data, valid_mask)
        loss_diagnosis = cox_ph_loss(hazards_diagnosis, diagnosis_time, diagnosis_presence)
        loss_death = cox_ph_loss(hazards_death, death_time, death_presence)
        loss_age = F.mse_loss(age_target.float(), age_out.float())
        loss = loss_sleep_staging + loss_age * 10

    # Handle backpropagation for training
    if is_training:
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()

    # Calculate metrics
    with torch.no_grad():
        _, predicted = torch.max(output_reshaped, 1)
        total = targets_reshaped.size(0)
        correct = (predicted == targets_reshaped).sum().item()
        
        # Calculate F1 components
        tp = torch.zeros(config['model_params']['num_classes']).to(device)
        fp = torch.zeros(config['model_params']['num_classes']).to(device)
        fn = torch.zeros(config['model_params']['num_classes']).to(device)
        
        for class_idx in range(config['model_params']['num_classes']):
            pred_mask = predicted == class_idx
            target_mask = targets_reshaped == class_idx
            
            tp[class_idx] += (pred_mask & target_mask).sum()
            fp[class_idx] += (pred_mask & ~target_mask).sum()
            fn[class_idx] += (~pred_mask & target_mask).sum()
    # Before returning, check for NaN values
    metrics = {
        'loss': loss.item(),
        'loss_sleep_staging': loss_sleep_staging.item(),
        'loss_diagnosis': loss_diagnosis.item(),
        'loss_death': loss_death.item(),
        'loss_age': loss_age.item(),
        'correct': correct,
        'total': total,
        'tp': tp,
        'fp': fp,
        'fn': fn
    }

    # Check for NaN values
    for key, value in metrics.items():
        if isinstance(value, (float, int)):
            if math.isnan(value):
                print(f"NaN detected in {key}")
                print(f"Debug info:")
                print(f"loss: {loss}")
                print(f"loss_sleep_staging: {loss_sleep_staging}")
                print(f"loss_diagnosis: {loss_diagnosis}")
                print(f"loss_death: {loss_death}")
                print(f"loss_age: {loss_age}")
                print(f"y_data: {y_data}")
                print(f"output: {output}")
                print(f"valid_mask: {valid_mask}")
                print(f"nan in y data: {torch.isnan(y_data).any()}")
                print(f"nan in output: {torch.isnan(output).any()}")
                print(f"nan in valid_mask: {torch.isnan(valid_mask).any()}")
                unique_targets_reshaped = torch.unique(targets_reshaped)
                print(f"unique targets: {unique_targets_reshaped}")
                unique_valid_mask = torch.unique(valid_mask)
                print(f"unique valid mask: {unique_valid_mask}")
                raise ValueError(f"NaN detected in {key}")

    return metrics


In [19]:
def train(model, train_loader, validation_loader, optimizer, scaler, config, device, patience=10):
    best_val_loss = float('inf')
    patience_counter = 0
    best_model_state = None

    for epoch in range(config['epochs']):
        # Training metrics
        train_metrics = {
            'running_loss': 0.0,
            'running_sleep_staging_loss': 0.0,
            'running_diagnosis_loss': 0.0,
            'running_death_loss': 0.0,
            'running_age_loss': 0.0,
            'correct': 0,
            'total': 0,
            'tp': torch.zeros(config['model_params']['num_classes']).to(device),
            'fp': torch.zeros(config['model_params']['num_classes']).to(device),
            'fn': torch.zeros(config['model_params']['num_classes']).to(device)
        }

        # Training loop
        model.train()
        train_loop = tqdm(enumerate(train_loader), 
                            total=len(train_loader), 
                            desc=f'Epoch {epoch}/{config["epochs"]-1}',
                            leave=True,
                            ncols=250)
        
        for i, batch_data in train_loop:
            batch_metrics = run_iteration(model, batch_data, optimizer, scaler, config, device, mode='train')
            
            # Update running metrics
            train_metrics['running_loss'] += batch_metrics['loss']
            train_metrics['running_sleep_staging_loss'] += batch_metrics['loss_sleep_staging']
            train_metrics['running_diagnosis_loss'] += batch_metrics['loss_diagnosis']
            train_metrics['running_death_loss'] += batch_metrics['loss_death']
            train_metrics['running_age_loss'] += batch_metrics['loss_age']
            train_metrics['correct'] += batch_metrics['correct']
            train_metrics['total'] += batch_metrics['total']
            train_metrics['tp'] += batch_metrics['tp']
            train_metrics['fp'] += batch_metrics['fp']
            train_metrics['fn'] += batch_metrics['fn']

            # Calculate current metrics for progress bar
            batch_count = i + 1
            avg_loss = train_metrics['running_loss'] / batch_count
            accuracy = train_metrics['correct'] / train_metrics['total'] if train_metrics['total'] > 0 else 0
            
            # Calculate F1 score
            precision = train_metrics['tp'] / (train_metrics['tp'] + train_metrics['fp'] + 1e-7)
            recall = train_metrics['tp'] / (train_metrics['tp'] + train_metrics['fn'] + 1e-7)
            f1 = 2 * (precision * recall) / (precision + recall + 1e-7)
            macro_f1 = f1.mean().item()

            train_loop.set_postfix({
                'loss': f'cur:{batch_metrics["loss"]:.3f}/avg:{avg_loss:.3f}',
                'sleep': f'cur:{batch_metrics["loss_sleep_staging"]:.3f}/acc:{accuracy:.3f}/f1:{macro_f1:.3f}',
                'diag': f'cur:{batch_metrics["loss_diagnosis"]:.3f}',
                'death': f'cur:{batch_metrics["loss_death"]:.3f}',
                'age': f'cur:{batch_metrics["loss_age"]:.3f}'
            })

        # Validation loop
        val_metrics = {
            'running_loss': 0.0,
            'running_sleep_staging_loss': 0.0,
            'running_diagnosis_loss': 0.0,
            'running_death_loss': 0.0,
            'running_age_loss': 0.0,
            'correct': 0,
            'total': 0,
            'tp': torch.zeros(config['model_params']['num_classes']).to(device),
            'fp': torch.zeros(config['model_params']['num_classes']).to(device),
            'fn': torch.zeros(config['model_params']['num_classes']).to(device)
        }

        model.eval()
        val_loop = tqdm(enumerate(validation_loader), 
                        total=len(validation_loader), 
                        desc=f'Validation Epoch {epoch}/{config["epochs"]-1}',
                        leave=True,
                        ncols=250)
        with torch.no_grad():
            for i, batch_data in val_loop:
                batch_metrics = run_iteration(model, batch_data, None, None, config, device, mode='val')
                    
                
                # Update validation metrics
                val_metrics['running_loss'] += batch_metrics['loss']
                val_metrics['running_sleep_staging_loss'] += batch_metrics['loss_sleep_staging']
                val_metrics['running_diagnosis_loss'] += batch_metrics['loss_diagnosis']
                val_metrics['running_death_loss'] += batch_metrics['loss_death']
                val_metrics['running_age_loss'] += batch_metrics['loss_age']
                val_metrics['correct'] += batch_metrics['correct']
                val_metrics['total'] += batch_metrics['total']
                val_metrics['tp'] += batch_metrics['tp']
                val_metrics['fp'] += batch_metrics['fp']
                val_metrics['fn'] += batch_metrics['fn']

                # Calculate current metrics
                batch_count = i + 1
                avg_val_loss = val_metrics['running_loss'] / batch_count
                val_accuracy = val_metrics['correct'] / val_metrics['total'] if val_metrics['total'] > 0 else 0
                
                # Calculate F1 score
                precision = val_metrics['tp'] / (val_metrics['tp'] + val_metrics['fp'] + 1e-7)
                recall = val_metrics['tp'] / (val_metrics['tp'] + val_metrics['fn'] + 1e-7)
                f1 = 2 * (precision * recall) / (precision + recall + 1e-7)
                val_macro_f1 = f1.mean().item()

                val_loop.set_postfix({
                    'val_loss': f'cur:{batch_metrics["loss"]:.3f}/avg:{avg_val_loss:.3f}',
                    'sleep': f'cur:{batch_metrics["loss_sleep_staging"]:.3f}/acc:{val_accuracy:.3f}/f1:{val_macro_f1:.3f}',
                    'diag': f'cur:{batch_metrics["loss_diagnosis"]:.3f}',
                    'death': f'cur:{batch_metrics["loss_death"]:.3f}',
                    'age': f'cur:{batch_metrics["loss_age"]:.3f}'
                })

        # Log to wandb if enabled
        if config['wandb']:
            wandb.log({
                'train/loss': avg_loss,
                'train/accuracy': accuracy,
                'train/f1_score': macro_f1,
                'val/loss': avg_val_loss,
                'val/accuracy': val_accuracy,
                'val/f1_score': val_macro_f1,
                'epoch': epoch
            })

        # Early stopping check
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            best_model_state = model.state_dict()
        else:
            patience_counter += 1

        # Early stopping trigger
        if patience_counter >= patience:
            print(f'\nEarly stopping triggered after {epoch + 1} epochs')
            model.load_state_dict(best_model_state)
            break

        print(f'\nEpoch {epoch} Summary: Training Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}, F1: {macro_f1:.4f} Validation Loss: {avg_val_loss:.4f}, Accuracy: {val_accuracy:.4f}, F1: {val_macro_f1:.4f} Best validation loss: {best_val_loss:.4f} Patience counter: {patience_counter}/{patience}')

    print('\nTraining finished!')
    print(f'Best validation loss: {best_val_loss:.4f}') 
    return model

In [20]:
def evaluate_and_save(model, test_loader, output_path, device):
    """
    Evaluate model on test set and save predictions and targets.
    
    Args:
        model: PyTorch model
        test_loader: DataLoader for test set
        output_path: Path to save results
        device: PyTorch device
    """
    model.eval()
    
    # Initialize lists to store predictions and targets
    sleep_preds = []
    sleep_targets = []
    age_preds = []
    age_targets = []
    
    with torch.no_grad():
        test_loop = tqdm(test_loader, desc='Evaluating', ncols=100)
        
        for x_data, y_data, mask, _, _, _, _, _, age_target in test_loop:
            # Move data to device
            x_data = x_data.to(device)
            y_data = y_data.to(device)
            mask = mask.bool().to(device)
            age_target = age_target.to(device)
            
            # Forward pass
            output, mask, age_out, _, _ = model(x_data, mask)
            
            # Process sleep staging predictions
            output_reshaped = output.reshape(-1, output.size(-1))
            targets_reshaped = y_data.reshape(-1).long()
            
            # Apply masking
            if mask is not None:
                mask_reshaped = mask.reshape(-1)
                valid_targets = targets_reshaped != -1
                valid_mask = ~mask_reshaped & valid_targets
                
                output_reshaped = output_reshaped[valid_mask]
                targets_reshaped = targets_reshaped[valid_mask]
            
            # Get predictions
            _, predicted = torch.max(output_reshaped, 1)
            
            # Store predictions and targets
            sleep_preds.extend(predicted.cpu().numpy().tolist())
            sleep_targets.extend(targets_reshaped.cpu().numpy().tolist())
            age_preds.extend(age_out.cpu().numpy().flatten().tolist())
            age_targets.extend(age_target.cpu().numpy().flatten().tolist())
    
    # Save results
    results = {
        'sleep_predictions': sleep_preds,
        'sleep_targets': sleep_targets,
        'age_predictions': age_preds,
        'age_targets': age_targets
    }
    
    # Save as numpy arrays
    np.save(output_path, results)
    print(f'Results saved to {output_path}')
    
    
    return results

In [21]:
def save_model(model, optimizer, scaler, config, model_path):
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scaler_state_dict': scaler.state_dict(),
        'config': config
    }, model_path)
    print(f'Model saved at {model_path}')

In [22]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
for pretrain_type in config['pretrain_type']:
    
            
        print(f'Fine-tuning model with pretrain type: {pretrain_type}')
        output_path = os.path.join(config['save_path'], f'{pretrain_type}/sleep_stage_and_age_results.npy')
        #output_path = f'/oak/stanford/groups/jamesz/magnusrk/pretraining_comparison_data/sleep_stage_results/{pretrain_type}_sleep_stage_and_age.npy'
        #if output_path folder does not exist, create it
        if not os.path.exists(os.path.dirname(output_path)):
            os.makedirs(os.path.dirname(output_path))
        
        train_dataset = SleepEventClassificationDataset(config, split="pretrain",pretrain_type = pretrain_type)
        config['max_files'] = config['val_size']
        validation_dataset = SleepEventClassificationDataset(config, split="validation",pretrain_type = pretrain_type)
        config['max_files'] = None
        test_dataset = SleepEventClassificationDataset(config, split="test",pretrain_type = pretrain_type)
        train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=config['num_workers'], collate_fn=finetune_collate_fn)
        validation_loader = DataLoader(validation_dataset, batch_size=config['batch_size']//2, shuffle=False, num_workers=config['num_workers'], collate_fn=finetune_collate_fn)
        test_loader = DataLoader(test_dataset, batch_size=config['batch_size']//2, shuffle=False, num_workers=config['num_workers'], collate_fn=finetune_collate_fn)

        model = SleepEventLSTMClassifier(embed_dim=config['model_params']['embed_dim']
                            , num_heads=config['model_params']['num_heads']
                            , num_layers=config['model_params']['num_layers']
                            , num_classes=config['model_params']['num_classes']
                            , pooling_head=config['model_params']['pooling_head']
                            , dropout=config['model_params']['dropout']
                            , max_seq_length=config['model_params']['max_seq_length'])
        model = model.to(device)
        optimizer = optim.AdamW(model.parameters(), lr=config['lr'])
        model.train()
        scaler = torch.cuda.amp.GradScaler()

        train(model, train_loader, validation_loader, optimizer, scaler, config, device, patience=5)
        
        save_path = os.path.join(config['save_path'], f'{pretrain_type}/sleep_stage_and_age_model.pt')
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        save_model(model, optimizer, scaler, config, save_path)
        
        print(f'Saving results to {output_path}')
        evaluate_and_save(model, test_loader, output_path, device)


            


Using device: cuda
Fine-tuning model with pretrain type: CL_pairwise_epochs_36


embs_path: /oak/stanford/groups/mignot/projects/SleepBenchTest/pretrain_comparison/output/final_embeddings
first hdf5_paths_new: /oak/stanford/groups/mignot/projects/SleepBenchTest/pretrain_comparison/output/final_embeddings/CL_pairwise_epochs_36/20250514_043440_epoch_35/SSC_2010_5313980369.hdf5
embs_path: /oak/stanford/groups/mignot/projects/SleepBenchTest/pretrain_comparison/output/final_embeddings
first hdf5_paths_new: /oak/stanford/groups/mignot/projects/SleepBenchTest/pretrain_comparison/output/final_embeddings/CL_pairwise_epochs_36/20250514_043440_epoch_35/SSC_2010_5313980369.hdf5
embs_path: /oak/stanford/groups/mignot/projects/SleepBenchTest/pretrain_comparison/output/final_embeddings
first hdf5_paths_new: /oak/stanford/groups/mignot/projects/SleepBenchTest/pretrain_comparison/output/final_embeddings/CL_pairwise_epochs_36/20250514_043440_epoch_35/SSC_2010_5313980369.hdf5


  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast() if is_training else torch.no_grad():
Epoch 0/29: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:29<00:00, 29.45s/it, loss=cur:4.498/avg:4.498, sleep=cur:3.807/acc:0.087/f1:0.054, diag=cur:0.599, death=cur:0.000, age=cur:0.069]
Validation Epoch 0/29: 100%|██████████████████████████████████████████████████████████████████████████████| 2/2 [00:04<00:00,  2.03s/it, val_loss=cur:4.209/avg:4.504, sleep=cur:3.565/acc:0.061/f1:0.034, diag=cur:0.000, death=cur:0.000, age=cur:0.064]


Epoch 0 Summary: Training Loss: 4.4982, Accuracy: 0.0870, F1: 0.0538 Validation Loss: 4.5036, Accuracy: 0.0609, F1: 0.0341 Best validation loss: 4.5036 Patience counter: 0/5



Epoch 1/29: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.83s/it, loss=cur:4.496/avg:4.496, sleep=cur:3.808/acc:0.090/f1:0.057, diag=cur:0.612, death=cur:0.000, age=cur:0.069]
Validation Epoch 1/29: 100%|██████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.71it/s, val_loss=cur:5.105/avg:5.116, sleep=cur:2.856/acc:0.517/f1:0.136, diag=cur:0.000, death=cur:0.000, age=cur:0.225]


Epoch 1 Summary: Training Loss: 4.4961, Accuracy: 0.0904, F1: 0.0568 Validation Loss: 5.1156, Accuracy: 0.5169, F1: 0.1363 Best validation loss: 4.5036 Patience counter: 1/5



Epoch 2/29: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.00s/it, loss=cur:5.903/avg:5.903, sleep=cur:3.262/acc:0.535/f1:0.139, diag=cur:0.617, death=cur:0.000, age=cur:0.264]
Validation Epoch 2/29: 100%|██████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.67it/s, val_loss=cur:5.178/avg:5.395, sleep=cur:2.929/acc:0.517/f1:0.136, diag=cur:0.000, death=cur:0.000, age=cur:0.225]


Epoch 2 Summary: Training Loss: 5.9027, Accuracy: 0.5346, F1: 0.1393 Validation Loss: 5.3946, Accuracy: 0.5169, F1: 0.1363 Best validation loss: 4.5036 Patience counter: 2/5



Epoch 3/29: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.01s/it, loss=cur:5.941/avg:5.941, sleep=cur:3.300/acc:0.535/f1:0.139, diag=cur:0.635, death=cur:0.000, age=cur:0.264]
Validation Epoch 3/29: 100%|██████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.67it/s, val_loss=cur:5.349/avg:5.328, sleep=cur:3.100/acc:0.517/f1:0.136, diag=cur:0.000, death=cur:0.000, age=cur:0.225]


Epoch 3 Summary: Training Loss: 5.9414, Accuracy: 0.5346, F1: 0.1393 Validation Loss: 5.3279, Accuracy: 0.5169, F1: 0.1363 Best validation loss: 4.5036 Patience counter: 3/5



Epoch 4/29: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.17it/s, loss=cur:5.851/avg:5.851, sleep=cur:3.209/acc:0.535/f1:0.139, diag=cur:0.631, death=cur:0.000, age=cur:0.264]
Validation Epoch 4/29: 100%|██████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.24it/s, val_loss=cur:5.211/avg:5.204, sleep=cur:2.961/acc:0.517/f1:0.136, diag=cur:0.000, death=cur:0.000, age=cur:0.225]


Epoch 4 Summary: Training Loss: 5.8510, Accuracy: 0.5346, F1: 0.1393 Validation Loss: 5.2043, Accuracy: 0.5169, F1: 0.1363 Best validation loss: 4.5036 Patience counter: 4/5



Epoch 5/29: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.21it/s, loss=cur:5.790/avg:5.790, sleep=cur:3.148/acc:0.535/f1:0.139, diag=cur:0.613, death=cur:0.000, age=cur:0.264]
Validation Epoch 5/29: 100%|██████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.23it/s, val_loss=cur:5.001/avg:5.109, sleep=cur:2.751/acc:0.517/f1:0.136, diag=cur:0.000, death=cur:0.000, age=cur:0.225]



Early stopping triggered after 6 epochs

Training finished!
Best validation loss: 4.5036
Model saved at /oak/stanford/groups/mignot/projects/SleepBenchTest/pretrain_comparison/output/results/CL_pairwise_epochs_36/sleep_stage_and_age_model.pt
Saving results to /oak/stanford/groups/mignot/projects/SleepBenchTest/pretrain_comparison/output/results/CL_pairwise_epochs_36/sleep_stage_and_age_results.npy


Evaluating: 100%|█████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.34it/s]

Results saved to /oak/stanford/groups/mignot/projects/SleepBenchTest/pretrain_comparison/output/results/CL_pairwise_epochs_36/sleep_stage_and_age_results.npy





In [23]:
#print(f'Saving results to {output_path}')
#evaluate_and_save(model, test_loader, output_path, device)
