## For Checking GPU properties.

In [1]:
# !nvidia-smi

## Important libraries to Install before start

In [2]:
# !pip install overrides==3.1.0
# !pip install allennlp==0.8.4
# !pip install pytorch-pretrained-bert
# !pip install transformers==4.4.2
# !pip install entmax

## Static arguments and other model hyper-parameters declaration. 

In [3]:
args = {}
args['origin_path'] = '/content/drive/MyDrive/Thesis/DeepDifferentialDiagnosis/data/MIMIC-III/mimic-iii-clinical-database-1.4/'
args['out_path'] = '/content/drive/MyDrive/Thesis/DeepDifferentialDiagnosis/data/25.10.2021-Old-Compare/test/'
args['min_sentence_len'] = 3
args['random_seed'] = 1
args['vocab'] = '%s%s.csv' % (args['out_path'], 'vocab')
args['vocab_min'] = 3
args['Y'] = 'full' #'50'/'full'
args['data_path'] = '%s%s_%s.csv' % (args['out_path'], 'train', args['Y']) #50/full
args['version'] = 'mimic3'
args['model'] = 'KG_MultiResCNN' #'KG_MultiResCNNLSTM','bert_we_kg'
args['gpu'] = 0
args['embed_file'] = '%s%s_%s.embed' % (args['out_path'], 'processed', 'full')
args['use_ext_emb'] = False
args['dropout'] = 0.2
args['num_filter_maps'] = 50
args['conv_layer'] = 2
args['filter_size'] = '3,5,7,9,13,15,17,23,29' #'3,5,9,15,19,25', '3,5,7,9,13,15,17,23,29'
args['test_model'] = None
args['weight_decay'] = 0 #Adam, 0.01 #AdamW
args['lr'] = 0.001 #0.0005, 0.001, 0.00146 best 3 for adam and Adamw, 1e-5 for Bert
args['tune_wordemb'] = True
args['MAX_LENGTH'] = 3000 #2500, 512 for bert #1878 is the avg length and max length is 10504 for only discharge summary. 238438 is the max length, 3056 is the avg length combined DS+PHY+NUR
args['batch_size'] = 6 #8,16
args['n_epochs'] = 15
args['MODEL_DIR'] = '/content/drive/MyDrive/Thesis/DeepDifferentialDiagnosis/data/model_output'
args['criterion'] = 'prec_at_8'
args['for_test'] = False
args['bert_dir'] = '/content/drive/MyDrive/Thesis/DeepDifferentialDiagnosis/data/Bert/'
args['pretrained_bert'] = 'bert-base-uncased' # 'emilyalsentzer/Bio_ClinicalBERT''bert-base-uncased' 'dmis-lab/biobert-base-cased-v1.1'
args['instance_count'] = 'full' #if not full then the number specified here will be the number of samples.
args['graph_embedding_file'] = '/home/pgoswami/DifferentialEHR/data/Pytorch-BigGraph/wikidata_translation_v1.tsv.gz'
args['entity_dimention'] = 200 #pytorch biggraph entity has dimention size of 200
# args['entity_selected'] = 5
args['MAX_ENT_LENGTH'] = 30 #mean value is 27.33, for DS+PY+NR max 49, avg 29
args['use_embd_layer'] = True
args['add_with_wordrap'] = True
args['step_size'] = 8
args['gamma'] = 0.1
args['patience'] = 10 #if does not improve result for 5 epochs then break.
args['use_schedular'] = True
args['grad_clip'] = False
args['use_entmax15'] = False
args['use_sentiment']=False
args['sentiment_bert'] = 'siebert/sentiment-roberta-large-english'
args['use_tfIdf'] = True
args['use_proc_label'] = True
args['notes_type'] = 'Discharge summary' # 'Discharge summary,Nursing,Physician ' / 'Nursing,Physician '
args['comment'] = """ My changes with 3000 token+30 ent embed. Tf-idf weight. Diag_ICD+Prod_ICD used, 50 codes """
args['save_everything'] = True

###############

## Data Processing ##

In [None]:
#wikidump creation process and indexing
# import wikimapper
# wikimapper.download_wikidumps(dumpname="enwiki-latest", path="/home/pgoswami/DifferentialEHR/data/Wikidata_dump/")
# wikimapper.create_index(dumpname="enwiki-latest",path_to_dumps="/home/pgoswami/DifferentialEHR/data/Wikidata_dump/", 
#                         path_to_db= "/home/pgoswami/DifferentialEHR/data/Wikidata_dump/index_enwiki-latest.db")

In [None]:
# Pytorch Biggraph pre-trained embedding file downloaded from 
#https://github.com/facebookresearch/PyTorch-BigGraph#pre-trained-embeddings
# to '/home/pgoswami/DifferentialEHR/data/Pytorch-BigGraph/wikidata_translation_v1.tsv.gz'

In [None]:
class ProcessedIter(object):

    def __init__(self, Y, filename):
        self.filename = filename

    def __iter__(self):
        with open(self.filename) as f:
            r = csv.reader(f)
            next(r)
            for row in r:
                yield (row[2].split()) #after group-by with subj_id and hadm_id, text is in 3rd column

In [None]:
import pandas as pd
import numpy as np
from collections import Counter, defaultdict
import csv
import sys
import operator
# import operator
from scipy.sparse import csr_matrix
from tqdm import tqdm
import gensim.models
import gensim.models.word2vec as w2v
import gensim.models.fasttext as fasttext
import nltk
nltk.download('punkt')
from nltk.tokenize import RegexpTokenizer
nlp_tool = nltk.data.load('tokenizers/punkt/english.pickle')
tokenizer = RegexpTokenizer(r'\w+')

import re
from transformers import pipeline #for entity extraction
from wikimapper import WikiMapper #creating wikidata entity id

import pickle
import smart_open as smart

class DataProcessing:
    
    def __init__(self, args):
        # step 1: process code-related files
        dfdiag = pd.read_csv(args['origin_path']+'DIAGNOSES_ICD.csv')
        if args['use_proc_label']:
            dfproc = pd.read_csv(args['origin_path']+'PROCEDURES_ICD.csv')
        
        dfdiag['absolute_code'] = dfdiag.apply(lambda row: str(self.reformat(str(row[4]), True)), axis=1)
        if args['use_proc_label']:
            dfproc['absolute_code'] = dfproc.apply(lambda row: str(self.reformat(str(row[4]), False)), axis=1)
        
        dfcodes = pd.concat([dfdiag, dfproc]) if args['use_proc_label'] else dfdiag
        
        dfcodes.to_csv(args['out_path']+'ALL_CODES.csv', index=False,
           columns=['ROW_ID', 'SUBJECT_ID', 'HADM_ID', 'SEQ_NUM', 'absolute_code'],
           header=['ROW_ID', 'SUBJECT_ID', 'HADM_ID', 'SEQ_NUM', 'ICD9_CODE']) #columns: 'ROW_ID', 'SUBJECT_ID', 'HADM_ID', 'SEQ_NUM', 'ICD9_CODE'
        print("unique ICD9 code: {}".format(len(dfcodes['absolute_code'].unique())))
        
        del dfcodes
        if args['use_proc_label']:
            del dfproc
        del dfdiag
        
        # step 2: process notes
        # min_sentence_len = 3
        disch_full_file = self.write_discharge_summaries(args['out_path']+'disch_full_acc.csv', args['min_sentence_len'], args['origin_path']+'NOTEEVENTS.csv')
        dfnotes = pd.read_csv(args['out_path']+'disch_full_acc.csv')
        dfnotes = dfnotes.sort_values(['SUBJECT_ID', 'HADM_ID'])
        dfnotes = dfnotes.drop_duplicates()
        dfnotes = dfnotes.groupby(['SUBJECT_ID','HADM_ID']).apply(lambda x: pd.Series({'TEXT':' '.join(str(v) for v in x.TEXT)})).reset_index()
        dfnotes.to_csv(args['out_path']+'disch_full.csv', index=False) #columns: 'SUBJECT_ID', 'HADM_ID', 'TEXT'
        
        # step 3: filter out the codes that not emerge in notes
        subj_ids = set(dfnotes['SUBJECT_ID'])
        self.code_filter(args['out_path'], subj_ids) 
        dfcodes_filtered = pd.read_csv(args['out_path']+'ALL_CODES_filtered_acc.csv', index_col=None)
        dfcodes_filtered = dfcodes_filtered.sort_values(['SUBJECT_ID', 'HADM_ID'])
        dfcodes_filtered.to_csv(args['out_path']+'ALL_CODES_filtered.csv', index=False) #columns: 'SUBJECT_ID', 'HADM_ID', 'ICD9_CODE', 'ADMITTIME', 'DISCHTIME'
        del dfnotes
        del dfcodes_filtered
        
        # step 4: link notes with their code
