Skip to content

Commit

Permalink
Bugfix: Invalid shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
andreArtelt committed May 8, 2020
1 parent 3b72e36 commit 4976ac2
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions ceml/sklearn/lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ def _build_constraints(self, var_x, y):
constraints = []

i = y
q_i = np.dot(self.mymodel.sigma_inv, self.mymodel.means[i])
b_i = np.log(self.mymodel.class_priors[i]) - .5 * np.dot( self.mymodel.means[i], np.dot(self.mymodel.sigma_inv, self.mymodel.means[i]))
q_i = np.dot(self.mymodel.sigma_inv, self.mymodel.means[i].T)
b_i = np.log(self.mymodel.class_priors[i]) - .5 * np.dot( self.mymodel.means[i], np.dot(self.mymodel.sigma_inv, self.mymodel.means[i].T))

for j in range(len(self.mymodel.means)):
if i == j:
Expand Down
2 changes: 1 addition & 1 deletion ceml/sklearn/qda.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def rebuild_model(self, model):
return Qda(model)

def _build_constraints(self, var_X, var_x, y):
i = y
i = int(y)
j = 0 if y == 1 else 1

A = .5 * ( self.mymodel.sigma_inv[i] - self.mymodel.sigma_inv[j])
Expand Down

0 comments on commit 4976ac2

Please sign in to comment.