Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trip destination alts preprocessor for both sample and simulate steps #869

Merged
merged 9 commits into from
May 14, 2024
31 changes: 20 additions & 11 deletions activitysim/abm/models/trip_destination.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ class TripDestinationSettings(LocationComponentSettings, extra="forbid"):
PRIMARY_DEST: str = "tour_leg_dest" # must be created in preprocessor
REDUNDANT_TOURS_MERGED_CHOOSER_COLUMNS: list[str] | None = None
preprocessor: PreprocessorSettings | None = None
alts_preprocessor: PreprocessorSettings | None = None
alts_preprocessor_sample: PreprocessorSettings | None = None
alts_preprocessor_simulate: PreprocessorSettings | None = None
CLEANUP: bool
fail_some_trips_for_testing: bool = False
"""This setting is used by testing code to force failed trip_destination."""
Expand Down Expand Up @@ -202,6 +203,15 @@ def _destination_sample(

log_alt_losers = state.settings.log_alt_losers

if model_settings.alts_preprocessor_sample:
expressions.assign_columns(
state,
df=alternatives,
model_settings=model_settings.alts_preprocessor_sample,
locals_dict=locals_dict,
trace_label=tracing.extend_trace_label(trace_label, "alts"),
)

choices = interaction_sample(
state,
choosers=trips,
Expand Down Expand Up @@ -936,6 +946,15 @@ def trip_destination_simulate(
)
locals_dict.update(skims)

if model_settings.alts_preprocessor_simulate:
expressions.assign_columns(
state,
df=destination_sample,
model_settings=model_settings.alts_preprocessor_simulate,
locals_dict=locals_dict,
trace_label=tracing.extend_trace_label(trace_label, "alts"),
)

log_alt_losers = state.settings.log_alt_losers
destinations = interaction_sample_simulate(
state,
Expand Down Expand Up @@ -1246,7 +1265,6 @@ def run_trip_destination(
state.filesystem, model_settings_file_name
)
preprocessor_settings = model_settings.preprocessor
alts_preprocessor_settings = model_settings.alts_preprocessor
logsum_settings = state.filesystem.read_model_settings(
model_settings.LOGSUM_SETTINGS
)
Expand Down Expand Up @@ -1369,15 +1387,6 @@ def run_trip_destination(
trace_label=nth_trace_label,
)

if alts_preprocessor_settings:
expressions.assign_columns(
state,
df=alternatives,
model_settings=alts_preprocessor_settings,
locals_dict=locals_dict,
trace_label=tracing.extend_trace_label(nth_trace_label, "alts"),
)

if isinstance(
nth_trips["trip_period"].dtype, pd.api.types.CategoricalDtype
):
Expand Down
Loading