# CatBoost Model: Drift & Segmentation Analysis
This notebook demonstrates how to train a CatBoost model, analyze drift, and perform segmentation analysis with interactive Plotly visualizations using the `tab-right` package.

In [None]:
# Install dependencies if running in Colab or a fresh environment
# !pip install catboost plotly pandas scikit-learn tab-right

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import plotly.io as pio
from catboost import CatBoostClassifier
from sklearn.metrics import log_loss

pio.renderers.default = "notebook"

## Load Example Dataset
We'll use the UCI Adult dataset (census income) from OpenML.

In [None]:
from sklearn.datasets import fetch_openml

data = fetch_openml("adult", version=2, as_frame=True)
df = data.frame.copy()
df = df.sample(frac=1, random_state=42).reset_index(drop=True)  # Shuffle
df = df.dropna()  # Drop missing for simplicity
df["target"] = (df["class"] == ">50K").astype(int)
df = df.drop(columns=["class"])
df.head()

## Split Data: Reference vs. Current
We'll simulate drift by splitting the data by time (first 70% as reference, last 30% as current).

In [None]:
split_idx = int(0.7 * len(df))
df_ref = df.iloc[:split_idx].reset_index(drop=True)
df_cur = df.iloc[split_idx:].reset_index(drop=True)
print(f"Reference: {df_ref.shape}, Current: {df_cur.shape}")

## Train CatBoost Model
We'll train on the reference data and predict on the current data.

In [None]:
cat_features = df_ref.select_dtypes(include="category").columns.tolist() + [
    col for col in df_ref.columns if df_ref[col].dtype == "object"
]
cat_features = list(set(cat_features) - set(["target"]))
X_ref = df_ref.drop(columns=["target"])
y_ref = df_ref["target"]
X_cur = df_cur.drop(columns=["target"])
y_cur = df_cur["target"]
model = CatBoostClassifier(
    cat_features=cat_features, iterations=50, depth=3, learning_rate=0.1, random_seed=42, verbose=0
)
model.fit(X_ref, y_ref)
y_pred = model.predict(X_cur)

## Segmentation Analysis
Let's segment the predictions by features and visualize the results using tab_right.

In [None]:
# Import required modules from tab_right
import numpy as np

from tab_right.plotting.plot_segmentations import DoubleSegmPlotting
from tab_right.segmentations.double_seg import DoubleSegmentationImp
from tab_right.segmentations.find_seg import FindSegmentationImp

# Import specific modules from tab_right
from tab_right.task_detection import detect_task

In [None]:
# Get model predictions with probabilities
y_pred_proba = model.predict_proba(X_cur)

# Create a DataFrame with features, true labels, and predictions
df_analysis = X_cur.copy()
df_analysis["target"] = y_cur
df_analysis["pred_class"] = y_pred
df_analysis["pred_prob_0"] = y_pred_proba[:, 0]
df_analysis["pred_prob_1"] = y_pred_proba[:, 1]

# Detect task type (should be binary classification)
task_type = detect_task(y_cur)
print(f"Detected task type: {task_type.value}")

### Define Error Metrics
Let's define some error metrics for our segmentation analysis.

In [None]:
def binary_log_loss(y_true, y_pred_df):
    """Calculate binary log loss for each row."""
    y_pred = y_pred_df["pred_prob_1"]
    epsilon = 1e-15
    y_pred = np.clip(y_pred, epsilon, 1 - epsilon)
    return -(y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred))


def binary_error(y_true, y_pred_df):
    """Calculate binary classification error (0/1 loss) for each row."""
    y_pred = y_pred_df["pred_class"]
    return (y_true != y_pred).astype(float)

### Single Feature Segmentation
Let's analyze how model performance varies across different segments of a single feature.

In [None]:
df_analysis.columns

In [None]:
# Initialize the segmentation finder
segmentation_finder = FindSegmentationImp(
    df=df_analysis, label_col="target", prediction_col=["pred_prob_0", "pred_prob_1"]
)

### Double Feature Segmentation
Now let's analyze how model performance varies across segments defined by two features.

In [None]:
def error_func(y_true, y_pred):
    """Calculate mean log loss for a segment."""
    return (y_pred - y_true).abs()


# Initialize the double segmentation
double_segmentation = DoubleSegmentationImp(segmentation_finder)

# Define feature pairs to analyze
feature_pairs = [("age", "education-num")]

# Analyze each feature pair and visualize
for feature1, feature2 in feature_pairs:
    print(f"\nAnalyzing feature pair: {feature1} and {feature2}")

    # Find double segmentation
    double_segments = double_segmentation(
        feature1,
        feature2,
        error_func,
        tree_model,
        score_metric=log_loss,
    )
    print(f"Found {len(double_segments)} segment combinations")

    # Display the segments
    display(double_segments.head())

    # Create double segmentation plotter using the default column name 'score'
    double_plotter = DoubleSegmPlotting(df=double_segments)

    # Plot the heatmap
    heatmap_fig = double_plotter.plotly_heatmap()
    heatmap_fig.update_layout(
        title=f"Log Loss Heatmap: {feature1} vs {feature2}", xaxis_title=feature1, yaxis_title=feature2
    )
    heatmap_fig.show()

## Performance Analysis by Features
Let's take a deeper look at how the model performs across categorical features.

In [None]:
# Analyze categorical features
import plotly.express as px
from sklearn.metrics import roc_auc_score


def safe_roc_auc_score(y_true, y_pred):
    """Calculate ROC AUC score with error handling for single-class data."""
    try:
        if len(set(y_true)) < 2:
            return None  # Not enough classes for ROC AUC
        return roc_auc_score(y_true, y_pred)
    except ValueError:
        return None


# Select some categorical features to analyze
cat_features_to_analyze = ["workclass", "education", "marital-status", "occupation", "relationship"]

for cat_feature in cat_features_to_analyze:
    # Group by the categorical feature
    grouped = df_analysis.groupby(cat_feature).agg({
        "target": ["count", "mean"],
        "pred_class": lambda x: accuracy_score(df_analysis.loc[x.index, "target"], x),
        "pred_prob_1": lambda x: safe_roc_auc_score(df_analysis.loc[x.index, "target"], x),
    })

    # Flatten the column hierarchy
    grouped.columns = [f"{col[0]}_{col[1]}" if col[1] else col[0] for col in grouped.columns]
    grouped = grouped.rename(columns={"pred_class_<lambda>": "accuracy", "pred_prob_1_<lambda>": "auc"})
    grouped = grouped.reset_index()

    # Filter out categories with no AUC score for better visualization
    grouped_filtered = grouped[grouped["auc"].notna()]

    # Plot accuracy by category
    if not grouped_filtered.empty:
        fig = px.bar(
            grouped_filtered,
            x=cat_feature,
            y="accuracy",
            color="target_count",
            hover_data=["target_mean", "auc", "target_count"],
            labels={
                "accuracy": "Accuracy",
                "target_count": "Sample Count",
                "target_mean": "Positive Rate",
                "auc": "AUC",
            },
            title=f"Model Performance by {cat_feature}",
        )
        fig.update_layout(xaxis_tickangle=-45)
        fig.show()