In [None]:
# CD4 Improvement Prediction Project

# Step 1: Import libraries
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
import shap
import matplotlib.pyplot as plt
import seaborn as sns

# Step 2: Load the dataset (from prepared file)
df = pd.read_csv("HealthGymV2_CbdrhDatathon_ART4HIV.csv")

# Step 3: Prepare features for classification
baseline = df[df['Timestep'] == 0][['PatientID', 'CD4', 'VL', 'Gender', 'Ethnic', 'Base Drug Combo']]
month6 = df[df['Timestep'] == 6][['PatientID', 'CD4']]
baseline = baseline.rename(columns={'CD4': 'CD4_baseline', 'VL': 'VL_baseline'})
month6 = month6.rename(columns={'CD4': 'CD4_6mo'})
merged_df = pd.merge(baseline, month6, on='PatientID')
merged_df['CD4_improved'] = (merged_df['CD4_6mo'] > merged_df['CD4_baseline']).astype(int)

# Step 4: Feature selection
features = ['CD4_baseline', 'VL_baseline', 'Gender', 'Ethnic', 'Base Drug Combo']
X = merged_df[features]
y = merged_df['CD4_improved']

# Step 5: Split the data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Step 6: Train the model
clf = RandomForestClassifier(random_state=42)
clf.fit(X_train, y_train)

# Step 7: Evaluate the model
y_pred = clf.predict(X_test)
print("Accuracy:", accuracy_score(y_test, y_pred))
print("Confusion Matrix:\n", confusion_matrix(y_test, y_pred))
print("Classification Report:\n", classification_report(y_test, y_pred))

# Step 8: Interpret the model using SHAP
explainer = shap.TreeExplainer(clf)
shap_values = explainer.shap_values(X_test)
shap.summary_plot(shap_values[1], X_test)

# Step 9: Show correct and incorrect predictions
results_df = X_test.copy()
results_df['Actual'] = y_test.values
results_df['Predicted'] = y_pred
correct_preds = results_df[results_df['Actual'] == results_df['Predicted']].head(3)
incorrect_preds = results_df[results_df['Actual'] != results_df['Predicted']].head(3)
print("\nCorrect Predictions:\n", correct_preds)
print("\nIncorrect Predictions:\n", incorrect_preds)