# Deepsurv


In [1]:
!pip install -U sentence-transformers > /dev/null 2>&1
!pip install xgboost > /dev/null 2>&1
!pip install scikit-learn==1.4.2 scikit-survival==0.23.1 > /dev/null 2>&1
!pip install torchtuples > /dev/null 2>&1
!pip install pycox > /dev/null 2>&1
!pip install numpy==1.21.5  > /dev/null 2>&1
!pip install interpret-core  > /dev/null 2>&1
!pip install lightgbm > /dev/null 2>&1
!pip install shap > /dev/null 2>&1
!pip install lifelines pycox > /dev/null 2>&1
!pip install pycountry > /dev/null 2>&1
!pip install -U sentence-transformers xgboost scikit-learn==1.4.2 scikit-survival==0.23.1 torchtuples pycox numpy==1.21.5 interpret-core lightgbm shap lifelines pycox pycountry > /dev/null 2>&1


In [13]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import torchtuples as tt
import kagglehub
import os

from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.model_selection import train_test_split
from lifelines import CoxPHFitter
from sksurv.metrics import concordance_index_censored
from pycox.models.cox import CoxPH
from pycox.evaluation import EvalSurv
import torch
import torch.nn as nn
from torch.optim import Adam


In [14]:
# Life Expectancy dataset
life_exp_path = kagglehub.dataset_download("kumarajarshi/life-expectancy-who")
life_exp_file = os.path.join(life_exp_path, "Life Expectancy Data.csv")
life_exp_df = pd.read_csv(life_exp_file)
print("Life Expectancy Sample:")
print(life_exp_df.head())

# Heart Failure dataset (not used in LightGBM, but kept for context)
heart_path = kagglehub.dataset_download("fedesoriano/heart-failure-prediction")
heart_file = os.path.join(heart_path, "heart.csv")
heart_df = pd.read_csv(heart_file)
print("Heart Failure Sample:")
print(heart_df.head())

# Age Dataset
age_path = kagglehub.dataset_download("imoore/age-dataset")
age_file = os.path.join(age_path, "AgeDataset-V1.csv")
age_df = pd.read_csv(age_file)
print("Age Dataset Sample:")
print(age_df.head())

# World important events Dataset
events_path = kagglehub.dataset_download("saketk511/world-important-events-ancient-to-modern")
events_file = os.path.join(events_path, "World Important Dates.csv")
events_df = pd.read_csv(events_file)
print("World Important Events Sample:")
print(events_df.head())

# Plane Crash Dataset
plane_crash_path = kagglehub.dataset_download("nguyenhoc/plane-crash")
plane_crash_file = os.path.join(plane_crash_path, "planecrashinfo_20181121001952.csv")  
planes_df = pd.read_csv(plane_crash_file)
print("Historical Plane Crashes Sample:")
print(planes_df.head())

# Gloabl Life Expectancy dataset
global_le_path = kagglehub.dataset_download("hasibalmuzdadid/global-life-expectancy-historical-dataset")
global_le_file = os.path.join(global_le_path, "global life expectancy dataset.csv")
global_le_df = pd.read_csv(global_le_file)
print("Global Life Expectancy Historical Dataset Sample:")
print(global_le_df.head())

# US death rate Dataset
death_rates_path = kagglehub.dataset_download("melissamonfared/death-rates-united-states")
death_rates_file = os.path.join(death_rates_path, "Death_rates.csv")
death_rates_df = pd.read_csv(death_rates_file)
print("Death Rates United States Dataset Sample:")
print(death_rates_df.head())

Life Expectancy Sample:
       Country  Year      Status  Life expectancy   Adult Mortality  \
0  Afghanistan  2015  Developing              65.0            263.0   
1  Afghanistan  2014  Developing              59.9            271.0   
2  Afghanistan  2013  Developing              59.9            268.0   
3  Afghanistan  2012  Developing              59.5            272.0   
4  Afghanistan  2011  Developing              59.2            275.0   

   infant deaths  Alcohol  percentage expenditure  Hepatitis B  Measles   ...  \