#         labeled = self.concat_data(args['out_path']+'ALL_CODES_filtered.csv', args['out_path']+'disch_full.csv', args['out_path']+'notes_labeled.csv')
        labeled = self.concat_data_new(args['out_path']+'ALL_CODES_filtered.csv', args['out_path']+'disch_full.csv', args['out_path']+'notes_labeled.csv')
         #columns: 'SUBJECT_ID', 'HADM_ID', 'TEXT', 'LABELS'
        
        labled_notes = pd.read_csv(labeled) #columns: 'SUBJECT_ID', 'HADM_ID', 'TEXT', 'LABELS'

        labled_notes = labled_notes.drop_duplicates()
        labled_notes.to_csv(labeled, index=False) #columns: 'SUBJECT_ID', 'HADM_ID', 'TEXT', 'LABELS'
        
        # step 5: statistic unique word, total word, HADM_ID number
        types = set()
        num_tok = 0
        for row in labled_notes.itertuples(): 
            for w in row[3].split(): #TEXT in 4rd column when used itertuples
                types.add(w)
                num_tok += 1
        
        print("num types", len(types), "num tokens", num_tok)
        print("HADM_ID: {}".format(len(labled_notes['HADM_ID'].unique())))
        print("SUBJECT_ID: {}".format(len(labled_notes['SUBJECT_ID'].unique())))
        del labled_notes
        
        
        
        #important step for entity extraction and finding their entity id from wikidata.
        fname_entity = self.extract_entity('%snotes_labeled.csv' % args['out_path'],  '%snotes_labeled_entity.csv' % args['out_path'])
        #columns: 'SUBJECT_ID', 'HADM_ID', 'TEXT', 'LABELS', 'ENTITY_ID'
        
        #important step to create embedding file from Pytorch Biggraph pretrained embedding file for our dataset entities.
        self.extract_biggraph_embedding(fname_entity, args['graph_embedding_file'], '%sentity2embedding.pickle' % args['out_path'])
        
        
        # step 6: split data into train dev test
        # step 7: sort data by its note length, add length to the last column
        
        tr, dv, te = self.split_length_sort_data(fname_entity, args['out_path'], 'full') #full data split and save
        #columns: 'SUBJECT_ID', 'HADM_ID', 'TEXT', 'LABELS', 'ENTITY_ID', 'length'
        
        # vocab_min = 3
        vname = '%svocab.csv' % args['out_path']
        self.build_vocab(args['vocab_min'], tr, vname)
        
        
        # step 8: train word embeddings via word2vec and fasttext
        Y = 'full'
        #if want to create vocabulary from w2v model then pass the vocabulary file name where you want to save the vocabulary
        w2v_file = self.word_embeddings('full', '%sdisch_full.csv' % args['out_path'], 100, 0, 5)
        self.gensim_to_embeddings('%sprocessed_full.w2v' % args['out_path'], '%svocab.csv' % args['out_path'], Y)
        self.fasttext_file = self.fasttext_embeddings('full', '%sdisch_full.csv' % args['out_path'], 100, 0, 5)
        self.gensim_to_fasttext_embeddings('%sprocessed_full.fasttext' % args['out_path'], '%svocab.csv' % args['out_path'], Y)
        
        # step 9: statistic the top 50 code
        Y = 50
        counts = Counter()
        dfnl = pd.read_csv(fname_entity) 
        for row in dfnl.itertuples(): #for read_csv and iteratuples, the first column (row[0]) is the index column
            for label in str(row[4]).split(';'): #lables are in 4th position
                counts[label] += 1

        codes_50 = sorted(counts.items(), key=operator.itemgetter(1), reverse=True)

        codes_50 = [code[0] for code in codes_50[:Y]]

        with open('%sTOP_%s_CODES.csv' % (args['out_path'], str(Y)), 'w') as of:
            w = csv.writer(of)
            for code in codes_50:
                w.writerow([code])
        
        with open(fname_entity, 'r') as f: #columns: 'SUBJECT_ID', 'TEXT', 'LABELS', 'ENTITY_ID'
            with open('%snotes_labeled_50.csv' % args['out_path'], 'w') as fl:
                r = csv.reader(f)
                w = csv.writer(fl)
                #header
                w.writerow(next(r))
                newrow = False
                for row in r:
                    newrow = True
                    for code in codes_50:
                        if code in str(row[3]).split(';'):
                            if newrow:
                                w.writerow(row)
                                newrow = False
        
        fname_50 = '%snotes_labeled_50.csv' % args['out_path'] #input dataframe
                
        tr, dv, te = self.split_length_sort_data(fname_50, args['out_path'], str(Y))
        #columns: 'SUBJECT_ID', 'TEXT', 'LABELS', 'ENTITY_ID', 'length'
        
        
    def reformat(self, code, is_diag):
        """
            Put a period in the right place because the MIMIC-3 data files exclude them.
            Generally, procedure codes have dots after the first two digits,
            while diagnosis codes have dots after the first three digits.
        """
        code = ''.join(code.split('.'))
        if is_diag:
            if code.startswith('E'):
                if len(code) > 4:
                    code = code[:4] + '.' + code[4:]
            else:
                if len(code) > 3:
                    code = code[:3] + '.' + code[3:]
        else:
            code = code[:2] + '.' + code[2:]
        return code
    
    def write_discharge_summaries(self, out_file, min_sentence_len, notes_file):
        print("processing notes file")
        with open(notes_file, 'r') as csvfile:
            with open(out_file, 'w') as outfile:
                print("writing to %s" % (out_file))
                outfile.write(','.join(['SUBJECT_ID', 'HADM_ID', 'CHARTTIME', 'TEXT']) + '\n')
                notereader = csv.reader(csvfile)
                next(notereader)

                for line in tqdm(notereader):
                    subj = int(float(line[1]))
                    category = line[6]
                    if category in  args['notes_type'].split(','): #can Includes "Nursing" and "Physician".
                        note = line[10]

                        all_sents_inds = []
                        generator = nlp_tool.span_tokenize(note)
                        for t in generator:
                            all_sents_inds.append(t)

                        text = ""
                        for ind in range(len(all_sents_inds)):
                            start = all_sents_inds[ind][0]
                            end = all_sents_inds[ind][1]

                            sentence_txt = note[start:end] 
                            
                            sentence_txt = re.sub(r'[[**].+?[**]]', '', sentence_txt) #adding to remove texts between [** **]

                            tokens = [t.lower() for t in tokenizer.tokenize(sentence_txt) if not t.isnumeric()]
                            if ind == 0:
                                text += '[CLS] ' + ' '.join(tokens) + ' [SEP]'
                            else:
                                text += ' [CLS] ' + ' '.join(tokens) + ' [SEP]'

                        text = '"' + text + '"'
                        outfile.write(','.join([line[1], line[2], line[4], text]) + '\n')


        return out_file
    
    def code_filter(self, out_path, subj_ids):
        with open(out_path+'ALL_CODES.csv', 'r') as lf:
            with open(out_path+'ALL_CODES_filtered_acc.csv', 'w') as of:
                w = csv.writer(of)
                w.writerow(['SUBJECT_ID', 'HADM_ID', 'ICD9_CODE', 'ADMITTIME', 'DISCHTIME'])
                r = csv.reader(lf)
                #header
                next(r)
                for i,row in enumerate(r):
                    subj_id = int(float(row[1]))
                    if subj_id in subj_ids:
                        w.writerow(row[1:3] + [row[-1], '', ''])
                        
                        
    def concat_data_new(self, labelsfile, notes_file, outfilename):
        
        print("labelsfile=",labelsfile) #columns: 'SUBJECT_ID', 'HADM_ID', 'ICD9_CODE', 'ADMITTIME', 'DISCHTIME'
        print("notes_file=",notes_file) #columns: 'SUBJECT_ID', 'HADM_ID', 'TEXT'
        
        mydf_label = pd.read_csv(labelsfile)
        mydf_label = mydf_label.groupby(['SUBJECT_ID','HADM_ID']).apply(lambda x: pd.Series({'ICD9_CODE':';'.join(str(v) for v in x.ICD9_CODE)})).reset_index()
        
        mydf_notes = pd.read_csv(notes_file) #already groupby with [subj,hadm]
        
        merged_df = pd.merge(mydf_notes, mydf_label, how='inner', on=['SUBJECT_ID','HADM_ID']).rename(columns={"ICD9_CODE": "LABELS"})
        merged_df.to_csv(outfilename, index=False)
        
        del merged_df
        return outfilename
        
        
    #used in old data process.     
    def concat_data(self, labelsfile, notes_file, outfilename):
        """
            INPUTS:
                labelsfile: sorted by hadm id, contains one label per line
                notes_file: sorted by hadm id, contains one note per line
        """
        csv.field_size_limit(sys.maxsize)
        print("labelsfile=",labelsfile) #columns: 'SUBJECT_ID', 'HADM_ID', 'ICD9_CODE', 'ADMITTIME', 'DISCHTIME'
        print("notes_file=",notes_file) #columns: 'SUBJECT_ID', 'HADM_ID', 'TEXT'
        with open(labelsfile, 'r') as lf:
            print("CONCATENATING")
            with open(notes_file, 'r') as notesfile:

                with open(outfilename, 'w') as outfile:
                    w = csv.writer(outfile)
                    w.writerow(['SUBJECT_ID', 'HADM_ID', 'TEXT', 'LABELS'])

                    labels_gen = self.next_labels(lf)
                    notes_gen = self.next_notes(notesfile)

                    for i, (subj_id, text, hadm_id) in enumerate(notes_gen):
                        if i % 10000 == 0:
                            print(str(i) + " done")
                        cur_subj, cur_labels, cur_hadm = next(labels_gen)
                        
                        if cur_hadm == hadm_id:
                            w.writerow([subj_id, str(hadm_id), text, ';'.join(cur_labels)])
                        else:
                            print("couldn't find matching hadm_id. data is probably not sorted correctly")
                            break

        return outfilename
    
    def next_labels(self, labelsfile): #columns: 'SUBJECT_ID', 'HADM_ID', 'ICD9_CODE', 'ADMITTIME', 'DISCHTIME'
        """
            Generator for label sets from the label file
        """
        labels_reader = csv.reader(labelsfile)
        # header
        next(labels_reader)

        first_label_line = next(labels_reader)

        cur_subj = int(first_label_line[0])
        cur_hadm = int(first_label_line[1])
        cur_labels = [first_label_line[2]]

        for row in labels_reader:
            subj_id = int(row[0])
            hadm_id = int(row[1])
            code = row[2]
            # keep reading until you hit a new hadm id
            if hadm_id != cur_hadm or subj_id != cur_subj:
                yield cur_subj, cur_labels, cur_hadm
                cur_labels = [code]
                cur_subj = subj_id
                cur_hadm = hadm_id
            else:
                # add to the labels and move on
                cur_labels.append(code)
        yield cur_subj, cur_labels, cur_hadm
        
    def next_notes(self, notesfile): #columns: 'SUBJECT_ID', 'HADM_ID', 'TEXT'
        """
            Generator for notes from the notes file
            This will also concatenate discharge summaries and their addenda, which have the same subject and hadm id
        """
        nr = csv.reader(notesfile)
        # header
        next(nr)

        first_note = next(nr)

        cur_subj = int(first_note[0])
        cur_hadm = int(first_note[1])
        cur_text = first_note[2] 

        for row in nr:
            subj_id = int(row[0])
            hadm_id = int(row[1])
            text = row[2] 
            # keep reading until you hit a new hadm id
            if hadm_id != cur_hadm or subj_id != cur_subj:
                yield cur_subj, cur_text, cur_hadm
                cur_text = text
                cur_subj = subj_id
                cur_hadm = hadm_id
            else:
                # concatenate to the discharge summary and move on
                cur_text += " " + text
        yield cur_subj, cur_text, cur_hadm
        
    def extract_entity(self, data_file, out_file): #data file columns: 'SUBJECT_ID', 'HADM_ID', 'TEXT', 'LABELS'
        #Pre-trained Entity extraction model from Huggingface
        unmasker = pipeline('ner', model='samrawal/bert-base-uncased_clinical-ner')
        #wikimapper from downloaded and indexed wikidump
        mapper = WikiMapper("/home/pgoswami/DifferentialEHR/data/Wikidata_dump/index_enwiki-latest.db")
        
        csv.field_size_limit(sys.maxsize)
        
        with open(data_file, 'r') as lf:
            with open(out_file, 'w') as of:
                w = csv.writer(of)
                w.writerow(['SUBJECT_ID', 'HADM_ID', 'TEXT', 'LABELS', 'ENTITY_ID'])
                r = csv.reader(lf)
                #header
                next(r)
                for i,row in enumerate(r):
                    if i % 1000 == 0:
                        print(str(i) + " entity extraction done")
                            
                    text = str(row[2])
                    extracted_entities = ' '.join([x for x in [obj['word'] for obj in unmasker(text)[0:50]]])
                    fine_text = extracted_entities.replace(' ##', '').split()
                    entity_ids = ' '.join([mapper.title_to_id(m.title()) for m in fine_text if mapper.title_to_id(m.title()) is not None]) #getting the title ids from wikidata
                    w.writerow(row + [entity_ids])
        return out_file
    
    def extract_biggraph_embedding(self, data_file, embedding_file_path, out_file): 
        #datafile columns :'SUBJECT_ID', 'HADM_ID', 'TEXT', 'LABELS', 'ENTITY_ID
        selected_entity_ids = set()
        
        with open(data_file, 'r') as lf:
            r = csv.reader(lf)
            #header
            next(r)
            for i,row in enumerate(r):
                entity_ids = str(row[4]).split()
                selected_entity_ids.update(entity_ids)
        
        print(f'Total {len(selected_entity_ids)} QIDs for Entities')
        
        entity2embedding = {}
        
        with smart.open(embedding_file_path, encoding='utf-8') as fp:  # smart open can read .gz files
            for i, line in enumerate(fp):
                cols = line.split('\t')

                entity_id = cols[0]

                if entity_id.startswith('<http://www.wikidata.org/entity/Q') and entity_id.endswith('>'):
                    entity_id = entity_id.replace('<http://www.wikidata.org/entity/', '').replace('>', '')

                    if entity_id in selected_entity_ids:
                        entity2embedding[entity_id] = np.array(cols[1:]).astype(np.float)

                if not i % 100000:
                    print(f'Lines completed {i}')

        # Save
        with open(out_file, 'wb') as f:
            pickle.dump(entity2embedding, f)

        print(f'Embeddings Saved to {out_file}')


    #datasetType = full/50, 
    #labeledfile=inputfilepath, 
    #base_name=outputfilename
    def split_length_sort_data(self, labeledfile, base_name, datsetType): 
        print("SPLITTING")
        labeledDf = pd.read_csv(labeledfile)
        labeledDf['length'] = labeledDf.apply(lambda row: len(str(row['TEXT']).split()), axis=1)

        labeledDf_train = labeledDf.sample(frac = 0.7) #70% train data
        labeledDf_remain = labeledDf.drop(labeledDf_train.index)
        labeledDf_dev = labeledDf_remain.sample(frac = 0.5) #15% val data
        labeledDf_test = labeledDf_remain.drop(labeledDf_dev.index) #15% test data
        
        filename_list = []
        for splt in ['train', 'dev', 'test']:
            filename = '%s%s_full.csv' % (base_name, splt) if datsetType == 'full' else '%s%s_%s.csv' % (base_name, splt, '50')
            conv_df = eval('labeledDf_'+splt) #getting the variable
            conv_df = conv_df.sort_values(['length'])
            print('saving to ..'+filename)
            filename_list.append(filename)
            conv_df.to_csv(filename, index=False)
        
        #gc the dataframes
        del labeledDf_train
        del labeledDf_remain
        del labeledDf_dev
        del labeledDf_test
        
        return filename_list[0], filename_list[1], filename_list[2]
    
    def build_vocab(self, vocab_min, infile, vocab_filename): #columns : 'SUBJECT_ID', 'HADM_ID', 'TEXT', 'LABELS', 'ENTITY_ID', 'length'
        """
            INPUTS:
                vocab_min: how many documents a word must appear in to be kept
                infile: (training) data file to build vocabulary from. CSV reader also need huge memory to load the file.
                vocab_filename: name for the file to output
        """
        csv.field_size_limit(sys.maxsize)
        with open(infile, 'r') as csvfile: #columns: 'SUBJECT_ID', 'HADM_ID', 'TEXT', 'LABELS', 'ENTITY_ID', 'length'
            reader = csv.reader(csvfile)
            # header
            next(reader)

            # 0. read in data
            print("reading in data...")
            # holds number of terms in each document
            note_numwords = []
            # indices where notes start
            note_inds = [0]
            # indices of discovered words
            indices = []
            # holds a bunch of ones
            data = []
            # keep track of discovered words
            vocab = {}
            # build lookup table for terms
            num2term = {}
            # preallocate array to hold number of notes each term appears in
            note_occur = np.zeros(400000, dtype=int)
            i = 0
            for row in reader:
                text = row[2] #chnage Prantik: after merging same subject values, text is in third (2) position
                numwords = 0
                for term in text.split():
                    # put term in vocab if it's not there. else, get the index
                    index = vocab.setdefault(term, len(vocab))
                    indices.append(index)
                    num2term[index] = term
                    data.append(1)
                    numwords += 1
                # record where the next note starts
                note_inds.append(len(indices))
                indset = set(indices[note_inds[-2]:note_inds[-1]])
                # go thru all the word indices you just added, and add to the note occurrence count for each of them
                for ind in indset:
                    note_occur[ind] += 1
                note_numwords.append(numwords)
                i += 1
            # clip trailing zeros
            note_occur = note_occur[note_occur > 0]

            # turn vocab into a list so indexing doesn't get fd up when we drop rows
            vocab_list = np.array([word for word, ind in sorted(vocab.items(), key=operator.itemgetter(1))])

            # 1. create sparse document matrix
            C = csr_matrix((data, indices, note_inds), dtype=int).transpose()
            # also need the numwords array to be a sparse matrix
            note_numwords = csr_matrix(1. / np.array(note_numwords))

            # 2. remove rows with less than 3 total occurrences
            print("removing rare terms")
            # inds holds indices of rows corresponding to terms that occur in < 3 documents
            inds = np.nonzero(note_occur >= vocab_min)[0]
            print(str(len(inds)) + " terms qualify out of " + str(C.shape[0]) + " total")
            # drop those rows
            C = C[inds, :]
            note_occur = note_occur[inds]
            vocab_list = vocab_list[inds]

            print("writing output")
            with open(vocab_filename, 'w') as vocab_file:
                for word in vocab_list:
                    vocab_file.write(word + "\n")
                    
    def word_embeddings(self, Y, notes_file, embedding_size, min_count, n_iter, outfile=None):
        modelname = "processed_%s.w2v" % (Y)
        sentences = ProcessedIter(Y, notes_file)
        print("Model name %s..." % (modelname))

        model = w2v.Word2Vec(vector_size=embedding_size, min_count=min_count, workers=4, epochs=n_iter)
        print("building word2vec vocab on %s..." % (notes_file))

        model.build_vocab(sentences)
        print("training...")
        model.train(sentences, total_examples=model.corpus_count, epochs=model.epochs)
        out_file = '/'.join(notes_file.split('/')[:-1] + [modelname])
        print("writing embeddings to %s" % (out_file))
        model.save(out_file)
        
        #if want to create vocab from w2v model, pass the vocab file name
        if outfile is not None:
            print("writing vocab to %s" % (outfile))
            with open(outfile, 'w') as vocab_file:
                for word in  model.wv.key_to_index:
                    vocab_file.write(word + "\n")
                
        return out_file
    
    def gensim_to_embeddings(self, wv_file, vocab_file, Y, outfile=None):
        model = gensim.models.Word2Vec.load(wv_file)
        wv = model.wv
        #free up memory
        del model

        vocab = set()
        with open(vocab_file, 'r') as vocabfile:
            for i,line in enumerate(vocabfile):
                line = line.strip()
                if line != '':
                    vocab.add(line)
        ind2w = {i+1:w for i,w in enumerate(sorted(vocab))}

        W, words = self.build_matrix(ind2w, wv)

        if outfile is None:
            outfile = wv_file.replace('.w2v', '.embed')

        #smash that save button
        self.save_embeddings(W, words, outfile)
        
    def build_matrix(self, ind2w, wv):
        """
            Go through vocab in order. Find vocab word in wv.index2word, then call wv.word_vec(wv.index2word[i]).
            Put results into one big matrix.
            Note: ind2w starts at 1 (saving 0 for the pad character), but gensim word vectors starts at 0
        """
        W = np.zeros((len(ind2w)+1, len(wv.get_vector(wv.index_to_key[0])) ))
        print("W shape=",W.shape)
        words = ["**PAD**"]
        W[0][:] = np.zeros(len(wv.get_vector(wv.index_to_key[0])))
        for idx, word in tqdm(ind2w.items()):
            if idx >= W.shape[0]:
                break
            W[idx][:] = wv.get_vector(word)
            words.append(word)
        print("W shape final=",W.shape)
        print("Word list length=",len(words))
        return W, words
    
    def fasttext_embeddings(self, Y, notes_file, embedding_size, min_count, n_iter):
        modelname = "processed_%s.fasttext" % (Y)
        sentences = ProcessedIter(Y, notes_file)
        print("Model name %s..." % (modelname))

        model = fasttext.FastText(vector_size=embedding_size, min_count=min_count, epochs=n_iter)
        print("building fasttext vocab on %s..." % (notes_file))

        model.build_vocab(sentences)
        print("training...")
        model.train(sentences, total_examples=model.corpus_count, epochs=model.epochs)
        out_file = '/'.join(notes_file.split('/')[:-1] + [modelname])
        print("writing embeddings to %s" % (out_file))
        model.save(out_file)
        return out_file
    
    def gensim_to_fasttext_embeddings(self, wv_file, vocab_file, Y, outfile=None):
        model = gensim.models.FastText.load(wv_file)
        wv = model.wv
        #free up memory
        del model

        vocab = set()
        with open(vocab_file, 'r') as vocabfile:
            for i,line in enumerate(vocabfile):
                line = line.strip()
                if line != '':
                    vocab.add(line)
        ind2w = {i+1:w for i,w in enumerate(sorted(vocab))}

        W, words = self.build_matrix(ind2w, wv)

        if outfile is None:
            outfile = wv_file.replace('.fasttext', '.fasttext.embed')

        #smash that save button
        self.save_embeddings(W, words, outfile)
    
    def save_embeddings(self, W, words, outfile):
        with open(outfile, 'w') as o:
            #pad token already included
            for i in range(len(words)):
                line = [words[i]]
                line.extend([str(d) for d in W[i]])
                o.write(" ".join(line) + "\n")

