In [58]:
# from google.colab import drive
# drive.mount('/content/drive')

In [59]:
# ! pip install pgmpy

In [60]:
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

In [61]:
import pandas as pd
import numpy as np

import common

In [62]:
date_set_path = "../temp_sets_100/"

# Patients  

## Data precess

### Read samples

In [63]:
patients_df = common.read_csv_no_rowid(date_set_path+"patients.csv")

### Drop useless colums

In [64]:
# Drop useless colums
patients_df.drop(['expire_flag'], axis=1, inplace=True)

### Deal with null values

In [65]:
# Check null value in table
common.nan_count(patients_df)

Total columns: 6
Total rows: 100
--------------
subject_id     0
gender         0
dob            0
dod           59
dod_hosp      78
dod_ssn       63
dtype: int64


In [66]:
# Set a value replacing the null time value
nan_datetime=pd.to_datetime(0)

In [67]:
patients_df['dob'].fillna(value=nan_datetime, inplace=True)
# patients_df['dod_hosp'].fillna(value=nan_datetime, inplace=True)
# patients_df['dod_ssn'].fillna(value=nan_datetime, inplace=True)

In [68]:
common.nan_count(patients_df)

Total columns: 6
Total rows: 100
--------------
subject_id     0
gender         0
dob            0
dod           59
dod_hosp      78
dod_ssn       63
dtype: int64


### Set the column types

In [69]:
patients_df.dtypes

subject_id     int64
gender        object
dob           object
dod           object
dod_hosp      object
dod_ssn       object
dtype: object

In [70]:
# Transfer some date type

patients_df['dob'] = pd.to_datetime(patients_df['dob'])
# patients_df['dob'] = patients_df['dob'].astype(np.int64)
patients_df['dod'] = pd.to_datetime(patients_df['dod'])
patients_df['dod_hosp'] = pd.to_datetime(patients_df['dod_hosp'])
patients_df['dod_ssn'] = pd.to_datetime(patients_df['dod_ssn'])

In [71]:
patients_df.dtypes

subject_id             int64
gender                object
dob           datetime64[ns]
dod           datetime64[ns]
dod_hosp      datetime64[ns]
dod_ssn       datetime64[ns]
dtype: object

### Process dod_hosp and dod_ssn

In [72]:
from pandas import NaT

# Define a method to deal with death time
def dod_process(df):
    '''
    Note that DOD merged together DOD_HOSP and DOD_SSN, giving priority to DOD_HOSP if both were recorded
    '''
    
    if not (pd.isna(df['dod_hosp']) or pd.isna(df['dod_ssn'])):
        return NaT
    else:
        return df['dod_ssn']

In [73]:
patients_df['dod_ssn'] = patients_df.apply(dod_process, axis=1)

In [74]:
patients_df

Unnamed: 0,subject_id,gender,dob,dod,dod_hosp,dod_ssn
0,569,M,2021-11-04,2107-11-30,NaT,2107-11-30
1,26282,F,2074-03-20,NaT,NaT,NaT
2,1762,F,2120-03-19,NaT,NaT,NaT
3,14481,M,1816-05-08,2121-02-14,NaT,2121-02-14
4,21470,M,2195-05-20,NaT,NaT,NaT
...,...,...,...,...,...,...
95,23647,M,2077-02-13,NaT,NaT,NaT
96,26485,F,2075-08-15,2164-05-08,2164-05-08,NaT
97,26884,F,2067-01-31,2153-06-23,NaT,2153-06-23
98,49024,M,2055-02-17,2137-01-12,NaT,2137-01-12


### Process dob (age > 300)

In [75]:
patients_df.dtypes

subject_id             int64
gender                object
dob           datetime64[ns]
dod           datetime64[ns]
dod_hosp      datetime64[ns]
dod_ssn       datetime64[ns]
dtype: object

In [76]:
import datetime
import random

def adjust_age_over_90(df):
    '''
    This method is to adjust the invalid date in 'dob' (which is 18xx)
    Process: Plus 300 - 310 years from original date, to let the age of patients are between 90 to 100
    '''

    years_100 = datetime.timedelta(days = (365 * 100 + 100/4))
    random_days_10_years = datetime.timedelta(days = random.randint(0, 10*365))
    
    # if patient was not dead, give a random reasonable birth
    if pd.isna(df['dod']):
        if np.int64(df['dob'].to_numpy()) < 0:
            return df['dob'] + years_100 + years_100 + years_100 + random_days_10_years
        else:
            return df['dob']
    # if patient was dead, use the death date - random 90 - 100 year, to make sure this patient's age was between 90 to 100
    else:
        return df['dod'] - years_100 + random_days_10_years

In [77]:
patients_df['dob'] = patients_df.apply(adjust_age_over_90, axis=1)

