In [31]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import shap

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.metrics import accuracy_score, classification_report
from sklearn.ensemble import RandomForestClassifier

# -------------------------------
# Load dataset
# -------------------------------
data = pd.read_csv("crop_recommendation.csv")

# Encode categorical labels
le = LabelEncoder()
data['label'] = le.fit_transform(data['label'])

# Features and target
X = data.drop('label', axis=1)
y = data['label']

# Train-test split
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

# Scale features
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# -------------------------------
# Train model
# -------------------------------
model = RandomForestClassifier(random_state=42)
model.fit(X_train, y_train)

# Predictions
y_pred = model.predict(X_test)
print("Accuracy:", accuracy_score(y_test, y_pred))
print(classification_report(y_test, y_pred, target_names=le.classes_))

# -------------------------------
# Feature importance plot
# -------------------------------
importances = model.feature_importances_
indices = np.argsort(importances)[::-1]

plt.figure(figsize=(8, 5))
sns.barplot(x=importances[indices], y=X.columns[indices], hue=X.columns[indices], palette="viridis", legend=False)
plt.title("Feature Importance (RandomForest)")
plt.tight_layout()
plt.savefig("crop_feature_importance.png")
plt.close()

# -------------------------------
# SHAP Explainability
# -------------------------------
explainer = shap.TreeExplainer(model)
shap_values = explainer(X_test, check_additivity=False)  # FIX applied

print("SHAP shape:", shap_values.values.shape)

# Handle regression vs multi-output automatically
if shap_values.values.ndim == 3:  
    shap_to_plot = shap_values.values[:, :, 0]
else:
    shap_to_plot = shap_values.values

# ---- SHAP summary plot (bar) ----
shap.summary_plot(shap_to_plot, X_test, plot_type="bar", show=False)
plt.savefig("crop_shap_summary.png", bbox_inches="tight")
plt.close()

# ---- SHAP summary plot (beeswarm) ----
shap.summary_plot(shap_to_plot, X_test, show=False)
plt.savefig("crop_shap_beeswarm.png", bbox_inches="tight")
plt.close()


Accuracy: 0.9954545454545455
              precision    recall  f1-score   support

       apple       1.00      1.00      1.00        20
      banana       1.00      1.00      1.00        20
   blackgram       1.00      0.95      0.97        20
    chickpea       1.00      1.00      1.00        20
     coconut       1.00      1.00      1.00        20
      coffee       1.00      1.00      1.00        20
      cotton       1.00      1.00      1.00        20
      grapes       1.00      1.00      1.00        20
        jute       0.95      1.00      0.98        20
 kidneybeans       1.00      1.00      1.00        20
      lentil       1.00      1.00      1.00        20
       maize       0.95      1.00      0.98        20
       mango       1.00      1.00      1.00        20
   mothbeans       1.00      1.00      1.00        20
    mungbean       1.00      1.00      1.00        20
   muskmelon       1.00      1.00      1.00        20
      orange       1.00      1.00      1.00        2