## Get text embeddings

## Hardware check

In [8]:
# gpu check
! nvidia-smi

Mon Apr 24 19:31:40 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  On   | 00000000:3B:00.0 Off |                    0 |
| N/A   36C    P0    39W / 300W |     43MiB / 32768MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [9]:
! python --version

Python 3.9.7


In [1]:
from tensorflow.keras.layers import Input, Dense, Dropout, BatchNormalization, Lambda
# from tensorflow.keras.models import Model
from tensorflow.keras import models
import pickle
import numpy as np
from tqdm import tqdm
# tqdm.pandas()
from sklearn.utils.class_weight import compute_class_weight
from tensorflow.keras.optimizers import Adam
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc
import tensorflow.keras.backend as K
from tensorflow.keras.callbacks import Callback, EarlyStopping
import pandas as pd
import json
from torch.utils.data import Dataset
from transformers import AutoTokenizer, pipeline, AutoModel
import smart_cond as sc
# from google.colab import files

2023-04-24 21:21:11.477425: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-04-24 21:21:11.521465: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


### Version Check

In [6]:
import tensorflow as tf
print(tf. __version__)

pickle.format_version

2.12.0


'4.0'

## Load Data

In [6]:
data_path = 'preprocessed_sepsis_data_with_text_sepsis_removed.pkl'
pkl = pickle.load(open(data_path, 'rb'))
data = pkl[0]
oc = pkl[1]
train_ind = pkl[2]
valid_ind = pkl[3]
test_ind = pkl[4]
del pkl

In [7]:
pred_window = 2  # hours
obs_windows = range(20, 124, 4)

# Remove train, val patients
data = data.merge(oc[['ts_ind', 'SUBJECT_ID']], on='ts_ind', how='left')
train_sub = oc.loc[oc.ts_ind.isin(train_ind)].SUBJECT_ID.unique()
valid_sub = oc.loc[oc.ts_ind.isin(valid_ind)].SUBJECT_ID.unique()
data = data.loc[~data.SUBJECT_ID.isin(train_sub)]
data = data.loc[~data.SUBJECT_ID.isin(valid_sub)]
oc = oc.loc[~oc.SUBJECT_ID.isin(train_sub)]
oc = oc.loc[~oc.SUBJECT_ID.isin(valid_sub)]

data.drop(columns=['SUBJECT_ID', 'TABLE'], inplace=True)
# Get static data with mean fill and missingness indicator.
static_varis = ['Age', 'Gender']
ii = data.variable.isin(static_varis)
static_data = data.loc[ii]
data = data.loc[~ii]


def inv_list(l, start=0):
    d = {}
    for i in range(len(l)):
        d[l[i]] = i+start
    return d


static_var_to_ind = inv_list(static_varis)
D = len(static_varis)
N = data.ts_ind.max()+1

# Get variable indices.
varis = sorted(list(set(data.variable)))
V = len(varis)
var_to_ind = inv_list(varis, start=1)
data['vind'] = data.variable.map(var_to_ind)
data = data[['ts_ind', 'vind', 'hour', 'value']
            ].sort_values(by=['ts_ind', 'vind', 'hour'])
# Find max_len.
fore_max_len = 880
# Get forecast inputs and outputs.
fore_texts_ip = []
fore_inds = []


def f(x):
    mask = [0 for i in range(V)]
    values = [0 for i in range(V)]
    for vv in x:
        v = int(vv[0])-1
        mask[v] = 1
        values[v] = vv[1]
    return values+mask


def pad(x):
    return x+[0]*(fore_max_len-len(x))