## Approach

## Everything in one cell for easy running. 

In [4]:
#############Imports###############

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import xavier_uniform_ as xavier_uniform
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau
from torch.utils.data import Dataset
from torchsummary import summary
from entmax import sparsemax, entmax15, entmax_bisect

from pytorch_pretrained_bert.modeling import BertLayerNorm
from pytorch_pretrained_bert import BertModel, BertConfig
from pytorch_pretrained_bert import WEIGHTS_NAME, CONFIG_NAME
from pytorch_pretrained_bert import BertAdam
from pytorch_pretrained_bert import BertTokenizer

import transformers as tr
from transformers import AdamW

from allennlp.data.token_indexers.elmo_indexer import ELMoTokenCharactersIndexer
from allennlp.data import Token, Vocabulary, Instance
from allennlp.data.fields import TextField
from allennlp.data.dataset import Batch

from sklearn.metrics import f1_score, precision_recall_fscore_support, accuracy_score
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import roc_curve, auc

from keras.preprocessing.sequence import pad_sequences
from keras.preprocessing.text import Tokenizer

from typing import Tuple,Callable,IO,Optional
from collections import defaultdict
from urllib.parse import urlparse
from functools import wraps
from hashlib import sha256
from typing import List
from math import floor
from tqdm import tqdm
import pandas as pd
import numpy as np

import requests
import tempfile
import tarfile
import random
import shutil
import struct
import pickle
import time
import json
import csv
import sys
import os



#Models


class ModelHub:
    
    def __init__(self, args, dicts):
        self.pick_model(args, dicts)
        
    def pick_model(self, args, dicts):
        Y = len(dicts['ind2c'])
        if args['model'] == 'KG_MultiResCNN':
            model = KG_MultiResCNN(args, Y, dicts)
        elif args['model'] == 'KG_MultiResCNNLSTM':
            model = KG_MultiResCNNLSTM(args, Y, dicts)
        elif args['model'] == 'bert_se_kg':
            model = Bert_SE_KG(args, Y, dicts)
        elif args['model'] == 'bert_we_kg':
            model = Bert_WE_KG(args, Y, dicts)
        elif args['model'] == 'bert_l4_we_kg':
            model = Bert_L4_WE_KG(args, Y, dicts)
        elif args['model'] == 'bert_mcnn_kg':
            model = Bert_MCNN_KG(args, Y, dicts)
        else:
            raise RuntimeError("wrong model name")

        if args['test_model']:
            sd = torch.load(args['test_model'])
            model.load_state_dict(sd)
        if args['gpu'] >= 0:
            model.cuda(args['gpu'])
        return model


class WordRep(nn.Module):
    def __init__(self, args, Y, dicts):
        super(WordRep, self).__init__()

        self.gpu = args['gpu']

        self.isTfIdf = False
        if args['use_tfIdf']:
          self.isTfIdf = True

        if args['embed_file']:
            print("loading pretrained embeddings from {}".format(args['embed_file']))
            if args['use_ext_emb']:
                pretrain_word_embedding, pretrain_emb_dim = self.build_pretrain_embedding(args['embed_file'], dicts['w2ind'],
                                                                                     True)
                W = torch.from_numpy(pretrain_word_embedding)
            else:
                W = torch.Tensor(self.load_embeddings(args['embed_file']))

            self.embed = nn.Embedding(W.size()[0], W.size()[1], padding_idx=0)
            self.embed.weight.data = W.clone()
        else:
            # add 2 to include UNK and PAD
            self.embed = nn.Embedding(len(dicts['w2ind']) + 2, args['embed_size'], padding_idx=0)
        self.feature_size = self.embed.embedding_dim

        self.embed_drop = nn.Dropout(p=args['dropout'])

        self.conv_dict = {
                    1: [self.feature_size, args['num_filter_maps']],
                    2: [self.feature_size, 100, args['num_filter_maps']],
                    3: [self.feature_size, 150, 100, args['num_filter_maps']],
                    4: [self.feature_size, 200, 150, 100, args['num_filter_maps']]
                     }


    def forward(self, x, tfIdf_inputs): #tfIdf_inputs

        if self.gpu >= 0:
            x = x if x.is_cuda else x.cuda(self.gpu)
            if self.isTfIdf and tfIdf_inputs != None:
                tfIdf_inputs = tfIdf_inputs if tfIdf_inputs.is_cuda else tfIdf_inputs.cuda(self.gpu)   
        try:
            features = [self.embed(x)]
        except:
            print(x)
              raise

        out = torch.cat(features, dim=2)

        if self.isTfIdf and tfIdf_inputs != None:
            weight = tfIdf_inputs.unsqueeze(dim=2)
            out = out * weight

        out = self.embed_drop(out)

        del x
        del tfIdf_inputs
        return out
    
    def load_embeddings(self, embed_file):
        #also normalizes the embeddings
        W = []
        with open(embed_file) as ef:
            for line in ef:
                line = line.rstrip().split()
                vec = np.array(line[1:]).astype(np.float)
                vec = vec / float(np.linalg.norm(vec) + 1e-6)
                W.append(vec)
            #UNK embedding, gaussian randomly initialized
            print("adding unk embedding")
            vec = np.random.randn(len(W[-1]))
            vec = vec / float(np.linalg.norm(vec) + 1e-6)
            W.append(vec)
        W = np.array(W)
        return W
    
    def build_pretrain_embedding(self, embedding_path, word_alphabet, norm):
        embedd_dict, embedd_dim = self.load_pretrain_emb(embedding_path)

        scale = np.sqrt(3.0 / embedd_dim)
        pretrain_emb = np.zeros([len(word_alphabet)+2, embedd_dim], dtype=np.float32)  # add UNK (last) and PAD (0)
        perfect_match = 0
        case_match = 0
        digits_replaced_with_zeros_found = 0
        lowercase_and_digits_replaced_with_zeros_found = 0
        not_match = 0
        for word, index in word_alphabet.items():
            if word in embedd_dict:
                if norm:
                    pretrain_emb[index,:] = self.norm2one(embedd_dict[word])
                else:
                    pretrain_emb[index,:] = embedd_dict[word]
                perfect_match += 1

            elif word.lower() in embedd_dict:
                if norm:
                    pretrain_emb[index,:] = self.norm2one(embedd_dict[word.lower()])
                else:
                    pretrain_emb[index,:] = embedd_dict[word.lower()]
                case_match += 1

            elif re.sub('\d', '0', word) in embedd_dict:
                if norm:
                    pretrain_emb[index,:] = self.norm2one(embedd_dict[re.sub('\d', '0', word)])
                else:
                    pretrain_emb[index,:] = embedd_dict[re.sub('\d', '0', word)]
                digits_replaced_with_zeros_found += 1

            elif re.sub('\d', '0', word.lower()) in embedd_dict:
                if norm:
                    pretrain_emb[index,:] = self.norm2one(embedd_dict[re.sub('\d', '0', word.lower())])
                else:
                    pretrain_emb[index,:] = embedd_dict[re.sub('\d', '0', word.lower())]
                lowercase_and_digits_replaced_with_zeros_found += 1

            else:
                if norm:
                    pretrain_emb[index, :] = self.norm2one(np.random.uniform(-scale, scale, [1, embedd_dim]))
                else:
                    pretrain_emb[index,:] = np.random.uniform(-scale, scale, [1, embedd_dim])
                not_match += 1

        # initialize pad and unknown
        pretrain_emb[0, :] = np.zeros([1, embedd_dim], dtype=np.float32)
        if norm:
            pretrain_emb[-1, :] = self.norm2one(np.random.uniform(-scale, scale, [1, embedd_dim]))
        else:
            pretrain_emb[-1, :] = np.random.uniform(-scale, scale, [1, embedd_dim])


        print("pretrained word emb size {}".format(len(embedd_dict)))
        print("prefect match:%.2f%%, case_match:%.2f%%, dig_zero_match:%.2f%%, "
                     "case_dig_zero_match:%.2f%%, not_match:%.2f%%"
                     %(perfect_match*100.0/len(word_alphabet), case_match*100.0/len(word_alphabet), digits_replaced_with_zeros_found*100.0/len(word_alphabet),
                       lowercase_and_digits_replaced_with_zeros_found*100.0/len(word_alphabet), not_match*100.0/len(word_alphabet)))

        return pretrain_emb, embedd_dim
    
    def load_pretrain_emb(self, embedding_path):
        embedd_dim = -1
        embedd_dict = dict()

        # emb_debug = []
        if embedding_path.find('.bin') != -1:
            with open(embedding_path, 'rb') as f:
                wordTotal = int(self._readString(f, 'utf-8'))
                embedd_dim = int(self._readString(f, 'utf-8'))

                for i in range(wordTotal):
                    word = self._readString(f, 'utf-8')
                    # emb_debug.append(word)

                    word_vector = []
                    for j in range(embedd_dim):
                        word_vector.append(self._readFloat(f))
                    word_vector = np.array(word_vector, np.float)

                    f.read(1)  # a line break

                    embedd_dict[word] = word_vector

        else:
            with codecs.open(embedding_path, 'r', 'UTF-8') as file:
                for line in file:
                    # logging.info(line)
                    line = line.strip()
                    if len(line) == 0:
                        continue
                    # tokens = line.split()
                    tokens = re.split(r"\s+", line)
                    if len(tokens) == 2:
                        continue # it's a head
                    if embedd_dim < 0:
                        embedd_dim = len(tokens) - 1
                    else:
                        # assert (embedd_dim + 1 == len(tokens))
                        if embedd_dim + 1 != len(tokens):
                            continue
                    embedd = np.zeros([1, embedd_dim])
                    embedd[:] = tokens[1:]
                    embedd_dict[tokens[0]] = embedd


        return embedd_dict, embedd_dim
    
    def _readString(self, f, code):
        # s = unicode()
        s = str()
        c = f.read(1)
        value = ord(c)

        while value != 10 and value != 32:
            if 0x00 < value < 0xbf:
                continue_to_read = 0
            elif 0xC0 < value < 0xDF:
                continue_to_read = 1
            elif 0xE0 < value < 0xEF:
                continue_to_read = 2
            elif 0xF0 < value < 0xF4:
                continue_to_read = 3
            else:
                raise RuntimeError("not valid utf-8 code")

            i = 0
            # temp = str()
            # temp = temp + c

            temp = bytes()
            temp = temp + c

            while i<continue_to_read:
                temp = temp + f.read(1)
                i += 1

            temp = temp.decode(code)
            s = s + temp

            c = f.read(1)
            value = ord(c)

        return s
    
    def _readFloat(self,f):
        bytes4 = f.read(4)
        f_num = struct.unpack('f', bytes4)[0]
        return f_num
    
    def norm2one(self,vec):
        root_sum_square = np.sqrt(np.sum(np.square(vec)))
        return vec/root_sum_square

