Skip to content

Commit

Permalink
Merge pull request #82 from ihincks/fix-serial-threshold
Browse files Browse the repository at this point in the history
Fixed an index bug in DirectViewParallelizedModel
  • Loading branch information
cgranade committed Sep 8, 2016
2 parents de46ca5 + f919c68 commit ea90440
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions src/qinfer/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ class DirectViewParallelizedModel(DerivedModel):
This :class:`Model` assumes that it has ownership over the DirectView, such
that no other processes will send tasks during the lifetime of the Model.
If you are having trouble pickling your model, consider switching to
``dill`` by calling ``direct_view.use_dill()``. This mode gives more support
for closures.
:param qinfer.Model serial_model: Model to be parallelized. This
model will be distributed to the engines in the direct view, such that
Expand All @@ -93,7 +97,7 @@ class DirectViewParallelizedModel(DerivedModel):

## INITIALIZER ##

def __init__(self, serial_model, direct_view, purge_client=False, serial_theshold=None):
def __init__(self, serial_model, direct_view, purge_client=False, serial_threshold=None):
if ipp is None:
raise RuntimeError(
"This model requires IPython parallelization support, "
Expand All @@ -104,7 +108,7 @@ def __init__(self, serial_model, direct_view, purge_client=False, serial_theshol
self._purge_client = purge_client
self._serial_threshold = (
10 * self.n_engines
if serial_theshold is None else int(serial_theshold)
if serial_threshold is None else int(serial_threshold)
)

super(DirectViewParallelizedModel, self).__init__(serial_model)
Expand Down Expand Up @@ -177,7 +181,7 @@ def likelihood(self, outcomes, modelparams, expparams):
"""
super(DirectViewParallelizedModel, self).likelihood(outcomes, modelparams, expparams)

if modelparams.shape[1] <= self._serial_threshold:
if modelparams.shape[0] <= self._serial_threshold:
return self._serial_model.likelihood(outcomes, modelparams, expparams)

if self._dv is None:
Expand Down

0 comments on commit ea90440

Please sign in to comment.