In [1]:
# https://www.ptsd.va.gov/understand_tx/meds_for_ptsd.asp

In [176]:
import sdv
import pandas as pd
from random import sample, choice, randint
import random
import uuid
import numpy as np
import torch
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModel  
from scipy.spatial import distance

# Get ICD and prescription tables

In [177]:
icd = pd.read_csv("../mimic-iv-demo/hosp/d_icd_diagnoses.csv")
prescritpton = pd.read_csv("../mimic-iv-demo/hosp/prescriptions.csv")

# Create Population Table

size between 200-300,  ages are chosen randomly between of 22 and 88

In [178]:
## Personal demographics table

random.seed(9)

person = [{"person_id":randint(0,9999999999),"age":randint(22,88)} for y in range(0, 300)]

person_df = pd.DataFrame(person)
bins = [22,33, 44,55, 66,77,88]
labels = ["'22-33'","'33-44'","'44-55'","'55-66'","'66-77'","'77-88'"]

person_df['age_bracket'] = pd.cut(person_df['age'], bins=bins, labels=labels)
person_df['age_bracket'] = person_df['age_bracket'].astype(str)

print("POPULATION SIZE: ",len(person), "\nSAMPLE: ") 
print(person_df.head(n=5))


POPULATION SIZE:  300 
SAMPLE: 
    person_id  age age_bracket
0  5898329840   39     '33-44'
1  2906162014   65     '55-66'
2  6454399878   32     '22-33'
3  3006899318   70     '66-77'
4  9317563040   79     '77-88'


This data defines the VASRD codes and medication. Since medicaton and VA Disability benefits are 

In [179]:
vasrd = {
    "9412": "Panic disorder and/or agoraphobia",
    "9411": "Posttraumatic stress disorder",
    "6602": "Asthma"
}

medication_ptsd = [
    "Sertraline",
    "Paroxetine"
]

medication_panic = [
    "FLUoxetine",
    "Sertraline",
    "Lorazepam"
]

medication_asthma = [
    "PredniSONE",
    "Salmeterol Xinafoate Diskus (50 mcg)",
    "Albuterol 0.083% Neb Soln",
    "Fluticasone Propionate NASAL"
]

In [180]:
random.seed(4)
benefits_dim = [{"benefit_id":randint(0,9999999999),"vasrd":vasrd_row, "vasrd desc": vasrd[vasrd_row]} for vasrd_row in vasrd]
benefits_dim = pd.DataFrame(benefits_dim)
benefits_dim

Unnamed: 0,benefit_id,vasrd,vasrd desc
0,5308786135,9412,Panic disorder and/or agoraphobia
1,9033029319,9411,Posttraumatic stress disorder
2,5996024489,6602,Asthma


In [181]:
random.seed(3)
drug_list = list(set(prescritpton["drug"].str.lower().tolist()))
medication_dim = [{"drug_id":random.randint(0,9999999999),"drug":drug} for drug in drug_list]
medication_dim = pd.DataFrame(medication_dim)
medication_dim

Unnamed: 0,drug_id,drug
0,9611984893,cefpodoxime proxetil
1,2337446730,albumin 5% (25g / 500ml)
2,6888784125,acetazolamide sodium
3,8871378905,baclofen
4,7891869609,chlorpheniramine maleate
...,...,...
599,5682947920,magnesium sulfate
600,9956284806,doxycycline hyclate
601,8315705722,magnesium oxide
602,7882235805,0.83% sodium chloride


In [182]:
random.seed(5)
icd_list = icd["icd_code"].tolist()
diagnosis_dim = [{"diagnosis_id":randint(0,9999999999),"icd_code":icd_code} for icd_code in icd_list]
diagnosis_dim = pd.DataFrame(diagnosis_dim)
diagnosis_dim

Unnamed: 0,diagnosis_id,icd_code
0,6970309701,0090
1,7480918169,01160
2,4051686260,01186
3,2787324501,01200
4,3869338171,01236
...,...,...
109770,3506338458,Z88
109771,1071824235,Z89012
109772,4009172070,Z90410
109773,2624060216,Z948


In [183]:
random.seed(6)

n = 4
print("3/",n," of the population has a VA Disability rating")
segment = int(len(person_df["person_id"]) - (len(person_df["person_id"])/n))

benefits_id_list = benefits_dim["benefit_id"].to_list()

benefits_table = [{"person_id": person_id, "benefit_id":benefits_id_list[choice(range(0,len(benefits_id_list)))]} for person_id in person_df["person_id"][0:segment]]
benefits_table = benefits_table + [{"person_id": person_id, "benefit_id":"None"} for person_id in person_df["person_id"][segment:]]
benefits_table = pd.DataFrame(benefits_table)
benefits_table.head(n=10)

3/ 4  of the population has a VA Disability rating


