# **Survival Analysis Revamp: Death Prediction 2.0**
## **Project Overview**
This project aims to revamp the original **death prediction model** into a **production-grade survival analysis system**. Instead of directly predicting an age of death, we model the **probability of survival over time**, accounting for censoring (individuals still alive).

## **Why Survival Analysis?**
Survival analysis is widely used in **healthcare, finance, and customer retention**:
- **Healthcare:** Predict patient survival rates.
- **Finance:** Credit risk and loan default probabilities.
- **Subscription Businesses:** Customer churn prediction (e.g., Netflix, Spotify).

## **Key Steps**
### **1️⃣ Reframe as Survival Analysis**
- Convert the dataset to survival format.
- Use Python’s `lifelines` and PyTorch-based `pycox`.
- Handle **censored data** (people still alive in 2024).

### **2️⃣ Train Survival Models**
- **Traditional Cox Proportional Hazards Model (`lifelines`)**
- **DeepSurv (Neural Networks for Survival Analysis)**
- **Transformer-based Time-to-Event Models (TFTs, Hugging Face Transformers)**

### **3️⃣ Deploy as an API**
- Wrap the trained model in a **FastAPI** server.
- Package with **Docker**.
- Deploy using **Google Cloud Run / AWS Lambda**.

## **Technologies Used**
- **Libraries:** `lifelines`, `pycox`, `FastAPI`, `Hugging Face Transformers`
- **Model Training:** Traditional (Cox Model) & Deep Learning (DeepSurv, TFT)
- **Deployment:** FastAPI, Docker, Google Cloud Run/AWS Lambda

---

> 📌 **Next Steps:** Run the first code cell to preprocess the dataset and train the baseline Cox Proportional Hazards Model.


In [1]:
!pip install -U sentence-transformers > /dev/null 2>&1
!pip install xgboost > /dev/null 2>&1
!pip install scikit-learn==1.4.2 scikit-survival==0.23.1 > /dev/null 2>&1


In [11]:
import numpy as np 
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import kagglehub
import os
import gc

from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score, mean_absolute_percentage_error, mean_squared_log_error, explained_variance_score
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.utils import resample
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from scipy.stats import norm
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

import xgboost as xgb
from sksurv.util import Surv
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.metrics import concordance_index_censored




In [3]:
# Download Life Expectancy dataset
life_exp_path = kagglehub.dataset_download("kumarajarshi/life-expectancy-who")
life_exp_file = os.path.join(life_exp_path, "Life Expectancy Data.csv")
life_exp_df = pd.read_csv(life_exp_file)

print(life_exp_df.head())

heart_path = kagglehub.dataset_download("fedesoriano/heart-failure-prediction")
heart_file = os.path.join(heart_path, "heart.csv")
heart_df = pd.read_csv(heart_file)

print(heart_df.head())

age_path = kagglehub.dataset_download("imoore/age-dataset")
age_file = os.path.join(age_path, "AgeDataset-V1.csv")  #
age_df = pd.read_csv(age_file)

print(age_df.head())

       Country  Year      Status  Life expectancy   Adult Mortality  \
0  Afghanistan  2015  Developing              65.0            263.0   
1  Afghanistan  2014  Developing              59.9            271.0   
2  Afghanistan  2013  Developing              59.9            268.0   
3  Afghanistan  2012  Developing              59.5            272.0   
4  Afghanistan  2011  Developing              59.2            275.0   

   infant deaths  Alcohol  percentage expenditure  Hepatitis B  Measles   ...  \
0             62     0.01               71.279624         65.0      1154  ...   
1             64     0.01               73.523582         62.0       492  ...   
2             66     0.01               73.219243         64.0       430  ...   
3             69     0.01               78.184215         67.0      2787  ...   
4             71     0.01                7.097109         68.0      3013  ...   

   Polio  Total expenditure  Diphtheria    HIV/AIDS         GDP  Population  \
0    6.

# Exploration and Data Cleaning

In [4]:
print("Life Expectancy Columns:", life_exp_df.columns.tolist())
print("Years:", life_exp_df['Year'].unique())
print("Missing Values:\n", life_exp_df.isnull().sum())

