Before starting the training:
1. Configure `CLIENT` in `constants.py`
2. Create or edit the client-specific file in `src/clients/`; use `_template.py` and follow the instructions

In [None]:
import sys
import pandas as pd
import shutil
import os
import glob

from omegaconf import OmegaConf

from saiva.model.shared.load_raw_data import fetch_training_data, fetch_training_cache_data
from saiva.model.shared.database import DbEngine
from saiva.model.shared.constants import LOCAL_TRAINING_CONFIG_PATH
from saiva.training.utils import load_config

from eliot import to_file
to_file(sys.stdout)

## Load config

In [None]:
config = load_config(LOCAL_TRAINING_CONFIG_PATH)
training_config = config.training_config

In [None]:
print("TRAIN_START_DATE:", training_config.training_metadata.experiment_dates.train_start_date)
print("TEST_END_DATE:", training_config.training_metadata.experiment_dates.test_end_date)
print("CLIENTS:", [organization_config.organization_id for organization_config in training_config.organization_configs_setup])

### ======================== Load Database ========================

In [None]:
engine = DbEngine()
saiva_engine = engine.get_postgresdb_engine()

### ======================== Fetch data for all organizations ============================

In [None]:
# Loads the data from SQL db and store them in local directory as cache

missing_datasets = set()

for organization_config in training_config.organization_configs:
    client_sql_engine = engine.get_sqldb_engine(
        db_name=organization_config.datasource.source_database_name,
        credentials_secret_id=organization_config.datasource.source_database_credentials_secret_id,
        query={"driver": "ODBC Driver 17 for SQL Server"}
    )
    
    # verify connectivity
    engine.verify_connectivity(client_sql_engine)

    result_dict = fetch_training_data(
        client=organization_config.organization_id, 
        client_sql_engine=client_sql_engine, 
        train_start_date=training_config.training_metadata.experiment_dates.train_start_date,
        test_end_date=training_config.training_metadata.experiment_dates.test_end_date,
    )

    for dataset in training_config.all_datasets:
        if result_dict.get(dataset, pd.DataFrame()).empty:
            missing_datasets.add(dataset)
            continue
        print(dataset,result_dict[dataset].shape)

    df = result_dict['patient_demographics']
    df['dateofbirth'] = pd.to_datetime(df['dateofbirth'], errors='coerce')
    df.to_parquet(f'/data/raw/{organization_config.organization_id}_patient_demographics.parquet' , index=False)


In [None]:
training_metadata = training_config.training_metadata
training_metadata['missing_datasets'] = list(missing_datasets)

print(training_metadata)

conf = OmegaConf.create({'training_config': {'training_metadata': training_metadata}})
OmegaConf.save(conf, f'{LOCAL_TRAINING_CONFIG_PATH}generated/training_metadata.yaml')

### ==================== If Multiple clients data need to be merged ====================

In [None]:
# Loads the data from SQL db for multiple clints and store them in local directory as cache

# for client in ['avante','gulfshore','palmgarden']:
#     print(f'*********************** Processing for {client} ******************************')
#     clientClass = get_client_class(client)
#     EXPERIMENT_DATES = getattr(clientClass(), 'get_experiment_dates')()
#     TRAIN_START_DATE, TEST_END_DATE = EXPERIMENT_DATES['train_start_date'], EXPERIMENT_DATES['test_end_date']
#     print(TRAIN_START_DATE, TEST_END_DATE)
    
#     engine = DbEngine()
#     saiva_engine = engine.get_postgresdb_engine()
#     client_sql_engine = engine.get_sqldb_engine(clientdb_name=client)
#     engine.verify_connectivity(client_sql_engine)
#     result_dict = fetch_training_data(client, client_sql_engine, TRAIN_START_DATE, TEST_END_DATE)
    
#     print('master_patient_lookup', result_dict['master_patient_lookup'].shape)
#     print('patient_census',result_dict['patient_census'].shape)
#     print('patient_rehosps',result_dict['patient_rehosps'].shape)
#     print('patient_demographics',result_dict['patient_demographics'].shape)
#     print('patient_diagnosis',result_dict['patient_diagnosis'].shape)
#     print('patient_vitals',result_dict['patient_vitals'].shape)
#     print('patient_meds',result_dict['patient_meds'].shape)
#     print('patient_orders',result_dict['patient_orders'].shape)
#     print('patient_alerts',result_dict['patient_alerts'].shape)
#     print('patient_progress_notes',result_dict['patient_progress_notes'].shape)
#     if not result_dict.get('patient_lab_results', pd.DataFrame()).empty:
#         print('patient_lab_results',result_dict['patient_lab_results'].shape)
#     print(result_dict.keys())

### ======================== TESTING ==========================

In [None]:
# Once fetch_training_data loads the data, use the same cache 

# result_dict = fetch_training_cache_data(CLIENT)

# print('master_patient_lookup', result_dict['master_patient_lookup'].shape)
# print('patient_census',result_dict['patient_census'].shape)
# print('patient_rehosps',result_dict['patient_rehosps'].shape)
# print('patient_demographics',result_dict['patient_demographics'].shape)
# print('patient_diagnosis',result_dict['patient_diagnosis'].shape)
# print('patient_vitals',result_dict['patient_vitals'].shape)
# print('patient_meds',result_dict['patient_meds'].shape)
# print('patient_orders',result_dict['patient_orders'].shape)
# print('patient_alerts',result_dict['patient_alerts'].shape)
# print('patient_progress_notes',result_dict['patient_progress_notes'].shape)
# if not result_dict.get('patient_lab_results', pd.DataFrame()).empty:
#     print('patient_lab_results',result_dict['patient_lab_results'].shape)
# print(result_dict.keys())

# have a max of 15042 master_patient_lookup rows ie. Infinity-Infinity

In [None]:
# TESTING specific queries

# query=f"""
#         select distinct patientid, facilityid, orderdate, gpiclass, 
#         gpisubclassdescription, orderdescription, pharmacymedicationname, a.PhysiciansOrderID
#         from view_ods_physician_order_list_v2 a
#         inner join view_ods_physician_order_list_med b
#         on a.PhysiciansOrderID = b.PhysiciansOrderID 
#         where orderdate between '{train_start_date}' and '{test_end_date}';
#         """

# df = pd.read_sql(query, con=client_sql_engine)
# print(df.shape)
# df.head()