<header style="padding:1px;background:#f9f9f9;border-top:3px solid #00b2b1"><img id="Teradata-logo" src="https://www.teradata.com/Teradata/Images/Rebrand/Teradata_logo-two_color.png" alt="Teradata" width="220" align="right" />

<b style = 'font-size:28px;font-family:Arial;color:#E37C4D'>Survival Analysis using teradataml</b>
</header>

<p style = 'font-size:18px;font-family:Arial;color:#E37C4D'><b>Introduction:</b></p>

<p style = 'font-size:16px;font-family:Arial'>Machine learning can be useful in heart failure prediction because it can analyze large amounts of data from multiple sources and identify complex patterns that may be difficult for humans to recognize. This can potentially improve the accuracy of prediction models and help healthcare professionals identify patients who are at high risk for heart failure, allowing for earlier intervention and better outcomes.</p>

<p style = 'font-size:18px;font-family:Arial;color:#E37C4D'><b>Data:</b></p>

<p style = 'font-size:16px;font-family:Arial'>This is a simulated dataset based on real hospital administrative data, which contains a random sample of emergency (unplanned) admissions for heart failure. The original records were linked to the national death registry in order to capture deaths that occur after discharge.</p>

<hr>
<p style = 'font-size:18px;font-family:Arial;color:#E37C4D'><b>Downloading and installing additional software needed</b>

In [None]:
%%capture
!pip install --upgrade teradataml

<p style = 'font-size:16px;font-family:Arial'>
    <i><b>*BEFORE proceeding, please RESTART the kernel to bring new software into Jupyter.</b></i>
</p>
<p style = 'font-size:16px;font-family:Arial'>Here, we import the required libraries, set environment variables and environment paths (if required).</p>

In [None]:
# system packages
import os
import sys
import warnings
warnings.filterwarnings("ignore")

from teradataml import *
from teradataml import valib

# Dataset packages 
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix, roc_curve, roc_auc_score

# plotting packages
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style="whitegrid")

%matplotlib inline
configure.val_install_location = "val"

<hr>
<b style = 'font-size:28px;font-family:Arial;color:#E37C4D'>1. Initiate a connection to Vantage</b>
<p style = 'font-size:18px;font-family:Arial;color:#E37C4D'> <b>Let's start by connecting to the Teradata system </b></p>
<p style = 'font-size:16px;font-family:Arial'>You will be prompted to provide the password. Enter your password, press the Enter key, and then use the down arrow to go to the next cell.</p>

In [None]:
%run -i ../startup.ipynb
eng = create_context(host = 'host.docker.internal', username='demo_user', password = password)
print(eng)
eng.execute('''SET query_band='DEMO=SurvivalAnalysis_Python.ipynb;' UPDATE FOR SESSION;''')

<p style = 'font-size:16px;font-family:Arial'>Begin running steps with Shift + Enter keys. </p>

<p style = 'font-size:20px;font-family:Arial;color:#E37C4D'><b>Getting Data for This Demo</b></p>
<p style = 'font-size:16px;font-family:Arial'>We have provided data for this demo on cloud storage. You can either run the demo using foreign tables to access the data without any storage on your environment or download the data to local storage, which may yield faster execution. Still, there could be considerations of available storage. Two statements are in the following cell, and one is commented out. You may switch which mode you choose by changing the comment string.</p>

In [None]:
%run -i ../run_procedure.py "call get_data('DEMO_SurvivalAnalysis_cloud');"        # Takes 10 seconds
# %run -i ../run_procedure.py "call get_data('DEMO_SurvivalAnalysis_local');"        # Takes 20 seconds

<p style = 'font-size:16px;font-family:Arial'>Next is an optional step – if you want to see the status of databases/tables created and space used.</p>

In [None]:
%run -i ../run_procedure.py "call space_report();"        # Takes 5 seconds

<hr>
<b style = 'font-size:28px;font-family:Arial;color:#E37C4D'>2. Read the data from Vantage as a teradataml Dataframe</b>

In [None]:
heart_failure = DataFrame(in_schema('DEMO_SurvivalAnalysis', 'heart_failure'))

In [None]:
print(heart_failure.shape)
heart_failure.head(5)

