Skip to content

Commit

Permalink
Merge 0b32bb1 into 793c807
Browse files Browse the repository at this point in the history
  • Loading branch information
mxndrwgrdnr committed Dec 14, 2018
2 parents 793c807 + 0b32bb1 commit b8776aa
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions urbansim_templates/models/large_multinomial_logit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# choicemodels imports are in the fit() and run() methods

from .. import modelmanager
from ..utils import get_data, version_greater_or_equal
from ..utils import get_data, version_greater_or_equal, to_list
from .shared import TemplateStep


Expand Down Expand Up @@ -475,9 +475,9 @@ def run(self, chooser_batch_size=None, interaction_terms=None):
Additional column(s) of interaction terms whose values depend on the
combination of observation and alternative, to be merged onto the final data
table. If passed as a Series or DataFrame, it should include a two-level
MultiIndex. One level's name and values should match an index or column from
the observations table, and the other should match an index or column from the
alternatives table.
MultiIndex. The outermost level's name and values should match an index or
column from the observations table, and the second should match an index or
column from the alternatives table.
Returns
-------
Expand All @@ -494,17 +494,24 @@ def run(self, chooser_batch_size=None, interaction_terms=None):
"choicemodels 0.2.dev4 or later. For installation instructions, see "
"https://github.com/udst/choicemodels.")

if interaction_terms is not None:
obs_extra_cols = to_list(self.chooser_size) + list(interaction_terms.index.names)[0]
alts_extra_cols = to_list(self.alt_capacity) + list(interaction_terms.index.names)[1]
else:
obs_extra_cols = self.chooser_size
alts_extra_cols = self.alt_capacity

observations = get_data(tables = self.out_choosers,
fallback_tables = self.choosers,
filters = self.out_chooser_filters,
model_expression = self.model_expression,
extra_columns = self.chooser_size)
extra_columns = obs_extra_cols)

alternatives = get_data(tables = self.out_alternatives,
fallback_tables = self.alternatives,
filters = self.out_alt_filters,
model_expression = self.model_expression,
extra_columns = self.alt_capacity)
extra_columns = alts_extra_cols)

model = MultinomialLogitResults(model_expression = self.model_expression,
fitted_parameters = self.fitted_parameters)
Expand Down

0 comments on commit b8776aa

Please sign in to comment.