Skip to content

Commit

Permalink
PyTorch module: Critical bugfix cost functions
Browse files Browse the repository at this point in the history
  • Loading branch information
andreArtelt committed Aug 4, 2020
1 parent b8e3947 commit aaa0c83
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 13 deletions.
14 changes: 6 additions & 8 deletions ceml/backend/torch/costfunctions/costfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,34 +110,32 @@ class NegLogLikelihoodCost(CostFunctionDifferentiableTorch):
"""
Negative-log-likelihood cost function.
"""
def __init__(self, input_to_output, y_target, **kwds):
def __init__(self, y_target, **kwds):
self.y_target = y_target
self.input_to_output = input_to_output

super().__init__(**kwds)

def score_impl(self, x):
def score_impl(self, y):
"""
Computes the loss - negative-log-likelihood.
"""
return negloglikelihood(self.input_to_output(x), self.y_target)
return negloglikelihood(y, self.y_target)


class SquaredError(CostFunctionDifferentiableTorch):
"""
Squared error cost function.
"""
def __init__(self, input_to_output, y_target, **kwds):
def __init__(self, y_target, **kwds):
self.y_target = y_target
self.input_to_output = input_to_output

super().__init__(**kwds)

def score_impl(self, x):
def score_impl(self, y):
"""
Computes the loss - squared error.
"""
return l2(self.input_to_output(x), self.y_target)
return l2(y, self.y_target)


class RegularizedCost(CostFunctionDifferentiableTorch):
Expand Down
2 changes: 1 addition & 1 deletion ceml/torch/counterfactual.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def compute_counterfactual(self, x, y_target, features_whitelist=None, regulariz
"""
# Hide the input in a wrapper if we can use a subset of features only
input_wrapper, x_orig, _, grad_mask = self.wrap_input(features_whitelist, x, optimizer)

# Check if the prediction of the given input is already consistent with y_target
done = done = done if done is not None else y_target if callable(y_target) else lambda y: y == y_target
self.warn_if_already_done(x, done)
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def predict(self, x, dim=1):
return torch.argmax(self.forward(x), dim=dim)

def get_loss(self, y_target, pred=None):
return NegLogLikelihoodCost(self.predict_proba, y_target)
return NegLogLikelihoodCost(input_to_output=self.predict_proba, y_target=y_target)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tests/torch/test_torch_linearregression.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def predict(self, x, dim=None): # Note: In contrast to classification, the param
return self.forward(x)

def get_loss(self, y_target, pred=None):
return SquaredError(self.predict, y_target)
return SquaredError(input_to_output=self.predict, y_target=y_target)

# Load data
X, y = load_boston(True)
Expand Down
4 changes: 2 additions & 2 deletions tests/torch/test_torch_softmaxregression.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def predict(self, x, dim=1):
return torch.argmax(self.forward(x), dim=dim)

def get_loss(self, y_target, pred=None):
return NegLogLikelihoodCost(self.predict_proba, y_target)
return NegLogLikelihoodCost(input_to_output=self.predict_proba, y_target=y_target)

# Load data
X, y = load_iris(True)
Expand Down Expand Up @@ -164,7 +164,7 @@ def get_loss(self, y_target, pred=None):
assert all([True if i in features_whitelist else delta[i] == 0. for i in range(x_orig.shape[0])])

optimizer = torch.optim.SGD
x_cf, y_cf, delta = generate_counterfactual(model, x_orig, y_target=0, features_whitelist=features_whitelist, regularization="l2", C=0.001, optimizer=optimizer, optimizer_args=optimizer_args, return_as_dict=False)
x_cf, y_cf, delta = generate_counterfactual(model, x_orig, y_target=0, features_whitelist=features_whitelist, regularization="l2", C=0.0001, optimizer=optimizer, optimizer_args=optimizer_args, return_as_dict=False)
assert y_cf == 0
assert model.predict(torch.from_numpy(np.array([x_cf], dtype=np.float32))).detach().numpy() == 0
assert all([True if i in features_whitelist else delta[i] == 0. for i in range(x_orig.shape[0])])
Expand Down

0 comments on commit aaa0c83

Please sign in to comment.