class SentimentOutput():
    
    def __init__(self, args):

        self.gpu = args['gpu']
        
        cache_path = os.path.join(args['bert_dir'], args['sentiment_bert'])
        
        savedModel = None
        if os.path.exists(cache_path):
            print("model path exist")
            savedModel = tr.AutoModelForSequenceClassification.from_pretrained(cache_path)
        else:
            print("Downloading and saving model")
            savedModel = tr.AutoModelForSequenceClassification.from_pretrained(str(args['sentiment_bert']))
            savedModel.save_pretrained(save_directory = cache_path, save_config=True)
        self.bert = savedModel
        self.config = savedModel.config
        
    def forward(self, x):

        encoded_input = dict()
        if self.gpu >= 0:
            x[0] = x[0] if x[0].is_cuda else x[0].cuda(self.gpu)
            x[1] = x[1] if x[1].is_cuda else x[1].cuda(self.gpu)
            model = self.bert
            model = model.cuda(self.gpu)

        encoded_input['input_ids'] = x[0]
        encoded_input['attention_mask'] = x[1]
        
        senti_output = model(**encoded_input, output_hidden_states=True)
        all_hidden_states  = senti_output.hidden_states
        out = all_hidden_states[-1] #last hidden state. [#batch_size, sequence(m), 1024]

        del all_hidden_states
        del senti_output
        del encoded_input
        del x
        del model

        return out

class OutputLayer(nn.Module):
    def __init__(self, args, Y, dicts, input_size):
        super(OutputLayer, self).__init__()

        self.gpu = args['gpu']
        
        self.use_entmax15 = False
        if args['use_entmax15']:
            self.use_entmax15 = True

        self.U = nn.Linear(input_size, Y)
        xavier_uniform(self.U.weight)


        self.final = nn.Linear(input_size, Y)
        xavier_uniform(self.final.weight)

        self.loss_function = nn.BCEWithLogitsLoss()

    def forward(self, x, target):

        if self.gpu >= 0:
            target = target if target.is_cuda else target.cuda(self.gpu)
            x = x if x.is_cuda else x.cuda(self.gpu)
        
        if self.use_entmax15:
            alpha =  entmax15(self.U.weight.matmul(x.transpose(1, 2)), dim=2)
        else:
            alpha = F.softmax(self.U.weight.matmul(x.transpose(1, 2)), dim=2)

        m = alpha.matmul(x)

        y = self.final.weight.mul(m).sum(dim=2).add(self.final.bias)

        loss = self.loss_function(y, target)

        del x
        del target

        return y, loss

class MRCNNLayer(nn.Module):
    def __init__(self, args, feature_size):
        super(MRCNNLayer, self).__init__()

        self.gpu = args['gpu']

        self.feature_size = feature_size
        self.conv_dict = {
            1: [self.feature_size, args['num_filter_maps']],
            2: [self.feature_size, 100, args['num_filter_maps']],
            3: [self.feature_size, 150, 100, args['num_filter_maps']],
            4: [self.feature_size, 200, 150, 100, args['num_filter_maps']]
              }
        
        self.conv = nn.ModuleList()
        filter_sizes = args['filter_size'].split(',')

        self.filter_num = len(filter_sizes)
        for filter_size in filter_sizes:
            filter_size = int(filter_size)
            one_channel = nn.ModuleList()
            
            tmp = nn.Conv1d(self.feature_size, self.feature_size, kernel_size=filter_size,
                            padding=int(floor(filter_size / 2)))
            xavier_uniform(tmp.weight)
            one_channel.add_module('baseconv', tmp)

            conv_dimension = self.conv_dict[args['conv_layer']]
            for idx in range(args['conv_layer']):
                tmp = ResidualBlock(conv_dimension[idx], conv_dimension[idx + 1], filter_size, 1, True,
                                    args['dropout'])
                one_channel.add_module('resconv-{}'.format(idx), tmp)
            
            self.conv.add_module('channel-{}'.format(filter_size), one_channel)



    def forward(self, x):

        if self.gpu >= 0:
            x = x if x.is_cuda else x.cuda(self.gpu)

        x = x.transpose(1, 2)

        conv_result = []
        for conv in self.conv:
            tmp = x
            for idx, md in enumerate(conv):
                if idx == 0:
                    tmp = torch.tanh(md(tmp))
                else:  
                    tmp = md(tmp)
            tmp = tmp.transpose(1, 2)
            conv_result.append(tmp)
        out = torch.cat(conv_result, dim=2)
        del x
        return out

class ResidualBlock(nn.Module):
    def __init__(self, inchannel, outchannel, kernel_size, stride, use_res, dropout):
        super(ResidualBlock, self).__init__()
        self.left = nn.Sequential(
            nn.Conv1d(inchannel, outchannel, kernel_size=kernel_size, stride=stride, padding=int(floor(kernel_size / 2)), bias=False),
            nn.BatchNorm1d(outchannel),
            nn.Tanh(),
            nn.Conv1d(outchannel, outchannel, kernel_size=kernel_size, stride=1, padding=int(floor(kernel_size / 2)), bias=False),
            nn.BatchNorm1d(outchannel)
        )

        self.use_res = use_res
        if self.use_res:
            self.shortcut = nn.Sequential(
                        nn.Conv1d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False),
                        nn.BatchNorm1d(outchannel)
                    )

        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        out = self.left(x)
        if self.use_res:
            out += self.shortcut(x)   
        out = torch.tanh(out)
        out = self.dropout(out)
        return out


class KG_MultiResCNN(nn.Module):

    def __init__(self, args, Y, dicts):
        super(KG_MultiResCNN, self).__init__()

        self.word_rep = WordRep(args, Y, dicts)
        self.feature_size = self.word_rep.feature_size

        self.is_sentiment = False
        if args['use_sentiment']:
            self.sentiment_model = SentimentOutput(args)
            self.S_U = nn.Linear(self.sentiment_model.config.hidden_size, self.feature_size)
            self.is_sentiment = True

        if args['use_embd_layer'] and args['add_with_wordrap']:
            self.kg_embd = EntityEmbedding(args, Y)
            self.kg_embd.dim_red = nn.Linear(self.kg_embd.feature_size, self.feature_size)
            self.kg_embd.feature_red = nn.Linear(args['MAX_ENT_LENGTH'], args['MAX_LENGTH'])
            self.add_emb_with_wordrap = True

        self.dropout = nn.Dropout(args['dropout'])

        self.conv = MRCNNLayer(args, self.feature_size)

        self.feature_size = self.conv.filter_num * args['num_filter_maps']

        self.output_layer = OutputLayer(args, Y, dicts, self.feature_size)


    def forward(self, x, target, text_inputs, embeddings, tfIdf_inputs): #inputs_id, labels, text_inputs, embeddings, tfIdf_inputs

        x = self.word_rep(x, tfIdf_inputs) #(batch, sequence, 100)

        if self.is_sentiment:
            senti_out = self.sentiment_model.forward(text_inputs)
            s_alpha = self.S_U(senti_out)
            del senti_out
            x = torch.mul(x, s_alpha) #(batch, sequence, 100)
            del s_alpha

        if hasattr(self, 'add_emb_with_wordrap') and (self.add_emb_with_wordrap):
            # with embedding layer
            out = self.kg_embd(embeddings) #torch.Size([batch, seq len(n), embedding dim(200)])
            out = self.kg_embd.dim_red(out) #torch.Size([batch, seq len(n), embedding dim(100)])

            x = torch.cat((x, out), dim=1) # new shape (batch_size, sequence_length(m+n), feature_size (100))

            del out


        x = self.dropout(x)

        x = self.conv(x)


        y, loss = self.output_layer(x, target)

        del x

        return y, loss

    def freeze_net(self):
        for p in self.word_rep.embed.parameters():
            p.requires_grad = False



class KG_MultiResCNNLSTM(nn.Module):

    def __init__(self, args, Y, dicts):
        super(KG_MultiResCNNLSTM, self).__init__()

        self.word_rep = WordRep(args, Y, dicts)
        self.embedding_size = self.word_rep.embed.weight.data.size()[0]
        

        self.conv = nn.ModuleList()
        filter_sizes = args['filter_size'].split(',')

        self.filter_num = len(filter_sizes)
        for filter_size in filter_sizes:
            filter_size = int(filter_size)
            one_channel = nn.ModuleList()
        
            
            tmp = nn.Conv1d(self.word_rep.feature_size, self.word_rep.feature_size, kernel_size=filter_size,
                            padding=int(floor(filter_size / 2)))
            xavier_uniform(tmp.weight)
            one_channel.add_module('baseconv', tmp)

            conv_dimension = self.word_rep.conv_dict[args['conv_layer']]
            for idx in range(args['conv_layer']):
                tmp = ResidualBlock(conv_dimension[idx], conv_dimension[idx + 1], filter_size, 1, True,
                                    args['dropout'])
                one_channel.add_module('resconv-{}'.format(idx), tmp)

            lstm = torch.nn.LSTM(
                    input_size= args['num_filter_maps'],
                    hidden_size= args['num_filter_maps'],
                    num_layers=1
                )
            
            one_channel.add_module('LSTM', lstm)
            
            self.conv.add_module('channel-{}'.format(filter_size), one_channel)

        self.output_layer = OutputLayer(args, Y, dicts, self.filter_num * args['num_filter_maps'])


    def forward(self, x, target, text_inputs, embeddings, tfIdf_inputs):

        x = self.word_rep(x, tfIdf_inputs)

        x = x.transpose(1, 2)

        conv_result = []
        for conv in self.conv:
            tmp = x
            for idx, md in enumerate(conv):
                if idx == 0:
                    tmp = torch.tanh(md(tmp))
                else:
                    if idx == 2:
                        tmp = tmp.transpose(1, 2)
                        tmp, (h,c) = md(tmp)
                        tmp = tmp.transpose(1, 2)
                    else:     
                        tmp = md(tmp)
            tmp = tmp.transpose(1, 2)
            conv_result.append(tmp)
        x = torch.cat(conv_result, dim=2)

        y, loss = self.output_layer(x, target)

        return y, loss

    def freeze_net(self):
        for p in self.word_rep.embed.parameters():
            p.requires_grad = False


class KGEntityToVec:
    
    @staticmethod
    def getEntityToVec():
        with open('%sentity2embedding.pickle' % args['out_path'], 'rb') as f:
            entity2vec = pickle.load(f)
        return entity2vec


