In [15]:
import sqlite3
import pandas as pd
import re
from tqdm import tqdm
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
import json
import os
import csv

## Helper Functions

## Sqlite3 Setup

In [None]:
db_path = 'F:/mimic-iii-clinical-database-1.4/mimic3.db'
sqliteConnection = sqlite3.connect(db_path)
mimiciii = sqliteConnection.cursor()

In [23]:
def get_df_from_files(path, filter_for=[]):
    hadm_ids = []
    texts = []
    
    if filter_for:
        files = [file + '.txt' for file in filter_for]
    else:
        files = os.listdir(path)
        
    for file in files:
        try:
            with open(path + file) as f:
                    text = f.read()
        except:
            print('Something went wrong with', file)
            continue
        hadm_id = file[:-4]
        
        hadm_ids.append(hadm_id)
        texts.append(text)
            
    return pd.DataFrame({'HADM_ID': hadm_ids, 'TEXT': texts})

In [24]:
filter_for = ['100001', '100003', '100006', '100007', '100009', '100010']

get_df_from_files('all_notes/', filter_for)

Unnamed: 0,HADM_ID,TEXT
0,100001,----CATEGORY:Discharge summary----\n\n--NEW:Re...
1,100003,----CATEGORY:Discharge summary----\n\n--NEW:Re...
2,100006,----CATEGORY:Discharge summary----\n\n--NEW:Re...
3,100007,----CATEGORY:Discharge summary----\n\n--NEW:Re...
4,100009,----CATEGORY:Discharge summary----\n\n--NEW:Re...
5,100010,----CATEGORY:Discharge summary----\n\n--NEW:Re...


In [None]:
def condense_notes(admission_df, noteevents_df):
    admission_ids = admission_df.HADM_ID.unique()
    admission_ids.sort()

    condensed_notes = pd.DataFrame()
    condensed_notes['HADM_ID'] = admission_ids
    notes_list = []
    
    for adm_id in tqdm(admission_ids):
        curr_adm = noteevents[noteevents.HADM_ID == adm_id]
        categories = curr_adm.CATEGORY.unique()
        curr_chart = ''
        
        for category in categories:
            curr_chart += '----CATEGORY: ' + category + '----\n\n'
            curr_category_notes = curr_adm[curr_adm.CATEGORY == category][['DESCRIPTION', 'TEXT']]
            curr_descriptions = curr_category_notes.DESCRIPTION.to_list()
            curr_notes = curr_category_notes.TEXT.to_list()

            for i in range(len(curr_descriptions)):
                curr_chart += '--NEW: ' + curr_descriptions[i] + '--\n'
                curr_chart += curr_notes[i] + '\n'

            curr_chart += '\n\n'

        notes_list.append(curr_chart)
    condensed_notes['TEXT'] = notes_list
    
    return condensed_notes

In [None]:
def fix_dot_zero(value):
    if pd.notna(value):
        return str(int(float(value)))
    else:
        return '0'

In [None]:
def get_col_names(cursor, table_name):
    cursor.execute(f"""
    SELECT sql FROM sqlite_master WHERE name='{table_name}';
    """)
    
    res = mimiciii.fetchall()
    cols = re.findall(r'\"\w+\"', res[0][0])
    return [x[1:-1] for x in cols]

In [None]:
def get_df_from_table_from_db(cursor, table_name, num_rows='*', skip_cols=[]):
    '''
    cursor: sqliteConnection cursor object
    table_name: name of table to get from cursor db
    num_rows: number of rows to retrieve (or '*' for all rows)
    skip_cols: list of columns to skip in the retrieval
    '''
    col_names = get_col_names(cursor, table_name)
    
    use_cols = [col for col in col_names if col not in skip_cols]
    
    if num_rows == '*':
        query = f'''select {', '.join(use_cols)} from {table_name};'''
    else:
        query = f'''select {', '.join(use_cols)} from {table_name} limit {num_rows};'''
        
    cursor.execute(query)
    rows = cursor.fetchall()
    
    return pd.DataFrame(rows, columns=use_cols, dtype=str)

In [None]:
def get_tables_list_from_db(cursor):
    cursor.execute("""
    select name from sqlite_master where type='table';
    """)
    table_names = [table[0] for table in cursor.fetchall()]
    return table_names

In [None]:
def to_int(x):
    if not x or np.isnan(x):
        return 0
    return int(x)

Steps for dataset:

1.  filter ADMISSIONS table for ARF-related diagnoses
2. filter drg table for ARF-related DRGs
3. join previous two tables
4. remove any similar but not ARF rows
5. join on HADM_ID other tables: NOTEEVENTS, PRESCRIPTIONS, LABEVENTS

#### 1. filter ADMISSIONS table for ARF-related diagnoses

