Skip to content

Commit

Permalink
Bugfix: Gaussian Naive Bayes: Invalid shape
Browse files Browse the repository at this point in the history
  • Loading branch information
andreArtelt committed May 8, 2020
1 parent 4976ac2 commit 3e8e3cf
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion ceml/sklearn/naivebayes.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def _build_constraints(self, var_X, var_x, y):
b = (self.mymodel.means[j, :] / self.mymodel.variances[j, :]) - (self.mymodel.means[i, :] / self.mymodel.variances[i, :])
c = np.log(self.mymodel.class_priors[j] / self.mymodel.class_priors[i]) + np.sum([np.log(1. / np.sqrt(2.*np.pi*self.mymodel.variances[j,k])) - ((self.mymodel.means[j,k]**2) / (2.*self.mymodel.variances[j,k])) for k in range(self.mymodel.dim)]) - np.sum([np.log(1. / np.sqrt(2.*np.pi*self.mymodel.variances[i,k])) - ((self.mymodel.means[i,k]**2) / (2.*self.mymodel.variances[i,k])) for k in range(self.mymodel.dim)])

return [cp.trace(A @ var_X) + var_x.T @ b + c + self.epsilon <= 0]
return [cp.trace(A @ var_X) + b @ var_x + c + self.epsilon <= 0]

def _build_solve_dcqp(self, x_orig, y_target, regularization, features_whitelist):
Q0 = np.eye(self.mymodel.dim) # TODO: Can be ignored if regularization != l2
Expand Down

0 comments on commit 3e8e3cf

Please sign in to comment.