class EntityEmbedding(nn.Module):
    def __init__(self, args, Y):
        super(EntityEmbedding, self).__init__()

        self.gpu = args['gpu']

        entity2vec = KGEntityToVec().getEntityToVec()
    
        embedding_matrix = self.create_embedding_matrix(entity2vec)
        
        vocab_size=embedding_matrix.shape[0]
        vector_size=embedding_matrix.shape[1]


        self.embed = nn.Embedding(num_embeddings=vocab_size,embedding_dim=vector_size)
        self.embed.weight=nn.Parameter(torch.tensor(embedding_matrix,dtype=torch.float32))
        # self.embed.weight.requires_grad=False
        self.feature_size = self.embed.embedding_dim

        self.conv_dict = {
                    1: [self.feature_size, args['num_filter_maps']],
                    2: [self.feature_size, 100, args['num_filter_maps']],
                    3: [self.feature_size, 150, 100, args['num_filter_maps']],
                    4: [self.feature_size, 200, 150, 100, args['num_filter_maps']]
                     }

        self.embed_drop = nn.Dropout(p=args['dropout'])


    def forward(self, x):

        if self.gpu >= 0:
            x = x if x.is_cuda else x.cuda(self.gpu)

        features = [self.embed(x)]

        output = torch.cat(features, dim=2)

        output = self.embed_drop(output)

        del x
        
        return output
    
    def create_embedding_matrix(self, ent2vec):
        embedding_matrix=np.zeros((len(ent2vec)+2,200))

        for index, key in enumerate(ent2vec):
            vec = ent2vec[key]
            vec = vec / float(np.linalg.norm(vec) + 1e-6)
            embedding_matrix[index+1]=vec

        return embedding_matrix

class Bert_SE_KG(nn.Module): #bert with sentence embedding

    def __init__(self, args, Y, dicts):
        super(Bert_SE_KG, self).__init__()

        cache_path = os.path.join(args['bert_dir'], args['pretrained_bert'])
        
        savedModel = None
        if os.path.exists(cache_path):
            print("model path exist")
            savedModel = tr.BertModel.from_pretrained(cache_path)
        else:
            print("Downloading and saving model")
            savedModel = tr.BertModel.from_pretrained(str(args['pretrained_bert']))
            savedModel.save_pretrained(save_directory = cache_path, save_config=True)
        self.bert = savedModel
        self.config = savedModel.config
        # print("Model config {}".format(self.config))

        self.feature_size = self.config.hidden_size

        if args['use_embd_layer']:
            self.kg_embd = EntityEmbedding(args, Y)
            self.kg_embd.embed.weight.requires_grad=False
            filetrs = [3]
            self.convs = nn.ModuleList([nn.Conv1d(self.kg_embd.feature_size, self.kg_embd.feature_size, int(filter_size)) for filter_size in filetrs])
            self.dim_reduction = nn.Linear(self.feature_size, self.kg_embd.feature_size)
            self.feature_size = self.kg_embd.feature_size*2
        
        self.dropout = nn.Dropout(args['dropout'])
        self.classifier = nn.Linear(self.feature_size, Y)
        
        
        self.apply(self.init_bert_weights)
        

    def forward(self, input_ids, token_type_ids, attention_mask, entity_embeddings, target):

        last_hidden_state, x = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, return_dict=False)
        
        if hasattr(self, 'kg_embd'):
            # with embedding layer
            out = self.kg_embd(entity_embeddings) #torch.Size([batch, seq len(n), embedding dim(200)])
            
            embedded = out.permute(0,2,1) #torch.Size([batch, embedding dim (200), seq len])#if want sentence embedding
            conved = [torch.relu(conv(embedded)) for conv in self.convs] #if want sentence embedding
            pooled = [conv.max(dim=-1).values for conv in conved] #if want sentence embedding
            cat = self.dropout(torch.cat(pooled, dim=-1)) #if want sentence embedding

            x = self.dim_reduction(x)
            x = x / float(torch.norm(x) + 1e-6)

            x = torch.cat((x, cat), dim=1) #if want sentence embedding
        
        x = self.dropout(x) #(batch_size, sequence_length(m), hidden_size(200/756))

        y = self.classifier(x)
        loss = F.binary_cross_entropy_with_logits(y, target)

        return y, loss

    def init_bert_weights(self, module):
        BertLayerNorm = torch.nn.LayerNorm
        
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        elif isinstance(module, BertLayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    def freeze_net(self):
        pass

class Bert_WE_KG(nn.Module): #bert with word embedding

    def __init__(self, args, Y, dicts):
        super(Bert_WE_KG, self).__init__()

        cache_path = os.path.join(args['bert_dir'], args['pretrained_bert'])
        
        savedModel = None
        if os.path.exists(cache_path):
            savedModel = tr.BertModel.from_pretrained(cache_path)
        else:
            savedModel = tr.BertModel.from_pretrained(str(args['pretrained_bert']))
            savedModel.save_pretrained(save_directory = cache_path, save_config=True)
        self.bert = savedModel
        self.config = savedModel.config

        self.feature_size = self.config.hidden_size

        if args['use_embd_layer']:
            self.kg_embd = EntityEmbedding(args, Y)
            self.kg_embd.embed.weight.requires_grad=False
            self.dim_reduction = nn.Linear(self.feature_size, self.kg_embd.feature_size)
            self.feature_size = self.kg_embd.feature_size
        
        
        self.dropout = nn.Dropout(args['dropout'])

        self.output_layer = OutputLayer(args, Y, dicts, self.feature_size)
        
        
        self.apply(self.init_bert_weights)
        

    def forward(self, input_ids, token_type_ids, attention_mask, entity_embeddings, target):

        last_hidden_state, pooled_output = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, return_dict=False)

        x = self.dropout(last_hidden_state) #(batch_size, sequence_length(m), hidden_size(786))

        
        if hasattr(self, 'kg_embd'):
            out = self.kg_embd(entity_embeddings) #torch.Size([batch, seq len(n), embedding dim(200)])

            x = self.dim_reduction(x) #torch.Size([batch, seq len(m), embedding dim(200)])

            x = x / float(torch.norm(x) + 1e-6)
            x = torch.cat((x, out), dim=1) # new shape (batch_size, sequence_length(m+n), feature_size (200))


        y, loss = self.output_layer(x, target)
        
        
        return y, loss

    def init_bert_weights(self, module):
        BertLayerNorm = torch.nn.LayerNorm
        
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        elif isinstance(module, BertLayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    def freeze_net(self):
        pass

class Bert_L4_WE_KG(nn.Module): #adding last 5 layers output of bert

    def __init__(self, args, Y, dicts):
        super(Bert_L4_WE_KG, self).__init__()

        cache_path = os.path.join(args['bert_dir'], args['pretrained_bert'])
        
        savedModel = None
        if os.path.exists(cache_path):
            savedModel = tr.BertModel.from_pretrained(cache_path, return_dict=True)
        else:
            savedModel = tr.BertModel.from_pretrained(str(args['pretrained_bert']), return_dict=True)
            savedModel.save_pretrained(save_directory = cache_path, save_config=True)
        self.bert = savedModel
        self.config = savedModel.config


        self.feature_size = self.config.hidden_size*4

        if args['use_embd_layer']:
            self.kg_embd = EntityEmbedding(args, Y)
            self.kg_embd.embed.weight.requires_grad=False
            self.dim_reduction = nn.Linear(self.feature_size, self.kg_embd.feature_size)
            self.feature_size = self.kg_embd.feature_size


        self.dropout = nn.Dropout(args['dropout'])
        self.output_layer = OutputLayer(args, Y, dicts, self.feature_size)
        self.apply(self.init_bert_weights)
        

    def forward(self, input_ids, token_type_ids, attention_mask, entity_embeddings, target):

        output = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_hidden_states=True)
        

        #*************experiment*************
        hidden_states = output.hidden_states
        # concatenate last four layers
        x = torch.cat([hidden_states[i] for i in [-1,-2,-3,-4]], dim=-1) #[batch_size, sequence_length, hidden_size(786)*4]
        #***********experiment***************


        x = self.dropout(x)

        if hasattr(self, 'kg_embd'):
            out = self.kg_embd(entity_embeddings) #torch.Size([batch, seq len(n), embedding dim(200)])

            x = self.dim_reduction(x) #torch.Size([batch, seq len(m), embedding dim(200)])

            x = x / float(torch.norm(x) + 1e-6)
            x = torch.cat((x, out), dim=1) # new shape (batch_size, sequence_length(m+n), feature_size (200))
            x = self.dropout(x)

        y, loss = self.output_layer(x, target)
        
        return y, loss

    def loss_fn(self, outputs, target):
        return nn.BCEWithLogitsLoss()(outputs, target)

    def init_bert_weights(self, module):
        BertLayerNorm = torch.nn.LayerNorm
        
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        elif isinstance(module, BertLayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    def freeze_net(self):
        pass


class Bert_MCNN_KG(nn.Module): #Bert with KG and CNN

    def __init__(self, args, Y, dicts):
        super(Bert_MCNN_KG, self).__init__()

        cache_path = os.path.join(args['bert_dir'], args['pretrained_bert'])
        
        savedModel = None
        if os.path.exists(cache_path):
            savedModel = tr.BertModel.from_pretrained(cache_path)
        else:
            savedModel = tr.BertModel.from_pretrained(str(args['pretrained_bert']))
            savedModel.save_pretrained(save_directory = cache_path, save_config=True)
        self.bert = savedModel
        self.config = savedModel.config

        self.dim_reduction1 = nn.Linear(self.config.hidden_size*4, self.config.hidden_size)

        self.feature_size = self.config.hidden_size

        if args['use_embd_layer']:
            self.kg_embd = EntityEmbedding(args, Y)
            self.kg_embd.embed.weight.requires_grad=False
            self.dim_reduction2 = nn.Linear(self.feature_size, self.kg_embd.feature_size)
            self.feature_size = self.kg_embd.feature_size
        
        self.dropout = nn.Dropout(args['dropout'])

        self.conv = MRCNNLayer(args, self.feature_size)


        self.feature_size = self.conv.filter_num * args['num_filter_maps']

        self.output_layer = OutputLayer(args, Y, dicts, self.feature_size)
      
        
        # self.apply(self.init_bert_weights)
        

    def forward(self, input_ids, token_type_ids, attention_mask, entity_embeddings, target):

        output = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_hidden_states=True)

        #*************experiment*************
        hidden_states = output.hidden_states
        # concatenate last four layers
        x = torch.cat([hidden_states[i] for i in [-1,-2,-3,-4]], dim=-1) #[batch_size, sequence_length, hidden_size(786)*4]
        #***********experiment***************

        x = x / float(torch.norm(x) + 1e-6) #normalize
        x = self.dim_reduction1(x) #[batch_size, sequence_length, hidden_size(786]

        if hasattr(self, 'kg_embd'):
            out = self.kg_embd(entity_embeddings) #torch.Size([batch, seq len(n), embedding dim(200)])
            x = self.dim_reduction2(x)
            x = torch.cat((x, out), dim=1) # new shape (batch_size, sequence_length(m+n), feature_size (200))

        x = self.dropout(x) #(batch_size, sequence_length, hidden_size(786 or 200))

        x = self.conv(x)

        y, loss = self.output_layer(x, target)
        
        return y, loss

    def freeze_net(self):
        pass

    
    


#############Train-Test###############

