# Tabnet Survival anaysis


In [57]:
!pip install -U sentence-transformers > /dev/null 2>&1
!pip install xgboost > /dev/null 2>&1
!pip install scikit-learn==1.4.2 scikit-survival==0.23.1 > /dev/null 2>&1
!pip install torchtuples > /dev/null 2>&1
!pip install pycox > /dev/null 2>&1
!pip install numpy==1.21.5  > /dev/null 2>&1
!pip install interpret-core  > /dev/null 2>&1
!pip install lightgbm > /dev/null 2>&1
!pip install shap > /dev/null 2>&1
!pip install lifelines pycox > /dev/null 2>&1

In [63]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import gc
import kagglehub
import contextlib
import logging

from sklearn.metrics import mean_squared_error, mean_absolute_error
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.impute import SimpleImputer
from pytorch_tabnet.tab_model import TabNetRegressor
from sklearn.preprocessing import StandardScaler


import lightgbm as lgb
import shap
import torch
from lifelines import CoxPHFitter





In [3]:
# Life Expectancy dataset
life_exp_path = kagglehub.dataset_download("kumarajarshi/life-expectancy-who")
life_exp_file = os.path.join(life_exp_path, "Life Expectancy Data.csv")
life_exp_df = pd.read_csv(life_exp_file)
print("Life Expectancy Sample:")
print(life_exp_df.head())

# Heart Failure dataset (not used in LightGBM, but kept for context)
heart_path = kagglehub.dataset_download("fedesoriano/heart-failure-prediction")
heart_file = os.path.join(heart_path, "heart.csv")
heart_df = pd.read_csv(heart_file)
print("Heart Failure Sample:")
print(heart_df.head())

# Age Dataset
age_path = kagglehub.dataset_download("imoore/age-dataset")
age_file = os.path.join(age_path, "AgeDataset-V1.csv")
age_df = pd.read_csv(age_file)
print("Age Dataset Sample:")
print(age_df.head())

# World important events Dataset
events_path = kagglehub.dataset_download("saketk511/world-important-events-ancient-to-modern")
events_file = os.path.join(events_path, "World Important Dates.csv")
events_df = pd.read_csv(events_file)
print("World Important Events Sample:")
print(events_df.head())

# Plane Crash Dataset
plane_crash_path = kagglehub.dataset_download("nguyenhoc/plane-crash")
plane_crash_file = os.path.join(plane_crash_path, "planecrashinfo_20181121001952.csv")  
planes_df = pd.read_csv(plane_crash_file)
print("Historical Plane Crashes Sample:")
print(planes_df.head())

# Gloabl Life Expectancy dataset
global_le_path = kagglehub.dataset_download("hasibalmuzdadid/global-life-expectancy-historical-dataset")
global_le_file = os.path.join(global_le_path, "global life expectancy dataset.csv")
global_le_df = pd.read_csv(global_le_file)
print("Global Life Expectancy Historical Dataset Sample:")
print(global_le_df.head())

# US death rate Dataset
death_rates_path = kagglehub.dataset_download("melissamonfared/death-rates-united-states")
death_rates_file = os.path.join(death_rates_path, "Death_rates.csv")
death_rates_df = pd.read_csv(death_rates_file)
print("Death Rates United States Dataset Sample:")
print(death_rates_df.head())

Life Expectancy Sample:
       Country  Year      Status  Life expectancy   Adult Mortality  \
0  Afghanistan  2015  Developing              65.0            263.0   
1  Afghanistan  2014  Developing              59.9            271.0   
2  Afghanistan  2013  Developing              59.9            268.0   
3  Afghanistan  2012  Developing              59.5            272.0   
4  Afghanistan  2011  Developing              59.2            275.0   

   infant deaths  Alcohol  percentage expenditure  Hepatitis B  Measles   ...  \
0             62     0.01               71.279624         65.0      1154  ...   
1             64     0.01               73.523582         62.0       492  ...   
2             66     0.01               73.219243         64.0       430  ...   
3             69     0.01               78.184215         67.0      2787  ...   
4             71     0.01                7.097109         68.0      3013  ...   

   Polio  Total expenditure  Diphtheria    HIV/AIDS         GD

