In [7]:
#!pip install psycopg2
#!pip install pysolr

In [101]:
import psycopg2
import pysolr
from datetime import datetime

# PART 1: Data Extraction and Indexing
To create a new core in solr, run: ./bin/solr create -c test -d path_to/conf, please change the path to the directory of the conf folder, for example mine is: ./bin/solr create -c test -d ~/Documents/uml/digital_health/conf.
The files in conf are based on the conf folder provided in the course github page, we changed the schema.xml to match our indexing schema.

In [None]:
conn = psycopg2.connect(host = "172.16.34.1", port = "5432", user = "mimic_demo", password = "mimic_demo", database = "mimic")
cur = conn.cursor()
cur.execute('SET search_path to mimciii')
#before connecting to the solr, please create a new core with configuration files in conf folder
solr = pysolr.Solr('http://localhost:8983/solr/test')

In [27]:
#This is the query to extract the required data in 1a.
query = """
SELECT DISTINCT ne.row_id, ne.chartdate, ne.text, ad.hospital_expire_flag, d.diagnosis
FROM mimiciii.noteevents ne
LEFT JOIN (
  SELECT hadm_id, array_agg(icd9_code) AS diagnosis
  FROM mimiciii.diagnoses_icd
  GROUP BY hadm_id
) d ON ne.hadm_id = d.hadm_id
LEFT JOIN mimiciii.admissions ad ON ne.hadm_id = ad.hadm_id 
WHERE ne.category = 'Discharge summary';
"""

In [28]:
cur.execute(query)
for row in cur.fetchall():
    # Extract the fields from the row
    note_id, chartdate, text, expire_flag, diagnoses = row
    doc = {
        'id': str(note_id),
        'chartdate': chartdate.strftime('%Y-%m-%dT%H:%M:%SZ'),
        'text': text,
        'expire_flag': expire_flag,
        'diagnoses': diagnoses   #all the ICD9 codes
    }
    #indexing 
    solr.add([doc])
solr.commit()
print("Finished indexing!")
conn.close()

'<?xml version="1.0" encoding="UTF-8"?>\n<response>\n\n<lst name="responseHeader">\n  <int name="status">0</int>\n  <int name="QTime">2859</int>\n</lst>\n</response>\n'

# Part 2: A Command line search system

In [68]:
import re
#This function is used to check the date format
def check_date_format(date_string):
    date_format = re.compile(r'\d{4}-\d{2}-\d{2}')
    return True if date_format.match(date_string) else False
def get_date_range():
    while True:
        date_range = input("Enter date range (e.g. 2022-01-01 to 2022-12-31): ")
        try:
            start_date, end_date = date_range.split(' to ')
        except ValueError:
            print("Invalid date range format. Please enter again.")
            continue
        if check_date_format(start_date) and check_date_format(end_date):
            start_date = datetime.strptime(start_date, '%Y-%m-%d')
            end_date = datetime.strptime(end_date, '%Y-%m-%d')
            #format the dates as strings in the ISO format expected by Solr
            start_date_iso = start_date.isoformat() + 'Z'
            end_date_iso = end_date.isoformat() + 'Z'
            #construct the Solr date query
            return start_date_iso, end_date_iso
        else:
            print("Invalid date range format. Please enter again.")
def get_expiration_flag():
    while True: 
        flag = input("Enter hospital expiration flag(0 or 1): ")
        if flag == "0" or flag == "1":
            return flag
        else:
            print("Invalid format. Please enter again(0 or 1).")
def get_icd9_codes():
    codes = input("Enter ICD9 codes: ")
    return codes.replace(".", "")
def get_text():
    text = input("Enter a word in note text: ")
    return text