Life Expectancy Columns: ['Country', 'Year', 'Status', 'Life expectancy ', 'Adult Mortality', 'infant deaths', 'Alcohol', 'percentage expenditure', 'Hepatitis B', 'Measles ', ' BMI ', 'under-five deaths ', 'Polio', 'Total expenditure', 'Diphtheria ', ' HIV/AIDS', 'GDP', 'Population', ' thinness  1-19 years', ' thinness 5-9 years', 'Income composition of resources', 'Schooling']
Years: [2015 2014 2013 2012 2011 2010 2009 2008 2007 2006 2005 2004 2003 2002
 2001 2000]
Missing Values:
 Country                              0
Year                                 0
Status                               0
Life expectancy                     10
Adult Mortality                     10
infant deaths                        0
Alcohol                            194
percentage expenditure               0
Hepatitis B                        553
Measles                              0
 BMI                                34
under-five deaths                    0
Polio                               19
Total

In [5]:
print("Heart Failure Columns:", heart_df.columns.tolist())
print("Missing Values:\n", heart_df.isnull().sum())

Heart Failure Columns: ['Age', 'Sex', 'ChestPainType', 'RestingBP', 'Cholesterol', 'FastingBS', 'RestingECG', 'MaxHR', 'ExerciseAngina', 'Oldpeak', 'ST_Slope', 'HeartDisease']
Missing Values:
 Age               0
Sex               0
ChestPainType     0
RestingBP         0
Cholesterol       0
FastingBS         0
RestingECG        0
MaxHR             0
ExerciseAngina    0
Oldpeak           0
ST_Slope          0
HeartDisease      0
dtype: int64


In [6]:
print("Age Dataset Columns:", age_df.columns.tolist())
print("Missing Values:\n", age_df.isnull().sum())

Age Dataset Columns: ['Id', 'Name', 'Short description', 'Gender', 'Country', 'Occupation', 'Birth year', 'Death year', 'Manner of death', 'Age of death']
Missing Values:
 Id                         0
Name                       0
Short description      67900
Gender                133646
Country               335509
Occupation            206914
Birth year                 0
Death year                 1
Manner of death      1169406
Age of death               1
dtype: int64


In [7]:
# -------------------------- Life Expectancy Dataset --------------------------
# Drop rows with missing target
life_exp_df = life_exp_df.dropna(subset=['Life expectancy '])

# Fill Alcohol: country/year median → global median if still missing
life_exp_df['Alcohol'] = life_exp_df.groupby(['Country', 'Year'])['Alcohol'].transform(
    lambda x: x.fillna(x.median())
)
life_exp_df['Alcohol'] = life_exp_df['Alcohol'].fillna(life_exp_df['Alcohol'].median())

# Fill GDP: country median → global median
life_exp_df['GDP'] = life_exp_df.groupby('Country')['GDP'].transform(
    lambda x: x.fillna(x.median())
)
life_exp_df['GDP'] = life_exp_df['GDP'].fillna(life_exp_df['GDP'].median())

# Drop unnecessary columns
life_exp_df = life_exp_df.drop(columns=[
    'Hepatitis B', 'Population', 'Income composition of resources',
    ' thinness  1-19 years', ' thinness 5-9 years'
])

# Final fill for any remaining nulls
life_exp_df = life_exp_df.fillna(method='ffill').fillna(method='bfill')

# -------------------------- Heart Failure Dataset ----------------------------
# Convert categoricals
heart_df = pd.get_dummies(
    heart_df, 
    columns=['ChestPainType', 'RestingECG', 'ST_Slope'],
    drop_first=True
)
heart_df['ExerciseAngina'] = heart_df['ExerciseAngina'].map({'Y': 1, 'N': 0})

# ---------------------------- Age Dataset ------------------------------------
# Drop death-related missingness
age_df = age_df.dropna(subset=['Death year', 'Age of death'])

# Clean categorical columns
for col in ['Gender', 'Country', 'Occupation', 'Short description']:
    age_df[col] = age_df[col].fillna('Unknown')

# Simplify country names
age_df['Country'] = age_df['Country'].str.split(';').str[0]

