In [1]:
import sys
import os
import numpy as np
import pandas as pd
import string
import pickle
from tqdm.autonotebook import tqdm
from typing import *


import multiprocessing as mp
from functools import partial

  from tqdm.autonotebook import tqdm


In [2]:
df_patients = pd.read_stata('../.data/ADRIAN_2018/patients_active_revised.dta') 
df_script_item = pd.read_stata('../.data/ADRIAN_2018/script_item_active_WITHOUT_gold_standard_revised.dta') 
df_reason_prescription = pd.read_stata('../.data/ADRIAN_2018/reason_prescription_active_WITHOUT_gold_standard_revised.dta') 
df_pathology = pd.read_stata('../.data/ADRIAN_2018/pathology_active_revised.dta') 
df_immunisations = pd.read_stata('../.data/ADRIAN_2018/immunisations_active_revised.dta') 
df_diagnosis_encounter = pd.read_stata('../.data/ADRIAN_2018/diagnosis_encounter_active_WITHOUT_gold_standard_revised.dta') 
df_observations = pd.read_stata('../.data/ADRIAN_2018/observations_active_WITHOUT_gold_standard_revised.dta') 

# check we have all necessary columns + add any dates columns if possible
dates = set(['visit_date', 'result_date', 'given_date', 'observation_date', 'hpt', 'hpt_date'])
df_patients_columns = set(['patientid', 'aborig', 'gender2', 'smoke', 'irsad', 'ieo', 'ieo_q', 'irsad_q']) | (set(df_patients.columns) & dates)
df_script_item_columns = set(['patientid', 'dose', 'frequency', 'med_active_ingr', 'med_name', 'quantity', 'repeats', 'strength']) | (set(df_script_item.columns) & dates)
df_reason_prescription_columns = set(['patientid', 'first_name_presc', 'first_ingred_presc', 'reason']) | (set(df_reason_prescription.columns) & dates)
df_pathology_columns = set(['patientid', 'result_name', 'result_value']) | (set(df_pathology.columns) & dates)
df_immunisations_columns = set(['patientid', 'vaccine_name']) | (set(df_immunisations.columns) & dates)
df_observations_columns = set(['patientid', 'observation_name', 'observation_value']) | (set(df_observations.columns) & dates)
df_diagnosis_encounter_columns = set(['patientid', 'reason']) | (set(df_diagnosis_encounter.columns) & dates)

all(
    [
    (lambda cols: (set(df_patients.columns) and cols) == set(cols))(df_patients_columns),
    (lambda cols: (set(df_script_item.columns) and cols) == set(cols))(df_script_item_columns),
    (lambda cols: (set(df_reason_prescription.columns) and cols) == set(cols))(df_reason_prescription_columns),
    (lambda cols: (set(df_pathology.columns) and cols) == set(cols))(df_pathology_columns),
    (lambda cols: (set(df_immunisations.columns) and cols) == set(cols))(df_immunisations_columns),
    (lambda cols: (set(df_observations.columns) and cols) == set(cols))(df_observations_columns),
    (lambda cols: (set(df_diagnosis_encounter.columns) and cols) == set(cols))(df_diagnosis_encounter_columns),
    ]
)

# Only have columns I want
df_patients = df_patients.drop(columns=set(df_patients.columns) - df_patients_columns)
df_script_item = df_script_item.drop(columns=set(df_script_item.columns) - df_script_item_columns)
df_reason_prescription = df_reason_prescription.drop(columns=set(df_reason_prescription.columns) - df_reason_prescription_columns)
df_pathology = df_pathology.drop(columns=set(df_pathology.columns) - df_pathology_columns)
df_immunisations = df_immunisations.drop(columns=set(df_immunisations.columns) - df_immunisations_columns)
df_observations = df_observations.drop(columns=set(df_observations.columns) - df_observations_columns)
df_diagnosis_encounter = df_diagnosis_encounter.drop(columns=set(df_diagnosis_encounter.columns) - df_diagnosis_encounter_columns)



In [3]:
# (cleanup)Rename different dates to all be the same col name 
df_script_item = df_script_item.rename(columns={'visit_date': 'field_date'})
df_reason_prescription = df_reason_prescription.rename(columns={'visit_date': 'field_date'})
df_pathology = df_pathology.rename(columns={'result_date': 'field_date'})
df_immunisations = df_immunisations.rename(columns={'given_date': 'field_date'})
df_observations = df_observations.rename(columns={'observation_date': 'field_date'})
df_diagnosis_encounter = df_diagnosis_encounter.rename(columns={'visit_date': 'field_date'})

