In [None]:
import os
import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import StratifiedGroupKFold
import timm 
from tqdm.auto import tqdm 
import wandb
from torch.optim.lr_scheduler import CosineAnnealingLR
import sys
from pathlib import Path

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "..", "..", "..")))
from src.utils.k_folds_creator import KFoldCreator
from src.utils.utils import get_models_save_path
from src.models.base_cnn import BaseCNN
from src.utils.constants import Constants 
from src.utils.eeg_spectrogram_creator import EEGSpectrogramGenerator
from src.datasets.multi_spectrogram import MultiSpectrogramDataset



2025-10-24 11:57:18,801 :: root :: INFO :: Initialising Utils
2025-10-24 11:57:18,806 :: root :: INFO :: Initialising Models


In [2]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mmaikotrede[0m ([33mhms-hslu-aicomp-hs25[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:

def display_spectrogram(spectrogram_tensor, labels, targets):
    """
    Displays a spectrogram tensor with its corresponding labels.
    """
    img_np = spectrogram_tensor[0].numpy()
    
    plt.figure(figsize=(12, 5))
    plt.imshow(img_np, aspect='auto', origin='lower', cmap='viridis')
    plt.colorbar(label='Normalized Log Power')
    plt.xlabel('Time Steps')
    plt.ylabel('Frequency Bins')
    
    label_str = ", ".join([f"{name}: {val:.2f}" for name, val in zip(targets, labels)])
    plt.title(f"Spectrogram Sample\nLabels: {label_str}")
    
    plt.tight_layout()
    plt.show()

In [4]:
class CFG:
    seed = 42
    n_splits = 5
    data_path = '../../../data/'
    
    eeg_spec_path = '../../../data/custom_eegs/cwt'
    
    model_name = 'tf_efficientnet_b0_ns'
    in_channels = 8  
    target_size = 6 
    
    batch_size = 32
    num_workers = 6
    epochs = 5
    lr = 1e-3
    
    img_size = (128, 256)


def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        #torch.backends.cudnn.deterministic = True

set_seed(CFG.seed)



In [6]:
TARGETS = Constants.TARGETS

In [7]:
def get_dataloaders(df, fold_id):
    train_df = df[df['fold'] != fold_id].reset_index(drop=True)
    valid_df = df[df['fold'] == fold_id].reset_index(drop=True)

    train_dataset = MultiSpectrogramDataset(
        train_df, TARGETS, CFG.data_path, CFG.img_size, CFG.eeg_spec_path, mode='train'
    )
    valid_dataset = MultiSpectrogramDataset(
        valid_df, TARGETS, CFG.data_path, CFG.img_size, CFG.eeg_spec_path, mode='train'
    )
    train_loader = DataLoader(
        train_dataset,
        batch_size=CFG.batch_size,
        shuffle=True,
        num_workers=CFG.num_workers,
        pin_memory=True,
        drop_last=True,
        persistent_workers=True
    )
    valid_loader = DataLoader(
        valid_dataset,
        batch_size=CFG.batch_size,
        shuffle=False,
        num_workers=CFG.num_workers,
        pin_memory=True,
        drop_last=False,
        persistent_workers=True
    )
    
    return train_loader, valid_loader

In [None]:
from torch.amp import GradScaler, autocast
def run_training(df, DATA_PREPARATION_VOTE_METHOD):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    scaler = GradScaler(enabled=(device.type == 'cuda'))
    
    all_oof_preds = []
    all_oof_labels = []
    for fold in range(CFG.n_splits):
        print(f"\n========== FOLD {fold} ==========")

        config = {
            # Model
            "architecture": CFG.model_name, "pretrained": True,
            # Data
            "fold": fold, "features": "spectrograms", "window_selection": "sum_and_normalize",
            # Training
            "optimizer": "AdamW", "learning_rate": CFG.lr, "batch_size": CFG.batch_size,
            "epochs": CFG.epochs, "seed": CFG.seed, "Scheduler": "CosineAnnealingLR" 
        }

        wandb.init(
            project="hms-aicomp",
            name=f"multi-spect-effnetb0-spec-fold{fold}", 
            tags=['baseline', f'fold{fold}'],
            config=config
        )

        model = BaseCNN(CFG.model_name, pretrained=True, num_classes=CFG.target_size, in_channels=CFG.in_channels)
        model.to(device)

        optimizer = torch.optim.AdamW(model.parameters(), lr=CFG.lr)
        scheduler = CosineAnnealingLR(optimizer, T_max=CFG.epochs)
        loss_fn = nn.KLDivLoss(reduction='batchmean')
        train_loader, valid_loader = get_dataloaders(df, fold)

        best_val_loss = float('inf')
        best_model_path = None

        for epoch in range(CFG.epochs):
            print(f"   --- Epoch {epoch+1}/{CFG.epochs} ---")

            model.train()
            train_loss = 0
            for images, labels in tqdm(train_loader, desc="Training"):
                images, labels = images.to(device), labels.to(device)

                optimizer.zero_grad()

                with autocast(device_type=device.type, dtype=torch.float16):
                    outputs = model(images)
                    log_probs = F.log_softmax(outputs, dim=1)
                    loss = loss_fn(log_probs, labels)

                scaler.scale(loss).backward()

                scaler.step(optimizer)
                scaler.update()

                train_loss += loss.item() * images.size(0)
                wandb.log({"train/loss": loss.item()})

            train_loss /= len(train_loader.dataset)

            model.eval()
            valid_loss = 0
            with torch.no_grad():
                for images, labels in tqdm(valid_loader, desc="Validation"):
                    images, labels = images.to(device), labels.to(device)
                    
                    with autocast(device_type=device.type, dtype=torch.float16):
                        outputs = model(images)
                        log_probs = F.log_softmax(outputs, dim=1)
                        loss = loss_fn(log_probs, labels)
                        
                    valid_loss += loss.item() * images.size(0)

            valid_loss /= len(valid_loader.dataset)
            
            epoch_lr = optimizer.param_groups[0]['lr']
            print(f"   Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Valid Loss = {valid_loss:.4f}, LR = {epoch_lr:.6f}")

            wandb.log({
                "epoch": epoch + 1,
                "train/epoch_loss": train_loss,
                "val/loss": valid_loss,
                "val/kl_div": valid_loss,
                "train/epoch_lr": epoch_lr
            })

            if valid_loss < best_val_loss:
                best_val_loss = valid_loss
                best_model_path = get_models_save_path() / "MultiSpectCNN" / DATA_PREPARATION_VOTE_METHOD / f'best_model_fold{fold}.pth'
                best_model_path.parent.mkdir(parents=True, exist_ok=True)
                torch.save(model.state_dict(), best_model_path)
                print(f"  New best model saved with validation loss: {best_val_loss:.4f}")
            scheduler.step()
            
        print(f"   --- Generating OOF predictions for fold {fold} ---")
        if best_model_path and best_model_path.exists():
            model.load_state_dict(torch.load(best_model_path))
            model.eval()

            fold_oof_preds = []
            fold_oof_labels = []

            with torch.no_grad():
                for images, labels in tqdm(valid_loader, desc=f"OOF Prediction Fold {fold}"):
                    images = images.to(device)
                    outputs = model(images)
                    probs = F.softmax(outputs, dim=1).cpu()

                    fold_oof_preds.append(probs)
                    fold_oof_labels.append(labels.cpu())

            all_oof_preds.append(torch.cat(fold_oof_preds).numpy())
            all_oof_labels.append(torch.cat(fold_oof_labels).numpy())
            print(f"   Finished OOF predictions for fold {fold}")
        else:
            print(f"   WARNING: No model file found at {best_model_path}. Skipping OOF for this fold.")


        wandb.summary['best_val_kl_div'] = best_val_loss

        if best_model_path:
            artifact = wandb.Artifact(f'model-fold{fold}', type='model')
            artifact.add_file(best_model_path)
            wandb.log_artifact(artifact)
            print(f"\nLogged artifact for fold {fold} with best validation loss: {best_val_loss:.4f}")
        else:
            print("\nNo best model was saved during training for this fold.")

        wandb.finish()
        
    if all_oof_preds and all_oof_labels:
        print("\nCalculating final OOF score...")
        final_oof_preds = np.concatenate(all_oof_preds)
        final_oof_labels = np.concatenate(all_oof_labels)

        oof_preds_tensor = torch.tensor(final_oof_preds, dtype=torch.float32)
        oof_labels_tensor = torch.tensor(final_oof_labels, dtype=torch.float32)

        log_oof_preds_tensor = torch.log(oof_preds_tensor)

        kl_loss_fn = nn.KLDivLoss(reduction='batchmean')
        overall_oof_score = kl_loss_fn(log_oof_preds_tensor, oof_labels_tensor).item()

        print(f"\nOverall OOF KL Score: {overall_oof_score:.4f}")
    else:
        print("\nCould not calculate OOF score because no predictions were generated.")
        
    return overall_oof_score

# Sum Votes

In [9]:
DATA_PREPARATION_VOTE_METHOD = "sum_and_normalize"


In [10]:
print("Preparing data and creating folds...")
df = pd.read_csv(CFG.data_path + 'processed_data_sum_votes_window.csv')
print('Train shape:', df.shape)
print('Targets', list(TARGETS))

fold_creator = KFoldCreator(n_splits=CFG.n_splits, seed=CFG.seed)
df = fold_creator.create_folds(df, stratify_col='expert_consensus', group_col='patient_id')

print("Folds created. Value counts per fold:")
print(df['fold'].value_counts())


Preparing data and creating folds...
Train shape: (17089, 12)
Targets ['seizure_vote', 'lpd_vote', 'gpd_vote', 'lrda_vote', 'grda_vote', 'other_vote']
Folds created. Value counts per fold:
fold
1    3938
2    3667
4    3417
0    3334
3    2733
Name: count, dtype: int64


In [11]:
if __name__ == '__main__':

    print("\nDisplaying a few sample spectrograms before training...")
    temp_dataset = MultiSpectrogramDataset(df, TARGETS, CFG.data_path, CFG.img_size, CFG.eeg_spec_path)
    overall_oof_score = run_training(df, DATA_PREPARATION_VOTE_METHOD)


Displaying a few sample spectrograms before training...
Using device: cuda



  model = create_fn(
2025-10-24 11:57:40,878 :: timm.models._builder :: INFO :: Loading pretrained weights from Hugging Face hub (timm/tf_efficientnet_b0.ns_jft_in1k)
2025-10-24 11:57:41,021 :: timm.models._hub :: INFO :: [timm/tf_efficientnet_b0.ns_jft_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
2025-10-24 11:57:41,032 :: timm.models._builder :: INFO :: Missing keys (classifier.weight, classifier.bias) discovered while loading pretrained weights. This is expected if model is being adapted.


   --- Epoch 1/5 ---


Training:   0%|          | 0/429 [00:00<?, ?it/s]

Validation:   0%|          | 0/105 [00:00<?, ?it/s]

   Epoch 1: Train Loss = 0.7287, Valid Loss = 0.7385, LR = 0.001000
  New best model saved with validation loss: 0.7385
   --- Epoch 2/5 ---


Training:   0%|          | 0/429 [00:00<?, ?it/s]

Validation:   0%|          | 0/105 [00:00<?, ?it/s]

   Epoch 2: Train Loss = 0.4881, Valid Loss = 0.7254, LR = 0.000905
  New best model saved with validation loss: 0.7254
   --- Epoch 3/5 ---


Training:   0%|          | 0/429 [00:00<?, ?it/s]

Validation:   0%|          | 0/105 [00:00<?, ?it/s]

   Epoch 3: Train Loss = 0.3485, Valid Loss = 0.6735, LR = 0.000655
  New best model saved with validation loss: 0.6735
   --- Epoch 4/5 ---


Training:   0%|          | 0/429 [00:00<?, ?it/s]

Validation:   0%|          | 0/105 [00:00<?, ?it/s]

   Epoch 4: Train Loss = 0.2046, Valid Loss = 0.6821, LR = 0.000345
   --- Epoch 5/5 ---


Training:   0%|          | 0/429 [00:00<?, ?it/s]

Validation:   0%|          | 0/105 [00:00<?, ?it/s]

   Epoch 5: Train Loss = 0.1215, Valid Loss = 0.6814, LR = 0.000095
   --- Generating OOF predictions for fold 0 ---


OOF Prediction Fold 0:   0%|          | 0/105 [00:00<?, ?it/s]

   Finished OOF predictions for fold 0

Logged artifact for fold 0 with best validation loss: 0.6735


0,1
epoch,▁▃▅▆█
train/epoch_loss,█▅▄▂▁
train/epoch_lr,█▇▅▃▁
train/loss,█▄▄▄▃▄▃▄▂▂▃▃▂▂▃▃▄▂▂▂▂▂▂▂▁▁▂▁▁▁▁▁▂▁▁▁▁▁▁▁
val/kl_div,█▇▁▂▂
val/loss,█▇▁▂▂

0,1
best_val_kl_div,0.67348
epoch,5.0
train/epoch_loss,0.12155
train/epoch_lr,0.0001
train/loss,0.0831
val/kl_div,0.68139
val/loss,0.68139





  model = create_fn(
2025-10-24 12:06:43,831 :: timm.models._builder :: INFO :: Loading pretrained weights from Hugging Face hub (timm/tf_efficientnet_b0.ns_jft_in1k)
2025-10-24 12:06:44,025 :: timm.models._hub :: INFO :: [timm/tf_efficientnet_b0.ns_jft_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
2025-10-24 12:06:44,095 :: timm.models._builder :: INFO :: Missing keys (classifier.weight, classifier.bias) discovered while loading pretrained weights. This is expected if model is being adapted.


   --- Epoch 1/5 ---


Training:   0%|          | 0/410 [00:00<?, ?it/s]

Validation:   0%|          | 0/124 [00:00<?, ?it/s]

   Epoch 1: Train Loss = 0.7473, Valid Loss = 0.6823, LR = 0.001000
  New best model saved with validation loss: 0.6823
   --- Epoch 2/5 ---


Training:   0%|          | 0/410 [00:00<?, ?it/s]

Validation:   0%|          | 0/124 [00:00<?, ?it/s]

   Epoch 2: Train Loss = 0.4924, Valid Loss = 0.8389, LR = 0.000905
   --- Epoch 3/5 ---


Training:   0%|          | 0/410 [00:00<?, ?it/s]

Validation:   0%|          | 0/124 [00:00<?, ?it/s]

   Epoch 3: Train Loss = 0.3571, Valid Loss = 0.7470, LR = 0.000655
   --- Epoch 4/5 ---


Training:   0%|          | 0/410 [00:00<?, ?it/s]

Validation:   0%|          | 0/124 [00:00<?, ?it/s]

   Epoch 4: Train Loss = 0.2146, Valid Loss = 0.6869, LR = 0.000345
   --- Epoch 5/5 ---


Training:   0%|          | 0/410 [00:00<?, ?it/s]

Validation:   0%|          | 0/124 [00:00<?, ?it/s]

   Epoch 5: Train Loss = 0.1308, Valid Loss = 0.7022, LR = 0.000095
   --- Generating OOF predictions for fold 1 ---


OOF Prediction Fold 1:   0%|          | 0/124 [00:00<?, ?it/s]

   Finished OOF predictions for fold 1

Logged artifact for fold 1 with best validation loss: 0.6823


0,1
epoch,▁▃▅▆█
train/epoch_loss,█▅▄▂▁
train/epoch_lr,█▇▅▃▁
train/loss,█▄▄▄▄▃▄▄▄▃▃▃▃▄▄▃▃▂▃▂▂▂▂▃▃▂▂▁▂▂▁▂▂▁▁▁▁▁▁▁
val/kl_div,▁█▄▁▂
val/loss,▁█▄▁▂

0,1
best_val_kl_div,0.68232
epoch,5.0
train/epoch_loss,0.13075
train/epoch_lr,0.0001
train/loss,0.14524
val/kl_div,0.7022
val/loss,0.7022





  model = create_fn(
2025-10-24 12:16:11,286 :: timm.models._builder :: INFO :: Loading pretrained weights from Hugging Face hub (timm/tf_efficientnet_b0.ns_jft_in1k)
2025-10-24 12:16:11,460 :: timm.models._hub :: INFO :: [timm/tf_efficientnet_b0.ns_jft_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
2025-10-24 12:16:11,509 :: timm.models._builder :: INFO :: Missing keys (classifier.weight, classifier.bias) discovered while loading pretrained weights. This is expected if model is being adapted.


   --- Epoch 1/5 ---


Training:   0%|          | 0/419 [00:00<?, ?it/s]

Validation:   0%|          | 0/115 [00:00<?, ?it/s]

   Epoch 1: Train Loss = 0.7423, Valid Loss = 0.7431, LR = 0.001000
  New best model saved with validation loss: 0.7431
   --- Epoch 2/5 ---


Training:   0%|          | 0/419 [00:00<?, ?it/s]

Validation:   0%|          | 0/115 [00:00<?, ?it/s]

   Epoch 2: Train Loss = 0.5033, Valid Loss = 0.6996, LR = 0.000905
  New best model saved with validation loss: 0.6996
   --- Epoch 3/5 ---


Training:   0%|          | 0/419 [00:00<?, ?it/s]

Validation:   0%|          | 0/115 [00:00<?, ?it/s]

   Epoch 3: Train Loss = 0.3650, Valid Loss = 0.7193, LR = 0.000655
   --- Epoch 4/5 ---


Training:   0%|          | 0/419 [00:00<?, ?it/s]

Validation:   0%|          | 0/115 [00:00<?, ?it/s]

   Epoch 4: Train Loss = 0.2177, Valid Loss = 0.6747, LR = 0.000345
  New best model saved with validation loss: 0.6747
   --- Epoch 5/5 ---


Training:   0%|          | 0/419 [00:00<?, ?it/s]

Validation:   0%|          | 0/115 [00:00<?, ?it/s]

   Epoch 5: Train Loss = 0.1293, Valid Loss = 0.7166, LR = 0.000095
   --- Generating OOF predictions for fold 2 ---


OOF Prediction Fold 2:   0%|          | 0/115 [00:00<?, ?it/s]

   Finished OOF predictions for fold 2

Logged artifact for fold 2 with best validation loss: 0.6747


0,1
epoch,▁▃▅▆█
train/epoch_loss,█▅▄▂▁
train/epoch_lr,█▇▅▃▁
train/loss,█▆▄▃▄▃▃▂▃▃▄▂▃▂▃▂▂▂▂▂▃▂▃▃▂▂▁▂▂▂▂▁▂▂▁▁▂▁▁▁
val/kl_div,█▄▆▁▅
val/loss,█▄▆▁▅

0,1
best_val_kl_div,0.67473
epoch,5.0
train/epoch_loss,0.12925
train/epoch_lr,0.0001
train/loss,0.15601
val/kl_div,0.71664
val/loss,0.71664





  model = create_fn(
2025-10-24 12:25:58,055 :: timm.models._builder :: INFO :: Loading pretrained weights from Hugging Face hub (timm/tf_efficientnet_b0.ns_jft_in1k)
2025-10-24 12:25:58,201 :: timm.models._hub :: INFO :: [timm/tf_efficientnet_b0.ns_jft_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
2025-10-24 12:25:58,235 :: timm.models._builder :: INFO :: Missing keys (classifier.weight, classifier.bias) discovered while loading pretrained weights. This is expected if model is being adapted.


   --- Epoch 1/5 ---


Training:   0%|          | 0/448 [00:00<?, ?it/s]

Validation:   0%|          | 0/86 [00:00<?, ?it/s]

   Epoch 1: Train Loss = 0.7349, Valid Loss = 0.6939, LR = 0.001000
  New best model saved with validation loss: 0.6939
   --- Epoch 2/5 ---


Training:   0%|          | 0/448 [00:00<?, ?it/s]

Validation:   0%|          | 0/86 [00:00<?, ?it/s]

   Epoch 2: Train Loss = 0.4817, Valid Loss = 0.6908, LR = 0.000905
  New best model saved with validation loss: 0.6908
   --- Epoch 3/5 ---


Training:   0%|          | 0/448 [00:00<?, ?it/s]

Validation:   0%|          | 0/86 [00:00<?, ?it/s]

   Epoch 3: Train Loss = 0.3458, Valid Loss = 0.6709, LR = 0.000655
  New best model saved with validation loss: 0.6709
   --- Epoch 4/5 ---


Training:   0%|          | 0/448 [00:00<?, ?it/s]

Validation:   0%|          | 0/86 [00:00<?, ?it/s]

   Epoch 4: Train Loss = 0.2116, Valid Loss = 0.6835, LR = 0.000345
   --- Epoch 5/5 ---


Training:   0%|          | 0/448 [00:00<?, ?it/s]

Validation:   0%|          | 0/86 [00:00<?, ?it/s]

   Epoch 5: Train Loss = 0.1255, Valid Loss = 0.6347, LR = 0.000095
  New best model saved with validation loss: 0.6347
   --- Generating OOF predictions for fold 3 ---


OOF Prediction Fold 3:   0%|          | 0/86 [00:00<?, ?it/s]

   Finished OOF predictions for fold 3

Logged artifact for fold 3 with best validation loss: 0.6347


0,1
epoch,▁▃▅▆█
train/epoch_loss,█▅▄▂▁
train/epoch_lr,█▇▅▃▁
train/loss,▇▇▇▇█▅█▅▄▄▂▅▅▅▄▄▅▅▅▄▅▄▃▃▄▃▂▃▃▃▂▃▂▁▁▁▁▂▁▁
val/kl_div,██▅▇▁
val/loss,██▅▇▁

0,1
best_val_kl_div,0.63467
epoch,5.0
train/epoch_loss,0.12552
train/epoch_lr,0.0001
train/loss,0.10869
val/kl_div,0.63467
val/loss,0.63467





  model = create_fn(
2025-10-24 12:35:39,067 :: timm.models._builder :: INFO :: Loading pretrained weights from Hugging Face hub (timm/tf_efficientnet_b0.ns_jft_in1k)
2025-10-24 12:35:39,298 :: timm.models._hub :: INFO :: [timm/tf_efficientnet_b0.ns_jft_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
2025-10-24 12:35:39,371 :: timm.models._builder :: INFO :: Missing keys (classifier.weight, classifier.bias) discovered while loading pretrained weights. This is expected if model is being adapted.


   --- Epoch 1/5 ---


Training:   0%|          | 0/427 [00:00<?, ?it/s]

Validation:   0%|          | 0/107 [00:00<?, ?it/s]

   Epoch 1: Train Loss = 0.7238, Valid Loss = 0.7401, LR = 0.001000
  New best model saved with validation loss: 0.7401
   --- Epoch 2/5 ---


Training:   0%|          | 0/427 [00:00<?, ?it/s]

Validation:   0%|          | 0/107 [00:00<?, ?it/s]

   Epoch 2: Train Loss = 0.4910, Valid Loss = 0.7232, LR = 0.000905
  New best model saved with validation loss: 0.7232
   --- Epoch 3/5 ---


Training:   0%|          | 0/427 [00:00<?, ?it/s]

Validation:   0%|          | 0/107 [00:00<?, ?it/s]

   Epoch 3: Train Loss = 0.3615, Valid Loss = 0.7341, LR = 0.000655
   --- Epoch 4/5 ---


Training:   0%|          | 0/427 [00:00<?, ?it/s]

Validation:   0%|          | 0/107 [00:00<?, ?it/s]

   Epoch 4: Train Loss = 0.2142, Valid Loss = 0.7659, LR = 0.000345
   --- Epoch 5/5 ---


Training:   0%|          | 0/427 [00:00<?, ?it/s]

Validation:   0%|          | 0/107 [00:00<?, ?it/s]

   Epoch 5: Train Loss = 0.1315, Valid Loss = 0.7068, LR = 0.000095
  New best model saved with validation loss: 0.7068
   --- Generating OOF predictions for fold 4 ---


OOF Prediction Fold 4:   0%|          | 0/107 [00:00<?, ?it/s]

   Finished OOF predictions for fold 4

Logged artifact for fold 4 with best validation loss: 0.7068


0,1
epoch,▁▃▅▆█
train/epoch_loss,█▅▄▂▁
train/epoch_lr,█▇▅▃▁
train/loss,█▄▃▃▃▃▃▂▃▃▃▂▂▂▂▂▂▃▂▂▁▂▂▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁
val/kl_div,▅▃▄█▁
val/loss,▅▃▄█▁

0,1
best_val_kl_div,0.70684
epoch,5.0
train/epoch_loss,0.13145
train/epoch_lr,0.0001
train/loss,0.08772
val/kl_div,0.70684
val/loss,0.70684



Calculating final OOF score...

Overall OOF KL Score: 0.6762


In [12]:
print(f"Overall OOF KL Score from training: {overall_oof_score:.4f}")

Overall OOF KL Score from training: 0.6762