<p style = 'font-size:16px;font-family:Arial'>The dataset above has 31 columns in total and the 'death' column is the predicted column where 1 means the patient died and 0 means he/she did not.
<br>
Let's check if the dataset is balanced w.r.t gender.
</p>

In [None]:
grp_gen = heart_failure.select(['gender','death']).groupby(['gender']).agg(['mean', 'count']).to_pandas()
sns.barplot(x='gender', y='mean_death', data=grp_gen)
plt.title('Mortality rate by gender')
plt.show()

<p style = 'font-size:16px;font-family:Arial'>The above graph shows us that the number of deaths are equally distributed amoung the gender.</p>
<hr>
<b style = 'font-size:28px;font-family:Arial;color:#E37C4D'>2. Data Prepration</b>
<p style = 'font-size:16px;font-family:Arial'>Feature scaling is performed during data pre-processing to handle highly varying magnitudes, values, or units. If feature scaling is not done, then a machine learning algorithm tends to weigh greater values higher and consider smaller values as lower ones, regardless of the unit of the values.
<br>
<br>
By scaling the data, we can ensure that all the features have a similar scale, so the distance metric can properly consider both features. This can help the machine learning algorithm converge faster and produce more accurate predictions.
</p>

In [None]:
from teradataml import ScaleFit, ScaleTransform

sf_fit = ScaleFit(data = heart_failure, scale_method = 'STD', target_columns = ['2:30'])
sf_trns = ScaleTransform(data = heart_failure, object = sf_fit.output, accumulate = ['"id"','death'])

<p style = 'font-size:16px;font-family:Arial'>Splitting the data in training and testing datasets in 70:30 ratio.</p>

In [None]:
tdf_samples = sf_trns.result.sample(frac = [0.3, 0.7], randomize = True)
copy_to_sql(tdf_samples[tdf_samples['sampleid'] == 2], table_name = 'heart_failure_train', schema_name = 'demo_user', if_exists = 'replace')
copy_to_sql(tdf_samples[tdf_samples['sampleid'] == 1], table_name = 'heart_failure_test', schema_name = 'demo_user', if_exists = 'replace')

In [None]:
heart_failure_train = DataFrame('heart_failure_train')
heart_failure_test = DataFrame('heart_failure_test')
print("Training Set = "+str(heart_failure_train.shape[0])+". Testing Set = "+str(heart_failure_test.shape[0]))

<hr>
<b style = 'font-size:28px;font-family:Arial;color:#E37C4D'>3. Model Training</b>

<p style = 'font-size:16px;font-family:Arial'>The function is an ensemble algorithm used for classification and regression predictive modeling problems. It is an extension of bootstrap aggregation (bagging) of decision trees. Typically, constructing a decision tree involves evaluating the value for each input feature in the data to select a split point.
<br>
<br>
The function reduces the features to a random subset (that can be considered at each split point); the algorithm can force each decision tree in the forest to be very different to improve prediction accuracy.</p>

In [None]:
from teradataml import DecisionForest

DecisionForest_out = DecisionForest(data = heart_failure_train,
                            input_columns = ['2:30'],
                            response_column = 'death',
                            max_depth = 5,
                            num_trees = 20,
                            min_node_size = 2,
                            seed = 2,
                            tree_type = 'CLASSIFICATION')

<hr>
<b style = 'font-size:28px;font-family:Arial;color:#E37C4D'>4. Model Testing</b>
<p style = 'font-size:16px;font-family:Arial'>DecisionForestPredict outputs the probability that each observation is in the predicted class.</p>

In [None]:
from teradataml import DecisionForestPredict

decision_forest_predict_out = DecisionForestPredict(
                                                        object = DecisionForest_out,
                                                        newdata = heart_failure_test,
                                                        id_column = "id",
                                                        detailed = False,
                                                        output_response_probdist = True,
                                                        output_prob = True,
                                                        output_responses =  ['0', '1'],
                                                        terms = 'death'
                                                    )

In [None]:
rf_pred=decision_forest_predict_out.result.to_pandas()
rf_pred['prediction'] = rf_pred['prediction'].astype('int64')

In [None]:
rf_pred.head()