In [4]:
# check we have all necessary columns + add any dates columns if possible
dates = set(['field_date', 'hpt', 'hpt_date'])
df_patients_columns = set(['patientid', 'aborig', 'gender2', 'smoke', 'irsad', 'ieo', 'ieo_q', 'irsad_q']) | (set(df_patients.columns) & dates)
df_script_item_columns = set(['patientid', 'dose', 'frequency', 'med_active_ingr', 'med_name', 'quantity', 'repeats', 'strength']) | (set(df_script_item.columns) & dates)
df_reason_prescription_columns = set(['patientid', 'first_name_presc', 'first_ingred_presc', 'reason']) | (set(df_reason_prescription.columns) & dates)
df_pathology_columns = set(['patientid', 'result_name', 'result_value']) | (set(df_pathology.columns) & dates)
df_immunisations_columns = set(['patientid', 'vaccine_name']) | (set(df_immunisations.columns) & dates)
df_observations_columns = set(['patientid', 'observation_name', 'observation_value']) | (set(df_observations.columns) & dates)
df_diagnosis_encounter_columns = set(['patientid', 'reason']) | (set(df_diagnosis_encounter.columns) & dates)

all(
    [
    (lambda cols: (set(df_patients.columns) and cols) == set(cols))(df_patients_columns),
    (lambda cols: (set(df_script_item.columns) and cols) == set(cols))(df_script_item_columns),
    (lambda cols: (set(df_reason_prescription.columns) and cols) == set(cols))(df_reason_prescription_columns),
    (lambda cols: (set(df_pathology.columns) and cols) == set(cols))(df_pathology_columns),
    (lambda cols: (set(df_immunisations.columns) and cols) == set(cols))(df_immunisations_columns),
    (lambda cols: (set(df_observations.columns) and cols) == set(cols))(df_observations_columns),
    (lambda cols: (set(df_diagnosis_encounter.columns) and cols) == set(cols))(df_diagnosis_encounter_columns),
    ]
)

True

In [5]:
def patient__get_patient_entry(patient_id: str) -> List[dict]:
    # get dataframe and only this patientid
    df = df_patients[
        (df_patients["patientid"] == patient_id)
    ].copy()

    # for every sample in dataset
    dataset_entries = []
    for idx, dataset_row in df.iterrows():
        dataset_entries.append(dataset_row.to_json())

    return dataset_entries

def script_item_active__get_patient_entry(patient_id: str) -> List[dict]:
    # get dataframe and only this patientid
    df = df_script_item[
        (df_script_item["patientid"] == patient_id)
    ].copy()
    
    # for every sample in dataset
    dataset_entries = []
    for idx, dataset_row in df.iterrows():
        dataset_entries.append(dataset_row.to_json())

    return dataset_entries


def reason_prescription_active__get_patient_entry(patient_id: str) -> List[dict]:
    # get dataframe and only this patientid
    df = df_reason_prescription[
        (df_reason_prescription["patientid"] == patient_id)
    ].copy()

    # for every sample in dataset
    dataset_entries = []
    for idx, dataset_row in df.iterrows():
        dataset_entries.append(dataset_row.to_json())

    return dataset_entries


def pathology_result2015_2017_active__get_patient_entry(patient_id: str) -> List[dict]:
    # get dataframe and only this patientid
    df = df_pathology[
        (
            df_pathology["patientid"]
            == patient_id
        )
    ].copy()

    # for every sample in dataset
    dataset_entries = []
    for idx, dataset_row in df.iterrows():
        dataset_entries.append(dataset_row.to_json())

    return dataset_entries


def immunisation_active__get_patient_entry(patient_id: str) -> List[dict]:
    # get dataframe and only this patientid
    df = df_immunisations[
        (
            df_immunisations["patientid"]
            == patient_id
        )
    ].copy()

    # for every sample in dataset
    dataset_entries = []
    for idx, dataset_row in df.iterrows():
        dataset_entries.append(dataset_row.to_json())

    return dataset_entries


