# Tabnet Survival anaysis


In [1]:
!pip install -U sentence-transformers > /dev/null 2>&1
!pip install xgboost > /dev/null 2>&1
!pip install scikit-learn==1.4.2 scikit-survival==0.23.1 > /dev/null 2>&1
!pip install torchtuples > /dev/null 2>&1
!pip install pycox > /dev/null 2>&1
!pip install numpy==1.21.5  > /dev/null 2>&1
!pip install interpret-core  > /dev/null 2>&1
!pip install lightgbm > /dev/null 2>&1
!pip install shap > /dev/null 2>&1

In [26]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import gc
import kagglehub
import contextlib
import logging

from sklearn.metrics import mean_squared_error, mean_absolute_error
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.impute import SimpleImputer
from pytorch_tabnet.tab_model import TabNetRegressor


import lightgbm as lgb
import shap
import torch




In [3]:
# Life Expectancy dataset
life_exp_path = kagglehub.dataset_download("kumarajarshi/life-expectancy-who")
life_exp_file = os.path.join(life_exp_path, "Life Expectancy Data.csv")
life_exp_df = pd.read_csv(life_exp_file)
print("Life Expectancy Sample:")
print(life_exp_df.head())

# Heart Failure dataset (not used in LightGBM, but kept for context)
heart_path = kagglehub.dataset_download("fedesoriano/heart-failure-prediction")
heart_file = os.path.join(heart_path, "heart.csv")
heart_df = pd.read_csv(heart_file)
print("Heart Failure Sample:")
print(heart_df.head())

# Age Dataset
age_path = kagglehub.dataset_download("imoore/age-dataset")
age_file = os.path.join(age_path, "AgeDataset-V1.csv")
age_df = pd.read_csv(age_file)
print("Age Dataset Sample:")
print(age_df.head())

# World important events Dataset
events_path = kagglehub.dataset_download("saketk511/world-important-events-ancient-to-modern")
events_file = os.path.join(events_path, "World Important Dates.csv")
events_df = pd.read_csv(events_file)
print("World Important Events Sample:")
print(events_df.head())

# Plane Crash Dataset
plane_crash_path = kagglehub.dataset_download("nguyenhoc/plane-crash")
plane_crash_file = os.path.join(plane_crash_path, "planecrashinfo_20181121001952.csv")  
planes_df = pd.read_csv(plane_crash_file)
print("Historical Plane Crashes Sample:")
print(planes_df.head())

# Gloabl Life Expectancy dataset
global_le_path = kagglehub.dataset_download("hasibalmuzdadid/global-life-expectancy-historical-dataset")
global_le_file = os.path.join(global_le_path, "global life expectancy dataset.csv")
global_le_df = pd.read_csv(global_le_file)
print("Global Life Expectancy Historical Dataset Sample:")
print(global_le_df.head())

# US death rate Dataset
death_rates_path = kagglehub.dataset_download("melissamonfared/death-rates-united-states")
death_rates_file = os.path.join(death_rates_path, "Death_rates.csv")
death_rates_df = pd.read_csv(death_rates_file)
print("Death Rates United States Dataset Sample:")
print(death_rates_df.head())

Life Expectancy Sample:
       Country  Year      Status  Life expectancy   Adult Mortality  \
0  Afghanistan  2015  Developing              65.0            263.0   
1  Afghanistan  2014  Developing              59.9            271.0   
2  Afghanistan  2013  Developing              59.9            268.0   
3  Afghanistan  2012  Developing              59.5            272.0   
4  Afghanistan  2011  Developing              59.2            275.0   

   infant deaths  Alcohol  percentage expenditure  Hepatitis B  Measles   ...  \
0             62     0.01               71.279624         65.0      1154  ...   
1             64     0.01               73.523582         62.0       492  ...   
2             66     0.01               73.219243         64.0       430  ...   
3             69     0.01               78.184215         67.0      2787  ...   
4             71     0.01                7.097109         68.0      3013  ...   

   Polio  Total expenditure  Diphtheria    HIV/AIDS         GD

