Skip to content

Commit

Permalink
fix: misc. fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinBernstorff committed Nov 11, 2022
1 parent 930fe77 commit 45f8348
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 33 deletions.
35 changes: 18 additions & 17 deletions src/application/t2d/generate_features_and_write_to_disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ def create_unresolved_specs() -> dict[str, list[UnresolvedAnySpec]]:

unresolved_specs["static_predictors"] = [
UnresolvedStaticSpec(
values_lookup_name="sex_female", input_col_name_override="sex_female"
values_lookup_name="sex_female", input_col_name_override="sex_female", prefix="pred_"
)
]

Expand All @@ -564,12 +564,13 @@ def get_unresolved_temporal_predictor_specs() -> list[UnresolvedPredictorSpec]:
unresolved_temporal_predictor_specs: list[UnresolvedPredictorSpec] = []

unresolved_temporal_predictor_specs += UnresolvedLabPredictorGroupSpec(
values_lookup_name=("hba1c",),
values_lookup_name=["hba1c"],
fallback=[np.nan],
lab_values_to_load=["numerical_and_coerce"],
interval_days=[9999],
resolve_multiple_fn_name=["count"],
allowed_nan_value_prop=allowed_nan_value_prop,
output_col_name_override="eval_hba1c_count_within_9999_days",
).create_combinations()

unresolved_temporal_predictor_specs += UnresolvedLabPredictorGroupSpec(
Expand Down Expand Up @@ -606,21 +607,21 @@ def get_unresolved_temporal_predictor_specs() -> list[UnresolvedPredictorSpec]:
allowed_nan_value_prop=allowed_nan_value_prop,
).create_combinations()

unresolved_temporal_predictor_specs += UnresolvedPredictorGroupSpec(
values_lookup_name=("antipsychotics",),
interval_days=interval_days,
resolve_multiple_fn_name=resolve_multiple,
fallback=[0],
allowed_nan_value_prop=allowed_nan_value_prop,
).create_combinations()

unresolved_temporal_predictor_specs += UnresolvedPredictorGroupSpec(
values_lookup_name=["weight_in_kg", "height_in_cm", "bmi"],
interval_days=interval_days,
resolve_multiple_fn_name=["latest"],
fallback=[np.nan],
allowed_nan_value_prop=allowed_nan_value_prop,
).create_combinations()
# unresolved_temporal_predictor_specs += UnresolvedPredictorGroupSpec(
# values_lookup_name=("antipsychotics",),
# interval_days=interval_days,
# resolve_multiple_fn_name=resolve_multiple,
# fallback=[0],
# allowed_nan_value_prop=allowed_nan_value_prop,
# ).create_combinations()

# unresolved_temporal_predictor_specs += UnresolvedPredictorGroupSpec(
# values_lookup_name=["weight_in_kg", "height_in_cm", "bmi"],
# interval_days=interval_days,
# resolve_multiple_fn_name=["latest"],
# fallback=[np.nan],
# allowed_nan_value_prop=allowed_nan_value_prop,
# ).create_combinations()

return unresolved_temporal_predictor_specs

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Templates for feature specifications."""

import itertools
from abc import abstractmethod
from typing import Callable, Iterable, Literal, Optional, Sequence, Union

import pandas as pd
Expand Down Expand Up @@ -30,13 +31,23 @@ class AnySpec(BaseModel):
"""

values_df: pd.DataFrame
feature_name: str
prefix: str
# Used for column name generation, e.g. pred_<feature_name>.

input_col_name_override: Optional[str] = None
# An override for the input column name. If None, will attempt
# to infer it by looking for the only column that doesn't match id_col_name or timestamp_col_name.

output_col_name_override: Optional[str] = None
# If none, will output pred_<input_col_name>
def get_col_str(self) -> str:
"""."""
col_str = f"{self.prefix}_{self.feature_name}"