<p style = 'font-size:16px;font-family:Arial'>In the above result, the column <b>death</b> is ground truth, <b>prediction</b> is the predicted output and <b>(prob_0, prob_1)</b> are probabilities of the output class.
<br>
<br>
The accuracy of the model is as follows:
</p>

In [None]:
print("Accuracy:", round((rf_pred['prediction'] == rf_pred['death']).mean() * 100, 2))

<p style = 'font-size:16px;font-family:Arial'>The ROC curve is a graph between TPR(True Positive Rate) and FPR(False Positive Rate). The area under the ROC curve is a metric of how well the model can distinguish between positive and negative classes. The higher the AUC, the better the model's performance in distinguishing between the positive and negative categories. AUC above 0.75 is generally considered decent.</p>

In [None]:
false_pos_rate, true_pos_rate, _ = roc_curve(rf_pred.death, rf_pred['prob_1'].astype('float'))
auc = roc_auc_score(rf_pred.death, rf_pred['prob_1'].astype('float'))
plt.plot(false_pos_rate, true_pos_rate, label = 'DecisionForest'+": auc="+str(round(auc,3)))
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend(loc='lower right')
plt.show()

<p style = 'font-size:16px;font-family:Arial'>Looking at the above ROC Curve, we can say that the model has performed decently well on testing data.</p>

<hr>
<b style = 'font-size:28px;font-family:Arial;color:#E37C4D'>9. Cleanup</b>
<p style = 'font-size:18px;font-family:Arial;color:#E37C4D'><b>Work Tables</b></p>
<p style = 'font-size:16px;font-family:Arial'>Cleanup work tables to prevent errors next time.</p>

In [None]:
eng.execute('DROP TABLE heart_failure_train;')

In [None]:
eng.execute('DROP TABLE heart_failure_test;')

<p style = 'font-size:18px;font-family:Arial;color:#E37C4D'> <b>Databases and Tables </b></p>
<p style = 'font-size:16px;font-family:Arial'>The following code will clean up tables and databases created above.</p>

In [None]:
%run -i ../run_procedure.py "call remove_data('DEMO_SurvivalAnalysis');"        # Takes 5 seconds

In [None]:
remove_context()

<b style = 'font-size:28px;font-family:Arial;color:#E37C4D'>Dataset:</b>

- `id`: patient id
- `death`: If the patient is deceased(boolean)
- `los`: length of stay (in days)
- `age`: age of the patient (in years)
- `gender`: gender of the patient (1-female, 2-male)
- `cancer`: If the patient has cancer (boolean)
- `cabg`: If the patient has gone through Coronary Artery Bypass Graft procedure (boolean)
- `crt`: If the patient has gone through Cardiac Resynchronization Therapy (boolean)
- `defib`: If the patient has defibrillator (boolean)
- `dementia`: If the patient has dementia (boolean)
- `diabetes`: If the patient has diabetes (boolean)
- `hypertension`: If the patient has hypertension (boolean)
- `ihd`: If the patient has Ischemic Heart Disease (boolean)
- `mental_health`: If the patient has been diagnosed with mental health issues (boolean)
- `arrhythmias`: If the patient has arrhythmia (boolean)
- `copd`: If the patient has Chronic Obstructive Pulmonary Disease (boolean)
- `obesity`: If the patient has obesity (boolean)
- `pvd`: If the patient has Peripheral Vascular Disease (boolean)
- `renal_disease`: If the patient has Renal Disease (boolean)
- `valvular_disease`: If the patient has Valvular Disease (boolean)
- `metastatic_cancer`: If the patient has Metastatic Cancer (boolean)
- `pacemaker`: If the patient has pacemaker (boolean)
- `pneumonia`: If the patient has pneumonia (boolean)
- `prior_appts_attended`: Number of prior appointments attended by the patient
- `prior_dnas`: If the patient has takes Prior DNA test (boolean)
- `pci`: If the patient has gone though Percutaneous Coronary Intervention procedure (boolean)
- `senile`: If the patient has Senile amyloidosis (SSA) (boolean)
- `fu_time`: Time since last follow-up (in days)

<footer style="padding:10px;background:#f9f9f9;border-bottom:3px solid #394851">Copyright © Teradata Corporation - 2023. All Rights Reserved.</footer>