Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ModuleContainer refactoring #183

Merged
merged 14 commits into from
Dec 17, 2022
4 changes: 3 additions & 1 deletion src/pymgrid/envs/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,12 @@ def from_microgrid(cls, microgrid):

"""
try:
return cls(microgrid.module_tuples(), add_unbalanced_module=False)
modules = microgrid.modules
except AttributeError:
assert isinstance(microgrid, NonModularMicrogrid)
return cls.from_nonmodular(microgrid)
else:
return cls(modules.to_tuples(), add_unbalanced_module=False)

@classmethod
def from_nonmodular(cls, nonmodular):
Expand Down
12 changes: 6 additions & 6 deletions src/pymgrid/microgrid/microgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def sample_action(self, strict_bound=False, sample_flex_modules=False):

"""

module_iterator = self._modules.module_dict() if sample_flex_modules else self._modules.controllable.module_dict()
module_iterator = self._modules.to_dict() if sample_flex_modules else self._modules.controllable.to_dict()
return {module_name: [module.sample_action(strict_bound=strict_bound) for module in module_list]
for module_name, module_list in module_iterator.items()
if module_list[0].action_space.shape[0]}
Expand All @@ -309,7 +309,7 @@ def get_empty_action(self, sample_flex_modules=False):
Empty action.

"""
module_iterator = self._modules.module_dict() if sample_flex_modules else self._modules.controllable.module_dict()
module_iterator = self._modules.to_dict() if sample_flex_modules else self._modules.controllable.to_dict()

return {module_name: [None]*len(module_list) for module_name, module_list in module_iterator.items()
if module_list[0].action_space.shape[0]}
Expand Down Expand Up @@ -419,7 +419,7 @@ def get_forecast_horizon(self):

"""
horizons = []
for module in self.iterlist():
for module in self._modules.iterlist():
try:
horizons.append(module.forecast_horizon)
except AttributeError:
Expand Down Expand Up @@ -543,7 +543,7 @@ def module_list(self):
The list of modules

"""
return self._modules.module_list()
return self._modules.to_list()

@property
def n_modules(self):
Expand Down Expand Up @@ -625,7 +625,7 @@ def serialize(self, dumper_stream):
"""
:meta private:
"""
data = {"modules": self._modules.module_tuples(),
data = {"modules": self._modules.to_tuples(),
"balance_log": self._balance_logger.serialize()}
return dump_data(data, dumper_stream, self.yaml_tag)

Expand Down Expand Up @@ -702,7 +702,7 @@ def from_scenario(cls, microgrid_number=0):
return cls.load(f)

def __getnewargs__(self):
return (self.module_tuples(), )
return (self.modules.to_tuples(), )

def __len__(self):
"""
Expand Down
Loading