In [None]:
# Imports necessary libraries 

import time
from datasets import load_dataset,DatasetDict, load_metric
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataset import random_split
from transformers import DataCollatorWithPadding,AutoModelForSequenceClassification 
from transformers import Trainer, TrainingArguments,AutoTokenizer,AutoModel,AutoConfig
from transformers.modeling_outputs import TokenClassifierOutput
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import pickle
import json

from transformers import AdamW,get_scheduler, get_linear_schedule_with_warmup
from accelerate import Accelerator, DistributedType, notebook_launcher
from accelerate.utils import set_seed
import random
import math
import statistics
import copy
from functools import partial

from ihan import *

import os
os.environ['CUDA_VISIBLE_DEVICES'] = "1"
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
#os.environ['KMP_DUPLICATE_LIB_OK']='True'

In [None]:
torch.cuda.empty_cache()
!nvidia-smi

## MODEL TRAINING

In [None]:
target="risk_level"
label = 'TARGET'


# number of unique medical code type
cd1 = 'diag'
num_codes1 =  23086 + 1  
cd2 = 'proc'
num_codes2 = 8626 + 1
cd3 = 'rvnu'
num_codes3 = 398 + 1
cd4 = 'gpi'
num_codes4 = 91 + 1


#Loads training and validation datasets from specified file paths.
indir = '/home/jovyan/vol-2/Roja/ds-ihan/v2/high_cost_newborn_mothers_pred_ihan/iter4_v1_train_oot/'    #Where the original data files are
outdir = '/home/jovyan/vol-2/Roja/ds-ihan/v2/high_cost_newborn_mothers_pred_ihan/iter4_v1_train_oot/output'    #Where the outputs (models, train/val/test data) are stored



In [None]:
#Loads training and validation datasets from specified file paths.

cols = ['MCID']+[label]+ ['DOS1', 'DATLIST1',
       'DOS2', 'DATLIST2', 'DOS3', 'DATLIST3', 'DOS4', 'DATLIST4', 'BABY_BIRTH_DT', 'MOM_AGE','HIGH_RISK_PROC_CD_CNT',
       'TRANSVAGINAL_SONOGRAMS', 'VASA_PRIVIA', 'ANATOMICAL_FETAL_SURVEY',
       'MLTPL_GEST', 'EPILEPSY', 'SUM_INPATIENT_VISIT', 'BLOOD_PRESR',
       'NT_TEST', 'COMPL_SMKNG_ALCHL', 'TRANSVAGINAL_ULTRASOUND_EXAM',
       'NST_TEST', 'SUM_EMERGENCY_VISIT', 'PAID_9_MONTHS', 'DIABETES',
       'INFERTILITY', 'PREECLAMPSIA', 'PAPP', 'HIGH_RISK_DIAG_CNT',
       'DOPPLER_FLOW_STUDIES', 'IVF', 'COMPLICATIONS_MLTPL_GEST',
       'TRANSCERVICAL_AND_TRANSABDOMINAL_CHORIONIC_VILLUS_SAMPLING',
       'FETAL_ECHOCARDIOGRAPHY', 'ENDOMETROSIS', 'PCOS', 'AMNIOCENTESIS',
       'SUM_OUTPATIENT_VISIT', 'GEST_DIABETES', 'HCG_SCREENING',
       'PAID_12_MONTHS', 'PAID_6_MONTHS','PAID_3_MONTHS', 'CANCER_HIST_PRSNL', 'AFI_INDEX',
       'GROWTH_RESTRICTION', 'OBESITY_HIGH_BMI', 'INFERTILITY_MEDICINE',
       'CANCER_HIST_FAMLY', 'BPP_PROFILE']

print(cols)
                          
fname = indir+'/train_data_iter4_v1'+".csv"
print(fname)
train_data = pd.read_csv(fname)[cols]

fname = indir+'/val_data_iter4_v1'+".csv"
print(fname)
val_data = pd.read_csv(fname)[cols]

print(train_data.shape, val_data.shape)
print("model data size:", len(train_data) + len(val_data))

In [None]:
#set up for training
##Model Training Configuration

criterion = nn.BCELoss(reduction='mean')     #Binary Cross Entropy
singleGPU=True
batch_size = 32
n_epochs = 20
num_classes = 1
num_heads = 1
p_dropout=0.1
task='binary'
continueTrain = 2
hiddenMethod = 'gru'

## IHAN Model Training

In [None]:
cols_seqCode = ['DATLIST1','DATLIST2']  ## specify medical code columns
cols_seqPair = []
cols_statFeature = ['MOM_AGE']  ## static features

num_seqs = len(cols_seqCode)
num_seqPairs = 0 if cols_seqPair is None else len(cols_seqPair)
num_static = 0 if cols_statFeature is None else len(cols_statFeature)
statFeature = False if cols_statFeature is None else True
print(num_seqs, num_seqPairs, num_static, statFeature)

train_dataset = CustomDataset(train_data, tgt=label, cols_seqCode = cols_seqCode,
                              cols_seqPair = cols_seqPair, cols_statFeature = cols_statFeature)