In [72]:
def enhanced_feature_engineering(df, life_exp_df, global_le_df, death_rates_df):
    """
    Feature engineering for survival analysis with proper censoring handling
    """
    # -------- Set Observation Year --------
    current_year = 2019  # Single source of truth for current year
    
    # -------- Basic Cleaning --------
    df['Country'] = df['Country'].str.split(';').str[0].str.strip()
    df['Gender'] = np.where(df['Gender'] == 'Male', 1, 
                          np.where(df['Gender'] == 'Female', 0, 0.5))

    # -------- Clinical Features --------
    # Stress Score
    stress_map = {'Politician': 9, 'Military personnel': 8, 'Journalist': 7,
                  'Businessperson': 6, 'Artist': 5, 'Teacher': 4, 
                  'Researcher': 3, 'Other': 5, 'Unknown': 5}
    df['stress_score'] = df['Occupation'].map(stress_map).fillna(5).astype('float32') / 9.0

    # BMI from Country Data
    life_exp_df[' BMI '] = pd.to_numeric(life_exp_df[' BMI '], errors='coerce')
    country_bmi = life_exp_df.groupby('Country')[' BMI '].median().to_dict()
    df['avg_bmi'] = df['Country'].map(country_bmi).fillna(25).astype('float32')

    # Smoking Prevalence (Synthetic)
    df['smoking_prev'] = (1 / (1 + np.exp((df['Birth year'] - 1950) / 10))).astype('float32')
    df['smoking_prev'] = np.clip(df['smoking_prev'], 0.1, 0.6)

    # -------- Country-Level Features --------
    # Life Expectancy (Global)
    global_le_df.columns = global_le_df.columns.str.strip()
    global_le_melted = global_le_df.melt(
        id_vars=['Country Name', 'Country Code'],
        value_vars=[str(y) for y in range(1960, current_year+1)],
        var_name='Year',
        value_name='Life_Exp_Value'
    )
    global_le_melted['Year'] = pd.to_numeric(global_le_melted['Year'])
    global_le_melted['Life_Exp_Value'] = pd.to_numeric(global_le_melted['Life_Exp_Value'], errors='coerce')
    
    global_le_agg = (
        global_le_melted
        .sort_values(['Country Name', 'Year'], ascending=[True, False])
        .groupby('Country Name')
        ['Life_Exp_Value']
        .first()
        .reset_index()
        .rename(columns={'Country Name': 'Country'})
    )
    df = df.merge(global_le_agg, on='Country', how='left')
    df['global_life_exp'] = df['Life_Exp_Value'].fillna(df['Life_Exp_Value'].median())
    
    # Death Rates
    death_rates_df['Death_Rate'] = (
        death_rates_df['ESTIMATE']
        .astype(str).str.replace(',', '')
        .replace(['nan', 'None'], np.nan)
        .astype(float)
    )
    death_rates_agg = death_rates_df.groupby('Country')['Death_Rate'].mean().reset_index()
    df = df.merge(death_rates_agg, on='Country', how='left')
    df['avg_death_rate'] = df['Death_Rate'].fillna(df['Death_Rate'].median())

    # -------- Survival Setup --------
    # Censoring Logic
    df['censored'] = (df['Death year'] > current_year).astype(int)
    df['T'] = np.where(
        df['censored'] == 1,
        current_year - df['Birth year'],
        df['Age of death']
    )
    
    # Ensure realistic survival times
    df['T'] = df['T'].clip(lower=0, upper=120)
    
    # Force censoring if none exists
    if df['censored'].sum() == 0:
        print("⚠️ Adding synthetic censoring")
        rng = np.random.default_rng(42)
        mask = (df['T'] > 0) & (df['T'] < current_year - df['Birth year'].min())
        eligible_indices = df[mask].index
        censored_idx = rng.choice(eligible_indices, size=int(len(df)*0.2), replace=False)
        df.loc[censored_idx, 'censored'] = 1
        df.loc[censored_idx, 'T'] = current_year - df.loc[censored_idx, 'Birth year']

    # Cleanup
    df = df.drop(columns=['Life_Exp_Value', 'Death_Rate'], errors='ignore')
    return df


