In [24]:
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

In [25]:
import pandas as pd
import numpy as np

import commonfunc

In [26]:
data_set_path='C:/Users/shrus/Documents/Synthetic-data-generation/'

In [27]:
prescription_df = commonfunc.read_csv_no_rowid(data_set_path+"sampled_data_csv_100/prescriptions.csv")

In [28]:
prescription_df=prescription_df.sample(1400)

In [29]:
prescription_df.head()

Unnamed: 0,subject_id,hadm_id,icustay_id,startdate,enddate,drug_type,drug,drug_name_poe,drug_name_generic,formulary_drug_cd,gsn,ndc,prod_strength,dose_val_rx,dose_unit_rx,form_val_disp,form_unit_disp,route
8243,27074,181147,228048.0,2179-12-10,2179-12-14,MAIN,Guaifenesin,Guaifenesin,Guaifenesin,GUAI10,759.0,121174400.0,10mL Cup,5-10,mL,0.5-1,UDCUP,PO
5923,1709,127294,207018.0,2118-01-06,2118-01-06,MAIN,PredniSONE,PredniSONE,PredniSONE,PRED20,6751.0,54001820.0,20 mg Tablet,60,mg,3,TAB,PO
7216,15728,193047,289569.0,2175-01-21,2175-01-21,MAIN,Dolasetron Mesylate,Dolasetron Mesylate,Dolasetron Mesylate,DOLA12.5I,50268.0,88120810.0,12.5mg Vial,12.5,mg,1,VIAL,IV
4453,11816,138518,,2118-06-27,2118-06-28,MAIN,Dextrose 50%,Dextrose 50%,Dextrose 50%,DEX50SY,1989.0,409490200.0,50mL Syringe,12.5,gm,0.5,SYR,IV
8319,32008,181524,228908.0,2163-01-21,2163-01-21,MAIN,Chlorhexidine Gluconate 0.12% Oral Rinse,Chlorhexidine Gluconate 0.12% Oral Rinse,Chlorhexidine Gluconate,CHLO15L,57959.0,54569520000.0,15ml Cup,15,mL,1,UDCUP,ORAL


## Data preprocessing

In [30]:
#Drop useless columns
prescription_df.drop(['subject_id', 'hadm_id', 'icustay_id','gsn','ndc','dose_val_rx','dose_unit_rx','form_val_disp'], axis=1, inplace=True)

In [31]:
prescription_df.dtypes

startdate            object
enddate              object
drug_type            object
drug                 object
drug_name_poe        object
drug_name_generic    object
formulary_drug_cd    object
prod_strength        object
form_unit_disp       object
route                object
dtype: object

In [32]:
prescription_df['startdate'] = pd.to_datetime(prescription_df['startdate'])
prescription_df['enddate'] = pd.to_datetime(prescription_df['enddate'])


In [33]:
prescription_df.isnull().sum()

startdate              0
enddate                1
drug_type              0
drug                   0
drug_name_poe        511
drug_name_generic    511
formulary_drug_cd      0
prod_strength          0
form_unit_disp         0
route                  0
dtype: int64

In [42]:
prescription_df.dropna(subset=['enddate'], inplace=True, axis=0)


# Metadata

In [43]:
from sdv.metadata import SingleTableMetadata

metadata = SingleTableMetadata()

In [44]:
metadata.detect_from_dataframe(data=prescription_df)

In [45]:
metadata

{
    "METADATA_SPEC_VERSION": "SINGLE_TABLE_V1",
    "columns": {
        "startdate": {
            "sdtype": "datetime"
        },
        "enddate": {
            "sdtype": "datetime"
        },
        "drug_type": {
            "sdtype": "categorical"
        },
        "drug": {
            "sdtype": "categorical"
        },
        "drug_name_poe": {
            "sdtype": "categorical"
        },
        "drug_name_generic": {
            "sdtype": "categorical"
        },
        "formulary_drug_cd": {
            "sdtype": "categorical"
        },
        "prod_strength": {
            "sdtype": "categorical"
        },
        "form_unit_disp": {
            "sdtype": "categorical"
        },
        "route": {
            "sdtype": "categorical"
        }
    }
}