def observation_active__get_patient_entry(patient_id: str) -> List[dict]:
    # get dataframe and only this patientid
    df = df_observations[
        (
            df_observations["patientid"]
            == patient_id
        )
    ].copy()

    # for every sample in dataset
    dataset_entries = []
    for idx, dataset_row in df.iterrows():
        dataset_entries.append(dataset_row.to_json())

    return dataset_entries


def encounter_reason_active__get_patient_entry(patient_id: str) -> List[dict]:
    # get dataframe and only this patientid
    df = df_diagnosis_encounter[
        (
            df_diagnosis_encounter["patientid"]
            == patient_id
        )
    ].copy()

    # for every sample in dataset
    dataset_entries = []
    for idx, dataset_row in df.iterrows():
        dataset_entries.append(dataset_row.to_json())

    return dataset_entries

diag_conversion = {
    'Yes': 1,
    'No': 0,
    'Maybe': -1
}

def generate_patient_data(active_patient_row: pd.core.series.Series) -> dict:
    """
    Given a pandas dataframe, generate json reports over various datasets
    """
    patient_sample = {
        "id": active_patient_row.patientid,
        "data": {},
        "label": diag_conversion[active_patient_row.hpt],
    }

    patient_sample["data"]["patient"] = patient__get_patient_entry(
        active_patient_row.patientid
    )

    patient_sample["data"]["scripts"] = script_item_active__get_patient_entry(
        active_patient_row.patientid
    )

    patient_sample["data"]["prescriptions"] = reason_prescription_active__get_patient_entry(
        active_patient_row.patientid
    )

    patient_sample["data"]["pathology"] = pathology_result2015_2017_active__get_patient_entry(
        active_patient_row.patientid
    )

    patient_sample["data"]["immunisations"] = immunisation_active__get_patient_entry(
        active_patient_row.patientid
    )

    patient_sample["data"]["observations"] = observation_active__get_patient_entry(
        active_patient_row.patientid
    )

    patient_sample["data"]["encounter_reasons"] = encounter_reason_active__get_patient_entry(
        active_patient_row.patientid
    )
        
    return patient_sample

In [1]:
import json
import pathlib
medical_corpus_pathlib = pathlib.Path('../.data/processed_medical_corpus.json')

In [2]:
if medical_corpus_pathlib.is_file():
    with open(medical_corpus_pathlib, 'r') as f:
        results = json.load(f)
        print(len(results))
else:
    results = []
    print('not file')

735333


In [30]:
multiprocess = True

def f(job) -> dict:
    index, row = job
    return generate_patient_data(row)

if __name__ == "__main__":
    job_generator = list(df_patients[['patientid', 'hpt']][len(results):].iterrows())

    if multiprocess:
        with mp.Pool(12) as pool:
            for result in tqdm(pool.imap_unordered(f, job_generator, chunksize=8), total=len(job_generator)):
                results.extend([result])
    else:
        for row_idx, patient in tqdm(job_generator, total=len(job_generator)):
            result = generate_patient_data(patient)
            results.extend([result])
    
    with open(medical_corpus_pathlib, 'w') as f:
        json.dump(results, f)

  0%|          | 0/685333 [00:00<?, ?it/s]

# Building train/val & schema

In [31]:
import json

with open(medical_corpus_pathlib, 'r') as f:
    medical_dataset = json.load(f)
    print(len(medical_dataset))

735333


## Raw Data sanitisation
For each sample, get the data, for every dataset parse into json, for every field in dataset group by same date (useful for shuffling data), compile a list starting from earliest record ending in either latest or when we reach their hypertension diagnosise date. 

For every field, remove set(['field_date', 'hpt', 'hpt_date', 'patientid']), and check if there is anything else to clean.

In [32]:
from loguru import logger
from itertools import groupby
from operator import itemgetter
from typing import Dict, Any, List, Set
from collections import OrderedDict
import dataclasses
from src.datasets.medical_data import PatientData

Json = Dict[str, Any]

class RemoveKeysAndCastToString:
    def __init__(self, disallowed_keys: Set[str]):
        self.disallowed_keys = disallowed_keys

    def __call__(self, dic: dict) -> OrderedDict:
        return OrderedDict((
            (k,str(v)) 
            for k,v in dic.items() 
            if k not in self.disallowed_keys
        ))
    
