In [None]:
# Importing necessary libraries
import pandas as pd
import joblib
import shap
import lime
import lime.lime_tabular
import numpy as np
import matplotlib.pyplot as plt

# Load the saved model
model_name = "Random Forest"  # Change as needed
model_path = f"C:/Users/bam/Desktop/Week-8/notebooks/trained_model_{model_name}.joblib"
model = joblib.load(model_path)

# Load the predictions
predictions_path = "C:/Users/bam/Desktop/Week-8/notebooks/predictions.csv"
predictions = pd.read_csv(predictions_path)

# Load the model metrics
metrics_path = "C:/Users/bam/Desktop/Week-8/notebooks/model_metrics.csv"
metrics = pd.read_csv(metrics_path)

# Print model metrics for reference
print("Model Metrics:")
print(metrics)

# Inspect the predictions DataFrame
print("Predictions DataFrame:")
print(predictions.head())  
print("Columns in Predictions DataFrame:", predictions.columns.tolist())  # List of columns

if 'predicted_class' in predictions.columns:
    
    X = predictions.drop(columns=['predicted_class'])  
else:
    print("Column 'predicted_class' not found in predictions DataFrame.")

    X = predictions 

X = pd.get_dummies(X, drop_first=True)

# SHAP Explanation
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)

# SHAP Summary Plot
shap.summary_plot(shap_values[1], X)  

# SHAP Force Plot for a specific instance 
shap.initjs()
shap.force_plot(explainer.expected_value[1], shap_values[1][0], X.iloc[0])

# SHAP Dependence Plot for a specific feature 
shap.dependence_plot('purchase_value', shap_values[1], X)

# LIME Explanation
# Create LIME explainer
lime_explainer = lime.lime_tabular.LimeTabularExplainer(
    training_data=np.array(X),
    feature_names=X.columns,
    mode='classification'
)

# Choose a specific instance to explain 
instance_index = 0
exp = lime_explainer.explain_instance(
    data_row=X.iloc[instance_index].values,
    predict_fn=model.predict_proba
)

# LIME Feature Importance Plot
exp.as_pyplot_figure()
plt.title(f'LIME Explanation for Instance {instance_index}')
plt.show()