val_dataset = CustomDataset(val_data, tgt=label, cols_seqCode = cols_seqCode,
                              cols_seqPair = cols_seqPair, cols_statFeature = cols_statFeature)

print(train_dataset, '\n', val_dataset, '\n')

#specify the num_seqcode variables for the specified cols_seqCode
model=CustomModel(num_classes,num_seqCode=[num_codes1,num_codes2], num_seqPair = [] , num_static=num_static, embedding_dim=128, hiddenMethod = hiddenMethod,num_heads = num_heads,p_dropout=p_dropout)
print(model)

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4)
collate_fn_arg = partial(collate_fn, num_seq_vars=num_seqs, num_seqPairs_vars=num_seqPairs, statFeature=statFeature, max_ncodes_perVisit = 512)

print("start training ")
start_time = time.time()
notebook_launcher(train, (model, train_dataset, val_dataset, batch_size, collate_fn_arg, criterion, optimizer, 
                                 n_epochs, outdir,singleGPU,task,continueTrain),num_processes=1) 
print("Training took (minutes): ", (time.time() - start_time)/60)

## PREDICTION ON OOT TEST DATA

In [None]:
#read OOT file from the specified path
dataFile = '/home/jovyan/vol-2/Roja/ds-ihan/v2/high_cost_newborn_mothers_pred_ihan/iter4_v1_train_oot/test_oot_iter_4_v1.csv'
df = pd.read_csv(dataFile)
print(df.shape)
print(df.columns)

In [None]:
#load the trained model
modelFile = '/home/jovyan/vol-2/Roja/ds-ihan/v2/high_cost_newborn_mothers_pred_ihan/iter4_v1_train_oot/output/ihanModel_201_7epochs_auc0.89268_1731570392.sav'
model = torch.load(modelFile, map_location=torch.device('cuda'))
model.eval()

In [None]:
task = 'binary'
batch_size = 16
tgt ='TARGET'

cols_seqCode = ['DATLIST1','DATLIST2'] # specify the medical codes columns
cols_seqPair = []
cols_statFeature = ['MOM_AGE']  #static features

num_seqs = len(cols_seqCode)
num_seqPairs = 0 if cols_seqPair is None else len(cols_seqPair)
num_static = 0 if cols_statFeature is None else len(cols_statFeature)
statFeature = False if cols_statFeature is None else True
print(num_seqs, num_seqPairs, num_static, statFeature)

ID='MCID'
dos_cols=['DOS1', 'DOS2']  # medical codes DOS columns
med_type = ['diag','proc']  # specify the medical code type
dos_cols_p = []
med_type_p = []
cols_pairList = []

interpret = False
evaluation=True

In [None]:
#prediction
chunksize = 400
pred_df=pred_datainchunks(model, dataFile, chunksize, ID, tgt,cols_seqCode,cols_seqPair,cols_statFeature,batch_size,interpret,task,dos_cols, med_type, dos_cols_p,med_type_p,cols_pairList,evaluation)

## Baby's born between JAN 2024 TO JUN 2024

In [None]:
dataFile1 = '/home/jovyan/vol-2/Roja/ds-ihan/v2/high_cost_newborn_mothers_pred_ihan/iter4_v1_train_oot/test_oot_iter_4_v1_6mnth.csv'
df = pd.read_csv(dataFile1)
print(df.shape)
print(df.columns)

In [None]:
#prediction
chunksize = 400
pred_df=pred_datainchunks(model, dataFile1, chunksize, ID, tgt,cols_seqCode,cols_seqPair,cols_statFeature,batch_size,interpret,task,dos_cols, med_type, dos_cols_p,med_type_p,cols_pairList,evaluation)

## Baby's born between Jul 2024 TO sep 2024

In [None]:
#read the file from the specified both 
dataFile2 = '/home/jovyan/vol-2/Roja/ds-ihan/v2/high_cost_newborn_mothers_pred_ihan/iter4_v1_train_oot/test_oot_iter_4_v1_3mnth.csv'
df = pd.read_csv(dataFile2)
print(df.shape)
print(df.columns)

In [None]:
#prediction
chunksize = 400
pred_df=pred_datainchunks(model, dataFile2, chunksize, ID, tgt,cols_seqCode,cols_seqPair,cols_statFeature,batch_size,interpret,task,dos_cols, med_type, dos_cols_p,med_type_p,cols_pairList,evaluation)

## Interpretation on OOT data

In [None]:
interpret = True
evaluation=True

In [None]:
#interpretation
chunksize = 400
df_contribution, df_contribution_mcid_code_sum, df_contribution_code_summary = pred_datainchunks(model, dataFile, chunksize, ID, tgt,
                    cols_seqCode,cols_seqPair,cols_statFeature,batch_size,interpret,task,
                    dos_cols, med_type, dos_cols_p,med_type_p,cols_pairList,evaluation)

In [None]:
df_contribution_code_summary

## read dict table

In [None]:
dict_data=pd.read_csv('dict_indx_desc.csv')

In [None]:
dict_data.head()

In [None]:
dict_data.isnull().sum()

In [None]:
dict_data.info()

In [None]:
dict_data['CD_TYPE'].value_counts()