In [78]:
# from pandas import NaT

# patients_df['dod_hosp'] = patients_df['dod_hosp'].apply(lambda x: 1 if not pd.isna(x) else 0)
# patients_df['dod_ssn'] = patients_df['dod_ssn'].apply(lambda x: 1 if not pd.isna(x) else 0)
# patients_df['dod_live'] = patients_df['dod_live'].apply(lambda x: 1 if not pd.isna(x) else 0)

---

##  Build Model

In [79]:
from sdv.tabular import CTGAN
from sdv.constraints import FixedCombinations

In [80]:
patients_df.columns

Index(['subject_id', 'gender', 'dob', 'dod', 'dod_hosp', 'dod_ssn'], dtype='object')

### Set constraint

In [81]:
# Fixed constraints
# fixed_subject_hadm_icustay_constraint = FixedCombinations(
#     column_names=['subject_id', 'hadm_id']
# )

In [82]:
# dod_hosp_dod_ssn_constraint = OneHotEncoding(
#     column_names=['dod_hosp', 'dod_ssn', 'dod_live']
# )

In [83]:
# patients_constraints = [dod_hosp_dod_ssn_constraint]

#### Build custom constrains

In [84]:
from sdv.constraints import create_custom_constraint

In [85]:
# DOD
def dod_data_unique_is_valid(column_names, data):
    one = (pd.isna(data['dod']) == False) & (pd.isna(data['dod_hosp']) == False) & (pd.isna(data['dod_ssn']) == True) #& (data['dod'] == data['dod_hosp'])
    two = (pd.isna(data['dod']) == False) & (pd.isna(data['dod_hosp']) == True) & (pd.isna(data['dod_ssn']) == False) #& (data['dod'] == data['dod_ssn'])
    three = (pd.isna(data['dod']) == True) & (pd.isna(data['dod_hosp']) == True) & (pd.isna(data['dod_ssn']) == True)
    is_only = one | two | three
    return is_only

def dod_data_unique_trasform(column_names, data):
    # print(data.loc[:, ['dod', 'dod_hosp', 'dod_ssn']])
    return data

def dod_data_unique_reverse_transform(column_names, transformed_data):
    # print(transformed_data.loc[:, ['dod', 'dod_hosp', 'dod_ssn']])

    one = (pd.isna(transformed_data['dod']) == False) & (pd.isna(transformed_data['dod_hosp']) == False) & (pd.isna(transformed_data['dod_ssn']) == True)
    two = (pd.isna(transformed_data['dod']) == False) & (pd.isna(transformed_data['dod_hosp']) == True) & (pd.isna(transformed_data['dod_ssn']) == False)

    transformed_data['dod_hosp'] = transformed_data[one]['dod']
    transformed_data['dod_ssn'] = transformed_data[two]['dod']

    return transformed_data


DodUniqueProcess = create_custom_constraint(
    is_valid_fn=dod_data_unique_is_valid,
    transform_fn=dod_data_unique_trasform,
    reverse_transform_fn=dod_data_unique_reverse_transform
)

dod_data_unique_constraint = DodUniqueProcess(
    column_names=['dod_hosp', 'dod_ssn']
)

In [86]:
# DOB
def dob_nut_null(column_names, data):
    return pd.isna(data['dob']) == False

DobNotNull = create_custom_constraint(
    is_valid_fn=dob_nut_null
)

dob_not_null_constraint = DobNotNull(
    column_names=['dob']
)

#### Predefined Constraints

In [87]:
from sdv.constraints import ScalarInequality

dob_before_2200 = ScalarInequality(
    column_name='dob',
    relation='<=',
    value="2200-01-01"
)

In [88]:
dob_after_2000 = ScalarInequality(
    column_name='dob',
    relation='>=',
    value="2000-01-01"
)

In [89]:
from sdv.constraints import Inequality

dob_before_dod = Inequality(
    low_column_name='dob',
    high_column_name='dod'
)

### Build and train model

In [90]:
constrains = [dod_data_unique_constraint, dob_not_null_constraint, dob_before_2200, dob_after_2000, dob_before_dod]

In [91]:
model = CTGAN(
    constraints=constrains, 
    batch_size=10000,
    cuda=True, 
    verbose=True, 
    epochs=60)

In [92]:
len(patients_df)

100

In [93]:
train_data = patients_df
model.fit(train_data)