In [70]:
def train_survival_model(df):
    """
    Trains CoxPH model with rigorous data validation
    """
    # Feature Selection
    features = [
        'stress_score', 'avg_bmi', 'smoking_prev',
        'global_life_exp', 'T', 'censored'
    ]
    
    # Data Preparation
    survival_df = df[features].dropna()
    
    # Data Validation
    print("\n📊 Censoring Distribution:")
    print(survival_df['censored'].value_counts())
    
    if survival_df['censored'].nunique() == 1:
        raise ValueError("All samples have same censoring status!")
    
    # Feature Processing
    scaler = StandardScaler()
    num_features = ['stress_score', 'avg_bmi', 'smoking_prev', 'global_life_exp']
    survival_df[num_features] = scaler.fit_transform(survival_df[num_features])

    # Model Training
    cph = CoxPHFitter()
    try:
        cph.fit(survival_df, duration_col='T', event_col='censored', fit_options={'step_size':0.1})
        print("✅ Training Successful!")
        print("\nModel Summary:")
        print(cph.print_summary())
        return cph
    except Exception as e:
        print(f"❌ Training Failed: {str(e)}")
        return None

In [73]:
processed_batch = enhanced_feature_engineering(age_df, life_exp_df, global_le_df, death_rates_df)
cox_model = train_survival_model(processed_batch)

# Show results
print("Cox Model Summary:")
cox_model.print_summary()

  return np.nanmean(a, axis, out=out, keepdims=keepdims)



📊 Censoring Distribution:
censored
0    1209713
1      13295
Name: count, dtype: int64




✅ Training Successful!

Model Summary:


0,1
model,lifelines.CoxPHFitter
duration col,'T'
event col,'censored'
baseline estimation,breslow
number of observations,1.22301e+06
number of events observed,13295
partial log-likelihood,-138544.25
time fit was run,2025-02-06 20:38:42 UTC

Unnamed: 0,coef,exp(coef),se(coef),coef lower 95%,coef upper 95%,exp(coef) lower 95%,exp(coef) upper 95%,cmp to,z,p,-log2(p)
stress_score,-0.03,0.97,0.01,-0.05,-0.01,0.95,0.99,0.0,-3.55,<0.005,11.36
avg_bmi,0.11,1.11,0.01,0.09,0.12,1.09,1.13,0.0,11.72,<0.005,103.05
smoking_prev,-11.11,0.0,0.15,-11.41,-10.81,0.0,0.0,0.0,-71.99,<0.005,inf
global_life_exp,-0.09,0.91,0.01,-0.1,-0.08,0.9,0.92,0.0,-15.38,<0.005,174.99

0,1
Concordance,0.82
Partial AIC,277096.50
log-likelihood ratio test,53437.14 on 4 df
-log2(p) of ll-ratio test,inf


None
Cox Model Summary:


0,1
model,lifelines.CoxPHFitter
duration col,'T'
event col,'censored'
baseline estimation,breslow
number of observations,1.22301e+06
number of events observed,13295
partial log-likelihood,-138544.25
time fit was run,2025-02-06 20:38:42 UTC

Unnamed: 0,coef,exp(coef),se(coef),coef lower 95%,coef upper 95%,exp(coef) lower 95%,exp(coef) upper 95%,cmp to,z,p,-log2(p)
stress_score,-0.03,0.97,0.01,-0.05,-0.01,0.95,0.99,0.0,-3.55,<0.005,11.36
avg_bmi,0.11,1.11,0.01,0.09,0.12,1.09,1.13,0.0,11.72,<0.005,103.05
smoking_prev,-11.11,0.0,0.15,-11.41,-10.81,0.0,0.0,0.0,-71.99,<0.005,inf
global_life_exp,-0.09,0.91,0.01,-0.1,-0.08,0.9,0.92,0.0,-15.38,<0.005,174.99

0,1
Concordance,0.82
Partial AIC,277096.50
log-likelihood ratio test,53437.14 on 4 df
-log2(p) of ll-ratio test,inf
