Skip to content

Commit

Permalink
llvm/execution: Use 'evaluate' function params when constructing simu…
Browse files Browse the repository at this point in the history
…lation input (#2467)

The original code assumed 'run' variant would be called.
This is not the case for parallel evaluate that only needs
'run, simulation' variant, resulting in redundant code
generation and compiler calls.

Instead, use the 'evaluate' compiled function that provides the
same binary type of the argument at a different offset.

Tested by asserting that Python-{LLVM,PTX} tests only generate
'run, simulation' variant of the composition function.

Signed-off-by: Jan Vesely <jan.vesely@rutgers.edu>
  • Loading branch information
jvesely committed Aug 17, 2022
1 parent d81df22 commit ab6e911
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions psyneulink/core/llvm/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,8 +563,11 @@ def cuda_execute(self, inputs):

# Methods used to accelerate "Run"

def _get_run_input_struct(self, inputs, num_input_sets):
input_type = self._bin_run_func.byref_arg_types[3]
def _get_run_input_struct(self, inputs, num_input_sets, arg=3):
# Callers that override input arg, should ensure that _bin_func is not None
bin_f = self._bin_run_func if arg == 3 else self._bin_func

input_type = bin_f.byref_arg_types[arg]
c_input = (input_type * num_input_sets) * len(self._execution_contexts)
if len(self._execution_contexts) == 1:
inputs = [inputs]
Expand Down Expand Up @@ -694,8 +697,8 @@ def _prepare_evaluate(self, inputs, num_input_sets, num_evaluations):
ct_comp_state = self._get_compilation_param('_eval_state', '_get_state_initializer', 1)
ct_comp_data = self._get_compilation_param('_eval_data', '_get_data_initializer', 6)

# Construct input variable
ct_inputs = self._get_run_input_struct(inputs, num_input_sets)
# Construct input variable, the 5th parameter of the evaluate function
ct_inputs = self._get_run_input_struct(inputs, num_input_sets, 5)

# Output ctype
out_ty = bin_func.byref_arg_types[4] * num_evaluations
Expand Down

0 comments on commit ab6e911

Please sign in to comment.