In [None]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm 
import wandb
from torch.optim.lr_scheduler import CosineAnnealingLR
import sys
from torch.amp import autocast, GradScaler


sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "..", "..", "..")))
from src.utils.k_folds_creator import KFoldCreator
from src.utils.utils import get_models_save_path, set_seeds
from src.utils.constants import Constants 
from src.datasets.eeg_dataset_montage import EEGDatasetMontage
from src.models.gru_convolution_attention import NodeAttentionModel

wandb.login()

2025-12-02 23:20:04,179 :: root :: INFO :: Initialising Utils
2025-12-02 23:20:04,760 :: root :: INFO :: Initialising Datasets
2025-12-02 23:20:04,785 :: root :: INFO :: Initialising Models


Skipping module cbramod_dataset due to missing dependency: No module named 'mne'


[34m[1mwandb[0m: Currently logged in as: [33mmaikotrede[0m ([33mhms-hslu-aicomp-hs25[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:

class CFG:
    seed = 42
    n_splits = 5
    data_path = '../../../data/'

    model_name = 'GRUConvNodeAttentionModel'
    hidden_units = 256
    num_layers = 1
    target_size = 6 
    
    num_cnn_blocks = 3 
    
    sampling_rate = 200 # Hz
    sequence_duration = 50 
    downsample_factor = 1
    
    num_channels = 19
    
    dropout = 0.4
    batch_size = 32
    num_workers = 8
    
    folds_to_train = [4] #speficy list of folds to train [0,1,2,3,4]
    

    stage1_epochs = 50
    stage1_lr = 10**-4
    
    stage2_epochs = 15
    stage2_lr = 10**-4.5
    
    patience = 10  
    min_delta = 0.001
    
    use_attention = True

    use_mixup = True      
    mixup_alpha = 0.5      
    

CFG.sequence_length = CFG.sequence_duration * CFG.sampling_rate 

set_seeds(CFG.seed)

TARGETS = Constants.TARGETS

In [3]:
def mixup_data(x, y, alpha=1.0, device='cuda'):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(device)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    mixed_y = lam * y + (1 - lam) * y[index, :]
    
    return mixed_x, mixed_y


def get_dataloaders(train_df, valid_df):
    train_loader = None
    valid_loader = None


    if train_df is not None and not train_df.empty:
        train_dataset = EEGDatasetMontage(
            df=train_df, 
            data_path=CFG.data_path, 
            mode='train', 
            downsample_factor=CFG.downsample_factor, 
            augmentations=["channel_mask", "time_shift"] 
        )

        train_loader = DataLoader(
            train_dataset, batch_size=CFG.batch_size, shuffle=True,
            num_workers=CFG.num_workers, pin_memory=True, drop_last=True, 
            persistent_workers=True if CFG.num_workers > 0 else False
        )
    
    if valid_df is not None and not valid_df.empty:
        valid_dataset = EEGDatasetMontage(
            df=valid_df, 
            data_path=CFG.data_path, 
            mode='valid', 
            downsample_factor=CFG.downsample_factor
        )
        
        valid_loader = DataLoader(
            valid_dataset, batch_size=CFG.batch_size, shuffle=False,
            num_workers=CFG.num_workers, pin_memory=True, drop_last=False, 
            persistent_workers=True if CFG.num_workers > 0 else False
        )
    
    return train_loader, valid_loader

In [4]:
def train_one_stage(fold, stage_name, train_df, valid_df, group_name, starting_checkpoint=None):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    autocast_enabled = (device.type == 'cuda')
    
    if stage_name == "Stage1":
        lr = CFG.stage1_lr
        epochs = CFG.stage1_epochs
    else:
        lr = CFG.stage2_lr
        epochs = CFG.stage2_epochs
        
    print(f"\n--- Starting {stage_name} | Fold {fold} ---")
    

    experiment_name = f"{group_name}_{stage_name}_fold{fold}"
    
    config = {
        "architecture": CFG.model_name,
        "fold": fold, 
        "stage": stage_name,
        "optimizer": "AdamW", 
        "learning_rate": lr, 
        "batch_size": CFG.batch_size,
        "epochs": epochs, 
        "num_cnn_blocks": CFG.num_cnn_blocks,
        "use_attention": CFG.use_attention,
        "seed": CFG.seed
    }

    wandb.init(
        project="hms-aicomp-gru-conv",
        name=experiment_name,
        group=group_name, 
        job_type=stage_name,
        tags=['two-stage', stage_name, f'fold{fold}', f'blocks_{CFG.num_cnn_blocks}'],
        config=config,
        reinit=True
    )


    model = NodeAttentionModel(
        num_nodes=CFG.num_channels,       
        node_embed_size=256,              
        hidden_size=CFG.hidden_units,    
        num_layers=CFG.num_layers,       
        num_classes=CFG.target_size,     
        num_cnn_blocks=CFG.num_cnn_blocks,
        dropout=CFG.dropout,
        use_inception=True              
    )
    model.to(device)
    
    if starting_checkpoint:
        print(f"Loading weights from {starting_checkpoint}...")
        model.load_state_dict(torch.load(starting_checkpoint))

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-2)
    scheduler = CosineAnnealingLR(optimizer, T_max=epochs)
    loss_fn = nn.KLDivLoss(reduction='batchmean')
    scaler = GradScaler(enabled=autocast_enabled)

    train_loader, valid_loader = get_dataloaders(train_df, valid_df)

    best_val_loss = float('inf')
    best_model_path = get_models_save_path() / "TwoStage" / f"{stage_name}_fold{fold}.pth"
    best_model_path.parent.mkdir(parents=True, exist_ok=True)
    
    patience_counter = 0
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        
        for signals, labels in tqdm(train_loader, desc=f"{stage_name} E{epoch+1}", leave=False):
            signals, labels = signals.to(device), labels.to(device)
            
            if CFG.use_mixup:
                signals, labels = mixup_data(signals, labels, alpha=CFG.mixup_alpha, device=device)
            
            optimizer.zero_grad()
            with autocast(enabled=autocast_enabled, device_type=device.type):
                outputs = model(signals)
                log_probs = F.log_softmax(outputs, dim=1)
                loss = loss_fn(log_probs, labels)

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()

            train_loss += loss.item() * signals.size(0)
            
        train_loss /= len(train_loader.dataset)

        model.eval()
        valid_loss = 0
        with torch.no_grad():
            for signals, labels in valid_loader:
                signals, labels = signals.to(device), labels.to(device)
                with autocast(enabled=autocast_enabled, device_type=device.type):
                    outputs = model(signals)
                    log_probs = F.log_softmax(outputs, dim=1)
                    loss = loss_fn(log_probs, labels)
                valid_loss += loss.item() * signals.size(0)

        valid_loss /= len(valid_loader.dataset)

        epoch_lr = optimizer.param_groups[0]['lr']
        print(f"  Ep {epoch+1}: Train={train_loss:.4f} | Val={valid_loss:.4f} | LR={epoch_lr:.6f}")

        wandb.log({
            "epoch": epoch + 1, 
            "train/epoch_loss": train_loss,   
            "val/loss": valid_loss,          
            "val/kl_div": valid_loss,        
            "train/epoch_lr": epoch_lr       
        })

        if valid_loss < best_val_loss - CFG.min_delta:
            best_val_loss = valid_loss
            patience_counter = 0
            torch.save(model.state_dict(), best_model_path)
        else:
            patience_counter += 1
            
        if patience_counter >= CFG.patience:
            print(f"  Early stopping at epoch {epoch+1}")
            break 
            
        scheduler.step()
        
    wandb.finish()
    return best_model_path, best_val_loss

In [5]:
def run_two_stage_pipeline(df):

    print("Total Votes Distribution (Head):")
    print(df['total_votes'].head())
    
    mask_low_votes = df['total_votes'] < 10
    mask_high_votes = df['total_votes'] >= 10
    
    print(f"Stage 1 Data (Low Votes < 10): {mask_low_votes.sum()} samples")
    print(f"Stage 2 Data (High Votes >= 10): {mask_high_votes.sum()} samples")
    group_name = f"TwoStage_montages_block_{CFG.num_cnn_blocks}_attention_{CFG.use_attention}_AUG(MU+CU)"
    
    all_oof_preds = []
    all_oof_labels = []
    fold_scores = []
    
    for fold in range(CFG.n_splits):
        print(f"\n{'='*20} Processing FOLD {fold} {'='*20}")
        if fold not in CFG.folds_to_train:
            print(f"Skipping Fold {fold}...")
            continue
        valid_idx = df['fold'] == fold
        valid_df = df[valid_idx].reset_index(drop=True)
        
        valid_stage2_df = valid_df[valid_df['total_votes'] >= 10].reset_index(drop=True)

        train_stage1 = df[(df['fold'] != fold) & mask_low_votes].reset_index(drop=True)
        
        stage1_path, _ = train_one_stage(
            fold=fold,
            stage_name="Stage1",
            train_df=train_stage1,
            valid_df=valid_df,
            group_name=group_name 
        )

        train_stage2 = df[(df['fold'] != fold) & mask_high_votes].reset_index(drop=True)
        
        stage2_path, best_val_loss = train_one_stage(
            fold=fold,
            stage_name="Stage2",
            train_df=train_stage2,
            valid_df=valid_stage2_df, 
            starting_checkpoint=stage1_path,
            group_name=group_name
        )
        
        fold_scores.append(best_val_loss)

        print(f"Generating OOF predictions for Fold {fold}...")
        device = torch.device('cuda')
        model = NodeAttentionModel(
            num_nodes=CFG.num_channels, node_embed_size=256, hidden_size=CFG.hidden_units,    
            num_layers=CFG.num_layers, num_classes=CFG.target_size, num_cnn_blocks=CFG.num_cnn_blocks,
            dropout=CFG.dropout, use_inception=True              
        )
        model.load_state_dict(torch.load(stage2_path))
        model.to(device)
        model.eval()
        
        _, valid_loader = get_dataloaders(pd.DataFrame(), valid_df)
        
        probs_list = []
        labels_list = []
        
        with torch.no_grad():
            for signals, labels in valid_loader:
                signals = signals.to(device)
                outputs = model(signals)
                probs = F.softmax(outputs, dim=1).cpu().numpy()
                probs_list.append(probs)
                labels_list.append(labels.numpy())
        
        all_oof_preds.append(np.concatenate(probs_list))
        all_oof_labels.append(np.concatenate(labels_list))
        
        del model
        torch.cuda.empty_cache()

        
        run = wandb.init(project="hms-aicomp-gru-conv", job_type="artifact_upload", name=f"artifact_fold{fold}")
        artifact = wandb.Artifact(f'model-fold{fold}-stage2', type='model')
        artifact.add_file(stage2_path)
        wandb.log_artifact(artifact)
        wandb.finish()

    all_oof_preds = np.concatenate(all_oof_preds)
    all_oof_labels = np.concatenate(all_oof_labels)
    
    oof_tensor = torch.tensor(all_oof_preds, dtype=torch.float32)
    true_tensor = torch.tensor(all_oof_labels, dtype=torch.float32)
    
    oof_tensor = torch.clamp(oof_tensor, 1e-6, 1.0)
    
    kl_loss = nn.KLDivLoss(reduction="batchmean")
    overall_score = kl_loss(torch.log(oof_tensor), true_tensor).item()
    
    return overall_score, fold_scores


In [None]:
print("Loading Data...")
df = pd.read_csv(CFG.data_path + 'processed_data_max_vote_window.csv') 

if 'expert_consensus' not in df.columns:
    df['expert_consensus'] = df[TARGETS].idxmax(axis=1)

print('Train shape:', df.shape)

fold_creator = KFoldCreator(n_splits=CFG.n_splits, seed=CFG.seed)
df = fold_creator.create_folds(df, stratify_col='expert_consensus', group_col='patient_id')

overall_cv, fold_scores = run_two_stage_pipeline(df)

print("\n" + "="*50)
print("FINAL RESULTS")
print("="*50)
print(f"Overall OOF KL-Divergence: {overall_cv:.4f}")
print(f"Average Fold Score: {np.mean(fold_scores):.4f}")
print("="*50)

Loading Data...
Train shape: (17089, 13)
Total Votes Distribution (Head):
0    12
1    14
2     1
3     1
4     2
Name: total_votes, dtype: int64
Stage 1 Data (Low Votes < 10): 11150 samples
Stage 2 Data (High Votes >= 10): 5939 samples

Skipping Fold 0...

Skipping Fold 1...

Skipping Fold 2...


--- Starting Stage1 | Fold 3 ---




Stage1 E1:   0%|          | 0/297 [00:00<?, ?it/s]

  Ep 1: Train=1.0887 | Val=1.0827 | LR=0.000100


Stage1 E2:   0%|          | 0/297 [00:00<?, ?it/s]

  Ep 2: Train=0.9243 | Val=1.0095 | LR=0.000100


Stage1 E3:   0%|          | 0/297 [00:00<?, ?it/s]

  Ep 3: Train=0.8355 | Val=0.8018 | LR=0.000100


Stage1 E4:   0%|          | 0/297 [00:00<?, ?it/s]

  Ep 4: Train=0.7855 | Val=0.7728 | LR=0.000099


Stage1 E5:   0%|          | 0/297 [00:00<?, ?it/s]

  Ep 5: Train=0.7479 | Val=0.7372 | LR=0.000098


Stage1 E6:   0%|          | 0/297 [00:00<?, ?it/s]

  Ep 6: Train=0.7151 | Val=0.7210 | LR=0.000098


Stage1 E7:   0%|          | 0/297 [00:00<?, ?it/s]

  Ep 7: Train=0.6949 | Val=0.7231 | LR=0.000096


Stage1 E8:   0%|          | 0/297 [00:00<?, ?it/s]

  Ep 8: Train=0.6821 | Val=0.7027 | LR=0.000095


Stage1 E9:   0%|          | 0/297 [00:00<?, ?it/s]

  Ep 9: Train=0.6657 | Val=0.6854 | LR=0.000094


Stage1 E10:   0%|          | 0/297 [00:00<?, ?it/s]

  Ep 10: Train=0.6528 | Val=0.7433 | LR=0.000092


Stage1 E11:   0%|          | 0/297 [00:00<?, ?it/s]

  Ep 11: Train=0.6431 | Val=0.7171 | LR=0.000090


Stage1 E12:   0%|          | 0/297 [00:00<?, ?it/s]

  Ep 12: Train=0.6306 | Val=0.6790 | LR=0.000089


Stage1 E13:   0%|          | 0/297 [00:00<?, ?it/s]

  Ep 13: Train=0.6232 | Val=0.6618 | LR=0.000086


Stage1 E14:   0%|          | 0/297 [00:00<?, ?it/s]

  Ep 14: Train=0.6134 | Val=0.6591 | LR=0.000084


Stage1 E15:   0%|          | 0/297 [00:00<?, ?it/s]

  Ep 15: Train=0.6178 | Val=0.6433 | LR=0.000082


Stage1 E16:   0%|          | 0/297 [00:00<?, ?it/s]

  Ep 16: Train=0.6071 | Val=0.6501 | LR=0.000079


Stage1 E17:   0%|          | 0/297 [00:00<?, ?it/s]

  Ep 17: Train=0.5886 | Val=0.6864 | LR=0.000077


Stage1 E18:   0%|          | 0/297 [00:00<?, ?it/s]

  Ep 18: Train=0.5873 | Val=0.6625 | LR=0.000074


Stage1 E19:   0%|          | 0/297 [00:00<?, ?it/s]

  Ep 19: Train=0.5857 | Val=0.6808 | LR=0.000071


Stage1 E20:   0%|          | 0/297 [00:00<?, ?it/s]

  Ep 20: Train=0.5819 | Val=0.6629 | LR=0.000068


Stage1 E21:   0%|          | 0/297 [00:00<?, ?it/s]

  Ep 21: Train=0.5680 | Val=0.6430 | LR=0.000065


Stage1 E22:   0%|          | 0/297 [00:00<?, ?it/s]

  Ep 22: Train=0.5651 | Val=0.6918 | LR=0.000062


Stage1 E23:   0%|          | 0/297 [00:00<?, ?it/s]

  Ep 23: Train=0.5622 | Val=0.6493 | LR=0.000059


Stage1 E24:   0%|          | 0/297 [00:00<?, ?it/s]

  Ep 24: Train=0.5611 | Val=0.6455 | LR=0.000056


Stage1 E25:   0%|          | 0/297 [00:00<?, ?it/s]

  Ep 25: Train=0.5471 | Val=0.6485 | LR=0.000053
  Early stopping at epoch 25


0,1
epoch,▁▁▂▂▂▂▃▃▃▄▄▄▅▅▅▅▆▆▆▇▇▇▇██
train/epoch_loss,█▆▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁
train/epoch_lr,██████▇▇▇▇▇▆▆▆▅▅▅▄▄▃▃▂▂▁▁
val/kl_div,█▇▄▃▂▂▂▂▂▃▂▂▁▁▁▁▂▁▂▁▁▂▁▁▁
val/loss,█▇▄▃▂▂▂▂▂▃▂▂▁▁▁▁▂▁▂▁▁▂▁▁▁

0,1
epoch,25.0
train/epoch_loss,0.54714
train/epoch_lr,5e-05
val/kl_div,0.64851
val/loss,0.64851



--- Starting Stage2 | Fold 3 ---


Loading weights from /home/maiko/Documents/HSLU/AICOMP/HSLU.AICOMP.HMS/models/TwoStage/Stage1_fold3.pth...


Stage2 E1:   0%|          | 0/155 [00:00<?, ?it/s]

  Ep 1: Train=0.3801 | Val=0.3508 | LR=0.000032


Stage2 E2:   0%|          | 0/155 [00:00<?, ?it/s]

  Ep 2: Train=0.3528 | Val=0.3610 | LR=0.000031


Stage2 E3:   0%|          | 0/155 [00:00<?, ?it/s]

  Ep 3: Train=0.3439 | Val=0.3465 | LR=0.000030


Stage2 E4:   0%|          | 0/155 [00:00<?, ?it/s]

  Ep 4: Train=0.3317 | Val=0.3596 | LR=0.000029


Stage2 E5:   0%|          | 0/155 [00:00<?, ?it/s]

  Ep 5: Train=0.3288 | Val=0.3378 | LR=0.000026


Stage2 E6:   0%|          | 0/155 [00:00<?, ?it/s]

  Ep 6: Train=0.3270 | Val=0.3335 | LR=0.000024


Stage2 E7:   0%|          | 0/155 [00:00<?, ?it/s]

  Ep 7: Train=0.3263 | Val=0.3244 | LR=0.000021


Stage2 E8:   0%|          | 0/155 [00:00<?, ?it/s]

  Ep 8: Train=0.3211 | Val=0.3320 | LR=0.000017


Stage2 E9:   0%|          | 0/155 [00:00<?, ?it/s]

  Ep 9: Train=0.3188 | Val=0.3265 | LR=0.000014


Stage2 E10:   0%|          | 0/155 [00:00<?, ?it/s]

  Ep 10: Train=0.3105 | Val=0.3245 | LR=0.000011


Stage2 E11:   0%|          | 0/155 [00:00<?, ?it/s]

  Ep 11: Train=0.3104 | Val=0.3288 | LR=0.000008


Stage2 E12:   0%|          | 0/155 [00:00<?, ?it/s]

  Ep 12: Train=0.3071 | Val=0.3268 | LR=0.000005


Stage2 E13:   0%|          | 0/155 [00:00<?, ?it/s]

  Ep 13: Train=0.3070 | Val=0.3260 | LR=0.000003


Stage2 E14:   0%|          | 0/155 [00:00<?, ?it/s]

  Ep 14: Train=0.3108 | Val=0.3205 | LR=0.000001


Stage2 E15:   0%|          | 0/155 [00:00<?, ?it/s]

  Ep 15: Train=0.3086 | Val=0.3222 | LR=0.000000


0,1
epoch,▁▁▂▃▃▃▄▅▅▅▆▇▇▇█
train/epoch_loss,█▅▅▃▃▃▃▂▂▁▁▁▁▁▁
train/epoch_lr,███▇▇▆▆▅▄▃▃▂▂▁▁
val/kl_div,▆█▅█▄▃▂▃▂▂▂▂▂▁▁
val/loss,▆█▅█▄▃▂▃▂▂▂▂▂▁▁

0,1
epoch,15.0
train/epoch_loss,0.30859
train/epoch_lr,0.0
val/kl_div,0.32218
val/loss,0.32218


Generating OOF predictions for Fold 3...




--- Starting Stage1 | Fold 4 ---


Stage1 E1:   0%|          | 0/276 [00:00<?, ?it/s]

  Ep 1: Train=1.1444 | Val=1.0437 | LR=0.000100


Stage1 E2:   0%|          | 0/276 [00:00<?, ?it/s]

  Ep 2: Train=0.9697 | Val=0.8928 | LR=0.000100


Stage1 E3:   0%|          | 0/276 [00:00<?, ?it/s]

  Ep 3: Train=0.8719 | Val=0.8184 | LR=0.000100


Stage1 E4:   0%|          | 0/276 [00:00<?, ?it/s]

  Ep 4: Train=0.8003 | Val=0.7732 | LR=0.000099


Stage1 E5:   0%|          | 0/276 [00:00<?, ?it/s]

  Ep 5: Train=0.7530 | Val=0.7653 | LR=0.000098


Stage1 E6:   0%|          | 0/276 [00:00<?, ?it/s]

  Ep 6: Train=0.7276 | Val=0.7607 | LR=0.000098


Stage1 E7:   0%|          | 0/276 [00:00<?, ?it/s]

  Ep 7: Train=0.6974 | Val=0.7105 | LR=0.000096


Stage1 E8:   0%|          | 0/276 [00:00<?, ?it/s]

  Ep 8: Train=0.6891 | Val=0.7340 | LR=0.000095


Stage1 E9:   0%|          | 0/276 [00:00<?, ?it/s]

  Ep 9: Train=0.6701 | Val=0.7453 | LR=0.000094


Stage1 E10:   0%|          | 0/276 [00:00<?, ?it/s]

  Ep 10: Train=0.6609 | Val=0.6807 | LR=0.000092


Stage1 E11:   0%|          | 0/276 [00:00<?, ?it/s]

  Ep 11: Train=0.6526 | Val=0.7060 | LR=0.000090


Stage1 E12:   0%|          | 0/276 [00:00<?, ?it/s]

  Ep 12: Train=0.6431 | Val=0.7309 | LR=0.000089


Stage1 E13:   0%|          | 0/276 [00:00<?, ?it/s]

  Ep 13: Train=0.6341 | Val=0.7110 | LR=0.000086


Stage1 E14:   0%|          | 0/276 [00:00<?, ?it/s]

  Ep 14: Train=0.6262 | Val=0.7023 | LR=0.000084


Stage1 E15:   0%|          | 0/276 [00:00<?, ?it/s]

  Ep 15: Train=0.6148 | Val=0.7091 | LR=0.000082


Stage1 E16:   0%|          | 0/276 [00:00<?, ?it/s]

  Ep 16: Train=0.6090 | Val=0.8030 | LR=0.000079


Stage1 E17:   0%|          | 0/276 [00:00<?, ?it/s]

  Ep 17: Train=0.5987 | Val=0.6851 | LR=0.000077


Stage1 E18:   0%|          | 0/276 [00:00<?, ?it/s]

  Ep 18: Train=0.5959 | Val=0.6721 | LR=0.000074


Stage1 E19:   0%|          | 0/276 [00:00<?, ?it/s]

  Ep 19: Train=0.5899 | Val=0.6671 | LR=0.000071


Stage1 E20:   0%|          | 0/276 [00:00<?, ?it/s]

  Ep 20: Train=0.5893 | Val=0.7044 | LR=0.000068


Stage1 E21:   0%|          | 0/276 [00:00<?, ?it/s]

  Ep 21: Train=0.5844 | Val=0.6794 | LR=0.000065


Stage1 E22:   0%|          | 0/276 [00:00<?, ?it/s]

  Ep 22: Train=0.5672 | Val=0.6688 | LR=0.000062


Stage1 E23:   0%|          | 0/276 [00:00<?, ?it/s]

  Ep 23: Train=0.5611 | Val=0.6378 | LR=0.000059


Stage1 E24:   0%|          | 0/276 [00:00<?, ?it/s]

  Ep 24: Train=0.5681 | Val=0.6938 | LR=0.000056


Stage1 E25:   0%|          | 0/276 [00:00<?, ?it/s]

  Ep 25: Train=0.5580 | Val=0.6537 | LR=0.000053


Stage1 E26:   0%|          | 0/276 [00:00<?, ?it/s]

  Ep 26: Train=0.5574 | Val=0.6544 | LR=0.000050


Stage1 E27:   0%|          | 0/276 [00:00<?, ?it/s]

  Ep 27: Train=0.5497 | Val=0.6562 | LR=0.000047


Stage1 E28:   0%|          | 0/276 [00:00<?, ?it/s]

  Ep 28: Train=0.5397 | Val=0.6386 | LR=0.000044


Stage1 E29:   0%|          | 0/276 [00:00<?, ?it/s]

  Ep 29: Train=0.5378 | Val=0.6433 | LR=0.000041


Stage1 E30:   0%|          | 0/276 [00:00<?, ?it/s]

  Ep 30: Train=0.5352 | Val=0.6923 | LR=0.000038


Stage1 E31:   0%|          | 0/276 [00:00<?, ?it/s]

  Ep 31: Train=0.5392 | Val=0.6899 | LR=0.000035


Stage1 E32:   0%|          | 0/276 [00:00<?, ?it/s]

  Ep 32: Train=0.5272 | Val=0.7022 | LR=0.000032


Stage1 E33:   0%|          | 0/276 [00:00<?, ?it/s]

  Ep 33: Train=0.5148 | Val=0.6770 | LR=0.000029
  Early stopping at epoch 33


0,1
epoch,▁▁▁▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇███
train/epoch_loss,█▆▅▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁
train/epoch_lr,████████▇▇▇▇▇▆▆▆▆▅▅▅▅▄▄▄▃▃▃▂▂▂▂▁▁
val/kl_div,█▅▄▃▃▃▂▃▃▂▂▃▂▂▂▄▂▂▂▂▂▂▁▂▁▁▁▁▁▂▂▂▂
val/loss,█▅▄▃▃▃▂▃▃▂▂▃▂▂▂▄▂▂▂▂▂▂▁▂▁▁▁▁▁▂▂▂▂

0,1
epoch,33.0
train/epoch_loss,0.51483
train/epoch_lr,3e-05
val/kl_div,0.67703
val/loss,0.67703



--- Starting Stage2 | Fold 4 ---


Loading weights from /home/maiko/Documents/HSLU/AICOMP/HSLU.AICOMP.HMS/models/TwoStage/Stage1_fold4.pth...


Stage2 E1:   0%|          | 0/152 [00:00<?, ?it/s]

  Ep 1: Train=0.3685 | Val=0.3632 | LR=0.000032


Stage2 E2:   0%|          | 0/152 [00:00<?, ?it/s]

  Ep 2: Train=0.3351 | Val=0.3500 | LR=0.000031


Stage2 E3:   0%|          | 0/152 [00:00<?, ?it/s]

  Ep 3: Train=0.3334 | Val=0.3566 | LR=0.000030


Stage2 E4:   0%|          | 0/152 [00:00<?, ?it/s]

  Ep 4: Train=0.3251 | Val=0.3656 | LR=0.000029


Stage2 E5:   0%|          | 0/152 [00:00<?, ?it/s]

  Ep 5: Train=0.3190 | Val=0.3465 | LR=0.000026


Stage2 E6:   0%|          | 0/152 [00:00<?, ?it/s]

  Ep 6: Train=0.3129 | Val=0.3574 | LR=0.000024


Stage2 E7:   0%|          | 0/152 [00:00<?, ?it/s]

  Ep 7: Train=0.3079 | Val=0.3430 | LR=0.000021


Stage2 E8:   0%|          | 0/152 [00:00<?, ?it/s]