In [4]:
from interpretml_utils import *
from interpret.glassbox import ExplainableBoostingClassifier, ExplainableBoostingRegressor
import pandas as pd
import numpy as np  
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

# Loading dataset
### (German)

In [5]:
# Load German Credit Dataset
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/statlog/german/german.data"
columns = [
    'checking_status', 'duration', 'credit_history', 'purpose', 'credit_amount',
    'savings_account', 'employment', 'installment_rate', 'personal_status_sex',
    'other_debtors', 'present_residence', 'property', 'age', 'other_installment_plans',
    'housing', 'existing_credits', 'job', 'num_maintenance', 'telephone', 'foreign_worker', 'target'
]

df = pd.read_csv(url, sep=' ', names=columns, header=None)

# Preprocessing
# Create binary sex feature (Male=1, Female=0)
df['sex'] = df['personal_status_sex'].apply(lambda x: 'male' if x in ['A91', 'A93', 'A94'] else 'female')

# Convert target to binary (Good credit=1, Bad credit=0)
df['target'] = df['target'].replace({1: 1, 2: 0})

features = df.columns.tolist()
features.remove('target')

X = df[features]
y = df['target']

# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Training baseline models

In [6]:
male_model = ExplainableBoostingClassifier(feature_names=X.columns.tolist())
male_model.fit(X_train[X_train['sex'] == 'male'], y_train[X_train['sex'] == 'male'])

female_model = ExplainableBoostingClassifier(feature_names=X.columns.tolist())
female_model.fit(X_train[X_train['sex'] == 'female'], y_train[X_train['sex'] == 'female'])

normal_model = ExplainableBoostingClassifier(feature_names=X.columns.tolist())
normal_model.fit(X_train, y_train)

print("done")

done


In [7]:
display(male_model.bins_)
display(female_model.bins_)