Unnamed: 0,person_id,benefit_id
0,5898329840,5996024489
1,2906162014,5308786135
2,6454399878,9033029319
3,3006899318,9033029319
4,9317563040,5308786135
5,1814995991,5308786135
6,1022254636,5308786135
7,9158698095,5996024489
8,1250600339,5996024489
9,5196438545,9033029319


## Randomly sample the segmented population

The number of benefits to assign is chosen randomly. The population to be addigned medication is randomly sampled from the population with the correlated VA Disability rating. the medication is also chosen at random from the identified list of medicaitons. The rest of the population is sampled and assigned a random medication from the `prescription` table.

In [184]:
random.seed(7)

list_of_panic_population = [choice(benefits_table.loc[benefits_table["benefit_id"] == benefits_dim["benefit_id"].to_list()[0]]["person_id"].tolist()) for x in range(0, randint(40,60))]
list_of_ptsd_population = [choice(benefits_table.loc[benefits_table["benefit_id"] == benefits_dim["benefit_id"].to_list()[1]]["person_id"].tolist()) for x in range(0, randint(20,50))]
list_of_asthma_population = [choice(benefits_table.loc[benefits_table["benefit_id"] == benefits_dim["benefit_id"].to_list()[2]]["person_id"].tolist()) for x in range(0, randint(50,60))]
list_of_none_population = [choice(benefits_table.loc[benefits_table["benefit_id"] == "None"]["person_id"].tolist()) for x in range(0, randint(20,50))]

In [185]:
random.seed(8)

medication = [{"person_id":person_id, "drug_id":medication_dim.loc[medication_dim["drug"] == medication_ptsd[choice(range(0,len(medication_ptsd)))].lower()]["drug_id"].values[0]} for person_id in list_of_ptsd_population]+\
[{"person_id":person_id, "drug_id":medication_dim.loc[medication_dim["drug"] == medication_panic[choice(range(0,len(medication_panic)))].lower()]["drug_id"].values[0]} for person_id in list_of_panic_population]+\
[{"person_id":person_id, "drug_id":medication_dim.loc[medication_dim["drug"] == medication_asthma[choice(range(0,len(medication_asthma)))].lower()]["drug_id"].values[0]} for person_id in list_of_asthma_population]+\
[{"person_id":person_id, "drug_id":medication_dim.loc[choice(range(0,len(medication_dim)))]["drug_id"]} for person_id in list_of_none_population]

medication_df = pd.DataFrame(medication)

with pd.option_context('display.max_rows', None, 'display.max_columns', None):  # more options can be specified also
    print(medication_df)

      person_id     drug_id
0    7203296599  5016170936
1    9844394925  4878949115
2    2882940332  4878949115
3    1586389040  5016170936
4    7404235780  5016170936
5    3524766129  5016170936
6    3338789861  5016170936
7     475050657  5016170936
8    2879345340  5016170936
9    7562317188  5016170936
10   5622189500  4878949115
11   1727308309  5016170936
12   1990296459  4878949115
13   5146124388  4878949115
14   4593520508  4878949115
15   3437173300  4878949115
16   8993275381  4878949115
17   1608139723  5016170936
18   3213880348  4878949115
19   5466667634  5016170936
20   5790571337  4878949115
21   7404235780  5016170936
22   5628259747  5016170936
23   3816180335  4878949115
24   9355053022  4878949115
25   1395967689  4878949115
26   2882940332  4878949115
27   1395967689  5016170936
28   4404021194  4878949115
29   3524766129  5016170936
30   9355053022  5016170936
31   2828729447  4878949115
32   8993275381  4878949115
33   9495272472  5016170936
34   3095391157  501

# ICD disease table

In [186]:
icd_panic = [
    {"icd_code":"F400"},
    {"icd_code":"F4001"},
    {"icd_code":"F4000"},
    {"icd_code":"F4002"}
]
icd_ptsd = [
    {"icd_code":"F431"},
    {"icd_code":"F4312"},
    {"icd_code":"F4310"},
    {"icd_code":"F4311"}
]
icd_asthma = [
    {"icd_code":"J4599"},
    {"icd_code":"J45901"},
    {"icd_code":"J4590"},
    {"icd_code":"J45998"}
]

In [187]:
random.seed(2)

list_of_panic_population = [choice(benefits_table.loc[benefits_table["benefit_id"] == benefits_dim["benefit_id"].to_list()[0]]["person_id"].tolist()) for x in range(0, randint(40,60))]
list_of_ptsd_population = [choice(benefits_table.loc[benefits_table["benefit_id"] == benefits_dim["benefit_id"].to_list()[1]]["person_id"].tolist()) for x in range(0, randint(20,50))]
list_of_asthma_population = [choice(benefits_table.loc[benefits_table["benefit_id"] == benefits_dim["benefit_id"].to_list()[2]]["person_id"].tolist()) for x in range(0, randint(50,60))]
list_of_none_population = [choice(benefits_table.loc[benefits_table["benefit_id"] == "None"]["person_id"].tolist()) for x in range(0, randint(20,50))]

