# Reproduce ANCHOR results from the paper

### 1. Init

In [1]:
from anchor import anchor_tabular
import numpy as np
import pickle
import pandas as pd

np.random.seed(1)

### 2. Load Train/Test Sets

In [2]:
name = 'adult'

# Load train set
title = "scripts/datasets/train_set_"+name+"_strat.p"
train = open(title,"rb")
train_set: pd.DataFrame = pickle.load(train)

title = "scripts/datasets/train_label_"+name+"_strat.p"
train_l = open(title,"rb")
train_label: pd.DataFrame = pickle.load(train_l)

# Load Test set
title = "scripts/datasets/test_set_" + name + "_strat.p"
test = open(title, "rb")
test_set: pd.DataFrame = pickle.load(test)

title = "scripts/datasets/test_label_" + name + "_strat.p"
test_l = open(title, "rb")
test_label: pd.DataFrame = pickle.load(test_l)

# Reset indices
train_set = train_set.reset_index(drop=True)
train_label = train_label.reset_index(drop=True)
test_set = test_set.reset_index(drop=True)
test_label = test_label.reset_index(drop=True)

# Convert object columns to int
train_set = train_set.astype(int)
test_set = test_set.astype(int)

print("Length train set:", len(train_set))
print("Length test set:", len(test_set))

Length train set: 21113
Length test set: 9049


### 3a. Fit models

In [3]:
### Fit xgboost, lightgbm and catboost ###
### The models provided with the repo don't match the paper's own prepared dataset

from xgboost import XGBClassifier
xgb = XGBClassifier(
    C= 1,
    penalty='l2',
    objective='binary:logistic',
    seed = 42,
    bootstrap=True,
    max_depth=90,
    learning_rate=0.1,
    n_estimators=500,
    tree_method='auto'
)
xgb.fit(train_set.values, train_label.values)

from sklearn.linear_model import LogisticRegression
lg = LogisticRegression(
    penalty='l2',
    C=1.0,
    solver='lbfgs',     # or 'liblinear' (both support L2)
    max_iter=1000,
    random_state=42
)
lg.fit(train_set.values, train_label.values)

from catboost import CatBoostClassifier
cat = CatBoostClassifier(
    iterations=500,         # Equivalent to n_estimators
    depth=10,               # Equivalent to max_depth
    learning_rate=0.1,      # Same learning rate
    loss_function='Logloss',  # For binary classification
    bootstrap_type='Bayesian',  # Equivalent to bootstrap in XGBoost
    random_seed=42,
    od_type='Iter',         # Early stopping after 'iterations'
    od_wait=50              # How many iterations to wait for improvement
)
cat.fit(train_set.values, train_label.values)

Parameters: { C, bootstrap, penalty } might not be used.

  This may not be accurate due to some parameters are only used in language bindings but
  passed down to XGBoost core.  Or some parameters are not used but slip through this
  verification. Please open an issue if you find above cases.


0:	learn: 0.6058914	total: 73.6ms	remaining: 36.7s
1:	learn: 0.5310385	total: 90.9ms	remaining: 22.6s
2:	learn: 0.4791864	total: 97.3ms	remaining: 16.1s
3:	learn: 0.4388315	total: 110ms	remaining: 13.7s
4:	learn: 0.4083986	total: 125ms	remaining: 12.4s
5:	learn: 0.3848253	total: 138ms	remaining: 11.4s
6:	learn: 0.3676019	total: 150ms	remaining: 10.5s
7:	learn: 0.3561208	total: 155ms	remaining: 9.55s
8:	learn: 0.3459560	total: 168ms	remaining: 9.17s
9:	learn: 0.3372917	total: 181ms	remaining: 8.87s
10:	learn: 0.3312674	total: 194ms	remaining: 8.62s
11:	learn: 0.3257599	total: 206ms	remaining: 8.39s
12:	learn: 0.3224239	total: 218ms	remaining: 8.17s
13:	learn: 0.3185056	total: 231ms	remaining: 8.

<catboost.core.CatBoostClassifier at 0x7f6a29788f90>

### 3b. Define explainer

In [4]:
# Define Anchor explainer
explainer = anchor_tabular.AnchorTabularExplainer(
    class_names=train_label.unique(),
    feature_names=train_set.columns,
    train_data=train_set.values
)

### 4. Select the samples from the paper

