# Competition Overview

**Competition Summary**

This Kaggle competition challenges participants to develop predictive models for allogeneic Hematopoietic Cell Transplantation (HCT) survival. The core objective is to create models that are not only accurate but also fair and equitable for patients across diverse backgrounds, particularly different racial groups. Participants will use synthetic data that mirrors real-world patient information to build their models. The competition runs from December 4, 2024, to March 5, 2025.

### [Link to competition](https://www.kaggle.com/competitions/equity-post-HCT-survival-predictions/overview)

# Evaluation Criteria

**Evaluation Criteria**

Model performance will be evaluated using a custom metric called the Stratified Concordance Index (C-index). This metric is an adaptation of the standard C-index, designed to ensure that the model's predictive power is equitable across different racial groups. A higher C-index (closer to 1.0) indicates a more accurate and fairer model. The final score is calculated as the mean C-index across all racial groups minus the standard deviation of those scores.

**Concordance index**

It represents the global assessment of the model discrimination power: this is the model’s ability to correctly provide a reliable ranking of the survival times based on the individual risk scores. It can be computed with the following formula:

![image.png](attachment:e4a0494a-786d-4991-a946-893a6838640c.png)

The concordance index is a value between 0 and 1 where:

- 0.5 is the expected result from random predictions,
- 1.0 is a perfect concordance and,
- 0.0 is perfect anti-concordance (multiply predictions with -1 to get 1.0)

