# Competition Overview

**Competition Summary**

This Kaggle competition challenges participants to develop predictive models for allogeneic Hematopoietic Cell Transplantation (HCT) survival. The core objective is to create models that are not only accurate but also fair and equitable for patients across diverse backgrounds, particularly different racial groups. Participants will use synthetic data that mirrors real-world patient information to build their models. The competition runs from December 4, 2024, to March 5, 2025.

### [Link to competition](https://www.kaggle.com/competitions/equity-post-HCT-survival-predictions/overview)

## Evaluation Criteria

**Evaluation Criteria**

Model performance will be evaluated using a custom metric called the Stratified Concordance Index (C-index). This metric is an adaptation of the standard C-index, designed to ensure that the model's predictive power is equitable across different racial groups. A higher C-index (closer to 1.0) indicates a more accurate and fairer model. The final score is calculated as the mean C-index across all racial groups minus the standard deviation of those scores.

**Concordance index**

It represents the global assessment of the model discrimination power: this is the model’s ability to correctly provide a reliable ranking of the survival times based on the individual risk scores. It can be computed with the following formula:

![image.png](attachment:4c245296-7fc1-4b42-9904-6cfa62fbc938.png)

The concordance index is a value between 0 and 1 where:

- 0.5 is the expected result from random predictions,
- 1.0 is a perfect concordance and,
- 0.0 is perfect anti-concordance (multiply predictions with -1 to get 1.0)