In [5]:
# Filter out the same sample as in the paper
x1 = test_set[
    (test_set["education-num"] == 13) &     # Bachelors
    (test_set["occupation"] == 3) &         # Prof-speciality
    (test_set["sex"] == 0) &                # Male
    (test_set["native-country"] == 36) &    # Vietnam
    (test_set["age"] == 35) &               
    (test_set["workclass"] == 3) &
    (test_set["hours-per-week"] == 40) &
    (test_set["race"] == 3) &               # Asian-Pac-Islander
    (test_set["marital-status"] == 1) &     # Married-civ
    (test_set["relationship"] == 1) &       # Husband
    (test_set["capital-gain"] == 0) &
    (test_set["capital-loss"] == 0)
]
y1 = test_label.loc[x1.index]

print("Sample x1:\n", x1)
print('>50k?')
print("Label:", y1.item())
print("xgb:", xgb.predict(x1.values).item())
print("lg:", lg.predict(x1.values).item())
print("cat:", cat.predict(x1.values).item())


x2 = test_set[
    (test_set["education-num"] == 9) &     # College
    (test_set["occupation"] == 10) &         # Sales
    #(test_set["sex"] == 0) &                # Male
    #(test_set["native-country"] == 41) &    # US
    (test_set["age"] == 19) &               
    #(test_set["workclass"] == 2) &
    (test_set["hours-per-week"] == 15) &
    #(test_set["race"] == 1) &                 # White
    #(test_set["marital-status"] == 1) &     # Married-civ
    #(test_set["relationship"] == 1) &       # Husband
    (test_set["capital-gain"] == 0) &    
    (test_set["capital-loss"] == 0)
]
y2 = test_label.loc[x2.index]

print("Sample x1:\n", x2)
print('>50k?')
print("Label:", y2.item())
print("xgb:", xgb.predict(x1.values).item())
print("lg:", lg.predict(x1.values).item())
print("cat:", cat.predict(x1.values).item())

Sample x1:
      age  workclass  fnlwgt  education-num  marital-status  occupation  \
166   35          3  110188             13               1           3   

     relationship  race  sex  capital-gain  capital-loss  hours-per-week  \
166             1     3    0             0             0              40   

     native-country  
166              36  
>50k?
Label: 0
xgb: 0
lg: 0
cat: 0
Sample x1:
       age  workclass  fnlwgt  education-num  marital-status  occupation  \
5609   19          2  119964              9               7          10   

      relationship  race  sex  capital-gain  capital-loss  hours-per-week  \
5609             4     1    1             0             0              15   

      native-country  
5609              41  
>50k?
Label: 0
xgb: 0
lg: 0
cat: 0


### 5. Explain!

In [6]:
exp = explainer.explain_instance(
    data_row=x1.values,
    classifier_fn=xgb.predict,
    threshold=0.95
)

#exp.show_in_notebook()
print(exp.names())
print('Anchor explanation x1: %s' % (' AND '.join(exp.names())))
print('Precision: %.2f' % exp.precision())
print('Coverage: %.2f' % exp.coverage())


exp = explainer.explain_instance(
    data_row=x2.values,
    classifier_fn=xgb.predict,
    threshold=0.95
)

#exp.show_in_notebook()
print('Anchor explanation x2: %s' % (' AND '.join(exp.names())))
print('Precision: %.2f' % exp.precision())
print('Coverage: %.2f' % exp.coverage())

['age <= 37.00', 'race > 1.00', 'hours-per-week <= 40.00', 'capital-gain <= 0.00', 'capital-loss <= 0.00', 'occupation <= 10.00', 'fnlwgt <= 179171.00']
Anchor explanation x1: age <= 37.00 AND race > 1.00 AND hours-per-week <= 40.00 AND capital-gain <= 0.00 AND capital-loss <= 0.00 AND occupation <= 10.00 AND fnlwgt <= 179171.00
Precision: 0.96
Coverage: 0.02
Anchor explanation x2: age <= 28.00 AND marital-status > 2.00
Precision: 0.98
Coverage: 0.20


### 6a. Calculate Fidelity

In [None]:
import re
from tqdm import tqdm

def conv_rule_to_pandas_query(rule):
    """
    Converts Anchor rules (list of strings) into a valid pandas .query() expression.
    Supports compound range rules like '3.00 < occupation <= 5.00'.
    """
    query_parts = []
    
    for cond in rule:
        # Match compound ranges like '3.00 < occupation <= 5.00'
        match = re.match(r"([\d\.]+)\s*<\s*(.+)\s*<=\s*([\d\.]+)", cond)
        if match:
            low, col, high = match.groups()
            col = col.strip()
            expr = f"({float(low)} < `{col}` <= {float(high)})"
            query_parts.append(expr)
            continue
        
        # Match simple binary comparisons
        match = re.match(r"(.+?)\s*(<=|>=|<|>|==)\s*([\d\.]+)", cond)
        if match:
            col, op, val = match.groups()
            col = col.strip()
            expr = f"(`{col}` {op} {float(val)})"
            query_parts.append(expr)
            continue

        raise ValueError(f"Unsupported condition format: {cond}")
    
    return " & ".join(query_parts)

