In [None]:
class EnvelopeLayer:
    def __init__(self, variables, parameters, objective, inequalities, equalities, **cvxpy_opts):
        self.variables = variables
        self.parameters = parameters
        self.objective = objective
        self.inequalities = inequalities
        self.equalities = equalities
        self.cvxpy_opts = cvxpy_opts
    
        self.cp_inequalities = [ineq(*variables, *parameters) <= 0 for ineq in inequalities]
        self.cp_equalities = [eq(*variables, *parameters) == 0 for eq in equalities]
        self.problem = cp.Problem(
            cp.Minimize(objective(*variables, *parameters)),
            self.cp_inequalities + self.cp_equalities
        )

    def forward(self, *batch_params):
        out = []
        batch_size = batch_params[0].shape[0]
        for batch in range(batch_size):
            params = [p[batch] for p in batch_params]
            with torch.no_grad():
                for i, p in enumerate(self.parameters):
                    p.value = params[i].cpu().double().numpy()
                try:
                    self.problem.solve(**self.cvxpy_opts, max_iter=100000)
                except:
                    print("Ill conditioned case. Eigenvalues: ")
                    print(torch.linalg.eigvals(params[0]))
                z = [torch.tensor(v.value).type_as(params[0]) for v in self.variables]
                lam = [torch.tensor(c.dual_value).type_as(params[0]) for c in self.cp_inequalities]
                nu = [torch.tensor(c.dual_value).type_as(params[0]) for c in self.cp_equalities]
            
            g = [ineq(*z, *params) for ineq in self.inequalities]
            h = [eq(*z, *params) for eq in self.equalities]
            L = (self.objective(*z, *params) +
                     sum((u*v).sum() for u,v in zip(lam,g)) + sum((u*v).sum() for u,v in zip(nu,h)))
            dparams = autograd.grad(L, params, create_graph=True)

            Sigma_hat = torch.matmul(params[0], params[0].transpose(0, 1))
            mu_hat = params[1]
            optimal_value = torch.matmul(z_star.unsqueeze(0),
                                         torch.matmul(Sigma_hat, z_star.unsqueeze(1))).squeeze() - torch.sum(
                mu_hat * z_star)

            out.append(optimal_value)
        out = torch.stack(out, dim=0)
        return out