References : [CIBMTR - Equity in post-HCT Survival Predictions](https://www.kaggle.com/competitions/equity-post-HCT-survival-predictions/overview)

## Submission Requirements

**Submission Requirements**

The submission file should contain a real-valued risk score for each patient in the test dataset, with a higher score indicating a higher likelihood of the target event.

# Dataset Description

The dataset for this competition contains 59 features related to hematopoietic stem cell transplantation (HSCT). These variables cover a wide range of demographic and medical information for both the transplant recipients and donors, such as age, sex, ethnicity, and disease status.

**Data and Target Variables**

The primary goal is to predict event-free survival (efs), which is the key target variable. The time to this event is captured in the efs_time variable. Together, these two variables are used for a censored time-to-event analysis.

**Dataset Generation**

The data is not real patient data; it was synthetically generated to protect patient privacy while mirroring real-world situations. The generation process used a model called SurvivalGAN, a type of Generative Adversarial Network (GAN) specifically designed to create realistic time-to-event data and handle the complexity of censored data. The synthetic dataset has equal representation across different racial categories, including White, Asian, African-American, Native American, Pacific Islander, and More than One Race.

**Included Files**

The competition provides four main files:
- train.csv: The training dataset with the efs target.
- test.csv: The dataset for which you must predict efs.
- sample_submission.csv: An example file showing the correct submission format.
- data_dictionary.csv: A list of all variables and their descriptions.

# Data Dictionary

<style type="text/css">
.tg  {border-collapse:collapse;border-spacing:0;}
.tg td{border-color:black;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;
  overflow:hidden;padding:10px 5px;word-break:normal;}
.tg th{border-color:black;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;
  font-weight:normal;overflow:hidden;padding:10px 5px;word-break:normal;}
.tg .tg-za14{border-color:inherit;text-align:left;vertical-align:bottom}
.tg .tg-7zrl{text-align:left;vertical-align:bottom}
</style>
<table class="tg"><thead>
  <tr>
    <th class="tg-za14">variable</th>
    <th class="tg-7zrl">description</th>
    <th class="tg-7zrl">type</th>
  </tr></thead>
<tbody>
  <tr>
    <td class="tg-7zrl">dri_score</td>
    <td class="tg-7zrl">Refined disease risk index</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">psych_disturb</td>
    <td class="tg-7zrl">Psychiatric disturbance</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">cyto_score</td>
    <td class="tg-7zrl">Cytogenetic score</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">diabetes</td>
    <td class="tg-7zrl">Diabetes</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">hla_match_c_high</td>
    <td class="tg-7zrl">Recipient / 1st donor allele level (high resolution) matching at HLA-C</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">hla_high_res_8</td>
    <td class="tg-7zrl">Recipient / 1st donor allele-level (high resolution) matching at&nbsp;&nbsp;&nbsp;HLA-A,-B,-C,-DRB1</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">tbi_status</td>
    <td class="tg-7zrl">TBI</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">arrhythmia</td>
    <td class="tg-7zrl">Arrhythmia</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">hla_low_res_6</td>
    <td class="tg-7zrl">Recipient / 1st donor antigen-level (low resolution) matching at&nbsp;&nbsp;&nbsp;HLA-A,-B,-DRB1</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">graft_type</td>
    <td class="tg-7zrl">Graft type</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">vent_hist</td>
    <td class="tg-7zrl">History of mechanical ventilation</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">renal_issue</td>
    <td class="tg-7zrl">Renal, moderate / severe</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">pulm_severe</td>
    <td class="tg-7zrl">Pulmonary, severe</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">prim_disease_hct</td>
    <td class="tg-7zrl">Primary disease for HCT</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">hla_high_res_6</td>
    <td class="tg-7zrl">Recipient / 1st donor allele-level (high resolution) matching at&nbsp;&nbsp;&nbsp;HLA-A,-B,-DRB1</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">cmv_status</td>
    <td class="tg-7zrl">Donor/recipient CMV serostatus</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">hla_high_res_10</td>
    <td class="tg-7zrl">Recipient / 1st donor allele-level (high resolution) matching at&nbsp;&nbsp;&nbsp;HLA-A,-B,-C,-DRB1,-DQB1</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">hla_match_dqb1_high</td>
    <td class="tg-7zrl">Recipient / 1st donor allele level (high resolution) matching at HLA-DQB1</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">tce_imm_match</td>
    <td class="tg-7zrl">T-cell epitope immunogenicity/diversity match</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">hla_nmdp_6</td>
    <td class="tg-7zrl">Recipient / 1st donor matching at HLA-A(lo),-B(lo),-DRB1(hi)</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">hla_match_c_low</td>
    <td class="tg-7zrl">Recipient / 1st donor antigen level (low resolution) matching at HLA-C</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">rituximab</td>
    <td class="tg-7zrl">Rituximab given in conditioning</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">hla_match_drb1_low</td>
    <td class="tg-7zrl">Recipient / 1st donor antigen level (low resolution) matching at HLA-DRB1</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">hla_match_dqb1_low</td>
    <td class="tg-7zrl">Recipient / 1st donor antigen level (low resolution) matching at HLA-DQB1</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">prod_type</td>
    <td class="tg-7zrl">Product type</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">cyto_score_detail</td>
    <td class="tg-7zrl">Cytogenetics for DRI (AML/MDS)</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">conditioning_intensity</td>
    <td class="tg-7zrl">Computed planned conditioning intensity</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">ethnicity</td>
    <td class="tg-7zrl">Ethnicity</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">year_hct</td>
    <td class="tg-7zrl">Year of HCT</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">obesity</td>
    <td class="tg-7zrl">Obesity</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">mrd_hct</td>
    <td class="tg-7zrl">MRD at time of HCT (AML/ALL)</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">in_vivo_tcd</td>
    <td class="tg-7zrl">In-vivo T-cell depletion (ATG/alemtuzumab)</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">tce_match</td>
    <td class="tg-7zrl">T-cell epitope matching</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">hla_match_a_high</td>
    <td class="tg-7zrl">Recipient / 1st donor allele level (high resolution) matching at HLA-A</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">hepatic_severe</td>
    <td class="tg-7zrl">Hepatic, moderate / severe</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">donor_age</td>
    <td class="tg-7zrl">Donor age</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">prior_tumor</td>
    <td class="tg-7zrl">Solid tumor, prior</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">hla_match_b_low</td>
    <td class="tg-7zrl">Recipient / 1st donor antigen level (low resolution) matching at HLA-B</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">peptic_ulcer</td>
    <td class="tg-7zrl">Peptic ulcer</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">age_at_hct</td>
    <td class="tg-7zrl">Age at HCT</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">hla_match_a_low</td>
    <td class="tg-7zrl">Recipient / 1st donor antigen level (low resolution) matching at HLA-A</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">gvhd_proph</td>
    <td class="tg-7zrl">Planned GVHD prophylaxis</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">rheum_issue</td>
    <td class="tg-7zrl">Rheumatologic</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">sex_match</td>
    <td class="tg-7zrl">Donor/recipient sex match</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">hla_match_b_high</td>
    <td class="tg-7zrl">Recipient / 1st donor allele level (high resolution) matching at HLA-B</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">race_group</td>
    <td class="tg-7zrl">Race</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">comorbidity_score</td>
    <td class="tg-7zrl">Sorror comorbidity score</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">karnofsky_score</td>
    <td class="tg-7zrl">KPS at HCT</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">hepatic_mild</td>
    <td class="tg-7zrl">Hepatic, mild</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">tce_div_match</td>
    <td class="tg-7zrl">T-cell epitope matching</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">donor_related</td>
    <td class="tg-7zrl">Related vs. unrelated donor</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">melphalan_dose</td>
    <td class="tg-7zrl">Melphalan dose (mg/m^2)</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">hla_low_res_8</td>
    <td class="tg-7zrl">Recipient / 1st donor antigen-level (low resolution) matching at&nbsp;&nbsp;&nbsp;HLA-A,-B,-C,-DRB1</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">cardiac</td>
    <td class="tg-7zrl">Cardiac</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">hla_match_drb1_high</td>
    <td class="tg-7zrl">Recipient / 1st donor allele level (high resolution) matching at HLA-DRB1</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">pulm_moderate</td>
    <td class="tg-7zrl">Pulmonary, moderate</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">hla_low_res_10</td>
    <td class="tg-7zrl"></td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">efs</td>
    <td class="tg-7zrl">Event-free survival</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">efs_time</td>
    <td class="tg-7zrl">Time to event-free survival, months</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
</tbody></table>

### [Link to dataset](https://www.kaggle.com/competitions/equity-post-HCT-survival-predictions/data)

# Pipeline Overview

**Competition Pipeline Overview**

The core goal is to predict the efs_time (event-free survival time) for patients, while also considering the efs (event status). The pipeline is built using Python, with a heavy emphasis on PyTorch and scikit-learn for modeling and data processing. The pipeline can be broken down into the following key stages:

**1. Data Preparation** 

The data is loaded, and new features like <code>is_cyto_score_same</code> and <code>year_hct</code> are created. Categorical features are handled using LabelEncoder, and numerical features are preprocessed with SimpleImputer and StandardScaler. The data is then organized into PyTorch DataLoaders for efficient batch processing.

**2. Model Architecture and Model Training** 

A custom neural network (NN) is used, which incorporates a specialized <code>CatEmbeddings</code> module to handle categorical data by creating and projecting embeddings. The main network then combines these embeddings with the numerical data. The network uses an ODST layer, batch normalization, and dropout.

- Loss Functions: The model uses a unique composite loss function for training. This includes:
  - A primary margin-based hinge loss (<code>calc_loss</code>) for survival data, which ranks patient pairs based on their survival times.
  - An auxiliary MSE loss on an auxiliary task to improve the learned representations.
  - A fairness loss that penalizes the standard deviation of the main loss across different racial groups (<code>race_group</code>), aiming for equitable performance.

**3. Hyperparameter Tuning** 

Optuna is used to find the optimal hyperparameters. It defines an objective function that maximizes a composite metric of the c-index and minimizes validation loss on a single data fold.

**4. Final training and Test Model Inference**

- **Final Training and Cross-Validation:** The model is trained using a 5-fold stratified cross-validation strategy. This ensures that the distribution of race and newborn status is maintained across folds, which is crucial for the fairness objective. In each fold, the model is trained with the best hyperparameters found by Optuna. PyTorch Lightning handles the training loop with callbacks for save the best model check point, learning rate monitoring, and stochastic weight averaging to prevent overfitting and improve generalization.
- **Prediction and Evaluation:** The trained models from all five folds are used to make predictions. The final prediction is an ensemble average of the predictions from each model, enhancing the overall robustness and accuracy. The performance is evaluated by the c-index and test loss on an unseen test dataset.

# Related Notebooks

### [Model submission notebook](https://www.kaggle.com/code/misterfour/cibmtr-challenge-submission)
### [Inspiration notebook](https://www.kaggle.com/code/zyh1104/cibmtr-nn-negative-log-likelihood-loss)
### [Package install lightning and tarbular data](https://www.kaggle.com/code/dreamingtree/download-lightning-and-pytorch-tabular)
### [Package install evaluation metric](https://www.kaggle.com/code/cdeotte/pip-install-lifelines)

# Import necessary libraries

In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/download-lightning-and-pytorch-tabular/MarkupSafe-3.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
/kaggle/input/download-lightning-and-pytorch-tabular/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl
/kaggle/input/download-lightning-and-pytorch-tabular/jedi-0.19.2-py2.py3-none-any.whl
/kaggle/input/download-lightning-and-pytorch-tabular/matplotlib-3.10.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
/kaggle/input/download-lightning-and-pytorch-tabular/aiohttp-3.11.11-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
/kaggle/input/download-lightning-and-pytorch-tabular/pexpect-4.9.0-py2.py3-none-any.whl
/kaggle/input/download-lightning-and-pytorch-tabular/numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
/kaggle/input/download-lightning-and-pytorch-tabular/antlr4-python3-runtime-4.9.3.tar.gz
/kaggle/input/download-lightning-and-pytorch-tabular/tzdata-2025.1-py2.py3-none-any.whl
/kaggle/input/downl

In [2]:
!pip install /kaggle/input/pip-install-lifelines/autograd-1.7.0-py3-none-any.whl
!pip install /kaggle/input/pip-install-lifelines/autograd-gamma-0.5.0.tar.gz
!pip install /kaggle/input/pip-install-lifelines/interface_meta-1.3.0-py3-none-any.whl
!pip install /kaggle/input/pip-install-lifelines/formulaic-1.0.2-py3-none-any.whl
!pip install /kaggle/input/pip-install-lifelines/lifelines-0.30.0-py3-none-any.whl
!pip install /kaggle/input/download-lightning-and-pytorch-tabular/pytorch_lightning-2.4.0-py3-none-any.whl
!pip install /kaggle/input/download-lightning-and-pytorch-tabular/scikit_learn-1.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
!pip install /kaggle/input/download-lightning-and-pytorch-tabular/torchmetrics-1.5.2-py3-none-any.whl
!pip install /kaggle/input/download-lightning-and-pytorch-tabular/pytorch_tabnet-4.1.0-py3-none-any.whl
!pip install /kaggle/input/download-lightning-and-pytorch-tabular/einops-0.7.0-py3-none-any.whl
!pip install /kaggle/input/download-lightning-and-pytorch-tabular/pytorch_tabular-1.1.1-py2.py3-none-any.whl
!pip install optuna

Processing /kaggle/input/pip-install-lifelines/autograd-1.7.0-py3-none-any.whl
Installing collected packages: autograd
  Attempting uninstall: autograd
    Found existing installation: autograd 1.8.0
    Uninstalling autograd-1.8.0:
      Successfully uninstalled autograd-1.8.0
Successfully installed autograd-1.7.0
Processing /kaggle/input/pip-install-lifelines/autograd-gamma-0.5.0.tar.gz
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: autograd-gamma
  Building wheel for autograd-gamma (setup.py) ... [?25l[?25hdone
  Created wheel for autograd-gamma: filename=autograd_gamma-0.5.0-py3-none-any.whl size=4030 sha256=4ba87750683cb48a2e54b5bc4218188c0450beefc0a2f4145da3328d4f28e3fa
  Stored in directory: /root/.cache/pip/wheels/b7/2e/d3/0a27e12c854a64e909e65029126b4464927060f5700bb48dc9
Successfully built autograd-gamma
Installing collected packages: autograd-gamma
Successfully installed autograd-gamma-0.5.0
Processing /kaggle/input/pip-install

In [None]:
import numpy as np
import pandas as pd
import torch
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import LabelEncoder, StandardScaler
from torch.utils.data import TensorDataset
from warnings import filterwarnings
import functools
from typing import List
import pytorch_lightning as pl
from lifelines.utils import concordance_index
from pytorch_lightning.cli import ReduceLROnPlateau
from pytorch_tabular.models.common.layers import ODST
from torch import nn
import matplotlib.pyplot as plt
from pytorch_lightning.callbacks import LearningRateMonitor, TQDMProgressBar, StochasticWeightAveraging, ModelCheckpoint, EarlyStopping
from pytorch_lightning.utilities import grad_norm
from sklearn.model_selection import StratifiedKFold
import optuna
import torch.nn.functional as F
import json

pl.seed_everything(42)
# Suppress warnings for cleaner output
filterwarnings('ignore')

# Data Loading and Feature Engineering

**Data Loading and Feature Engineering**

The <code>load_data()</code> function is the starting point. It reads in training and testing data from CSV files. The <code>add_features()</code> function is called right after loading. This is where new features are created from existing ones to help the model learn more effectively. For example, it checks if cyto_score and <code>cyto_score_detail</code> are the same, creating a new binary feature. It also modifies the <code>year_hct</code> column and comments out other potential features, suggesting this is a stage of active experimentation.

**Data Preprocessing**

The core of the code lies in the preprocessing functions.

- <code>get_feature_types()</code>: This utility function separates the columns into numerical and categorical types based on their data type and the number of unique values.

- <code>get_categoricals()</code>: This function handles the categorical features. It uses `LabelEncoder` to transform string-based categories into numerical integer labels. It also checks for and removes any categorical columns that have only one unique value (constant columns), as they won't provide any useful information for the model.

- <code>preprocess_data()</code>: This is the main orchestrator. It uses SimpleImputer to fill in missing values (NaN) in the numerical data with the mean of the column. It then uses `StandardScaler` to standardize the numerical features, which is a crucial step for many machine learning algorithms, especially neural networks, to ensure all features are on a similar scale.

**PyTorch Data Preparation**

The final step is to prepare the preprocessed data for a deep learning model.

- <code>init_dl()</code>: This function takes the preprocessed categorical and numerical data, along with the target variables, and bundles them into a TensorDataset. It then uses a DataLoader to create batches of data. This is a standard practice in PyTorch to efficiently feed data to a neural network during training and validation. The log-transformation of <code>efs_time</code> suggests that the target variable might have a skewed distribution, which is a common issue in many datasets. The code sets a fixed batch size of 2048 and shuffles the data for training.

**Theories and Concepts**

- **Data Preprocessing:** This is a foundational step in any machine learning project. Real-world data is often messy, with missing values, inconsistent formats, and features of different scales. Preprocessing transforms this raw data into a clean, structured format that models can effectively learn from.
- **Feature Engineering:** The `add_features` function demonstrates this concept. Feature engineering is the process of using domain knowledge to create new input features that are more informative and useful for the model. By creating `is_cyto_score_same`, the code provides the model with a direct and explicit signal that might be difficult to learn from the raw `cyto_score` and `cyto_score_detail` columns alone.
- **Categorical Data Encoding:** Machine learning algorithms, at their core, are mathematical and operate on numerical data. Categorical features (e.g., "red," "green," "blue") must be converted into a numerical representation. Label Encoding assigns a unique integer to each category. This is a simple and common method, though it can sometimes impose a false sense of ordinality on the data.
- **Handling Missing Data (Imputation):** Missing data can cause errors or lead to biased results. SimpleImputer with the mean strategy fills in missing values for numerical features with the average value of that column. This is a simple and effective method when the amount of missing data is small.
- **Feature Scaling:** The StandardScaler is used to standardize numerical features. This is a critical step for many algorithms, particularly those based on gradient descent (like neural networks) or those that rely on distance calculations (like k-nearest neighbors). Scaling ensures that no single feature dominates the learning process simply because it has a larger magnitude.
- **Data Leakage:** This is a major risk in machine learning pipelines. It occurs when information from the validation or test set inadvertently influences the training process. The code correctly avoids this by using `fit_transform` only on the training data and then applying the same fitted transformation (`transform`) to the validation data. This ensures the model is evaluated on a truly unseen dataset.
- **Batching and Data Loading:** The DataLoader is a key component for training deep learning models. Training a model on the entire dataset at once is computationally infeasible for large datasets. DataLoader breaks the dataset into smaller "mini-batches." This allows for more efficient memory usage, enables parallel processing on GPUs, and introduces a form of stochasticity that helps the optimization process converge.

In [2]:
# Utility function to handle categorical data transformations
def get_X_cat(df, cat_cols, transformers=None):
    """
    Apply a specific categorical data transformer or a LabelEncoder if None.
    """
    if transformers is None:
        transformers = [LabelEncoder().fit(df[col]) for col in cat_cols]
    return transformers, np.array(
        [transformer.transform(df[col]) for col, transformer in zip(cat_cols, transformers)]
    ).T


# Utility function to get feature types
def get_feature_types(train):
    """
    Utility function to return categorical and numerical column names.
    """
    categorical_cols = [col for i, col in enumerate(train.columns) if ((train[col].dtype == "object") | (2 < train[col].nunique() < 25))]
    RMV = ["ID", "efs", "efs_time", "y"]
    FEATURES = [c for c in train.columns if not c in RMV]
    numerical = [i for i in FEATURES if i not in categorical_cols]
    return categorical_cols, numerical


# Preprocessing function for categorical features
def get_categoricals(train, val):
    """
    Remove constant categorical columns and transform them using LabelEncoder.
    Return the label-transformers for each categorical column, categorical dataframes and numerical columns.
    """
    categorical_cols, numerical = get_feature_types(train)
    remove = []
    for col in categorical_cols:
        if train[col].nunique() == 1:
            remove.append(col)
        ind = ~val[col].isin(train[col])
        if ind.any():
            val.loc[ind, col] = np.nan
    categorical_cols = [col for col in categorical_cols if col not in remove]
    transformers, X_cat_train = get_X_cat(train, categorical_cols)
    _, X_cat_val = get_X_cat(val, categorical_cols, transformers)
    return X_cat_train, X_cat_val, numerical, transformers


# PyTorch Data Preparation
def init_dl(X_cat, X_num, df, training=False):
    """
    Initialize data loaders with 4 dimensions : categorical dataframe, numerical dataframe and target values (efs and efs_time).
    Notice that efs_time is log-transformed.
    Fix batch size to 2048 and return dataloader for training or validation depending on training value.
    """
    ds_train = TensorDataset(
        torch.tensor(X_cat, dtype=torch.long),
        torch.tensor(X_num, dtype=torch.float32),
        torch.tensor(df.efs_time.values, dtype=torch.float32).log(),
        torch.tensor(df.efs.values, dtype=torch.long)
    )
    bs = 2048
    dl_train = torch.utils.data.DataLoader(ds_train, batch_size=bs, pin_memory=True, shuffle=training)
    return dl_train


# Feature Engineering
def add_features(df):
    """
    Create some new features to help the model focus on specific patterns.
    """
    df['is_cyto_score_same'] = (df['cyto_score'] == df['cyto_score_detail']).astype(int)
    df['year_hct'] -= 2000
    return df


# Data Loading
def load_data():
    """
    Load data and add features.
    """
    test = pd.read_csv("/kaggle/input/equity-post-HCT-survival-predictions/test.csv")
    test = add_features(test)
    print("Test shape:", test.shape)
    train = pd.read_csv("/kaggle/input/equity-post-HCT-survival-predictions/train.csv")
    train = add_features(train)
    print("Train shape:", train.shape)
    return test, train


# Main Preprocessing Orchestrator
def preprocess_data(train, val):
    """
    Standardize numerical variables and transform (Label-encode) categoricals.
    Fill NA values with mean for numerical.
    Create torch dataloaders to prepare data for training and evaluation.
    """
    X_cat_train, X_cat_val, numerical, transformers = get_categoricals(train, val)
    scaler = StandardScaler()
    imp = SimpleImputer(missing_values=np.nan, strategy='mean', add_indicator=True)
    X_num_train = imp.fit_transform(train[numerical])
    X_num_train = scaler.fit_transform(X_num_train)
    X_num_val = imp.transform(val[numerical])
    X_num_val = scaler.transform(X_num_val)
    dl_train = init_dl(X_cat_train, X_num_train, train, training=True)
    dl_val = init_dl(X_cat_val, X_num_val, val)
    return X_cat_val, X_num_train, X_num_val, dl_train, dl_val, transformers

# Model architecture and model training

## Model Architecture Description

The code defines a neural network (NN) that combines both numerical and categorical data.

**CatEmbeddings:** This module handles the categorical features. It learns a numerical representation (embedding) for each category, which is a common practice for feeding non-numeric data to a neural network. These embeddings are then concatenated and passed through a small neural network to create a combined representation.

**NN:** This is the main model. It takes the projected categorical embeddings and the standardized numerical data, concatenates them, and feeds the combined tensor through a Multi-Layer Perceptron (MLP) to generate the final prediction.

**CatEmbeddings Class**

This class handles all the categorical features in the dataset.

**Layers:** It contains multiple nn.Embedding layers, one for each categorical feature. An embedding layer converts a categorical variable into a dense vector of real numbers, which a neural network can process more effectively.

**Functionality:** It takes the categorical data, passes each feature through its corresponding embedding layer, and then concatenates all the resulting embedding vectors. This combined vector is then fed through a small sequential network (<code>nn.Sequential</code>) with two <code>nn.Linear</code> (fully connected) layers and a GELU activation function to project the data into a lower-dimensional space. This process transforms a patient's categorical information into a single, meaningful vector representation.

**NN Class**

This is the main neural network that takes both the categorical and continuous data as input. NN class consists of following:

**self.embeddings:** This is an instance of the CatEmbeddings class, which processes the categorical features as described above.

**concatenation:** It takes the output from self.embeddings and concatenates it with the continuous (numerical) input features.

**self.mlp:** This is a multi-layer perceptron (MLP). It's a sequential block of layers that includes:

- An ODST (Orthogonal Decision-Tree) layer: This is a specialized layer that helps the model learn complex, non-linear relationships in the data.

- <code>nn.BatchNorm1d</code>: A batch normalization layer that helps stabilize and speed up the training process by normalizing the inputs to the next layer.

- <code>nn.Dropout</code>: A dropout layer that randomly "turns off" a fraction of neurons during training to prevent the model from overfitting.

- <code>self.out</code>: A final <code>nn.Linear</code> layer that maps the output of the MLP to a single value, which represents the model's risk prediction for survival.

**Loss Function (LitNN Class)**

The LitNN class is the core of the model. It's built on PyTorch Lightning, which simplifies the training process. The most important part of this class is its custom loss function, defined in <code>calc_loss</code> and <code>get_full_loss</code>.

**Survival Loss:** The model uses a margin-based hinge loss to handle the survival data. This loss function compares pairs of patients, aiming to ensure that the patient with the longer survival time has a lower predicted risk score. This is a common approach for censored time-to-event data where not all patients have experienced the event.

**Fairness Loss:** The <code>get_race_losses</code> function is a critical component for addressing the fairness objective of the competition. It calculates the loss for each individual racial group and then computes the standard deviation of these losses. This standard deviation is added to the main loss function. By minimizing this term, the model is penalized for performing inconsistently across different racial groups, thereby encouraging more equitable predictions.

This is the PyTorch Lightning module that orchestrates the entire training and evaluation process. It's not a layer itself but a framework for managing the model. LitNN class consists of following:

**<code>self.model</code>:** An instance of the NN class, which is the core predictive model.

**<code>self.aux_cls</code>:** A small auxiliary network that performs an additional task to help the main model learn better representations. It consists of two nn.Linear layers with a GELU activation function.

**Training and Evaluation**

**Training Loop (<code>training_step</code>):** It defines how the model learns from data. It calculates a primary loss based on pairs of patient data (calc_loss) and an additional "fairness" loss (get_race_losses) that penalizes the model if its performance varies significantly across different racial groups. This is a crucial component for meeting the competition's objectives.

**Loss Calculation (<code>calc_loss</code>):** This function is the heart of the survival analysis. It compares all possible pairs of patients and uses a margin-based hinge loss to ensure that patients who survived longer receive a lower risk score prediction.

**Optimizer:** The configure_optimizers method uses the Adam optimizer with weight decay and a Cosine Annealing scheduler to manage the learning rate, which are standard practices for deep learning models.

**Training (<code>training_step</code>):** During training, the model calculates both the main survival loss and the race-based fairness loss. It also includes an auxiliary loss on a side task to improve the model's feature representation.

**Validation and Testing:** The validation_step and test_step functions calculate the main loss and, more importantly, the Stratified Concordance Index (C-index). As described in your Canvas document, this custom metric evaluates the model's accuracy while explicitly accounting for fairness across racial groups, confirming that the code is built specifically to address the competition's unique evaluation criteria.

**Fairness Metric (<code>get_race_losses</code>):** This function explicitly calculates the variance of the loss across different racial groups. By adding the square root of this variance to the main loss, the model is pushed to have a more consistent and fair performance for all groups.

**Evaluation (<code>validation_step</code>, <code>on_validation_epoch_end</code>):** It uses the Stratified Concordance Index (`C-index`), the competition's official metric, to evaluate the model's performance on the validation set. This ensures the model is not only accurate but also equitable.

Thanks for competition metric : [Reference](https://www.kaggle.com/code/cdeotte/pip-install-lifelines)

**Theories and Concepts**

**Hybrid Model Architecture**

The model is designed to process both categorical and continuous data, a common challenge with tabular datasets.

- **Categorical Embeddings:** The `CatEmbeddings` class is a crucial component. Instead of using one-hot encoding, which can create a very large and sparse feature space, it represents each category (e.g., 'race' or 'city') as a dense vector. This technique allows the model to learn meaningful, low-dimensional representations where categories with similar properties are closer in the embedding space. These learned embeddings are then concatenated and passed through a projection layer, a small neural network, to learn complex interactions between them.
- **Multilayer Perceptron (MLP):** The NN class acts as the main model. It takes the concatenated categorical embeddings and the numerical features and combines them. This combined vector is then fed into an MLP, which uses layers like `nn.Linear` to learn complex, non-linear relationships in the data. The model includes a specialized ODST layer (likely a form of Optimal Decision Tree Split), which is a modern, interpretable layer for tabular data.
- **Data Concatenation:** The NN class showcases a hybrid architecture. It takes the processed categorical embeddings and the numerical features (which have been separately normalized) and concatenates them into a single, comprehensive feature vector. This combined vector is then fed into a shared Multilayer Perceptron (MLP) to learn non-linear relationships between all the input features.
- **Advanced Layers:** The model includes specialized layers like ODST, which suggests the use of a modern, tree-based layer for tabular data. It also uses Batch Normalization (`nn.BatchNorm1d`) to stabilize and accelerate training, and Dropout (`nn.Dropout`) to prevent overfitting by randomly deactivating neurons during training.
- **Weight Initialization:** The NN class includes a loop to initialize weights using Xavier Normal initialization (`nn.init.xavier_normal_`). This is a common practice that helps the network converge faster and avoids problems with vanishing or exploding gradients.

**Advanced Loss Function for Survival Data**

The most unique part of this code is the custom `calc_loss` function, which is designed specifically for time-to-event data.

- **Ranking-Based Loss:** Standard Mean Squared Error (MSE) is not suitable for survival analysis because the model needs to learn to correctly rank subjects based on their survival time, not just predict the exact time. For example, it is more important to predict that Subject A will survive longer than Subject B than it is to get the exact survival time of each. The code achieves this by using a pairwise comparison approach, creating all possible pairs of subjects and comparing their predicted survival times.
- **Handling Censored Data:** A key challenge in survival analysis is handling censored data—cases where the event (e.g., death) has not yet occurred at the end of the study. The `get_mask` function identifies invalid pairs for comparison. For example, if Subject A is still alive after 5 years and Subject B dies after 3 years, we know Subject A outlived Subject B. However, if Subject C is still alive after 2 years and Subject D is still alive after 4 years, we cannot definitively rank them, and these pairs are masked out from the loss calculation.
- **Margin-Based Hinge Loss:** The loss function itself is a variation of a Hinge Loss, which is commonly used in support vector machines and ranking problems. The margin hyperparameter ensures that the model is penalized only when the difference between the predicted outcomes of a pair of subjects is less than this margin, forcing the model to be confident in its rankings.

**Fairness and Regularization**

The model incorporates concepts to promote fairness and prevent overfitting.

- **Fairness Regularization:** The `get_race_losses` function is a sophisticated attempt to ensure the model performs equitably across different subgroups. By calculating the loss for each racial group and then adding the standard deviation of these losses to the main loss function, the model is penalized for having widely different performance metrics across these groups. This encourages the model to be more robust and fair, preventing it from favoring one group over another.
- **Regularization and Activation:** The model uses several key techniques to prevent overfitting:
  - **Batch Normalization (nn.BatchNorm1d):** This layer normalizes the activations of the previous layer, helping to stabilize and accelerate training.
  - **Dropout (nn.Dropout):** During training, this layer randomly "drops out" a percentage of neurons, forcing the network to learn more robust and redundant features.
  - **GELU Activation:** The `GELU` (Gaussian Error Linear Unit) is an activation function used to introduce non-linearity into the network. It's often seen as a smoother alternative to the popular ReLU function.
- **Auxiliary Loss:** The `aux_cls` (auxiliary classifier) and aux_loss are part of a multi-task learning approach. The model is trained on a primary task (survival time prediction) but also on a secondary, related task (the `efs_time` prediction on embeddings). This can help the model learn more robust and generalizable features in the embeddings.

**PyTorch Lightning and Optimizer**

The code uses PyTorch Lightning to streamline the training process and includes advanced optimization techniques.

- **PyTorch Lightning:** LitNN continues from `pl.LightningModule`. This framework abstracts away the standard boilerplate code for training and evaluation loops, simplifying the code and making it easier to manage.
- **Adam Optimizer and Cosine Annealing**: The configure_optimizers method uses the popular Adam optimizer, which is efficient for deep learning. It also includes `CosineAnnealingLR`, a learning rate scheduler that cyclically adjusts the learning rate. This can help the model escape local minima and find a better, more generalized solution during training.

## Model Architecture

In [4]:
class CatEmbeddings(nn.Module):
    """
    Embedding module for the categorical dataframe.
    """
    def __init__(
        self,
        projection_dim: int,
        categorical_cardinality: List[int],
        embedding_dim: int
    ):
        """
        projection_dim: The dimension of the final output after projecting the concatenated embeddings into a lower-dimensional space.
        categorical_cardinality: A list where each element represents the number of unique categories (cardinality) in each categorical feature.
        embedding_dim: The size of the embedding space for each categorical feature.
        self.embeddings: list of embedding layers for each categorical feature.
        self.projection: sequential neural network that goes from the embedding to the output projection dimension with GELU activation.
        """
        super(CatEmbeddings, self).__init__()
        self.embeddings = nn.ModuleList([
            nn.Embedding(cardinality, embedding_dim)
            for cardinality in categorical_cardinality
        ])
        self.projection = nn.Sequential(
            nn.Linear(embedding_dim * len(categorical_cardinality), projection_dim),
            nn.GELU(),
            nn.Linear(projection_dim, projection_dim)
        )

    def forward(self, x_cat):
        """
        Apply the projection on concatened embeddings that contains all categorical features.
        """
        x_cat = [embedding(x_cat[:, i]) for i, embedding in enumerate(self.embeddings)]
        x_cat = torch.cat(x_cat, dim=1)
        return self.projection(x_cat)


class NN(nn.Module):
    """
    Train a model on both categorical embeddings and numerical data.
    """
    def __init__(
            self,
            continuous_dim: int,
            categorical_cardinality: List[int],
            embedding_dim: int,
            projection_dim: int,
            hidden_dim: int,
            dropout: float = 0
    ):
        """
        continuous_dim: The number of continuous features.
        categorical_cardinality: A list of integers representing the number of unique categories in each categorical feature.
        embedding_dim: The dimensionality of the embedding space for each categorical feature.
        projection_dim: The size of the projected output space for the categorical embeddings.
        hidden_dim: The number of neurons in the hidden layer of the MLP.
        dropout: The dropout rate applied in the network.
        self.embeddings: previous embeddings for categorical data.
        self.mlp: defines an MLP model with an ODST layer followed by batch normalization and dropout.
        self.out: linear output layer that maps the output of the MLP to a single value
        self.dropout: defines dropout
        Weights initialization with xavier normal algorithm and biases with zeros.
        """
        super(NN, self).__init__()
        self.embeddings = CatEmbeddings(projection_dim, categorical_cardinality, embedding_dim)
        self.mlp = nn.Sequential(
            ODST(projection_dim + continuous_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.Dropout(dropout)
        )
        self.out = nn.Linear(hidden_dim, 1)
        self.dropout = nn.Dropout(dropout)

        # initialize weights
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x_cat, x_cont):
        """
        Create embedding layers for categorical data, concatenate with continous variables.
        Add dropout and goes through MLP and return raw output and 1-dimensional output as well.
        """
        x = self.embeddings(x_cat)
        x = torch.cat([x, x_cont], dim=1)
        x = self.dropout(x)
        x = self.mlp(x)
        return self.out(x), x


@functools.lru_cache
def combinations(N):
    """
    calculates all possible 2-combinations (pairs) of a tensor of indices from 0 to N-1, 
    and caches the result using functools.lru_cache for optimization
    """
    ind = torch.arange(N)
    comb = torch.combinations(ind, r=2)
    return comb.cuda()


class LitNN(pl.LightningModule):
    """
    Main Model creation and losses definition to fully train the model.
    """
    def __init__(
            self,
            continuous_dim: int,
            categorical_cardinality: List[int],
            embedding_dim: int,
            projection_dim: int,
            hidden_dim: int,
            lr: float = 1e-3,
            dropout: float = 0.2,
            weight_decay: float = 1e-3,
            aux_weight: float = 0.1,
            margin: float = 0.5,
            race_index: int = 0
    ):
        """
        continuous_dim: The number of continuous input features.
        categorical_cardinality: A list of integers, where each element corresponds to the number of unique categories for each categorical feature.
        embedding_dim: The dimension of the embeddings for the categorical features.
        projection_dim: The dimension of the projected space after embedding concatenation.
        hidden_dim: The size of the hidden layers in the feedforward network (MLP).
        lr: The learning rate for the optimizer.
        dropout: Dropout probability to avoid overfitting.
        weight_decay: The L2 regularization term for the optimizer.
        aux_weight: Weight used for auxiliary tasks.
        margin: Margin used in some loss functions.
        race_index: An index that refer to race_group in the input data.
        """
        super(LitNN, self).__init__()
        self.save_hyperparameters()

        # Creates an instance of the NN model defined above
        self.model = NN(
            continuous_dim=self.hparams.continuous_dim,
            categorical_cardinality=self.hparams.categorical_cardinality,
            embedding_dim=self.hparams.embedding_dim,
            projection_dim=self.hparams.projection_dim,
            hidden_dim=self.hparams.hidden_dim,
            dropout=self.hparams.dropout
        )
        self.targets = []

        # Defines a small feedforward neural network that performs an auxiliary task with 1-dimensional output
        self.aux_cls = nn.Sequential(
            nn.Linear(self.hparams.hidden_dim, self.hparams.hidden_dim // 3),
            nn.GELU(),
            nn.Linear(self.hparams.hidden_dim // 3, 1)
        )

    def on_before_optimizer_step(self, optimizer):
        """
        Compute the 2-norm for each layer
        If using mixed precision, the gradients are already unscaled here
        """
        norms = grad_norm(self.model, norm_type=2)
        self.log_dict(norms)

    def forward(self, x_cat, x_cont):
        """
        Forward pass that outputs the 1-dimensional prediction and the embeddings (raw output)
        """
        x, emb = self.model(x_cat, x_cont)
        return x.squeeze(1), emb

    def training_step(self, batch, batch_idx):
        """
        defines how the model processes each batch of data during training.
        A batch is a combination of : categorical data, continuous data, efs_time (y) and efs event.
        y_hat is the efs_time prediction on all data and aux_pred is auxiliary prediction on embeddings.
        Calculates loss and race_group loss on full data.
        Auxiliary loss is calculated with an event mask, ignoring efs=0 predictions and taking the average.
        Returns loss and aux_loss multiplied by weight defined above.
        """
        x_cat, x_cont, y, efs = batch
        y_hat, emb = self(x_cat, x_cont)
        aux_pred = self.aux_cls(emb).squeeze(1)
        loss, race_loss = self.get_full_loss(efs, x_cat, y, y_hat)
        aux_loss = nn.functional.mse_loss(aux_pred, y, reduction='none')
        aux_mask = efs == 1
        aux_loss = (aux_loss * aux_mask).sum() / aux_mask.sum()
        self.log("train_loss", loss, on_epoch=True, prog_bar=True, logger=True)
        self.log("race_loss", race_loss, on_epoch=True, prog_bar=True, logger=True, on_step=False)
        self.log("aux_loss", aux_loss, on_epoch=True, prog_bar=True, logger=True, on_step=False)
        return loss + aux_loss * self.hparams.aux_weight

    def get_full_loss(self, efs, x_cat, y, y_hat):
        """
        Output loss and race_group loss.
        """
        loss = self.calc_loss(y, y_hat, efs)
        race_loss = self.get_race_losses(efs, x_cat, y, y_hat)
        loss += 0.1 * race_loss
        return loss, race_loss

    def get_race_losses(self, efs, x_cat, y, y_hat):
        """
        Calculate loss for each race_group based on deviation/variance.
        """
        races = torch.unique(x_cat[:, self.hparams.race_index])
        race_losses = []
        for race in races:
            ind = x_cat[:, self.hparams.race_index] == race
            race_losses.append(self.calc_loss(y[ind], y_hat[ind], efs[ind]))
        race_loss = sum(race_losses) / len(race_losses)
        races_loss_std = sum((r - race_loss)**2 for r in race_losses) / len(race_losses)
        return torch.sqrt(races_loss_std)

    def calc_loss(self, y, y_hat, efs):
        """
        Most important part of the model : loss function used for training.
        We face survival data with event indicators along with time-to-event.

        This function computes the main loss by the following the steps :
        * create all data pairs with "combinations" function (= all "two subjects" combinations)
        * make sure that we have at least 1 event in each pair
        * convert y to +1 or -1 depending on the correct ranking
        * loss is computed using a margin-based hinge loss
        * mask is applied to ensure only valid pairs are being used (censored data can't be ranked with event in some cases)
        * average loss on all pairs is returned
        """
        N = y.shape[0]
        comb = combinations(N)
        comb = comb[(efs[comb[:, 0]] == 1) | (efs[comb[:, 1]] == 1)]
        pred_left = y_hat[comb[:, 0]]
        pred_right = y_hat[comb[:, 1]]
        y_left = y[comb[:, 0]]
        y_right = y[comb[:, 1]]
        y = 2 * (y_left > y_right).int() - 1
        loss = nn.functional.relu(-y * (pred_left - pred_right) + self.hparams.margin)
        mask = self.get_mask(comb, efs, y_left, y_right)
        loss = (loss.double() * (mask.double())).sum() / mask.sum()
        return loss

    def get_mask(self, comb, efs, y_left, y_right):
        """
        Defines all invalid comparisons :
        * Case 1: "Left outlived Right" but Right is censored
        * Case 2: "Right outlived Left" but Left is censored
        Masks for case 1 and case 2 are combined using |= operator and inverted using ~ to create a "valid pair mask"
        """
        left_outlived = y_left >= y_right
        left_1_right_0 = (efs[comb[:, 0]] == 1) & (efs[comb[:, 1]] == 0)
        mask2 = (left_outlived & left_1_right_0)
        right_outlived = y_right >= y_left
        right_1_left_0 = (efs[comb[:, 1]] == 1) & (efs[comb[:, 0]] == 0)
        mask2 |= (right_outlived & right_1_left_0)
        mask2 = ~mask2
        mask = mask2
        return mask

    def validation_step(self, batch, batch_idx):
        """
        This method defines how the model processes each batch during validation
        """
        x_cat, x_cont, y, efs = batch
        y_hat, emb = self(x_cat, x_cont)
        loss, race_loss = self.get_full_loss(efs, x_cat, y, y_hat)
        self.targets.append([y, y_hat.detach(), efs, x_cat[:, self.hparams.race_index]])
        self.log("val_loss", loss, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def on_validation_epoch_end(self):
        """
        At the end of the validation epoch, it computes and logs the concordance index
        """
        cindex, metric = self._calc_cindex()
        self.log("cindex", metric, on_epoch=True, prog_bar=True, logger=True)
        self.log("cindex_simple", cindex, on_epoch=True, prog_bar=True, logger=True)
        self.targets.clear()

    def _calc_cindex(self):
        """
        Calculate c-index accounting for each race_group or global.
        """
        y = torch.cat([t[0] for t in self.targets]).cpu().numpy()
        y_hat = torch.cat([t[1] for t in self.targets]).cpu().numpy()
        efs = torch.cat([t[2] for t in self.targets]).cpu().numpy()
        races = torch.cat([t[3] for t in self.targets]).cpu().numpy()
        metric = self._metric(efs, races, y, y_hat)
        cindex = concordance_index(y, y_hat, efs)
        return cindex, metric

    def _metric(self, efs, races, y, y_hat):
        """
        Calculate c-index accounting for each race_group
        """
        metric_list = []
        for race in np.unique(races):
            y_ = y[races == race]
            y_hat_ = y_hat[races == race]
            efs_ = efs[races == race]
            metric_list.append(concordance_index(y_, y_hat_, efs_))
        metric = float(np.mean(metric_list) - np.sqrt(np.var(metric_list)))
        return metric

    def test_step(self, batch, batch_idx):
        """
        Same as training step but to log test data
        """
        x_cat, x_cont, y, efs = batch
        y_hat, emb = self(x_cat, x_cont)
        loss, race_loss = self.get_full_loss(efs, x_cat, y, y_hat)
        self.targets.append([y, y_hat.detach(), efs, x_cat[:, self.hparams.race_index]])
        self.log("test_loss", loss)
        return loss

    def on_test_epoch_end(self) -> None:
        """
        At the end of the test epoch, calculates and logs the concordance index for the test set
        """
        cindex, metric = self._calc_cindex()
        self.log("test_cindex", metric, on_epoch=True, prog_bar=True, logger=True)
        self.log("test_cindex_simple", cindex, on_epoch=True, prog_bar=True, logger=True)
        self.targets.clear()


    def configure_optimizers(self):
        """
        configures the optimizer and learning rate scheduler:
        * Optimizer: Adam optimizer with weight decay (L2 regularization).
        * Scheduler: Cosine Annealing scheduler, which adjusts the learning rate according to a cosine curve.
        """
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        scheduler_config = {
            "scheduler": torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer,
                T_max=45,
                eta_min=6e-3
            ),
            "interval": "epoch",
            "frequency": 1,
            "strict": False,
        }

        return {"optimizer": optimizer, "lr_scheduler": scheduler_config}

## Model Training and Validation

**Main Function Workflow**

The main function is the primary entry point of the script, controlling the entire model training process. Here's what it does step-by-step:

**Data Preparation:** It loads the training and testing datasets. It then adds placeholder columns, <code>efs_time</code> and <code>efs</code>, to the test data with a value of 1. This is done to ensure the test data has the same structure as the training data for later processing, even though these values will be overwritten during the actual prediction phase.

**Stratified K-Fold Cross-Validation:** This is a crucial part of the script. It uses StratifiedKFold with 5 splits to divide the original training data. The stratification criterion is a combination of <code>race_group</code> and whether the patient's <code>age_at_hct</code> is equal to 0.044 (indicating a newborn). This ensures that each fold has a similar distribution of these key features, which is important for training a robust and fair model, especially given the <code>race_group</code> fairness metric.

**Training Loop:** The code iterates through each of the 5 folds.

- **Data Preprocessing:** For each fold, it splits the data into a training set (<code>train</code>) and a validation set (<code>val</code>). The <code>preprocess_data</code> function is then called to handle missing values, standardize numerical features, and perform label encoding for categorical features, creating PyTorch DataLoaders.
- **Model Training and Validation:** The <code>train_final</code> function is called to train the LitNN model on the current fold's training data (<code>dl_train</code>). After fitting, the <code>trainer.test()</code> method is called on the validation data (<code>dl_val)</code>.
- **Model Storage:** The trained model from each fold is appended to the models list, and the corresponding data transformers are saved to <code>transformers_list</code>.

**Final Return:** After the loop completes, the function returns a tuple containing the list of trained models, the preprocessed test data, the original training data, the list of categorical columns, and the list of data transformers. The script does not make predictions on the test dataset within this function; it only prepares and trains the models, which can then be used later for prediction.

**Helper function**

The `train_final` Function covers the model training logic for a single data fold.

**Hyperparameter Initialization**
It defines the default hyperparameters for the `LitNN` model, as `embedding_dim`, `hidden_dim`, `lr`, `dropout`, `aux_weight`, `margin`, and `weight_decay`. These values can be overridden if a specific `hparams` dictionary is provided.

**Model Initialization**
It creates an instance of the `LitNN` model, passing in the number of continuous features, categorical cardinalities, and the hyperparameters. It also finds the index of the `race_group` column, which is crucial for calculating the fairness-aware loss.

**PyTorch Lightning Trainer**
It sets up the `pl.Trainer` with specific configurations for training on a GPU and logging.

- `accelerator='cuda'`: Configures the trainer to use a **GPU** for faster computation.
- `max_epochs=50`: The model will train for a maximum of **50 epochs**.
- `log_every_n_steps=6`: It logs training metrics every **6 batches**.
- `Callbacks`: It includes several callbacks to monitor and improve training:
    - `ModelCheckpoint`: Saves the best-performing model based on the lowest **validation loss** (`val_loss`).
    - `LearningRateMonitor`: Logs the learning rate changes over the training epochs.
    - `TQDMProgressBar`: Displays a progress bar during training.
    - `StochasticWeightAveraging`: A technique that averages the model's weights from a specific epoch (`swa_epoch_start=40`) to improve generalization.

**Model Fitting and Testing**
The `trainer.fit` method initiates the training process using the training data (`dl_train`). After training is complete, the `trainer.test` method evaluates the final model on the validation set (`dl_val`) and logs its performance. Finally, it returns the trained model in evaluation mode (`model.eval()`) for future inference on new data.

In [5]:
def train_final(X_num_train, dl_train, dl_val, transformers, hparams=None, categorical_cols=None):
    """
    Defines model hyperparameters and fit the model.
    """
    if hparams is None:
        hparams = {            
            "embedding_dim": 16,
            "projection_dim": 128,
            "hidden_dim": 64,
            "lr": 0.01,
            "dropout": 0.01,
            "aux_weight": 0.2,
            "margin": 0.2,
            "weight_decay": 0.0002
        }
    model = LitNN(
        continuous_dim=X_num_train.shape[1],
        categorical_cardinality=[len(t.classes_) for t in transformers],
        race_index=categorical_cols.index("race_group"),
        **hparams
    )
    checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor="val_loss", save_top_k=1)
    # Set up trainer
    trainer = pl.Trainer(
        accelerator='cuda',
        max_epochs=50,
        log_every_n_steps=6,
        callbacks=[
            checkpoint_callback,
            LearningRateMonitor(logging_interval='epoch'),
            TQDMProgressBar(),
            StochasticWeightAveraging(swa_lrs=1e-5, swa_epoch_start=40, annealing_epochs=15)
        ],
    )
    trainer.fit(model, dl_train)
    trainer.test(model, dl_val)
    return model.eval()


def main(hparams=None):
    """
    Main function to train the model.
    The steps are as follows:
    * Load data and fill efs and efs_time for test data with 1
    * Get categorical and numerical columns
    * Split the train data on the stratified criterion: race_group * newborns yes/no
    * Preprocess the fold data (create dataloaders)
    * Train the model
    * Return necessary data for prediction
    """
    test, train_original = load_data()
    test['efs_time'] = 1
    test['efs'] = 1
    categorical_cols, numerical = get_feature_types(train_original)
    kf = StratifiedKFold(n_splits=5, shuffle=True)
    models = []
    transformers_list = []
    
    for i, (train_index, val_index) in enumerate(
        kf.split(
            train_original, 
            train_original.race_group.astype(str) + (train_original.age_at_hct == 0.044).astype(str)
        )
    ):
        train = train_original.iloc[train_index].copy()
        val = train_original.iloc[val_index].copy()
        X_cat_val, X_num_train, X_num_val, dl_train, dl_val, transformers = preprocess_data(train, val)
        model = train_final(X_num_train, dl_train, dl_val, transformers, categorical_cols=categorical_cols, hparams=hparams)
        models.append(model)
        transformers_list.append(transformers)
    
    return models, test, train_original, categorical_cols, transformers_list

hparams = None
res = main(hparams)

Test shape: (3, 59)
Train shape: (28800, 61)


2025-09-18 23:06:35.757557: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1758236795.780704     375 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1758236795.787794     375 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Training: |          | 0/? [00:00<?, ?it/s]

Testing: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Testing: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Testing: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Testing: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Testing: |          | 0/? [00:00<?, ?it/s]

# Model Hyperparameters Tuning

**Hyperparameter Tuning Steps with Optuna**

Hyperparameter tuning is managed by the **Optuna** framework. The `objective` function integrates the model with Optuna. This function is called repeatedly, with each call representing a "trial" with a different set of hyperparameters. The goal is to maximize a composite metric to find the best combination of parameters. The process works as follows:

1.  **Optuna Study**: An Optuna study is created with the objective of **"maximizing"** a metric. In this case, the target is a **composite metric** that combines the concordance index and validation loss.
2.  **Trials**: The `study.optimize()` function iteratively calls the `objective` function for a specified number of `n_trials` (**20** in this example).
3.  **Hyperparameter Suggestion**: Inside each trial, the `trial` object suggests values for the following hyperparameters from predefined ranges.
4.  **Model Training**: A new `LitNN` model is instantiated with the suggested hyperparameters. The model is trained using a `PyTorch Lightning Trainer` on a **single data fold** to save time. An `EarlyStopping` callback is used to stop training if the validation loss does not improve for a certain number of epochs (`patience=5`), which makes the process more efficient.
5.  **Metric Calculation**: Once training is complete, the `objective` function retrieves the final `val_loss`, `cindex_simple`, and `cindex` from the trainer's metrics. It then computes a composite metric as `(cindex + cindex_simple) / (1 + val_loss)`.
6.  **Optimization**: Optuna analyzes the returned composite metric from each trial and uses an intelligent algorithm to decide which set of hyperparameters to try next, aiming to find the combination that leads to the highest metric. After all trials are complete, the study identifies and reports the best set of hyperparameters found.

**List of Hyperparameters**

- `embedding_dim`: Suggested as an integer between 8 and 32.
- `projection_dim`: Suggested as an integer between 32 and 256.
- `hidden_dim`: Suggested as an integer between 32 and 256.
- `lr`: Suggested as a floating-point number between 1e-5 and 1e-1 on a logarithmic scale.
- `dropout`: Suggested as a floating-point number between 0.1 and 0.5.
- `aux_weight`: Suggested as a floating-point number between 0.1 and 0.5.
- `margin`: Suggested as a floating-point number between 0.1 and 0.5.
- `weight_decay`: Suggested as a floating-point number between 1e-5 and 1e-3 on a logarithmic scale.

**Theories and Concepts**

**Hyperparameter Tuning with Optuna**

This step uses optuna (bayasien hyperparameters tuning) to automated hyperparameter tuning. Instead of manually trying different combinations of hyperparameters (like learning rate, dropout, or layer dimensions), this process uses a library called Optuna to find the optimal set for you.

- *Optuna Framework:** Optuna is a specialized, open-source hyperparameter optimization framework. It automates the search for the best set of model parameters to maximize or minimize a specific objective. It uses sophisticated sampling algorithms to intelligently explore the hyperparameter space, making it much more efficient than a simple grid search or random search.
- **Objective Function:** The `objective` function is the core of the Optuna process. It acts as a black box that takes a trial (a single set of suggested hyperparameters from Optuna) and returns a performance metric. Optuna's goal is to find the combination of hyperparameters that maximizes the value returned by this function. In this case, the function trains a model with the trial's suggested parameters and returns a custom composite metric.
- **Trial:** The `optuna.Trial` object represents one single run of the training process with a unique combination of hyperparameters. Optuna suggests values for each hyperparameter within a defined range (e.g., `trial.suggest_float("lr", 1e-5, 1e-1, log=True`)). This allows the algorithm to explore a continuous or discrete range of values.

**Advanced Training Concepts**

The objective function uses several callbacks and metrics to make the tuning process more efficient and effective.

- **Early Stopping:** The `EarlyStopping` callback is a crucial technique to prevent overfitting. It monitors the validation loss (`val_loss`) and stops the training process early if the model's performance on the validation set stops improving for a certain number of epochs (`patience=5`). This saves significant computational resources and prevents the model from memorizing the training data.
- **Composite Metric:** The objective function returns a composite metric calculated as `(cindex + cindex_simple) / (1 + val_loss)`. This is a clever way to combine multiple, often conflicting, evaluation goals into a single value that Optuna can optimize. By maximizing this composite metric, you are encouraging the model to achieve a high concordance index (c-index), which is a key metric for survival analysis, while also keeping the validation loss low.
- **Logging:** The `TensorBoardLogger` is used to log the training and validation metrics for each trial. This creates a detailed record of the entire tuning process, which can be visualized in TensorBoard to analyze the performance of different hyperparameter combinations and better understand the model's behavior.

**Revise Model Architecture**

Revise model architecture for suggest hyperparameters for hyperparameters tuning

In [None]:
class CatEmbeddings(nn.Module):
    """
    Embedding module for the categorical dataframe.
    """
    def __init__(
        self,
        projection_dim: int,
        categorical_cardinality: List[int],
        embedding_dim: int
    ):
        super(CatEmbeddings, self).__init__()
        self.embeddings = nn.ModuleList([
            nn.Embedding(cardinality, embedding_dim)
            for cardinality in categorical_cardinality
        ])
        self.projection = nn.Sequential(
            nn.Linear(embedding_dim * len(categorical_cardinality), projection_dim),
            nn.GELU(),
            nn.Linear(projection_dim, projection_dim)
        )

    def forward(self, x_cat):
        x_cat = [embedding(x_cat[:, i]) for i, embedding in enumerate(self.embeddings)]
        x_cat = torch.cat(x_cat, dim=1)
        return self.projection(x_cat)

class NN(nn.Module):
    """
    Train a model on both categorical embeddings and numerical data.
    """
    def __init__(
            self,
            continuous_dim: int,
            categorical_cardinality: List[int],
            embedding_dim: int,
            projection_dim: int,
            hidden_dim: int,
            dropout: float = 0
    ):
        super(NN, self).__init__()
        self.embeddings = CatEmbeddings(projection_dim, categorical_cardinality, embedding_dim)
        # Replace ODST with nn.Linear as a fallback; replace with actual ODST if available
        self.mlp = nn.Sequential(
            nn.Linear(projection_dim + continuous_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.Dropout(dropout)
        )
        self.out = nn.Linear(hidden_dim, 1)
        self.dropout = nn.Dropout(dropout)

        # Initialize weights
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x_cat, x_cont):
        x = self.embeddings(x_cat)
        x = torch.cat([x, x_cont], dim=1)
        x = self.dropout(x)
        x = self.mlp(x)
        return self.out(x), x

@functools.lru_cache
def combinations(N):
    ind = torch.arange(N)
    comb = torch.combinations(ind, r=2)
    return comb.cuda()

class LitNN(pl.LightningModule):
    def __init__(
            self,
            continuous_dim: int,
            categorical_cardinality: List[int],
            embedding_dim: int,
            projection_dim: int,
            hidden_dim: int,
            lr: float = 1e-3,
            dropout: float = 0.2,
            weight_decay: float = 1e-3,
            aux_weight: float = 0.1,
            margin: float = 0.5,
            race_index: int = 0
    ):
        super(LitNN, self).__init__()
        self.save_hyperparameters()
        self.model = NN(
            continuous_dim=self.hparams.continuous_dim,
            categorical_cardinality=self.hparams.categorical_cardinality,
            embedding_dim=self.hparams.embedding_dim,
            projection_dim=self.hparams.projection_dim,
            hidden_dim=self.hparams.hidden_dim,
            dropout=self.hparams.dropout
        )
        self.targets = []
        self.aux_cls = nn.Sequential(
            nn.Linear(self.hparams.hidden_dim, self.hparams.hidden_dim // 3),
            nn.GELU(),
            nn.Linear(self.hparams.hidden_dim // 3, 1)
        )

    def on_before_optimizer_step(self, optimizer):
        norms = grad_norm(self.model, norm_type=2)
        self.log_dict(norms)

    def forward(self, x_cat, x_cont):
        x, emb = self.model(x_cat, x_cont)
        return x.squeeze(1), emb

    def training_step(self, batch, batch_idx):
        x_cat, x_cont, y, efs = batch
        y_hat, emb = self(x_cat, x_cont)
        aux_pred = self.aux_cls(emb).squeeze(1)
        loss, race_loss = self.get_full_loss(efs, x_cat, y, y_hat)
        aux_loss = F.mse_loss(aux_pred, y, reduction='none')
        aux_mask = efs == 1
        aux_loss = (aux_loss * aux_mask).sum() / aux_mask.sum()
        self.log("train_loss", loss, on_epoch=True, prog_bar=True, logger=True)
        self.log("race_loss", race_loss, on_epoch=True, prog_bar=True, logger=True, on_step=False)
        self.log("aux_loss", aux_loss, on_epoch=True, prog_bar=True, logger=True, on_step=False)
        return loss + aux_loss * self.hparams.aux_weight

    def get_full_loss(self, efs, x_cat, y, y_hat):
        loss = self.calc_loss(y, y_hat, efs)
        race_loss = self.get_race_losses(efs, x_cat, y, y_hat)
        loss += 0.1 * race_loss
        return loss, race_loss

    def get_race_losses(self, efs, x_cat, y, y_hat):
        races = torch.unique(x_cat[:, self.hparams.race_index])
        race_losses = []
        for race in races:
            ind = x_cat[:, self.hparams.race_index] == race
            race_losses.append(self.calc_loss(y[ind], y_hat[ind], efs[ind]))
        race_loss = sum(race_losses) / len(race_losses)
        races_loss_std = sum((r - race_loss)**2 for r in race_losses) / len(race_losses)
        return torch.sqrt(races_loss_std)

    def calc_loss(self, y, y_hat, efs):
        N = y.shape[0]
        comb = combinations(N)
        comb = comb[(efs[comb[:, 0]] == 1) | (efs[comb[:, 1]] == 1)]
        pred_left = y_hat[comb[:, 0]]
        pred_right = y_hat[comb[:, 1]]
        y_left = y[comb[:, 0]]
        y_right = y[comb[:, 1]]
        y = 2 * (y_left > y_right).int() - 1
        loss = F.relu(-y * (pred_left - pred_right) + self.hparams.margin)
        mask = self.get_mask(comb, efs, y_left, y_right)
        loss = (loss.double() * mask.double()).sum() / mask.sum()
        return loss

    def get_mask(self, comb, efs, y_left, y_right):
        left_outlived = y_left >= y_right
        left_1_right_0 = (efs[comb[:, 0]] == 1) & (efs[comb[:, 1]] == 0)
        mask2 = (left_outlived & left_1_right_0)
        right_outlived = y_right >= y_left
        right_1_left_0 = (efs[comb[:, 1]] == 1) & (efs[comb[:, 0]] == 0)
        mask2 |= (right_outlived & right_1_left_0)
        mask2 = ~mask2
        return mask2

    def validation_step(self, batch, batch_idx):
        x_cat, x_cont, y, efs = batch
        y_hat, emb = self(x_cat, x_cont)
        loss, race_loss = self.get_full_loss(efs, x_cat, y, y_hat)
        self.targets.append([y, y_hat.detach(), efs, x_cat[:, self.hparams.race_index]])
        self.log("val_loss", loss, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def on_validation_epoch_end(self):
        cindex, metric = self._calc_cindex()
        self.log("cindex", metric, on_epoch=True, prog_bar=True, logger=True)
        self.log("cindex_simple", cindex, on_epoch=True, prog_bar=True, logger=True)
        self.targets.clear()

    def _calc_cindex(self):
        y = torch.cat([t[0] for t in self.targets]).cpu().numpy()
        y_hat = torch.cat([t[1] for t in self.targets]).cpu().numpy()
        efs = torch.cat([t[2] for t in self.targets]).cpu().numpy()
        races = torch.cat([t[3] for t in self.targets]).cpu().numpy()
        metric = self._metric(efs, races, y, y_hat)
        cindex = concordance_index(y, y_hat, efs)
        return cindex, metric

    def _metric(self, efs, races, y, y_hat):
        metric_list = []
        for race in np.unique(races):
            y_ = y[races == race]
            y_hat_ = y_hat[races == race]
            efs_ = efs[races == race]
            metric_list.append(concordance_index(y_, y_hat_, efs_))
        metric = float(np.mean(metric_list) - np.sqrt(np.var(metric_list)))
        return metric

    def test_step(self, batch, batch_idx):
        x_cat, x_cont, y, efs = batch
        y_hat, emb = self(x_cat, x_cont)
        loss, race_loss = self.get_full_loss(efs, x_cat, y, y_hat)
        self.targets.append([y, y_hat.detach(), efs, x_cat[:, self.hparams.race_index]])
        self.log("test_loss", loss)
        return loss

    def on_test_epoch_end(self):
        cindex, metric = self._calc_cindex()
        self.log("test_cindex", metric, on_epoch=True, prog_bar=True, logger=True)
        self.log("test_cindex_simple", cindex, on_epoch=True, prog_bar=True, logger=True)
        self.targets.clear()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        scheduler_config = {
            "scheduler": torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer,
                T_max=45,
                eta_min=6e-3
            ),
            "interval": "epoch",
            "frequency": 1,
            "strict": False,
        }
        return {"optimizer": optimizer, "lr_scheduler": scheduler_config}

**Define Objective and Tune Hyperparameters**

In [10]:
def objective(trial: optuna.Trial):
    """
    Optuna objective function to optimize hyperparameters for highest cindex and cindex_simple, lowest val_loss.
    """
    hparams = {
        "embedding_dim": trial.suggest_int("embedding_dim", 8, 32),
        "projection_dim": trial.suggest_int("projection_dim", 32, 256),
        "hidden_dim": trial.suggest_int("hidden_dim", 32, 256),
        "lr": trial.suggest_float("lr", 1e-5, 1e-1, log=True),
        "dropout": trial.suggest_float("dropout", 0.1, 0.5),
        "aux_weight": trial.suggest_float("aux_weight", 0.1, 0.5),
        "margin": trial.suggest_float("margin", 0.1, 0.5),
        "weight_decay": trial.suggest_float("weight_decay", 1e-5, 1e-3, log=True)
    }

    # Load and preprocess data
    test, train_original = load_data()
    categorical_cols, numerical_cols = get_feature_types(train_original)

    # Use a single fold for tuning
    kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    train_index, val_index = next(iter(kf.split(
        train_original, 
        train_original.race_group.astype(str) + (train_original.age_at_hct == 0.044).astype(str)
    )))
    
    train_df = train_original.iloc[train_index].copy()
    val_df = train_original.iloc[val_index].copy()
    
    X_cat_val, X_num_train, X_num_val, dl_train, dl_val, transformers = preprocess_data(train_df, val_df)
    
    # Compute categorical cardinality
    categorical_cardinality = [len(train_df[col].unique()) for col in categorical_cols]
    
    # Define model with trial hyperparameters
    model = LitNN(
        continuous_dim=X_num_train.shape[1],
        categorical_cardinality=categorical_cardinality,
        race_index=categorical_cols.index("race_group"),
        **hparams
    )

    # Set up trainer
    checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor="val_loss", save_top_k=1)
    early_stop_callback = EarlyStopping(
        monitor="val_loss",
        min_delta=0.00,
        patience=5,
        verbose=False,
        mode="min"
    )
    trainer = pl.Trainer(
        accelerator='gpu' if torch.cuda.is_available() else 'cpu',
        devices=1,
        max_epochs=50,
        callbacks=[
            early_stop_callback,
            LearningRateMonitor(logging_interval='epoch'),
            TQDMProgressBar(),
        ],
        logger=pl.loggers.TensorBoardLogger("optuna_logs", name="litnn_model_tuning")
    )
    
    # Fit the model
    trainer.fit(model, dl_train, dl_val)
    
    # Compute composite metric: maximize (cindex + cindex_simple) / (1 + val_loss)
    val_loss = trainer.callback_metrics["val_loss"].item()
    cindex = trainer.callback_metrics.get("cindex_simple", torch.tensor(0.0)).item()
    cindex_simple = trainer.callback_metrics.get("cindex", torch.tensor(0.0)).item()
    composite_metric = (cindex + cindex_simple) / (1 + val_loss)
    
    return composite_metric

def main():
    """
    Main function to perform hyperparameter tuning and train the final model.
    """
    # Step 1: Hyperparameter tuning with Optuna
    study = optuna.create_study(direction="maximize")  # Maximize composite metric
    print("Starting hyperparameter tuning with Optuna...")
    study.optimize(objective, n_trials=20)       ###########
    print("Tuning finished.")
    
    # Print best trial
    best_trial = study.best_trial
    print(f"Best Trial's Value (Composite Metric): {best_trial.value}")
    print("Best Hyperparameters:")
    for key, value in best_trial.params.items():
        print(f"  {key}: {value}")

# Run the main function
main()
print("done")

[I 2025-09-18 23:28:04,214] A new study created in memory with name: no-name-89bfb4c0-2072-4857-8da2-1fecdd038d5f


Starting hyperparameter tuning with Optuna...
Test shape: (3, 59)
Train shape: (28800, 61)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[I 2025-09-18 23:28:50,460] Trial 0 finished with value: 0.9994189485288872 and parameters: {'embedding_dim': 32, 'projection_dim': 234, 'hidden_dim': 149, 'lr': 0.054418283242056706, 'dropout': 0.2296179638417638, 'aux_weight': 0.3930663389566441, 'margin': 0.4769248607528731, 'weight_decay': 5.080728462171927e-05}. Best is trial 0 with value: 0.9994189485288872.


Test shape: (3, 59)
Train shape: (28800, 61)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[I 2025-09-18 23:29:41,150] Trial 1 finished with value: 1.0190926788409789 and parameters: {'embedding_dim': 10, 'projection_dim': 250, 'hidden_dim': 42, 'lr': 1.2287968540875901e-05, 'dropout': 0.20678992843543187, 'aux_weight': 0.21930189548659854, 'margin': 0.3614120628400075, 'weight_decay': 0.00018993187903505032}. Best is trial 1 with value: 1.0190926788409789.


Test shape: (3, 59)
Train shape: (28800, 61)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[I 2025-09-18 23:30:24,736] Trial 2 finished with value: 1.065190710674768 and parameters: {'embedding_dim': 14, 'projection_dim': 166, 'hidden_dim': 228, 'lr': 0.0014499935879020527, 'dropout': 0.4053521347270136, 'aux_weight': 0.15622465652675271, 'margin': 0.3394803798304513, 'weight_decay': 3.8100709912601955e-05}. Best is trial 2 with value: 1.065190710674768.


Test shape: (3, 59)
Train shape: (28800, 61)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[I 2025-09-18 23:30:52,290] Trial 3 finished with value: 1.0970779506771304 and parameters: {'embedding_dim': 32, 'projection_dim': 156, 'hidden_dim': 253, 'lr': 1.1547322494470239e-05, 'dropout': 0.11467411845325515, 'aux_weight': 0.4419376630490587, 'margin': 0.24625035095381154, 'weight_decay': 0.0008188138826680629}. Best is trial 3 with value: 1.0970779506771304.


Test shape: (3, 59)
Train shape: (28800, 61)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[I 2025-09-18 23:31:51,031] Trial 4 finished with value: 1.199035209325849 and parameters: {'embedding_dim': 9, 'projection_dim': 81, 'hidden_dim': 186, 'lr': 0.018940284547140813, 'dropout': 0.23198688820460645, 'aux_weight': 0.12378688654665725, 'margin': 0.17305265912236567, 'weight_decay': 0.0007548149916455488}. Best is trial 4 with value: 1.199035209325849.


Test shape: (3, 59)
Train shape: (28800, 61)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[I 2025-09-18 23:32:38,240] Trial 5 finished with value: 1.2453148359938238 and parameters: {'embedding_dim': 26, 'projection_dim': 126, 'hidden_dim': 36, 'lr': 0.0034666451532109847, 'dropout': 0.304655924809626, 'aux_weight': 0.4266769879317067, 'margin': 0.1260948074073776, 'weight_decay': 4.409986715453138e-05}. Best is trial 5 with value: 1.2453148359938238.


Test shape: (3, 59)
Train shape: (28800, 61)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[I 2025-09-18 23:33:21,844] Trial 6 finished with value: 0.9629996119623356 and parameters: {'embedding_dim': 22, 'projection_dim': 147, 'hidden_dim': 121, 'lr': 1.89405508722489e-05, 'dropout': 0.10260117055667682, 'aux_weight': 0.18383343312978684, 'margin': 0.410301586285286, 'weight_decay': 2.266014601802455e-05}. Best is trial 5 with value: 1.2453148359938238.


Test shape: (3, 59)
Train shape: (28800, 61)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[I 2025-09-18 23:34:07,124] Trial 7 finished with value: 1.0209154943385814 and parameters: {'embedding_dim': 25, 'projection_dim': 166, 'hidden_dim': 113, 'lr': 0.00036214559330899484, 'dropout': 0.317761110645121, 'aux_weight': 0.3812668116587311, 'margin': 0.3894963843770509, 'weight_decay': 0.0009743299481454929}. Best is trial 5 with value: 1.2453148359938238.


Test shape: (3, 59)
Train shape: (28800, 61)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[I 2025-09-18 23:34:35,053] Trial 8 finished with value: 0.9990877789240752 and parameters: {'embedding_dim': 19, 'projection_dim': 100, 'hidden_dim': 128, 'lr': 0.00021193359832392405, 'dropout': 0.12387806500650021, 'aux_weight': 0.33942301669548514, 'margin': 0.4423284060334375, 'weight_decay': 1.4667973434496696e-05}. Best is trial 5 with value: 1.2453148359938238.


Test shape: (3, 59)
Train shape: (28800, 61)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[I 2025-09-18 23:35:25,473] Trial 9 finished with value: 1.085162225429516 and parameters: {'embedding_dim': 31, 'projection_dim': 86, 'hidden_dim': 186, 'lr': 2.0517377026684104e-05, 'dropout': 0.275357493412029, 'aux_weight': 0.20555649502147713, 'margin': 0.3337095472536006, 'weight_decay': 4.1389369308736945e-05}. Best is trial 5 with value: 1.2453148359938238.


Test shape: (3, 59)
Train shape: (28800, 61)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[I 2025-09-18 23:36:33,599] Trial 10 finished with value: 1.2584780502824535 and parameters: {'embedding_dim': 26, 'projection_dim': 32, 'hidden_dim': 32, 'lr': 0.0032541057251796093, 'dropout': 0.44326888993714186, 'aux_weight': 0.4891574850426208, 'margin': 0.1032482635519678, 'weight_decay': 0.0001649172892397375}. Best is trial 10 with value: 1.2584780502824535.


Test shape: (3, 59)
Train shape: (28800, 61)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[I 2025-09-18 23:36:55,799] Trial 11 finished with value: 1.2355199728200166 and parameters: {'embedding_dim': 26, 'projection_dim': 34, 'hidden_dim': 49, 'lr': 0.004482340212923925, 'dropout': 0.48416036996353795, 'aux_weight': 0.4890633531363416, 'margin': 0.11474570976116434, 'weight_decay': 0.0001889406839306271}. Best is trial 10 with value: 1.2584780502824535.


Test shape: (3, 59)
Train shape: (28800, 61)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[I 2025-09-18 23:37:40,506] Trial 12 finished with value: 1.2609244782178108 and parameters: {'embedding_dim': 27, 'projection_dim': 43, 'hidden_dim': 76, 'lr': 0.00487099942060652, 'dropout': 0.3888623388174608, 'aux_weight': 0.4963105348090965, 'margin': 0.10024871755788853, 'weight_decay': 0.00011459557069590687}. Best is trial 12 with value: 1.2609244782178108.


Test shape: (3, 59)
Train shape: (28800, 61)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[I 2025-09-18 23:38:32,654] Trial 13 finished with value: 1.1655544609767619 and parameters: {'embedding_dim': 20, 'projection_dim': 33, 'hidden_dim': 78, 'lr': 0.009328951303661293, 'dropout': 0.4061491426047079, 'aux_weight': 0.48234280918914646, 'margin': 0.21746042074358918, 'weight_decay': 0.00013909420684867158}. Best is trial 12 with value: 1.2609244782178108.


Test shape: (3, 59)
Train shape: (28800, 61)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[I 2025-09-18 23:39:28,172] Trial 14 finished with value: 1.205614909930437 and parameters: {'embedding_dim': 28, 'projection_dim': 58, 'hidden_dim': 81, 'lr': 0.0007324421572275496, 'dropout': 0.4190499665036119, 'aux_weight': 0.273623223914815, 'margin': 0.17151114003359197, 'weight_decay': 0.00034459252575959256}. Best is trial 12 with value: 1.2609244782178108.


Test shape: (3, 59)
Train shape: (28800, 61)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[I 2025-09-18 23:40:20,513] Trial 15 finished with value: 1.1522418140647703 and parameters: {'embedding_dim': 23, 'projection_dim': 57, 'hidden_dim': 78, 'lr': 0.05933171874931837, 'dropout': 0.4903484193689721, 'aux_weight': 0.49639466537198634, 'margin': 0.23921917874756204, 'weight_decay': 8.710304029765529e-05}. Best is trial 12 with value: 1.2609244782178108.


Test shape: (3, 59)
Train shape: (28800, 61)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[I 2025-09-18 23:41:35,710] Trial 16 finished with value: 1.2409836379313164 and parameters: {'embedding_dim': 17, 'projection_dim': 205, 'hidden_dim': 61, 'lr': 8.672664267563006e-05, 'dropout': 0.3644395081429718, 'aux_weight': 0.28352437752445103, 'margin': 0.10238711418770544, 'weight_decay': 0.0003627721847073471}. Best is trial 12 with value: 1.2609244782178108.


Test shape: (3, 59)
Train shape: (28800, 61)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[I 2025-09-18 23:42:32,121] Trial 17 finished with value: 1.1127257891526439 and parameters: {'embedding_dim': 29, 'projection_dim': 113, 'hidden_dim': 91, 'lr': 0.0024015748484050715, 'dropout': 0.45115126022543284, 'aux_weight': 0.3407650396929911, 'margin': 0.28706232131429366, 'weight_decay': 8.3616648160187e-05}. Best is trial 12 with value: 1.2609244782178108.


Test shape: (3, 59)
Train shape: (28800, 61)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[I 2025-09-18 23:43:10,089] Trial 18 finished with value: 1.2008448132105694 and parameters: {'embedding_dim': 28, 'projection_dim': 62, 'hidden_dim': 150, 'lr': 0.016832151927085953, 'dropout': 0.3687370705054854, 'aux_weight': 0.43747608552729866, 'margin': 0.16954358695408578, 'weight_decay': 0.00032805821339264963}. Best is trial 12 with value: 1.2609244782178108.


Test shape: (3, 59)
Train shape: (28800, 61)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[I 2025-09-18 23:44:17,718] Trial 19 finished with value: 1.2233311239084892 and parameters: {'embedding_dim': 23, 'projection_dim': 71, 'hidden_dim': 98, 'lr': 0.007137481108727292, 'dropout': 0.3530740984795029, 'aux_weight': 0.3870907280491791, 'margin': 0.1492411166307448, 'weight_decay': 0.00012386538643094616}. Best is trial 12 with value: 1.2609244782178108.


Tuning finished.
Best Trial's Value (Composite Metric): 1.2609244782178108
Best Hyperparameters:
  embedding_dim: 27
  projection_dim: 43
  hidden_dim: 76
  lr: 0.00487099942060652
  dropout: 0.3888623388174608
  aux_weight: 0.4963105348090965
  margin: 0.10024871755788853
  weight_decay: 0.00011459557069590687
done


# Final Model Training and Model Inference Test

**Final Model Training and Inference**

After identifying the best hyperparameters through Optuna, the model is re-trained on the full training dataset using a robust cross-validation approach. The final prediction is an ensemble of the predictions from each cross-validation fold.

**Retraining with Best Hyperparameters**

The main function orchestrates this final stage. It initializes a new StratifiedKFold object with 5 splits to ensure the distribution of `race_group` and a newborn indicator is consistent across all folds. It then iterates through each fold:

- **Data Split:** The training data is split into a training and a validation set for the current fold.
- **Model Training:** The `train_final` function is called to train the LitNN model using the best hyperparameters identified during the Optuna tuning phase. The model is trained for a maximum of 100 epochs with various callbacks, including `ModelCheckpoint`, `LearningRateMonitor`, `TQDMProgressBar`, and `StochasticWeightAveraging`.
- **Model Inference:** After training on a fold, the trained model is used to make predictions on the entire held-out test dataset. The predictions are stored in a NumPy array.
- **Ensemble Prediction:** This process is repeated for all 5 folds. The predictions from each fold are accumulated. The final prediction is the average of the predictions from all 5 models, which generally leads to a more robust and stable result.

**Model Testing and Submission**

The final step involves creating the submission file with the ensembled predictions.

- The `test_pred` array, which contains the sum of predictions from all folds, is averaged by the number of folds (5 in this case) and assigned to the prediction column in the submission DataFrame.
- The sign of the predictions is flipped (`-test_pred`) before being saved, as the model's output is likely a time-to-event score, where lower values correspond to higher risk (and thus a more negative score is a higher-risk patient).
- The final submission file, `submission.csv`, is saved, containing the unique patient IDs and their final predicted scores.

**Best Hyperparameters:**

- `embedding_dim`: 27
- `projection_dim`: 43
- `hidden_dim`: 76
- `lr`: 0.00487099942060652
- `dropout`: 0.3888623388174608
- `aux_weight`: 0.4963105348090965
- `margin`: 0.10024871755788853
- `weight_decay`: 0.00011459557069590687

In [13]:
def train_final(X_num_train, dl_train, dl_val, transformers, hparams=None, categorical_cols=None):
    """
    Defines model hyperparameters and fit the model.
    """
    if hparams is None:
        hparams = {            
            "embedding_dim": 27,
            "projection_dim": 43,
            "hidden_dim": 76,
            "lr": 0.00487,
            "dropout": 0.38886,
            "aux_weight": 0.49631,
            "margin": 0.10025,
            "weight_decay": 0.000115
        }
    model = LitNN(
        continuous_dim=X_num_train.shape[1],
        categorical_cardinality=[len(t.classes_) for t in transformers],
        race_index=categorical_cols.index("race_group"),
        **hparams
    )
    checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor="val_loss", save_top_k=1)
    # Set up trainer
    trainer = pl.Trainer(
        accelerator='cuda',
        max_epochs=100,                                 #### 56 is used in this competition but 100 epoch achieve the lowest validation loss
        log_every_n_steps=6,
        callbacks=[
            checkpoint_callback,
            LearningRateMonitor(logging_interval='epoch'),
            TQDMProgressBar(),
            StochasticWeightAveraging(swa_lrs=1e-5, swa_epoch_start=40, annealing_epochs=15)
        ],
    )
    trainer.fit(model, dl_train)
    trainer.test(model, dl_val)
    return model.eval()

def main(hparams):
    """
    Main function to train the model.
    The steps are as following :
    * load data and fill efs and efs time for test data with 1
    * initialize pred array with 0
    * get categorical and numerical columns
    * split the train data on the stratified criterion : race_group * newborns yes/no
    * preprocess the fold data (create dataloaders)
    * train the model and create final submission output
    """
    test, train_original = load_data()
    test['efs_time'] = 1
    test['efs'] = 1
    test_pred = np.zeros(test.shape[0])
    categorical_cols, numerical = get_feature_types(train_original)
    kf = StratifiedKFold(n_splits=5, shuffle=True, )               ####
    for i, (train_index, test_index) in enumerate(
        kf.split(
            train_original, train_original.race_group.astype(str) + (train_original.age_at_hct == 0.044).astype(str)
        )
    ):
        tt = train_original.copy()
        train = tt.iloc[train_index]
        val = tt.iloc[test_index]
        X_cat_val, X_num_train, X_num_val, dl_train, dl_val, transformers = preprocess_data(train, val)
        model = train_final(X_num_train, dl_train, dl_val, transformers, categorical_cols=categorical_cols)
        # Create submission
        train = tt.iloc[train_index]
        X_cat_val, X_num_train, X_num_val, dl_train, dl_val, transformers = preprocess_data(train, test)
        pred, _ = model.cuda().eval()(
            torch.tensor(X_cat_val, dtype=torch.long).cuda(),
            torch.tensor(X_num_val, dtype=torch.float32).cuda()
        )
        test_pred += pred.detach().cpu().numpy()
        
    subm_data = pd.read_csv("/kaggle/input/equity-post-HCT-survival-predictions/sample_submission.csv")
    subm_data['prediction'] = -test_pred
    subm_data.to_csv('submission.csv', index=False)
    
    display(subm_data.head())
    return


hparams = None
res = main(hparams)
print("done")

Test shape: (3, 59)
Train shape: (28800, 61)


Training: |          | 0/? [00:00<?, ?it/s]

Testing: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Testing: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Testing: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Testing: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Testing: |          | 0/? [00:00<?, ?it/s]

Unnamed: 0,ID,prediction
0,28800,-0.740163
1,28801,-0.006642
2,28802,-1.153348


done