Epoch 1, Loss G:  0.7311,Loss D:  0.0065
Epoch 2, Loss G:  0.7188,Loss D: -0.0103
Epoch 3, Loss G:  0.7024,Loss D: -0.0182
Epoch 4, Loss G:  0.6973,Loss D: -0.0448
Epoch 5, Loss G:  0.6799,Loss D: -0.0446
Epoch 6, Loss G:  0.6687,Loss D: -0.0713
Epoch 7, Loss G:  0.6498,Loss D: -0.0919
Epoch 8, Loss G:  0.6371,Loss D: -0.1080
Epoch 9, Loss G:  0.6209,Loss D: -0.1294
Epoch 10, Loss G:  0.5864,Loss D: -0.1379
Epoch 11, Loss G:  0.5596,Loss D: -0.1504
Epoch 12, Loss G:  0.5139,Loss D: -0.1593
Epoch 13, Loss G:  0.4827,Loss D: -0.1727
Epoch 14, Loss G:  0.4288,Loss D: -0.1897
Epoch 15, Loss G:  0.3778,Loss D: -0.1607
Epoch 16, Loss G:  0.3184,Loss D: -0.1639
Epoch 17, Loss G:  0.2459,Loss D: -0.1337
Epoch 18, Loss G:  0.1574,Loss D: -0.1251
Epoch 19, Loss G:  0.0873,Loss D: -0.0739
Epoch 20, Loss G: -0.0153,Loss D: -0.0693
Epoch 21, Loss G: -0.1000,Loss D: -0.0342
Epoch 22, Loss G: -0.1913,Loss D:  0.0344
Epoch 23, Loss G: -0.2648,Loss D:  0.0370
Epoch 24, Loss G: -0.3474,Loss D:  0.1080
E

### Save model

In [94]:
import cloudpickle

with open(date_set_path + "100_models/" + "patients_model.pkl", 'wb') as f:
    cloudpickle.dump(model, f)

In [95]:
# model.save(date_set_path + "100_models/" + "patients_model.pkl")

## Generate synthetic data

In [96]:
import cloudpickle

with open(date_set_path + "100_models/" + "patients_model.pkl", 'rb') as f:
    model = cloudpickle.load(f)

In [97]:
sample = model.sample(num_rows=len(train_data))

Sampling rows: 100%|██████████| 100/100 [00:00<00:00, 249.77it/s]


In [98]:
sample.head(30)

Unnamed: 0,subject_id,gender,dob,dod,dod_hosp,dod_ssn
0,26284,F,2072-09-28,NaT,NaT,NaT
1,34981,F,2099-12-23,NaT,NaT,NaT
2,29076,F,2096-01-16,NaT,NaT,NaT
3,27157,M,2094-05-20,2184-12-27,2184-12-27,NaT
4,59536,M,2065-01-16,2162-04-02,2162-04-02,NaT
5,47957,M,2055-02-03,NaT,NaT,NaT
6,69407,F,2117-03-08,NaT,NaT,NaT
7,99503,M,2048-06-20,NaT,NaT,NaT
8,12732,M,2101-01-14,NaT,NaT,NaT
9,21093,M,2032-02-25,NaT,NaT,NaT


---

## Evaluate data

In [99]:
from sdv.evaluation import evaluate

In [100]:
evaluate(sample, train_data, metrics=['ContinuousKLDivergence'])

nan

In [101]:
evaluate(sample, train_data, metrics=['DiscreteKLDivergence'])

nan

---

### Evaluate timeseries data (not accurate in tis table)

In [102]:
from sdv.metrics.timeseries import LSTMDetection, TSFCDetection

In [103]:
metadata_2 = model.get_metadata().to_dict()

In [104]:
metadata_2['entity_columns']=['subject_id']

In [105]:
metadata_2

{'fields': {'subject_id': {'type': 'numerical',
   'subtype': 'integer',
   'transformer': 'integer'},
  'gender': {'type': 'categorical', 'transformer': None},
  'dob': {'type': 'datetime', 'transformer': 'datetime'},
  'dod': {'type': 'datetime', 'transformer': 'datetime'},
  'dod_hosp': {'type': 'datetime', 'transformer': 'datetime'},
  'dod_ssn': {'type': 'datetime', 'transformer': 'datetime'}},
 'constraints': [{'constraint': 'sdv.constraints.tabular.CustomConstraint',
   'column_names': ['dod_hosp', 'dod_ssn']},
  {'constraint': 'sdv.constraints.tabular.CustomConstraint',
   'column_names': ['dob']},
  {'constraint': 'sdv.constraints.tabular.ScalarInequality',
   'column_name': 'dob',
   'relation': '<=',
   'value': '2200-01-01'},
  {'constraint': 'sdv.constraints.tabular.ScalarInequality',
   'column_name': 'dob',
   'relation': '>=',
   'value': '2000-01-01'},
  {'constraint': 'sdv.constraints.tabular.Inequality',
   'low_column_name': 'dob',
   'high_column_name': 'dod'}],
 '

In [106]:
LSTMDetection.compute(sample, train_data, metadata_2)

0.44680851063829785