# Naive Bayes — Mushroom Edibility Notebook

This notebook mirrors the production code in `src/` and offers an exploratory workspace for understanding and extending the mushroom Naive Bayes classifier. Work through the sections sequentially to validate the data pipeline, interrogate model behaviour, and document findings you want to port back into the FastAPI service.

**Roadmap**

- Inspect the dataset, normalise missing markers, and compute baseline statistics.
- Rebuild the scikit-learn preprocessing + `CategoricalNB` pipeline from the `src/` package.
- Evaluate the model with accuracy, precision, recall, F1, ROC-AUC, and confusion matrices.
- Probe feature influence via class-conditional likelihood ratios and partial dependence sketches.
- Capture follow-up experiments (alpha sweeps, grouped categories, monitoring ideas).


In [None]:
"""Environment imports mirroring the production pipeline."""
from pathlib import Path
import json

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from IPython.display import display

from sklearn.metrics import (
    accuracy_score,
    confusion_matrix,
    f1_score,
    precision_score,
    recall_score,
    roc_auc_score,
    RocCurveDisplay,
    PrecisionRecallDisplay,
    ConfusionMatrixDisplay,
 )
from sklearn.model_selection import train_test_split

from src.config import CONFIG as NAIVE_CONFIG, NaiveBayesConfig
from src.data import load_dataset, build_features, train_validation_split
from src.pipeline import MushroomPipeline, train_and_persist

In [None]:
sns.set_theme(style="whitegrid")

config: NaiveBayesConfig = NAIVE_CONFIG
raw_df = load_dataset(config)
display(raw_df.head())
print(f"Total rows: {len(raw_df):,}")
print("Missing values per column:")
display(raw_df.isna().sum().sort_values(ascending=False))

## 1. Dataset Overview

The mushroom dataset encodes categorical descriptors with single-character tokens. Missing stalk-root values appear as `NaN` after normalisation above and will be imputed with the modal category inside the pipeline. Use the head and summary statistics to verify the column order aligns with `NaiveBayesConfig.feature_columns`.

In [None]:
X, y = build_features(raw_df, config)
print(f"Features shape: {X.shape}")
print(f"Target distribution:\n{y.value_counts().rename(index={0: 'edible', 1: 'poisonous'})}")

### Odor signal strength

Naive Bayes excels when individual features offer strong class signals. Odor categories such as `foul` and `creosote` nearly guarantee toxicity; plotting their conditional probabilities confirms the intuition the algorithm captures during training.

In [None]:
odor_counts = (
    pd.crosstab(X["odor"], y, normalize="index")
    .rename(columns={0: "edible", 1: "poisonous"})
    .sort_values("poisonous", ascending=False)
 )
display(odor_counts.head())
odor_counts.plot(kind="bar", stacked=True, figsize=(10, 4))
plt.title("Poisonous probability by odor class")
plt.ylabel("Proportion")
plt.xlabel("Odor category")
plt.tight_layout()

## 2. Train/Validation Split

Re-create the deterministic 80/20 stratified split used in the scripted pipeline. Metrics in the notebook should match `src/train.py` when using the same configuration.

In [None]:
X_train, X_val, y_train, y_val = train_validation_split(config)
print(f"Train size: {X_train.shape[0]:,} | Validation size: {X_val.shape[0]:,}")
print("Training class balance:")
display(y_train.value_counts(normalize=True).rename(index={0: 'edible', 1: 'poisonous'}))
print("Validation class balance:")
display(y_val.value_counts(normalize=True).rename(index={0: 'edible', 1: 'poisonous'}))

## 3. Rebuild the Production Pipeline

Instantiate `MushroomPipeline`, fit on the training fold, and persist artefacts to mirror the CLI workflow. If you rerun this notebook after modifying the pipeline or configuration, ensure artefacts are refreshed for consistency with the FastAPI service.

In [None]:
pipeline = MushroomPipeline(config)
metrics = pipeline.train()
artifact_path = pipeline.save()
metrics_path = pipeline.write_metrics(metrics)
print("Metrics:")
display(metrics)
print(f"Model artifact: {artifact_path}")
print(f"Metrics file: {metrics_path}")

In [None]:
y_val_pred = pipeline.pipeline.predict(X_val)
y_val_proba = pipeline.pipeline.predict_proba(X_val)[:, 1]

metric_frame = {
    "accuracy": float(accuracy_score(y_val, y_val_pred)),
    "precision": float(precision_score(y_val, y_val_pred)),
    "recall": float(recall_score(y_val, y_val_pred)),
    "f1": float(f1_score(y_val, y_val_pred)),
    "roc_auc": float(roc_auc_score(y_val, y_val_proba)),
}
display(metric_frame)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

ConfusionMatrixDisplay.from_predictions(
    y_val,
    y_val_pred,
    display_labels=["edible", "poisonous"],
    cmap="Blues",
    colorbar=False,
    ax=axes[0],
)
axes[0].set_title("Confusion matrix")

RocCurveDisplay.from_predictions(
    y_val,
    y_val_proba,
    name="Naive Bayes",
    ax=axes[1],
)
axes[1].set_title("ROC curve")
axes[1].plot([0, 1], [0, 1], linestyle="--", color="grey", alpha=0.6)

plt.tight_layout()

## 4. Inspect Class-Conditional Likelihoods

`CategoricalNB` captures token-level likelihoods under the hood. Surfacing them helps sanity-check whether the model relies on botanically plausible signals. The helper below extracts log-probabilities for a subset of influential features.

In [None]:
encoder = pipeline.pipeline.named_steps["preprocessor"].named_transformers_["categorical"].named_steps["encoder"]
classifier = pipeline.pipeline.named_steps["classifier"]
feature_names = encoder.get_feature_names_out(config.feature_columns)
likelihood_frame = pd.DataFrame(
    {
        "feature": feature_names,
        "log_prob_poisonous": classifier.feature_log_prob_[1],
        "log_prob_edible": classifier.feature_log_prob_[0],
    }
)
likelihood_frame["log_ratio"] = likelihood_frame["log_prob_poisonous"] - likelihood_frame["log_prob_edible"]
top_poisonous = likelihood_frame.sort_values("log_ratio", ascending=False).head(10)
top_edible = likelihood_frame.sort_values("log_ratio").head(10)
display(top_poisonous)
display(top_edible)

## 5. Experiment Ideas & Notes

- **Smoothing sweep**: iterate over `alpha` values (e.g., 0.1 → 10) to observe calibration vs. accuracy trade-offs.
- **Category grouping**: cluster rare habitats or cap colors and re-evaluate the likelihood tables for stability.
- **Monitoring hooks**: log the top contributing one-hot tokens per prediction when serving via FastAPI to aid alerting workflows.
- **Alternative models**: prototype tree-based baselines (Random Forest, Gradient Boosting) inside the same train/serve pattern for A/B testing.
- **Data quality checks**: add assertions that ensure no unseen categories slip into production by comparing encoder vocabularies across training runs.