In [188]:
random.seed(1)

diagnosis = [{"person_id":person, "diagnosis_id":diagnosis_dim.loc[diagnosis_dim["icd_code"] == icd_ptsd[choice(range(0,len(icd_ptsd)))]["icd_code"]]["diagnosis_id"].values[0]} for person in list_of_ptsd_population]+\
[{"person_id":person, "diagnosis_id":diagnosis_dim.loc[diagnosis_dim["icd_code"] == icd_panic[choice(range(0,len(icd_panic)))]["icd_code"]]["diagnosis_id"].values[0]} for person in list_of_panic_population]+\
[{"person_id":person, "diagnosis_id":diagnosis_dim.loc[diagnosis_dim["icd_code"] == icd_asthma[choice(range(0,len(icd_asthma)))]["icd_code"]]["diagnosis_id"].values[0]} for person in list_of_asthma_population]+\
[{"person_id":person, "diagnosis_id":diagnosis_dim.loc[choice(range(0,len(diagnosis_dim)))]["diagnosis_id"]} for person in list_of_none_population]

diagnosis_df = pd.DataFrame(diagnosis)
diagnosis_df.head(n=10)

Unnamed: 0,person_id,diagnosis_id
0,3213880348,1072016970
1,2828729447,5154431470
2,1602635893,7747525874
3,9844394925,5154431470
4,592307545,9559604349
5,5710055747,9559604349
6,6278166942,9559604349
7,8071846395,9559604349
8,592307545,1072016970
9,6515193184,5154431470


In [189]:
person_df.to_csv("demographics_table.csv", index=False)
benefits_table.to_csv("benefits_table.csv", index=False)
medication_df.to_csv("medication_table.csv", index=False)
diagnosis_df.to_csv("diagnosis_table.csv", index=False)

diagnosis_dim.to_csv("diagnosis_dim.csv", index=False)
medication_dim.to_csv("medication_dim.csv", index=False)
benefits_dim.to_csv("benefits_dim.csv", index=False)

In [175]:
merged_df = person_df.merge(benefits_df, how="left", on="person_id")
merged_df= merged_df.merge(medication_df, how="left", on="person_id")
merged_df = merged_df.merge(diagnosis_df, how="left", on="person_id")

KeyError: 'person_id'

## Gaussian Synthesizer 

In [102]:
from sdv.metadata import SingleTableMetadata
metadata= {
    "columns": {
        "person_id": {
            "sdtype": "categorical"
        },
        "age": {
            "sdtype": "categorical"
        },
        "vasrd": {
            "sdtype": "categorical"
        },
        "drug": {
            "sdtype": "categorical"
        },
        "diagnosis": {
            "sdtype": "categorical"
        }
    },
    "METADATA_SPEC_VERSION": "SINGLE_TABLE_V1"
}
metadata_obj = SingleTableMetadata.load_from_dict(metadata)

In [155]:
from sdv.single_table import TVAESynthesizer

synthesizer = TVAESynthesizer(metadata_obj,epochs=500)
synthesizer.fit(merged_df)

synthetic_data = synthesizer.sample(num_rows=200)

In [156]:
synthetic_data

Unnamed: 0,person_id,age,vasrd,drug,diagnosis
0,7686347231,79,9411 Posttraumatic stress disorder,,
1,2155788657,23,9412 Panic disorder and/or agoraphobia,fluoxetine,F4001
2,1740462140,61,6602 Asthma,,
3,2155788657,73,6602 Asthma,fluoxetine,F4001
4,9313862409,41,9412 Panic disorder and/or agoraphobia,sertraline,
...,...,...,...,...,...
195,7228541575,42,9412 Panic disorder and/or agoraphobia,fluoxetine,F4001
196,7791492802,79,9411 Posttraumatic stress disorder,,
197,8466690470,47,,,
198,1233939245,34,9411 Posttraumatic stress disorder,,F431


In [157]:
from sdv.evaluation.single_table import run_diagnostic, evaluate_quality
from sdv.evaluation.single_table import get_column_plot
real_data = merged_df
# 1. perform basic validity checks
diagnostic = run_diagnostic(real_data, synthetic_data, metadata_obj)

# 2. measure the statistical similarity
quality_report = evaluate_quality(real_data, synthetic_data, metadata_obj)

# 3. plot the data
fig = get_column_plot(
    real_data=real_data,
    synthetic_data=synthetic_data,
    metadata=metadata_obj,
    column_name='drug'
)

fig.show()

Generating report ...

(1/2) Evaluating Data Validity: |██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 1054.80it/s]|
Data Validity Score: 100.0%