In [19]:
def enhanced_feature_engineering(df, life_exp_df, global_le_df, death_rates_df):
    """
    Feature engineering for survival analysis using TabNet:
    - Handles categorical encoding
    - Creates meaningful health-related features
    - Adds censoring and survival time variables
    """
    # -------- Basic Cleaning --------
    df['Country'] = df['Country'].str.split(';').str[0].str.strip()
    df['Gender'] = np.where(df['Gender'] == 'Male', 1, np.where(df['Gender'] == 'Female', 0, 0.5))

    # -------- Essential Features --------
    stress_map = {'Politician': 9, 'Military personnel': 8, 'Journalist': 7,
                  'Businessperson': 6, 'Artist': 5, 'Teacher': 4, 
                  'Researcher': 3, 'Other': 5, 'Unknown': 5}
    df['stress_score'] = df['Occupation'].map(stress_map).fillna(5).astype('float32') / 9.0

    # Get BMI from country statistics
    life_exp_df[' BMI '] = pd.to_numeric(life_exp_df[' BMI '], errors='coerce')
    country_bmi = life_exp_df.groupby('Country')[' BMI '].median().to_dict()
    df['avg_bmi'] = df['Country'].map(country_bmi).fillna(25).astype('float32')

    # Heart disease risk
    df['heart_disease_risk'] = (0.4 * df['Gender'] +
                                0.3 * df['stress_score'] +
                                0.3 * df['avg_bmi']).astype('float32')

    # Smoking Prevalence
    df['smoking_prev'] = (1 / (1 + np.exp((df['Birth year'] - 1950) / 10))).astype('float32')
    df['smoking_prev'] = np.clip(df['smoking_prev'], 0.1, 0.6)

    # -------- Merge Country-Level Health Data --------
    life_exp_filtered = life_exp_df.groupby('Country')[['Alcohol', 'GDP', 'Schooling']].median()
    df = df.join(life_exp_filtered.add_prefix('country_'), on='Country', how='left')

    # -------------------- FIX: Handle Global Life Expectancy Data --------------------
    print("Global Life Expectancy Columns:", global_le_df.columns)  # Debugging

    # Rename for consistency
    global_le_df = global_le_df.rename(columns={'Country Name': 'Country'})

    # Convert from wide to long format (yearly data to single column)
    global_le_df = global_le_df.melt(
        id_vars=['Country'], 
        var_name='Year', 
        value_name='Life_Expectancy'
    )

    # Ensure numeric conversion and remove invalid values
    global_le_df["Year"] = pd.to_numeric(global_le_df["Year"], errors='coerce')
    global_le_df["Life_Expectancy"] = pd.to_numeric(global_le_df["Life_Expectancy"], errors='coerce')
    global_le_df = global_le_df.dropna(subset=['Life_Expectancy'])

    # Compute average life expectancy per country
    global_le_agg = (
        global_le_df.groupby("Country")["Life_Expectancy"]
        .mean()
        .reset_index()
    )

    df = df.merge(global_le_agg.rename(columns={'Life_Expectancy': 'global_life_exp'}), on='Country', how='left')

    # -------------------- FIX: Handle Death Rates Data --------------------
    print("Death Rates Dataset Columns:", death_rates_df.columns)  # Debugging

    # Use the "ESTIMATE" column as Death Rate
    correct_col = "ESTIMATE"

    # Ensure numeric conversion
    death_rates_df[correct_col] = pd.to_numeric(death_rates_df[correct_col], errors="coerce")

    # Group by Country and take the average death rate
    death_rates_agg = death_rates_df.groupby("Country")[correct_col].mean().reset_index()
    death_rates_agg.rename(columns={correct_col: "avg_death_rate"}, inplace=True)

    # Merge death rates with main dataset
    df = df.merge(death_rates_agg, on="Country", how="left")


    # -------------------- Compute Censoring & Survival Time --------------------
    df['censored'] = (df['Death year'] > 2024).astype(int)  # 1 = alive (censored), 0 = dead
    df['T'] = np.where(df['censored'] == 1, 2024 - df['Birth year'], df['Age of death'])
    df['T'] = df['T'].clip(lower=0)  # Ensure no negative survival times

    # -------------------- Remove Unnecessary Features --------------------
    drop_cols = ['plane_crash_count', 'num_events', 'avg_plane_fatalities']
    df = df.drop(columns=[col for col in drop_cols if col in df.columns])

    return df


