From d7dba98d8be048c315491050d88ae0d241f3f970 Mon Sep 17 00:00:00 2001 From: Henrik Mettler Date: Wed, 6 Jan 2021 09:51:18 +0100 Subject: [PATCH] Fix merge conflict --- cgp/genome.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/cgp/genome.py b/cgp/genome.py index 6cc9d433..d474c1fe 100644 --- a/cgp/genome.py +++ b/cgp/genome.py @@ -712,8 +712,24 @@ def update_parameters_from_torch_class(self, torch_cls: "torch.nn.Module") -> bo def _initialize_unknown_parameters(self) -> None: for region_idx, region in self.iter_hidden_regions(): - - self._initialize_parameter_values(region_idx, region) + node_id = region[0] + node_type = self._primitives[node_id] + assert issubclass(node_type, OperatorNode) + for parameter_name_with_idx in self._get_parameter_names_with_idx_of_node( + node_type, region_idx + ): + if parameter_name_with_idx not in self._parameter_names_to_values: + self._parameter_names_to_values[ + parameter_name_with_idx + ] = node_type.initial_value(parameter_name_with_idx) + + def _get_parameter_names_with_idx_of_node( + self, node_type: Type[OperatorNode], region_idx: int + ) -> List[str]: + parameter_names_with_idx: List[str] = [] + for parameter_name in node_type._parameter_names: + parameter_names_with_idx.append("<" + parameter_name[1:-1] + str(region_idx) + ">") + return parameter_names_with_idx def _initialize_parameter_values( self, region_idx: int, region: List[int], reinitialize: bool = False