# **Survival Analysis Revamp: Death Prediction 2.0**
## **Project Overview**
This project aims to revamp the original **death prediction model** into a **production-grade survival analysis system**. Instead of directly predicting an age of death, we model the **probability of survival over time**, accounting for censoring (individuals still alive).

## **Why Survival Analysis?**
Survival analysis is widely used in **healthcare, finance, and customer retention**:
- **Healthcare:** Predict patient survival rates.
- **Finance:** Credit risk and loan default probabilities.
- **Subscription Businesses:** Customer churn prediction (e.g., Netflix, Spotify).

## **Key Steps**
### **1️⃣ Reframe as Survival Analysis**
- Convert the dataset to survival format.
- Use Python’s `lifelines` and PyTorch-based `pycox`.
- Handle **censored data** (people still alive in 2024).

### **2️⃣ Train Survival Models**
- **Traditional Cox Proportional Hazards Model (`lifelines`)**
- **DeepSurv (Neural Networks for Survival Analysis)**
- **Transformer-based Time-to-Event Models (TFTs, Hugging Face Transformers)**

### **3️⃣ Deploy as an API**
- Wrap the trained model in a **FastAPI** server.
- Package with **Docker**.
- Deploy using **Google Cloud Run / AWS Lambda**.

## **Technologies Used**
- **Libraries:** `lifelines`, `pycox`, `FastAPI`, `Hugging Face Transformers`
- **Model Training:** Traditional (Cox Model) & Deep Learning (DeepSurv, TFT)
- **Deployment:** FastAPI, Docker, Google Cloud Run/AWS Lambda

---

> 📌 **Next Steps:** Run the first code cell to preprocess the dataset and train the baseline Cox Proportional Hazards Model.