References : [CIBMTR - Equity in post-HCT Survival Predictions](https://www.kaggle.com/competitions/equity-post-HCT-survival-predictions/overview)

## Submission Requirements

**Submission Requirements**

The submission file should contain a real-valued risk score for each patient in the test dataset, with a higher score indicating a higher likelihood of the target event.

# Dataset Description

The dataset for this competition contains 59 features related to hematopoietic stem cell transplantation (HSCT). These variables cover a wide range of demographic and medical information for both the transplant recipients and donors, such as age, sex, ethnicity, and disease status.

**Data and Target Variables**

The primary goal is to predict event-free survival (efs), which is the key target variable. The time to this event is captured in the efs_time variable. Together, these two variables are used for a censored time-to-event analysis.

**Included Files**

The competition provides four main files:
- train.csv: The training dataset with the efs target.
- test.csv: The dataset for which you must predict efs.
- sample_submission.csv: An example file showing the correct submission format.
- data_dictionary.csv: A list of all variables and their descriptions.

# Data Dictionary

<style type="text/css">
.tg  {border-collapse:collapse;border-spacing:0;}
.tg td{border-color:black;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;
  overflow:hidden;padding:10px 5px;word-break:normal;}
.tg th{border-color:black;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px;
  font-weight:normal;overflow:hidden;padding:10px 5px;word-break:normal;}
.tg .tg-za14{border-color:inherit;text-align:left;vertical-align:bottom}
.tg .tg-7zrl{text-align:left;vertical-align:bottom}
</style>
<table class="tg"><thead>
  <tr>
    <th class="tg-za14">variable</th>
    <th class="tg-7zrl">description</th>
    <th class="tg-7zrl">type</th>
  </tr></thead>
<tbody>
  <tr>
    <td class="tg-7zrl">dri_score</td>
    <td class="tg-7zrl">Refined disease risk index</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">psych_disturb</td>
    <td class="tg-7zrl">Psychiatric disturbance</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">cyto_score</td>
    <td class="tg-7zrl">Cytogenetic score</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">diabetes</td>
    <td class="tg-7zrl">Diabetes</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">hla_match_c_high</td>
    <td class="tg-7zrl">Recipient / 1st donor allele level (high resolution) matching at HLA-C</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">hla_high_res_8</td>
    <td class="tg-7zrl">Recipient / 1st donor allele-level (high resolution) matching at&nbsp;&nbsp;&nbsp;HLA-A,-B,-C,-DRB1</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">tbi_status</td>
    <td class="tg-7zrl">TBI</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">arrhythmia</td>
    <td class="tg-7zrl">Arrhythmia</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">hla_low_res_6</td>
    <td class="tg-7zrl">Recipient / 1st donor antigen-level (low resolution) matching at&nbsp;&nbsp;&nbsp;HLA-A,-B,-DRB1</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">graft_type</td>
    <td class="tg-7zrl">Graft type</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">vent_hist</td>
    <td class="tg-7zrl">History of mechanical ventilation</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">renal_issue</td>
    <td class="tg-7zrl">Renal, moderate / severe</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">pulm_severe</td>
    <td class="tg-7zrl">Pulmonary, severe</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">prim_disease_hct</td>
    <td class="tg-7zrl">Primary disease for HCT</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">hla_high_res_6</td>
    <td class="tg-7zrl">Recipient / 1st donor allele-level (high resolution) matching at&nbsp;&nbsp;&nbsp;HLA-A,-B,-DRB1</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">cmv_status</td>
    <td class="tg-7zrl">Donor/recipient CMV serostatus</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">hla_high_res_10</td>
    <td class="tg-7zrl">Recipient / 1st donor allele-level (high resolution) matching at&nbsp;&nbsp;&nbsp;HLA-A,-B,-C,-DRB1,-DQB1</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">hla_match_dqb1_high</td>
    <td class="tg-7zrl">Recipient / 1st donor allele level (high resolution) matching at HLA-DQB1</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">tce_imm_match</td>
    <td class="tg-7zrl">T-cell epitope immunogenicity/diversity match</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">hla_nmdp_6</td>
    <td class="tg-7zrl">Recipient / 1st donor matching at HLA-A(lo),-B(lo),-DRB1(hi)</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">hla_match_c_low</td>
    <td class="tg-7zrl">Recipient / 1st donor antigen level (low resolution) matching at HLA-C</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">rituximab</td>
    <td class="tg-7zrl">Rituximab given in conditioning</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">hla_match_drb1_low</td>
    <td class="tg-7zrl">Recipient / 1st donor antigen level (low resolution) matching at HLA-DRB1</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">hla_match_dqb1_low</td>
    <td class="tg-7zrl">Recipient / 1st donor antigen level (low resolution) matching at HLA-DQB1</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">prod_type</td>
    <td class="tg-7zrl">Product type</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">cyto_score_detail</td>
    <td class="tg-7zrl">Cytogenetics for DRI (AML/MDS)</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">conditioning_intensity</td>
    <td class="tg-7zrl">Computed planned conditioning intensity</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">ethnicity</td>
    <td class="tg-7zrl">Ethnicity</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">year_hct</td>
    <td class="tg-7zrl">Year of HCT</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">obesity</td>
    <td class="tg-7zrl">Obesity</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">mrd_hct</td>
    <td class="tg-7zrl">MRD at time of HCT (AML/ALL)</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">in_vivo_tcd</td>
    <td class="tg-7zrl">In-vivo T-cell depletion (ATG/alemtuzumab)</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">tce_match</td>
    <td class="tg-7zrl">T-cell epitope matching</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">hla_match_a_high</td>
    <td class="tg-7zrl">Recipient / 1st donor allele level (high resolution) matching at HLA-A</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">hepatic_severe</td>
    <td class="tg-7zrl">Hepatic, moderate / severe</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">donor_age</td>
    <td class="tg-7zrl">Donor age</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">prior_tumor</td>
    <td class="tg-7zrl">Solid tumor, prior</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">hla_match_b_low</td>
    <td class="tg-7zrl">Recipient / 1st donor antigen level (low resolution) matching at HLA-B</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">peptic_ulcer</td>
    <td class="tg-7zrl">Peptic ulcer</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">age_at_hct</td>
    <td class="tg-7zrl">Age at HCT</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">hla_match_a_low</td>
    <td class="tg-7zrl">Recipient / 1st donor antigen level (low resolution) matching at HLA-A</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">gvhd_proph</td>
    <td class="tg-7zrl">Planned GVHD prophylaxis</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">rheum_issue</td>
    <td class="tg-7zrl">Rheumatologic</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">sex_match</td>
    <td class="tg-7zrl">Donor/recipient sex match</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">hla_match_b_high</td>
    <td class="tg-7zrl">Recipient / 1st donor allele level (high resolution) matching at HLA-B</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">race_group</td>
    <td class="tg-7zrl">Race</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">comorbidity_score</td>
    <td class="tg-7zrl">Sorror comorbidity score</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">karnofsky_score</td>
    <td class="tg-7zrl">KPS at HCT</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">hepatic_mild</td>
    <td class="tg-7zrl">Hepatic, mild</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">tce_div_match</td>
    <td class="tg-7zrl">T-cell epitope matching</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">donor_related</td>
    <td class="tg-7zrl">Related vs. unrelated donor</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">melphalan_dose</td>
    <td class="tg-7zrl">Melphalan dose (mg/m^2)</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">hla_low_res_8</td>
    <td class="tg-7zrl">Recipient / 1st donor antigen-level (low resolution) matching at&nbsp;&nbsp;&nbsp;HLA-A,-B,-C,-DRB1</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">cardiac</td>
    <td class="tg-7zrl">Cardiac</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">hla_match_drb1_high</td>
    <td class="tg-7zrl">Recipient / 1st donor allele level (high resolution) matching at HLA-DRB1</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">pulm_moderate</td>
    <td class="tg-7zrl">Pulmonary, moderate</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">hla_low_res_10</td>
    <td class="tg-7zrl"></td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">efs</td>
    <td class="tg-7zrl">Event-free survival</td>
    <td class="tg-7zrl">Categorical</td>
  </tr>
  <tr>
    <td class="tg-7zrl">efs_time</td>
    <td class="tg-7zrl">Time to event-free survival, months</td>
    <td class="tg-7zrl">Numerical</td>
  </tr>
</tbody></table>

### [Link to dataset](https://www.kaggle.com/competitions/equity-post-HCT-survival-predictions/data)

# Pipeline Overview

The submission pipeline is a complete machine learning pipeline for a survival prediction task, from data loading to final submission file generation. The pipeline is designed around a 5-fold stratified cross-validation strategy to train a neural network model, which then produces an ensembled prediction for the test set.

**1. Data Preparation and Preprocessing**

The pipeline begins with the main function, which loads the raw training and test datasets. A crucial step is feature engineering, where the `add_features` function creates new features like is_cyto_score_same and modifies existing ones like `year_hct`. The test data's target variables (`efs` and `efs_time`) are initialized to a default value since they are unknown.

The core preprocessing logic, housed within `preprocess_data`, orchestrates several key steps within each cross-validation fold:

- **Feature Separation:** `get_feature_types` identifies categorical and numerical columns.
- **Categorical Transformation:** `get_categoricals` handles categorical features by removing columns with a single unique value and then applying Label Encoding to the rest.
- **Numerical Imputation and Scaling:** Missing values in numerical columns are filled using the mean (`SimpleImputer`), and the data is then standardized using `StandardScaler`.
- **PyTorch DataLoaders:** Finally, the preprocessed data is converted into PyTorch TensorDataset and DataLoader objects, which are optimized for training the neural network.

**2. Model Architecture and Training**

The model is a neural network defined within the LitNN class, which inherits from `pytorch_lightning.LightningModule` for streamlined training.

- **Model Layers:** The NN class defines the main network architecture. It uses a `CatEmbeddings` module to create dense representations of categorical features. These embeddings are then concatenated with the numerical features and passed through a multi-layer perceptron (MLP) to produce the final prediction.
- **Custom Loss Function:** The LitNN module defines a specialized loss function for survival data. This loss is a margin-based hinge loss that operates on all possible pairs of data points. It is designed to learn a correct ranking of event times, particularly for uncensored data, and is a core component for handling survival data.
- **Auxiliary and Fairness Losses:** To improve performance and fairness, the model incorporates two additional loss components:
  - An auxiliary loss on a secondary prediction task, designed to improve the model's intermediate representations.
  - A fairness-aware loss (`get_race_losses`) that minimizes the variance of the main loss across different `race_group` categories.

The `train_final` function uses a PyTorch Lightning Trainer to handle the training loop, incorporating callbacks for monitoring and an Adam optimizer with a Cosine Annealing learning rate scheduler.

**3. Cross-Validation and Inference**

The main function executes the training and inference loop using a StratifiedKFold strategy, which ensures the distribution of key features (`race_group` and `age_at_hct`) is maintained in each fold.

- **Fold Training:** For each of the 5 folds, a new model is trained on the training data and validated on the validation set.
- **Prediction Ensemble:** After a model is trained on a given fold, it is used to make predictions on the entire test dataset. The predictions from each of the 5 folds are then summed up.
- **Final Submission:** The summed predictions are averaged and saved to a `submission.csv` file, resulting in an ensemble-based prediction that is typically more robust than a single model's output.

# Related Notebooks

### [Model submission notebook](https://www.kaggle.com/code/misterfour/cibmtr-challenge-submission)
### [Inspiration notebook](https://www.kaggle.com/code/zyh1104/cibmtr-nn-negative-log-likelihood-loss)
### [Package install lightning and tarbular data](https://www.kaggle.com/code/dreamingtree/download-lightning-and-pytorch-tabular)
### [Package install evaluation metric](https://www.kaggle.com/code/cdeotte/pip-install-lifelines)

# Import Necessary Libraries

In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/pip-install-lifelines/fonttools-4.55.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
/kaggle/input/pip-install-lifelines/tzdata-2024.2-py2.py3-none-any.whl
/kaggle/input/pip-install-lifelines/kiwisolver-1.4.7-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl
/kaggle/input/pip-install-lifelines/interface_meta-1.3.0-py3-none-any.whl
/kaggle/input/pip-install-lifelines/scipy-1.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
/kaggle/input/pip-install-lifelines/pillow-11.0.0-cp310-cp310-manylinux_2_28_x86_64.whl
/kaggle/input/pip-install-lifelines/contourpy-1.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
/kaggle/input/pip-install-lifelines/pyparsing-3.2.0-py3-none-any.whl
/kaggle/input/pip-install-lifelines/wrapt-1.17.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl
/kaggle/input/pip-install-lifelines/__results__.html
/kaggle/input/pip-install-lifelines/cycler-0.12.1-py3-n

In [2]:
!pip install /kaggle/input/pip-install-lifelines/autograd-1.7.0-py3-none-any.whl
!pip install /kaggle/input/pip-install-lifelines/autograd-gamma-0.5.0.tar.gz
!pip install /kaggle/input/pip-install-lifelines/interface_meta-1.3.0-py3-none-any.whl
!pip install /kaggle/input/pip-install-lifelines/formulaic-1.0.2-py3-none-any.whl
!pip install /kaggle/input/pip-install-lifelines/lifelines-0.30.0-py3-none-any.whl
!pip install  /kaggle/input/download-lightning-and-pytorch-tabular/pytorch_lightning-2.4.0-py3-none-any.whl
!pip install  /kaggle/input/download-lightning-and-pytorch-tabular/scikit_learn-1.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
!pip install  /kaggle/input/download-lightning-and-pytorch-tabular/torchmetrics-1.5.2-py3-none-any.whl
!pip install  /kaggle/input/download-lightning-and-pytorch-tabular/pytorch_tabnet-4.1.0-py3-none-any.whl
!pip install  /kaggle/input/download-lightning-and-pytorch-tabular/einops-0.7.0-py3-none-any.whl
!pip install  /kaggle/input/download-lightning-and-pytorch-tabular/pytorch_tabular-1.1.1-py2.py3-none-any.whl

Processing /kaggle/input/pip-install-lifelines/autograd-1.7.0-py3-none-any.whl
autograd is already installed with the same version as the provided wheel. Use --force-reinstall to force an installation of the wheel.
Processing /kaggle/input/pip-install-lifelines/autograd-gamma-0.5.0.tar.gz
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: autograd-gamma
  Building wheel for autograd-gamma (setup.py) ... [?25l[?25hdone
  Created wheel for autograd-gamma: filename=autograd_gamma-0.5.0-py3-none-any.whl size=4031 sha256=9f8e0e05e05ea59826d45e61f6b83f9288b887020f9a2dc0dcc008dadc235092
  Stored in directory: /root/.cache/pip/wheels/6b/b5/e0/4c79e15c0b5f2c15ecf613c720bb20daab20a666eb67135155
Successfully built autograd-gamma
Installing collected packages: autograd-gamma
Successfully installed autograd-gamma-0.5.0
Processing /kaggle/input/pip-install-lifelines/interface_meta-1.3.0-py3-none-any.whl
Installing collected packages: interface-meta
Success

In [3]:
import numpy as np
import pandas as pd
import torch
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import LabelEncoder, StandardScaler
from torch.utils.data import TensorDataset
from warnings import filterwarnings
import functools
from typing import List
import json
import pytorch_lightning as pl
import numpy as np, pandas as pd
import matplotlib.pyplot as plt
import torch
from pytorch_lightning.callbacks import LearningRateMonitor, TQDMProgressBar
from pytorch_lightning.callbacks import StochasticWeightAveraging
from sklearn.model_selection import StratifiedKFold

pl.seed_everything(42)
import pytorch_lightning as pl
import numpy as np
import torch
from lifelines.utils import concordance_index
from pytorch_lightning.cli import ReduceLROnPlateau
from pytorch_tabular.models.common.layers import ODST
from torch import nn
from pytorch_lightning.utilities import grad_norm

filterwarnings('ignore')

# Data Loading and Feature Engineering

**Data Loading and Feature Engineering**

This step created consists of several utility functions that manage the entire data preprocessing pipeline for a machine learning model. This approach ensures data is consistently and correctly formatted, transformed, and ready for model training.

**Data Loading and Feature Engineering**

- **`load_data()`:** This function is the entry point for data handling. It reads the raw data files (train.csv and test.csv) and then calls `add_features()` to perform initial feature engineering on both datasets.
- **`add_features(df)`:** A simple but crucial step that creates new features to potentially improve model performance. It adds a new binary column (`is_cyto_score_same`) and modifies the `year_hct` column to be a smaller, relative value by subtracting 2000.

**Feature Separation and Transformation**

- **`get_feature_types(train)`:** This utility inspects the training data to automatically identify and separate column names into two lists: categorical and numerical. It excludes certain predefined columns like '`ID`' and the target variables '`efs`' and '`efs_time`'.
- **`get_categoricals(train, val)`:** This function handles the preparation of categorical data. It first removes any categorical columns that have a single unique value. It then handles any categories in the validation set that don't appear in the training set by converting them to NaN. Finally, it uses `get_X_cat()` to apply LabelEncoder to transform all categorical features into a numerical format, which is a requirement for the model's embedding layers.
- **`get_X_cat(df, cat_cols, transformers=None)`:** A helper function for `get_categoricals`, responsible for applying LabelEncoder. It's designed to either create new transformers for the training data or use existing ones (from the training set) to ensure consistent encoding for the validation or test set.

**Final Data Preparation**

- **`preprocess_data(train, val)`:** This is the central orchestrator of the preprocessing pipeline. After handling categorical data, it manages the numerical features. It uses SimpleImputer to fill missing values with the column mean and then standardizes the data using `StandardScaler`.
- **`init_dl(X_cat, X_num, df, training=False)`:** The final step in data preparation. This function takes the processed categorical and numerical data and creates a `torch.utils.data.DataLoader` object. A key step here is that it log-transforms the `efs_time` variable, which is a common practice for survival analysis tasks to handle skewed distributions. It sets a fixed batch size of 2048 and shuffles the data only for the training set.

**Theories and Concepts**

- **Data Preprocessing:** This is a foundational step in any machine learning project. Real-world data is often messy, with missing values, inconsistent formats, and features of different scales. Preprocessing transforms this raw data into a clean, structured format that models can effectively learn from.
- **Feature Engineering:** The `add_features` function demonstrates this concept. Feature engineering is the process of using domain knowledge to create new input features that are more informative and useful for the model. By creating `is_cyto_score_same`, the code provides the model with a direct and explicit signal that might be difficult to learn from the raw `cyto_score` and `cyto_score_detail` columns alone.
- **Categorical Data Encoding:** Machine learning algorithms, at their core, are mathematical and operate on numerical data. Categorical features (e.g., "red," "green," "blue") must be converted into a numerical representation. Label Encoding assigns a unique integer to each category. This is a simple and common method, though it can sometimes impose a false sense of ordinality on the data.
- **Handling Missing Data (Imputation):** Missing data can cause errors or lead to biased results. SimpleImputer with the mean strategy fills in missing values for numerical features with the average value of that column. This is a simple and effective method when the amount of missing data is small.
- **Feature Scaling:** The StandardScaler is used to standardize numerical features. This is a critical step for many algorithms, particularly those based on gradient descent (like neural networks) or those that rely on distance calculations (like k-nearest neighbors). Scaling ensures that no single feature dominates the learning process simply because it has a larger magnitude.
- **Data Leakage:** This is a major risk in machine learning pipelines. It occurs when information from the validation or test set inadvertently influences the training process. The code correctly avoids this by using `fit_transform` only on the training data and then applying the same fitted transformation (`transform`) to the validation data. This ensures the model is evaluated on a truly unseen dataset.
- **Batching and Data Loading:** The DataLoader is a key component for training deep learning models. Training a model on the entire dataset at once is computationally infeasible for large datasets. DataLoader breaks the dataset into smaller "mini-batches." This allows for more efficient memory usage, enables parallel processing on GPUs, and introduces a form of stochasticity that helps the optimization process converge.

In [4]:
# Utility function to handle categorical data transformations
def get_X_cat(df, cat_cols, transformers=None):
    """
    Apply a specific categorical data transformer or a LabelEncoder if None.
    """
    if transformers is None:
        transformers = [LabelEncoder().fit(df[col]) for col in cat_cols]
    return transformers, np.array(
        [transformer.transform(df[col]) for col, transformer in zip(cat_cols, transformers)]
    ).T


# Utility function to get feature types
def get_feature_types(train):
    """
    Utility function to return categorical and numerical column names.
    """
    categorical_cols = [col for i, col in enumerate(train.columns) if ((train[col].dtype == "object") | (2 < train[col].nunique() < 25))]
    RMV = ["ID", "efs", "efs_time", "y"]
    FEATURES = [c for c in train.columns if not c in RMV]
    numerical = [i for i in FEATURES if i not in categorical_cols]
    return categorical_cols, numerical


# Preprocessing function for categorical features
def get_categoricals(train, val):
    """
    Remove constant categorical columns and transform them using LabelEncoder.
    Return the label-transformers for each categorical column, categorical dataframes and numerical columns.
    """
    categorical_cols, numerical = get_feature_types(train)
    remove = []
    for col in categorical_cols:
        if train[col].nunique() == 1:
            remove.append(col)
        ind = ~val[col].isin(train[col])
        if ind.any():
            val.loc[ind, col] = np.nan
    categorical_cols = [col for col in categorical_cols if col not in remove]
    transformers, X_cat_train = get_X_cat(train, categorical_cols)
    _, X_cat_val = get_X_cat(val, categorical_cols, transformers)
    return X_cat_train, X_cat_val, numerical, transformers


# PyTorch Data Preparation
def init_dl(X_cat, X_num, df, training=False):
    """
    Initialize data loaders with 4 dimensions : categorical dataframe, numerical dataframe and target values (efs and efs_time).
    Notice that efs_time is log-transformed.
    Fix batch size to 2048 and return dataloader for training or validation depending on training value.
    """
    ds_train = TensorDataset(
        torch.tensor(X_cat, dtype=torch.long),
        torch.tensor(X_num, dtype=torch.float32),
        torch.tensor(df.efs_time.values, dtype=torch.float32).log(),
        torch.tensor(df.efs.values, dtype=torch.long)
    )
    bs = 2048
    dl_train = torch.utils.data.DataLoader(ds_train, batch_size=bs, pin_memory=True, shuffle=training)
    return dl_train


# Feature Engineering
def add_features(df):
    """
    Create some new features to help the model focus on specific patterns.
    """
    df['is_cyto_score_same'] = (df['cyto_score'] == df['cyto_score_detail']).astype(int)
    df['year_hct'] -= 2000
    return df


# Data Loading
def load_data():
    """
    Load data and add features.
    """
    test = pd.read_csv("/kaggle/input/equity-post-HCT-survival-predictions/test.csv")
    test = add_features(test)
    print("Test shape:", test.shape)
    train = pd.read_csv("/kaggle/input/equity-post-HCT-survival-predictions/train.csv")
    train = add_features(train)
    print("Train shape:", train.shape)
    return test, train


# Main Preprocessing Orchestrator
def preprocess_data(train, val):
    """
    Standardize numerical variables and transform (Label-encode) categoricals.
    Fill NA values with mean for numerical.
    Create torch dataloaders to prepare data for training and evaluation.
    """
    X_cat_train, X_cat_val, numerical, transformers = get_categoricals(train, val)
    scaler = StandardScaler()
    imp = SimpleImputer(missing_values=np.nan, strategy='mean', add_indicator=True)
    X_num_train = imp.fit_transform(train[numerical])
    X_num_train = scaler.fit_transform(X_num_train)
    X_num_val = imp.transform(val[numerical])
    X_num_val = scaler.transform(X_num_val)
    dl_train = init_dl(X_cat_train, X_num_train, train, training=True)
    dl_val = init_dl(X_cat_val, X_num_val, val)
    return X_cat_val, X_num_train, X_num_val, dl_train, dl_val, transformers


# Model architecture

**Model Architecture and Training Logic**

In this step, it defines a complete neural network model and a training pipeline using PyTorch and PyTorch Lightning. The model is designed for a survival analysis task, combining both categorical and numerical features and incorporating a fairness component into its loss function.

**1. CatEmbeddings Class: Categorical Feature Processing**

This module is responsible for handling categorical data. It converts each categorical feature into a dense, continuous vector representation, a process known as embedding.

- **Initialization:** It creates a list of `nn.Embedding layers`, one for each categorical feature, allowing the model to learn a unique vector for every category.
- **Feature Projection:** It concatenates the embeddings from all categorical features and passes the combined vector through a small neural network (`nn.Sequential`). This network, consisting of two linear layers with a GELU activation, projects the high-dimensional concatenated embeddings into a lower-dimensional space.
- **Forward Pass:** It applies the corresponding embedding layer to each categorical feature in the input tensor, concatenates the results, and then feeds the result through the projection network.

**2. NN Class: Core Model**

This is the main neural network class that combines both data types.

- **Input Handling:** It takes both the categorical (`x_cat`) and numerical (`x_cont`) data as input. The categorical data is first processed by the CatEmbeddings module.
- **Data Concatenation:** The output from the categorical embeddings is then concatenated with the numerical features.
- **Multi-Layer Perceptron (MLP):** The combined features are passed through an MLP, which includes an ODST layer (likely a specialized deep learning layer for tabular data), followed by BatchNorm1d and Dropout.
- **Weight Initialization:** The weights of the linear layers are initialized using the `xavier_normal_` algorithm, and biases are set to zero to promote stable training.

**3. LitNN Class: PyTorch Lightning Training Module**

This class connects from `pl.LightningModule` and `encapsulates` the full training and evaluation loop, including the loss function, optimizer, and logging.

- **Model Instantiation:** It creates an instance of the NN model and an auxiliary prediction head (aux_cls), which performs a secondary task on the hidden layer embeddings.
- **Loss Function (`calc_loss`):** This is the most crucial part of the code. The model is trained using a margin-based hinge loss calculated on pairwise comparisons of data points. This loss is specifically designed for survival analysis, as it focuses on correctly ranking individuals based on their time-to-event. It includes a masking mechanism (`get_mask`) to handle censored data, ensuring that only valid comparisons are considered in the loss calculation.
- **Fairness Component (`get_race_losses`):** A unique feature of this model is the incorporation of a fairness term. It calculates the loss for each racial group and adds the standard deviation of these losses to the total loss. This encourages the model to learn a representation that performs consistently across all groups, rather than optimizing for the overall average performance at the expense of specific subgroups.
- **Training and Validation Steps:** The `training_step` and validation_step methods define the forward pass and loss calculation for each batch. It also logs relevant metrics like training and validation loss, and the auxiliary loss.
- **Concordance Index (`on_validation_epoch_end`):** At the end of each validation and test epoch, the model computes and logs the concordance index (`C-index`), a standard metric for evaluating the performance of survival models. It also calculates a special metric (`_metric`) that accounts for the C-index performance across different racial groups.
- **Optimizer and Scheduler:** It configures an Adam optimizer for training and uses a `CosineAnnealingLR` scheduler to dynamically adjust the learning rate over the course of training.

**Theories and Concepts**

**Hybrid Model Architecture**

The model is designed to process both categorical and continuous data, a common challenge with tabular datasets.

- **Categorical Embeddings:** The `CatEmbeddings` class is a crucial component. Instead of using one-hot encoding, which can create a very large and sparse feature space, it represents each category (e.g., 'race' or 'city') as a dense vector. This technique allows the model to learn meaningful, low-dimensional representations where categories with similar properties are closer in the embedding space. These learned embeddings are then concatenated and passed through a projection layer, a small neural network, to learn complex interactions between them.
- **Multilayer Perceptron (MLP):** The NN class acts as the main model. It takes the concatenated categorical embeddings and the numerical features and combines them. This combined vector is then fed into an MLP, which uses layers like `nn.Linear` to learn complex, non-linear relationships in the data. The model includes a specialized ODST layer (likely a form of Optimal Decision Tree Split), which is a modern, interpretable layer for tabular data.
- **Data Concatenation:** The NN class showcases a hybrid architecture. It takes the processed categorical embeddings and the numerical features (which have been separately normalized) and concatenates them into a single, comprehensive feature vector. This combined vector is then fed into a shared Multilayer Perceptron (MLP) to learn non-linear relationships between all the input features.
- **Advanced Layers:** The model includes specialized layers like ODST, which suggests the use of a modern, tree-based layer for tabular data. It also uses Batch Normalization (`nn.BatchNorm1d`) to stabilize and accelerate training, and Dropout (`nn.Dropout`) to prevent overfitting by randomly deactivating neurons during training.
- **Weight Initialization:** The NN class includes a loop to initialize weights using Xavier Normal initialization (`nn.init.xavier_normal_`). This is a common practice that helps the network converge faster and avoids problems with vanishing or exploding gradients.

**Advanced Loss Function for Survival Data**

The most unique part of this code is the custom `calc_loss` function, which is designed specifically for time-to-event data.

- **Ranking-Based Loss:** Standard Mean Squared Error (MSE) is not suitable for survival analysis because the model needs to learn to correctly rank subjects based on their survival time, not just predict the exact time. For example, it is more important to predict that Subject A will survive longer than Subject B than it is to get the exact survival time of each. The code achieves this by using a pairwise comparison approach, creating all possible pairs of subjects and comparing their predicted survival times.
- **Handling Censored Data:** A key challenge in survival analysis is handling censored data—cases where the event (e.g., death) has not yet occurred at the end of the study. The `get_mask` function identifies invalid pairs for comparison. For example, if Subject A is still alive after 5 years and Subject B dies after 3 years, we know Subject A outlived Subject B. However, if Subject C is still alive after 2 years and Subject D is still alive after 4 years, we cannot definitively rank them, and these pairs are masked out from the loss calculation.
- **Margin-Based Hinge Loss:** The loss function itself is a variation of a Hinge Loss, which is commonly used in support vector machines and ranking problems. The margin hyperparameter ensures that the model is penalized only when the difference between the predicted outcomes of a pair of subjects is less than this margin, forcing the model to be confident in its rankings.

**Fairness and Regularization**

The model incorporates concepts to promote fairness and prevent overfitting.

- **Fairness Regularization:** The `get_race_losses` function is a sophisticated attempt to ensure the model performs equitably across different subgroups. By calculating the loss for each racial group and then adding the standard deviation of these losses to the main loss function, the model is penalized for having widely different performance metrics across these groups. This encourages the model to be more robust and fair, preventing it from favoring one group over another.
- **Regularization and Activation:** The model uses several key techniques to prevent overfitting:
  - **Batch Normalization (nn.BatchNorm1d):** This layer normalizes the activations of the previous layer, helping to stabilize and accelerate training.
  - **Dropout (nn.Dropout):** During training, this layer randomly "drops out" a percentage of neurons, forcing the network to learn more robust and redundant features.
  - **GELU Activation:** The `GELU` (Gaussian Error Linear Unit) is an activation function used to introduce non-linearity into the network. It's often seen as a smoother alternative to the popular ReLU function.
- **Auxiliary Loss:** The `aux_cls` (auxiliary classifier) and aux_loss are part of a multi-task learning approach. The model is trained on a primary task (survival time prediction) but also on a secondary, related task (the `efs_time` prediction on embeddings). This can help the model learn more robust and generalizable features in the embeddings.

**PyTorch Lightning and Optimizer**

The code uses PyTorch Lightning to streamline the training process and includes advanced optimization techniques.

- **PyTorch Lightning:** LitNN continues from `pl.LightningModule`. This framework abstracts away the standard boilerplate code for training and evaluation loops, simplifying the code and making it easier to manage.
- **Adam Optimizer and Cosine Annealing**: The configure_optimizers method uses the popular Adam optimizer, which is efficient for deep learning. It also includes `CosineAnnealingLR`, a learning rate scheduler that cyclically adjusts the learning rate. This can help the model escape local minima and find a better, more generalized solution during training.

In [5]:
class CatEmbeddings(nn.Module):
    """
    Embedding module for the categorical dataframe.
    """
    def __init__(
        self,
        projection_dim: int,
        categorical_cardinality: List[int],
        embedding_dim: int
    ):
        """
        projection_dim: The dimension of the final output after projecting the concatenated embeddings into a lower-dimensional space.
        categorical_cardinality: A list where each element represents the number of unique categories (cardinality) in each categorical feature.
        embedding_dim: The size of the embedding space for each categorical feature.
        self.embeddings: list of embedding layers for each categorical feature.
        self.projection: sequential neural network that goes from the embedding to the output projection dimension with GELU activation.
        """
        super(CatEmbeddings, self).__init__()
        self.embeddings = nn.ModuleList([
            nn.Embedding(cardinality, embedding_dim)
            for cardinality in categorical_cardinality
        ])
        self.projection = nn.Sequential(
            nn.Linear(embedding_dim * len(categorical_cardinality), projection_dim),
            nn.GELU(),
            nn.Linear(projection_dim, projection_dim)
        )

    def forward(self, x_cat):
        """
        Apply the projection on concatened embeddings that contains all categorical features.
        """
        x_cat = [embedding(x_cat[:, i]) for i, embedding in enumerate(self.embeddings)]
        x_cat = torch.cat(x_cat, dim=1)
        return self.projection(x_cat)


class NN(nn.Module):
    """
    Train a model on both categorical embeddings and numerical data.
    """
    def __init__(
            self,
            continuous_dim: int,
            categorical_cardinality: List[int],
            embedding_dim: int,
            projection_dim: int,
            hidden_dim: int,
            dropout: float = 0
    ):
        """
        continuous_dim: The number of continuous features.
        categorical_cardinality: A list of integers representing the number of unique categories in each categorical feature.
        embedding_dim: The dimensionality of the embedding space for each categorical feature.
        projection_dim: The size of the projected output space for the categorical embeddings.
        hidden_dim: The number of neurons in the hidden layer of the MLP.
        dropout: The dropout rate applied in the network.
        self.embeddings: previous embeddings for categorical data.
        self.mlp: defines an MLP model with an ODST layer followed by batch normalization and dropout.
        self.out: linear output layer that maps the output of the MLP to a single value
        self.dropout: defines dropout
        Weights initialization with xavier normal algorithm and biases with zeros.
        """
        super(NN, self).__init__()
        self.embeddings = CatEmbeddings(projection_dim, categorical_cardinality, embedding_dim)
        self.mlp = nn.Sequential(
            ODST(projection_dim + continuous_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.Dropout(dropout)
        )
        self.out = nn.Linear(hidden_dim, 1)
        self.dropout = nn.Dropout(dropout)

        # initialize weights
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x_cat, x_cont):
        """
        Create embedding layers for categorical data, concatenate with continous variables.
        Add dropout and goes through MLP and return raw output and 1-dimensional output as well.
        """
        x = self.embeddings(x_cat)
        x = torch.cat([x, x_cont], dim=1)
        x = self.dropout(x)
        x = self.mlp(x)
        return self.out(x), x


@functools.lru_cache
def combinations(N):
    """
    calculates all possible 2-combinations (pairs) of a tensor of indices from 0 to N-1, 
    and caches the result using functools.lru_cache for optimization
    """
    ind = torch.arange(N)
    comb = torch.combinations(ind, r=2)
    return comb.cuda()


class LitNN(pl.LightningModule):
    """
    Main Model creation and losses definition to fully train the model.
    """
    def __init__(
            self,
            continuous_dim: int,
            categorical_cardinality: List[int],
            embedding_dim: int,
            projection_dim: int,
            hidden_dim: int,
            lr: float = 1e-3,
            dropout: float = 0.2,
            weight_decay: float = 1e-3,
            aux_weight: float = 0.1,
            margin: float = 0.5,
            race_index: int = 0
    ):
        """
        continuous_dim: The number of continuous input features.
        categorical_cardinality: A list of integers, where each element corresponds to the number of unique categories for each categorical feature.
        embedding_dim: The dimension of the embeddings for the categorical features.
        projection_dim: The dimension of the projected space after embedding concatenation.
        hidden_dim: The size of the hidden layers in the feedforward network (MLP).
        lr: The learning rate for the optimizer.
        dropout: Dropout probability to avoid overfitting.
        weight_decay: The L2 regularization term for the optimizer.
        aux_weight: Weight used for auxiliary tasks.
        margin: Margin used in some loss functions.
        race_index: An index that refer to race_group in the input data.
        """
        super(LitNN, self).__init__()
        self.save_hyperparameters()

        # Creates an instance of the NN model defined above
        self.model = NN(
            continuous_dim=self.hparams.continuous_dim,
            categorical_cardinality=self.hparams.categorical_cardinality,
            embedding_dim=self.hparams.embedding_dim,
            projection_dim=self.hparams.projection_dim,
            hidden_dim=self.hparams.hidden_dim,
            dropout=self.hparams.dropout
        )
        self.targets = []

        # Defines a small feedforward neural network that performs an auxiliary task with 1-dimensional output
        self.aux_cls = nn.Sequential(
            nn.Linear(self.hparams.hidden_dim, self.hparams.hidden_dim // 3),
            nn.GELU(),
            nn.Linear(self.hparams.hidden_dim // 3, 1)
        )

    def on_before_optimizer_step(self, optimizer):
        """
        Compute the 2-norm for each layer
        If using mixed precision, the gradients are already unscaled here
        """
        norms = grad_norm(self.model, norm_type=2)
        self.log_dict(norms)

    def forward(self, x_cat, x_cont):
        """
        Forward pass that outputs the 1-dimensional prediction and the embeddings (raw output)
        """
        x, emb = self.model(x_cat, x_cont)
        return x.squeeze(1), emb

    def training_step(self, batch, batch_idx):
        """
        defines how the model processes each batch of data during training.
        A batch is a combination of : categorical data, continuous data, efs_time (y) and efs event.
        y_hat is the efs_time prediction on all data and aux_pred is auxiliary prediction on embeddings.
        Calculates loss and race_group loss on full data.
        Auxiliary loss is calculated with an event mask, ignoring efs=0 predictions and taking the average.
        Returns loss and aux_loss multiplied by weight defined above.
        """
        x_cat, x_cont, y, efs = batch
        y_hat, emb = self(x_cat, x_cont)
        aux_pred = self.aux_cls(emb).squeeze(1)
        loss, race_loss = self.get_full_loss(efs, x_cat, y, y_hat)
        aux_loss = nn.functional.mse_loss(aux_pred, y, reduction='none')
        aux_mask = efs == 1
        aux_loss = (aux_loss * aux_mask).sum() / aux_mask.sum()
        self.log("train_loss", loss, on_epoch=True, prog_bar=True, logger=True)
        self.log("race_loss", race_loss, on_epoch=True, prog_bar=True, logger=True, on_step=False)
        self.log("aux_loss", aux_loss, on_epoch=True, prog_bar=True, logger=True, on_step=False)
        return loss + aux_loss * self.hparams.aux_weight

    def get_full_loss(self, efs, x_cat, y, y_hat):
        """
        Output loss and race_group loss.
        """
        loss = self.calc_loss(y, y_hat, efs)
        race_loss = self.get_race_losses(efs, x_cat, y, y_hat)
        loss += 0.1 * race_loss
        return loss, race_loss

    def get_race_losses(self, efs, x_cat, y, y_hat):
        """
        Calculate loss for each race_group based on deviation/variance.
        """
        races = torch.unique(x_cat[:, self.hparams.race_index])
        race_losses = []
        for race in races:
            ind = x_cat[:, self.hparams.race_index] == race
            race_losses.append(self.calc_loss(y[ind], y_hat[ind], efs[ind]))
        race_loss = sum(race_losses) / len(race_losses)
        races_loss_std = sum((r - race_loss)**2 for r in race_losses) / len(race_losses)
        return torch.sqrt(races_loss_std)

    def calc_loss(self, y, y_hat, efs):
        """
        Most important part of the model : loss function used for training.
        We face survival data with event indicators along with time-to-event.

        This function computes the main loss by the following the steps :
        * create all data pairs with "combinations" function (= all "two subjects" combinations)
        * make sure that we have at least 1 event in each pair
        * convert y to +1 or -1 depending on the correct ranking
        * loss is computed using a margin-based hinge loss
        * mask is applied to ensure only valid pairs are being used (censored data can't be ranked with event in some cases)
        * average loss on all pairs is returned
        """
        N = y.shape[0]
        comb = combinations(N)
        comb = comb[(efs[comb[:, 0]] == 1) | (efs[comb[:, 1]] == 1)]
        pred_left = y_hat[comb[:, 0]]
        pred_right = y_hat[comb[:, 1]]
        y_left = y[comb[:, 0]]
        y_right = y[comb[:, 1]]
        y = 2 * (y_left > y_right).int() - 1
        loss = nn.functional.relu(-y * (pred_left - pred_right) + self.hparams.margin)
        mask = self.get_mask(comb, efs, y_left, y_right)
        loss = (loss.double() * (mask.double())).sum() / mask.sum()
        return loss

    def get_mask(self, comb, efs, y_left, y_right):
        """
        Defines all invalid comparisons :
        * Case 1: "Left outlived Right" but Right is censored
        * Case 2: "Right outlived Left" but Left is censored
        Masks for case 1 and case 2 are combined using |= operator and inverted using ~ to create a "valid pair mask"
        """
        left_outlived = y_left >= y_right
        left_1_right_0 = (efs[comb[:, 0]] == 1) & (efs[comb[:, 1]] == 0)
        mask2 = (left_outlived & left_1_right_0)
        right_outlived = y_right >= y_left
        right_1_left_0 = (efs[comb[:, 1]] == 1) & (efs[comb[:, 0]] == 0)
        mask2 |= (right_outlived & right_1_left_0)
        mask2 = ~mask2
        mask = mask2
        return mask

    def validation_step(self, batch, batch_idx):
        """
        This method defines how the model processes each batch during validation
        """
        x_cat, x_cont, y, efs = batch
        y_hat, emb = self(x_cat, x_cont)
        loss, race_loss = self.get_full_loss(efs, x_cat, y, y_hat)
        self.targets.append([y, y_hat.detach(), efs, x_cat[:, self.hparams.race_index]])
        self.log("val_loss", loss, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def on_validation_epoch_end(self):
        """
        At the end of the validation epoch, it computes and logs the concordance index
        """
        cindex, metric = self._calc_cindex()
        self.log("cindex", metric, on_epoch=True, prog_bar=True, logger=True)
        self.log("cindex_simple", cindex, on_epoch=True, prog_bar=True, logger=True)
        self.targets.clear()

    def _calc_cindex(self):
        """
        Calculate c-index accounting for each race_group or global.
        """
        y = torch.cat([t[0] for t in self.targets]).cpu().numpy()
        y_hat = torch.cat([t[1] for t in self.targets]).cpu().numpy()
        efs = torch.cat([t[2] for t in self.targets]).cpu().numpy()
        races = torch.cat([t[3] for t in self.targets]).cpu().numpy()
        metric = self._metric(efs, races, y, y_hat)
        cindex = concordance_index(y, y_hat, efs)
        return cindex, metric

    def _metric(self, efs, races, y, y_hat):
        """
        Calculate c-index accounting for each race_group
        """
        metric_list = []
        for race in np.unique(races):
            y_ = y[races == race]
            y_hat_ = y_hat[races == race]
            efs_ = efs[races == race]
            metric_list.append(concordance_index(y_, y_hat_, efs_))
        metric = float(np.mean(metric_list) - np.sqrt(np.var(metric_list)))
        return metric

    def test_step(self, batch, batch_idx):
        """
        Same as training step but to log test data
        """
        x_cat, x_cont, y, efs = batch
        y_hat, emb = self(x_cat, x_cont)
        loss, race_loss = self.get_full_loss(efs, x_cat, y, y_hat)
        self.targets.append([y, y_hat.detach(), efs, x_cat[:, self.hparams.race_index]])
        self.log("test_loss", loss)
        return loss

    def on_test_epoch_end(self) -> None:
        """
        At the end of the test epoch, calculates and logs the concordance index for the test set
        """
        cindex, metric = self._calc_cindex()
        self.log("test_cindex", metric, on_epoch=True, prog_bar=True, logger=True)
        self.log("test_cindex_simple", cindex, on_epoch=True, prog_bar=True, logger=True)
        self.targets.clear()


    def configure_optimizers(self):
        """
        configures the optimizer and learning rate scheduler:
        * Optimizer: Adam optimizer with weight decay (L2 regularization).
        * Scheduler: Cosine Annealing scheduler, which adjusts the learning rate according to a cosine curve.
        """
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        scheduler_config = {
            "scheduler": torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer,
                T_max=45,
                eta_min=6e-3
            ),
            "interval": "epoch",
            "frequency": 1,
            "strict": False,
        }

        return {"optimizer": optimizer, "lr_scheduler": scheduler_config}

# Model training,Generate Predictions, and Create Submission File

The `main` and `train_final` functions orchestrate the training, evaluation, and final submission generation. They implement a robust 5-fold stratified cross-validation strategy to ensure model reliability and fairness.

**Training and Optimization / `train_final() Function`**

This function is responsible for the actual training of a single model instance within the cross-validation loop.

- **Hyperparameter Initialization:** It defines a set of default hyperparameters like `embedding_dim`, `hidden_dim`, `lr`, and `dropout`. If a different set of hyperparameters is provided (e.g., from an Optuna study), it uses those instead.
- **Model Instantiation:** A LitNN model is created with the specified hyperparameters and the dimensions of the preprocessed data.
- **PyTorch Lightning Trainer:** A PyTorch Lightning Trainer is instantiated to manage the training process efficiently. The trainer is configured to use a GPU (`accelerator='cuda'`) and a maximum of 100 epochs.
- **Callbacks:** Several callbacks are used to enhance the training:
  - `ModelCheckpoint`: Saves the best model checkpoint based on the lowest validation loss (`val_loss`).
  - `LearningRateMonitor`: Logs the learning rate at each epoch to monitor the scheduler's behavior.
  - `TQDMProgressBar`: Provides a visual progress bar during training.
  - `StochasticWeightAveraging` (SWA): A technique that averages model weights during the final epochs (from epoch 40 onwards) to improve generalization.
- **Training and Evaluation:** The `trainer.fit()` method trains the model using the provided training data loader (`dl_train`), and `trainer.test()` evaluates the final model on the validation data loader (`dl_val`).

**Data Splitting and Ensemble Learning / `main() Function`**

This function serves as the main driver of the entire pipeline.

- **Data Loading:** It loads the raw training and test data and applies initial feature engineering.
- **Cross-Validation Split:** It creates a StratifiedKFold object to split the training data into 5 folds. The stratification is based on a combined key of race_group and whether the patient is a newborn, ensuring that the distribution of these critical features is preserved across all folds.
- **Iterative Training:** The code then iterates through each fold:
  - It creates the training and validation splits for the current fold.
  - It preprocesses the data for the current fold, creating the necessary dataloaders.
  - It calls `train_final()` to train a new model instance on the current fold's data.
  - It re-prepares the data to make predictions on the full test set using the just-trained model.
  - It adds the predictions from the current fold to a running total (`test_pred`), effectively creating an ensemble of predictions from all five models.

- **Submission File Generation:** After all folds are completed, the final ensembled predictions are averaged. The script then creates the `submission.csv` file by assigning the negative of these predictions to the prediction column. The negative sign is a standard convention in survival analysis.

**Theories and Concepts**

**Training and Optimization**

The `train_final` function orchestrates the training process and utilizes several advanced techniques to ensure the model is robust and performs well.

- **PyTorch Lightning Framework:** The code leverages PyTorch Lightning, a framework that simplifies the training of complex models in PyTorch. It abstracts away the boilerplate code for the training loop, validation, and testing steps, allowing the developer to focus on the model and the data.
- **Hyperparameter Optimization:** The hparams dictionary contains a set of hyperparameters that are not learned from the data. These values (like `learning_rate`, `dropout`, and `hidden_dim`) have a significant impact on the model's performance and are likely the result of a separate hyperparameter search or tuning process.
- **Model Checkpointing:** The `ModelCheckpoint` callback is a crucial component for saving the best-performing model. During training, it monitors a specific metric (in this case, `val_loss`) and saves a copy of the model's weights whenever a new minimum is achieved. This ensures that even if the model's performance degrades in later epochs, you can always retrieve the best version.
- **Stochastic Weight Averaging (SWA):** This is a powerful optimization technique used to find a better, more generalized solution. Instead of using the weights from the very last training epoch, SWA averages the weights of the model over a certain period of training (starting from `swa_epoch_start`). This technique helps the model escape sharp local minima and converge to a broader, flatter minimum, which often leads to better performance on unseen data.
- **GPU Acceleration:** The use of `accelerator='cuda'` in the `pl.Trainer` object indicates that the training process is accelerated using a GPU (Graphics Processing Unit). This is essential for deep learning models, as GPUs can perform the massive parallel computations required for training much faster than a CPU.

**Data Splitting and Ensemble Learning**

The `main` function is the core driver of the entire workflow, handling the data and implementing a robust training strategy.

- **Stratified K-Fold Cross-Validation:** The code uses StratifiedKFold to split the training data. This is a form of cross-validation where the data is divided into k subsets (or folds), and the model is trained and validated k times. The "stratified" part is particularly important that it ensures that the proportion of specific categories (in this case, a combination of `race_group` and `newborns`) is maintained in each fold. This prevents a situation where one fold might have an unrepresentative distribution of a key group, which could lead to a biased model.
- **Ensemble Learning:** The code trains a separate model for each of the 5 folds and then averages their predictions on the test set. This is a simple but effective form of ensemble learning. By combining the outputs of multiple models, the final prediction becomes more stable and less prone to the random fluctuations or overfitting of a single model, thereby improving overall generalization and accuracy.

In [6]:
def train_final(X_num_train, dl_train, dl_val, transformers, hparams=None, categorical_cols=None):
    """
    Defines model hyperparameters and fit the model.
    """
    if hparams is None:
        hparams = {            
            "embedding_dim": 27,
            "projection_dim": 43,
            "hidden_dim": 76,
            "lr": 0.00487,
            "dropout": 0.38886,
            "aux_weight": 0.49631,
            "margin": 0.10025,
            "weight_decay": 0.000115
        }
    model = LitNN(
        continuous_dim=X_num_train.shape[1],
        categorical_cardinality=[len(t.classes_) for t in transformers],
        race_index=categorical_cols.index("race_group"),
        **hparams
    )
    checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor="val_loss", save_top_k=1)
    trainer = pl.Trainer(
        accelerator='cuda',
        max_epochs=100,                               #### 56 is used in this competition but 100 epoch achieve the lowest validation loss
        log_every_n_steps=6,
        callbacks=[
            checkpoint_callback,
            LearningRateMonitor(logging_interval='epoch'),
            TQDMProgressBar(),
            StochasticWeightAveraging(swa_lrs=1e-5, swa_epoch_start=40, annealing_epochs=15)
        ],
    )
    trainer.fit(model, dl_train)
    trainer.test(model, dl_val)
    return model.eval()

def main(hparams):
    """
    Main function to train the model.
    The steps are as following :
    * load data and fill efs and efs time for test data with 1
    * initialize pred array with 0
    * get categorical and numerical columns
    * split the train data on the stratified criterion : race_group * newborns yes/no
    * preprocess the fold data (create dataloaders)
    * train the model and create final submission output
    """
    test, train_original = load_data()
    test['efs_time'] = 1
    test['efs'] = 1
    test_pred = np.zeros(test.shape[0])
    categorical_cols, numerical = get_feature_types(train_original)
    kf = StratifiedKFold(n_splits=5, shuffle=True, )
    for i, (train_index, test_index) in enumerate(
        kf.split(
            train_original, train_original.race_group.astype(str) + (train_original.age_at_hct == 0.044).astype(str)
        )
    ):
        tt = train_original.copy()
        train = tt.iloc[train_index]
        val = tt.iloc[test_index]
        X_cat_val, X_num_train, X_num_val, dl_train, dl_val, transformers = preprocess_data(train, val)
        model = train_final(X_num_train, dl_train, dl_val, transformers, categorical_cols=categorical_cols)
        # Create submission
        train = tt.iloc[train_index]
        X_cat_val, X_num_train, X_num_val, dl_train, dl_val, transformers = preprocess_data(train, test)
        pred, _ = model.cuda().eval()(
            torch.tensor(X_cat_val, dtype=torch.long).cuda(),
            torch.tensor(X_num_val, dtype=torch.float32).cuda()
        )
        test_pred += pred.detach().cpu().numpy()
        
    subm_data = pd.read_csv("/kaggle/input/equity-post-HCT-survival-predictions/sample_submission.csv")
    subm_data['prediction'] = -test_pred
    subm_data.to_csv('submission.csv', index=False)
    
    display(subm_data.head())
    return 


hparams = None
res = main(hparams)
print("done")

Test shape: (3, 59)
Train shape: (28800, 61)


Training: |          | 0/? [00:00<?, ?it/s]

Testing: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Testing: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Testing: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Testing: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Testing: |          | 0/? [00:00<?, ?it/s]

Unnamed: 0,ID,prediction
0,28800,-0.842534
1,28801,-0.04208
2,28802,-1.066475


done
