In [None]:
# Train and save a color classifier using K-Nearest Neighbors

import pandas as pd
from sklearn.neighbors import KNeighborsClassifier
import pickle

# Load the training dataset
df = pd.read_csv("color_training_set.csv")

# Extract features (RGB values) and labels (color names)
X = df[["R", "G", "B"]].values  # Feature matrix
y = df["Label"].values  # Target labels

# Initialize and train the KNN model
knn = KNeighborsClassifier(n_neighbors=3)  # Using 3 neighbors for improved accuracy
knn.fit(X, y)

# Save the trained model to a file
with open("color_classifier.pkl", "wb") as f:
    pickle.dump(knn, f)

print("✅ Model saved successfully as 'color_classifier.pkl'")

✅ Model saved as color_classifier.pkl


In [None]:
import shap
import pandas as pd

# Create a SHAP explainer using KNN's probability predictions
explainer = shap.Explainer(knn.predict_proba, X)

# Compute SHAP values for all input samples
shap_values = explainer(X)

# Store per-class SHAP DataFrames
shap_dfs = []

# Iterate over each output class
for class_idx in range(shap_values.values.shape[2]):
    # Extract SHAP values for the current class
    shap_class_values = shap_values.values[:, :, class_idx]

    # Create a DataFrame for SHAP values and input features
    shap_df = pd.DataFrame(shap_class_values, columns=["SHAP_R", "SHAP_G", "SHAP_B"])
    shap_df["R"] = X[:, 0]
    shap_df["G"] = X[:, 1]
    shap_df["B"] = X[:, 2]

    # Add predicted labels and class index
    shap_df["Predicted_Label"] = knn.predict(X)
    shap_df["Class"] = class_idx

    # Append to the list
    shap_dfs.append(shap_df)

# Concatenate all class-wise SHAP DataFrames
final_shap_df = pd.concat(shap_dfs, ignore_index=True)

# Preview the result
final_shap_df.head()

  from .autonotebook import tqdm as notebook_tqdm


Unnamed: 0,SHAP_R,SHAP_G,SHAP_B,R,G,B,Predicted_Label,Class
0,0.076667,-0.293333,0.116667,4,222,10,Green,0
1,0.111111,0.082778,-0.293889,19,17,234,Blue,0
2,-0.27,0.05,0.12,223,35,1,Red,0
3,0.070556,-0.292778,0.122222,48,243,32,Green,0
4,-0.27,0.05,0.12,227,34,39,Red,0


In [None]:
import shap
import matplotlib.pyplot as plt

# Prepare SHAP Explanation object from DataFrame
shap_values_for_plot = shap.Explanation(
    values=final_shap_df[["SHAP_R", "SHAP_G", "SHAP_B"]].values,
    data=final_shap_df[["R", "G", "B"]].values,
    feature_names=["R", "G", "B"],
)

# Generate SHAP summary plot
shap.summary_plot(
    shap_values_for_plot,
    features=final_shap_df[["R", "G", "B"]].values,
    feature_names=["R", "G", "B"],
)