def make_time_grouped_data(data: List[str]) -> List[List[Json]]:
    """
    So that data in the same day can be randomly shuffled.
    Return a list of ordered dicts, also sanitise.
    """

    parsed_data = [json.loads(x) for x in data]

    if not parsed_data:
        # logger.warning('No data parsed')
        return None
    
    if 'field_date' in parsed_data[0]:
        parsed_data.sort(key=itemgetter('field_date'))
        time_ordered_and_grouped = OrderedDict(
            (date, list(group_generator)) 
            for (date, group_generator) in groupby(parsed_data,key=itemgetter('field_date'))   
        )
    else:
        time_ordered_and_grouped = {0: parsed_data}

    # sanitise: remove data after diagnosis date 
    if parsed_data[0].get('hpt') == 'Yes':
        time_ordered_and_grouped = OrderedDict(
            (time, group)
            for time, group in time_ordered_and_grouped.items()
            if time < parsed_data[0].get('hpt_date')
        )

    if not time_ordered_and_grouped:
        # logger.warning('No data before hypertensive diagnosis')
        return None
    
    # sanitise: remove keys we don't want the model to see
    key_remover = RemoveKeysAndCastToString({'patientid', 'field_date', 'hpt', 'hpt_date'})
    time_ordered_and_grouped = OrderedDict(
        (time, [key_remover(field) for field in group])
        for time, group in time_ordered_and_grouped.items()
    )

    return time_ordered_and_grouped


def dfs_unpack_json(
    parent: Any, branch: str = "", separator: str = "."
) -> Generator[Tuple[str, str], None, None]:
    if isinstance(parent, dict):
        for key, value in parent.items():
            if isinstance(key, int):
                # we can ignore the date grouping as a key
                yield from dfs_unpack_json(
                    parent=value,
                    branch=branch,
                    separator=separator,
                )
            else:
                yield from dfs_unpack_json(
                    parent=value,
                    branch=branch + (separator if branch else "") + key,
                    separator=separator,
                )
    elif type(parent) in [list, tuple]:
        for value in parent:
            yield from dfs_unpack_json(parent=value, branch=branch, separator=separator)
    else:
        yield branch, str(parent)
        
def unpack_flattened_keys(sample):
    return set(key for key, value in dfs_unpack_json(sample))



sanitised_medical_data = []
all_medical_data_keys = set()
for sample in tqdm(medical_dataset, total=len(medical_dataset)):
    patient_id = sample.get('id')
    patient_label = sample.get('label')
    
    data = OrderedDict((k, make_time_grouped_data(v)) for k,v in sample.get('data').items())
    
    no_datas = [type(patient_data) == type(None) for k, patient_data in data.items()]
    if all(no_datas):
        assert False, 'No data parsed'
    
    all_medical_data_keys = all_medical_data_keys | unpack_flattened_keys(data)
    
    ordered_patient_data = OrderedDict()
    ordered_patient_data['id'] = patient_id
    ordered_patient_data['label'] = patient_label
    ordered_patient_data['data'] = data
    
    sanitised_medical_data.append(PatientData(
        uid = patient_id,
        label = patient_label,
        data = data
    ))

  0%|          | 0/735333 [00:00<?, ?it/s]

In [33]:
import random
positives = [patient_data for patient_data in sanitised_medical_data if patient_data.human_readable_label == "Yes"]

negatives = [patient_data for patient_data in sanitised_medical_data if patient_data.human_readable_label == "No"]

maybies = [patient_data for patient_data in sanitised_medical_data if patient_data.human_readable_label == "Maybe"]
random.shuffle(positives)
random.shuffle(negatives)
random.shuffle(maybies)
len(positives), len(negatives), len(maybies)

(59016, 626682, 49635)

In [34]:
import itertools
from pathlib import Path
TRAIN_DATASET_FILEPATH = '../.data/medical_data_train_with_metadata'
VAL_DATASET_FILEPATH = '../.data/medical_data_val_with_metadata'

num_positives = int(len(positives)*0.8)
num_negatives = int(len(negatives)*0.8)
train_dataset = list(itertools.chain(positives[:num_positives], negatives[:num_negatives]))
val_dataset = list(itertools.chain(positives[num_positives:], negatives[num_negatives:]))

## Schema
We want each field to come with a token that represnts dataset start and end, field, field_runes_padded

In [35]:
schema = {
    'schema': sorted(list(all_medical_data_keys))
}