class Train_Test:
    def __init__(self):
        print("Train--Test")
    
    def train(self, args, model, optimizer, scheduler, epoch, gpu, data_loader):
        # print("EPOCH %d" % epoch)
        print('Epoch:', epoch,'LR:', optimizer.param_groups[0]['lr'])
        print('Epoch:', epoch,'LR:', scheduler.get_last_lr())

        losses = []


        model.train()

        # loader
        data_iter = iter(data_loader)
        num_iter = len(data_loader)
        
        for i in tqdm(range(num_iter)):

            if args['model'].find("bert") != -1:

                inputs_id, segments, masks, ent_embeddings, labels = next(data_iter)

                inputs_id, segments, masks, labels = torch.LongTensor(inputs_id), torch.LongTensor(segments), \
                                                     torch.LongTensor(masks), torch.FloatTensor(labels)
                if args['use_embd_layer']:
                    #for embedding layer
                    ent_embeddings = torch.LongTensor(ent_embeddings)
                else:
                    ent_embeddings = None

                if gpu >= 0:
                    if args['use_embd_layer']:
                        ent_embeddings = ent_embeddings.cuda(gpu)
                    else:
                        ent_embeddings = None
                        
                    inputs_id, segments, masks, labels = inputs_id.cuda(gpu), segments.cuda(gpu), \
                                                         masks.cuda(gpu), labels.cuda(gpu)
                try:
                    optimizer.zero_grad()
                    output, loss = model(inputs_id, segments, masks, ent_embeddings, labels)
                except:
                    print("Unexpected error:", sys.exc_info()[0])
                    raise
                
            else:

                inputs_id, labels, text_inputs, embeddings, tfIdf_inputs = next(data_iter)

                if args['use_embd_layer']:
                    embeddings = torch.LongTensor(embeddings)

                if args['use_sentiment']:
                    input_ids = torch.stack([x_[0][0] for x_ in text_inputs])
                    attention = torch.stack([x_[1][0] for x_ in text_inputs])
                    text_inputs = [input_ids,attention]

                if args['use_tfIdf']:
                    tfIdf_inputs = torch.FloatTensor(tfIdf_inputs)

                inputs_id, labels = torch.LongTensor(inputs_id), torch.FloatTensor(labels)

                optimizer.zero_grad()
                output, loss = model(inputs_id, labels, text_inputs, embeddings, tfIdf_inputs)

            loss.backward()

            if args['grad_clip']:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            optimizer.step()
            losses.append(loss.item())
          
        return losses
    
    def test(self, args, model, data_path, fold, gpu, dicts, data_loader):
        self.model_name = args['model']
        filename = data_path.replace('train', fold)
        print('file for evaluation: %s' % filename)
        num_labels = len(dicts['ind2c'])

        y, yhat, yhat_raw, hids, losses = [], [], [], [], []

        model.eval()

        # loader
        data_iter = iter(data_loader)
        num_iter = len(data_loader)
        for i in tqdm(range(num_iter)):
            with torch.no_grad():

                if args['model'].find("bert") != -1:
                    inputs_id, segments, masks, ent_embeddings, labels = next(data_iter)

                    inputs_id, segments, masks, labels = torch.LongTensor(inputs_id), torch.LongTensor(segments), \
                                                         torch.LongTensor(masks), torch.FloatTensor(labels)
                    
                    
                    if args['use_embd_layer']:
                        #for embedding layer
                        ent_embeddings = torch.LongTensor(ent_embeddings)
                    else:
                        ent_embeddings = None

                    if gpu >= 0:
                        if args['use_embd_layer']:
                            ent_embeddings = ent_embeddings.cuda(gpu)
                        else:
                            ent_embeddings = None
                        inputs_id, segments, masks, labels = inputs_id.cuda(
                            gpu), segments.cuda(gpu), masks.cuda(gpu), labels.cuda(gpu)

                    try:
                        output, loss = model(inputs_id, segments, masks, ent_embeddings, labels)
                    except:
                        print("Unexpected error:", sys.exc_info()[0])
                        raise
                    
                else:
                    
                    inputs_id, labels, text_inputs, embeddings, tfIdf_inputs = next(data_iter)

                    if args['use_embd_layer']:
                        embeddings = torch.LongTensor(embeddings)

                    if args['use_sentiment']:
                        input_ids = torch.stack([x_[0][0] for x_ in text_inputs])
                        attention = torch.stack([x_[1][0] for x_ in text_inputs])
                        text_inputs = [input_ids,attention]

                    if args['use_tfIdf']:
                        tfIdf_inputs = torch.FloatTensor(tfIdf_inputs)

                    inputs_id, labels = torch.LongTensor(inputs_id), torch.FloatTensor(labels)

                    output, loss = model(inputs_id, labels, text_inputs, embeddings, tfIdf_inputs)

                output = torch.sigmoid(output)
                output = output.data.cpu().numpy()

                losses.append(loss.item())
                target_data = labels.data.cpu().numpy()

                yhat_raw.append(output)
                
                output = np.round(output)
                
                y.append(target_data)
                
                yhat.append(output)

        y = np.concatenate(y, axis=0)
        yhat = np.concatenate(yhat, axis=0)
        yhat_raw = np.concatenate(yhat_raw, axis=0)


        k = 5 if num_labels == 50 else [8,15]

        self.new_metric_calc(y, yhat_raw) #checking my metric values #considering 0 detection as TN

        self.calculate_print_metrics(y, yhat_raw) #checking sklearn metric values considering 0 detection as TP
        
        metrics = self.all_metrics(yhat, y, k=k, yhat_raw=yhat_raw)

        print()
        print("Metric calculation by Fei Li and Hong Yu start")
        self.print_metrics(metrics)
        print("Metric calculation by Fei Li and Hong Yu end")
        print()
        
        
        metrics['loss_%s' % fold] = np.mean(losses)

        print('loss_%s' % fold, metrics['loss_%s' % fold])
        return metrics
    
    def new_metric_calc(self, y, yhat):
        names = ["acc", "prec", "rec", "f1"]
        
        yhat = np.round(yhat) #rounding the vaues
        
        #Macro
        macro_accuracy = np.mean([accuracy_score(y[i], yhat[i]) for i in range(len(y))])
        macro_precision = np.mean([self.getPrecision(y[i], yhat[i]) for i in range(len(y))])
        macro_recall = np.mean([self.getRecall(y[i], yhat[i]) for i in range(len(y))])
        macro_f_score = np.mean([self.getFScore(y[i], yhat[i]) for i in range(len(y))])
        
        
        #Micro
        ymic = y.ravel()
        yhatmic = yhat.ravel()
        
        micro_accuracy =  accuracy_score(ymic, yhatmic)
        micro_precision =  self.getPrecision(ymic, yhatmic)
        micro_recall =  self.getRecall(ymic, yhatmic)
        micro_f_score =  self.getFScore(ymic, yhatmic)
        
        
        macro = (macro_accuracy, macro_precision, macro_recall, macro_f_score)
        micro = (micro_accuracy, micro_precision, micro_recall, micro_f_score)
        
        metrics = {names[i] + "_macro": macro[i] for i in range(len(macro))}
        metrics.update({names[i] + "_micro": micro[i] for i in range(len(micro))})

        print()
        print("Metric calculation for all labels together start")
        self.print_metrics(metrics)
        print("Metric calculation for all labels together end")
        print()
        
        return metrics
    
    def getFScore(self, y, yhat):
        prec = self.getPrecision(y, yhat)
        rec = self.getRecall(y, yhat)
        if prec + rec == 0:
            f1 = 0.
        else:
            f1 = (2*(prec*rec))/(prec+rec)
        
        return f1
    
    def getRecall(self, y, yhat):
        return self.getTP(y, yhat)/(self.getTP(y, yhat) + self.getFN(y, yhat) + 1e-10)
    
    def getPrecision(self, y, yhat):
        return self.getTP(y, yhat)/(self.getTP(y, yhat) + self.getFP(y, yhat) + 1e-10)
    
    def getTP(self, y, yhat):
        return np.multiply(y, yhat).sum().item()
    
    def getFN(self, y, yhat):
        return np.multiply(y, np.logical_not(yhat).astype(float)).sum().item()
    
    def getFP(self, y, yhat):
        return np.multiply(np.logical_not(y).astype(float), y).sum().item()
    
    def calculate_print_metrics(self, y, yhat):
        
        names = ["acc", "prec", "rec", "f1"]
        
        yhat = np.round(yhat) #rounding the vaues
        
        macro_precision, macro_recall, macro_f_score, macro_support = precision_recall_fscore_support(y, yhat, average = 'macro', zero_division=1)
#         macro_accuracy = ((np.concatenate(np.round(yhat), axis=0) == np.concatenate(y, axis=0)).sum().item()) / len(y) #accuracy_score(y, np.round(yhat))
#         macro_accuracy = ((np.round(yhat) == y).sum().item() / len(y[0])) / len(y)
        macro_accuracy = np.mean([accuracy_score(y[i], yhat[i]) for i in range(len(y))])
        
        
        ymic = y.ravel()
        yhatmic = yhat.ravel()
        micro_precision, micro_recall, micro_f_score, micro_support = precision_recall_fscore_support(ymic, yhatmic, average='micro', zero_division=1)
        micro_accuracy =  accuracy_score(ymic, yhatmic) 
        
        macro = (macro_accuracy, macro_precision, macro_recall, macro_f_score)
        micro = (micro_accuracy, micro_precision, micro_recall, micro_f_score)
        
        metrics = {names[i] + "_macro": macro[i] for i in range(len(macro))}
        metrics.update({names[i] + "_micro": micro[i] for i in range(len(micro))})
        
        
        print()
        print("Sklearn Metric calculation start")
        self.print_metrics(metrics)
        print("Sklearn Metric calculation end")
        print()

        return metrics
        
    
    def all_metrics(self, yhat, y, k=8, yhat_raw=None, calc_auc=True):
        """
            Inputs:
                yhat: binary predictions matrix
                y: binary ground truth matrix
                k: for @k metrics
                yhat_raw: prediction scores matrix (floats)
            Outputs:
                dict holding relevant metrics
        """
        names = ["acc", "prec", "rec", "f1"]

        #macro
        macro = self.all_macro(yhat, y)
        #micro
        ymic = y.ravel()
        yhatmic = yhat.ravel()
        micro = self.all_micro(yhatmic, ymic)

        metrics = {names[i] + "_macro": macro[i] for i in range(len(macro))}
        metrics.update({names[i] + "_micro": micro[i] for i in range(len(micro))})

        #AUC and @k
        if yhat_raw is not None and calc_auc:
            #allow k to be passed as int or list
            if type(k) != list:
                k = [k]
            for k_i in k:
                rec_at_k = self.recall_at_k(yhat_raw, y, k_i)
                metrics['rec_at_%d' % k_i] = rec_at_k
                prec_at_k = self.precision_at_k(yhat_raw, y, k_i)
                metrics['prec_at_%d' % k_i] = prec_at_k
                metrics['f1_at_%d' % k_i] = 2*(prec_at_k*rec_at_k)/(prec_at_k+rec_at_k)

            roc_auc = self.auc_metrics(yhat_raw, y, ymic)
            metrics.update(roc_auc)

        return metrics
    
    def auc_metrics(self, yhat_raw, y, ymic):
        if yhat_raw.shape[0] <= 1:
            return
        fpr = {}
        tpr = {}
        roc_auc = {}
        #get AUC for each label individually
        relevant_labels = []
        auc_labels = {}
        for i in range(y.shape[1]):
            #only if there are true positives for this label
            if y[:,i].sum() > 0:
                fpr[i], tpr[i], _ = roc_curve(y[:,i], yhat_raw[:,i])
                if len(fpr[i]) > 1 and len(tpr[i]) > 1:
                    auc_score = auc(fpr[i], tpr[i])
                    if not np.isnan(auc_score):
                        auc_labels["auc_%d" % i] = auc_score
                        relevant_labels.append(i)

        #macro-AUC: just average the auc scores
        aucs = []
        for i in relevant_labels:
            aucs.append(auc_labels['auc_%d' % i])
        roc_auc['auc_macro'] = np.mean(aucs)

        #micro-AUC: just look at each individual prediction
        yhatmic = yhat_raw.ravel()
        fpr["micro"], tpr["micro"], _ = roc_curve(ymic, yhatmic)
        roc_auc["auc_micro"] = auc(fpr["micro"], tpr["micro"])

        return roc_auc
    
    def precision_at_k(self, yhat_raw, y, k):
        #num true labels in top k predictions / k
        sortd = np.argsort(yhat_raw)[:,::-1]
        topk = sortd[:,:k]

        #get precision at k for each example
        vals = []
        for i, tk in enumerate(topk):
            if len(tk) > 0:
                num_true_in_top_k = y[i,tk].sum()
                denom = len(tk)
                vals.append(num_true_in_top_k / float(denom))

        return np.mean(vals)
    
    def recall_at_k(self,yhat_raw, y, k):
        #num true labels in top k predictions / num true labels
        sortd = np.argsort(yhat_raw)[:,::-1]
        topk = sortd[:,:k]

        #get recall at k for each example
        vals = []
        for i, tk in enumerate(topk):
            num_true_in_top_k = y[i,tk].sum()
            denom = y[i,:].sum()
            vals.append(num_true_in_top_k / float(denom))

        vals = np.array(vals)
        vals[np.isnan(vals)] = 0.

        return np.mean(vals)
    
    def all_micro(self, yhatmic, ymic):
        return self.micro_accuracy(yhatmic, ymic), self.micro_precision(yhatmic, ymic), self.micro_recall(yhatmic, ymic), self.micro_f1(yhatmic, ymic)
    
    def micro_f1(self, yhatmic, ymic):
        prec = self.micro_precision(yhatmic, ymic)
        rec = self.micro_recall(yhatmic, ymic)
        if prec + rec == 0:
            f1 = 0.
        else:
            f1 = 2*(prec*rec)/(prec+rec)
        return f1
    
    def micro_recall(self, yhatmic, ymic):
        return self.intersect_size(yhatmic, ymic, 0) / (ymic.sum(axis=0) + 1e-10) #NaN fix
    
    def micro_precision(self, yhatmic, ymic):
        return self.intersect_size(yhatmic, ymic, 0) / (yhatmic.sum(axis=0) + 1e-10) #NaN fix
    
    def micro_accuracy(self, yhatmic, ymic):
        return self.intersect_size(yhatmic, ymic, 0) / (self.union_size(yhatmic, ymic, 0) + 1e-10) #NaN fix
    
    def all_macro(self,yhat, y):
        return self.macro_accuracy(yhat, y), self.macro_precision(yhat, y), self.macro_recall(yhat, y), self.macro_f1(yhat, y)
    
    def macro_f1(self, yhat, y):
        prec = self.macro_precision(yhat, y)
        rec = self.macro_recall(yhat, y)
        if prec + rec == 0:
            f1 = 0.
        else:
            f1 = 2*(prec*rec)/(prec+rec)
        return f1
    
    def macro_recall(self, yhat, y):
        num = self.intersect_size(yhat, y, 0) / (y.sum(axis=0) + 1e-10)
        return np.mean(num)
    
    def macro_precision(self, yhat, y):
        num = self.intersect_size(yhat, y, 0) / (yhat.sum(axis=0) + 1e-10)
        return np.mean(num)
    
    def macro_accuracy(self, yhat, y):
        num = self.intersect_size(yhat, y, 0) / (self.union_size(yhat, y, 0) + 1e-10)
        return np.mean(num)
    
    def intersect_size(self, yhat, y, axis):
        #axis=0 for label-level union (macro). axis=1 for instance-level
        return np.logical_and(yhat, y).sum(axis=axis).astype(float)
    
    def union_size(self, yhat, y, axis):
        #axis=0 for label-level union (macro). axis=1 for instance-level
        return np.logical_or(yhat, y).sum(axis=axis).astype(float)
    
    def print_metrics(self, metrics):
        print()
        if "auc_macro" in metrics.keys():
            print("[MACRO] accuracy, precision, recall, f-measure, AUC")
            print("   %.4f, %.4f, %.4f, %.4f, %.4f" % (metrics["acc_macro"], metrics["prec_macro"], metrics["rec_macro"], metrics["f1_macro"], metrics["auc_macro"]))
        else:
            print("[MACRO] accuracy, precision, recall, f-measure")
            print("   %.4f, %.4f, %.4f, %.4f" % (metrics["acc_macro"], metrics["prec_macro"], metrics["rec_macro"], metrics["f1_macro"]))

        if "auc_micro" in metrics.keys():
            print("[MICRO] accuracy, precision, recall, f-measure, AUC")
            print("   %.4f, %.4f, %.4f, %.4f, %.4f" % (metrics["acc_micro"], metrics["prec_micro"], metrics["rec_micro"], metrics["f1_micro"], metrics["auc_micro"]))
        else:
            print("[MICRO] accuracy, precision, recall, f-measure")
            print("   %.4f, %.4f, %.4f, %.4f" % (metrics["acc_micro"], metrics["prec_micro"], metrics["rec_micro"], metrics["f1_micro"]))
        for metric, val in metrics.items():
            if metric.find("rec_at") != -1:
                print("%s: %.4f" % (metric, val))
        print()

        
        
        