In [None]:
dict_data['CD_TYPE']=dict_data['CD_TYPE'].replace({'DIAG_CD':'diag','RVNU_CD':'rvnu','HLTH_SRVC_CD':'proc','GPI_02_GRP_CD':'gpi'})

In [None]:
dict_data['CD_TYPE'].value_counts()

## create ranK for the medical code based on contribution coefficient score

In [None]:
df_contribution_code_summary=df_contribution_code_summary.sort_values(by='contribCoef',ascending=False).reset_index(drop=True).reset_index()
df_contribution_code_summary['index']=df_contribution_code_summary['index']+1
df_contribution_code_summary.rename(columns={'index':'rank'},inplace=True)
df_contribution_code_summary.head()

## top increasing and decreasing risk code

In [None]:
# merge with dict table to get description and code for the medical code index

In [None]:
df_top=df_contribution_code_summary.sort_values(by='rank',ascending=True)
df_top=df_top[df_top['count']>=5].reset_index(drop=True)
df_merged_top=pd.merge(df_top,dict_data ,left_on=['codeIndex','type'],right_on=['IDX','CD_TYPE'],how='left')
df_merged_top_code=df_merged_top[['rank','type','codeIndex','contribCoef','count','CD_VALUE','DESCRIPTION']]

In [None]:
df_merged_top_code.head()

In [None]:
#df_merged_top_code.to_csv('top_contribution_code_IHAN201.csv')

In [None]:
#bottom contribution code
df_bottom =df_contribution_code_summary.sort_values(by='rank',ascending=False).reset_index(drop=True)
df_bottom=df_bottom[df_bottom['count']>=5]
df_merged_bottom=pd.merge(df_bottom,dict_data ,left_on=['codeIndex','type'],right_on=['IDX','CD_TYPE'],how='left')
df_merged_bottom_code=df_merged_bottom[['rank','type','codeIndex','contribCoef','count','CD_VALUE','DESCRIPTION']]

In [None]:
#df_merged_bottom_code.to_csv('bottom_contribution_code_IHAN201.csv',index=False)

## Merging interpretation result table and code dictionary index description table to get actual code and description for the medical code index

In [None]:
def idx2code_merge(df,dict_df,cols=['IDX','TYPE','VALUE','DESCRIPTION']):
    dict_df.columns=['codeIndex','type','code','cd_desc']
    df_static=df[df['type']=='static']
    df_static['cd_desc']=df_static['codeIndex'].apply(lambda x: x)
    df_static['code'] =  df_static['codeIndex'].apply(lambda x: x)
    df_filter=df[~df['type'].isin(['static'])]
    df_filter['codeIndex']=df_filter['codeIndex'].astype(int)
    dict_df['codeIndex']=dict_df['codeIndex'].astype(int)
    merged=df_filter.merge(dict_df,on=['type','codeIndex'],how='left')
    df=pd.concat([merged,df_static],ignore_index=True)
    return df

In [None]:
df_merged=idx2code_merge(df_contribution,dict_data,cols=['IDX','CD_TYPE','CD_VALUE','DESCRIPTION'])

In [None]:
df_merged.columns

In [None]:
df_pred=df_merged[['mcid', 'y_obs', 'y_score', 'dos', 'contribCoef', 'type','value', 'code', 'cd_desc']]

In [None]:
df_pred.head()

In [None]:
df_pred['type'].value_counts()

## save interpretation result into snowflake table

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from dateutil.relativedelta import relativedelta
import snowflake.connector
from snowflake.connector.pandas_tools import write_pandas
import warnings
# Ignore all warnings
warnings.filterwarnings('ignore')

import sys
print(sys.version)

# Establish snowflake connection
conn = snowflake.connector.connect(
    account = "carelon-edaprod1.privatelink",
    user = "HAID",
    password = "PASSWORD",
    warehouse = "DL_AIFS_STAR_USER_WH_L",
    database = "NON_CRTFD_AIFS",
    schema = "DL_TS_STAR"
)

# Create a cursor object
cursor = conn.cursor()

In [None]:
test_table_nogbd='NON_CRTFD_AIFS.DL_TS_STAR.HIGH_RISK_NEWBORNS_MOTHER_OOT_TEST_SCORE_ITER4_V1_IHAN201'
query = """
DROP TABLE IF EXISTS {0}
""".format(test_table_nogbd)
print(query)
cursor.execute(query)

In [None]:
#Write to snowflake table
query = """
CREATE TABLE IF NOT EXISTS {0} 
(mcid STRING,y_obs NUMBER,y_score FLOAT,dos STRING,contribCoef FLOAT,type STRING,value FLOAT,code STRING,cd_desc STRING)
""".format(test_table_nogbd)
cursor.execute(query)
success, nchunks, nrows, _ = write_pandas(conn, df_pred, test_table_nogbd, quote_identifiers=False)

In [None]:
query="""select * from NON_CRTFD_AIFS.DL_TS_STAR.HIGH_RISK_NEWBORNS_MOTHER_OOT_TEST_SCORE_ITER4_V1_IHAN201 LIMIT 5 """
data=pd.read_sql_query(query,conn)

In [None]:
data.head()