def calculate_fidelity(model, explainer, dataset, tresh=0.95, num_samples=100):
    fidelity_scores = []
    
    pbar = tqdm(range(num_samples), desc="Calculating fidelity")
    for i in pbar:
        # Get the instance and its true label
        x = dataset.iloc[i]
        y_true = model.predict(x.values.reshape(1, -1)).item()  # model prediction for the instance

        # Explain the instance with Anchor
        exp = explainer.explain_instance(x.values, classifier_fn=model.predict, threshold=tresh)
        rule = exp.names()  # e.g., ["age > 30", "capital-gain <= 0"]

        if not rule:
            pbar.set_postfix({"skipped": i})
            continue  # Skip if no rule found

        # Convert the Anchor rule to a pandas query expression
        query_str = conv_rule_to_pandas_query(rule)
        
        # Sample instances that satisfy the rule condition
        covered = dataset.query(query_str)
        
        # Check how many of the covered instances match the model's original prediction
        pred_match = model.predict(covered.values) == y_true
        fidelity = pred_match.mean() if len(covered) > 0 else 0
        
        fidelity_scores.append(fidelity)
        pbar.set_postfix({"current_fid": f"{fidelity:.2f}"})

    avg_fidelity = np.mean(fidelity_scores)
    print("Final fidelity:", avg_fidelity)
    return avg_fidelity

thresh = 0.4 # default: 0.95
fidelity_lg = calculate_fidelity(lg, explainer, test_set, tresh=thresh)
fidelity_xgb = calculate_fidelity(xgb, explainer, test_set, tresh=thresh)
fidelity_cat = calculate_fidelity(cat, explainer, test_set, tresh=thresh)

Calculating fidelity: 100%|██████████| 100/100 [00:24<00:00,  4.03it/s, skipped=99]     


Final fidelity: 0.8147183402477902


Calculating fidelity: 100%|██████████| 100/100 [00:06<00:00, 14.70it/s, skipped=99]     


Final fidelity: 0.7841283062729928


Calculating fidelity: 100%|██████████| 100/100 [00:19<00:00,  5.24it/s, current_fid=0.98]

Final fidelity: 0.854246580338278





### 6b. Calculate Faithfulness

### 7. Calculate Stability

In [27]:
from tqdm import tqdm
from copy import deepcopy
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import MinMaxScaler

def jaccard_distance(exp_x: set, exp_x_neigh: set, debug):
    """Distance between two explanations."""
    union = exp_x | exp_x_neigh
    if not union:
        return 0.0  # If both are empty, treat as identical
    jacc = 1 - len(exp_x & exp_x_neigh) / len(union)
    if debug:
        print("\tJacc:\t", jacc)
    return jacc

def hamming_distance(exp_x: set, exp_x_neigh: set, debug=False):
    """Distance between two explanations using Hamming distance."""
    # Calculate the symmetric difference between the two sets
    diff = exp_x ^ exp_x_neigh  # ^ is symmetric difference
    # Normalize the distance by the maximum possible number of differences (size of the union)
    max_diff = max(len(exp_x), len(exp_x_neigh))
    if max_diff == 0:
        return 0.0  # If both are empty, treat as identical
    
    hamming = len(diff) / max_diff
    
    if debug:
        print("\tHamming:\t", hamming)
    
    return hamming

def eucl_distance(x1: np.ndarray, x2: np.ndarray, debug):
    """Euclidean distance between inputs."""
    eucl = np.linalg.norm(x1 - x2)
    if debug:
        print("\tEucl:\t", eucl)
    return eucl

def norm_eucl_distance(x1: np.ndarray, x2: np.ndarray, scaler: MinMaxScaler, debug):
    """Euclidean distance between inputs, first normalizing features."""
    
    # Normalize the input vectors (x1 and x2)
    x1_normalized = scaler.transform([x1])
    x2_normalized = scaler.transform([x2])

    # Calculate the Euclidean distance between the normalized vectors
    eucl_dist = np.linalg.norm(x1_normalized - x2_normalized)

    if debug:
        print("\tEuc_norm:\t", eucl_dist)
        
    return eucl_dist

def calc_lipschitz(exp_x, exp_xp, x, xp, debug, sim, scaler=None):
    if sim == "jaccard":
        exp_dist = jaccard_distance(exp_x, exp_xp, debug)
    else:
        exp_dist = hamming_distance(exp_x, exp_xp, debug)

    if scaler is not None:
        input_dist = norm_eucl_distance(x, xp, scaler, debug)
    else:
        input_dist = eucl_distance(x, xp, debug)

    lip = exp_dist/input_dist
    if debug:
        print("\tLip:\t", lip)
    return lip

