In [None]:
import sys
from pathlib import Path

from omegaconf import OmegaConf

from saiva.model.shared.demographics import DemographicFeatures
from saiva.model.shared.labs import LabFeatures
from saiva.model.shared.meds import MedFeatures
from saiva.model.shared.orders import OrderFeatures
from saiva.model.shared.vitals import VitalFeatures
from saiva.model.shared.alerts import AlertFeatures
from saiva.model.shared.rehosp import RehospFeatures
from saiva.model.shared.notes import NoteFeatures
from saiva.model.shared.diagnosis import DiagnosisFeatures
from saiva.model.shared.patient_census import PatientCensus
from saiva.model.shared.admissions import AdmissionFeatures
from saiva.model.shared.immunizations import ImmunizationFeatures
from saiva.model.shared.risks import RiskFeatures
from saiva.model.shared.assessments import AssessmentFeatures
from saiva.model.shared.adt import AdtFeatures
from saiva.model.shared.mds import MDSFeatures
import pandas as pd
import time

from saiva.model.shared.load_raw_data import fetch_training_cache_data
from saiva.model.shared.utils import get_client_class, get_memory_usage
from eliot import start_action, start_task, to_file, log_message
to_file(sys.stdout)

## Load config

In [None]:
from saiva.model.shared.constants import saiva_api, LOCAL_TRAINING_CONFIG_PATH
from saiva.training.utils import load_config

config = load_config(LOCAL_TRAINING_CONFIG_PATH)
training_config = config.training_config

In [None]:
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', -1)

In [None]:
# Load the data from local directory cache 

processed_path = Path('/data/processed')
processed_path.mkdir(parents=True, exist_ok=True)

# Replace this if necessary
CLIENT = training_config.organization_configs[0].organization_id

config = load_config("/src/saiva/conf/training/")

result_dict = fetch_training_cache_data(client=CLIENT, generic=True)
for key, value in result_dict.items():
    print(f'{key} : {result_dict[key].shape}')

In [None]:
training = True

TRAIN_START_DATE = training_config.training_metadata.experiment_dates.train_start_date
TEST_END_DATE = training_config.training_metadata.experiment_dates.test_end_date

model_version = saiva_api.model_types.get_by_model_type_id(model_type_id=training_config.model_type, version=training_config.model_version)

print(TRAIN_START_DATE)
print(TEST_END_DATE)
print(training)
print(model_version.model_type_id, model_version.id)

In [None]:
training_metadata = training_config.training_metadata
training_metadata['model_type_version_id'] = model_version.id

print(training_metadata)

conf = OmegaConf.create({'training_config': {'training_metadata': training_metadata}})
OmegaConf.save(conf, f'{LOCAL_TRAINING_CONFIG_PATH}generated/training_metadata.yaml')

In [None]:
# read from parquet file
# census_df = pd.read_parquet(processed_path/'census_df.parquet')

In [None]:
%%time

print(TRAIN_START_DATE)

patient_census = PatientCensus(
            census_df=result_dict.get('patient_census', None),
            train_start_date=TRAIN_START_DATE,
            test_end_date=TEST_END_DATE,
        )
census_df = patient_census.generate_features()

# Write to new parquet file
census_df.to_parquet(processed_path/'census_df.parquet')

print(census_df.shape)
census_df.head(3)

In [None]:
def generate_feature_group(feature_group_class, feature_group_kwargs, processed_file_path, to_print=False):
    
    start_time = time.time()
    
    feature_group = feature_group_class(**feature_group_kwargs)
    df = feature_group.generate_features()

    # Write to new parquet file
    if not isinstance(df, pd.DataFrame):
        df = df[0]
    df.to_parquet(processed_file_path)
    duration = time.time() - start_time
    if to_print:
        print(df.shape)
        display(df.head(3))
        print(f'Wall time for {feature_group_class.__name__}: {time.strftime("%H:%M:%S", time.gmtime(duration))}','\n')
    return

In [None]:
feature_groups_spec = [
    (DemographicFeatures, {'demo_df': 'patient_demographics'}, {}, 'demo_df.parquet'),
    (VitalFeatures, {'vitals': 'patient_vitals'}, {'config': config}, 'vitals_df.parquet'),
    (OrderFeatures, {'orders': 'patient_orders'}, {'config': config}, 'orders_df.parquet'),
    (MedFeatures, {'meds': 'patient_meds'}, {'config': config}, 'meds_df.parquet'),
    (AlertFeatures, {'alerts': 'patient_alerts'}, {'config': config}, 'alerts_df.parquet'),
    (LabFeatures, {'labs': 'patient_lab_results'}, {'config': config}, 'labs_df.parquet'),
    (RehospFeatures, {'rehosps': 'patient_rehosps', 'adt_df': 'patient_adt'}, {'config': config, 'train_start_date': TRAIN_START_DATE}, 'rehosp_df.parquet'),
    (AdmissionFeatures, {'admissions': 'patient_admissions'}, {}, 'admissions_df.parquet'),
    (DiagnosisFeatures, {'diagnosis': 'patient_diagnosis'}, {'diagnosis_lookup_ccs_s3_file_path': model_version.diagnosis_lookup_ccs_s3_uri, 'config': config}, 'diagnosis_df.parquet'),
    (NoteFeatures, {'notes': 'patient_progress_notes'}, {'client': CLIENT, 'vector_model': training_metadata.vector_model}, 'notes_df.parquet'),
    (ImmunizationFeatures, {'immuns_df': 'patient_immunizations'}, {'config': config}, 'immuns_df.parquet'),
    (RiskFeatures, {'risks_df': 'patient_risks'}, {'config': config}, 'risks_df.parquet'),
    (AssessmentFeatures, {'assessments_df': 'patient_assessments'}, {'config': config}, 'assessments_df.parquet'),
    (AdtFeatures, {'adt_df': 'patient_adt'}, {'config': config}, 'adt_df.parquet'),
    #(MDSFeatures, {'mds_df': 'patient_mds', 'adt_df': 'patient_adt'}, {'config': config}, 'mds_df.parquet'), # switched off by default
]

In [None]:
import gc
for feature_group_class, result_keys, params, file_name in feature_groups_spec:
    
    if any([result_dict.get(result_key, pd.DataFrame()).empty for result_key in result_keys.values()]):
        continue
    
    kwargs = {key: result_dict.get(value, pd.DataFrame()) for key, value in result_keys.items()}
    kwargs.update(params)
    kwargs['census_df'] = census_df.copy()
    kwargs['training'] = training
    generate_feature_group(
        feature_group_class,
        kwargs,
        processed_path/file_name,
        True
    )
    gc.collect()