Skip to content

Commit

Permalink
Print a warning if the prediction of the original input is already co…
Browse files Browse the repository at this point in the history
…nsistent with the requested prediction
  • Loading branch information
andreArtelt committed Jul 12, 2019
1 parent 55cc681 commit f1b6881
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 6 deletions.
14 changes: 11 additions & 3 deletions ceml/sklearn/counterfactual.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-

from abc import ABC, abstractmethod
import logging

from ..optim import InputWrapper, desc_to_optim
from ..model import Counterfactual
from ..backend.jax.costfunctions import RegularizedCost
Expand Down Expand Up @@ -70,6 +71,10 @@ def build_loss(self, regularization, x_orig, y_target, pred, grad_mask, C, input

return loss, loss_grad

def warn_if_already_done(self, x, done):
if done(self.model.predict([x])[0]):
logging.warning("The prediction of the input 'x' is already consistent with the requested prediction 'y_target' - It might not make sense to search for a counterfactual!")

def __build_result_dict(self, x_cf, y_cf, delta):
return {'x_cf': x_cf, 'y_cf': y_cf, 'delta': delta}

Expand Down Expand Up @@ -155,11 +160,14 @@ def compute_counterfactual(self, x, y_target, features_whitelist=None, regulariz
# Hide the input in a wrapper if we can use a subset of features only
input_wrapper, x_orig, pred, grad_mask = self.wrap_input(features_whitelist, x, optimizer)

# Check if the prediction of the given input is already consistent with y_target
done = done if done is not None else y_target if callable(y_target) else lambda y: y == y_target
self.warn_if_already_done(x, done)

# Repeat for all C
if not type(C) == list:
C = [C]
done = done if done is not None else y_target if callable(y_target) else lambda y: y == y_target


for c in C:
# Build loss
loss, loss_grad = self.build_loss(regularization, x_orig, y_target, pred, grad_mask, c, input_wrapper)
Expand Down
5 changes: 5 additions & 0 deletions ceml/sklearn/decisiontree.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,11 @@ def compute_counterfactual(self, x, y_target, features_whitelist=None, regulariz
(x_cf, y_cf, delta) : triple if `return_as_dict` is False
"""
# Check if the prediction of the given input is already consistent with y_target
done = y_target if callable(y_target) else lambda y: y == y_target
self.warn_if_already_done(x, done)

# Compute all counterfactual
counterfactuals = self.compute_all_counterfactuals(x, y_target, features_whitelist, regularization)

# Select the one with the smallest score
Expand Down
5 changes: 4 additions & 1 deletion ceml/sklearn/randomforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,13 @@ def compute_counterfactual(self, x, y_target, features_whitelist=None, regulariz
# Try to compute a counter factual for each of the models and use this counterfactual as a starting point
x_start = self.__compute_initial_values(x, y_target, features_whitelist)

# Check if the prediction of the given input is already consistent with y_target
done = done if done is not None else y_target if callable(y_target) else lambda y: y == y_target
self.warn_if_already_done(x, done)

# Repeat for all C
if not type(C) == list:
C = [C]
done = done if done is not None else y_target if callable(y_target) else lambda y: y == y_target

for x0 in x_start:
input_wrapper, x_orig, pred, grad_mask = self.wrap_input(features_whitelist, x0, optimizer)
Expand Down
11 changes: 10 additions & 1 deletion ceml/tfkeras/counterfactual.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# -*- coding: utf-8 -*-
import logging
import tensorflow as tf
import numpy as np

from ..backend.tensorflow.layer import create_tensor, create_mutable_tensor
from ..backend.tensorflow.costfunctions import RegularizedCost
from ..backend.tensorflow.optimizer import desc_to_optim
Expand Down Expand Up @@ -55,6 +57,10 @@ def loss_grad_npy(x):

return loss, loss_npy, loss_grad_npy

def warn_if_already_done(self, x, done):
if done(self.model.predict(np.array([x]))):
logging.warning("The prediction of the input 'x' is already consistent with the requested prediction 'y_target' - It might not make sense to search for a counterfactual!")

def __build_result_dict(self, x_cf, y_cf, delta):
return {'x_cf': x_cf, 'y_cf': y_cf, 'delta': delta}

Expand Down Expand Up @@ -145,10 +151,13 @@ def compute_counterfactual(self, x, y_target, features_whitelist=None, regulariz
# Hide the input in a wrapper if we can use a subset of features only
input_wrapper, x_orig, _, grad_mask = self.wrap_input(features_whitelist, x, optimizer)

# Check if the prediction of the given input is already consistent with y_target
done = y_target if callable(y_target) else lambda y: y == y_target
self.warn_if_already_done(x, done)

# Repeat for all C
if not type(C) == list:
C = [C]
done = y_target if callable(y_target) else lambda y: y == y_target

for c in C:
# Build loss
Expand Down
11 changes: 10 additions & 1 deletion ceml/torch/counterfactual.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# -*- coding: utf-8 -*-
import logging
import torch
import numpy as np

from ..backend.torch.layer import create_tensor
from ..backend.torch.costfunctions import RegularizedCost
from ..backend.torch.optimizer import desc_to_optim
Expand Down Expand Up @@ -60,6 +62,10 @@ def loss_grad_npy(x):

return loss, loss_npy, loss_grad_npy

def warn_if_already_done(self, x, done):
if done(self.model.predict(create_tensor(x, self.device), dim=0).numpy()):
logging.warning("The prediction of the input 'x' is already consistent with the requested prediction 'y_target' - It might not make sense to search for a counterfactual!")

def __build_result_dict(self, x_cf, y_cf, delta):
return {'x_cf': x_cf, 'y_cf': y_cf, 'delta': delta}

Expand Down Expand Up @@ -156,10 +162,13 @@ def compute_counterfactual(self, x, y_target, features_whitelist=None, regulariz
# Hide the input in a wrapper if we can use a subset of features only
input_wrapper, x_orig, _, grad_mask = self.wrap_input(features_whitelist, x, optimizer)

# Check if the prediction of the given input is already consistent with y_target
done = y_target if callable(y_target) else lambda y: y == y_target
self.warn_if_already_done(x, done)

# Repeat for all C
if not type(C) == list:
C = [C]
done = y_target if callable(y_target) else lambda y: y == y_target

for c in C:
# Build loss
Expand Down

0 comments on commit f1b6881

Please sign in to comment.