(2/2) Evaluating Data Structure: |██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 444.12it/s]|
Data Structure Score: 100.0%

Overall Score (Average): 100.0%

Generating report ...

(1/2) Evaluating Column Shapes: |███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 908.45it/s]|
Column Shapes Score: 68.57%

(2/2) Evaluating Column Pair Trends: |█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 19.79it/s]|
Column Pair Trends Score: 44.67%

Overall Score (Average): 56.62%



In [158]:
fig = get_column_plot(
    real_data=real_data,
    synthetic_data=synthetic_data,
    metadata=metadata_obj,
    column_name='age'
)
fig.show()

In [159]:
fig = get_column_plot(
    real_data=real_data,
    synthetic_data=synthetic_data,
    metadata=metadata_obj,
    column_name='vasrd'
)
fig.show()

In [160]:

fig = get_column_plot(
    real_data=merged_df,
    synthetic_data=synthetic_data,
    metadata=metadata_obj,
    column_name='diagnosis'
)

fig.show()

In [161]:
df = synthetic_data.merge(icd,right_on="icd_code",left_on="diagnosis")
df[["icd_code","long_title"]]


Unnamed: 0,icd_code,long_title
0,F4001,Agoraphobia with panic disorder
1,F4001,Agoraphobia with panic disorder
2,J4599,Other asthma
3,F4311,"Post-traumatic stress disorder, acute"
4,J45998,Other asthma
...,...,...
88,F4001,Agoraphobia with panic disorder
89,F4310,"Post-traumatic stress disorder, unspecified"
90,F4001,Agoraphobia with panic disorder
91,F431,Post-traumatic stress disorder (PTSD)


In [162]:
df = synthetic_data
with pd.option_context('display.max_rows', None, 'display.max_columns', None):  # more options can be specified also
    print((df.groupby(['vasrd','diagnosis']).count() / df.groupby('vasrd').count())['person_id'])

vasrd                                   diagnosis
6602 Asthma                             F4001        0.061224
                                        F4002        0.040816
                                        J4590        0.183673
                                        J45901       0.020408
                                        J4599        0.081633
                                        J45998       0.408163
9411 Posttraumatic stress disorder      F431         0.041667
                                        F4310        0.020833
                                        F4311        0.250000
                                        F4312        0.041667
                                        J4599        0.020833
                                        h47032       0.020833
                                        m27          0.020833
                                        w2101xa      0.041667
9412 Panic disorder and/or agoraphobia  94120        0.019608
                    

In [163]:
df = merged_df
with pd.option_context('display.max_rows', None, 'display.max_columns', None):  # more options can be specified also
    print((df.groupby(['vasrd','diagnosis']).count() / df.groupby('vasrd').count())['person_id'])

vasrd                                   diagnosis
6602 Asthma                             6825         0.007812
                                        J4590        0.171875
                                        J45901       0.148438
                                        J4599        0.187500
                                        J45998       0.250000
                                        a922         0.007812
                                        h59212       0.007812
                                        i69234       0.007812
                                        k051         0.015625
                                        m65132       0.007812
                                        s60446d      0.007812
                                        s61439a      0.007812
                                        s62115g      0.015625
                                        s72112f      0.015625
                                        s91254       0.007812
                    

In [164]:
df = synthetic_data
with pd.option_context('display.max_rows', None, 'display.max_columns', None):  # more options can be specified also
    print((df.groupby(['vasrd','drug']).count() / df.groupby('vasrd').count())['person_id'])

vasrd                                   drug                                    
6602 Asthma                             albuterol 0.083% neb soln                   0.040816
                                        fluoxetine                                  0.081633
                                        fluticasone propionate nasal                0.102041
                                        lorazepam                                   0.040816
                                        prednisone                                  0.367347
                                        salmeterol xinafoate diskus (50 mcg)        0.040816
                                        sertraline                                  0.020408
9411 Posttraumatic stress disorder      paroxetine                                  0.020833
                                        sertraline                                  0.354167
9412 Panic disorder and/or agoraphobia  fluoxetine                                

In [166]:
df = merged_df
with pd.option_context('display.max_rows', None, 'display.max_columns', None):  # more options can be specified also
    print((df.groupby(['diagnosis','drug']).count() / df.groupby('diagnosis').count())['person_id'])

diagnosis  drug                                 
2865       magnesium oxide                          1.000000
43850      lorazepam                                1.000000
475        dobutamine                               0.500000
           idarubicin                               0.500000
F400       fluoxetine                               0.333333
           lorazepam                                0.166667
           sertraline                               0.166667
F4000      fluoxetine                               0.300000
           lorazepam                                0.150000
           sertraline                               0.300000
F4001      fluoxetine                               0.210526
           lorazepam                                0.157895
           sertraline                               0.263158
F4002      fluoxetine                               0.150000
           lorazepam                                0.250000
           sertraline               