In [None]:
import pandas as pd
# import modin.pandas as pd
from pathlib import Path
import os
import timeit
from datetime import timedelta

In [None]:
import sys
from omegaconf import OmegaConf

from saiva.model.shared.load_raw_data import get_genric_file_names
from eliot import to_file
to_file(sys.stdout)

## Load config

In [None]:
from saiva.model.shared.constants import saiva_api, LOCAL_TRAINING_CONFIG_PATH
from saiva.training.utils import load_config

config = load_config(LOCAL_TRAINING_CONFIG_PATH)
training_config = config.training_config

In [None]:
# Replace this if necessary
CLIENT = training_config.organization_configs[0].organization_id

# Strip client name and get the actual dataframe names
data_path = Path('/data/raw')
client_file_types = get_genric_file_names(data_path=data_path, client=CLIENT)

print(client_file_types)

In [None]:
"""
- When using other clients data, merge respective files across all 
  given clients and rename them removing the client name
- Append client_name to masterpatientid
- Add a client column
"""
start_time = timeit.default_timer()
for ft in client_file_types:
    try:
        # Fetch same file across client
        client_files = data_path.glob(f'*_{ft}.parquet')

        df = pd.DataFrame()
        
        # Loop through all the files and combine them
        for f in client_files:
            client = f.name.split('_')[0]
            client_df = pd.read_parquet(f)
            if len(client_df) == 0:
                continue
            client_df['masterpatientid'] = client_df['masterpatientid'].apply(lambda x: client + '_' + str(x))
            client_df['client'] = client
            df = df.append(client_df, ignore_index=True)
            print(f, len(client_df), len(df))

        if ft == 'patient_demographics':
            df['dateofbirth'] = df['dateofbirth'].astype('datetime64[ms]')
            # df['dateofbirth'] = pd.to_datetime(df['dateofbirth'], errors='coerce') # force convert and set invalid values to NaT
            
        if ft == 'patient_diagnosis':
            df['onsetdate'] = df['onsetdate'].astype('datetime64[ms]')
            
        df.to_parquet(data_path/f'{ft}.parquet')
        print('============================')
    except Exception as e:
        print(ft, 'failed:', e)
print(f"{timeit.default_timer() - start_time} seconds")

### test + validation set is 25%
## Note: Obtained training, test, validation dates needs to be added in client respective file under `get_experiment_dates` function!!

In [None]:
# test + validation set is 25%
# Note: Obtained test, validation dates needs to be added in client respective file under `get_experiment_dates` function

def get_prior_date_as_str(date_as_str):
    prior_date = pd.to_datetime(date_as_str) - timedelta(days=1)
    prior_date_as_str = prior_date.date().strftime('%Y-%m-%d')
    return prior_date_as_str


data_path = Path('/data/raw')
df = pd.read_parquet(data_path/'patient_census.parquet')

df.drop_duplicates(
    subset=['masterpatientid', 'censusdate'],
    keep='last',
    inplace=True
)
df.sort_values(by=['censusdate'], inplace=True)

total_count = df.shape[0]
test_count = int((total_count * 25) / 100)
test_split_count = int((test_count * 50) / 100) # split between validation & test set

test_df = df.tail(test_count) # cut last n rows
validation_df = test_df.head(test_split_count)
test_df = test_df.tail(test_split_count)

train_start_date = df.censusdate.min().date().strftime('%Y-%m-%d')
validation_start_date = validation_df.censusdate.min().date().strftime('%Y-%m-%d')
test_start_date = test_df.censusdate.min().date().strftime('%Y-%m-%d')
test_end_date = test_df.censusdate.max().date().strftime('%Y-%m-%d')

train_end_date = get_prior_date_as_str(validation_start_date)
validation_end_date = get_prior_date_as_str(test_start_date)

print(f'train_start_date: {train_start_date}')
print(f'train_end_date: {train_end_date}')
print(f'validation_start_date: {validation_start_date}')
print(f'validation_end_date: {validation_end_date}')
print(f'test_start_date: {test_start_date}')
print(f'test_end_date: {test_end_date}')

## Double check that you have non-overlapping dates
```    
   Example:
       return {
            'train_start_date': '2019-09-01',
            'train_end_date': '2021-02-21',
            'validation_start_date': '2021-02-22',
            'validation_end_date': '2021-06-06',
            'test_start_date': '2021-06-07',
            'test_end_date': '2021-09-15'
        }
```

We update the experiment configuration in OrganizationMlModelConfig training_metadata

In [None]:
training_metadata = training_config.training_metadata
training_metadata['experiment_dates'] = {
    'train_start_date': train_start_date,
    'train_end_date': train_end_date,
    'validation_start_date': validation_start_date,
    'validation_end_date': validation_end_date,
    'test_start_date': test_start_date,
    'test_end_date': test_end_date

}

print(training_metadata)

conf = OmegaConf.create({'training_config': {'training_metadata': training_metadata}})
OmegaConf.save(conf, f'{LOCAL_TRAINING_CONFIG_PATH}generated/training_metadata.yaml')

### ======================== TESTING ================================

In [None]:
# Load generic named Training data which is cached in local folders
# from shared.load_raw_data import fetch_training_cache_data

# result_dict = fetch_training_cache_data(client=CLIENT, generic=True)
# for key, value in result_dict.items():
#     print(f'{key} : {result_dict[key].info()}')

In [None]:
# Remove all newly generated parquet files

# for ft in client_file_types:
#     os.remove(data_path/f'{ft}.parquet')