In [3]:
# ---------------------------------------------------------
# Assignment 7: Naïve Bayes Classifier (Using sklearn)
# ---------------------------------------------------------

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import numpy as np
import pandas as pd

# ---------------------------------------------------------
# Step 1: Load the dataset
# ---------------------------------------------------------
iris = load_iris()
X = iris.data        # features
y = iris.target      # labels (0,1,2)
feature_names = iris.feature_names
class_names = iris.target_names

# Optional: small readable preview
df = pd.DataFrame(X, columns=feature_names)
df['target'] = y
print(" Dataset loaded")
print(df.head(), "\n")

# ---------------------------------------------------------
# Step 2: Train–test split (stratified to keep class balance)
# ---------------------------------------------------------
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.30, random_state=42, stratify=y
)

print(f"Train shape: {X_train.shape}, Test shape: {X_test.shape}\n")

# ---------------------------------------------------------
# Step 3: Create and train the model
# ---------------------------------------------------------
model = GaussianNB()
model.fit(X_train, y_train)

# ---------------------------------------------------------
# Step 4: Predict + Probabilities
# ---------------------------------------------------------
y_pred = model.predict(X_test)
y_proba = model.predict_proba(X_test)   # probability for each class

# ---------------------------------------------------------
# Step 5: Evaluation
# ---------------------------------------------------------
acc = accuracy_score(y_test, y_pred)
cm = confusion_matrix(y_test, y_pred)
report = classification_report(y_test, y_pred, target_names=class_names)

print("------------------------------------------------------")
print("Naïve Bayes Classifier — Evaluation")
print("------------------------------------------------------")
print(f"Accuracy: {acc*100:.2f}%\n")

print("Confusion Matrix (rows=true, cols=pred):")
print(cm, "\n")

print("Classification Report:")
print(report)

# Class-wise accuracy (diagonal / row sum)
row_sums = cm.sum(axis=1, keepdims=True)
class_acc = (cm.diagonal() / row_sums.flatten())
print("Class-wise Accuracy:")
for name, a in zip(class_names, class_acc):
    print(f"  {name:>10}: {a*100:5.2f}%")

print("------------------------------------------------------")

# ---------------------------------------------------------
# Quick peek: first 5 test predictions with probs
# ---------------------------------------------------------
peek_n = 5
print(f"\nSample predictions (first {peek_n}):")
for i in range(min(peek_n, len(X_test))):
    probs = ", ".join([f"{cls}={p:.2f}" for cls, p in zip(class_names, y_proba[i])])
    print(f"True: {class_names[y_test[i]]:>10} | Pred: {class_names[y_pred[i]]:>10} | Proba -> {probs}")


 Dataset loaded
   sepal length (cm)  sepal width (cm)  petal length (cm)  petal width (cm)  \
0                5.1               3.5                1.4               0.2   
1                4.9               3.0                1.4               0.2   
2                4.7               3.2                1.3               0.2   
3                4.6               3.1                1.5               0.2   
4                5.0               3.6                1.4               0.2   

   target  
0       0  
1       0  
2       0  
3       0  
4       0   

Train shape: (105, 4), Test shape: (45, 4)

------------------------------------------------------
Naïve Bayes Classifier — Evaluation
------------------------------------------------------
Accuracy: 91.11%

Confusion Matrix (rows=true, cols=pred):
[[15  0  0]
 [ 0 14  1]
 [ 0  3 12]] 

Classification Report:
              precision    recall  f1-score   support

      setosa       1.00      1.00      1.00        15
  versicolor   