def generate_perturbation_neighborhood(x: pd.Series, dataset: pd.DataFrame, num_samples=10, max_perturbation=2):
    """Generates a synthetic neighborhood around x by applying small perturbations."""
    neighborhood = []

    for _ in range(num_samples):
        x_prime = deepcopy(x)

        for col in dataset.columns:
            # Apply a small perturbation, limiting to max_perturbation in magnitude
            perturbation = np.random.randint(-max_perturbation, max_perturbation + 1)
            new_value = x[col] + perturbation

            # Ensure the perturbed value is non-negative (not below 0)
            x_prime[col] = max(0, new_value)

        neighborhood.append(x_prime.values)

    return np.array(neighborhood)

def sample_neighborhood(x: np.ndarray, nn: NearestNeighbors, dataset: pd.DataFrame, k=10):
    _, indices = nn.kneighbors(x, n_neighbors=k)
    neigh_idcs = indices[0][1:]
    neigh_vals = dataset.iloc[neigh_idcs].values
    return neigh_vals


def calculate_stability(model, explainer, dataset, k_neighbors=10, thresh=0.95, num_samples=1, debug=False, neigh="gen", sim="jaccard", norm=False):
    assert neigh in ["gen","sampled"]
    assert sim in ["jaccard","hamming"]

    stability_scores = []
    if neigh == "sampled":
        nn = NearestNeighbors(n_neighbors=k_neighbors + 1).fit(dataset.values)

    if norm:
        scaler = MinMaxScaler()
        scaler.fit_transform(dataset.values)

    for i in tqdm(range(num_samples), desc="Calculating stability..."):
        x = dataset.iloc[i]
        x_val = x.values.reshape(1, -1)
        if debug:
            print("Row:\t", list(x.values))

        try:
            anchor_x = explainer.explain_instance(x_val[0], classifier_fn=model.predict, threshold=thresh)
            rule_x = set(anchor_x.names())
            if not rule_x:
                if debug:
                    print(f"No rule on index {i}. Skipping instance...")
                continue  # Skip if no rule was found
            if debug:
                print("Exp:\t", anchor_x.names())
        except:
            if debug:
                print(f"Anchor failed on index {i}. Skipping instance...")
            continue  # Skip instances where anchor fails

        # Find neighbors (excluding self)
        if neigh == "gen":
            neigh_vals = generate_perturbation_neighborhood(x, dataset, num_samples=k_neighbors)
        else:
            neigh_vals = sample_neighborhood(x_val, nn, dataset)

        lipschitz_vals = []
        for j, x_prime in enumerate(neigh_vals):
            if debug:
                print(f"\n\tNN{j}:\t", list(x_prime))
            #try:
            anchor_xp = explainer.explain_instance(x_prime.reshape(1, -1), classifier_fn=model.predict, threshold=thresh)
            if debug:
                print("\tExp:\t", anchor_xp.names())
            
            rule_xp = set(anchor_xp.names())

            lip = calc_lipschitz(
                exp_x=rule_x,
                exp_xp=rule_xp,
                x=x.values,
                xp=x_prime,
                debug=debug,
                sim=sim,
                scaler=scaler if norm else None
            )
            lipschitz_vals.append(lip)
            #except:
                #print(f"Error occured in neighbor explanation. Skipping instance.")
                #continue  # If explanation fails, skip

        if lipschitz_vals:
            max_lip = np.max(lipschitz_vals)
            if debug:
                print("\nMax Lipstein:\t", max_lip)
            stability_scores.append(max_lip)

    return np.mean(stability_scores) if stability_scores else 0.0

num_samples = 5
debug = False
neigh = "sampled" # "gen" | "sampled"
sim = "hamming" # "hamming" | "jaccard"
norm = True

stabilities = []
for model in [lg, xgb, cat]:
    stability = calculate_stability(
        xgb, explainer, test_set, 
        num_samples=num_samples,
        debug=debug,
        neigh=neigh,
        sim=sim,
        norm=norm
    )
    stabilities.append(stability)

print("LG Stability:", stabilities[0])
print("XGB Stability:", stabilities[1])
print("CAT Stability:", stabilities[2])

Calculating stability...: 100%|██████████| 5/5 [00:33<00:00,  6.74s/it]
Calculating stability...: 100%|██████████| 5/5 [00:32<00:00,  6.42s/it]
Calculating stability...: 100%|██████████| 5/5 [00:34<00:00,  6.92s/it]

LG Stability: 2.4909891679452683
XGB Stability: 2.618334095670109
CAT Stability: 2.7905672606973715



