In [182]:
# from google.colab import drive
# drive.mount('/content/drive')

In [183]:
# ! pip install pgmpy

In [184]:
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

In [185]:
import pandas as pd
import numpy as np

import common

In [186]:
date_set_path = "../temp_sets_100/"

### Read samples

In [187]:
patients_df = common.read_csv_no_rowid(date_set_path+"patients.csv")

## Patients  Preprocess data

### Drop useless colums

In [188]:
# Drop useless colums
patients_df.drop(['expire_flag'], axis=1, inplace=True)

### Deal with null values

In [189]:
# Check null value in table
common.nan_count(patients_df)

Total columns: 6
Total rows: 100
--------------
subject_id     0
gender         0
dob            0
dod           59
dod_hosp      78
dod_ssn       63
dtype: int64


In [190]:
# Set a value replacing the null time value
nan_datetime=pd.to_datetime(0)

In [191]:
patients_df['dob'].fillna(value=nan_datetime, inplace=True)
# patients_df['dod_hosp'].fillna(value=nan_datetime, inplace=True)
# patients_df['dod_ssn'].fillna(value=nan_datetime, inplace=True)

In [192]:
common.nan_count(patients_df)

Total columns: 6
Total rows: 100
--------------
subject_id     0
gender         0
dob            0
dod           59
dod_hosp      78
dod_ssn       63
dtype: int64


### Set the column types

In [193]:
patients_df.dtypes

subject_id     int64
gender        object
dob           object
dod           object
dod_hosp      object
dod_ssn       object
dtype: object

In [194]:
# Transfer some date type

patients_df['dob'] = pd.to_datetime(patients_df['dob'])
patients_df['dod'] = pd.to_datetime(patients_df['dod'])
patients_df['dod_hosp'] = pd.to_datetime(patients_df['dod_hosp'])
patients_df['dod_ssn'] = pd.to_datetime(patients_df['dod_ssn'])

In [195]:
patients_df.dtypes

subject_id             int64
gender                object
dob           datetime64[ns]
dod           datetime64[ns]
dod_hosp      datetime64[ns]
dod_ssn       datetime64[ns]
dtype: object

### Process dod_hosp and dod_ssn

In [196]:
from pandas import NaT

# Define a method to deal with death time
def dod_process(df):
    '''
    Note that DOD merged together DOD_HOSP and DOD_SSN, giving priority to DOD_HOSP if both were recorded
    '''
    
    if not (pd.isna(df['dod_hosp']) or pd.isna(df['dod_ssn'])):
        return NaT
    else:
        return df['dod_ssn']

In [197]:
patients_df['dod_ssn'] = patients_df.apply(dod_process, axis=1)

In [198]:
patients_df

Unnamed: 0,subject_id,gender,dob,dod,dod_hosp,dod_ssn
0,569,M,2021-11-04,2107-11-30,NaT,2107-11-30
1,26282,F,2074-03-20,NaT,NaT,NaT
2,1762,F,2120-03-19,NaT,NaT,NaT
3,14481,M,1816-05-08,2121-02-14,NaT,2121-02-14
4,21470,M,2195-05-20,NaT,NaT,NaT
...,...,...,...,...,...,...
95,23647,M,2077-02-13,NaT,NaT,NaT
96,26485,F,2075-08-15,2164-05-08,2164-05-08,NaT
97,26884,F,2067-01-31,2153-06-23,NaT,2153-06-23
98,49024,M,2055-02-17,2137-01-12,NaT,2137-01-12


In [199]:
# from pandas import NaT

# patients_df['dod_hosp'] = patients_df['dod_hosp'].apply(lambda x: 1 if not pd.isna(x) else 0)
# patients_df['dod_ssn'] = patients_df['dod_ssn'].apply(lambda x: 1 if not pd.isna(x) else 0)
# patients_df['dod_live'] = patients_df['dod_live'].apply(lambda x: 1 if not pd.isna(x) else 0)

---

##  Build Network

In [200]:
from sdv.tabular import CTGAN
from sdv.constraints import FixedCombinations

In [201]:
patients_df.columns

Index(['subject_id', 'gender', 'dob', 'dod', 'dod_hosp', 'dod_ssn'], dtype='object')

### Set constraint

In [202]:
# Fixed constraints
# fixed_subject_hadm_icustay_constraint = FixedCombinations(
#     column_names=['subject_id', 'hadm_id']
# )

In [203]:
# dod_hosp_dod_ssn_constraint = OneHotEncoding(
#     column_names=['dod_hosp', 'dod_ssn', 'dod_live']
# )

In [204]:
# patients_constraints = [dod_hosp_dod_ssn_constraint]

In [205]:
# Custom constrains
def dod_data_unique_is_valid(column_names, data):
    one = (pd.isna(data['dod']) == False) & (pd.isna(data[column_names[0]]) == False) & (pd.isna(data[column_names[1]]) == True) #& (data['dod'] == data[column_names[0]])
    two = (pd.isna(data['dod']) == False) & (pd.isna(data[column_names[0]]) == True) & (pd.isna(data[column_names[1]]) == False) #& (data['dod'] == data[column_names[1]])
    three = (pd.isna(data['dod']) == True) & (pd.isna(data[column_names[0]]) == True) & (pd.isna(data[column_names[1]]) == True)
    is_only = one | two | three
    return is_only

def dod_data_unique_trasform(column_names, data):
    # print(data.loc[:, ['dod', 'dod_hosp', 'dod_ssn']])
    return data

