In [None]:
#default_exp topo_solvers

In [None]:
#export
import time
import torch
from collections import defaultdict

from dl4to.solution import Solution
from dl4to.criteria import VolumeFraction, Binariness

In [None]:
#hide
from nbdev.showdoc import show_doc

# SIMP iterator

In [None]:
#export
class SIMPIterator:
    """
    Performs the actual SIMP optimization.
    """
    def __init__(
        self,
        problem:"dl4to.problem.Problem", # The problem that should be solved by SIMP.
        criterion:"dl4to.criteria.Criterion", # The objective function that should be optimized for in the optimization process.
        density_representer:"dl4to.density_representers.DensityRepresenter", # The density representer that is used for the latent density representation. The density representer also performs the projection, smoothing and filtering.
        lr:float, # The learning rate of the `torch.optim.Adam` optimizer.
        binarizer_steepening_factor:float # The factor at which the binarizer should be steepened in each iteration. E.g.,a value of 1.1 corresponds to a steepening of 10% per iteration.
    ):
        self.lr = lr
        self.logs = defaultdict(list)
        self.binarizer_steepening_factor = binarizer_steepening_factor
        self.problem = problem

        self.criterion = criterion
        self.volume_crit = VolumeFraction()
        self.binariness_crit = Binariness()
        self.density_representer = density_representer
        self.optimizer = torch.optim.Adam(self.density_representer.parameters(), lr=self.lr)


    def _extend_logs(self, solution, loss, volume, tick, σ_vm):
        self.logs["losses"].append(loss.item())
        self.logs["volumes"].append(volume.item())
        self.logs["durations"].append(time.time() - tick)
        self.logs["binarinesses"].append(self.binariness_crit([solution]))
        self.logs["relative_max_σ_vm"].append(σ_vm.max().item() / self.problem.σ_ys)


    def _perform_optimizer_step(self, loss):
        self.optimizer.zero_grad()
        loss.backward(retain_graph=True)
        grads = torch.autograd.grad(loss, self.density_representer.θ)[0].detach()
        self.optimizer.step()
        return grads


    def __call__(self, 
                 p:float # The SIMP exponent
                ):
        """
        Creates the SIMP solution objects, solves the PDE and communicates with the density representer.
        Returns a `dl4to.solution.Solution` object.
        """
        tick = time.time()
        solution = Solution(
            problem=self.problem,
            θ=self.density_representer(),
        )

        u, σ, σ_vm = solution.solve_pde(p=p)
        loss = self.criterion([solution])
        volume = self.volume_crit([solution])

        grads = self._perform_optimizer_step(loss)
        solution.θ = self.density_representer()
        self.density_representer.steepen_binarizer(self.binarizer_steepening_factor)

        self._extend_logs(solution, loss, volume, tick, σ_vm)
        solution.logs = self.logs

        solution.logs['grads'].append(grads)
        return solution