Skip to content

Commit

Permalink
perf: parallelise temporal predictor loading
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinBernstorff committed Nov 18, 2022
1 parent 1a3e5de commit 8d53f16
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 45 deletions.
77 changes: 49 additions & 28 deletions src/application/t2d/generate_features_and_write_to_disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import tempfile
import time
from collections.abc import Sequence
from multiprocessing import Pool
from pathlib import Path
from typing import Optional, Union
from typing import Callable, Optional, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -43,6 +44,7 @@
)
from psycop_feature_generation.utils import (
FEATURE_SETS_PATH,
N_WORKERS,
PROJECT_ROOT,
write_df_to_file,
)
Expand Down Expand Up @@ -308,6 +310,13 @@ def add_predictors_to_ds(
return flattened_dataset


def resolve_spec_set_component(spec_set_component: dict[str, Callable]):
for k, v in spec_set_component.items():
spec_set_component[k] = v()

return spec_set_component


class SpecSet(BaseModel):
"""A set of unresolved specs, ready for resolving."""

Expand All @@ -316,7 +325,6 @@ class SpecSet(BaseModel):
outcomes: list[OutcomeSpec]
metadata: list[AnySpec]


def create_flattened_dataset(
prediction_times: pd.DataFrame,
birthdays: pd.DataFrame,
Expand Down Expand Up @@ -407,7 +415,7 @@ def get_static_predictor_specs():

def get_metadata_specs() -> list[AnySpec]:
"""Get metadata specs."""
return [
metadata_specs = [
StaticSpec(
values_loader="t2d",
input_col_name_override="timestamp",
Expand All @@ -426,17 +434,20 @@ def get_metadata_specs() -> list[AnySpec]:
allowed_nan_value_prop=0.0,
prefix="eval",
),
OutcomeGroupSpec(
values_loader=["hba1c"],
interval_days=[year * 365 for year in LOOKAHEAD_YEARS],
resolve_multiple_fn=["count"],
fallback=[0],
incident=[False],
allowed_nan_value_prop=[0.0],
prefix="eval",
).create_combinations(),
]

metadata_specs += OutcomeGroupSpec(
values_loader=["hba1c"],
interval_days=[year * 365 for year in LOOKAHEAD_YEARS],
resolve_multiple_fn=["count"],
fallback=[0],
incident=[False],
allowed_nan_value_prop=[0.0],
prefix="eval",
).create_combinations()

return metadata_specs


def get_outcome_specs():
"""Get outcome specs."""
Expand All @@ -450,15 +461,19 @@ def get_outcome_specs():
).create_combinations()


def resolve_group_spec(group_spec: Union[PredictorGroupSpec, OutcomeGroupSpec]):
return group_spec.create_combinations()


def get_temporal_predictor_specs() -> list[PredictorSpec]:
"""Generate predictor spec list."""
base_resolve_multiple = ["max", "min", "mean", "latest", "count"]
base_interval_days = [30, 90, 180, 365, 730]
base_allowed_nan_value_prop = [0]

temporal_predictor_specs: list[PredictorSpec] = []
temporal_predictor_groups: list[PredictorGroupSpec] = []

temporal_predictor_specs += PredictorGroupSpec(
temporal_predictor_groups += PredictorGroupSpec(
values_loader=(
"hba1c",
"alat",
Expand All @@ -476,9 +491,9 @@ def get_temporal_predictor_specs() -> list[PredictorSpec]:
interval_days=base_interval_days,
fallback=[np.nan],
allowed_nan_value_prop=base_allowed_nan_value_prop,
).create_combinations()
)

temporal_predictor_specs += PredictorGroupSpec(
temporal_predictor_groups += PredictorGroupSpec(
values_loader=(
"essential_hypertension",
"hyperlipidemia",
Expand All @@ -490,14 +505,10 @@ def get_temporal_predictor_specs() -> list[PredictorSpec]:
interval_days=base_interval_days,
fallback=[0],
allowed_nan_value_prop=base_allowed_nan_value_prop,
).create_combinations()
)

temporal_predictor_specs += PredictorGroupSpec(
temporal_predictor_groups += PredictorGroupSpec(
values_loader=(
"essential_hypertension",
"hyperlipidemia",
"polycystic_ovarian_syndrome",
"sleep_apnea",
"f0_disorders",
"f1_disorders",
"f2_disorders",
Expand All @@ -508,16 +519,16 @@ def get_temporal_predictor_specs() -> list[PredictorSpec]:
"f7_disorders",
"f8_disorders",
"hyperkinetic_disorders",
"gerd_drugs",
),
resolve_multiple_fn=base_resolve_multiple,
interval_days=base_interval_days,
fallback=[0],
allowed_nan_value_prop=base_allowed_nan_value_prop,
).create_combinations()
)

temporal_predictor_specs += PredictorGroupSpec(
temporal_predictor_groups += PredictorGroupSpec(
values_loader=(
"gerd_drugs",
"antipsychotics",
"clozapine",
"top_10_weight_gaining_antipsychotics",
Expand All @@ -539,15 +550,25 @@ def get_temporal_predictor_specs() -> list[PredictorSpec]:
resolve_multiple_fn=base_resolve_multiple,
fallback=[0],
allowed_nan_value_prop=base_allowed_nan_value_prop,
).create_combinations()
)

temporal_predictor_specs += PredictorGroupSpec(
temporal_predictor_groups += PredictorGroupSpec(
values_loader=["weight_in_kg", "height_in_cm", "bmi"],
interval_days=base_interval_days,
resolve_multiple_fn=["latest"],
fallback=[np.nan],
allowed_nan_value_prop=base_allowed_nan_value_prop,
).create_combinations()
)

with Pool(min(N_WORKERS, len(temporal_predictor_groups))) as p:
temporal_predictor_specs: list[PredictorSpec] = p.map(
func=resolve_group_spec, iterable=temporal_predictor_groups
)

# Unpack list of lists
temporal_predictor_specs = [
item for sublist in temporal_predictor_specs for item in sublist
]

return temporal_predictor_specs

Expand Down
5 changes: 3 additions & 2 deletions src/psycop_feature_generation/loaders/raw/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,9 @@ def load_from_codes(

df = sql_load(sql, database="USR_PS_FORSK", chunksize=None, n_rows=n_rows)

# Drop all rows whose code_col_name is in exclude_codes
df = df[~df[code_col_name].isin(exclude_codes)]
if exclude_codes:
# Drop all rows whose code_col_name is in exclude_codes
df = df[~df[code_col_name].isin(exclude_codes)]

if output_col_name is None:
if isinstance(codes_to_match, list):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

import itertools
from collections.abc import Sequence
from functools import cache, partial
from multiprocessing import Pool
from functools import cache
from typing import Any, Callable, Optional, Union

import pandas as pd
Expand All @@ -15,7 +14,7 @@
from psycop_feature_generation.timeseriesflattener.resolve_multiple_functions import (
resolve_multiple_fns,
)
from psycop_feature_generation.utils import N_WORKERS, data_loaders
from psycop_feature_generation.utils import data_loaders

msg = Printer(timestamp=True)

Expand Down Expand Up @@ -302,14 +301,10 @@ def create_feature_combinations_from_dict(

# Create all combinations of top level elements
permutations_dicts = [dict(zip(keys, v)) for v in itertools.product(*values)]

return permutations_dicts


def create_output_class_from_kwargs(output_class, kwargs):
return output_class(**kwargs)


def create_specs_from_group(
feature_group_spec: MinGroupSpec,
output_class: AnySpec,
Expand All @@ -322,13 +317,7 @@ def create_specs_from_group(

permuted_dicts = create_feature_combinations_from_dict(d=feature_group_spec_dict)

with Pool(min(N_WORKERS, len(feature_group_spec.values_loader))) as p:
output_list: list[AnySpec] = p.map(
iterable=permuted_dicts,
func=partial(create_output_class_from_kwargs, output_class),
)

return output_list # type: ignore
return [output_class(**d) for d in permuted_dicts] # type: ignore


class PredictorGroupSpec(MinGroupSpec):
Expand Down

0 comments on commit 8d53f16

Please sign in to comment.