In [82]:
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

In [83]:
import pandas as pd
import numpy as np

import commonfunc

In [84]:
data_set_path='C:/Users/shrus/Documents/Synthetic-data-generation/'

In [85]:
prescription_df = commonfunc.read_csv_no_rowid(data_set_path+"sampled_data_csv_100/prescriptions.csv")

In [86]:
prescription_df=prescription_df.sample(1400)

In [87]:
prescription_df.head()

Unnamed: 0,subject_id,hadm_id,icustay_id,startdate,enddate,drug_type,drug,drug_name_poe,drug_name_generic,formulary_drug_cd,gsn,ndc,prod_strength,dose_val_rx,dose_unit_rx,form_val_disp,form_unit_disp,route
8779,17668,151926,257517.0,2188-09-25,2188-09-26,MAIN,Sodium Bicarbonate,,,NABC50I,1185.0,74662502.0,50mEq Vial,300,mEq,6,VIAL,IV
6930,11003,167847,223836.0,2119-06-13,2119-06-14,BASE,Iso-Osmotic Dextrose,,,FRBD100,,0.0,100ml Bag,100,mL,100,mL,IV
6632,69093,189496,255781.0,2179-08-22,2179-08-22,BASE,0.9% Sodium Chloride,,,NS250,1210.0,338004902.0,250mL Bag,250,mL,250,mL,IV
4676,44486,178101,,2137-04-30,2137-04-30,MAIN,PredniSONE,PredniSONE,PredniSONE,PRED20,6751.0,54001820.0,20 mg Tablet,40,mg,2,TAB,PO
1275,93162,180765,241558.0,2154-03-28,2154-03-29,BASE,Bag,,,BAG50,,0.0,50 mL Bag,100,mL,2,BAG,IV


## Data preprocessing

In [88]:
#Drop useless columns
prescription_df.drop(['subject_id', 'hadm_id', 'icustay_id','gsn','ndc','dose_val_rx','dose_unit_rx','form_val_disp'], axis=1, inplace=True)

In [89]:
prescription_df.dtypes

startdate            object
enddate              object
drug_type            object
drug                 object
drug_name_poe        object
drug_name_generic    object
formulary_drug_cd    object
prod_strength        object
form_unit_disp       object
route                object
dtype: object

In [90]:
prescription_df['startdate'] = pd.to_datetime(prescription_df['startdate'])
prescription_df['enddate'] = pd.to_datetime(prescription_df['enddate'])


In [91]:
prescription_df.isnull().sum()

startdate              0
enddate                0
drug_type              0
drug                   0
drug_name_poe        540
drug_name_generic    540
formulary_drug_cd      0
prod_strength          0
form_unit_disp         0
route                  0
dtype: int64

In [92]:
prescription_df.dropna(subset=['enddate'], inplace=True, axis=0)


In [93]:
prescription_df.columns

Index(['startdate', 'enddate', 'drug_type', 'drug', 'drug_name_poe',
       'drug_name_generic', 'formulary_drug_cd', 'prod_strength',
       'form_unit_disp', 'route'],
      dtype='object')

# Metadata

In [94]:
from sdv.metadata import SingleTableMetadata

metadata = SingleTableMetadata()

In [95]:
metadata.detect_from_dataframe(data=prescription_df)

In [96]:
metadata

{
    "METADATA_SPEC_VERSION": "SINGLE_TABLE_V1",
    "columns": {
        "startdate": {
            "sdtype": "datetime"
        },
        "enddate": {
            "sdtype": "datetime"
        },
        "drug_type": {
            "sdtype": "categorical"
        },
        "drug": {
            "sdtype": "categorical"
        },
        "drug_name_poe": {
            "sdtype": "categorical"
        },
        "drug_name_generic": {
            "sdtype": "categorical"
        },
        "formulary_drug_cd": {
            "sdtype": "categorical"
        },
        "prod_strength": {
            "sdtype": "categorical"
        },
        "form_unit_disp": {
            "sdtype": "categorical"
        },
        "route": {
            "sdtype": "categorical"
        }
    }
}

## Constraints

In [53]:
startdate_enddate_constraint = {
    'constraint_class': 'Inequality',
    'constraint_parameters': {
        'low_column_name': 'startdate',
        'high_column_name': 'enddate',
        'strict_boundaries': True
    }
}

## Modelling

In [54]:
data = prescription_df

In [56]:
from sdv.lite import SingleTablePreset

synthesizer1 = SingleTablePreset(metadata, name='FAST_ML')

synthesizer1.fit(data)

synthetic_data1 = synthesizer1.sample(num_rows=len(data))

In [61]:
from sdv.single_table import GaussianCopulaSynthesizer

synthesizer2 = GaussianCopulaSynthesizer(metadata)
synthesizer2.fit(data)