In [96]:
import mysql.connector
def get_synonyms(text):
    synonyms = []
    query_syn = f"""
    select distinct m2.STR
    from MRCONSO as m1
    join MRREL on m1.CUI = MRREL.CUI1
    join MRCONSO as m2 on MRREL.CUI2 = m2.CUI
    where m1.SAB = 'CHV'
    and m1.LAT = 'ENG'
    and m2.SAB = 'CHV'
    and m2.LAT = 'ENG'
    and MRREL.REL = 'SY'
    and m1.STR = '{text}' limit 30;"""
    cnx = mysql.connector.connect(host='172.16.34.1', port='3307',
                            user='umls', password='umls', database='umls2022')
    cur = cnx.cursor()
    cur.execute(query_syn)
    for row in cur.fetchall():
        synonyms.append(row[0])
    cnx.close()
    return synonyms
#synonyms = get_synonyms('lung cancer')


In [113]:
def query_parser():
    queries = []
    sa_query = []
    da = input('Type yes if you want to enter a date range(e.g. 2022-01-01 to 2022-12-31) in your search or type anthing else to proceed ')
    if da.lower() == 'yes':
        start_date, end_date = get_date_range()
        date_query = 'chartdate:[{} TO {}]'.format(start_date, end_date)
        queries.append(date_query)
    else:
        date_query = None
    fa = input('Type yes if you want to enter a hospital expiration flag in your search or type anthing else to proceed: ')
    if fa.lower() == 'yes':
        flag = get_expiration_flag()
        flag_query = 'expire_flag:'+flag
        queries.append(flag_query)
    else:
        flag_query = None
    ca = input('Type yes if you want to enter a ICD9 code in your search or type anthing else to proceed: ')
    if ca.lower() == 'yes':
        ca = get_icd9_codes()
        ca_query = 'diagnoses:'+ca
        queries.append(ca_query)
    else:
        flag_query = None
    ta = input('Type yes if you want to enter a word in note text in your search or type anthing else to proceed: ')
    if ta.lower() == 'yes':
        text = get_text()
        text_query = f'text:"{text}"'
        queries.append(text_query)
        sa = input('Type yes if you want to find the synonyms of the text in your search or type anthing else to skip: ')
        if sa.lower() == 'yes':
            sa_queries = []
            synonyms = get_synonyms(text)
            print(f"The sysnonyms for {text} is:", synonyms)
            print("******************************************")
            for s in synonyms:
                sa_q = f'text:"{s}"'
                sa_queries.append(sa_q)
            sa_query = ' OR '.join(s for s in sa_queries)
    else:
        flag_query = None
    if queries !=[]:
        total_query = ' AND '.join(q for q in queries)
        if sa_query != []:
            total_query = total_query + ' OR ' + sa_query 
    else:
        total_query = None
    return total_query


def search_query():
    query_for_search = query_parser()
    while query_for_search == None:
        print("Missing query conditions. Please enter at least one query condition.")
        print("******************************************")
        query_for_search = query_parser()
    results = solr.search(query_for_search, fl='id', rows=20)
    ids = [res['id'] for res in results]
    print(f'Top 20 IDs matching "{query_for_search}":')
    for i in ids:
        print(i)   
    print("******************************************")
    total_doc_count = solr.search(query_for_search, rows=0)
    num_matches = total_doc_count.hits
    print(f'Total number of matching documents: {num_matches}')



In [None]:
while True:
    search_query()
    exit = input('Type exit to finish searching or enter anything else to start a new search: ')
    print("******************************************")
    if exit.lower()=='exit':
        break

Type yes if you want to enter a date range(e.g. 2022-01-01 to 2022-12-31) in your search or type anthing else to proceed 
Type yes if you want to enter a hospital expiration flag in your search or type anthing else to proceed: 
Type yes if you want to enter a ICD9 code in your search or type anthing else to proceed: 
Type yes if you want to enter a word in note text in your search or type anthing else to proceed: yes
Enter a word in note text: brain cancer
Type yes if you want to find the synonyms of the text in your search or type anthing else to skip: 
Top 20 IDs matching "text:"brain cancer"":
18463
16058
23397
43979
12030
3253
8412
7728
12319
14352
26653
46678
41557
49766
46804
11822
22272
35094
42633
15401
******************************************
Total number of matching documents: 137
Type exit to finish searching or enter anything else to start a new search: 
******************************************
Type yes if you want to enter a date range(e.g. 2022-01-01 to 2022-12-31) in