In [19]:
# from google.colab import drive
# drive.mount('/content/drive')

In [1]:
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

In [2]:
import pandas as pd
import numpy as np

In [3]:
date_set_path = "../temp_sets/"

## Read Tempdatasets

In [30]:
# read csv file as dataframe, and drop ROW_ID column
def read_csv_no_rowid(file_path):
    df = pd.read_csv(file_path)
    low_list = [x.lower() for x in df.columns]
    df.columns = low_list
    df.drop(['unnamed: 0.1', 'unnamed: 0', 'row_id'], axis=1, inplace=True)

    return df

In [31]:
# check NaN value

def nan_count(df):
    print("Total columns: " + str(len(df.columns)))
    print("Total rows: " + str(len(df)))
    print("--------------")
    print(df.isnull().sum())

### Chartevents

In [32]:
chartevents_sample_df = read_csv_no_rowid(date_set_path + "CHARTEVENTS_random_sample_1.csv")

In [33]:
chartevents_sample_df.head()

Unnamed: 0,subject_id,hadm_id,icustay_id,itemid,charttime,storetime,cgid,value,valuenum,valueuom,warning,error,resultstatus,stopped
0,10694,138159,294193.0,220210,2153-08-21 10:34:00,,,27.0,27.0,insp/min,0.0,0.0,,
1,1459,172420,212644.0,224162,2195-11-10 11:56:00,2195-11-10 11:57:00,15047.0,8.0,8.0,insp/min,0.0,0.0,,
2,8492,118470,225777.0,220210,2117-07-03 14:00:00,2117-07-03 15:01:00,19593.0,19.0,19.0,insp/min,0.0,0.0,,
3,10694,138159,294193.0,220293,2153-08-14 06:13:00,,,18.0,18.0,L/min,0.0,0.0,,
4,12831,119477,295273.0,224697,2195-09-19 12:02:00,,,11.0,11.0,cmH2O,0.0,0.0,,


In [34]:
# Drop useless colums
chartevents_sample_df.drop(['storetime', 'cgid', 'valuenum', 'valueuom', 'warning', 'error', 'resultstatus', 'stopped'], axis=1, inplace=True)

In [35]:
chartevents_sample_df.dropna(subset=['icustay_id'], inplace=True, axis=0)

In [36]:
chartevents_sample_df.loc[pd.isna(chartevents_sample_df['icustay_id']) == True]

Unnamed: 0,subject_id,hadm_id,icustay_id,itemid,charttime,value


In [37]:
# Transfer some date type
chartevents_sample_df['icustay_id'] = chartevents_sample_df['icustay_id'].astype(int)
chartevents_sample_df['charttime'] = pd.to_datetime(chartevents_sample_df['charttime'])

In [38]:
chartevents_sample_df.dtypes

subject_id             int64
hadm_id                int64
icustay_id             int32
itemid                 int64
charttime     datetime64[ns]
value                 object
dtype: object

In [39]:
nan_count(chartevents_sample_df)

Total columns: 6
Total rows: 825983
--------------
subject_id       0
hadm_id          0
icustay_id       0
itemid           0
charttime        0
value         5164
dtype: int64


In [40]:
chartevents_sample_df['value'].fillna(value='Na', inplace=True)

In [41]:
nan_count(chartevents_sample_df)

Total columns: 6
Total rows: 825983
--------------
subject_id    0
hadm_id       0
icustay_id    0
itemid        0
charttime     0
value         0
dtype: int64


In [42]:
chartevents_sample_df.head()

Unnamed: 0,subject_id,hadm_id,icustay_id,itemid,charttime,value
0,10694,138159,294193,220210,2153-08-21 10:34:00,27.0
1,1459,172420,212644,224162,2195-11-10 11:56:00,8.0
2,8492,118470,225777,220210,2117-07-03 14:00:00,19.0
3,10694,138159,294193,220293,2153-08-14 06:13:00,18.0
4,12831,119477,295273,224697,2195-09-19 12:02:00,11.0


##  Build Network

---

### CTGAN

In [43]:
from sdv.tabular import CTGAN
from sdv.evaluation import evaluate
from sdv.constraints import FixedCombinations

In [44]:
fixed_subject_hadm_icustay_constraint = FixedCombinations(
    column_names=['subject_id', 'hadm_id', 'icustay_id'],
    handling_strategy='transform'
)

fixed_item_value_constraint = FixedCombinations(
    column_names=['itemid', 'value'],
    handling_strategy='transform'
)

In [45]:
chartevents_constraints = [fixed_subject_hadm_icustay_constraint, fixed_item_value_constraint]

In [46]:
model = CTGAN(constraints=chartevents_constraints, cuda=True, verbose=True, epochs=100)

In [47]:
len(chartevents_sample_df)

825983

In [None]:
model.fit(chartevents_sample_df.sample(n=100000))

In [None]:
sample = model.sample(num_rows=10000)

In [None]:
sample.drop_duplicates(subset=['icustay_id'])

Unnamed: 0,subject_id,hadm_id,icustay_id,itemid,charttime,value,valuenum,valueuom
0,10139,103843,236308,2269,2126-04-21 15:14:00,97,-4.085128,Na
1,4394,132535,201596,228332,2165-03-10 01:31:00,-21,-27.002549,BPM
2,11279,103849,246752,3753,2166-11-19 00:32:00,-19,-56.898943,units
3,19696,143133,211294,1995,2136-08-24 08:20:00,Social Services,-1.706791,ml/min
4,57493,103850,202438,3415,2126-01-31 19:37:00,Standard,-2.061842,Na
...,...,...,...,...,...,...,...,...
9984,91125,143100,245975,221906,2114-12-22 15:52:00,21.420000076293945,42.476885,Breath
9985,19718,137883,267225,223616,2126-02-24 06:15:00,5.5,40.287828,mmHg
9994,57485,103615,259051,225682,2113-06-13 05:29:00,8.3999996185302734,-2.848200,Na
9996,57645,103824,258357,226396,2136-11-11 13:32:00,108,-13.787698,in


In [None]:
evaluate(sample, chartevents_sample_df.loc[0:1000, :], aggregate=False)

### Timeseries table (error)

In [None]:
# from sdv.timeseries import PAR
# from sdv.constraints import FixedCombinations

In [None]:
# entity_columns = ['subject_id', 'hadm_id', 'icustay_id']
# context_columns = []
# sequence_index = 'charttime'

In [None]:
# fixed_itemid_value_constraint = FixedCombinations(column_names=['itemid', 'value'], handling_strategy='transform')

# constraints = [fixed_itemid_value_constraint]

In [None]:
# model = PAR(entity_columns=entity_columns,context_columns=context_columns,sequence_index=sequence_index, constraints=constraints)

In [None]:
# model.fit(chartevents_sample_df.loc[18170:19000, :])

In [None]:
# model.sample(num_sequences=2)

In [None]:
# # 检查是否存在entity=1的情况, 这种情况下回报错(bugs)
# sequences = chartevents_sample_df[['subject_id', 'hadm_id', 'icustay_id', 'itemid']].groupby(['subject_id', 'hadm_id', 'icustay_id', 'itemid']).size().reset_index().rename(columns={0: 'sequence_length'})
# sequences[sequences['sequence_length'] == 1]