0             62     0.01               71.279624         65.0      1154  ...   
1             64     0.01               73.523582         62.0       492  ...   
2             66     0.01               73.219243         64.0       430  ...   
3             69     0.01               78.184215         67.0      2787  ...   
4             71     0.01                7.097109         68.0      3013  ...   

   Polio  Total expenditure  Diphtheria    HIV/AIDS         GD

In [15]:
def enhanced_feature_engineering(df, life_exp_df, global_le_df, death_rates_df):
    """
    Feature engineering for survival analysis with proper censoring handling
    """
    # -------- Validate Input Columns --------
    required_columns = {
        'Country', 'Gender', 'Occupation', 'Birth year',
        'Death year', 'Age of death'
    }
    missing = required_columns - set(df.columns)
    if missing:
        raise KeyError(f"Missing required columns: {missing}")

    # -------- Set Observation Year --------
    current_year = 2019
    
    # -------- Basic Cleaning --------
    df['Country'] = df['Country'].str.split(';').str[0].str.strip()
    df['Gender'] = np.where(df['Gender'] == 'Male', 1, 
                          np.where(df['Gender'] == 'Female', 0, 0.5))

    # -------- Clinical Features --------
    stress_map = {'Politician': 9, 'Military personnel': 8, 'Journalist': 7,
                  'Businessperson': 6, 'Artist': 5, 'Teacher': 4, 
                  'Researcher': 3, 'Other': 5, 'Unknown': 5}
    df['stress_score'] = df['Occupation'].map(stress_map).fillna(5).astype('float32') / 9.0

    life_exp_df[' BMI '] = pd.to_numeric(life_exp_df[' BMI '], errors='coerce')
    country_bmi = life_exp_df.groupby('Country')[' BMI '].median().to_dict()
    df['avg_bmi'] = df['Country'].map(country_bmi).fillna(25).astype('float32')

    df['smoking_prev'] = (1 / (1 + np.exp((df['Birth year'] - 1950) / 10))).astype('float32')
    df['smoking_prev'] = np.clip(df['smoking_prev'], 0.1, 0.6)

    # -------- Country-Level Features --------
    global_le_df.columns = global_le_df.columns.str.strip()
    global_le_melted = global_le_df.melt(
        id_vars=['Country Name', 'Country Code'],
        value_vars=[str(y) for y in range(1960, current_year+1)],
        var_name='Year',
        value_name='Life_Exp_Value'
    )
    
    global_le_agg = (
        global_le_melted
        .sort_values(['Country Name', 'Year'], ascending=[True, False])
        .groupby('Country Name')
        ['Life_Exp_Value']
        .first()
        .reset_index()
        .rename(columns={'Country Name': 'Country'})
    )
    
    df = df.merge(global_le_agg, on='Country', how='left')
    df['global_life_exp'] = df['Life_Exp_Value'].fillna(df['Life_Exp_Value'].median())

    # -------- Survival Data Setup --------
    df['censored'] = (df['Death year'] > current_year).astype(int)
    df['T'] = np.where(
        df['censored'] == 1,
        current_year - df['Birth year'],
        df['Age of death']
    ).clip(0, 120)

    return df