for w in tqdm(obs_windows):
    pred_data = data.loc[(data.hour >= w) & (data.hour <= w+pred_window)]
    pred_data = pred_data.groupby(['ts_ind', 'vind']).agg(
        {'value': 'first'}).reset_index()
    pred_data['vind_value'] = pred_data[['vind', 'value']].values.tolist()
    pred_data = pred_data.groupby('ts_ind').agg(
        {'vind_value': list}).reset_index()
    pred_data['vind_value'] = pred_data['vind_value'].apply(f)
    obs_data = data.loc[(data.hour < w) & (data.hour >= w-24)]
    obs_data = obs_data.loc[obs_data.ts_ind.isin(pred_data.ts_ind)]
    obs_data = obs_data.groupby('ts_ind').head(fore_max_len)
    obs_data = obs_data.groupby('ts_ind').agg(
        {'vind': list, 'hour': list, 'value': list}).reset_index()
    obs_data = obs_data.merge(pred_data, on='ts_ind')
    for col in ['vind', 'hour', 'value']:
        obs_data[col] = obs_data[col].apply(pad)
    fore_inds.append(np.array(list(obs_data.ts_ind)))

    matrix = list(obs_data.value)
    obs_strings = []
    for l in matrix:
        string_list = []
        for value in l:
            if isinstance(value, str):
                string_list.append(value)
        obs_strings.append(string_list)
    del matrix
    fore_texts_ip.append(np.array(obs_strings))
del data

