### Set up

In [64]:
from typing import Any

import cooper
import torch
import math
from cooper import CMPState

In [2]:
loss_fn = torch.nn.MSELoss()

In [65]:
# Template for single lower bound constraint on given "volume" function
class VolumeConstrainedIntervalMinimizer(cooper.ConstrainedMinimizationProblem):
    def __init__(self, volume_threshold: float, volume_function):
        super().__init__()
        self.volume_threshold = volume_threshold
        self.volume_function = volume_function
        multiplier = cooper.multipliers.DenseMultiplier(num_constraints=1) # device =
        self.volume_constraint = cooper.Constraint(
            multiplier=multiplier,
            constraint_type=cooper.ConstraintType.INEQUALITY,
            formulation_type=cooper.formulations.Lagrangian,
        )

    def compute_cmp_state(self, model, inputs, targets) -> cooper.CMPState:
        logits = model(*inputs)
        loss = loss_fn(logits, targets)
        volume = self.volumeFunction(model.weight)
        volume_constraint_state = cooper.ConstraintState(violation=self.volume_threshold - volume)

        #misc = {"batch_accuracy": ...}

        observed_constraints = {self.volume_constraint: volume_constraint_state}

        return cooper.CMPState(loss=loss, observed_constraints=observed_constraints, ) # misc =




### A dummy example

##### Minimizing x+y subject to x² + y² = 1

In [31]:
class SphereSurfaceConstrainedMinimizer(cooper.ConstrainedMinimizationProblem):
    def __init__(self, radius):
        super().__init__()
        self.radius = radius
        multiplier = cooper.multipliers.DenseMultiplier(num_constraints=1)
        self.radius_constraint = cooper.Constraint(
            multiplier=multiplier,
            constraint_type=cooper.ConstraintType.EQUALITY,
            formulation_type=cooper.formulations.Lagrangian,
        )

    def compute_cmp_state(self, model) -> cooper.CMPState:
        loss = model.weight.sum()
        radius = model.weight.pow(2).sum()

        radius_constraint_state = cooper.ConstraintState(violation=self.radius-radius)
        observed_constraints = {self.radius_constraint: radius_constraint_state}
        return cooper.CMPState(loss=loss, observed_constraints=observed_constraints, )


In [32]:
class DummyModel(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.weight = torch.nn.Parameter(torch.tensor([1.0,1.0]))

    def forward(self):
        return self.weight


In [33]:
cmp = SphereSurfaceConstrainedMinimizer(radius=1.0)
model = torch.nn.Linear(1,2,False)

primal_optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
dual_optimizer = torch.optim.SGD(cmp.dual_parameters(), lr=0.01, maximize=True)

cooper_optimizer = cooper.optim.AlternatingDualPrimalOptimizer(
    cmp = cmp, primal_optimizers=primal_optimizer, dual_optimizers=dual_optimizer
)

In [34]:
for epoch_num in range(1000):
    compute_cmp_state_kwargs = {"model": model,}
    roll_out = cooper_optimizer.roll(compute_cmp_state_kwargs=compute_cmp_state_kwargs)

In [35]:
model.weight

Parameter containing:
tensor([[-0.7056],
        [-0.7056]], requires_grad=True)