# Group rare occupations (threshold = 1000)
occupation_counts = age_df['Occupation'].value_counts()
age_df['Occupation'] = np.where(
    age_df['Occupation'].isin(occupation_counts[occupation_counts >= 1000].index),
    age_df['Occupation'],
    'Other'
)

# Encode gender (handle unknowns)
age_df['Gender'] = np.where(
    age_df['Gender'] == 'Male', 1,
    np.where(age_df['Gender'] == 'Female', 0, 0.5)
)

# Drop unnecessary column
age_df = age_df.drop(columns=['Manner of death'])

# ---------------------------- Validation -------------------------------------
print("\nFinal Missing Values:")
print("Life Expectancy:\n", life_exp_df.isnull().sum())
print("\nHeart Failure:\n", heart_df.isnull().sum())
print("\nAge Dataset:\n", age_df.isnull().sum())

print("\nSample Categories:")
print("Occupations:", age_df['Occupation'].unique()[:10])
print("Countries:", age_df['Country'].unique()[:10])

  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, ou


Final Missing Values:
Life Expectancy:
 Country                   0
Year                      0
Status                    0
Life expectancy           0
Adult Mortality           0
infant deaths             0
Alcohol                   0
percentage expenditure    0
Measles                   0
 BMI                      0
under-five deaths         0
Polio                     0
Total expenditure         0
Diphtheria                0
 HIV/AIDS                 0
GDP                       0
Schooling                 0
dtype: int64

Heart Failure:
 Age                  0
Sex                  0
RestingBP            0
Cholesterol          0
FastingBS            0
MaxHR                0
ExerciseAngina       0
Oldpeak              0
HeartDisease         0
ChestPainType_ATA    0
ChestPainType_NAP    0
ChestPainType_TA     0
RestingECG_Normal    0
RestingECG_ST        0
ST_Slope_Flat        0
ST_Slope_Up          0
dtype: int64

Age Dataset:
 Id                   0
Name                 0
Short descr

In [8]:
print("Age Dataset Occupations:", age_df['Occupation'].unique()[:20])
print("Life Expectancy Countries:", life_exp_df['Country'].unique()[:20])

print("Life Expectancy Missing After Cleaning:", life_exp_df.isnull().sum())
print("Age Dataset Missing After Cleaning:", age_df.isnull().sum())

Age Dataset Occupations: ['Politician' 'Artist' 'Other' 'Astronomer' 'Athlete' 'Researcher'
 'Military personnel' 'Philosopher' 'Businessperson' 'Explorer'
 'Architect' 'Teacher' 'Aristocrat' 'Entrepreneur' 'Journalist' 'Engineer'
 'Author' 'Unknown' 'Religious figure' 'Official']
Life Expectancy Countries: ['Afghanistan' 'Albania' 'Algeria' 'Angola' 'Antigua and Barbuda'
 'Argentina' 'Armenia' 'Australia' 'Austria' 'Azerbaijan' 'Bahamas'
 'Bahrain' 'Bangladesh' 'Barbados' 'Belarus' 'Belgium' 'Belize' 'Benin'
 'Bhutan' 'Bolivia (Plurinational State of)']
Life Expectancy Missing After Cleaning: Country                   0
Year                      0
Status                    0
Life expectancy           0
Adult Mortality           0
infant deaths             0
Alcohol                   0
percentage expenditure    0
Measles                   0
 BMI                      0
under-five deaths         0
Polio                     0
Total expenditure         0
Diphtheria                0
 HIV/AI

In [9]:
# Global settings for batch processing
BATCH_SIZE = 50000
N_ITERATIONS = 5
BASE_RANDOM_STATE = 42

