In [None]:
#default_exp topo_solvers

In [None]:
#exporti
import json
import torch
import warnings
from copy import deepcopy
from collections import defaultdict

from dl4to.utils import get_dataloader
from dl4to.preprocessing import TrivialPreprocessing
from dl4to.topo_solvers import TopoSolver, TrainModule

In [None]:
#hide
from nbdev.showdoc import show_doc

# Trainable topo solver

In [None]:
#export
class TrainableTopoSolver(TopoSolver):
    """
    A topo solver that is trainable and can be used for learned topology optimization.
    """
    def __init__(
        self,
        criterion:"dl4to.criteria.Criterion", # The loss criterion that should be used for the training.
        model:torch.nn.Module, # A PyTorch neural network. Make sure that the input and output dimensions are correct.
        optimizer:torch.optim.Optimizer, # A PyTorch optimizer, for instance torch.optim.Adam. Make sure to set `params=model.parameters()` if you want to use the optimizer to train the neural network.
        preprocessing:"dl4to.preprocessing.Preprocessing"=TrivialPreprocessing(), # The preprocessing that should be used in the pipeline.
        name:str=None # The name of the topo solver.
    ):
        device = next(model.parameters()).device
        super().__init__(
            device=device,
            name=name,
            trainable=True,
            differentiable=True
        )

        self.criterion = criterion
        self.model = model
        self.model.eval()
        self.optimizer = optimizer
        self.preprocessing = preprocessing

        self.logs = defaultdict(list)
        self._train_module = TrainModule(self)


    @property
    def device(self):
        return self._device


    @device.setter
    def device(self, device):
        self._device = device
        self.model.to(device)


    def _get_copy_without_cluttering_entries(self, my_dict):
        internal_dict_wo_model = {}
        for key, value in my_dict.items():
            if key not in {"model", "optimizer", "train_module"}:
                if hasattr(value, 'name'):
                    internal_dict_wo_model[key] = value.name
                else:
                    internal_dict_wo_model[key] = value
        return internal_dict_wo_model


    def get_args_as_dict(self):
        """
        Returns basic properties and arguments of the topo solver as a dictionary.
        """
        internal_dict_wo_model = self._get_copy_without_cluttering_entries(self.__dict__)
        internal_dict_wo_model = {key: str(value) for key, value in internal_dict_wo_model.items()}
        return {'solver_name': self.name, **internal_dict_wo_model}


    def _postprocess_model_outputs(self, model_outputs, solutions):
        new_solutions = []
        for model_output, solution in zip(model_outputs, solutions):
            solution = Solution(problem=solution.problem, θ=model_output)
            new_solutions.append(solution)
        return new_solutions


    def _get_new_solutions(self, solutions, eval_mode):
        model_inputs_list = [self.preprocessing(solution) for solution in solutions]
        model_inputs = torch.cat(model_inputs_list, dim=0).to(self.device)
        if eval_mode:
            self.model.eval()
        model_outputs = self.model(model_inputs)
        solutions = self._postprocess_model_outputs(model_outputs, solutions)
        return solutions


    def train(self,
              root:str, # The directory where the training results should be saved.
              dataloader_train:torch.utils.data.DataLoader, # The dataloader that contains the training data.
              dataloader_val:torch.utils.data.DataLoader=None, # The dataloader that contains the validation data.
              epochs:int=100, # The maximal number of training epochs.
              validation_interval:int=10, # The number of epochs after which a validation step is performed and printed.
              verbose:bool=True, # Whether to print information on the current training status, like the current loss and epoch.
              patience:bool=None # If the validation score does not improve for `patience` epochs in a row, then the training is stopped and the best model is used.
             ):
        """
        Run the training for the topo solver.
        """
        self._train_module(
            root=root,
            dataloader_train=dataloader_train,
            dataloader_val=dataloader_val,
            epochs=epochs,
            validation_interval=validation_interval,
            verbose=verbose,
            patience=patience,
        )

In [None]:
show_doc(TrainableTopoSolver.get_args_as_dict)

<h4 id="TrainableTopoSolver.get_args_as_dict" class="doc_header"><code>TrainableTopoSolver.get_args_as_dict</code><a href="__main__.py#L54" class="source_link" style="float:right">[source]</a></h4>

> <code>TrainableTopoSolver.get_args_as_dict</code>()

Returns basic properties and arguments of the topo solver as a dictionary.

In [None]:
show_doc(TrainableTopoSolver.train)

<h4 id="TrainableTopoSolver.train" class="doc_header"><code>TrainableTopoSolver.train</code><a href="__main__.py#L81" class="source_link" style="float:right">[source]</a></h4>

> <code>TrainableTopoSolver.train</code>(**`root`**:`str`, **`dataloader_train`**:`DataLoader`, **`dataloader_val`**:`DataLoader`=*`None`*, **`epochs`**:`int`=*`100`*, **`validation_interval`**:`int`=*`10`*, **`verbose`**:`bool`=*`True`*, **`patience`**:`bool`=*`None`*)

Run the training for the topo solver.

||Type|Default|Details|
|---|---|---|---|
|**`root`**|`str`||The directory where the training results should be saved.|
|**`dataloader_train`**|`DataLoader`||The dataloader that contains the training data.|
|**`dataloader_val`**|`DataLoader`|`None`|The dataloader that contains the validation data.|
|**`epochs`**|`int`|`100`|The maximal number of training epochs.|
|**`validation_interval`**|`int`|`10`|The number of epochs after which a validation step is performed and printed.|
|**`verbose`**|`bool`|`True`|Whether to print information on the current training status, like the current loss and epoch.|
|**`patience`**|`bool`|`None`|If the validation score does not improve for `patience` epochs in a row, then the training is stopped and the best model is used.|


In [None]:
#hide
from torch.optim import Adam
from dl4to.criteria import WeightedBCE

In [None]:
%%time
#hide

class MockTopoSolver(TrainableTopoSolver):
        def _get_name(self):
            return "MockTopoSolver"


def test_that_we_can_instanciate_a_mock():
    model = torch.nn.Linear(4, 4)
    optimizer = Adam(model.parameters(), lr=1e-3)
    topo_solver = MockTopoSolver(model=model, optimizer=optimizer, criterion=WeightedBCE())
    topo_solver.name == "MockTopoSolver"


test_that_we_can_instanciate_a_mock()

CPU times: user 0 ns, sys: 2.61 ms, total: 2.61 ms
Wall time: 22 ms
