# problem

Let:
1. $X$ be the dataset, with datapoints $\{x_i\}^n_1$ 
2. the features for $x_i$ are $\{z^i_{j}\}^k_1$
3. $Y$ be the dependent variable $\{y_i\}^n_1$
4. A model $f : X \rightarrow \hat{Y}$  

In [None]:
import sys
from pathlib import Path
import os
os.chdir(Path.cwd().parent)

In [None]:
!dir

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import shap
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import r2_score, mean_absolute_error

import sys
from pathlib import Path
from config import DATA_DIR
import pickle
with open(DATA_DIR / 'boston_housing_dataset.pkl', 'rb') as rf:
    data_dict = pickle.load(rf)


data = data_dict['data']
metadata = data_dict['metadata']

# === Separate features and target ===
X = data.drop(columns=['MEDV'])
y = data['MEDV']

# === Train/test split ===
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

# === Fit a baseline model ===
model = RandomForestRegressor(random_state=42, n_estimators=200)
model.fit(X_train, y_train)

y_pred = model.predict(X_test)
print(f"RÂ²: {r2_score(y_test, y_pred):.3f}")
print(f"MAE: {mean_absolute_error(y_test, y_pred):.3f}")


# Ceteris Paribus (CP) Plots

a line plot where the y axis is $f(x^i_{z_j = v})$ and x axis is $v$, where $v$ lies in the range between range of feature $z_j$ from the dataset. $[\min(z_j),\max(z_j)]$ 

In [None]:
# Efficient multi-point Ceteris Paribus (ICE) implementation
# - Handles multiple datapoints at once
# - Predicts in batches per feature (vectorized) for efficiency
# - Returns {feature: (figure, dataframe_used)}
#
# Minimal demonstration included at the bottom.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from io import BytesIO


def ceteris_paribus_bytes_multi(model, X: pd.DataFrame, num_datapoints: int = 1, grid_points: int = 50):
    """
    Build Ceteris Paribus (ICE) plots for 'num_datapoints' rows from X and return
    images as PNG bytes.

    Args:
        model: fitted estimator with predict() and optionally predict_proba()
        X (pd.DataFrame): data defining grids and base rows
        num_datapoints (int): number of rows (taken from the top of X)
        grid_points (int): grid resolution per feature

    Returns:
        dict[str, tuple[bytes, pd.DataFrame]]:
            feature -> (png_bytes, data)
            where data columns: ['row_id','feature_value','prediction']
    """
    num_datapoints = max(1, min(num_datapoints, len(X)))
    X_targets = X.iloc[:num_datapoints].copy()
    out = {}

    # Binary probability if available
    proba_fn = getattr(model, "predict_proba", None)
    is_binary = False
    if proba_fn is not None:
        classes_ = getattr(model, "classes_", None)
        is_binary = classes_ is not None and len(classes_) == 2

    for feature in X.columns:
        if not np.issubdtype(X[feature].dtype, np.number):
            continue

        grid = np.linspace(X[feature].min(), X[feature].max(), grid_points)

        base = np.repeat(X_targets.values, grid_points, axis=0)
        base_df = pd.DataFrame(base, columns=X.columns)
        base_df[feature] = np.tile(grid, num_datapoints)

        if proba_fn is not None and is_binary:
            preds = model.predict_proba(base_df)[:, 1]
            print(f"using predict_proba")
        else:
            preds = model.predict(base_df)
        
        row_ids = np.repeat(np.arange(num_datapoints), grid_points)
        df_plot = pd.DataFrame(
            {"row_id": row_ids, "feature_value": base_df[feature].to_numpy(), "prediction": preds.astype(float)}
        )

        # Plot one line per row_id + dashed mean curve
        fig = plt.figure()
        ax = fig.gca()
        for rid in range(num_datapoints):
            sub = df_plot[df_plot["row_id"] == rid]
            ax.plot(sub["feature_value"], sub["prediction"], alpha=0.9, linewidth=1.6)
        mean_curve = df_plot.groupby("feature_value", as_index=False)["prediction"].mean()
        ax.plot(mean_curve["feature_value"], mean_curve["prediction"], linestyle="--", linewidth=2.0)

        ax.set_xlabel(feature)
        ax.set_ylabel("predicted probability" if is_binary else "prediction")
        ax.set_title(f"Ceteris Paribus (n={num_datapoints}): {feature}")
        ax.grid(True, alpha=0.3)

        # Serialize figure to PNG bytes
        buf = BytesIO()
        fig.savefig(buf, format="png", bbox_inches="tight", dpi=160)
        plt.close(fig)
        png_bytes = buf.getvalue()
        buf.close()

        out[feature] = (png_bytes, df_plot)

    return out


def load_and_show_png_bytes(img_bytes: bytes, title: str | None = None):
    """
    Loader function: read PNG bytes and plot them with matplotlib.
    """
    from io import BytesIO
    import matplotlib.image as mpimg

    arr = mpimg.imread(BytesIO(img_bytes), format="png")
    plt.figure()
    plt.imshow(arr)
    plt.axis("off")
    if title:
        plt.title(title)
    plt.show()

# -------- Minimal example to illustrate --------
# Build CP plots for the first 3 datapoints
# Build CP plots as PNG bytes for the first 3 datapoints
cp_bytes = ceteris_paribus_bytes_multi(model, X_test, num_datapoints=3, grid_points=100)

print("Features:", list(cp_bytes.keys()))

# Show two loaded images and peek at their data
for feat, (img_bytes, df_data) in cp_bytes.items():
    print(f"\nData sample for '{feat}':")
    load_and_show_png_bytes(img_bytes, title=feat)

In [18]:
print(cp_bytes.keys())

dict_keys(['CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS', 'RAD', 'TAX', 'PTRATIO', 'B', 'LSTAT'])


In [19]:
print(type(cp_bytes['CRIM'][0]))

<class 'bytes'>


In [None]:
def load_and_save_png_bytes(img_bytes_or_tuple, title=None, out_path='temp.jpg'):
    """
    Accept either raw PNG bytes or a (png_bytes, data_df) tuple,
    display and save the image to out_path.
    """
    from io import BytesIO
    import matplotlib.image as mpimg

    # accept (bytes, df) or bytes
    if isinstance(img_bytes_or_tuple, (tuple, list)):
        img_bytes = img_bytes_or_tuple[0]
    else:
        img_bytes = img_bytes_or_tuple

    arr = mpimg.imread(BytesIO(img_bytes), format='png')
    plt.figure()
    plt.imshow(arr)
    plt.axis('off')
    if title:
        plt.title(title)
    plt.savefig(out_path, bbox_inches='tight', pad_inches=0)
    plt.close()

load_and_save_png_bytes(cp_bytes['CRIM'])