synthetic_data2 = synthesizer2.sample(num_rows=len(data))

In [62]:
from sdv.single_table import CTGANSynthesizer

synthesizer3 = CTGANSynthesizer(metadata)

synthesizer3.fit(data)

synthetic_data3 = synthesizer3.sample(num_rows=len(data))

In [63]:
from sdv.single_table import TVAESynthesizer

synthesizer4 = TVAESynthesizer(metadata)
synthesizer4.fit(data)

synthetic_data4 = synthesizer4.sample(num_rows=len(data))

In [64]:
from sdv.single_table import CopulaGANSynthesizer

synthesizer5 = CopulaGANSynthesizer(metadata)
synthesizer5.fit(data)

synthetic_data5 = synthesizer5.sample(num_rows=len(data))

In [57]:
from sdv.evaluation.single_table import evaluate_quality

quality_report = evaluate_quality(
    real_data=data,
    synthetic_data=synthetic_data1,
    metadata=metadata)

Creating report: 100%|██████████| 4/4 [00:02<00:00,  1.74it/s]



Overall Quality Score: 60.77%

Properties:
Column Shapes: 80.93%
Column Pair Trends: 40.62%


In [65]:
from sdv.evaluation.single_table import evaluate_quality

quality_report = evaluate_quality(
    real_data=data,
    synthetic_data=synthetic_data2,
    metadata=metadata)

Creating report: 100%|██████████| 4/4 [00:01<00:00,  2.10it/s]



Overall Quality Score: 48.51%

Properties:
Column Shapes: 69.25%
Column Pair Trends: 27.77%


In [66]:
from sdv.evaluation.single_table import evaluate_quality

quality_report = evaluate_quality(
    real_data=data,
    synthetic_data=synthetic_data3,
    metadata=metadata)

Creating report: 100%|██████████| 4/4 [00:01<00:00,  2.48it/s]



Overall Quality Score: 54.82%

Properties:
Column Shapes: 74.23%
Column Pair Trends: 35.41%


In [67]:
from sdv.evaluation.single_table import evaluate_quality

quality_report = evaluate_quality(
    real_data=data,
    synthetic_data=synthetic_data4,
    metadata=metadata)

Creating report: 100%|██████████| 4/4 [00:02<00:00,  1.75it/s]



Overall Quality Score: 57.11%

Properties:
Column Shapes: 69.09%
Column Pair Trends: 45.14%


In [68]:
from sdv.evaluation.single_table import evaluate_quality

quality_report = evaluate_quality(
    real_data=data,
    synthetic_data=synthetic_data5,
    metadata=metadata)

Creating report: 100%|██████████| 4/4 [00:01<00:00,  2.29it/s]



Overall Quality Score: 51.68%

Properties:
Column Shapes: 72.11%
Column Pair Trends: 31.25%


In [69]:
synthesizer1.save('C:/Users/shrus/Documents/Synthetic-data-generation/models/prescription/'+'prescription_best_singletablepreset.pkl')


## Visualization

In [104]:
from sdv.lite import SingleTablePreset

prescription_synthesizer = SingleTablePreset.load(data_set_path+'models/prescription/prescription_best_singletablepreset.pkl')

In [105]:
data=prescription_df

In [106]:
syn_data=prescription_synthesizer.sample(len(data))

In [107]:
from sdmetrics.reports.single_table import DiagnosticReport

diag_report = DiagnosticReport()

diag_report.generate(data, syn_data, metadata)


Creating report: 100%|██████████| 4/4 [00:11<00:00,  2.92s/it]


DiagnosticResults:

SUCCESS:
✓ The synthetic data covers over 90% of the numerical ranges present in the real data
✓ Over 90% of the synthetic rows are not copies of the real data

! The synthetic data is missing more than 10% of the categories present in the real data
! More than 10% the synthetic data does not follow the min/max boundaries set by the real data





In [108]:
diag_report.get_properties()

{'Coverage': 0.7179683752690604,
 'Synthesis': 1.0,
 'Boundaries': 0.8985714285714286}

In [109]:
from sdv.evaluation.single_table import evaluate_quality

quality_report = evaluate_quality(
    real_data=data,
    synthetic_data=syn_data,
    metadata=metadata)

Creating report: 100%|██████████| 4/4 [00:06<00:00,  1.65s/it]



Overall Quality Score: 56.42%

Properties:
Column Shapes: 75.09%
Column Pair Trends: 37.74%


In [110]:
quality_report.get_visualization('Column Shapes')

In [116]:
from sdv.evaluation.single_table import get_column_plot

fig = get_column_plot(
    real_data=data,
    synthetic_data=syn_data,
    column_name='form_unit_disp',
    metadata=metadata
)

fig.show()