In [None]:
#all admissions
admission = get_df_from_table_from_db(mimiciii, 'admissions')
admission = admission.apply(lambda x: x.astype(str).str.upper())
#admission.HADM_ID = admission.HADM_ID.astype('int64')
admission.SUBJECT_ID = admission.SUBJECT_ID.astype('int64')

In [None]:
#filter for ARF-related diagnoses
arf_adm= admission[admission.DIAGNOSIS.str.contains('RESPIRATORY FAILURE') | admission.DIAGNOSIS.str.contains('RESP. FAILURE')  | admission.DIAGNOSIS.str.contains('RESP FAILURE')]
to_exclude = ['CHRONIC RESPIRATORY FAILURE;AIRWAY OBSTRUCTION', 'CHRONIC RESPIRATORY FAILURE', 'CHRONIC RESPIRATORY FAILURE; TRAC OBSTRUCTED AIRWAY']
arf_adm = arf_adm[~arf_adm.DIAGNOSIS.isin(to_exclude)]
arf_adm = arf_adm[arf_adm.HADM_ID != 0]

arf_adm_ID = arf_adm.HADM_ID.to_list()

arf_adm.head()

#### 2. filter drg table for ARF-related DRGs

In [None]:
#all drg codes
drgcodes = get_df_from_table_from_db(mimiciii, 'drgcodes')
drgcodes.DESCRIPTION = drgcodes.DESCRIPTION.astype(str)

#filter for ARF
#arf_drg = drgcodes[drgcodes.DESCRIPTION.str.contains('RESPIRATORY FAILURE') | ((drgcodes.DRG_CODE == 193) & (drgcodes.DRG_TYPE == 'MS'))]
arf_drg = drgcodes[drgcodes.DESCRIPTION.str.contains('RESPIRATORY FAILURE')]
arf_drg = arf_drg[arf_drg.HADM_ID != 0]

arf_drg_ID = arf_drg.HADM_ID.to_list()

arf_drg.head()

In [None]:
print(arf_adm.shape)
print(arf_drg.shape)

#### 3. join previous two tables

In [None]:
hadm_ids = list(set(arf_drg_ID + arf_adm_ID))
print(len(hadm_ids))

In [None]:
with open('arf_hadm_ids.json', 'w') as j_file:
    json.dump(hadm_ids, j_file, indent=4)

#### 4. remove any similar but not ARF rows

already done

#### 5. join on HADM_ID other tables: NOTEEVENTS, PRESCRIPTIONS, LABEVENTS

#### noteevents

In [None]:
noteevents = get_df_from_table_from_db(mimiciii, 'noteevents')
noteevents.HADM_ID = noteevents.HADM_ID.apply(fix_dot_zero)

##### compile all notes

In [None]:
condensed_notes = condense_notes(admission, noteevents)
for index, row in condensed_notes.iterrows():
    with open(f'all_notes/{row["HADM_ID"]}.txt', 'w') as f:
        f.write(row['TEXT'])

condensed_notes.to_csv('all_notes_raw.csv', index=False)

In [None]:
condensed_notes = pd.read_csv('all_notes_raw.csv')
condensed_notes.TEXT  = condensed_notes.TEXT.astype(str)
condensed_notes.head()

In [None]:
def filter_notes(df, file_to_write, subset=False, start_at=0):
    df.TEXT = df.TEXT.astype(str)
    if subset:
        notes_list = df.TEXT.to_list()[:1000]
        hadm_ids = df.HADM_ID.to_list()[:1000]
    else:
        notes_list = df.TEXT.to_list()[start_at:]
        hadm_ids = df.HADM_ID.to_list()
    
    #filtered_notes = []
    stop_words = set(stopwords.words('english') + ['*'])
    
    num_notes = len(notes_list)
    with open(file_to_write, 'a', newline='') as file:
        csv_writer = csv.writer(file)
        for i, note in tqdm(enumerate(notes_list), total=num_notes):
            tokens = word_tokenize(note)
            filtered_text = ' '.join([word for word in tokens if word.lower() not in stop_words])
            
            csv_writer.writerow([str(hadm_ids[i]), filtered_text])
            
            with open('log.txt', 'w') as log:
                log.write(str(int(i) + int(start_at)))
    
try:
    with open('log.txt', 'r') as file:
        start_at = file.read()
except:
    start_at = 0

filter_notes(condensed_notes, file_to_write='notes_filtered_condensed.csv', start_at=int(start_at))

In [None]:
#get arf notes list or subject id
tmp1 = pd.merge(arf_adm, noteevents, on='HADM_ID', how='left').HADM_ID.to_list()
tmp2 = pd.merge(arf_drg, noteevents, on='HADM_ID', how='left').HADM_ID.to_list()
notes_set = set(tmp1 + tmp2)

In [None]:
#arf notes
arf_notes = noteevents[noteevents.HADM_ID.isin(notes_set)]
print(arf_notes.shape)
arf_notes.head()

In [None]:
#combine admission notes into one

arf_admission_ids = pd.unique(arf_notes.HADM_ID)
arf_admission_ids.sort()