def preprocess_and_engineer_features(df, life_exp_df):
    """
    Given a raw batch (df) from the Age dataset and the life expectancy dataframe,
    perform data cleaning and synthetic feature engineering.
    """
    # --- Preprocessing & Cleaning ---
    df = df.dropna(subset=['Death year', 'Age of death'])
    for col in ['Gender', 'Country', 'Occupation', 'Short description']:
        df[col] = df[col].fillna('Unknown')
    df['Country'] = df['Country'].str.split(';').str[0]
    occupation_counts = df['Occupation'].value_counts()
    df['Occupation'] = np.where(
        df['Occupation'].isin(occupation_counts[occupation_counts >= 1000].index),
        df['Occupation'],
        'Other'
    )
    df['Gender'] = np.where(df['Gender'] == 'Male', 1,
                            np.where(df['Gender'] == 'Female', 0, 0.5))
    if 'Manner of death' in df.columns:
        df = df.drop(columns=['Manner of death'])
    
    # --- Synthetic Feature Engineering ---
    # 1. Stress Score
    stress_map = {
        'Politician': 9, 'Military personnel': 8, 'Journalist': 7,
        'Businessperson': 6, 'Artist': 5, 'Teacher': 4, 
        'Researcher': 3, 'Other': 5, 'Unknown': 5
    }
    df['stress_score'] = df['Occupation'].map(stress_map).fillna(5).astype('int8')
    
    # 2. BMI from Country (lookup from life_exp_df)
    country_bmi = life_exp_df.groupby('Country')[' BMI '].last().to_dict()
    df['avg_bmi'] = df['Country'].map(country_bmi).fillna(25).astype('float32')
    
    # 3. Heart Disease Risk (binary proxy)
    df['heart_disease_risk'] = np.where(df['Gender'] == 1, 0.65, 0.35).astype('float32')
    
    # 4. Smoking Prevalence
    birth_years = df['Birth year'].to_numpy()
    df['smoking_prev'] = np.clip(0.5 - 0.0035*(birth_years - 1950), 0.1, 0.6).astype('float32')
    
    # 5. Country Features: Merge with life expectancy data
    life_exp_filtered = (life_exp_df[['Country', 'Alcohol', 'GDP', 'Schooling']]
                         .sort_values('Country')
                         .groupby('Country').last()
                         .add_prefix('country_'))
    df['Country'] = df['Country'].astype('category')
    df = df.join(life_exp_filtered, on='Country', how='left')
    df['country_Alcohol'] = df['country_Alcohol'].fillna(df['country_Alcohol'].mean())
    df['country_GDP'] = df['country_GDP'].fillna(df['country_GDP'].median())
    df['country_Schooling'] = df['country_Schooling'].fillna(df['country_Schooling'].median())
    
    # 6. Text Features: Use TF-IDF on Occupation + Short description
    text_data = df['Occupation'] + " " + df['Short description'].fillna('')
    tfidf = TfidfVectorizer(max_features=100)
    text_features = tfidf.fit_transform(text_data)
    text_df = pd.DataFrame(
        text_features.toarray(),
        columns=[f"tfidf_{i}" for i in range(text_features.shape[1])],
        index=df.index
    )
    df = pd.concat([df, text_df], axis=1)
    
    return df

