In [None]:
import polars as pl
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import numpy as np
import shap

data = pl.read_csv(r'C:\Users\Noah Legall\LegallLab\SHAP-mTB-AMR\scripts\rpoBC_fullseq.csv')
# use this code if working locally on my own computer - the path is different.
#data = pl.read_csv(r'C:\Users\noah_\SHAP-mTB-AMR\scripts\rpoBC_fullseq.csv')

# Identify all positional columns (those starting with "pos_")
pos_cols = [col for col in data.columns if col.startswith("pos_")]

# Keep only columns where more than one unique nucleotide value appears
variable_pos = [
    col for col in pos_cols
    if data.select(pl.col(col).n_unique()).item() > 1
]

var_columnns = data.select(variable_pos).columns
X = data.select(variable_pos).to_numpy()
y = data["phenotype"].to_numpy()

# Split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Random Forest
rf = RandomForestClassifier(n_estimators=50, random_state=42)
rf.fit(X_train, y_train)
rf_acc = accuracy_score(y_test, rf.predict(X_test))
print("ðŸŒ² Random Forest accuracy:", rf_acc)

importance = rf.feature_importances_
top_idx = sorted(np.argsort(importance)[-100:])  # keep top 100
X_top = X[:, top_idx]

#the top index for var_columnms
top_columns = [var_columnns[i] for i in top_idx]
print(top_columns)

# Split
X_train, X_test, y_train, y_test = train_test_split(X_top, y, test_size=0.2, random_state=42)

# Random Forest
rf_small = RandomForestClassifier(n_estimators=50, random_state=42)
rf_small.fit(X_train, y_train)
rf_acc = accuracy_score(y_test, rf_small.predict(X_test))
print("ðŸŒ² Random Forest accuracy:", rf_acc)

explainer = shap.TreeExplainer(rf_small, feature_perturbation="interventional", feature_names=top_columns, approximate =True)
shap_values = explainer(X_top, check_additivity=False)