In [1]:
DATA_PREPARATION_VOTE_METHOD = "max_vote_window" # "max_vote_window" or "sum_and_normalize". Decides how to aggregate the predictions of the overlapping windows

In [2]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm 
import wandb
from torch.optim.lr_scheduler import CosineAnnealingLR
import sys

if bool(os.environ.get("KAGGLE_URL_BASE", "")):
  import sys
  # running on kaggle
  sys.path.insert(0, "/kaggle/input/hsm-source-files")
else:
  # running locally
  sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "..", "..", "..")))

from src.datasets.cbramod_dataset import CBraModDataset
from src.utils.k_folds_creator import KFoldCreator
from src.utils.utils import get_models_save_path, set_seeds, get_raw_data_dir, get_processed_data_dir
from src.models.cbramod_model import CBraModModel
from src.utils.constants import Constants 
from src.datasets.eeg_processor import EEGDataProcessor
from huggingface_hub import hf_hub_download

set_seeds(Constants.SEED)

2025-11-09 16:14:00,669 :: root :: INFO :: Initialising Utils
2025-11-09 16:14:02,419 :: root :: INFO :: Initialising Datasets
2025-11-09 16:14:03,023 :: root :: INFO :: Initialising Models


Skipping module tcn due to missing dependency: No module named 'pytorch_tcn'