In [9]:
class DeepSurv(nn.Module):
    def __init__(self, input_dim):
        super(DeepSurv, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        return self.model(x)


In [16]:
def train_deepsurv_model(df):
    """
    Train DeepSurv Model using Pycox for survival analysis
    """
    feature_list = ['stress_score', 'avg_bmi', 'smoking_prev', 'global_life_exp']
    X = df[feature_list].values.astype('float32')
    durations = df['T'].values.astype('float32')
    events = df['censored'].values.astype('bool')

    X_train, X_val, durations_train, durations_val, events_train, events_val = train_test_split(
        X, durations, events, test_size=0.2, random_state=42
    )

    in_features = X_train.shape[1]
    num_nodes = [64, 64]
    out_features = 1
    batch_norm = True
    dropout = 0.1
    net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm, dropout)

    model = CoxPH(net, tt.optim.Adam)
    y_train_tuple = (durations_train, events_train)
    y_val_tuple = (durations_val, events_val)

    model.fit(X_train, y_train_tuple, batch_size=256, epochs=100, verbose=True, val_data=(X_val, y_val_tuple))

    model.compute_baseline_hazards()

    return model, X_val, durations_val, events_val


In [17]:
def evaluate_deepsurv_model(model, X_val, durations_val, events_val):
    """
    Evaluate DeepSurv Model Performance using Concordance Index
    """
    surv = model.predict_surv_df(X_val)
    surv.index = pd.to_numeric(surv.index, errors='coerce')

    if 0 not in surv.index:
        new_row = pd.DataFrame(np.ones((1, surv.shape[1])), index=[0], columns=surv.columns)
        surv = pd.concat([new_row, surv])
        surv = surv.sort_index()

    ev = EvalSurv(surv, durations_val, events_val, censor_surv='km')
    c_index = ev.concordance_td('antolini')

    print(f"\n📊 DeepSurv Concordance Index: {c_index:.4f}")
    return c_index


In [18]:
# 1️⃣ Process the Dataset
processed_batch = enhanced_feature_engineering(age_df, life_exp_df, global_le_df, death_rates_df)

# 2️⃣ Train DeepSurv
print("\n🚀 Training DeepSurv Model...")
deepsurv_model, X_val, durations_val, events_val = train_deepsurv_model(processed_batch)

# 3️⃣ Evaluate Model
evaluate_deepsurv_model(deepsurv_model, X_val, durations_val, events_val)



🚀 Training DeepSurv Model...
0:	[7s / 7s],	
1:	[7s / 15s],	
2:	[7s / 22s],	
3:	[7s / 30s],	
4:	[7s / 37s],	
5:	[7s / 44s],	
6:	[7s / 52s],	
7:	[7s / 59s],	
8:	[7s / 1m:7s],	
9:	[7s / 1m:14s],	
10:	[7s / 1m:22s],	
11:	[7s / 1m:29s],	
12:	[7s / 1m:36s],	
13:	[7s / 1m:44s],	
14:	[7s / 1m:51s],	
15:	[7s / 1m:59s],	
16:	[7s / 2m:6s],	
17:	[7s / 2m:14s],	
18:	[7s / 2m:21s],	
19:	[7s / 2m:28s],	
20:	[7s / 2m:36s],	
21:	[7s / 2m:43s],	
22:	[7s / 2m:51s],	
23:	[7s / 2m:58s],	
24:	[7s / 3m:6s],	
25:	[7s / 3m:13s],	
26:	[7s / 3m:20s],	
27:	[7s / 3m:28s],	
28:	[7s / 3m:35s],	
29:	[7s / 3m:43s],	
30:	[7s / 3m:50s],	
31:	[7s / 3m:58s],	
32:	[7s / 4m:5s],	
33:	[7s / 4m:12s],	
34:	[7s / 4m:20s],	
35:	[7s / 4m:27s],	
36:	[7s / 4m:35s],	
37:	[7s / 4m:42s],	
38:	[7s / 4m:49s],	
39:	[7s / 4m:57s],	
40:	[7s / 5m:4s],	
41:	[7s / 5m:12s],	
42:	[7s / 5m:19s],	
43:	[7s / 5m:27s],	
44:	[7s / 5m:34s],	
45:	[7s / 5m:42s],	
46:	[7s / 5m:49s],	
47:	[7s / 5m:57s],	
48:	[7s / 6m:4s],	
49:	[7s / 6m:12s],	
50:	[7s / 6

0.0