## Constraints

In [53]:
startdate_enddate_constraint = {
    'constraint_class': 'Inequality',
    'constraint_parameters': {
        'low_column_name': 'startdate',
        'high_column_name': 'enddate',
        'strict_boundaries': True
    }
}

## Modelling

In [54]:
data = prescription_df

In [56]:
from sdv.lite import SingleTablePreset

synthesizer1 = SingleTablePreset(metadata, name='FAST_ML')

synthesizer1.fit(data)

synthetic_data1 = synthesizer1.sample(num_rows=len(data))

In [61]:
from sdv.single_table import GaussianCopulaSynthesizer

synthesizer2 = GaussianCopulaSynthesizer(metadata)
synthesizer2.fit(data)

synthetic_data2 = synthesizer2.sample(num_rows=len(data))

In [62]:
from sdv.single_table import CTGANSynthesizer

synthesizer3 = CTGANSynthesizer(metadata)

synthesizer3.fit(data)

synthetic_data3 = synthesizer3.sample(num_rows=len(data))

In [63]:
from sdv.single_table import TVAESynthesizer

synthesizer4 = TVAESynthesizer(metadata)
synthesizer4.fit(data)

synthetic_data4 = synthesizer4.sample(num_rows=len(data))

In [64]:
from sdv.single_table import CopulaGANSynthesizer

synthesizer5 = CopulaGANSynthesizer(metadata)
synthesizer5.fit(data)

synthetic_data5 = synthesizer5.sample(num_rows=len(data))

In [57]:
from sdv.evaluation.single_table import evaluate_quality

quality_report = evaluate_quality(
    real_data=data,
    synthetic_data=synthetic_data1,
    metadata=metadata)

Creating report: 100%|██████████| 4/4 [00:02<00:00,  1.74it/s]



Overall Quality Score: 60.77%

Properties:
Column Shapes: 80.93%
Column Pair Trends: 40.62%


In [65]:
from sdv.evaluation.single_table import evaluate_quality

quality_report = evaluate_quality(
    real_data=data,
    synthetic_data=synthetic_data2,
    metadata=metadata)

Creating report: 100%|██████████| 4/4 [00:01<00:00,  2.10it/s]



Overall Quality Score: 48.51%

Properties:
Column Shapes: 69.25%
Column Pair Trends: 27.77%


In [66]:
from sdv.evaluation.single_table import evaluate_quality

quality_report = evaluate_quality(
    real_data=data,
    synthetic_data=synthetic_data3,
    metadata=metadata)

Creating report: 100%|██████████| 4/4 [00:01<00:00,  2.48it/s]



Overall Quality Score: 54.82%

Properties:
Column Shapes: 74.23%
Column Pair Trends: 35.41%


In [67]:
from sdv.evaluation.single_table import evaluate_quality

quality_report = evaluate_quality(
    real_data=data,
    synthetic_data=synthetic_data4,
    metadata=metadata)

Creating report: 100%|██████████| 4/4 [00:02<00:00,  1.75it/s]



Overall Quality Score: 57.11%

Properties:
Column Shapes: 69.09%
Column Pair Trends: 45.14%


In [68]:
from sdv.evaluation.single_table import evaluate_quality

quality_report = evaluate_quality(
    real_data=data,
    synthetic_data=synthetic_data5,
    metadata=metadata)

Creating report: 100%|██████████| 4/4 [00:01<00:00,  2.29it/s]



Overall Quality Score: 51.68%

Properties:
Column Shapes: 72.11%
Column Pair Trends: 31.25%


In [None]:
synthesizer1.save('C:/Users/shrus/Documents/Synthetic-data-generation/models/prescription/'+'labevents_best_singletablepreset.pkl')