condensed_arf_notes = pd.DataFrame()

condensed_arf_notes['HADM_ID'] = arf_admission_ids

arf_notes_list = []

for adm_id in tqdm(arf_admission_ids):
    curr_adm = arf_notes[arf_notes.HADM_ID == adm_id]
    categories = pd.unique(curr_adm.CATEGORY)
    curr_chart = ''
    
    for category in categories:
        curr_chart += '----CATEGORY:' + category + '----\n\n'
        curr_category_notes = curr_adm[curr_adm.CATEGORY == category][['DESCRIPTION', 'TEXT']]
        curr_descriptions = curr_category_notes.DESCRIPTION.to_list()
        curr_notes = curr_category_notes.TEXT.to_list()

        for i in range(len(curr_descriptions)):
            curr_chart += '--NEW:' + curr_descriptions[i] + '--\n'
            curr_chart += curr_notes[i] + '\n'

        curr_chart += '\n\n'
    
    arf_notes_list.append(curr_chart)
    
condensed_arf_notes['TEXT'] = arf_notes_list
print(len(condensed_arf_notes))
print(condensed_arf_notes.head())

In [None]:
condensed_arf_notes.head()

#combine admission notes into one
single_adm = arf_notes[arf_notes.HADM_ID == 134727]

categories = pd.unique(single_adm.CATEGORY)

curr_chart = ''
for category in categories:
    curr_chart += '----CATEGORY:' + category + '----\n\n'
    curr_category_notes = single_adm[single_adm.CATEGORY == category][['DESCRIPTION', 'TEXT']]
    curr_descriptions = curr_category_notes.DESCRIPTION.to_list()
    curr_notes = curr_category_notes.TEXT.to_list()
        
    for i in range(len(curr_descriptions)):
        curr_chart += '--NEW:' + curr_descriptions[i] + '--\n'
        curr_chart += curr_notes[i] + '\n'
        
    curr_chart += '\n\n'

print(curr_chart)

In [None]:
#pre-proc arf notes
arf_notes_list = condensed_arf_notes.TEXT.to_list()

filtered_notes = []
stop_words = set(stopwords.words('english') + ['*'])

for note in tqdm(arf_notes_list):
    tokens = word_tokenize(note)
    filtered_text = ' '.join([word for word in tokens if word.lower() not in stop_words])
    filtered_notes.append(filtered_text)

In [None]:
condensed_arf_notes['FILTERED_TEXT'] = filtered_notes
condensed_arf_notes.to_csv('arf_notes_filtered_condensed.csv', index=None)

In [None]:
print(condensed_arf_notes.TEXT.iloc[0])

In [None]:
master_df = pd.merge(arf_adm, condensed_arf_notes, on='HADM_ID', how='inner')
master_df.head()

In [None]:
noteevents.shape

In [None]:
sample_chart = septic_notes.iloc[0].FILTERED_TEXT

template = '''
Question: {question}

Clinical Chart: {chart}

Answer: Let's think step by step: 
'''

question = '''Given is some text found in a clinical chart. Is there sign or indication that this patient had sepsis?'''

prompt = PromptTemplate(
    input_variables=['question', 'chart'],
    template=template
)

In [None]:
sample_chart

In [None]:
'sepsis' in sample_chart

In [None]:
sample_chart = septic_notes.iloc[0].FILTERED_TEXT
llm_chain = LLMChain(prompt=prompt,
                     llm=ChatOpenAI(model_name='gpt-3.5-turbo', temperature=1e-5))

In [None]:
sample_chart = septic_notes.iloc[0].FILTERED_TEXT[:1000]
llm_chain = LLMChain(prompt=prompt,
                     llm=HuggingFaceHub(repo_id='google/flan-t5-xxl',
                                        model_kwargs={
                                            'temperature': 1e-5
                                        }))

In [None]:
print(llm_chain.run(question=question, chart=sample_chart))

In [None]:
template = """Question: {question}

Answer: """
prompt = PromptTemplate(
        template=template,
    input_variables=['question']
)

# user question
question = "Which NFL team won the Super Bowl in the 2010 season?"

In [None]:
# initialize Hub LLM
hub_llm = HuggingFaceHub(
        repo_id='google/flan-t5-xxl',
    model_kwargs={'temperature':1e-10}
)

# create prompt template > LLM chain
llm_chain = LLMChain(
    prompt=prompt,
    llm=hub_llm
)

# ask the user question about NFL 2010
print(llm_chain.run(question))

In [None]:
db = SQLDatabase.from_uri('sqlite:///F:/mimic-iii-clinical-database-1.4/mimic3.db')
db_chain = SQLDatabaseChain(llm=HuggingFaceHub(repo_id='google/flan-t5-xxl',
                                        model_kwargs={
                                            'temperature': .01
                                        }),
                            database=db,
                            verbose=True)

db_chain.run("How many tables are there?")