Skip to content

Commit

Permalink
Composition: set most_recent_context after compiled execution
Browse files Browse the repository at this point in the history
  • Loading branch information
kmantel committed Mar 16, 2021
1 parent 5641c3d commit 2330f4d
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 8 deletions.
16 changes: 16 additions & 0 deletions psyneulink/core/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -3475,6 +3475,22 @@ def log_values(self, entries):
"""
self.log.log_values(entries)

def _propagate_most_recent_context(self, context=None, visited=None):
if visited is None:
visited = set([self])

if context is None:
context = self.most_recent_context

self.most_recent_context = context

# TODO: avoid duplicating objects in _dependent_components
# throughout psyneulink or at least condense these methods
for obj in self._dependent_components:
if obj not in visited:
visited.add(obj)
obj._propagate_most_recent_context(context, visited)

@property
def _dict_summary(self):
from psyneulink.core.compositions.composition import Composition
Expand Down
12 changes: 6 additions & 6 deletions psyneulink/core/compositions/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -8397,9 +8397,9 @@ def run(
# copies back matrix to pnl from param struct (after learning)
_comp_ex._copy_params_to_pnl(context=context)

self._propagate_most_recent_context(context)
# KAM added the [-1] index after changing Composition run()
# behavior to return only last trial of run (11/7/18)
self.most_recent_context = context
return results[-1]

except Exception as e:
Expand Down Expand Up @@ -8727,7 +8727,7 @@ def _execute_controller(self,
self.controller.execute(context=context)

if execution_mode:
_comp_ex.execute_node(self.controller)
_comp_ex.execute_node(self.controller, context=context)

context.remove_flag(ContextFlags.PROCESSING)

Expand Down Expand Up @@ -8943,6 +8943,7 @@ def execute(
if report._rich_diverted_reports:
self.rich_diverted_reports = report._rich_diverted_reports

self._propagate_most_recent_context(context)
return _comp_ex.extract_node_output(self.output_CIM)

except Exception as e:
Expand Down Expand Up @@ -9036,7 +9037,7 @@ def execute(
build_CIM_input = self._build_variable_for_input_CIM(inputs)

if execution_mode:
_comp_ex.execute_node(self.input_CIM, inputs)
_comp_ex.execute_node(self.input_CIM, inputs, context)
# FIXME: parameter_CIM should be executed here as well,
# but node execution of nested compositions with
# outside control is not supported yet.
Expand Down Expand Up @@ -9282,7 +9283,7 @@ def execute(

# Execute Mechanism
if execution_mode:
_comp_ex.execute_node(node)
_comp_ex.execute_node(node, context=context)
else:
if node is not self.controller:
mech_context = copy(context)
Expand Down Expand Up @@ -9476,14 +9477,13 @@ def execute(
# Extract result here
if execution_mode:
_comp_ex.freeze_values()
_comp_ex.execute_node(self.output_CIM)
_comp_ex.execute_node(self.output_CIM, context=context)
report.report_progress(self, run_report, context)
if context.source & ContextFlags.COMMAND_LINE:
if report._recorded_reports:
self.recorded_reports = report._recorded_reports
if report._rich_diverted_reports:
self.rich_diverted_reports = report._rich_diverted_reports

return _comp_ex.extract_node_output(self.output_CIM)

# Reset context flags
Expand Down
7 changes: 5 additions & 2 deletions psyneulink/core/llvm/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,12 +460,13 @@ def _get_input_struct(self, inputs):
def freeze_values(self):
self.__frozen_vals = copy.deepcopy(self._data_struct)

def execute_node(self, node, inputs=None):
def execute_node(self, node, inputs=None, context=None):
# We need to reconstruct the input dictionary here if it was not provided.
# This happens during node execution of nested compositions.
assert len(self._execution_contexts) == 1
if inputs is None and node is self._composition.input_CIM:
context = self._execution_contexts[0]
if context is None:
context = self._execution_contexts[0]
port_inputs = {origin_port:[proj.parameters.value._get(context) for proj in p[0].path_afferents] for (origin_port, p) in self._composition.input_CIM_ports.items()}
inputs = {}
for p, v in port_inputs.items():
Expand Down Expand Up @@ -493,6 +494,8 @@ def execute_node(self, node, inputs=None):
print("RAN: {}. Params: {}".format(node, self.extract_node_params(node)))
print("RAN: {}. Results: {}".format(node, self.extract_node_output(node)))

node._propagate_most_recent_context(context)

@property
def _bin_exec_func(self):
if self.__bin_exec_func is None:
Expand Down

0 comments on commit 2330f4d

Please sign in to comment.