In [16]:
def train_tabnet(df):
    """
    Trains a TabNet model for survival analysis using T (time-to-event) as the target.
    """
    # Define features and target
    features = ['stress_score', 'avg_bmi', 'heart_disease_risk', 'smoking_prev', 
                'country_Alcohol', 'global_life_exp']
    
    X = df[features]
    y = df[['T', 'censored']]  # We need both time and event status for survival analysis

    # Split data
    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

    # Handle missing values with imputation
    imputer = SimpleImputer(strategy="median")
    X_train = imputer.fit_transform(X_train)
    X_val = imputer.transform(X_val)

    # Convert target variables to PyTorch tensors
    T_train = torch.tensor(y_train['T'].values, dtype=torch.float32).reshape(-1, 1)
    C_train = torch.tensor(y_train['censored'].values, dtype=torch.float32).reshape(-1, 1)
    T_val = torch.tensor(y_val['T'].values, dtype=torch.float32).reshape(-1, 1)
    C_val = torch.tensor(y_val['censored'].values, dtype=torch.float32).reshape(-1, 1)

    # TabNet Model
    tabnet = TabNetRegressor()
    
    # Loss function: Negative Log Likelihood for survival analysis
    def survival_loss(preds, T, C):
        """
        Implements a loss function that considers censoring using Negative Log Likelihood.
        """
        risk = -torch.log(preds) * (1 - C)  # Event observed
        risk += -torch.log(1 - preds) * C   # Censored
        return risk.mean()

    # Training Loop
    max_epochs = 100
    batch_size = 512
    for epoch in range(max_epochs):
        tabnet.fit(
            X_train, T_train,
            eval_set=[(X_val, T_val)],
            patience=10,
            batch_size=batch_size,
            virtual_batch_size=256
        )
        preds = tabnet.predict(X_val)
        loss = survival_loss(torch.tensor(preds, dtype=torch.float32), T_val, C_val)
        print(f"Epoch {epoch+1}, Loss: {loss.item()}")

    return tabnet, X_val, y_val


In [20]:
# -------------------- STEP 3: Evaluate Model with Concordance Index --------------------
def evaluate_tabnet(tabnet, X_val, y_val):
    """
    Evaluates the survival model using Concordance Index (C-index).
    """
    preds = tabnet.predict(X_val)  # Survival predictions
    c_index = concordance_index(y_val['T'], -preds, y_val['censored'])
    print(f"Concordance Index: {c_index:.3f}")
    return c_index



In [27]:
# -------------------- Execute Training & Evaluation --------------------
processed_batch = enhanced_feature_engineering(age_df, life_exp_df, global_le_df, death_rates_df)
tabnet_model, X_val, y_val = train_tabnet(processed_batch)
evaluate_tabnet(tabnet_model, X_val, y_val)

Global Life Expectancy Columns: Index(['Country Name', 'Country Code', '1960', '1961', '1962', '1963', '1964',
       '1965', '1966', '1967', '1968', '1969', '1970', '1971', '1972', '1973',
       '1974', '1975', '1976', '1977', '1978', '1979', '1980', '1981', '1982',
       '1983', '1984', '1985', '1986', '1987', '1988', '1989', '1990', '1991',
       '1992', '1993', '1994', '1995', '1996', '1997', '1998', '1999', '2000',
       '2001', '2002', '2003', '2004', '2005', '2006', '2007', '2008', '2009',
       '2010', '2011', '2012', '2013', '2014', '2015', '2016', '2017', '2018',
       '2019', '2020'],
      dtype='object')
Death Rates Dataset Columns: Index(['INDICATOR', 'UNIT', 'UNIT_NUM', 'STUB_NAME', 'STUB_NAME_NUM',
       'STUB_LABEL', 'STUB_LABEL_NUM', 'YEAR', 'YEAR_NUM', 'AGE', 'AGE_NUM',
       'ESTIMATE', 'Country'],
      dtype='object')




RuntimeError: index -1 is out of bounds for dimension 1 with size 6