Skip to content

Commit

Permalink
Bugfix: GMLVQ: Add missing transpose in computing final distance matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
andreArtelt committed Jul 22, 2019
1 parent a79a375 commit 609501e
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 9 deletions.
6 changes: 3 additions & 3 deletions ceml/sklearn/lvq.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,14 @@ def __init__(self, model, dist="l2"):

# The model might have learned its own distance metric
if isinstance(model, sklearn_lvq.GmlvqModel) or isinstance(model, sklearn_lvq.MrslvqModel):
self.dist_mat = create_tensor(np.dot(model.omega_, model.omega_))
self.dist_mat = create_tensor(np.dot(model.omega_.T, model.omega_))
elif isinstance(model, sklearn_lvq.LgmlvqModel) or isinstance(model, sklearn_lvq.LmrslvqModel):
if model.classwise == True:
self.omegas = [create_tensor(np.dot(omega, omega)) for omega in model.omegas_]
self.omegas = [create_tensor(np.dot(omega.T, omega)) for omega in model.omegas_]
self.classwise = True
self.dist_mats = None
else:
self.dist_mats = [create_tensor(np.dot(omega, omega)) for omega in model.omegas_]
self.dist_mats = [create_tensor(np.dot(omega.T, omega)) for omega in model.omegas_]
self.classwise = False

super(LVQ, self).__init__()
Expand Down
9 changes: 3 additions & 6 deletions tests/sklearn/test_sklearn_lvq.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ def test_glvq():

# Create and fit model
model = GlvqModel(prototypes_per_class=3, max_iter=100, random_state=4242)
#model = GmlvqModel(prototypes_per_class=3, max_iter=200, random_state=4242)
#model = LgmlvqModel(prototypes_per_class=3, max_iter=100, random_state=4242)
model.fit(X_train, y_train)

# Select data point for explaining its prediction
Expand Down Expand Up @@ -73,8 +71,7 @@ def test_gmlvq():
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=4242)

# Create and fit model
model = GmlvqModel(prototypes_per_class=3, max_iter=200, random_state=4242)
#model = LgmlvqModel(prototypes_per_class=3, max_iter=100, random_state=4242)
model = GmlvqModel(prototypes_per_class=3, max_iter=200, random_state=4242, dim=2)
model.fit(X_train, y_train)

# Select data point for explaining its prediction
Expand All @@ -84,7 +81,7 @@ def test_gmlvq():
# Compute counterfactual
features_whitelist = None

x_cf, y_cf, delta = generate_counterfactual(model, x_orig, 0, features_whitelist=features_whitelist, regularization="l1", C=1.0, optimizer="bfgs", return_as_dict=False)
x_cf, y_cf, delta = generate_counterfactual(model, x_orig, 0, features_whitelist=features_whitelist, regularization="l1", C=0.01, optimizer="bfgs", return_as_dict=False)
assert y_cf == 0
assert model.predict(np.array([x_cf])) == 0

Expand All @@ -102,7 +99,7 @@ def test_gmlvq():


features_whitelist = [0, 1, 2, 3]
x_cf, y_cf, delta = generate_counterfactual(model, x_orig, 0, features_whitelist=features_whitelist, regularization="l1", C=1.0, optimizer="bfgs", return_as_dict=False)
x_cf, y_cf, delta = generate_counterfactual(model, x_orig, 0, features_whitelist=features_whitelist, regularization="l1", C=0.01, optimizer="bfgs", return_as_dict=False)
assert y_cf == 0
assert model.predict(np.array([x_cf])) == 0
assert all([True if i in features_whitelist else delta[i] == 0. for i in range(x_orig.shape[0])])
Expand Down

0 comments on commit 609501e

Please sign in to comment.