[[{'A11': 1, 'A12': 2, 'A13': 3, 'A14': 4}],
 [array([ 4.5,  5.5,  6.5,  7.5,  8.5,  9.5, 10.5, 11.5, 12.5, 13.5, 14.5,
         15.5, 17. , 19. , 20.5, 21.5, 23. , 25. , 26.5, 27.5, 29. , 33. ,
         37.5, 39.5, 41. , 43.5, 46.5, 51. , 57. ]),
  array([ 4.5,  5.5,  6.5,  7.5,  8.5,  9.5, 10.5, 11.5, 12.5, 13.5, 14.5,
         15.5, 17. , 19. , 20.5, 21.5, 23. , 25. , 26.5, 27.5, 29. , 33. ,
         37.5, 41. , 43.5, 46.5, 51. , 57. ])],
 [{'A30': 1, 'A31': 2, 'A32': 3, 'A33': 4, 'A34': 5}],
 [{'A40': 1,
   'A41': 2,
   'A410': 3,
   'A42': 4,
   'A43': 5,
   'A44': 6,
   'A45': 7,
   'A46': 8,
   'A48': 9,
   'A49': 10}],
 [array([  338.5,   382.5,   440. ,   488. ,   546.5,   580.5,   604. ,
           621.5,   627. ,   634. ,   639.5,   647. ,   664. ,   679.5,
           688. ,   694. ,   698.5,   700.5,   704. ,   707.5,   708.5,
           713. ,   718. ,   723. ,   728.5,   742. ,   756.5,   767.5,
           778.5,   791.5,   803. ,   835. ,   883. ,   901. ,   905.5,
     

[[{'A11': 1, 'A12': 2, 'A13': 3, 'A14': 4}],
 [array([ 5. ,  6.5,  7.5,  8.5,  9.5, 10.5, 11.5, 12.5, 14. , 16.5, 19.5,
         22.5, 25.5, 28.5, 31.5, 34.5, 39. , 44.5, 47.5, 54. ])],
 [{'A30': 1, 'A31': 2, 'A32': 3, 'A33': 4, 'A34': 5}],
 [{'A40': 1,
   'A41': 2,
   'A410': 3,
   'A42': 4,
   'A43': 5,
   'A44': 6,
   'A45': 7,
   'A46': 8,
   'A48': 9,
   'A49': 10}],
 [array([  296.5,   352.5,   377. ,   400.5,   418.5,   430.5,   440.5,
           528.5,   630.5,   659. ,   669. ,   677. ,   682.5,   712. ,
           745.5,   751.5,   756.5,   761.5,   776.5,   792.5,   796. ,
           801.5,   821. ,   838.5,   850.5,   867. ,   880. ,   890. ,
           904.5,   923.5,   941.5,   955.5,   968. ,   979.5,   989.5,
          1012. ,  1035. ,  1044.5,  1048. ,  1052. ,  1073.5,  1102.5,
          1118. ,  1124.5,  1128.5,  1158. ,  1186.5,  1190.5,  1195.5,
          1202. ,  1206.5,  1211.5,  1222. ,  1229.5,  1233.5,  1236.5,
          1238.5,  1249. ,  1266. ,  1274.5,  127

In [8]:
ff_model = CombinedEBM([male_model, female_model], [0.5, 0.5])
ff_model_obj = ff_model.get_model_object()

merged_model = merge_ebms([male_model, female_model])

# Displaying with custom EBMVisualizer

In [9]:
visualizer = InterpretmlEBMVisualizer([male_model, female_model, normal_model, ff_model_obj, merged_model], ["Male Model", "Female Model", "Normal Model", "50-50 Model", "Merged Model"])
visualizer.show()

HBox(children=(VBox(children=(Dropdown(description='Feature:', options=(('checking_status', 0), ('duration', 1…

# Group Performance Plots

In [10]:
male_model = ExplainableBoostingClassifier(feature_names=X.columns.tolist())
male_model.fit(X_train[X_train['sex'] == 'male'], y_train[X_train['sex'] == 'male'])

female_model = ExplainableBoostingClassifier(feature_names=X.columns.tolist())
female_model.fit(X_train[X_train['sex'] == 'female'], y_train[X_train['sex'] == 'female'])

normal_model = ExplainableBoostingClassifier(feature_names=X.columns.tolist())
normal_model.fit(X_train, y_train)

print("done")

done


In [11]:
foi = 'sex'
_x = X_train
_y = y_train

male_mask = _x[foi] == 'male'
female_mask = _x[foi] == 'female'

In [12]:
%matplotlib widget
plt.ioff()  # Avoids duplicate plots
analyzer = GroupPerformanceAnalyzer(
    male_model, female_model, normal_model,
    _x, _y,
    male_mask=male_mask, female_mask=female_mask,
    feature_of_interest='sex',
    combine_strategy='post',
    metric='log_likelihood',
)
analyzer.generate_plot(n_combinations=100)

Evaluating combinations:   0%|          | 0/100 [00:00<?, ?it/s]

Evaluating combinations: 100%|██████████| 100/100 [00:00<00:00, 131.90it/s]


HBox(children=(VBox(children=(HTML(value='<b>Model Details:</b>'), Output()), layout=Layout(margin='0 20px', w…

In [None]:
%matplotlib widget
plt.ioff()
analyzer = GenericGroupPerformanceAnalyzer(
    models_to_combine=[
        ("Male Model", male_model),
        ("Female Model", female_model),
        ("Normal Model", normal_model),
    ],
    baseline_models=[
    ],
    X_test=_x, y_test=_y,
    male_mask=male_mask, female_mask=female_mask,
    feature_of_interest='sex',
    metric='log_likelihood'
)
analyzer.generate_plot(n_combinations=100)

Evaluating combinations: 100%|██████████| 100/100 [00:01<00:00, 88.22it/s]


HBox(children=(VBox(children=(HTML(value='<b>Model Details:</b>'), Output()), layout=Layout(margin='0 20px', w…

In [22]:
import ipywidgets as widgets


class GenericGroupPerformanceAnalyzer:
    def __init__(self, models_to_combine: List[tuple[str, EBMModel]],
                 baseline_models: List[tuple[str, EBMModel]],
                 X_test: pd.DataFrame, y_test: np.ndarray,
                 X_train: pd.DataFrame = None, y_train: np.ndarray = None,
                 male_mask: np.ndarray = None, female_mask: np.ndarray = None,
                 feature_of_interest: str = 'sex',
                 metric: Literal["accuracy", "log_likelihood", "auc"] = "accuracy"):
        
        self.models_to_combine = np.array(models_to_combine)
        self.baseline_models = np.array(baseline_models)
        self.X_test = X_test
        self.y_test = y_test
        self.X_train = X_train
        self.y_train = y_train
        self.feature_of_interest = feature_of_interest
        self.metric = metric
        
        # Create masks for groups
        if male_mask is None or female_mask is None:
            # Only set default values if masks weren't passed
            self.male_mask = X_test[feature_of_interest] == 1
            self.female_mask = X_test[feature_of_interest] == 0
        else:
            # Use the passed masks
            self.male_mask = male_mask
            self.female_mask = female_mask
        
        self.fig = None
        self.ax = None
        self.scatter_plots = {}  # Dictionary to store scatter plots by group
        self.info_output = Output()
        self.metrics_data = []
        self.combination_groups = []
        self.group_data = {}  # Dictionary to store data by group
    
    def _get_weighed_model(self, model: EBMModel, weight: float) -> EBMModel:
        new_model = model
        
        if hasattr(new_model, 'predict_proba'):
            new_model.predict_proba = lambda X: model.predict_proba(X) * weight
        else:
            new_model.predict = lambda X: model.predict(X) * weight
        
        return new_model

    def _combine_models(self, weights: list[float]) -> ExplainableBoostingClassifier:
        """Combine models using InterpretML's API capabilities"""
        return CombinedEBM(self.models_to_combine[:, 1], weights)
        
    def _evaluate_model(self, model) -> dict:
        """Evaluate model using InterpretML's prediction format"""
        if self.metric == "accuracy":
            y_pred = model.predict(self.X_test)

            return {
                f'male_{self.metric}': np.mean(self.y_test[self.male_mask] == y_pred[self.male_mask]),
                f'female_{self.metric}': np.mean(self.y_test[self.female_mask] == y_pred[self.female_mask]),
                f'overall_{self.metric}': np.mean(self.y_test == y_pred)
            }
        
        elif self.metric == "log_likelihood":
            y_probs = model.predict_proba(self.X_test)
            eps = 1e-10
            
            # For male samples
            male_probs = y_probs[self.male_mask]
            male_true = self.y_test[self.male_mask]
            male_ll = np.mean(np.log(male_probs[range(len(male_true)), male_true] + eps))
            
            # For female samples
            female_probs = y_probs[self.female_mask]
            female_true = self.y_test[self.female_mask]
            female_ll = np.mean(np.log(female_probs[range(len(female_true)), female_true] + eps))
            
            # For all samples
            overall_ll = np.mean(np.log(y_probs[range(len(self.y_test)), self.y_test] + eps))
            
            return {
                f'male_{self.metric}': male_ll,
                f'female_{self.metric}': female_ll,
                f'overall_{self.metric}': overall_ll
            }
        elif self.metric == "auc":
            y_probs = model.predict_proba(self.X_test)[:, 1]  # Get probabilities for class 1
            
            return {
                f'male_{self.metric}': roc_auc_score(self.y_test[self.male_mask], y_probs[self.male_mask]),
                f'female_{self.metric}': roc_auc_score(self.y_test[self.female_mask], y_probs[self.female_mask]),
                f'overall_{self.metric}': roc_auc_score(self.y_test, y_probs)
            }
            
        raise ValueError(f"Unknown metric: {self.metric}")

    def _plot_baseline_models(self):
        """Initialize baseline models with InterpretML-specific handling"""
        base_models = self.baseline_models.tolist() + self.models_to_combine.tolist()
        
        for label, model in base_models:
            metrics = self._evaluate_model(model)
            x_val = metrics[f'male_{self.metric}']
            y_val = metrics[f'female_{self.metric}']
            
            self.ax.scatter(x_val, y_val, s=100, edgecolors='black', 
                            label=label, zorder=10)   

    def _generate_zero_weight_combinations(self, n_combinations, zero_index):
        """Generate combinations where the specified model has zero weight"""
        num_models = len(self.models_to_combine)
        combinations = []
        
        for _ in range(n_combinations):
            # Generate weights for non-zero models
            non_zero_weights = np.random.dirichlet(np.ones(num_models - 1))
            
            # Insert zero at the specified index
            weights = np.insert(non_zero_weights, zero_index, 0)
            combinations.append(weights)
            
        return np.array(combinations)

    def generate_plot(self, n_combinations: int = 100):
        """Generate the main performance comparison plot"""
        num_models = len(self.models_to_combine)
        
        # Group 1: Standard combinations with all models
        standard_weights = np.random.dirichlet(np.ones(num_models), n_combinations)
        
        # Create a list to store all combination groups with their colors and labels
        self.combination_groups = [
            {"id": "all_models", "weights": standard_weights, "color": "blue", "label": "All Models"}
        ]
        
        # Add zero-weight groups if we have 3 or more models
        if num_models >= 3:
            for i in range(min(3, num_models)):
                model_name = self.models_to_combine[i][0]
                zero_weights = self._generate_zero_weight_combinations(n_combinations // 3, i)
                
                self.combination_groups.append({
                    "id": f"without_{i}",
                    "weights": zero_weights,
                    "color": ["red", "green", "purple"][i],  # Different color for each group
                    "label": f"Without {model_name}"
                })
        
        # Evaluate all combinations
        self.metrics_data = []
        self.group_data = {}
        
        for group in self.combination_groups:
            group_metrics = []
            
            for w in tqdm(group["weights"], desc=f"Evaluating {group['label']}"):
                combined = self._combine_models(w)
                metrics = self._evaluate_model(combined)
                metrics.update({
                    'weights': w,
                    'group_id': group["id"],
                    'group_label': group["label"],
                    'color': group["color"]
                })
                group_metrics.append(metrics)
                self.metrics_data.append(metrics)
            
            self.group_data[group["id"]] = group_metrics

        # Create plot with adjusted figsize and more space for legend
        self.fig, self.ax = plt.subplots(figsize=(10, 8))
        
        # Adjust the subplot to make room for the legend
        plt.subplots_adjust(right=0.75)
        
        # Plot each group with its own color
        for group in self.combination_groups:
            group_id = group["id"]
            group_data = self.group_data[group_id]
            
            if group_data:
                x_values = [m[f'male_{self.metric}'] for m in group_data]
                y_values = [m[f'female_{self.metric}'] for m in group_data]
                
                scatter = self.ax.scatter(x_values, y_values, c=group["color"], 
                                        alpha=0.6, label=group["label"])
                self.scatter_plots[group_id] = scatter
        
        self._plot_baseline_models()
        self._configure_plot()
        self._setup_interactivity()
        
        # Create and display the interactive dashboard
        display(self._create_display())

    def _configure_plot(self):
        """Configure plot aesthetics for InterpretML consistency"""
        self.ax.set_xlabel(f"Male {self.metric.title().replace('_', ' ')}",
                         fontsize=12)
        self.ax.set_ylabel(f"Female {self.metric.title().replace('_', ' ')}",
                         fontsize=12)
        
        if self.metric == "accuracy":
            self.ax.plot([0, 1], [0, 1], 'k--', alpha=0.3)
            
        self.ax.grid(True, alpha=0.3)
        
        # Move legend outside the plot to prevent cutting
        legend = self.ax.legend(loc='center left', bbox_to_anchor=(1.02, 0.5), 
                             frameon=False, fontsize=10)
        legend.set_title("Groups", prop={'size': 12})

    def _setup_interactivity(self):
        """Add interactive tooltips with model details"""
        cursor = mplcursors.cursor(list(self.scatter_plots.values()))
        
        @cursor.connect("add")
        def on_add(sel):
            with self.info_output:
                clear_output(wait=True)
                
                # Find which scatter plot was selected
                selected_scatter = sel.artist
                point_index = sel.index
                
                # Find the group that this scatter plot belongs to
                selected_group_id = None
                for group_id, scatter in self.scatter_plots.items():
                    if scatter == selected_scatter:
                        selected_group_id = group_id
                        break
                
                if selected_group_id and point_index < len(self.group_data[selected_group_id]):
                    metrics = self.group_data[selected_group_id][point_index]
                    
                    weights_str = ', '.join([f"{name}: {w:.2f}" for name, w in zip(self.models_to_combine[:, 0], metrics['weights'])])
                    display(HTML(
                        f"<div style='border: 1px solid #ccc; padding: 10px; border-radius: 5px;'>"
                        f"<b>Group:</b> {metrics['group_label']}<br>"
                        f"<b>Weights:</b> {weights_str}<br>"
                        f"<b>Male {self.metric.title()}:</b> {metrics[f'male_{self.metric}']:.3f}<br>"
                        f"<b>Female {self.metric.title()}:</b> {metrics[f'female_{self.metric}']:.3f}<br>"
                        f"<b>Overall {self.metric.title()}:</b> {metrics[f'overall_{self.metric}']:.3f}"
                        "</div>"
                    ))

    def _toggle_group_visibility(self, group_id, change):
        """Toggle visibility of a group's scatter plot"""
        if group_id in self.scatter_plots:
            scatter = self.scatter_plots[group_id]
            scatter.set_visible(change['new'])
            self.fig.canvas.draw_idle()

    def _create_checkboxes(self):
        """Create checkboxes for toggling group visibility"""
        checkbox_widgets = []
        
        for group in self.combination_groups:
            group_id = group["id"]
            checkbox = Checkbox(
                value=True,
                description=group["label"],
                style={'description_width': 'initial'},
                layout=
                widgets.Layout(margin='5px 0')
            )
            checkbox.observe(lambda change, gid=group_id: self._toggle_group_visibility(gid, change), names='value')
            checkbox_widgets.append(checkbox)
        
        return VBox(checkbox_widgets)

    def _create_display(self):
        """Create final widget layout with checkboxes for visibility control"""
        # Create checkboxes for plot control
        checkboxes = self._create_checkboxes()
        
        # Create the control panel
        control_panel = VBox([
            HTML("<b>Model Details:</b>"),
            self.info_output,
            HTML("<b>Show/Hide Groups:</b>"),
            checkboxes
        ], layout={'width': '300px', 'margin': '0 20px'})
        
        # Make the figure canvas wider to accommodate the legend
        fig_canvas = self.fig.canvas
        fig_canvas.layout.width = '800px'
        
        return HBox([
            control_panel,
            fig_canvas
        ])

In [None]:
%matplotlib widget
plt.ioff()
analyzer = GenericGroupPerformanceAnalyzer(
    models_to_combine=[
        ("Male Model", male_model),
        ("Female Model", female_model),
    ],
    baseline_models=[
    ],
    X_test=_x, y_test=_y,
    male_mask=male_mask, female_mask=female_mask,
    feature_of_interest='sex',
    metric='log_likelihood'
)
analyzer.generate_plot(n_combinations=100)

Evaluating All Models: 100%|██████████| 100/100 [00:01<00:00, 87.01it/s]
Evaluating Without Male Model: 100%|██████████| 33/33 [00:00<00:00, 87.89it/s]
Evaluating Without Female Model: 100%|██████████| 33/33 [00:00<00:00, 88.42it/s]
Evaluating Without Normal Model: 100%|██████████| 33/33 [00:00<00:00, 88.82it/s]


HBox(children=(VBox(children=(HTML(value='<b>Model Details:</b>'), Output(), HTML(value='<b>Show/Hide Groups:<…