if isinstance(self, OutcomeSpec):
if self.is_dichotomous():
col_str += "_dichotomous"

return col_str


class StaticSpec(AnySpec):
Expand Down Expand Up @@ -65,7 +76,7 @@ class TemporalSpec(AnySpec):
allowed_nan_value_prop: float = 0.0

# Input col names
prefix: Optional[str] = None
prefix: str
id_col_name: str = "dw_ek_borger"
timestamp_col_name: str = "timestamp"

Expand All @@ -87,12 +98,7 @@ def __init__(self, **kwargs):

def get_col_str(self, col_main_override: Optional[str] = None) -> str:
"""."""
if self.output_col_name_override:
return self.output_col_name_override

col_main = col_main_override if col_main_override else self.feature_name

col_str = f"{self.prefix}_{col_main}_within_{self.interval_days}_days_{self.resolve_multiple_fn_name}_fallback_{self.fallback}"
col_str = f"{self.prefix}_{self.feature_name}_within_{self.interval_days}_days_{self.resolve_multiple_fn_name}_fallback_{self.fallback}"

if isinstance(self, OutcomeSpec):
if self.is_dichotomous():
Expand Down Expand Up @@ -189,7 +195,7 @@ def create_specs_from_group(

permuted_dicts = create_feature_combinations_from_dict(d=feature_group_spec_dict)

return [output_class(**d) for d in permuted_dicts]
return [output_class(**d) for d in permuted_dicts] # type: ignore


class PredictorGroupSpec(MinGroupSpec):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,10 @@ def add_age_and_date_of_birth(
static_spec=StaticSpec(
values_df=id2date_of_birth,
input_col_name_override=date_of_birth_col_name,
output_col_name_override=date_of_birth_col_name,
prefix="eval",
# We typically don't want to use date of birth as a predictor,
# but might want to use transformations - e.g. "year of birth" or "age at prediction time".
feature_name=date_of_birth_col_name,
),
)

Expand Down Expand Up @@ -695,10 +698,7 @@ def add_static_info(
else:
value_col_name = static_spec.input_col_name_override

if static_spec.output_col_name_override is None:
output_col_name = f"pred_{value_col_name}"
else:
output_col_name = f"{static_spec.output_col_name_override}"
output_col_name = f"{static_spec}_{value_col_name}"

df = pd.DataFrame(
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,20 @@ def resolve_spec(
# Infer feature_name from values_lookup_name
kwargs_dict["feature_name"] = kwargs_dict["values_lookup_name"]

# Remove the attributes that are not allowed in the outcome specs,
# Remove the attributes that are not allowed in the resolve_to_class,
# or which are inferred in the return statement.

# This implementation is super brittle - whenever a new key is added to
# any class that is resolved, but which isn't added to feature_spec_objects,
# it breaks. Alternative ideas are very welcome.
# We can get around it by allowing extras (e.g. attributes which aren't specified) in the feature_spec_objects,
# but that leaves us open to typos.
for redundant_key in (
"values_df",
"resolve_multiple_fn",
"lab_values_to_load",
"values_lookup_name",
"output_col_name_override",
):
if redundant_key in kwargs_dict:
kwargs_dict.pop(redundant_key)
Expand All @@ -56,6 +63,10 @@ def resolve_spec(
elif isinstance(self, UnresolvedStaticSpec):
resolve_to_class = StaticSpec

if self.output_col_name_override:
kwargs_dict["feature_name"] = self.output_col_name_override
kwargs_dict["prefix"] = ""

return resolve_to_class(
values_df=str2df[self.values_lookup_name], **kwargs_dict
)
Expand Down Expand Up @@ -162,3 +173,4 @@ def create_combinations(self):
class UnresolvedStaticSpec(UnresolvedAnySpec):
"""Specification for a static feature, where the df has not been
resolved."""
prefix: str = "pred"

0 comments on commit 45f8348

Please sign in to comment.