In [1]:
!pip install -U sentence-transformers> /dev/null 2>&1
!pip install xgboost scikit-survival /dev/null 2>&1



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m
[31mERROR: Invalid requirement: '/dev/null': Expected package name at the start of dependency specifier
    /dev/null
    ^
Hint: It looks like a path. The path does exist. The argument you provided (/dev/null) appears to be a requirements file. If that is the case, use the '-r' flag to install the packages specified within it.[0m[31m
[0m

In [2]:
import numpy as np 
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import kagglehub
import os

from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score, mean_absolute_percentage_error, mean_squared_log_error, explained_variance_score
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.utils import resample
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from scipy.stats import norm
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

import xgboost as xgb


In [3]:
# Download Life Expectancy dataset
life_exp_path = kagglehub.dataset_download("kumarajarshi/life-expectancy-who")
life_exp_file = os.path.join(life_exp_path, "Life Expectancy Data.csv")
life_exp_df = pd.read_csv(life_exp_file)

print(life_exp_df.head())

heart_path = kagglehub.dataset_download("fedesoriano/heart-failure-prediction")
heart_file = os.path.join(heart_path, "heart.csv")
heart_df = pd.read_csv(heart_file)

print(heart_df.head())

age_path = kagglehub.dataset_download("imoore/age-dataset")
age_file = os.path.join(age_path, "AgeDataset-V1.csv")  #
age_df = pd.read_csv(age_file)

print(age_df.head())

       Country  Year      Status  Life expectancy   Adult Mortality  \
0  Afghanistan  2015  Developing              65.0            263.0   
1  Afghanistan  2014  Developing              59.9            271.0   
2  Afghanistan  2013  Developing              59.9            268.0   
3  Afghanistan  2012  Developing              59.5            272.0   
4  Afghanistan  2011  Developing              59.2            275.0   

   infant deaths  Alcohol  percentage expenditure  Hepatitis B  Measles   ...  \
0             62     0.01               71.279624         65.0      1154  ...   
1             64     0.01               73.523582         62.0       492  ...   
2             66     0.01               73.219243         64.0       430  ...   
3             69     0.01               78.184215         67.0      2787  ...   
4             71     0.01                7.097109         68.0      3013  ...   

   Polio  Total expenditure  Diphtheria    HIV/AIDS         GDP  Population  \
0    6.

# Exploration and Data Cleaning

In [4]:
print("Life Expectancy Columns:", life_exp_df.columns.tolist())
print("Years:", life_exp_df['Year'].unique())
print("Missing Values:\n", life_exp_df.isnull().sum())

Life Expectancy Columns: ['Country', 'Year', 'Status', 'Life expectancy ', 'Adult Mortality', 'infant deaths', 'Alcohol', 'percentage expenditure', 'Hepatitis B', 'Measles ', ' BMI ', 'under-five deaths ', 'Polio', 'Total expenditure', 'Diphtheria ', ' HIV/AIDS', 'GDP', 'Population', ' thinness  1-19 years', ' thinness 5-9 years', 'Income composition of resources', 'Schooling']
Years: [2015 2014 2013 2012 2011 2010 2009 2008 2007 2006 2005 2004 2003 2002
 2001 2000]
Missing Values:
 Country                              0
Year                                 0
Status                               0
Life expectancy                     10
Adult Mortality                     10
infant deaths                        0
Alcohol                            194
percentage expenditure               0
Hepatitis B                        553
Measles                              0
 BMI                                34
under-five deaths                    0
Polio                               19
Total

In [5]:
print("Heart Failure Columns:", heart_df.columns.tolist())
print("Missing Values:\n", heart_df.isnull().sum())

Heart Failure Columns: ['Age', 'Sex', 'ChestPainType', 'RestingBP', 'Cholesterol', 'FastingBS', 'RestingECG', 'MaxHR', 'ExerciseAngina', 'Oldpeak', 'ST_Slope', 'HeartDisease']
Missing Values:
 Age               0
Sex               0
ChestPainType     0
RestingBP         0
Cholesterol       0
FastingBS         0
RestingECG        0
MaxHR             0
ExerciseAngina    0
Oldpeak           0
ST_Slope          0
HeartDisease      0
dtype: int64


In [6]:
print("Age Dataset Columns:", age_df.columns.tolist())
print("Missing Values:\n", age_df.isnull().sum())

Age Dataset Columns: ['Id', 'Name', 'Short description', 'Gender', 'Country', 'Occupation', 'Birth year', 'Death year', 'Manner of death', 'Age of death']
Missing Values:
 Id                         0
Name                       0
Short description      67900
Gender                133646
Country               335509
Occupation            206914
Birth year                 0
Death year                 1
Manner of death      1169406
Age of death               1
dtype: int64


In [7]:
# -------------------------- Life Expectancy Dataset --------------------------
# Drop rows with missing target
life_exp_df = life_exp_df.dropna(subset=['Life expectancy '])

# Fill Alcohol: country/year median → global median if still missing
life_exp_df['Alcohol'] = life_exp_df.groupby(['Country', 'Year'])['Alcohol'].transform(
    lambda x: x.fillna(x.median())
)
life_exp_df['Alcohol'] = life_exp_df['Alcohol'].fillna(life_exp_df['Alcohol'].median())

# Fill GDP: country median → global median
life_exp_df['GDP'] = life_exp_df.groupby('Country')['GDP'].transform(
    lambda x: x.fillna(x.median())
)
life_exp_df['GDP'] = life_exp_df['GDP'].fillna(life_exp_df['GDP'].median())

# Drop unnecessary columns
life_exp_df = life_exp_df.drop(columns=[
    'Hepatitis B', 'Population', 'Income composition of resources',
    ' thinness  1-19 years', ' thinness 5-9 years'
])

# Final fill for any remaining nulls
life_exp_df = life_exp_df.fillna(method='ffill').fillna(method='bfill')

# -------------------------- Heart Failure Dataset ----------------------------
# Convert categoricals
heart_df = pd.get_dummies(
    heart_df, 
    columns=['ChestPainType', 'RestingECG', 'ST_Slope'],
    drop_first=True
)
heart_df['ExerciseAngina'] = heart_df['ExerciseAngina'].map({'Y': 1, 'N': 0})

# ---------------------------- Age Dataset ------------------------------------
# Drop death-related missingness
age_df = age_df.dropna(subset=['Death year', 'Age of death'])

# Clean categorical columns
for col in ['Gender', 'Country', 'Occupation', 'Short description']:
    age_df[col] = age_df[col].fillna('Unknown')

# Simplify country names
age_df['Country'] = age_df['Country'].str.split(';').str[0]

# Group rare occupations (threshold = 1000)
occupation_counts = age_df['Occupation'].value_counts()
age_df['Occupation'] = np.where(
    age_df['Occupation'].isin(occupation_counts[occupation_counts >= 1000].index),
    age_df['Occupation'],
    'Other'
)

# Encode gender (handle unknowns)
age_df['Gender'] = np.where(
    age_df['Gender'] == 'Male', 1,
    np.where(age_df['Gender'] == 'Female', 0, 0.5)
)

# Drop unnecessary column
age_df = age_df.drop(columns=['Manner of death'])

# ---------------------------- Validation -------------------------------------
print("\nFinal Missing Values:")
print("Life Expectancy:\n", life_exp_df.isnull().sum())
print("\nHeart Failure:\n", heart_df.isnull().sum())
print("\nAge Dataset:\n", age_df.isnull().sum())

print("\nSample Categories:")
print("Occupations:", age_df['Occupation'].unique()[:10])
print("Countries:", age_df['Country'].unique()[:10])

  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, ou


Final Missing Values:
Life Expectancy:
 Country                   0
Year                      0
Status                    0
Life expectancy           0
Adult Mortality           0
infant deaths             0
Alcohol                   0
percentage expenditure    0
Measles                   0
 BMI                      0
under-five deaths         0
Polio                     0
Total expenditure         0
Diphtheria                0
 HIV/AIDS                 0
GDP                       0
Schooling                 0
dtype: int64

Heart Failure:
 Age                  0
Sex                  0
RestingBP            0
Cholesterol          0
FastingBS            0
MaxHR                0
ExerciseAngina       0
Oldpeak              0
HeartDisease         0
ChestPainType_ATA    0
ChestPainType_NAP    0
ChestPainType_TA     0
RestingECG_Normal    0
RestingECG_ST        0
ST_Slope_Flat        0
ST_Slope_Up          0
dtype: int64

Age Dataset:
 Id                   0
Name                 0
Short descr

In [8]:
print("Age Dataset Occupations:", age_df['Occupation'].unique()[:20])
print("Life Expectancy Countries:", life_exp_df['Country'].unique()[:20])

print("Life Expectancy Missing After Cleaning:", life_exp_df.isnull().sum())
print("Age Dataset Missing After Cleaning:", age_df.isnull().sum())

Age Dataset Occupations: ['Politician' 'Artist' 'Other' 'Astronomer' 'Athlete' 'Researcher'
 'Military personnel' 'Philosopher' 'Businessperson' 'Explorer'
 'Architect' 'Teacher' 'Aristocrat' 'Entrepreneur' 'Journalist' 'Engineer'
 'Author' 'Unknown' 'Religious figure' 'Official']
Life Expectancy Countries: ['Afghanistan' 'Albania' 'Algeria' 'Angola' 'Antigua and Barbuda'
 'Argentina' 'Armenia' 'Australia' 'Austria' 'Azerbaijan' 'Bahamas'
 'Bahrain' 'Bangladesh' 'Barbados' 'Belarus' 'Belgium' 'Belize' 'Benin'
 'Bhutan' 'Bolivia (Plurinational State of)']
Life Expectancy Missing After Cleaning: Country                   0
Year                      0
Status                    0
Life expectancy           0
Adult Mortality           0
infant deaths             0
Alcohol                   0
percentage expenditure    0
Measles                   0
 BMI                      0
under-five deaths         0
Polio                     0
Total expenditure         0
Diphtheria                0
 HIV/AI

In [9]:
# ===================== OPTIMIZED SYNTHETIC FEATURES =====================
# --------------------- Age Dataset: Clinical Proxies ---------------------
# 1. Stress Score (Vectorized)
stress_map = {
    'Politician': 9, 'Military personnel': 8, 'Journalist': 7,
    'Businessperson': 6, 'Artist': 5, 'Teacher': 4, 
    'Researcher': 3, 'Other': 5, 'Unknown': 5
}
age_df['stress_score'] = age_df['Occupation'].map(stress_map).fillna(5).astype('int8')  # Fix: fill missing

# 2. BMI from Country (Precomputed Lookup)
country_bmi = life_exp_df.groupby('Country')[' BMI '].last().to_dict()
age_df['avg_bmi'] = age_df['Country'].map(country_bmi).fillna(25).astype('float32')

# 3. Heart Disease Risk (Binary Vectorization)
age_df['heart_disease_risk'] = np.where(age_df['Gender'] == 1, 0.65, 0.35).astype('float32')

# 4. Smoking Prevalence (Vectorized Calculation)
birth_years = age_df['Birth year'].to_numpy()
age_df['smoking_prev'] = np.clip(0.5 - 0.0035*(birth_years - 1950), 0.1, 0.6).astype('float32')


# --------------------- Country Features (Optimized Merge) -----------------
# Pre-filter and sort life expectancy data
life_exp_filtered = life_exp_df[['Country', 'Alcohol', 'GDP', 'Schooling']]\
    .sort_values('Country')\
    .groupby('Country').last()\
    .add_prefix('country_')

# Merge using categoricals for speed
age_df['Country'] = age_df['Country'].astype('category')
age_df = age_df.join(life_exp_filtered, on='Country', how='left')

# Fill missing values in-place
age_df['country_Alcohol'] = age_df['country_Alcohol'].fillna(age_df['country_Alcohol'].mean())
age_df['country_GDP'] = age_df['country_GDP'].fillna(age_df['country_GDP'].median())
age_df['country_Schooling'] = age_df['country_Schooling'].fillna(age_df['country_Schooling'].median())

# --------------------- Text Features (Lightweight Alternative) ------------
# Instead of BERT, use TF-IDF on occupation + description
from sklearn.feature_extraction.text import TfidfVectorizer

text_data = age_df['Occupation'] + " " + age_df['Short description'].fillna('')
tfidf = TfidfVectorizer(max_features=100)  # 100 dim vs 384 from BERT
text_features = tfidf.fit_transform(text_data)

# Convert to DataFrame and merge
text_df = pd.DataFrame(text_features.toarray(), 
                      columns=[f"tfidf_{i}" for i in range(text_features.shape[1])],
                      index=age_df.index)
age_df = pd.concat([age_df, text_df], axis=1)

# XGBoost Surrival Forest

In [10]:
# --------------------- Data Preparation ---------------------
# Features (using your synthetic features + TF-IDF)
features = ['stress_score', 'avg_bmi', 'heart_disease_risk', 
           'smoking_prev', 'country_Alcohol', 'country_GDP'] + \
          [c for c in age_df if c.startswith('tfidf_')]

X = age_df[features]
y = age_df['Age of death']

# Use quantile-based age bins to ensure minimum samples per stratum
n_bins = 20  # Reduced from 10 for finer stratification
age_bins = np.quantile(y, np.linspace(0, 1, n_bins + 1))
stratify_col = np.digitize(y, age_bins)

# Ensure minimum 2 samples per stratum
X_sample, _, y_sample, _ = train_test_split(
    X, y, 
    train_size=100000,
    stratify=stratify_col,
    random_state=42
)

# Split train/val
X_train, X_val, y_train, y_val = train_test_split(
    X_sample, y_sample, test_size=0.2, random_state=42
)

# --------------------- Train XGBoost (Accelerated Failure Time) ---------------------
params = {
    'objective': 'survival:aft',
    'eval_metric': 'aft-nloglik',
    'tree_method': 'hist',  # Faster for large data
    'learning_rate': 0.1,
    'max_depth': 6,
    'verbosity': 1
}

dtrain = xgb.DMatrix(X_train, label=y_train)
dval = xgb.DMatrix(X_val, label=y_val)

model = xgb.train(
    params,
    dtrain,
    num_boost_round=1000,
    evals=[(dtrain, 'train'), (dval, 'val')],
    early_stopping_rounds=50,
    verbose_eval=20
)

# --------------------- Evaluation ---------------------
preds = model.predict(dval)
print(f"MAE: {mean_absolute_error(y_val, preds):.2f} years")
print(f"RMSE: {np.sqrt(mean_squared_error(y_val, preds)):.2f} years")

# Feature Importance
xgb.plot_importance(model, max_num_features=20)

ValueError: The least populated class in y has only 1 member, which is too few. The minimum number of groups for any class cannot be less than 2.