In [None]:
from cf_implemented_algorithms import *


query_instance = X[11:12].to(device)

cat_features_idx = [feature_names.index(name) for name in categorical_features ] # list(set(categorical_features + actionable_features))

target_logit = 1.0

X_train_np = X_train.cpu().numpy()

from scipy.stats import median_abs_deviation

mads = []
for c in range(X_train_np.shape[1]):
    mad_c = median_abs_deviation(X_train_np[:, c], scale='normal')
    if mad_c == 0:
        mads.append(1.0)
    else:
        mads.append(mad_c)

mad = torch.from_numpy(np.array(mads)).float().to(device)

counterfactual = wachter(
    model=model,
    data=query_instance,
    target=1.0,  # same target as in the
    cat_features=cat_features_idx, # the algorithm does freeze categorical features
    mad=mad,
    positive_data=None, # this is not used in the algorithm
    lmbda = 1e-3
)

with torch.no_grad():
    # Get predictions
    original_output = torch.round(torch.sigmoid(model(query_instance))).item()
    cf_output = torch.round(torch.sigmoid(model(counterfactual))).item()

    # Calculate distance
    l1_dist = torch.norm(query_instance - counterfactual, p=1).item()

    print("Original Output: ",original_output)
    print("CF Output : ",cf_output)
    print("L1 distance:", l1_dist)

# Compare changed features (visualization)
original = query_instance.cpu().numpy().flatten()
cf = counterfactual.cpu().numpy().flatten()

changed_indices = np.where(~np.isclose(original, cf, atol=0.1))[0]
print("\nChanged features:")
for idx in changed_indices:
    print(f"{feature_names[idx]}: {original[idx]:.2f} → {cf[idx]:.2f}")



counterfactuals = wachter(
    model=model,
    data=X_negative,
    target=1.0,
    cat_features=cat_features_idx,
    mad=mad,
    positive_data= X_positive,
    lmbda = 1e-3
)

l1_distances = torch.norm(X_negative - counterfactuals, p=1, dim=1).cpu().numpy()
print(f"Mean L1 distance for negatively classified data using Wachter: {np.mean(l1_distances):.2f}")