def dod_data_unique_reverse_transform(column_names, transformed_data):
    # print(transformed_data.loc[:, ['dod', 'dod_hosp', 'dod_ssn']])

    one = (pd.isna(transformed_data['dod']) == False) & (pd.isna(transformed_data[column_names[0]]) == False) & (pd.isna(transformed_data[column_names[1]]) == True)
    two = (pd.isna(transformed_data['dod']) == False) & (pd.isna(transformed_data[column_names[0]]) == True) & (pd.isna(transformed_data[column_names[1]]) == False)

    transformed_data[column_names[0]] = transformed_data[one]['dod']
    transformed_data[column_names[1]] = transformed_data[two]['dod']

    return transformed_data

In [206]:
# Build custom constrains
from sdv.constraints import create_custom_constraint

DodUniqueProcess = create_custom_constraint(
    is_valid_fn=dod_data_unique_is_valid,
    transform_fn=dod_data_unique_trasform,
    reverse_transform_fn=dod_data_unique_reverse_transform
)

In [207]:
dod_data_unique_constraint = DodUniqueProcess(
    column_names=['dod_hosp', 'dod_ssn']
)

In [208]:
constrains = [dod_data_unique_constraint]

### Build and train model

In [209]:
model = CTGAN(
    constraints=constrains, 
    batch_size=10000,
    cuda=True, 
    verbose=True, 
    epochs=60)

In [210]:
len(patients_df)

100

In [211]:
train_data = patients_df
model.fit(train_data)

Epoch 1, Loss G:  0.8444,Loss D:  0.0209
Epoch 2, Loss G:  0.8393,Loss D:  0.0155
Epoch 3, Loss G:  0.8298,Loss D:  0.0122
Epoch 4, Loss G:  0.8171,Loss D:  0.0044
Epoch 5, Loss G:  0.8284,Loss D: -0.0055
Epoch 6, Loss G:  0.8184,Loss D: -0.0160
Epoch 7, Loss G:  0.8191,Loss D: -0.0195
Epoch 8, Loss G:  0.8121,Loss D: -0.0367
Epoch 9, Loss G:  0.8063,Loss D: -0.0370
Epoch 10, Loss G:  0.8119,Loss D: -0.0571
Epoch 11, Loss G:  0.7917,Loss D: -0.0576
Epoch 12, Loss G:  0.7819,Loss D: -0.0755
Epoch 13, Loss G:  0.7899,Loss D: -0.0795
Epoch 14, Loss G:  0.7643,Loss D: -0.0807
Epoch 15, Loss G:  0.7415,Loss D: -0.0864
Epoch 16, Loss G:  0.7269,Loss D: -0.0893
Epoch 17, Loss G:  0.7238,Loss D: -0.0960
Epoch 18, Loss G:  0.6862,Loss D: -0.1017
Epoch 19, Loss G:  0.6790,Loss D: -0.1214
Epoch 20, Loss G:  0.6489,Loss D: -0.1088
Epoch 21, Loss G:  0.6088,Loss D: -0.1084
Epoch 22, Loss G:  0.5654,Loss D: -0.0893
Epoch 23, Loss G:  0.5237,Loss D: -0.1156
Epoch 24, Loss G:  0.4841,Loss D: -0.0945
E

In [212]:
# model.save(date_set_path + "100_models/" + "patients_model.pkl")

### Generate synthetic data

In [213]:
sample = model.sample(num_rows=len(train_data))

Sampling rows: 100%|██████████| 100/100 [00:00<00:00, 387.72it/s]


In [214]:
sample.head()

Unnamed: 0,subject_id,gender,dob,dod,dod_hosp,dod_ssn
0,569,M,2054-05-18,NaT,NaT,NaT
1,2813,M,2019-11-04,2182-07-29,NaT,2182-07-29
2,12008,M,2109-10-17,NaT,NaT,NaT
3,10617,M,2150-09-20,NaT,NaT,NaT
4,5574,F,2098-10-02,NaT,NaT,NaT


---

## Evaluate data

In [220]:
from sdv.evaluation import evaluate

In [221]:
evaluate(sample, train_data, metrics=['ContinuousKLDivergence'])

nan

In [222]:
evaluate(sample, train_data, metrics=['DiscreteKLDivergence'])

nan

---

### Evaluate timeseries data (not accurate in tis table)

In [215]:
from sdv.metrics.timeseries import LSTMDetection, TSFCDetection

In [216]:
metadata_2 = model.get_metadata().to_dict()

In [217]:
metadata_2['entity_columns']=['subject_id']

In [218]:
metadata_2

{'fields': {'subject_id': {'type': 'numerical',
   'subtype': 'integer',
   'transformer': 'integer'},
  'gender': {'type': 'categorical', 'transformer': None},
  'dob': {'type': 'datetime', 'transformer': 'datetime'},
  'dod': {'type': 'datetime', 'transformer': 'datetime'},
  'dod_hosp': {'type': 'datetime', 'transformer': 'datetime'},
  'dod_ssn': {'type': 'datetime', 'transformer': 'datetime'}},
 'constraints': [{'constraint': 'sdv.constraints.tabular.CustomConstraint',
   'column_names': ['dod_hosp', 'dod_ssn']}],
 'model_kwargs': {},
 'name': None,
 'primary_key': None,
 'sequence_index': None,
 'entity_columns': ['subject_id'],
 'context_columns': []}

In [219]:
LSTMDetection.compute(sample, train_data, metadata_2)

0.32608695652173914