Skip to content

Commit

Permalink
SDP (GNB and QDA) supports affine preprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
andreArtelt committed Jan 15, 2021
1 parent 1eaaafa commit b459d93
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
9 changes: 7 additions & 2 deletions ceml/sklearn/naivebayes.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ def _build_constraints(self, var_X, var_x, y):
b = (self.mymodel.means[j, :] / self.mymodel.variances[j, :]) - (self.mymodel.means[i, :] / self.mymodel.variances[i, :])
c = np.log(self.mymodel.class_priors[j] / self.mymodel.class_priors[i]) + np.sum([np.log(1. / np.sqrt(2.*np.pi*self.mymodel.variances[j,k])) - ((self.mymodel.means[j,k]**2) / (2.*self.mymodel.variances[j,k])) for k in range(self.mymodel.dim)]) - np.sum([np.log(1. / np.sqrt(2.*np.pi*self.mymodel.variances[i,k])) - ((self.mymodel.means[i,k]**2) / (2.*self.mymodel.variances[i,k])) for k in range(self.mymodel.dim)])

if self.is_affine_preprocessing_set(): # If necessary, apply affine preprocessing
c = c + self.b.T @ b + self.b.T @ A @ self.b
b = self.A.T @ b + (self.b.T @ A @ self.A).T + self.A.T @ A @ self.b
A = self.A.T @ A @ self.A

return [cp.trace(A @ var_X) + b @ var_x + c + self.epsilon <= 0]

def _build_solve_dcqp(self, x_orig, y_target, regularization, features_whitelist, optimizer_args):
Expand Down Expand Up @@ -160,8 +165,8 @@ def _build_solve_dcqp(self, x_orig, y_target, regularization, features_whitelist

def solve(self, x_orig, y_target, regularization, features_whitelist, return_as_dict, optimizer_args):
xcf = None
if self.mymodel.is_binary and not self.is_affine_preprocessing_set() and regularization != "l1":
xcf = self.build_solve_opt(x_orig, y_target, optimizer_args)
if self.mymodel.is_binary and regularization != "l1":
xcf = self.build_solve_opt(x_orig, y_target, features_whitelist, optimizer_args)
else:
xcf = self._build_solve_dcqp(x_orig, y_target, regularization, features_whitelist, optimizer_args)
delta = x_orig - xcf
Expand Down
7 changes: 6 additions & 1 deletion ceml/sklearn/qda.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ def _build_constraints(self, var_X, var_x, y):
b = np.dot(self.mymodel.sigma_inv[j], self.mymodel.means[j]) - np.dot(self.mymodel.sigma_inv[i], self.mymodel.means[i])
c = np.log(self.mymodel.class_priors[j] / self.mymodel.class_priors[i]) + 0.5 * np.log(np.linalg.det(self.mymodel.sigma_inv[j]) / np.linalg.det(self.mymodel.sigma_inv[i])) + 0.5 * (self.mymodel.means[i].T.dot(self.mymodel.sigma_inv[i]).dot(self.mymodel.means[i]) - self.mymodel.means[j].T.dot(self.mymodel.sigma_inv[j]).dot(self.mymodel.means[j]))

if self.is_affine_preprocessing_set(): # If necessary, apply affine preprocessing
c = c + self.b.T @ b + self.b.T @ A @ self.b
b = self.A.T @ b + (self.b.T @ A @ self.A).T + self.A.T @ A @ self.b
A = self.A.T @ A @ self.A

return [cp.trace(A @ var_X) + var_x.T @ b + c + self.epsilon <= 0]

def _build_solve_dcqp(self, x_orig, y_target, regularization, features_whitelist, optimizer_args):
Expand Down Expand Up @@ -165,7 +170,7 @@ def _build_solve_dcqp(self, x_orig, y_target, regularization, features_whitelist

def solve(self, x_orig, y_target, regularization, features_whitelist, return_as_dict, optimizer_args):
xcf = None
if self.mymodel.is_binary and not self.is_affine_preprocessing_set() and regularization != "l1":
if self.mymodel.is_binary and regularization != "l1":
xcf = self.build_solve_opt(x_orig, y_target, features_whitelist, optimizer_args)
else:
xcf = self._build_solve_dcqp(x_orig, y_target, regularization, features_whitelist, optimizer_args)
Expand Down

0 comments on commit b459d93

Please sign in to comment.