def train_and_evaluate_models(df):
    """
    Given a processed DataFrame with synthetic features, split it into training/validation sets,
    train an XGBoost survival model (AFT) and a Cox PH model, and return evaluation metrics.
    """
    # Define the features to use: synthetic features + any TF-IDF features.
    feature_list = ['stress_score', 'avg_bmi', 'heart_disease_risk', 'smoking_prev', 'country_Alcohol', 'country_GDP']
    tfidf_cols = [col for col in df.columns if col.startswith('tfidf_')]
    features = feature_list + tfidf_cols

    X = df[features]
    y = df['Age of death']

    # Split data
    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=BASE_RANDOM_STATE)

    # ----- XGBoost AFT Model -----
    dtrain = xgb.DMatrix(X_train, label=y_train, feature_names=features)
    dval   = xgb.DMatrix(X_val, label=y_val, feature_names=features)
    dtrain.set_float_info("label_lower_bound", y_train)
    dtrain.set_float_info("label_upper_bound", y_train)
    dval.set_float_info("label_lower_bound", y_val)
    dval.set_float_info("label_upper_bound", y_val)

    params_aft = {
        'objective': 'survival:aft',
        'eval_metric': 'aft-nloglik',
        'aft_loss_distribution': 'normal',
        'aft_loss_distribution_scale': 0.1,
        'tree_method': 'approx',
        'learning_rate': 0.01,
        'max_depth': 4,
        'subsample': 0.7,
        'colsample_bytree': 0.7,
        'verbosity': 1
    }
    print("\nXGBoost AFT Training:")
    try:
        model_xgb = xgb.train(
            params_aft,
            dtrain,
            num_boost_round=100,
            evals=[(dtrain, 'train'), (dval, 'val')],
            verbose_eval=10
        )
        preds_xgb = model_xgb.predict(dval)
        rmse = np.sqrt(mean_squared_error(y_val, preds_xgb))
        mae = mean_absolute_error(y_val, preds_xgb)
    except Exception as e:
        print("XGBoost training failed:", e)
        rmse = None
        mae = None

    # ----- Cox Proportional Hazards Model -----
    y_surv_train = Surv.from_arrays(event=np.ones(len(y_train), dtype=bool), time=y_train.to_numpy())
    cox = CoxPHSurvivalAnalysis(alpha=0.5)
    cox.fit(X_train, y_surv_train)
    cindex = concordance_index_censored(
        np.ones(len(y_val), dtype=bool),
        y_val.to_numpy(),
        cox.predict(X_val)
    )[0]

    coef_df = pd.DataFrame({
        'feature': X_train.columns,
        'coef': cox.coef_,
        'abs_coef': np.abs(cox.coef_)
    }).sort_values('abs_coef', ascending=False)

    metrics = {
        'XGBoost_RMSE': rmse,
        'XGBoost_MAE': mae,
        'Cox_Concordance': cindex,
        'Cox_Feature_Coefficients': coef_df.head(10)
    }
    return metrics

def process_and_train_batch(random_seed):
    """
    Sample a batch from the full age_df, apply preprocessing/feature engineering,
    then train and evaluate the models.
    """
    batch_df = age_df.sample(n=BATCH_SIZE, random_state=random_seed).copy()
    print(f"\nProcessing batch with random seed {random_seed} (shape: {batch_df.shape})")
    
    processed_batch = preprocess_and_engineer_features(batch_df, life_exp_df)
    metrics = train_and_evaluate_models(processed_batch)
    
    # Free memory
    del batch_df, processed_batch
    gc.collect()
    
    return metrics

In [12]:
results = []
for i in range(N_ITERATIONS):
    seed = BASE_RANDOM_STATE + i
    metrics = process_and_train_batch(seed)
    results.append(metrics)
    print(f"Results for batch {i+1}:")
    print(f"  XGBoost RMSE: {metrics['XGBoost_RMSE']}")
    print(f"  XGBoost MAE: {metrics['XGBoost_MAE']}")
    print(f"  Cox Concordance Index: {metrics['Cox_Concordance']}")
    print("  Top Cox Features:")
    print(metrics['Cox_Feature_Coefficients'])
    print("-" * 50)

print("\nFinal aggregated results from all batches:")
print(results)



Processing batch with random seed 42 (shape: (50000, 9))

XGBoost AFT Training:
[0]	train-aft-nloglik:27.62989	val-aft-nloglik:27.62989
[10]	train-aft-nloglik:27.62862	val-aft-nloglik:27.62862
[20]	train-aft-nloglik:27.62755	val-aft-nloglik:27.62755
[30]	train-aft-nloglik:27.62669	val-aft-nloglik:27.62669
[40]	train-aft-nloglik:27.62602	val-aft-nloglik:27.62602
[50]	train-aft-nloglik:27.62555	val-aft-nloglik:27.62555
[60]	train-aft-nloglik:27.62529	val-aft-nloglik:27.62529
[70]	train-aft-nloglik:27.62481	val-aft-nloglik:27.62468
[80]	train-aft-nloglik:27.62448	val-aft-nloglik:27.62419
[90]	train-aft-nloglik:27.62440	val-aft-nloglik:27.62399
[99]	train-aft-nloglik:27.62456	val-aft-nloglik:27.62407
Results for batch 1:
  XGBoost RMSE: 69.79566694124222
  XGBoost MAE: 67.78838658621311
  Cox Concordance Index: 0.5665436742785513
  Top Cox Features:
         feature       coef   abs_coef
3   smoking_prev -12.191143  12.191143
22      tfidf_16   1.142863   1.142863
9        tfidf_3   1.050