#############Model Summary###############

import torch
import torch.nn as nn
from torch.autograd import Variable

from collections import OrderedDict
import numpy as np


def My_Summary(model, input_size, batch_size=-1, device="cuda"):

    def register_hook(module):

        def hook(module, input, output):
            class_name = str(module.__class__).split(".")[-1].split("'")[0]
            module_idx = len(summary)

            m_key = "%s-%i" % (class_name, module_idx + 1)
            summary[m_key] = OrderedDict()
            summary[m_key]["input_shape"] = list(input[0].size())
            summary[m_key]["input_shape"][0] = batch_size
            if isinstance(output, (list, tuple)):
                summary[m_key]["output_shape"] = [[-1] + list(o.size())[1:] for o in output if len(list(o.size())) > 0][0]
            else:
                summary[m_key]["output_shape"] = list(output.size())
                if len(summary[m_key]["output_shape"]) > 0:
                    summary[m_key]["output_shape"][0] = batch_size
                else:
                    summary[m_key]["output_shape"] = [-1]

            params = 0
            if hasattr(module, "weight") and hasattr(module.weight, "size"):
                params += torch.prod(torch.LongTensor(list(module.weight.size())))
                summary[m_key]["trainable"] = module.weight.requires_grad
            if hasattr(module, "bias") and hasattr(module.bias, "size"):
                params += torch.prod(torch.LongTensor(list(module.bias.size())))
            summary[m_key]["nb_params"] = params

        if (
            not isinstance(module, nn.Sequential)
            and not isinstance(module, nn.ModuleList)
            and not (module == model)
        ):
            hooks.append(module.register_forward_hook(hook))

    device = device.lower()
    assert device in [
        "cuda",
        "cpu",
    ], "Input device is not valid, please specify 'cuda' or 'cpu'"

    # if device == "cuda" and torch.cuda.is_available():
    #     dtype = torch.cuda.FloatTensor
    # else:
    #     dtype = torch.FloatTensor

    # multiple inputs to the network
    if isinstance(input_size, tuple):
        input_size = [input_size]


    # batch_size of 2 for batchnorm
    x = [torch.rand(2, *in_size[0]).type(in_size[1]) if in_size[1] != 0 else None for in_size in input_size]
    # print(type(x[0]))

    # create properties
    summary = OrderedDict()
    hooks = []

    # register hook
    model.apply(register_hook)

    # make a forward pass
    # print(x.shape)
    model(*x)

    # remove these hooks
    for h in hooks:
        h.remove()

    print("----------------------------------------------------------------")
    line_new = "{:>20}  {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #")
    print(line_new)
    print("================================================================")
    total_params = 0
    total_output = 0
    trainable_params = 0
    for layer in summary:
        # input_shape, output_shape, trainable, nb_params
        line_new = "{:>20}  {:>25} {:>15}".format(
            layer,
            str(summary[layer]["output_shape"]),
            "{0:,}".format(summary[layer]["nb_params"]),
        )
        total_params += summary[layer]["nb_params"]
        total_output += np.prod(summary[layer]["output_shape"])
        if "trainable" in summary[layer]:
            if summary[layer]["trainable"] == True:
                trainable_params += summary[layer]["nb_params"]
        print(line_new)
    # assume 4 bytes/number (float on cuda).
    total_input_size = abs(np.prod([in_size[0][0] for in_size in input_size]) * batch_size * 4. / (1024 ** 2.))
    total_output_size = abs(2. * total_output * 4. / (1024 ** 2.))  # x2 for gradients
    total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.))
    total_size = total_params_size + total_output_size + total_input_size

    print("================================================================")
    print("Total params: {0:,}".format(total_params))
    print("Trainable params: {0:,}".format(trainable_params))
    print("Non-trainable params: {0:,}".format(total_params - trainable_params))
    print("----------------------------------------------------------------")
    print("Input size (MB): %0.2f" % total_input_size)
    print("Forward/backward pass size (MB): %0.2f" % total_output_size)
    print("Params size (MB): %0.2f" % total_params_size)
    print("Estimated Total Size (MB): %0.2f" % total_size)
    print("----------------------------------------------------------------")
    # return summary

    
    
    
    
    


#############Main###############

class MyDataset(Dataset):

    def __init__(self, X):
        self.X = X


    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx]