fore_texts_ip = np.concatenate(fore_texts_ip, axis=0)

  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(ob

In [9]:
len(fore_texts_ip)

265307

In [10]:
fore_test_concat_text_ip = []
# concat train texts per instance
for text in tqdm(fore_texts_ip):
    concat_text = ' '.join(text)
    fore_test_concat_text_ip.append(concat_text)

100%|██████████| 265307/265307 [00:01<00:00, 238570.04it/s]


In [8]:
# pred_window = 2 # hours
# obs_windows = range(20, 124, 4)

# # Remove test patients.
# data = data.merge(oc[['ts_ind', 'SUBJECT_ID']], on='ts_ind', how='left')
# test_sub = oc.loc[oc.ts_ind.isin(test_ind)].SUBJECT_ID.unique()
# data = data.loc[~data.SUBJECT_ID.isin(test_sub)]
# oc = oc.loc[~oc.SUBJECT_ID.isin(test_sub)]
# data.drop(columns=['SUBJECT_ID', 'TABLE'], inplace=True)
# # Get static data with mean fill and missingness indicator.
# static_varis = ['Age', 'Gender']
# ii = data.variable.isin(static_varis)
# static_data = data.loc[ii]
# data = data.loc[~ii]
# def inv_list(l, start=0):
#     d = {}
#     for i in range(len(l)):
#         d[l[i]] = i+start
#     return d
# static_var_to_ind = inv_list(static_varis)
# D = len(static_varis)
# N = data.ts_ind.max()+1

# # Get variable indices.
# varis = sorted(list(set(data.variable)))
# V = len(varis)
# var_to_ind = inv_list(varis, start=1)
# data['vind'] = data.variable.map(var_to_ind)
# data = data[['ts_ind', 'vind', 'hour', 'value']].sort_values(by=['ts_ind', 'vind', 'hour'])
# # Find max_len.
# fore_max_len = 880
# # Get forecast inputs and outputs.
# fore_texts_ip = []
# fore_inds = []
# def f(x):
#     mask = [0 for i in range(V)]
#     values = [0 for i in range(V)]
#     for vv in x:
#         v = int(vv[0])-1
#         mask[v] = 1
#         values[v] = vv[1]
#     return values+mask
# def pad(x):
#     return x+[0]*(fore_max_len-len(x))
# for w in tqdm(obs_windows):
#     pred_data = data.loc[(data.hour>=w)&(data.hour<=w+pred_window)]
#     pred_data = pred_data.groupby(['ts_ind', 'vind']).agg({'value':'first'}).reset_index()
#     pred_data['vind_value'] = pred_data[['vind', 'value']].values.tolist()
#     pred_data = pred_data.groupby('ts_ind').agg({'vind_value':list}).reset_index()
#     pred_data['vind_value'] = pred_data['vind_value'].apply(f)    
#     obs_data = data.loc[(data.hour<w)&(data.hour>=w-24)]
#     obs_data = obs_data.loc[obs_data.ts_ind.isin(pred_data.ts_ind)]
#     obs_data = obs_data.groupby('ts_ind').head(fore_max_len)
#     obs_data = obs_data.groupby('ts_ind').agg({'vind':list, 'hour':list, 'value':list}).reset_index()
#     obs_data = obs_data.merge(pred_data, on='ts_ind')
#     for col in ['vind', 'hour', 'value']:
#         obs_data[col] = obs_data[col].apply(pad)
#     fore_inds.append(np.array(list(obs_data.ts_ind)))
    
#     matrix = list(obs_data.value)
#     obs_strings = []
#     for l in matrix:
#         string_list = []
#         for value in l:
#             if isinstance(value, str):
#                 string_list.append(value)
#         obs_strings.append(string_list)
#     del matrix
#     fore_texts_ip.append(np.array(obs_strings)) 
# del data
# fore_texts_ip = np.concatenate(fore_texts_ip, axis=0)
# fore_inds = np.concatenate(fore_inds, axis=0)
# # Get train and valid ts_ind for forecast task.
# train_sub = oc.loc[oc.ts_ind.isin(train_ind)].SUBJECT_ID.unique()
# valid_sub = oc.loc[oc.ts_ind.isin(valid_ind)].SUBJECT_ID.unique()
# rem_sub = oc.loc[~oc.SUBJECT_ID.isin(np.concatenate((train_ind, valid_ind)))].SUBJECT_ID.unique()
# bp = int(0.8*len(rem_sub))
# train_sub = np.concatenate((train_sub, rem_sub[:bp]))
# valid_sub = np.concatenate((valid_sub, rem_sub[bp:]))
# train_ind = oc.loc[oc.SUBJECT_ID.isin(train_sub)].ts_ind.unique() # Add remaining ts_ind s of train subjects.
# valid_ind = oc.loc[oc.SUBJECT_ID.isin(valid_sub)].ts_ind.unique() # Add remaining ts_ind s of train subjects.
# # Generate 3 sets of inputs and outputs.
# train_ind = np.argwhere(np.in1d(fore_inds, train_ind)).flatten()
# valid_ind = np.argwhere(np.in1d(fore_inds, valid_ind)).flatten()

# fore_train_ip = [ip[train_ind] for ip in [fore_texts_ip]]
# fore_valid_ip = [ip[valid_ind] for ip in [fore_texts_ip]]
# del fore_texts_ip

# fore_train_text_ip = fore_train_ip[0]
# fore_valid_text_ip = fore_valid_ip[0]
# del fore_train_ip, fore_valid_ip

  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(obs_strings))
  fore_texts_ip.append(np.array(ob

In [36]:
len(fore_train_text_ip)

452944

In [37]:
len(fore_valid_text_ip)

53331

In [29]:
fore_train_concat_text_ip = []
# concat train texts per instance
for text in tqdm(fore_train_text_ip):
    concat_text = ' '.join(text)
    fore_train_concat_text_ip.append(concat_text)

100%|██████████| 452944/452944 [00:01<00:00, 355007.87it/s]


In [30]:
fore_valid_concat_text_ip = []
# concat train texts per instance
for text in tqdm(fore_valid_text_ip):
    concat_text = ' '.join(text)
    fore_valid_concat_text_ip.append(concat_text)

100%|██████████| 53331/53331 [00:00<00:00, 285661.92it/s]


In [31]:
# dump to pkl
pickle.dump([fore_train_concat_text_ip, fore_valid_concat_text_ip], open('text_ip.pkl','wb'))

In [32]:
data_path = 'text_ip.pkl'
pkl = pickle.load(open(data_path, 'rb'))

In [34]:
len(pkl[0])

452944

In [35]:
len(pkl[1])

53331