with open('schema.json', 'w') as outfile:
    json.dump(schema, outfile)

## Saving samples

In [36]:
import torch
def write_data_to_disk(data: PatientData, file_path: pathlib.Path) -> bool:
    file_path.parent.mkdir(parents=True, exist_ok=True)
    
    torch.save(data, file_path)

    return True

In [37]:
import random
import itertools
from pathlib import Path
TRAIN_DATASET_FILEPATH = '../.data/medical_data_train_with_metadata'
VAL_DATASET_FILEPATH = '../.data/medical_data_val_with_metadata'

train_dataset = list(itertools.chain(positives[:num_positives], negatives[:num_negatives]))
val_dataset = list(itertools.chain(positives[num_positives:], negatives[num_negatives:]))

for sample in tqdm(train_dataset, total=len(train_dataset)):
    filename = (
        Path(TRAIN_DATASET_FILEPATH)
        / f"class{sample.label}"
        / f"sample_patientid_{sample.uid}.pt"
    )

    write_data_to_disk(
        data=sample, file_path=filename
    )

for sample in tqdm(val_dataset, total=len(val_dataset)):
    filename = (
        Path(VAL_DATASET_FILEPATH)
        / f"class{sample.label}"
        / f"sample_patientid_{sample.uid}.pt"
    )

    write_data_to_disk(
        data=sample, file_path=filename
    )

  0%|          | 0/548557 [00:00<?, ?it/s]

  0%|          | 0/137141 [00:00<?, ?it/s]

In [25]:
import glob
import pathlib

len(list(pathlib.Path('../.data/medical_data_train_with_metadata/class0/').glob('*.pt')))

34108

In [26]:
len(list(pathlib.Path('../.data/medical_data_train_with_metadata/class1/').glob('*.pt')))

3170

In [27]:
len(list(pathlib.Path('../.data/medical_data_val_with_metadata/class0/').glob('*.pt')))

8528

In [28]:
len(list(pathlib.Path('../.data/medical_data_val_with_metadata/class1/').glob('*.pt')))

793

In [12]:
import torch
import glob
import dataclasses
from src.datasets.medical_data import PatientData

for patient_id in [1611489, 1509283, 2569768, 2608285, 2650036, 1638824, 2783447, 1998712,
        3016693,  715960, 2445458, 2864352, 2088872,  516763, 1455156, 1497152,
        1362617, 2729893, 2546927, 2162700, 1028866, 1862622,  696967, 2509926,
        3162513,  856450,  774542, 1270402, 2593633, 1995239, 1564290,  568979]:

    print(glob.glob(f'../.data/medical_data_train_with_metadata/*/sample_patientid_{str(patient_id)}.pt'))
    for file in glob.glob(f'../.data/medical_data_train_with_metadata/*/sample_patientid_{str(patient_id)}.pt'):
#         print(file)
        sample = torch.load(file)
        print(sample)
#         for patient_dataset in sample.data:
#             print(len(patient_dataset))
    
# sample

['../.data/medical_data_train_with_metadata/class1/sample_patientid_1611489.pt']
PatientData(uid=1611489, label=1, data=OrderedDict([('patient', OrderedDict([(0, [OrderedDict([('aborig', 'Not stated'), ('gender2', 'Male'), ('smoke', 'Non smoker'), ('irsad', '9'), ('ieo', '8.0'), ('ieo_q', '2nd upper quintile'), ('irsad_q', 'Upper quintile')])])])), ('scripts', OrderedDict([(1461283200000, [OrderedDict([('dose', '1'), ('frequency', '5'), ('med_active_ingr', 'TEMAZEPAM'), ('med_name', 'NORMISON'), ('quantity', '25'), ('repeats', '0'), ('strength', '10mg')]), OrderedDict([('dose', '½'), ('frequency', '2'), ('med_active_ingr', 'ESCITALOPRAM OXALATE'), ('med_name', 'LEXAPRO'), ('quantity', '28'), ('repeats', '0'), ('strength', '10mg')])])])), ('prescriptions', OrderedDict([(1461283200000, [OrderedDict([('first_name_presc', 'NORMISON'), ('first_ingred_presc', 'TEMAZEPAM'), ('reason', '')])])])), ('pathology', None), ('immunisations', None), ('observations', OrderedDict([(1461283200000, [Orde