class Run:
    def __init__(self, args):
        if args['random_seed'] != 0:
            random.seed(args['random_seed'])
            np.random.seed(args['random_seed'])
            torch.manual_seed(args['random_seed'])
            torch.cuda.manual_seed_all(args['random_seed'])
        
        print("loading lookups...")
        dicts = self.load_lookups(args)
        modelhub = ModelHub(args, dicts)
        model = modelhub.pick_model(args, dicts)
        print(model)

        My_Summary(model,
                [(tuple([args['MAX_LENGTH']]),torch.LongTensor), (tuple([len(dicts['ind2c'])]),torch.FloatTensor), (tuple([0]),0), (tuple([args['MAX_ENT_LENGTH']]),torch.LongTensor), (tuple([0]),0)],
                device="cpu") #inputs_id, labels, text_inputs, embeddings, tfIdf_inputs

        if not args['test_model']:
            optimizer = optim.Adam(model.parameters(), weight_decay=args['weight_decay'], lr=args['lr'])
            # optimizer = optim.AdamW(model.parameters(), lr=args['lr'], betas=(0.9, 0.999), eps=1e-08, weight_decay=args['weight_decay'], amsgrad=True)
        else:
            optimizer = None

        if args['tune_wordemb'] == False:
            model.freeze_net()
        
        metrics_hist = defaultdict(lambda: [])
        metrics_hist_te = defaultdict(lambda: [])
        metrics_hist_tr = defaultdict(lambda: [])

        if args['model'].find("bert") != -1:
            prepare_instance_func = self.prepare_instance_bert
        else:
            prepare_instance_func = self.prepare_instance
            
        train_instances = prepare_instance_func(dicts, args['data_path'], args, args['MAX_LENGTH'])
        print("train_instances {}".format(len(train_instances)))
        
        dev_instances = prepare_instance_func(dicts, args['data_path'].replace('train','dev'), args, args['MAX_LENGTH'])
        print("dev_instances {}".format(len(dev_instances)))
            
        test_instances = prepare_instance_func(dicts, args['data_path'].replace('train','test'), args, args['MAX_LENGTH'])
        print("test_instances {}".format(len(test_instances)))
        
        if args['model'].find("bert") != -1:
            collate_func = self.my_collate_bertf
        else:
            collate_func = self.my_collate
        
        train_loader = DataLoader(MyDataset(train_instances), args['batch_size'], shuffle=True, collate_fn=collate_func)

        dev_loader = DataLoader(MyDataset(dev_instances), 1, shuffle=False, collate_fn=collate_func)
  
        test_loader = DataLoader(MyDataset(test_instances), 1, shuffle=False, collate_fn=collate_func)
        
        if not args['test_model'] and args['model'].find("bert") != -1:
            #original start
            param_optimizer = list(model.named_parameters())
            param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]
            no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
            optimizer_grouped_parameters = [
                {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
                 'weight_decay': 0.01},
                {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
            ]

            num_train_optimization_steps = int(
                len(train_instances) / args['batch_size'] + 1) * args['n_epochs']

            # optimizer = AdamW(optimizer_grouped_parameters, lr=args['lr'], eps=1e-8)
            # optimizer = BertAdam(optimizer_grouped_parameters,
            #                      lr=args['lr'],
            #                      warmup=0.1,
            #                      e=1e-8,
            #                      t_total=num_train_optimization_steps)
            optimizer = BertAdam(optimizer_grouped_parameters,
                                 lr=args['lr'],
                                 warmup=0.1,
                                 t_total=num_train_optimization_steps)
            #original end
        
        scheduler = StepLR(optimizer, step_size=args['step_size'], gamma=args['gamma'])

        test_only = args['test_model'] is not None
        
        train_test = Train_Test()

        for epoch in range(args['n_epochs']):

            if epoch == 0 and not args['test_model'] and args['save_everything']:
                model_dir = os.path.join(args['MODEL_DIR'], '_'.join([args['model'], time.strftime('%b_%d_%H_%M_%S', time.localtime())]))
                os.makedirs(model_dir)
            elif args['test_model']:
                model_dir = os.path.dirname(os.path.abspath(args['test_model']))

            if not test_only:
                epoch_start = time.time()
                losses = train_test.train(args, model, optimizer, scheduler, epoch, args['gpu'], train_loader)
                loss = np.mean(losses)
                epoch_finish = time.time()
                print("epoch finish in %.2fs, loss: %.4f" % (epoch_finish - epoch_start, loss))
            else:
                loss = np.nan

            if epoch == args['n_epochs'] - 1:
                print("last epoch: testing on dev and test sets")
                test_only = True

            # test on dev
            evaluation_start = time.time()
            metrics = train_test.test(args, model, args['data_path'], "dev", args['gpu'], dicts, dev_loader)
            evaluation_finish = time.time()
            print("evaluation finish in %.2fs" % (evaluation_finish - evaluation_start))
            if test_only or epoch == args['n_epochs'] - 1:
                metrics_te = train_test.test(args, model, args['data_path'], "test", args['gpu'], dicts, test_loader)
            else:
                metrics_te = defaultdict(float)

            if args['use_schedular']:
              #Update scheduler
              scheduler.step()

            metrics_tr = {'loss': loss}
            metrics_all = (metrics, metrics_te, metrics_tr)

            for name in metrics_all[0].keys():
                metrics_hist[name].append(metrics_all[0][name])
            for name in metrics_all[1].keys():
                metrics_hist_te[name].append(metrics_all[1][name])
            for name in metrics_all[2].keys():
                metrics_hist_tr[name].append(metrics_all[2][name])
            metrics_hist_all = (metrics_hist, metrics_hist_te, metrics_hist_tr)

            if args['save_everything']:
                self.save_everything(args, metrics_hist_all, model, model_dir, None, args['criterion'], test_only)

            sys.stdout.flush()

            if test_only:
                break
              
            if args['criterion'] in metrics_hist.keys():
                if self.early_stop(metrics_hist, args['criterion'], args['patience']):
                    #stop training, do tests on test and train sets, and then stop the script
                    print("%s hasn't improved in %d epochs, early stopping..." % (args['criterion'], args['patience']))
                    test_only = True
                    args['test_model'] = '%s/model_best_%s.pth' % (model_dir, args['criterion'])
                    model = modelhub.pick_model(args, dicts)
                        
    def load_lookups(self, args):

        csv.field_size_limit(sys.maxsize)
        ind2w, w2ind = self.load_vocab_dict(args, args['vocab'])
        

        #get code and description lookups
        if args['Y'] == 'full':
            ind2c = self.load_full_codes(args['data_path'], version=args['version'])
        else:
            codes = set()
            with open("%sTOP_%s_CODES.csv" % (args['out_path'], str(args['Y'])), 'r') as labelfile:
                lr = csv.reader(labelfile)
                for i,row in enumerate(lr):
                    codes.add(row[0])
            ind2c = {i:c for i,c in enumerate(sorted(codes))}

        c2ind = {c:i for i,c in ind2c.items()}

        dicts = {'ind2w': ind2w, 'w2ind': w2ind, 'ind2c': ind2c, 'c2ind': c2ind}

        return dicts
    
    def load_vocab_dict(self, args, vocab_file):
        vocab = set()

        with open(vocab_file, 'r') as vocabfile:
            for i, line in enumerate(vocabfile):
                line = line.rstrip()
                # if line.strip() in vocab:
                #     print(line)
                if line != '':
                    vocab.add(line.strip())

        ind2w = {i + 1: w for i, w in enumerate(sorted(vocab))}
        w2ind = {w: i for i, w in ind2w.items()}

        return ind2w, w2ind
    
    def load_full_codes(self,train_path, version='mimic3'):
        csv.field_size_limit(sys.maxsize)
        codes = set()
        for split in ['train', 'dev', 'test']:
            with open(train_path.replace('train', split), 'r') as f:
                lr = csv.reader(f)
                next(lr)
                for row in lr:
                    for code in row[3].split(';'): #codes are in 3rd position after removing hadm_id, 3 when hadm id
                        codes.add(code)
        codes = set([c for c in codes if c != ''])
        ind2c = defaultdict(str, {i:c for i,c in enumerate(sorted(codes))})
        return ind2c
    
    def prepare_instance(self, dicts, filename, args, max_length):
      #columns : SUBJECT_ID,	HADM_ID,	TEXT,	LABELS,	ENTITY_ID,	length
        print("reading from file=",filename)
        csv.field_size_limit(sys.maxsize)
        ind2w, w2ind, ind2c, c2ind = dicts['ind2w'], dicts['w2ind'], dicts['ind2c'], dicts['c2ind']
        instances = []
        num_labels = len(dicts['ind2c'])
        
        if args['use_embd_layer']:
            ent2vec = KGEntityToVec().getEntityToVec()
            keys_list = list(ent2vec.keys())

        if args['use_sentiment']:
            tokenizer = tr.AutoTokenizer.from_pretrained(str(args['sentiment_bert']))

        if args['use_tfIdf']:
            data_to_use = pd.read_csv(filename)
          
            X_data = data_to_use['TEXT']
            X_data = [text.replace('[CLS]','').replace('[SEP]','') for text in X_data]

            vectorizer = TfidfVectorizer(max_features=300)
            
            df_data = vectorizer.fit_transform(X_data)

            sequences_data = dict(zip(vectorizer.get_feature_names_out(), df_data.toarray()[0]+1))
            
            del data_to_use
            del X_data
            del vectorizer
            del df_data

        with open(filename, 'r') as infile:
            r = csv.reader(infile)
            #header
            next(r)

            count = 0
            for row in tqdm(r):

                text = row[2] #text is in 2nd column after removing hadm_id, 2 if HADM

                labels_idx = np.zeros(num_labels)
                labelled = False

                for l in row[3].split(';'): #labels are in 3rd column after removing hadm_id, 3 if HADM
                    if l in c2ind.keys():
                        code = int(c2ind[l])
                        labels_idx[code] = 1
                        labelled = True
                if not labelled:
                    continue

                tokens_ = text.split()
                tokens = []
                tokens_id = []
                for token in tokens_:
                    if token == '[CLS]' or token == '[SEP]':
                        continue
                    tokens.append(token)
                    token_id = w2ind[token] if token in w2ind else len(w2ind) + 1
                    tokens_id.append(token_id)

                if len(tokens) > max_length:
                    tokens = tokens[:max_length]
                    tokens_id = tokens_id[:max_length]

                if args['use_sentiment']:
                    tokens = text.replace('[CLS]', '').replace('[SEP]', '')
                    #Bert models can use max 512 tokens
                    tokens = tokenizer(tokens,
                                       padding='max_length',
                                       truncation=True,
                                       max_length=512,
                                       return_tensors='pt')

                if args['use_tfIdf']:
                    tf_idf = [sequences_data[token] if token in sequences_data else 1.0 for token in tokens]
                else:
                    tf_idf = None 

                if args['use_embd_layer']:
                    #getting entity embeddings from KG. Each entity embd is of 200d. Extending to to create a single array.
                    entities = row[4] #entities are stored in 4th column
                    entities_ = entities.split()
                    ent_found = False

                    #for use in embedding layer
                    entities_id = set()
                    for entity in entities_[:args['MAX_ENT_LENGTH']]:
                        ent_id = keys_list.index(entity)+1 if entity in keys_list else len(keys_list) + 1
                        entities_id.add(ent_id)
                        ent_found = True
                    
                    if not ent_found:
                        continue

                    entity_embeddings = list(entities_id)
                else:
                    entity_embeddings = None       

                dict_instance = {'label': labels_idx,
                                     'tokens': tokens,
                                     "entity_embd":entity_embeddings,
                                     "tokens_id": tokens_id,
                                      "tf_idf": tf_idf
                                 }

                instances.append(dict_instance)

                count += 1
            
                if args['instance_count'] != 'full' and count == int(args['instance_count']):
                    break

        return instances
    
    def prepare_instance_bert(self, dicts, filename, args, max_length):
      #columns : SUBJECT_ID,	HADM_ID,	TEXT,	LABELS,	ENTITY_ID,	length
        csv.field_size_limit(sys.maxsize)
        ind2w, w2ind, ind2c, c2ind = dicts['ind2w'], dicts['w2ind'], dicts['ind2c'], dicts['c2ind']
        instances = []
        num_labels = len(dicts['ind2c'])
        
        wp_tokenizer = tr.BertTokenizer.from_pretrained(args['pretrained_bert'], do_lower_case=True)

        ent2vec = KGEntityToVec().getEntityToVec()
        
        if args['use_embd_layer']:
            keys_list = list(ent2vec.keys())

        with open(filename, 'r') as infile:
            r = csv.reader(infile)
            #header
            next(r)
            count = 0
            for row in tqdm(r):
                
                text = row[2] #text is in 2nd column now after removing hadm_id, if HADM_ID then text is in 3rd column

                labels_idx = np.zeros(num_labels)
                labelled = False

                for l in row[3].split(';'): #labels are in 3rd column after removing hadm_id
                    if l in c2ind.keys():
                        code = int(c2ind[l])
                        labels_idx[code] = 1
                        labelled = True
                if not labelled:
                    continue

                # original 2 start
                ##Changes made by prantik for obove code start
                tokens = wp_tokenizer.tokenize(text)
                tokens = list(filter(lambda a: (a != "[CLS]" and a != "[SEP]"), tokens))[0:max_length-2]
                tokens.insert(0, '[CLS]')
                tokens.append('[SEP]')
                ##Changes made by prantik for obove code end

                tokens_id = wp_tokenizer.convert_tokens_to_ids(tokens)
                masks = [1] * len(tokens)
                segments = [0] * len(tokens)
                # original 2 end

                
                #getting entity embeddings from KG. Each entity embd is of 200d. Extending to to create a single array.
                entities = row[4] #entities are stored in 4th column
                entities_ = entities.split()
                ent_found = False
                
                if args['use_embd_layer']:
                    #for use in embedding layer
                    entities_id = set()
                    for entity in entities_[:args['MAX_ENT_LENGTH']]:
                        ent_id = keys_list.index(entity)+1 if entity in keys_list else len(keys_list) + 1
                        entities_id.add(ent_id)
                        ent_found = True
                    
                    if not ent_found:
                        continue

                    entity_embeddings = list(entities_id)
                else:
                    entity_embeddings = None
    
                dict_instance = {'label':labels_idx, 'tokens':tokens, "entity_embd":entity_embeddings,
                                 "tokens_id":tokens_id, "segments":segments, "masks":masks}

                instances.append(dict_instance)
                
                count += 1
            
                if args['instance_count'] != 'full' and count == int(args['instance_count']):
                    break

        return instances


    def my_collate(self, x):
        words = [x_['tokens_id'] for x_ in x]
        max_seq_len = max([len(w) for w in words])
        if max_seq_len < args['MAX_LENGTH']:
            max_seq_len = args['MAX_LENGTH']

        inputs_id = self.pad_sequence(words, max_seq_len)

        labels = [x_['label'] for x_ in x]

        if args['use_sentiment']:
            text_inputs = [[x_['tokens']['input_ids'], x_['tokens']['attention_mask']] for x_ in x]
        else:
            text_inputs = []

        embeddings = None
        if args['use_embd_layer']:
            embeddings = [x_['entity_embd'] for x_ in x]
            emb_list = [len(x) for x in embeddings]
            max_embd_len = max(emb_list)
            if max_embd_len < args['MAX_ENT_LENGTH']:
                max_embd_len = args['MAX_ENT_LENGTH']
            embeddings = self.pad_sequence(embeddings, max_embd_len)
        
        tfIdf_inputs = None
        if args['use_tfIdf']:
            tfIdf_inputs = [x_['tf_idf'] for x_ in x]
            tfIdf_inputs = self.pad_sequence(tfIdf_inputs, max_seq_len, np.float)

        return inputs_id, labels, text_inputs, embeddings, tfIdf_inputs

    def my_collate_bert(self, x):
        words = [x_['tokens_id'] for x_ in x]
        segments = [x_['segments'] for x_ in x]
        masks = [x_['masks'] for x_ in x]
        embeddings = [x_['entity_embd'] for x_ in x]
        
        seq_len = [len(w) for w in words]
        max_seq_len = max(seq_len)
        
        if args['use_embd_layer']:
            #for embedding layer
            max_embd_len = max([len(x) for x in embeddings])
            if max_embd_len < args['MAX_ENT_LENGTH']:
                max_embd_len = args['MAX_ENT_LENGTH']
        try:
            inputs_id = self.pad_sequence(words, max_seq_len)
            segments = self.pad_sequence(segments, max_seq_len)
            masks = self.pad_sequence(masks, max_seq_len)
            if args['use_embd_layer']:
                #for embedding layer
                embeddings = self.pad_sequence(embeddings, max_embd_len)
        except:
            print("Unexpected error:", sys.exc_info()[0])
            raise

        labels = [x_['label'] for x_ in x]

        return inputs_id, segments, masks, embeddings, labels
    
    
    def pad_sequence(self, x, max_len, type=np.int):

        padded_x = np.zeros((len(x), max_len), dtype=type)
        for i, row in enumerate(x):
            if max_len >= len(row):
                padded_x[i][:len(row)] = row
            else:
                padded_x[i][:max_len] = row[:max_len] #trancate

        return padded_x
    
    def save_metrics(self, metrics_hist_all, model_dir):
        with open(model_dir + "/metrics.json", 'w') as metrics_file:
            #concatenate dev, train metrics into one dict
            data = metrics_hist_all[0].copy()
            data.update({"%s_te" % (name):val for (name,val) in metrics_hist_all[1].items()})
            data.update({"%s_tr" % (name):val for (name,val) in metrics_hist_all[2].items()})
            json.dump(data, metrics_file, indent=1)
            
    def save_everything(self, args, metrics_hist_all, model, model_dir, params, criterion, evaluate=False):

        self.save_args(args, model_dir)

        self.save_metrics(metrics_hist_all, model_dir)

        if not evaluate:
            #save the model with the best criterion metric
            if not np.all(np.isnan(metrics_hist_all[0][criterion])):
                if criterion == 'loss_dev':
                    eval_val = np.nanargmin(metrics_hist_all[0][criterion])
                else:
                    eval_val = np.nanargmax(metrics_hist_all[0][criterion])

                if eval_val == len(metrics_hist_all[0][criterion]) - 1:
                    print("saving model==")
                    sd = model.cpu().state_dict()
                    torch.save(sd, model_dir + "/model_best_%s.pth" % criterion)
                    if args['gpu'] >= 0:
                        model.cuda(args['gpu'])
        print("saved metrics, params, model to directory %s\n" % (model_dir))

    def save_args(self, args, model_path):
        file_path = model_path + "/args.json"
        if not os.path.exists(file_path):
            with open(file_path, 'w') as args_file:
                json.dump(args, args_file)
        
    def early_stop(self, metrics_hist, criterion, patience):
        if not np.all(np.isnan(metrics_hist[criterion])):
            if len(metrics_hist[criterion]) >= patience:
                if criterion == 'loss_dev':
                    return np.nanargmin(metrics_hist[criterion]) < len(metrics_hist[criterion]) - patience
                else:
                    return np.nanargmax(metrics_hist[criterion]) < len(metrics_hist[criterion]) - patience
        else:
            return False

In [None]:
#Set proper path values in args{} and hit for data processig and saving.
DataProcessing(args)

In [None]:
#Set proper values in args{} and hit for training, validating and testing.
Run(args)

loading lookups...
loading pretrained embeddings from /content/drive/MyDrive/Thesis/DeepDifferentialDiagnosis/data/25.10.2021-Old-Compare/test/processed_full.embed
adding unk embedding
loading pretrained embeddings from /content/drive/MyDrive/Thesis/DeepDifferentialDiagnosis/data/25.10.2021-Old-Compare/test/processed_full.embed
adding unk embedding
KG_MultiResCNN(
  (word_rep): WordRep(
    (embed): Embedding(48497, 100, padding_idx=0)
    (embed_drop): Dropout(p=0.2, inplace=False)
  )
  (kg_embd): EntityEmbedding(
    (embed): Embedding(17815, 200)
    (embed_drop): Dropout(p=0.2, inplace=False)
    (dim_red): Linear(in_features=200, out_features=100, bias=True)
    (feature_red): Linear(in_features=30, out_features=3000, bias=True)
  )
  (dropout): Dropout(p=0.2, inplace=False)
  (conv): MRCNNLayer(
    (conv): ModuleList(
      (channel-3): ModuleList(
        (baseconv): Conv1d(100, 100, kernel_size=(3,), stride=(1,), padding=(1,))
        (resconv-0): ResidualBlock(
          (le

28789it [00:23, 1202.90it/s]


train_instances 28787
reading from file= /content/drive/MyDrive/Thesis/DeepDifferentialDiagnosis/data/25.10.2021-Old-Compare/test/dev_full.csv


4082it [00:03, 1300.26it/s]