In [3]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mdavidhodel[0m ([33mhms-hslu-aicomp-hs25[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [4]:
DATA_PATH = get_raw_data_dir()

processor = EEGDataProcessor(raw_data_path=DATA_PATH, processed_data_path=get_processed_data_dir())
train_df = processor.process_data(vote_method=DATA_PREPARATION_VOTE_METHOD, skip_parquet=True)

kl_score = nn.KLDivLoss(reduction="batchmean")

Processor initialized.
Raw data path: '/home/david/git/aicomp/data'
Processed data path: '/home/david/git/aicomp/data/processed'
Starting EEG Data Processing Pipeline
Skipping Parquet file creation as requested.
Using 'max_vote_window' vote aggregation strategy.

Processed train data saved to '/home/david/git/aicomp/data/processed/train_processed.csv'.
Shape of the final dataframe: (17089, 12)

Pipeline finished successfully!


In [5]:
pretrained_weights_path = hf_hub_download("weighting666/CBraMod", "pretrained_weights.pth", repo_type="model")
print(f"Pretrained weights downloaded to: {pretrained_weights_path}")

Pretrained weights downloaded to: /home/david/.cache/huggingface/hub/models--weighting666--CBraMod/snapshots/500543c7e30bda1b22bfd51a49301b238dee21fd/pretrained_weights.pth


In [6]:
fold_creator = KFoldCreator(n_splits=5, seed=Constants.SEED)
train_df = fold_creator.create_folds(train_df, stratify_col='expert_consensus', group_col='patient_id')

In [7]:
class Config:
  window_size_seconds = 5
  eeg_normalization = "naive"
  batch_size = 64
  num_dataset_workers = 8
  dropout_prob = 0.1
  learning_rate = 1*10**(-4)
  weight_decay = 5e-2
  num_epochs = 15
  clip_grad_norm = 1.0
  apply_preprocessing = False

class WandbConfig:
  entity = "hms-hslu-aicomp-hs25"
  project_name = "hms-aicomp-cbramod"
  run_name = f"cbramod-finetune_window_{Config.window_size_seconds}s_bs{Config.batch_size}_lr{Config.learning_rate}_wd{Config.weight_decay}_dropout{Config.dropout_prob}_norm-{Config.eeg_normalization}"

In [8]:
def get_dataloaders(df, fold_id):
    fold_train_df = df[df['fold'] != fold_id].reset_index(drop=True)
    fold_valid_df = df[df['fold'] == fold_id].reset_index(drop=True)

    train_dataset = CBraModDataset(fold_train_df, DATA_PATH, Config.window_size_seconds, eeg_frequency=200, mode='train', normalization=Config.eeg_normalization, apply_preprocessing=Config.apply_preprocessing)
    valid_dataset = CBraModDataset(fold_valid_df, DATA_PATH, Config.window_size_seconds, eeg_frequency=200, mode='train', normalization=Config.eeg_normalization, apply_preprocessing=Config.apply_preprocessing)

    train_loader = DataLoader(
        train_dataset,
        batch_size=Config.batch_size,
        shuffle=True,
        num_workers=Config.num_dataset_workers,
        pin_memory=True,
        drop_last=True,
    )
    valid_loader = DataLoader(
        valid_dataset,
        batch_size=Config.batch_size,
        shuffle=False,
        num_workers=Config.num_dataset_workers,
        pin_memory=True,
        drop_last=False,
    )

    return train_loader, valid_loader

In [9]:
def create_model(device):
  model = CBraModModel(
    pretrained_weights_path=pretrained_weights_path,
    classifier_type="all_patch_reps",
    num_of_classes=len(Constants.TARGETS),
    dropout_prob=Config.dropout_prob,
    num_eeg_channels=len(Constants.EEG_FEATURES),
    seq_len_seconds=Config.window_size_seconds,
    device=device,
  ).to(device)

  backbone_params = []
  other_params = []
  for name, param in model.named_parameters():
      if "backbone" in name:
          backbone_params.append(param)
          param.requires_grad = True
      else:
          other_params.append(param)

  return model, backbone_params, other_params

def create_optimizer(backbone_params, other_params, lr, weight_decay):
  optimizer = torch.optim.AdamW([
      {'params': backbone_params, 'lr': lr},
      {'params': other_params, 'lr': lr * 5}
  ], weight_decay=weight_decay)
  return optimizer

# def create_lr_scheduler(optimizer, data_length, num_epochs):
#   total_steps = num_epochs * data_length
#   scheduler = CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=1e-6)
#   return scheduler

def create_lr_scheduler(optimizer, data_length, num_epochs, lr):
   total_steps = num_epochs * data_length
   warmup_steps = total_steps // 10
   return torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=lr,
            total_steps=total_steps,
            pct_start=warmup_steps / total_steps,
            div_factor=100,
            final_div_factor=100,
            anneal_strategy="linear"
    )

In [10]:
loss_fn = nn.KLDivLoss(reduction='batchmean')

def run_training():
  device = torch.device("cuda")

  all_oof_preds = []
  all_oof_labels = []
  for fold_id in range(fold_creator.n_splits):
    print(f"\n========== FOLD {fold_id} ==========")
    torch.cuda.empty_cache()

    run = wandb.init(
        entity=WandbConfig.entity,
        project=WandbConfig.project_name,
        name=f"{WandbConfig.run_name}_fold{fold_id}", 
        tags=[f'fold{fold_id}'],
        config= {
            "window_size_seconds": Config.window_size_seconds,
            "eeg_normalization": Config.eeg_normalization,
            "batch_size": Config.batch_size,
            "dropout_prob": Config.dropout_prob,
            "learning_rate": Config.learning_rate,
            "weight_decay": Config.weight_decay,
            "num_epochs": Config.num_epochs,
            "clip_grad_norm": Config.clip_grad_norm,
            "data_preparation_vote_method": DATA_PREPARATION_VOTE_METHOD,
        }
    )


    model, backbone_params, other_params = create_model(device)

    optimizer = create_optimizer(backbone_params, other_params, Config.learning_rate, Config.weight_decay)
    train_loader, valid_loader = get_dataloaders(train_df, fold_id)
    scheduler = create_lr_scheduler(optimizer, len(train_loader), Config.num_epochs, Config.learning_rate)

    best_val_loss = float('inf')
    best_model_path = None

    for epoch in range(Config.num_epochs):
      model.train()
      train_loss = 0

      for eeg_windows, labels in tqdm(train_loader, desc=f"Fold {fold_id} Epoch {epoch} Training"):
          eeg_windows, labels = eeg_windows.to(device), labels.to(device)

          optimizer.zero_grad()
          outputs = model(eeg_windows)
          log_probs = F.log_softmax(outputs, dim=1)
          loss = loss_fn(log_probs, labels)
          loss.backward()
          if Config.clip_grad_norm > 0:
              torch.nn.utils.clip_grad_norm_(model.parameters(), Config.clip_grad_norm)
          optimizer.step()
          scheduler.step()

          train_loss += loss.item() * eeg_windows.size(0)

          run.log({
             "train/loss": loss.item(),
             "train/lr": scheduler.get_last_lr()[0],
          })

      train_loss /= len(train_loader.dataset)

      model.eval()
      valid_loss = 0
      with torch.no_grad():
        for eeg_windows, labels in tqdm(valid_loader, desc=f"Fold {fold_id} Epoch {epoch} Validation"):
            eeg_windows, labels = eeg_windows.to(device), labels.to(device)

            outputs = model(eeg_windows)
            log_probs = F.log_softmax(outputs, dim=1)
            loss = loss_fn(log_probs, labels)

            valid_loss += loss.item() * eeg_windows.size(0)

      valid_loss /= len(valid_loader.dataset)
      print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}, Valid Loss = {valid_loss:.4f}")

      run.log({
          "epoch": epoch + 1,
          "train/epoch_loss": train_loss,
          "val/loss": valid_loss,
          "val/kl_div": valid_loss,
      })

      if valid_loss < best_val_loss:
          best_val_loss = valid_loss
          best_model_path = get_models_save_path() / "cbramod" / DATA_PREPARATION_VOTE_METHOD / f"best_model_fold_{fold_id}.pth"
          best_model_path.parent.mkdir(parents=True, exist_ok=True)
          torch.save(model.state_dict(), best_model_path)

      if best_model_path is not None:
          model.load_state_dict(torch.load(best_model_path))
          model.eval()

          fold_oof_preds = []
          fold_oof_labels = []

          with torch.no_grad():
            for eeg_windows, labels in tqdm(valid_loader, desc=f"Fold {fold_id} OOF Predictions"):
                eeg_windows = eeg_windows.to(device)

                outputs = model(eeg_windows)
                probs = F.softmax(outputs, dim=1).cpu().numpy()

                fold_oof_preds.append(probs)
                fold_oof_labels.append(labels.numpy())

          all_oof_preds.append(np.concatenate(fold_oof_preds, axis=0))
          all_oof_labels.append(np.concatenate(fold_oof_labels, axis=0))
      else:
          raise RuntimeError("Best model path is None, cannot generate OOF predictions.")
      
    if best_model_path is not None:
        artifact = wandb.Artifact(f"{run.name}-{run.id}", type='model')
        artifact.add_file(str(best_model_path))
        run.log_artifact(artifact)
    else:
       raise RuntimeError("Best model path is None, cannot log model artifact.")

    run.summary['best_val_kl_div'] = best_val_loss
    run.finish()
    # break # used to debugging for only one fold

  if all_oof_preds and all_oof_labels:
      print("\nCalculating final OOF score...")
      final_oof_preds = np.concatenate(all_oof_preds)
      final_oof_labels = np.concatenate(all_oof_labels)

      oof_preds_tensor = torch.tensor(final_oof_preds, dtype=torch.float32)
      oof_labels_tensor = torch.tensor(final_oof_labels, dtype=torch.float32)

      log_oof_preds_tensor = torch.log(oof_preds_tensor + 1e-8)

      overall_oof_score = loss_fn(log_oof_preds_tensor, oof_labels_tensor).item()

      print(f"\nOverall OOF KL Score: {overall_oof_score:.4f}")
  else:
      print("\nCould not calculate OOF score because no predictions were generated.")
      
  return overall_oof_score

run_training()




Fold 0 Epoch 0 Training: 100%|██████████| 203/203 [00:42<00:00,  4.73it/s]
Fold 0 Epoch 0 Validation: 100%|██████████| 64/64 [00:07<00:00,  8.24it/s]


Epoch 0: Train Loss = 1.1326, Valid Loss = 1.3967


Fold 0 OOF Predictions: 100%|██████████| 64/64 [00:05<00:00, 12.51it/s]
Fold 0 Epoch 1 Training: 100%|██████████| 203/203 [00:40<00:00,  5.00it/s]
Fold 0 Epoch 1 Validation: 100%|██████████| 64/64 [00:05<00:00, 12.58it/s]


Epoch 1: Train Loss = 0.9474, Valid Loss = 1.1923


Fold 0 OOF Predictions: 100%|██████████| 64/64 [00:05<00:00, 11.91it/s]
Fold 0 Epoch 2 Training: 100%|██████████| 203/203 [00:40<00:00,  5.03it/s]
Fold 0 Epoch 2 Validation: 100%|██████████| 64/64 [00:05<00:00, 11.55it/s]


Epoch 2: Train Loss = 0.8161, Valid Loss = 1.1765


Fold 0 OOF Predictions: 100%|██████████| 64/64 [00:05<00:00, 11.96it/s]
Fold 0 Epoch 3 Training: 100%|██████████| 203/203 [00:40<00:00,  5.07it/s]
Fold 0 Epoch 3 Validation: 100%|██████████| 64/64 [00:05<00:00, 12.29it/s]


Epoch 3: Train Loss = 0.7401, Valid Loss = 1.1369


Fold 0 OOF Predictions: 100%|██████████| 64/64 [00:05<00:00, 12.34it/s]
Fold 0 Epoch 4 Training: 100%|██████████| 203/203 [00:40<00:00,  5.06it/s]
Fold 0 Epoch 4 Validation: 100%|██████████| 64/64 [00:05<00:00, 12.48it/s]


Epoch 4: Train Loss = 0.6737, Valid Loss = 1.1069


Fold 0 OOF Predictions: 100%|██████████| 64/64 [00:05<00:00, 12.29it/s]
Fold 0 Epoch 5 Training: 100%|██████████| 203/203 [00:39<00:00,  5.09it/s]
Fold 0 Epoch 5 Validation: 100%|██████████| 64/64 [00:05<00:00, 12.55it/s]


Epoch 5: Train Loss = 0.6181, Valid Loss = 1.1469


Fold 0 OOF Predictions: 100%|██████████| 64/64 [00:05<00:00, 12.40it/s]
Fold 0 Epoch 6 Training: 100%|██████████| 203/203 [00:39<00:00,  5.08it/s]
Fold 0 Epoch 6 Validation: 100%|██████████| 64/64 [00:05<00:00, 11.80it/s]


Epoch 6: Train Loss = 0.6091, Valid Loss = 1.1294


Fold 0 OOF Predictions: 100%|██████████| 64/64 [00:05<00:00, 11.75it/s]
Fold 0 Epoch 7 Training: 100%|██████████| 203/203 [00:40<00:00,  5.07it/s]
Fold 0 Epoch 7 Validation: 100%|██████████| 64/64 [00:05<00:00, 11.47it/s]


Epoch 7: Train Loss = 0.6072, Valid Loss = 1.1341


Fold 0 OOF Predictions: 100%|██████████| 64/64 [00:05<00:00, 11.77it/s]
Fold 0 Epoch 8 Training: 100%|██████████| 203/203 [00:40<00:00,  5.02it/s]
Fold 0 Epoch 8 Validation: 100%|██████████| 64/64 [00:05<00:00, 11.78it/s]


Epoch 8: Train Loss = 0.6031, Valid Loss = 1.0277


Fold 0 OOF Predictions: 100%|██████████| 64/64 [00:05<00:00, 11.99it/s]
Fold 0 Epoch 9 Training: 100%|██████████| 203/203 [00:40<00:00,  5.04it/s]
Fold 0 Epoch 9 Validation: 100%|██████████| 64/64 [00:05<00:00, 10.93it/s]


Epoch 9: Train Loss = 0.5560, Valid Loss = 1.1302


Fold 0 OOF Predictions: 100%|██████████| 64/64 [00:05<00:00, 11.04it/s]
Fold 0 Epoch 10 Training: 100%|██████████| 203/203 [00:40<00:00,  4.99it/s]
Fold 0 Epoch 10 Validation: 100%|██████████| 64/64 [00:05<00:00, 12.06it/s]


Epoch 10: Train Loss = 0.5490, Valid Loss = 1.0527


Fold 0 OOF Predictions: 100%|██████████| 64/64 [00:05<00:00, 10.94it/s]
Fold 0 Epoch 11 Training: 100%|██████████| 203/203 [00:40<00:00,  4.97it/s]
Fold 0 Epoch 11 Validation: 100%|██████████| 64/64 [00:05<00:00, 11.81it/s]


Epoch 11: Train Loss = 0.5460, Valid Loss = 1.1091


Fold 0 OOF Predictions: 100%|██████████| 64/64 [00:05<00:00, 11.32it/s]
Fold 0 Epoch 12 Training: 100%|██████████| 203/203 [00:40<00:00,  4.99it/s]
Fold 0 Epoch 12 Validation: 100%|██████████| 64/64 [00:05<00:00, 12.73it/s]


Epoch 12: Train Loss = 0.5366, Valid Loss = 1.0944


Fold 0 OOF Predictions: 100%|██████████| 64/64 [00:05<00:00, 12.62it/s]
Fold 0 Epoch 13 Training: 100%|██████████| 203/203 [00:40<00:00,  5.03it/s]
Fold 0 Epoch 13 Validation: 100%|██████████| 64/64 [00:05<00:00, 12.18it/s]


Epoch 13: Train Loss = 0.5349, Valid Loss = 1.0533


Fold 0 OOF Predictions: 100%|██████████| 64/64 [00:05<00:00, 12.09it/s]
Fold 0 Epoch 14 Training: 100%|██████████| 203/203 [00:40<00:00,  5.02it/s]
Fold 0 Epoch 14 Validation: 100%|██████████| 64/64 [00:05<00:00, 12.46it/s]


Epoch 14: Train Loss = 0.5311, Valid Loss = 1.0121


Fold 0 OOF Predictions: 100%|██████████| 64/64 [00:05<00:00, 12.34it/s]


0,1
epoch,▁▁▂▃▃▃▄▅▅▅▆▇▇▇█
train/epoch_loss,█▆▄▃▃▂▂▂▂▁▁▁▁▁▁
train/loss,██▆▆▄▄▅▅▄▅▃▃▃▃▃▃▂▂▄▃▂▃▂▂▁▃▂▂▃▁▂▃▁▂▂▂▃▂▂▂
train/lr,▁▂▄▅█████▇▇▇▇▆▆▆▅▅▅▅▅▅▅▄▄▄▄▄▄▄▃▃▂▂▂▂▂▂▁▁
val/kl_div,█▄▄▃▃▃▃▃▁▃▂▃▂▂▁
val/loss,█▄▄▃▃▃▃▃▁▃▂▃▂▂▁

0,1
best_val_kl_div,1.01205
epoch,15.0
train/epoch_loss,0.53108
train/loss,0.5482
train/lr,-0.0
val/kl_div,1.01205
val/loss,1.01205





Fold 1 Epoch 0 Training: 100%|██████████| 209/209 [00:42<00:00,  4.87it/s]
Fold 1 Epoch 0 Validation: 100%|██████████| 58/58 [00:04<00:00, 12.67it/s]


Epoch 0: Train Loss = 1.1425, Valid Loss = 1.2313


Fold 1 OOF Predictions: 100%|██████████| 58/58 [00:04<00:00, 12.83it/s]
Fold 1 Epoch 1 Training: 100%|██████████| 209/209 [00:43<00:00,  4.77it/s]
Fold 1 Epoch 1 Validation: 100%|██████████| 58/58 [00:04<00:00, 12.23it/s]


Epoch 1: Train Loss = 0.9343, Valid Loss = 1.1268


Fold 1 OOF Predictions: 100%|██████████| 58/58 [00:04<00:00, 11.74it/s]
Fold 1 Epoch 2 Training: 100%|██████████| 209/209 [00:43<00:00,  4.82it/s]
Fold 1 Epoch 2 Validation: 100%|██████████| 58/58 [00:04<00:00, 12.82it/s]


Epoch 2: Train Loss = 0.8027, Valid Loss = 1.1576


Fold 1 OOF Predictions: 100%|██████████| 58/58 [00:04<00:00, 12.48it/s]
Fold 1 Epoch 3 Training: 100%|██████████| 209/209 [00:43<00:00,  4.82it/s]
Fold 1 Epoch 3 Validation: 100%|██████████| 58/58 [00:04<00:00, 11.72it/s]


Epoch 3: Train Loss = 0.8084, Valid Loss = 1.1367


Fold 1 OOF Predictions: 100%|██████████| 58/58 [00:04<00:00, 12.60it/s]
Fold 1 Epoch 4 Training: 100%|██████████| 209/209 [00:43<00:00,  4.77it/s]
Fold 1 Epoch 4 Validation: 100%|██████████| 58/58 [00:04<00:00, 12.85it/s]


Epoch 4: Train Loss = 0.8044, Valid Loss = 1.0755


Fold 1 OOF Predictions: 100%|██████████| 58/58 [00:04<00:00, 12.45it/s]
Fold 1 Epoch 5 Training: 100%|██████████| 209/209 [00:43<00:00,  4.79it/s]
Fold 1 Epoch 5 Validation: 100%|██████████| 58/58 [00:04<00:00, 12.45it/s]


Epoch 5: Train Loss = 0.7292, Valid Loss = 1.0758


Fold 1 OOF Predictions: 100%|██████████| 58/58 [00:05<00:00, 10.60it/s]
Fold 1 Epoch 6 Training: 100%|██████████| 209/209 [00:43<00:00,  4.83it/s]
Fold 1 Epoch 6 Validation: 100%|██████████| 58/58 [00:04<00:00, 12.74it/s]


Epoch 6: Train Loss = 0.7264, Valid Loss = 1.0748


Fold 1 OOF Predictions: 100%|██████████| 58/58 [00:04<00:00, 12.63it/s]
Fold 1 Epoch 7 Training: 100%|██████████| 209/209 [00:43<00:00,  4.80it/s]
Fold 1 Epoch 7 Validation: 100%|██████████| 58/58 [00:04<00:00, 12.51it/s]


Epoch 7: Train Loss = 0.6734, Valid Loss = 1.0353


Fold 1 OOF Predictions: 100%|██████████| 58/58 [00:05<00:00, 10.93it/s]
Fold 1 Epoch 8 Training: 100%|██████████| 209/209 [00:43<00:00,  4.80it/s]
Fold 1 Epoch 8 Validation: 100%|██████████| 58/58 [00:04<00:00, 12.31it/s]


Epoch 8: Train Loss = 0.6237, Valid Loss = 1.0571


Fold 1 OOF Predictions: 100%|██████████| 58/58 [00:05<00:00,  9.97it/s]
Fold 1 Epoch 9 Training: 100%|██████████| 209/209 [00:43<00:00,  4.83it/s]
Fold 1 Epoch 9 Validation: 100%|██████████| 58/58 [00:04<00:00, 11.73it/s]


Epoch 9: Train Loss = 0.6181, Valid Loss = 0.9977


Fold 1 OOF Predictions: 100%|██████████| 58/58 [00:04<00:00, 11.79it/s]
Fold 1 Epoch 10 Training: 100%|██████████| 209/209 [00:43<00:00,  4.83it/s]
Fold 1 Epoch 10 Validation: 100%|██████████| 58/58 [00:04<00:00, 12.44it/s]


Epoch 10: Train Loss = 0.5812, Valid Loss = 0.9990


Fold 1 OOF Predictions: 100%|██████████| 58/58 [00:04<00:00, 12.18it/s]
Fold 1 Epoch 11 Training: 100%|██████████| 209/209 [00:43<00:00,  4.78it/s]
Fold 1 Epoch 11 Validation: 100%|██████████| 58/58 [00:05<00:00, 11.24it/s]


Epoch 11: Train Loss = 0.5756, Valid Loss = 1.0047


Fold 1 OOF Predictions: 100%|██████████| 58/58 [00:04<00:00, 12.20it/s]
Fold 1 Epoch 12 Training: 100%|██████████| 209/209 [00:43<00:00,  4.84it/s]
Fold 1 Epoch 12 Validation: 100%|██████████| 58/58 [00:04<00:00, 12.25it/s]


Epoch 12: Train Loss = 0.5674, Valid Loss = 1.0041


Fold 1 OOF Predictions: 100%|██████████| 58/58 [00:04<00:00, 11.94it/s]
Fold 1 Epoch 13 Training: 100%|██████████| 209/209 [00:44<00:00,  4.74it/s]
Fold 1 Epoch 13 Validation: 100%|██████████| 58/58 [00:04<00:00, 11.99it/s]


Epoch 13: Train Loss = 0.5619, Valid Loss = 1.0167


Fold 1 OOF Predictions: 100%|██████████| 58/58 [00:05<00:00, 10.78it/s]
Fold 1 Epoch 14 Training: 100%|██████████| 209/209 [00:44<00:00,  4.71it/s]
Fold 1 Epoch 14 Validation: 100%|██████████| 58/58 [00:04<00:00, 12.42it/s]


Epoch 14: Train Loss = 0.5580, Valid Loss = 1.0106


Fold 1 OOF Predictions: 100%|██████████| 58/58 [00:04<00:00, 12.43it/s]


0,1
epoch,▁▁▂▃▃▃▄▅▅▅▆▇▇▇█
train/epoch_loss,█▆▄▄▄▃▃▂▂▂▁▁▁▁▁
train/loss,█▅▆▄▄▄▄▄▄▃▂▂▃▃▂▄▃▂▂▃▂▂▂▃▃▃▂▂▂▂▂▂▂▂▂▃▁▂▁▂
train/lr,▂▃█████▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▄▄▄▃▃▃▃▃▃▃▃▂▂▂▂▁▁▁
val/kl_div,█▅▆▅▃▃▃▂▃▁▁▁▁▂▁
val/loss,█▅▆▅▃▃▃▂▃▁▁▁▁▂▁

0,1
best_val_kl_div,0.99768
epoch,15.0
train/epoch_loss,0.55803
train/loss,0.68784
train/lr,-0.0
val/kl_div,1.01059
val/loss,1.01059





Fold 2 Epoch 0 Training: 100%|██████████| 214/214 [00:43<00:00,  4.89it/s]
Fold 2 Epoch 0 Validation: 100%|██████████| 53/53 [00:04<00:00, 11.74it/s]


Epoch 0: Train Loss = 1.1472, Valid Loss = 1.2286


Fold 2 OOF Predictions: 100%|██████████| 53/53 [00:05<00:00, 10.52it/s]
Fold 2 Epoch 1 Training: 100%|██████████| 214/214 [00:44<00:00,  4.86it/s]
Fold 2 Epoch 1 Validation: 100%|██████████| 53/53 [00:04<00:00, 12.02it/s]


Epoch 1: Train Loss = 0.9470, Valid Loss = 1.1372


Fold 2 OOF Predictions: 100%|██████████| 53/53 [00:04<00:00, 11.29it/s]
Fold 2 Epoch 2 Training: 100%|██████████| 214/214 [00:44<00:00,  4.83it/s]
Fold 2 Epoch 2 Validation: 100%|██████████| 53/53 [00:04<00:00, 11.46it/s]


Epoch 2: Train Loss = 0.8191, Valid Loss = 1.0544


Fold 2 OOF Predictions: 100%|██████████| 53/53 [00:04<00:00, 12.00it/s]
Fold 2 Epoch 3 Training: 100%|██████████| 214/214 [00:44<00:00,  4.79it/s]
Fold 2 Epoch 3 Validation: 100%|██████████| 53/53 [00:04<00:00, 11.81it/s]


Epoch 3: Train Loss = 0.7334, Valid Loss = 1.0439


Fold 2 OOF Predictions: 100%|██████████| 53/53 [00:04<00:00, 10.67it/s]
Fold 2 Epoch 4 Training: 100%|██████████| 214/214 [00:44<00:00,  4.78it/s]
Fold 2 Epoch 4 Validation: 100%|██████████| 53/53 [00:04<00:00, 10.73it/s]


Epoch 4: Train Loss = 0.6762, Valid Loss = 1.0122


Fold 2 OOF Predictions: 100%|██████████| 53/53 [00:05<00:00,  9.93it/s]
Fold 2 Epoch 5 Training: 100%|██████████| 214/214 [00:44<00:00,  4.83it/s]
Fold 2 Epoch 5 Validation: 100%|██████████| 53/53 [00:04<00:00, 11.40it/s]


Epoch 5: Train Loss = 0.6145, Valid Loss = 1.0090


Fold 2 OOF Predictions: 100%|██████████| 53/53 [00:04<00:00, 11.54it/s]
Fold 2 Epoch 6 Training: 100%|██████████| 214/214 [00:44<00:00,  4.82it/s]
Fold 2 Epoch 6 Validation: 100%|██████████| 53/53 [00:04<00:00, 11.76it/s]


Epoch 6: Train Loss = 0.5678, Valid Loss = 0.9731


Fold 2 OOF Predictions: 100%|██████████| 53/53 [00:04<00:00, 11.49it/s]
Fold 2 Epoch 7 Training: 100%|██████████| 214/214 [00:44<00:00,  4.78it/s]
Fold 2 Epoch 7 Validation: 100%|██████████| 53/53 [00:04<00:00, 11.52it/s]


Epoch 7: Train Loss = 0.5203, Valid Loss = 1.0213


Fold 2 OOF Predictions: 100%|██████████| 53/53 [00:04<00:00, 12.22it/s]
Fold 2 Epoch 8 Training: 100%|██████████| 214/214 [00:44<00:00,  4.81it/s]
Fold 2 Epoch 8 Validation: 100%|██████████| 53/53 [00:04<00:00, 11.68it/s]


Epoch 8: Train Loss = 0.5169, Valid Loss = 1.0186


Fold 2 OOF Predictions: 100%|██████████| 53/53 [00:04<00:00, 11.64it/s]
Fold 2 Epoch 9 Training: 100%|██████████| 214/214 [00:44<00:00,  4.85it/s]
Fold 2 Epoch 9 Validation: 100%|██████████| 53/53 [00:04<00:00, 12.18it/s]


Epoch 9: Train Loss = 0.5109, Valid Loss = 0.9705


Fold 2 OOF Predictions: 100%|██████████| 53/53 [00:04<00:00, 11.37it/s]
Fold 2 Epoch 10 Training: 100%|██████████| 214/214 [00:44<00:00,  4.81it/s]
Fold 2 Epoch 10 Validation: 100%|██████████| 53/53 [00:04<00:00, 11.84it/s]


Epoch 10: Train Loss = 0.4704, Valid Loss = 1.0184


Fold 2 OOF Predictions: 100%|██████████| 53/53 [00:04<00:00, 11.82it/s]
Fold 2 Epoch 11 Training: 100%|██████████| 214/214 [00:44<00:00,  4.80it/s]
Fold 2 Epoch 11 Validation: 100%|██████████| 53/53 [00:04<00:00, 10.91it/s]


Epoch 11: Train Loss = 0.4667, Valid Loss = 1.0001


Fold 2 OOF Predictions: 100%|██████████| 53/53 [00:04<00:00, 12.04it/s]
Fold 2 Epoch 12 Training: 100%|██████████| 214/214 [00:44<00:00,  4.82it/s]
Fold 2 Epoch 12 Validation: 100%|██████████| 53/53 [00:04<00:00, 11.97it/s]


Epoch 12: Train Loss = 0.4625, Valid Loss = 0.9678


Fold 2 OOF Predictions: 100%|██████████| 53/53 [00:04<00:00, 11.99it/s]
Fold 2 Epoch 13 Training: 100%|██████████| 214/214 [00:44<00:00,  4.76it/s]
Fold 2 Epoch 13 Validation: 100%|██████████| 53/53 [00:04<00:00, 11.03it/s]


Epoch 13: Train Loss = 0.4350, Valid Loss = 0.9963


Fold 2 OOF Predictions: 100%|██████████| 53/53 [00:04<00:00, 12.23it/s]
Fold 2 Epoch 14 Training: 100%|██████████| 214/214 [00:43<00:00,  4.89it/s]
Fold 2 Epoch 14 Validation: 100%|██████████| 53/53 [00:04<00:00, 12.36it/s]


Epoch 14: Train Loss = 0.4281, Valid Loss = 0.9886


Fold 2 OOF Predictions: 100%|██████████| 53/53 [00:04<00:00, 11.47it/s]


0,1
epoch,▁▁▂▃▃▃▄▅▅▅▆▇▇▇█
train/epoch_loss,█▆▅▄▃▃▂▂▂▂▁▁▁▁▁
train/loss,▇█▆▅▆▅▄▅▄▅▅▃▂▄▃▄▂▂▄▂▃▁▂▃▂▂▂▂▂▂▁▂▁▂▂▁▁▁▁▁
train/lr,▂▃▅▆█▇▇▇▇▇▇▇▆▆▆▅▅▅▅▅▅▄▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▁▁
val/kl_div,█▆▃▃▂▂▁▂▂▁▂▂▁▂▂
val/loss,█▆▃▃▂▂▁▂▂▁▂▂▁▂▂

0,1
best_val_kl_div,0.96778
epoch,15.0
train/epoch_loss,0.4281
train/loss,0.40737
train/lr,-0.0
val/kl_div,0.98864
val/loss,0.98864





Fold 3 Epoch 0 Training: 100%|██████████| 226/226 [00:44<00:00,  5.07it/s]
Fold 3 Epoch 0 Validation: 100%|██████████| 42/42 [00:03<00:00, 11.70it/s]


Epoch 0: Train Loss = 1.1414, Valid Loss = 1.2482


Fold 3 OOF Predictions: 100%|██████████| 42/42 [00:03<00:00, 12.16it/s]
Fold 3 Epoch 1 Training: 100%|██████████| 226/226 [00:44<00:00,  5.08it/s]
Fold 3 Epoch 1 Validation: 100%|██████████| 42/42 [00:03<00:00, 11.07it/s]


Epoch 1: Train Loss = 0.9531, Valid Loss = 0.9891


Fold 3 OOF Predictions: 100%|██████████| 42/42 [00:03<00:00, 11.19it/s]
Fold 3 Epoch 2 Training: 100%|██████████| 226/226 [00:44<00:00,  5.03it/s]
Fold 3 Epoch 2 Validation: 100%|██████████| 42/42 [00:03<00:00, 11.76it/s]


Epoch 2: Train Loss = 0.8111, Valid Loss = 1.0006


Fold 3 OOF Predictions: 100%|██████████| 42/42 [00:03<00:00, 11.93it/s]
Fold 3 Epoch 3 Training: 100%|██████████| 226/226 [00:44<00:00,  5.04it/s]
Fold 3 Epoch 3 Validation: 100%|██████████| 42/42 [00:03<00:00, 12.27it/s]


Epoch 3: Train Loss = 0.8073, Valid Loss = 0.9965


Fold 3 OOF Predictions: 100%|██████████| 42/42 [00:03<00:00, 12.28it/s]
Fold 3 Epoch 4 Training: 100%|██████████| 226/226 [00:44<00:00,  5.03it/s]
Fold 3 Epoch 4 Validation: 100%|██████████| 42/42 [00:03<00:00, 12.06it/s]


Epoch 4: Train Loss = 0.8022, Valid Loss = 0.9519


Fold 3 OOF Predictions: 100%|██████████| 42/42 [00:03<00:00, 12.26it/s]
Fold 3 Epoch 5 Training: 100%|██████████| 226/226 [00:45<00:00,  4.98it/s]
Fold 3 Epoch 5 Validation: 100%|██████████| 42/42 [00:03<00:00, 12.27it/s]


Epoch 5: Train Loss = 0.7199, Valid Loss = 1.0238


Fold 3 OOF Predictions: 100%|██████████| 42/42 [00:03<00:00, 12.31it/s]
Fold 3 Epoch 6 Training: 100%|██████████| 226/226 [00:44<00:00,  5.05it/s]
Fold 3 Epoch 6 Validation: 100%|██████████| 42/42 [00:03<00:00, 12.31it/s]


Epoch 6: Train Loss = 0.7183, Valid Loss = 0.9880


Fold 3 OOF Predictions: 100%|██████████| 42/42 [00:03<00:00, 12.02it/s]
Fold 3 Epoch 7 Training: 100%|██████████| 226/226 [00:44<00:00,  5.05it/s]
Fold 3 Epoch 7 Validation: 100%|██████████| 42/42 [00:03<00:00, 12.21it/s]


Epoch 7: Train Loss = 0.7123, Valid Loss = 0.9724


Fold 3 OOF Predictions: 100%|██████████| 42/42 [00:03<00:00, 12.25it/s]
Fold 3 Epoch 8 Training: 100%|██████████| 226/226 [00:45<00:00,  5.01it/s]
Fold 3 Epoch 8 Validation: 100%|██████████| 42/42 [00:03<00:00, 12.45it/s]


Epoch 8: Train Loss = 0.7069, Valid Loss = 0.9828


Fold 3 OOF Predictions: 100%|██████████| 42/42 [00:03<00:00, 12.37it/s]
Fold 3 Epoch 9 Training: 100%|██████████| 226/226 [00:44<00:00,  5.03it/s]
Fold 3 Epoch 9 Validation: 100%|██████████| 42/42 [00:03<00:00, 12.26it/s]


Epoch 9: Train Loss = 0.7016, Valid Loss = 0.9462


Fold 3 OOF Predictions: 100%|██████████| 42/42 [00:03<00:00, 12.31it/s]
Fold 3 Epoch 10 Training: 100%|██████████| 226/226 [00:44<00:00,  5.05it/s]
Fold 3 Epoch 10 Validation: 100%|██████████| 42/42 [00:03<00:00, 11.84it/s]


Epoch 10: Train Loss = 0.6524, Valid Loss = 0.9679


Fold 3 OOF Predictions: 100%|██████████| 42/42 [00:03<00:00, 12.23it/s]
Fold 3 Epoch 11 Training: 100%|██████████| 226/226 [00:45<00:00,  5.02it/s]
Fold 3 Epoch 11 Validation: 100%|██████████| 42/42 [00:03<00:00, 12.37it/s]


Epoch 11: Train Loss = 0.6438, Valid Loss = 0.9395


Fold 3 OOF Predictions: 100%|██████████| 42/42 [00:03<00:00, 12.33it/s]
Fold 3 Epoch 12 Training: 100%|██████████| 226/226 [00:44<00:00,  5.02it/s]
Fold 3 Epoch 12 Validation: 100%|██████████| 42/42 [00:03<00:00, 12.35it/s]


Epoch 12: Train Loss = 0.6090, Valid Loss = 0.9513


Fold 3 OOF Predictions: 100%|██████████| 42/42 [00:03<00:00, 12.24it/s]
Fold 3 Epoch 13 Training: 100%|██████████| 226/226 [00:45<00:00,  5.00it/s]
Fold 3 Epoch 13 Validation: 100%|██████████| 42/42 [00:03<00:00, 12.35it/s]


Epoch 13: Train Loss = 0.6064, Valid Loss = 0.9465


Fold 3 OOF Predictions: 100%|██████████| 42/42 [00:03<00:00, 12.37it/s]
Fold 3 Epoch 14 Training: 100%|██████████| 226/226 [00:45<00:00,  5.02it/s]
Fold 3 Epoch 14 Validation: 100%|██████████| 42/42 [00:03<00:00, 12.18it/s]


Epoch 14: Train Loss = 0.6002, Valid Loss = 0.9421


Fold 3 OOF Predictions: 100%|██████████| 42/42 [00:03<00:00, 12.35it/s]


0,1
epoch,▁▁▂▃▃▃▄▅▅▅▆▇▇▇█
train/epoch_loss,█▆▄▄▄▃▃▂▂▂▂▂▁▁▁
train/loss,█▇▅█▄▄▅▅▄▅▅▄▂▄▃▃▃▃▃▃▃▄▂▄▃▆▂▄▃▂▄▃▁▃▂▂▃▂▂▁
train/lr,▁▄▅███▇▇▇▇▆▆▆▆▆▆▆▆▆▅▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▁▁▁
val/kl_div,█▂▂▂▁▃▂▂▂▁▂▁▁▁▁
val/loss,█▂▂▂▁▃▂▂▂▁▂▁▁▁▁

0,1
best_val_kl_div,0.93949
epoch,15.0
train/epoch_loss,0.6002
train/loss,0.72066
train/lr,-0.0
val/kl_div,0.94209
val/loss,0.94209





Fold 4 Epoch 0 Training: 100%|██████████| 214/214 [00:42<00:00,  5.05it/s]
Fold 4 Epoch 0 Validation: 100%|██████████| 53/53 [00:04<00:00, 12.42it/s]


Epoch 0: Train Loss = 1.1671, Valid Loss = 1.1696


Fold 4 OOF Predictions: 100%|██████████| 53/53 [00:04<00:00, 12.89it/s]
Fold 4 Epoch 1 Training: 100%|██████████| 214/214 [00:42<00:00,  4.99it/s]
Fold 4 Epoch 1 Validation: 100%|██████████| 53/53 [00:04<00:00, 12.74it/s]


Epoch 1: Train Loss = 0.9496, Valid Loss = 1.0039


Fold 4 OOF Predictions: 100%|██████████| 53/53 [00:04<00:00, 12.65it/s]
Fold 4 Epoch 2 Training: 100%|██████████| 214/214 [00:43<00:00,  4.97it/s]
Fold 4 Epoch 2 Validation: 100%|██████████| 53/53 [00:04<00:00, 12.80it/s]


Epoch 2: Train Loss = 0.8163, Valid Loss = 0.9949


Fold 4 OOF Predictions: 100%|██████████| 53/53 [00:04<00:00, 12.73it/s]
Fold 4 Epoch 3 Training: 100%|██████████| 214/214 [00:43<00:00,  4.91it/s]
Fold 4 Epoch 3 Validation: 100%|██████████| 53/53 [00:04<00:00, 13.03it/s]


Epoch 3: Train Loss = 0.7276, Valid Loss = 0.9355


Fold 4 OOF Predictions: 100%|██████████| 53/53 [00:04<00:00, 12.91it/s]
Fold 4 Epoch 4 Training: 100%|██████████| 214/214 [00:42<00:00,  5.00it/s]
Fold 4 Epoch 4 Validation: 100%|██████████| 53/53 [00:04<00:00, 12.75it/s]


Epoch 4: Train Loss = 0.6688, Valid Loss = 0.9492


Fold 4 OOF Predictions: 100%|██████████| 53/53 [00:04<00:00, 12.89it/s]
Fold 4 Epoch 5 Training: 100%|██████████| 214/214 [00:43<00:00,  4.92it/s]
Fold 4 Epoch 5 Validation: 100%|██████████| 53/53 [00:04<00:00, 12.88it/s]


Epoch 5: Train Loss = 0.6638, Valid Loss = 0.9548


Fold 4 OOF Predictions: 100%|██████████| 53/53 [00:04<00:00, 12.63it/s]
Fold 4 Epoch 6 Training: 100%|██████████| 214/214 [00:43<00:00,  4.96it/s]
Fold 4 Epoch 6 Validation: 100%|██████████| 53/53 [00:04<00:00, 12.50it/s]


Epoch 6: Train Loss = 0.6550, Valid Loss = 0.9602


Fold 4 OOF Predictions: 100%|██████████| 53/53 [00:04<00:00, 12.44it/s]
Fold 4 Epoch 7 Training: 100%|██████████| 214/214 [00:42<00:00,  4.99it/s]
Fold 4 Epoch 7 Validation: 100%|██████████| 53/53 [00:04<00:00, 13.06it/s]


Epoch 7: Train Loss = 0.6481, Valid Loss = 0.9463


Fold 4 OOF Predictions: 100%|██████████| 53/53 [00:04<00:00, 12.59it/s]
Fold 4 Epoch 8 Training: 100%|██████████| 214/214 [00:42<00:00,  5.02it/s]
Fold 4 Epoch 8 Validation: 100%|██████████| 53/53 [00:04<00:00, 12.97it/s]


Epoch 8: Train Loss = 0.6442, Valid Loss = 0.9744


Fold 4 OOF Predictions: 100%|██████████| 53/53 [00:04<00:00, 12.84it/s]
Fold 4 Epoch 9 Training: 100%|██████████| 214/214 [00:43<00:00,  4.90it/s]
Fold 4 Epoch 9 Validation: 100%|██████████| 53/53 [00:04<00:00, 12.94it/s]


Epoch 9: Train Loss = 0.6396, Valid Loss = 0.9338


Fold 4 OOF Predictions: 100%|██████████| 53/53 [00:04<00:00, 12.94it/s]
Fold 4 Epoch 10 Training: 100%|██████████| 214/214 [00:43<00:00,  4.95it/s]
Fold 4 Epoch 10 Validation: 100%|██████████| 53/53 [00:04<00:00, 12.85it/s]


Epoch 10: Train Loss = 0.5967, Valid Loss = 0.9701


Fold 4 OOF Predictions: 100%|██████████| 53/53 [00:04<00:00, 13.01it/s]
Fold 4 Epoch 11 Training: 100%|██████████| 214/214 [00:43<00:00,  4.97it/s]
Fold 4 Epoch 11 Validation: 100%|██████████| 53/53 [00:04<00:00, 13.00it/s]


Epoch 11: Train Loss = 0.5904, Valid Loss = 0.9443


Fold 4 OOF Predictions: 100%|██████████| 53/53 [00:04<00:00, 12.79it/s]
Fold 4 Epoch 12 Training: 100%|██████████| 214/214 [00:43<00:00,  4.93it/s]
Fold 4 Epoch 12 Validation: 100%|██████████| 53/53 [00:04<00:00, 12.99it/s]


Epoch 12: Train Loss = 0.5849, Valid Loss = 0.9650


Fold 4 OOF Predictions: 100%|██████████| 53/53 [00:04<00:00, 12.46it/s]
Fold 4 Epoch 13 Training: 100%|██████████| 214/214 [00:43<00:00,  4.93it/s]
Fold 4 Epoch 13 Validation: 100%|██████████| 53/53 [00:04<00:00, 12.88it/s]


Epoch 13: Train Loss = 0.5770, Valid Loss = 0.9421


Fold 4 OOF Predictions: 100%|██████████| 53/53 [00:04<00:00, 12.57it/s]
Fold 4 Epoch 14 Training: 100%|██████████| 214/214 [00:42<00:00,  5.00it/s]
Fold 4 Epoch 14 Validation: 100%|██████████| 53/53 [00:04<00:00, 12.87it/s]


Epoch 14: Train Loss = 0.5769, Valid Loss = 0.9422


Fold 4 OOF Predictions: 100%|██████████| 53/53 [00:04<00:00, 12.46it/s]


0,1
epoch,▁▁▂▃▃▃▄▅▅▅▆▇▇▇█
train/epoch_loss,█▅▄▃▂▂▂▂▂▂▁▁▁▁▁
train/loss,▅█▄▅▄▄▄▄▄▃▃▃▃▃▄▃▃▁▂▃▂▃▃▃▂▂▂▃▄▃▂▃▂▂▂▂▃▂▁▂
train/lr,▂▂█████▇▇▇▇▆▆▆▆▆▆▆▆▆▆▆▆▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▁▁
val/kl_div,█▃▃▁▁▂▂▁▂▁▂▁▂▁▁
val/loss,█▃▃▁▁▂▂▁▂▁▂▁▂▁▁

0,1
best_val_kl_div,0.93382
epoch,15.0
train/epoch_loss,0.57688
train/loss,0.53177
train/lr,-0.0
val/kl_div,0.94218
val/loss,0.94218



Calculating final OOF score...

Overall OOF KL Score: 1.0263


1.0262975692749023