In [1]:
import mysql.connector
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import datetime
import math
import os
import sys
import logging
mf_module_path = os.path.abspath(os.path.join('../python'))
if mf_module_path not in sys.path:
    sys.path.append(mf_module_path)
import mf
import mf_random
import hpoutil
import networkx
import obonet
import pickle
from tqdm import tqdm
import time

**Connect to MySQL database**

In [2]:
mydb = mysql.connector.connect(host='localhost',
                               user='mimicuser',
                               passwd='mimic',
                               database='mimiciiiv13',
                              auth_plugin='mysql_native_password')

First approach to query mysql from python

Check that MySQL connection works properly

In [3]:
df = pd.read_sql_query("SELECT * FROM LABEVENTS LIMIT 5;", mydb)
df

Unnamed: 0,ROW_ID,SUBJECT_ID,HADM_ID,ITEMID,CHARTTIME,VALUE,VALUENUM,VALUEUOM,FLAG
0,1,2,163353,51143,2138-07-17 20:48:00,0,0.0,%,
1,2,2,163353,51144,2138-07-17 20:48:00,0,0.0,%,
2,3,2,163353,51146,2138-07-17 20:48:00,0,0.0,%,
3,4,2,163353,51200,2138-07-17 20:48:00,0,0.0,%,
4,5,2,163353,51221,2138-07-17 20:48:00,0,0.0,%,abnormal


Get a cursor so that it can be used later

In [4]:
cursor = mydb.cursor(buffered=True)

We explored several method to compute the synergy score for different diseases. Method 1-3 all worked but the time and space requirements are too high. See the archived file. Here, we use method 4 to compute phenotype pairwise synergies.


# Synergy between Lab-derived Abnormalities

## Algorithm

This method relies on the power of MySQL for doing queies and joins, return a batch of phenotype profiles a time, and then use the power of Numpy to do numeric computation.

Specificially, the method runs the following algorithm:

    1. For one diagnosis code, specify the phenotypes to analyze--a list of HPO terms.
    2. For a batch of patient*encounters, return a list of diagnosis codes (1 or 0)
    3. For the same batch of patient*encounters, return a list of phenotypes.
    4. Create a numpy array with dimension (N x P)
    5. Perform numeric computation with Numpy:
        outer product for ++ of PxP.T
        outer product for +- of Px(1-P).T
        outer product for -+ of (1-P)xP
        outer product for -- of (1-P)x(1-P).T
        combine the above with - and + of diagnosis value
        stack them together as a (N x P x P x 8) matrix.
        Step 1 - 5 are performed at each site. The resulting matrix is returned to JAX for final analyze.
    6. Compute pairwise synergy:
        use the multi-dimension array to calculate p(D = 1), p(D = 0), p(P1 * P2)
        compute mutual information of each phenotype in regarding to one diagnosis I(P:D)
        compute mutual information of two phenotypes in regarding to one diagnosis I(P:D)
        compute pairwise synergy
        

In [None]:
#TODO: rewrite to be backward compatible
def diagnosis_set():
    """Aggregate ICD9 codes with the first three digit and count how many times they appear. 
    Note this function uses encounters as the unit, meaning a code will counted twice if same patient was 
    diagnosed again at a later encounter."""
    diagnosis_count = pd.read_sql_query("SELECT SUBJECT_ID, HADM_ID, \
        CASE \
        WHEN(ICD9_CODE LIKE 'V%') THEN SUBSTRING(ICD9_CODE, 1, 3) \
        WHEN(ICD9_CODE LIKE 'E%') THEN SUBSTRING(ICD9_CODE, 1, 4) \
        ELSE SUBSTRING(ICD9_CODE, 1, 3) END AS ICD9 \
        FROM DIAGNOSES_ICD", mydb)
    diagnosisSet = diagnosis_count.drop_duplicates().groupby('ICD9').size().sort_values(ascending=False)
    return diagnosisSet

#TODO: rewrite to be backward compatible
def createAbnormalPhenotypeTable(threshold, include_inferred=True, force_update=True):
    """
    This is the abnormal phenotypes. 
    @include_inferred whether to include inferred HPO. Default true.
    @force_update whether current table, if present, should be forced to update
    """
    if force_update:
        cursor.execute('''DROP TEMPORARY TABLE IF EXISTS p''')
    if include_inferred:
        cursor.execute('''
                    CREATE TEMPORARY TABLE IF NOT EXISTS p
                    WITH abnorm AS (
                        SELECT
                            LABEVENTS.SUBJECT_ID, LABEVENTS.HADM_ID, LabHpo.MAP_TO
                        FROM 
                            LABEVENTS 
                        JOIN LabHpo on LABEVENTS.ROW_ID = LabHpo.ROW_ID
                        WHERE LabHpo.NEGATED = 'F'
                        
                        UNION ALL
                        
                        SELECT 
                            LABEVENTS.SUBJECT_ID, LABEVENTS.HADM_ID, INFERRED_LABHPO.INFERRED_TO AS MAP_TO 
                        FROM 
                            INFERRED_LABHPO 
                        JOIN 
                            LABEVENTS ON INFERRED_LABHPO.LABEVENT_ROW_ID = LABEVENTS.ROW_ID
                        )
                    SELECT SUBJECT_ID, HADM_ID, MAP_TO
                    FROM abnorm 
                    GROUP BY SUBJECT_ID, HADM_ID, MAP_TO
                    HAVING COUNT(*) > {}
                    -- parameter to control how to define an abnormal phenotype is present.
                '''.format(threshold))
    else:       
        cursor.execute('''
                    CREATE TEMPORARY TABLE IF NOT EXISTS p
                    WITH abnorm AS (
                        SELECT
                            LABEVENTS.SUBJECT_ID, LABEVENTS.HADM_ID, LabHpo.MAP_TO
                        FROM 
                            LABEVENTS 
                        JOIN LabHpo on LABEVENTS.ROW_ID = LabHpo.ROW_ID
                        WHERE LabHpo.NEGATED = 'F')
                    SELECT SUBJECT_ID, HADM_ID, MAP_TO
                    FROM abnorm 
                    GROUP BY SUBJECT_ID, HADM_ID, MAP_TO
                    HAVING COUNT(*) > {}
                    -- parameter to control how to define an abnormal phenotype is present.
                '''.format(threshold))
    cursor.execute('CREATE INDEX p_idx01 ON p (SUBJECT_ID, HADM_ID)')
    cursor.execute('CREATE INDEX p_idx02 ON p (MAP_TO);')


#TODO: rewrite to be backward compatible
def encountersWithDiagnosis(diagnosis):
    cursor.execute('''DROP TEMPORARY TABLE IF EXISTS d''')
    cursor.execute('''
        CREATE TEMPORARY TABLE IF NOT EXISTS d
        SELECT 
            DISTINCT SUBJECT_ID, HADM_ID, 1 AS DIAGNOSIS
        FROM 
            DIAGNOSES_ICD 
        WHERE ICD9_CODE LIKE '{}%'
        -- This is encounters with positive diagnosis
    '''.format(diagnosis))
    cursor.execute('CREATE INDEX d_idx01 ON d(SUBJECT_ID, HADM_ID)')

    
def createPhenotypeSet(diagnosis, threshold=1000):
    """
    Create the phenotypes that we should analyze. Exemely less frequently observed phenotypes are excluded.
    """
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS ps')
    cursor.execute('''
            CREATE TEMPORARY TABLE ps
            WITH pd AS(
                SELECT p.*
                FROM 
                    p JOIN (SELECT 
                                DISTINCT SUBJECT_ID, HADM_ID, 1 AS DIAGNOSIS
                            FROM 
                                DIAGNOSES_ICD 
                            WHERE ICD9_CODE LIKE '{}%') AS d
                    ON p.SUBJECT_ID = d.SUBJECT_ID AND p.HADM_ID = d.HADM_ID)
            SELECT 
                MAP_TO, COUNT(*) AS N, 1 AS PHENOTYPE
            FROM pd
            GROUP BY MAP_TO
            HAVING N > {}
            ORDER BY N DESC'''.format(diagnosis, threshold))
    phenoSet = pd.read_sql_query('SELECT * FROM ps', mydb)
    return phenoSet


def batch_query(start_index, end_index):
    batch_size_actual = pd.read_sql_query('''
                SELECT 
                    COUNT(DISTINCT SUBJECT_ID, HADM_ID) 
                FROM admissions 
                WHERE SUBJECT_ID BETWEEN {} AND {}
                '''.format(start_index, end_index), mydb).iloc[0,0]
    # create diagnosis table
    diagnosisList = pd.read_sql_query('''
                WITH a AS (
                    SELECT DISTINCT SUBJECT_ID, HADM_ID 
                    FROM admissions 
                    WHERE SUBJECT_ID BETWEEN {} AND {})
                SELECT 
                    a.SUBJECT_ID, a.HADM_ID, IF(d.DIAGNOSIS IS NULL, 0, 1) AS DIAGNOSIS
                FROM 
                    a
                LEFT JOIN
                    d ON a.SUBJECT_ID = d.SUBJECT_ID AND a.HADM_ID = d.HADM_ID         
                '''.format(start_index, end_index), mydb)
    # create phenotype profile table
    phenotyle_profile = pd.read_sql_query('''
        WITH 
            a AS (
                    SELECT 
                        DISTINCT SUBJECT_ID, HADM_ID 
                    FROM 
                        admissions 
                    WHERE SUBJECT_ID BETWEEN {} AND {}), 
            c as (
                SELECT a.*, ps.MAP_TO
                FROM a
                JOIN ps),
                -- cross product of all patient*encounter and phenotypes list
            pp as (
                SELECT p.*, 1 AS PHENOTYPE 
                FROM p RIGHT JOIN a 
                ON p.SUBJECT_ID = a.SUBJECT_ID AND p.HADM_ID = a.HADM_ID)

        SELECT c.SUBJECT_ID, c.HADM_ID, c.MAP_TO, IF(pp.PHENOTYPE IS NULL, 0, 1) AS PHENOTYPE 
        FROM pp 
        RIGHT JOIN c ON pp.SUBJECT_ID = c.SUBJECT_ID and pp.HADM_ID = c.HADM_ID AND pp.MAP_TO = c.MAP_TO
        '''.format(start_index, end_index), mydb)
    return batch_size_actual, diagnosisList, phenotyle_profile

In [None]:
def iterate_in_batch(logger):
    logger.info('starting iterate_in_batch()')
    batch_size = 100
    # find the set of diagnosis that are worthy to analyze
    diagnosisSet = diagnosis_set()
    logger.info('diagnosis set completed')

    # create a temp table for abnormal phenotypes of each patient*encounter that met the threshold
    #createAbnormalPhenotypeTable(threshold=1, force_update=True)
    logger.info('createAbnormalPhenotypeTable() completed')
    
    synergies = {}
    
    for diagnosis in diagnosisSet.keys():
        if (diagnosisSet[diagnosis] > 5000):
            # create a temp table for diagnosis of all patient*encouter to analyze
            encountersWithDiagnosis(diagnosis)
            logger.info('encountersWithDiagnosis() completed')

            ## create a list of phenotypes that we want to analyze for the specified disease and preset threshold
            phenoSet = createPhenotypeSet(diagnosis, threshold=100)
            logger.info('phenoSet completed')
            P_SIZE = len(phenoSet)

            ## find the start and end ROW_ID for patient*encounter
            ADM_ID_START, ADM_ID_END = pd.read_sql_query('SELECT MIN(ROW_ID) AS min, MAX(ROW_ID) AS max FROM admissions', mydb).iloc[0]
            batch_N = ADM_ID_END - ADM_ID_START + 1
            TOTAL_BATCH = math.ceil(batch_N / batch_size) # total number of batches
            synergies[diagnosis] = mf.SynergyWithinSet(diagnosis, phenoSet.MAP_TO)
            logger.info('starting batch queries for {}'.format(diagnosis))
            for i in np.arange(TOTAL_BATCH):
                start_index = i * batch_size + ADM_ID_START
                if i < TOTAL_BATCH - 1:
                    end_index = start_index + batch_size - 1
                else:
                    end_index = batch_N
                
                batch_size_actual, diagnosisList, phenotyle_profile = batch_query(start_index, end_index)
                
                if batch_size_actual > 0 :
                    diagnosisVector = diagnosisList.DIAGNOSIS
                    phenotypeProfileMatrix = phenotyle_profile.PHENOTYPE.values.reshape([batch_size_actual, P_SIZE])
                    if i % 100 == 0:
                        logger.info('new batch: start_index={}, end_index={}, batch_size= {}, phenotype_size = {}'.format(start_index, end_index, batch_size_actual, len(phenoSet)))
                    synergies[diagnosis].add_batch(phenotypeProfileMatrix, diagnosisVector)
    
    return synergies

It takes about 10 minutes to set up the phenotype table (p). Afterward, each disease takes about 10 minutes to complete the summary statistics.

In [None]:
logger = logging.getLogger()
logging.basicConfig(format='%(asctime)s | %(levelname)s : %(message)s',level=logging.DEBUG, stream=sys.stdout)
logger.setLevel(logging.DEBUG)

start = datetime.datetime.now()

synergies = iterate_in_batch(logger)
   
end = datetime.datetime.now()


In [None]:
print('running time: {}s'.format((end - start).total_seconds()))

In [None]:
with open('synergies.obj', 'wb') as synergies_file:
    pickle.dump(synergies, synergies_file)

close database connection

In [None]:
cursor.close()
mydb.close()

# Synergy between lab-derived and radiology report-derived Abnormalities

## Algorithm

The algorithm is about the same with the Method 3 in mutual_info_archive. Briefly, 

    * Select encounterOfInterest, temp table: JAX_encounterOfInterest(SUBJECT_ID, HADM_ID)
    * Init diagnosisProfile: temp table: JAX_diagnosisProfile(SUBJECT_ID, HADM_ID, ICD, N)
    * Init textHpoProfile: temp table: JAX_textHpoProfile(SUBJECT_ID, HADM_ID, MAP_TO, N)
    * Init labHpoProfile: temp table: JAX_labHpoProfile(SUBJECT_ID, HADM_ID, MAP_TO, N)
    
    * Rank ICD frequency, temp table: JAX_diagFrequencyRank(ICD, N)
      select diagOfInterest
    * Rank textHPO frequency, temp table: JAX_textHpoFrequencyRank(MAP_TO, N)
      select textHpoOfInterest
    * Rank labHPO frequency, temp table: JAX_labHpoFrequencyRank(MAP_TO, N)
      select labHpoOfInterest
    
    * Iteratation
      for diagnosis in diagOfInterest
          for textHpo in textHpoOfInterest
              for labHpo in labHpoOfInterest
                 Assign diagnosis value: assignDiagnosis(), table: (SUBJECT_ID, HADM_ID, DIAGNOSIS)
                 Assign text2hpo phenotype value: table: SUBJECT_ID, HADM_ID, PHEN_TEXT
                 Assign lab2hpo phenotype value: table: SUBJECT_ID, HADM_ID, PHEN_LAB

In [5]:
# define encounters of interest
def encounterOfInterest(debug=False, N=100):
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_encounterOfInterest')
    if debug:
        limit = 'LIMIT {}'.format(N)
    else:
        limit = ''
    # This is admissions that we want to analyze, 'LIMIT 100' in debug mode
    cursor.execute('''
                CREATE TEMPORARY TABLE IF NOT EXISTS JAX_encounterOfInterest(
                    ROW_ID MEDIUMINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY)
                
                SELECT 
                    DISTINCT SUBJECT_ID, HADM_ID 
                FROM admissions
                {}
                '''.format(limit))
    
def indexEncounterOfInterest():
    cursor.execute('CREATE INDEX JAX_encounterOfInterest_idx01 ON JAX_encounterOfInterest (SUBJECT_ID, HADM_ID)')
    
def diagnosisProfile():
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_diagnosisProfile')
    cursor.execute('''
                CREATE TEMPORARY TABLE IF NOT EXISTS JAX_diagnosisProfile
                SELECT 
                    DIAGNOSES_ICD.SUBJECT_ID, DIAGNOSES_ICD.HADM_ID, DIAGNOSES_ICD.ICD9_CODE, DIAGNOSES_ICD.SEQ_NUM
                FROM
                    DIAGNOSES_ICD
                RIGHT JOIN
                    JAX_encounterOfInterest
                ON 
                    DIAGNOSES_ICD.SUBJECT_ID = JAX_encounterOfInterest.SUBJECT_ID 
                    AND 
                    DIAGNOSES_ICD.HADM_ID = JAX_encounterOfInterest.HADM_ID
                ''')
    
def textHpoProfile(include_inferred=True, threshold=1):
    if include_inferred:
        cursor.execute('''
                    CREATE TEMPORARY TABLE IF NOT EXISTS JAX_textHpoProfile
                    WITH abnorm AS (
                        SELECT
                            NOTEEVENTS.SUBJECT_ID, NOTEEVENTS.HADM_ID, NoteHpoClinPhen.MAP_TO
                        FROM 
                            NOTEEVENTS 
                        JOIN NoteHpoClinPhen on NOTEEVENTS.ROW_ID = NoteHpoClinPhen.NOTES_ROW_ID
                        
                        UNION ALL
                        
                        SELECT
                            NOTEEVENTS.SUBJECT_ID, NOTEEVENTS.HADM_ID, Inferred_NoteHpo.INFERRED_TO AS MAP_TO
                        FROM 
                            NOTEEVENTS 
                        JOIN Inferred_NoteHpo on NOTEEVENTS.ROW_ID = Inferred_NoteHpo.NOTEEVENT_ROW_ID
                        )
                    SELECT SUBJECT_ID, HADM_ID, MAP_TO, COUNT(*) AS OCCURRANCE, 1 AS dummy
                    FROM abnorm 
                    GROUP BY SUBJECT_ID, HADM_ID, MAP_TO
                    HAVING COUNT(*) >= {}
                    -- parameter to control how to define an abnormal phenotype is present.
                '''.format(threshold))
        
    else:
        cursor.execute('''
                    CREATE TEMPORARY TABLE IF NOT EXISTS JAX_p_text
                    WITH abnorm AS (
                        SELECT
                            NOTEEVENTS.SUBJECT_ID, NOTEEVENTS.HADM_ID, NoteHpoClinPhen.MAP_TO
                        FROM 
                            NOTEEVENTS 
                        JOIN NoteHpoClinPhen on NOTEEVENTS.ROW_ID = NoteHpoClinPhen.NOTES_ROW_ID)
                    SELECT SUBJECT_ID, HADM_ID, MAP_TO
                    FROM abnorm 
                    GROUP BY SUBJECT_ID, HADM_ID, MAP_TO, COUNT(*) AS OCCURRANCE, 1 AS dummy
                    HAVING COUNT(*) >= {}
                    -- parameter to control how to define an abnormal phenotype is present.
                '''.format(threshold))
        
def indexTextHpoProfile():
    cursor.execute('CREATE INDEX JAX_textHpoProfile_idx01 ON JAX_textHpoProfile (SUBJECT_ID, HADM_ID)')
    cursor.execute('CREATE INDEX JAX_textHpoProfile_idx02 ON JAX_textHpoProfile (MAP_TO);')
    cursor.execute('CREATE INDEX JAX_textHpoProfile_idx03 ON JAX_textHpoProfile (SUBJECT_ID, HADM_ID, MAP_TO)')
    
def labHpoProfile(threshold, include_inferred=True, force_update=True):
    # TODO: refactor the method 
    #createAbnormalPhenotypeTable(threshold, include_inferred=True, force_update=True)
    
    if force_update:
        cursor.execute('''DROP TEMPORARY TABLE IF EXISTS JAX_labHpoProfile''')
    if include_inferred:
        cursor.execute('''
                    CREATE TEMPORARY TABLE IF NOT EXISTS JAX_labHpoProfile
                    WITH abnorm AS (
                        SELECT
                            LABEVENTS.SUBJECT_ID, LABEVENTS.HADM_ID, LabHpo.MAP_TO
                        FROM 
                            LABEVENTS 
                        JOIN LabHpo on LABEVENTS.ROW_ID = LabHpo.ROW_ID
                        WHERE LabHpo.NEGATED = 'F'
                        
                        UNION ALL
                        
                        SELECT 
                            LABEVENTS.SUBJECT_ID, LABEVENTS.HADM_ID, INFERRED_LABHPO.INFERRED_TO AS MAP_TO 
                        FROM 
                            INFERRED_LABHPO 
                        JOIN 
                            LABEVENTS ON INFERRED_LABHPO.LABEVENT_ROW_ID = LABEVENTS.ROW_ID
                        )
                    SELECT SUBJECT_ID, HADM_ID, MAP_TO, COUNT(*) AS OCCURRANCE, 1 AS dummy
                    FROM abnorm 
                    GROUP BY SUBJECT_ID, HADM_ID, MAP_TO
                    HAVING COUNT(*) >= {}
                    -- parameter to control how to define an abnormal phenotype is present.
                '''.format(threshold))
    else:       
        cursor.execute('''
                    CREATE TEMPORARY TABLE IF NOT EXISTS JAX_labHpoProfile
                    WITH abnorm AS (
                        SELECT
                            LABEVENTS.SUBJECT_ID, LABEVENTS.HADM_ID, LabHpo.MAP_TO
                        FROM 
                            LABEVENTS 
                        JOIN LabHpo on LABEVENTS.ROW_ID = LabHpo.ROW_ID
                        WHERE LabHpo.NEGATED = 'F')
                    SELECT SUBJECT_ID, HADM_ID, MAP_TO, COUNT(*) AS OCCURRANCE, 1 AS dummy
                    FROM abnorm 
                    GROUP BY SUBJECT_ID, HADM_ID, MAP_TO
                    HAVING COUNT(*) >= {}
                    -- parameter to control how to define an abnormal phenotype is present.
                '''.format(threshold))

def indexLabHpoProfile():
    cursor.execute('CREATE INDEX JAX_labHpoProfile_idx01 ON JAX_labHpoProfile (SUBJECT_ID, HADM_ID)')
    cursor.execute('CREATE INDEX JAX_labHpoProfile_idx02 ON JAX_labHpoProfile (MAP_TO);')
    cursor.execute('CREATE INDEX JAX_labHpoProfile_idx03 ON JAX_labHpoProfile (SUBJECT_ID, HADM_ID, MAP_TO)')
    
def rankICD():
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_diagFrequencyRank')
    cursor.execute("""
        CREATE TEMPORARY TABLE IF NOT EXISTS JAX_diagFrequencyRank
        WITH JAX_temp_diag AS (
            SELECT DISTINCT SUBJECT_ID, HADM_ID, 
                CASE 
                    WHEN(ICD9_CODE LIKE 'V%') THEN SUBSTRING(ICD9_CODE, 1, 3) 
                    WHEN(ICD9_CODE LIKE 'E%') THEN SUBSTRING(ICD9_CODE, 1, 4) 
                ELSE 
                    SUBSTRING(ICD9_CODE, 1, 3) END AS ICD9_CODE 
            FROM JAX_diagnosisProfile)
        SELECT 
            ICD9_CODE, COUNT(*) AS N
        FROM
            JAX_temp_diag
        GROUP BY 
            ICD9_CODE
        ORDER BY N
        DESC
        """)

def rankHpoFromText(diagnosis):
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_textHpoFrequencyRank')
    cursor.execute('''
            CREATE TEMPORARY TABLE JAX_textHpoFrequencyRank            
            WITH pd AS(
                SELECT 
                    JAX_textHpoProfile.*
                FROM 
                    JAX_textHpoProfile 
                JOIN (
                    SELECT 
                        DISTINCT SUBJECT_ID, HADM_ID
                    FROM 
                        JAX_diagnosisProfile 
                    WHERE 
                        ICD9_CODE LIKE '{}%') AS d
                ON 
                    JAX_textHpoProfile.SUBJECT_ID = d.SUBJECT_ID AND JAX_textHpoProfile.HADM_ID = d.HADM_ID)
            SELECT 
                MAP_TO, COUNT(*) AS N, 1 AS PHENOTYPE
            FROM pd
            GROUP BY MAP_TO
            ORDER BY N DESC'''.format(diagnosis))
    
def rankHpoFromLab(diagnosis):
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_labHpoFrequencyRank')
    cursor.execute('''
            CREATE TEMPORARY TABLE JAX_labHpoFrequencyRank            
            WITH pd AS(
                SELECT 
                    JAX_labHpoProfile.*
                FROM 
                    JAX_labHpoProfile 
                JOIN (
                    SELECT 
                        DISTINCT SUBJECT_ID, HADM_ID
                    FROM 
                        JAX_diagnosisProfile 
                    WHERE 
                        ICD9_CODE LIKE '{}%') AS d
                ON 
                    JAX_labHpoProfile.SUBJECT_ID = d.SUBJECT_ID AND JAX_labHpoProfile.HADM_ID = d.HADM_ID)
            SELECT 
                MAP_TO, COUNT(*) AS N, 1 AS PHENOTYPE
            FROM pd
            GROUP BY MAP_TO
            ORDER BY N DESC'''.format(diagnosis))           

In [6]:
# assign 0 or 1 to each encouter whether a diagnosis is observed
def createDiagnosisTable(diagnosis, primary_diagnosis_only):
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_mf_diag')
    if primary_diagnosis_only:
        limit = 'AND SEQ_NUM=1'
    else:
        limit = ''
    cursor.execute('''
                CREATE TEMPORARY TABLE IF NOT EXISTS JAX_mf_diag 
                WITH 
                    d AS (
                        SELECT 
                            DISTINCT SUBJECT_ID, HADM_ID, '1' AS DIAGNOSIS
                        FROM 
                            JAX_diagnosisProfile 
                        WHERE ICD9_CODE LIKE '{}%' {})
                    -- This is encounters with positive diagnosis

                SELECT 
                    DISTINCT a.SUBJECT_ID, a.HADM_ID, IF(d.DIAGNOSIS IS NULL, '0', '1') AS DIAGNOSIS
                FROM 
                    JAX_encounterOfInterest AS a
                LEFT JOIN
                    d ON a.SUBJECT_ID = d.SUBJECT_ID AND a.HADM_ID = d.HADM_ID       
                /* -- This is the first join for diagnosis (0, or 1) */    
                '''.format(diagnosis, limit))
    cursor.execute('CREATE INDEX JAX_mf_diag_idx01 ON JAX_mf_diag (SUBJECT_ID, HADM_ID)')

# assign 0 or 1 to each encounter whether a phenotype is observed from radiology reports
def diagnosisTextHpo(phenotype):
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_mf_diag_textHpo')
    """
    cursor.execute('''
        CREATE TEMPORARY TABLE JAX_mf_diag_textHpo
        SELECT 
            L.*, IF(R.MAP_TO IS NULL, '0', '1') AS PHEN_TXT
        FROM JAX_mf_diag AS L 
        LEFT JOIN 
            (SELECT * 
            FROM JAX_textHpoProfile 
            WHERE JAX_textHpoProfile.MAP_TO = '{}') AS R 
        ON L.SUBJECT_ID = R.SUBJECT_ID AND L.HADM_ID = R.HADM_ID 
    '''.format(phenotype))
    """
    cursor.execute('''
        CREATE TEMPORARY TABLE JAX_mf_diag_textHpo
        WITH L AS (SELECT JAX_mf_diag.*, '{}' AS PHEN_TXT FROM JAX_mf_diag)
        SELECT 
            L.*, IF(R.dummy IS NULL, '0', '1') AS PHEN_TXT_VALUE
        FROM L 
        LEFT JOIN 
            JAX_textHpoProfile AS R
        ON L.SUBJECT_ID = R.SUBJECT_ID AND L.HADM_ID = R.HADM_ID AND L.PHEN_TXT = R.MAP_TO
    '''.format(phenotype))
    cursor.execute('CREATE INDEX JAX_mf_diag_textHpo_idx01 ON JAX_mf_diag_textHpo (SUBJECT_ID, HADM_ID)')

def diagnosisAllTextHpo(threshold_min, threshold_max):
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_mf_diag_allTextHpo')
    cursor.execute('''
        CREATE TEMPORARY TABLE JAX_mf_diag_allTextHpo
        WITH 
            P AS (SELECT MAP_TO AS PHEN_TXT FROM JAX_textHpoFrequencyRank WHERE N BETWEEN {} AND {}),
            L AS (SELECT * FROM JAX_mf_diag JOIN P)
        SELECT 
            L.*, IF(R.dummy IS NULL, '0', '1') AS PHEN_TXT_VALUE
        FROM L 
        LEFT JOIN 
            JAX_textHpoProfile AS R
        ON L.SUBJECT_ID = R.SUBJECT_ID AND L.HADM_ID = R.HADM_ID AND L.PHEN_TXT = R.MAP_TO
    '''.format(threshold_min, threshold_max))
    cursor.execute('CREATE INDEX JAX_mf_diag_allTextHpo_idx01 ON JAX_mf_diag_allTextHpo (SUBJECT_ID, HADM_ID, PHEN_TXT)')
    
def diagnosisLabHpo(phenotype):
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_mf_diag_labHpo')
    cursor.execute('''
        CREATE TEMPORARY TABLE JAX_mf_diag_labHpo
        WITH L AS (SELECT JAX_mf_diag.*, '{}' AS PHEN_LAB FROM JAX_mf_diag)
        SELECT 
            L.*, IF(R.dummy IS NULL, '0', '1') AS PHEN_LAB_VALUE
        FROM L 
        LEFT JOIN 
             JAX_labHpoProfile AS R 
        ON L.SUBJECT_ID = R.SUBJECT_ID AND L.HADM_ID = R.HADM_ID AND L.PHEN_LAB = R.MAP_TO
    '''.format(phenotype))
    cursor.execute('CREATE INDEX JAX_mf_diag_labHpo_idx01 ON JAX_mf_diag_labHpo (SUBJECT_ID, HADM_ID)')
    
def diagnosisAllLabHpo(threshold_min, threshold_max):
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_mf_diag_allLabHpo')
    cursor.execute('''
        CREATE TEMPORARY TABLE JAX_mf_diag_allLabHpo
        WITH 
            P AS (SELECT MAP_TO AS PHEN_LAB FROM JAX_labHpoFrequencyRank WHERE N BETWEEN {} AND {}),
            L AS (SELECT * FROM JAX_mf_diag JOIN P)
        SELECT 
            L.*, IF(R.dummy IS NULL, '0', '1') AS PHEN_LAB_VALUE
        FROM L 
        LEFT JOIN 
             JAX_labHpoProfile AS R 
        ON L.SUBJECT_ID = R.SUBJECT_ID AND L.HADM_ID = R.HADM_ID AND L.PHEN_LAB = R.MAP_TO
    '''.format(threshold_min, threshold_max))
    cursor.execute('CREATE INDEX JAX_mf_diag_allLabHpo_idx01 ON JAX_mf_diag_allLabHpo (SUBJECT_ID, HADM_ID)')

def diagnosisTextLab(phenotype):
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_mf_diag_txtHpo_labHpo')
    result = cursor.execute('''
        CREATE TEMPORARY TABLE JAX_mf_diag_txtHpo_labHpo 
        WITH L AS (SELECT JAX_mf_diag_textHpo.*, '{}' AS PHEN_LAB FROM JAX_mf_diag_textHpo)
        SELECT L.*, IF(R.dummy IS NULL, '0', '1') AS PHEN_LAB_VALUE
        FROM L 
        LEFT JOIN 
            JAX_labHpoProfile AS R 
        ON L.SUBJECT_ID = R.SUBJECT_ID AND L.HADM_ID = R.HADM_ID AND L.PHEN_LAB = R.MAP_TO
    '''.format(phenotype))
    
    
def diagnosisAllTextAllLab():
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_mf_diag_allTxtHpo_allLabHpo')
    cursor.execute('''
        CREATE TEMPORARY TABLE JAX_mf_diag_allTxtHpo_allLabHpo 
        SELECT L.SUBJECT_ID, L.HADM_ID, L.DIAGNOSIS, L.PHEN_TXT, L.PHEN_TXT_VALUE, R.PHEN_LAB, R.PHEN_LAB_VALUE 
        FROM JAX_mf_diag_allTextHpo AS L 
        JOIN JAX_mf_diag_allLabHpo AS R
        ON L.SUBJECT_ID = R.SUBJECT_ID AND L.HADM_ID = R.HADM_ID
    ''')
    

def initSummaryStatisticTables():
    # define empty columns to store summary statistics
    summary_statistics1_radiology = pd.DataFrame(data={'DIAGNOSIS_CODE':[], 
                       'PHENOTYPE':[], 
                       'DIAGNOSIS_VALUE':[], 
                       'PHENOTYPE_VALUE':[], 
                       'N':[]},
                columns = ['DIAGNOSIS_CODE', 'PHENOTYPE', 'DIAGNOSIS_VALUE', 'PHENOTYPE_VALUE', 'N'])
    
    summary_statistics1_lab = pd.DataFrame(data={'DIAGNOSIS_CODE':[], 
                       'PHENOTYPE':[], 
                       'DIAGNOSIS_VALUE':[], 
                       'PHENOTYPE_VALUE':[], 
                       'N':[]},
                columns = ['DIAGNOSIS_CODE', 'PHENOTYPE', 'DIAGNOSIS_VALUE', 'PHENOTYPE_VALUE', 'N'])

    summary_statistics2 = pd.DataFrame(data={'DIAGNOSIS_CODE':[], 
                       'PHEN_TXT':[], 
                       'PHEN_LAB':[], 
                       'DIAGNOSIS_VALUE':[], 
                       'PHEN_TXT_VALUE':[], 
                       'PHEN_LAB_VALUE':[], 
                       'N':[]},
                columns = ['DIAGNOSIS_CODE', 'PHEN_TXT', 'PHEN_LAB', 'DIAGNOSIS_VALUE', 'PHEN_TXT_VALUE', 'PHEN_LAB_VALUE', 'N']) 

    return summary_statistics1_radiology, summary_statistics1_lab, summary_statistics2

def initTables(debug=False):
    """
    This combines LabHpo and Inferred_LabHpo, and combines TextHpo and Inferred_TextHpo. 
    Only need to run once. For efficiency consideration, the tables can also be created as perminent. 
    It is time-consuming, so call it with caution. 
    """
    #init textHpoProfile and index it
    #I create perminant tables to save time; other users should enable them
    #textHpoProfile(include_inferred=True, threshold=1)
    #indexTextHpoProfile()
    #init labHpoProfile and index it
    #labHpoProfile(threshold=1, include_inferred=True, force_update=True)
    #indexLabHpoProfile()
    
    #define encounters to analyze
    encounterOfInterest(debug)
    indexEncounterOfInterest()
    #init diagnosisProfile
    diagnosisProfile()
    

def iterate(primary_diagnosis_only, diagnosis_threshold_min, textHpo_threshold_min, labHpo_threshold_min, logger): 
    logger.info('starting iterating...................................')
    N = pd.read_sql_query("SELECT count(*) FROM JAX_encounterOfInterest", mydb)
    # init empty tables to hold summary statistics
    summary_statistics1_radiology, summary_statistics1_lab, summary_statistics2 = initSummaryStatisticTables()
    
    # define a set of diseases that we want to analyze
    rankICD()
    
    diseaseOfInterest = pd.read_sql_query("SELECT * FROM JAX_diagFrequencyRank WHERE N > {}".format(diagnosis_threshold_min), mydb).ICD9_CODE.values
    diseaseOfInterest = ['428']
    # define encounters to analyze
    logger.info('diseases of interest established: {}'.format(len(diseaseOfInterest)))
    for diagnosis in diseaseOfInterest:
        logger.info("start analyzing disease {}".format(diagnosis))
        
        # assign each encounter whether a diagnosis code is observed
        # create a table j1 (joint 1)
        createDiagnosisTable(diagnosis, primary_diagnosis_only)
        # for every diagnosis, find phenotypes of interest to look at from radiology reports
        # for every diagnosis, find phenotypes of interest to look at from laboratory tests
        rankHpoFromText(diagnosis)
        rankHpoFromLab(diagnosis)
        
        textHpoOfInterest = pd.read_sql_query("SELECT * FROM JAX_textHpoFrequencyRank WHERE N > {}".format(textHpo_threshold_min), mydb).MAP_TO.values
        labHpoOfInterest = pd.read_sql_query("SELECT * FROM JAX_labHpoFrequencyRank WHERE N > {}".format(labHpo_threshold_min), mydb).MAP_TO.values
        logger.info("TextHpo of interest established, size: {}".format(len(textHpoOfInterest)))
        logger.info("LabHpo of interest established, size: {}".format(len(labHpoOfInterest)))
        for textHpo in textHpoOfInterest:
            logger.info("iteration: TextHpo--{}".format(textHpo))
            # assign each encounter whether a phenotype is observed from radiology reports
            diagnosisTextHpo(textHpo)            
            result1_text = pd.read_sql_query('''
                SELECT 
                    '{}' AS DIAGNOSIS_CODE, '{}' AS PHENOTYPE, DIAGNOSIS AS DIAGNOSIS_VALUE, PHEN_TXT_VALUE AS PHENOTYPE_VALUE, COUNT(*) AS N 
                FROM JAX_mf_diag_textHpo 
                GROUP BY 
                    DIAGNOSIS, PHEN_TXT_VALUE
            '''.format(diagnosis, textHpo), mydb)
            summary_statistics1_radiology = summary_statistics1_radiology.append(result1_text)
            # summary statistics for p1
            # calculate I(p1;D)
            for labHpo in labHpoOfInterest:
                logger.info(".........LabHpo--{}".format(labHpo))
                diagnosisLabHpo(labHpo)
                result1_lab = pd.read_sql_query('''
                    SELECT 
                        '{}' AS DIAGNOSIS_CODE, '{}' AS PHENOTYPE, DIAGNOSIS AS DIAGNOSIS_VALUE, PHEN_LAB_VALUE AS PHENOTYPE_VALUE, COUNT(*) AS N 
                    FROM 
                        JAX_mf_diag_labHpo 
                    GROUP BY DIAGNOSIS, PHEN_LAB_VALUE
                '''.format(diagnosis, labHpo), mydb)
                summary_statistics1_lab = summary_statistics1_lab.append(result1_lab)
            
                # assign each encounter whether a phenotype is observed from lab tests
                diagnosisTextLab(labHpo)
                result2 = pd.read_sql_query('''
                    SELECT 
                        '{}' AS DIAGNOSIS_CODE, 
                        '{}' AS PHEN_TXT, 
                        '{}' AS PHEN_LAB,  
                        DIAGNOSIS AS DIAGNOSIS_VALUE, 
                        PHEN_TXT_VALUE, 
                        PHEN_LAB_VALUE, 
                        COUNT(*) AS N
                    FROM JAX_mf_diag_txtHpo_labHpo 
                    GROUP BY DIAGNOSIS, PHEN_TXT_VALUE, PHEN_LAB_VALUE
                '''.format(diagnosis, textHpo, labHpo), mydb)
                summary_statistics2 = summary_statistics2.append(result2)
    logger.info('end iterating.....................................')            
    return N, summary_statistics1_radiology, summary_statistics1_lab, summary_statistics2 


def iterate_batch(primary_diagnosis_only, diagnosis_threshold_min, textHpo_threshold_min, textHpo_threshold_max, labHpo_threshold_min, labHpo_threshold_max, logger): 
    logger.info('starting iterating...................................')
    N = pd.read_sql_query("SELECT count(*) FROM JAX_encounterOfInterest", mydb)
    # init empty tables to hold summary statistics
    summary_statistics1_radiology, summary_statistics1_lab, summary_statistics2 = initSummaryStatisticTables()
    
    # define a set of diseases that we want to analyze
    rankICD()
    
    diseaseOfInterest = pd.read_sql_query("SELECT * FROM JAX_diagFrequencyRank WHERE N > {}".format(diagnosis_threshold_min), mydb).ICD9_CODE.values
    diseaseOfInterest = ['428']
    logger.info('diseases of interest established: {}'.format(len(diseaseOfInterest)))
    
    for diagnosis in diseaseOfInterest:
        logger.info("start analyzing disease {}".format(diagnosis))
        
        logger.info(".......assigning values of diagnosis")
        # assign each encounter whether a diagnosis code is observed
        # create a table j1 (joint 1)
        createDiagnosisTable(diagnosis, primary_diagnosis_only)
        # for every diagnosis, find phenotypes of interest to look at from radiology reports
        # for every diagnosis, find phenotypes of interest to look at from laboratory tests
        rankHpoFromText(diagnosis)
        rankHpoFromLab(diagnosis)
        logger.info("..............diagnosis values found")
        
        logger.info(".......assigning values of TextHpo")
        diagnosisAllTextHpo(textHpo_threshold_min, textHpo_threshold_max)
        result1_text = pd.read_sql_query("""
            SELECT '{}' AS DIAGNOSIS_CODE, 
                PHEN_TXT AS PHENOTYPE, 
                DIAGNOSIS AS DIAGNOSIS_VALUE, 
                PHEN_TXT_VALUE AS PHENOTYPE_VALUE, 
                COUNT(*) AS N 
            FROM JAX_mf_diag_allTextHpo 
            GROUP BY DIAGNOSIS, PHEN_TXT, PHEN_TXT_VALUE
        """.format(diagnosis), mydb)
        logger.info("..............TextHpo values found")
        summary_statistics1_radiology = summary_statistics1_radiology.append(result1_text)

        
        logger.info(".......assigning values of LabHpo")
        diagnosisAllLabHpo(labHpo_threshold_min, labHpo_threshold_max)
        result1_lab = pd.read_sql_query("""
            SELECT 
                '{}' AS DIAGNOSIS_CODE, 
                PHEN_LAB AS PHENOTYPE, 
                DIAGNOSIS AS DIAGNOSIS_VALUE, 
                PHEN_LAB_VALUE AS PHENOTYPE_VALUE, 
                COUNT(*) AS N 
            FROM JAX_mf_diag_allLabHpo 
            GROUP BY DIAGNOSIS, PHEN_LAB, PHEN_LAB_VALUE
        """.format(diagnosis), mydb)
        logger.info("..............LabHpo values found")
        summary_statistics1_lab = summary_statistics1_lab.append(result1_lab)

        logger.info(".......building diagnosis-TextHpo-LabHpo joint distribution")
        diagnosisAllTextAllLab()
        result2 = pd.read_sql_query("""
            SELECT 
                '{}' AS DIAGNOSIS_CODE, 
                PHEN_TXT, 
                PHEN_LAB, 
                DIAGNOSIS AS DIAGNOSIS_VALUE,
                PHEN_TXT_VALUE, 
                PHEN_LAB_VALUE, 
                COUNT(*) AS N 
            FROM JAX_mf_diag_allTxtHpo_allLabHpo 
            GROUP BY DIAGNOSIS, PHEN_LAB, PHEN_LAB_VALUE, PHEN_TXT, PHEN_TXT_VALUE
        """.format(diagnosis) , mydb)
        logger.info("..............diagnosis-TextHpo-LabHpo joint distribution built")
        summary_statistics2 = summary_statistics2.append(result2)

    logger.info('end iterating.....................................')            
    return N, summary_statistics1_radiology, summary_statistics1_lab, summary_statistics2 

In [None]:
# how to run this
# Again, it take either too long or too much memory space to run
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

# 1. build the temp tables for Lab converted HPO, Text convert HPO
# Read the comments within the method!
initTables(debug=False)
# 2. iterate the database t (for debug, use parameter values: 0, 10, 15, for production, use parameter values: 0, 10000, 10000
#N, summary_statistics1_radiology, summary_statistics1_lab, summary_statistics2 = iterate(diagnosis_threshold_min=0, textHpo_threshold_min=10, labHpo_threshold_min=15, logger=logger)
#N, summary_statistics1_radiology, summary_statistics1_lab, summary_statistics2 = iterate(diagnosis_threshold_min=0, textHpo_threshold_min=1000, labHpo_threshold_min=1000, logger=logger)

# 2b. use the batch method
#N2, summary_statistics1_radiology2, summary_statistics1_lab2, summary_statistics22 = iterate_batch(diagnosis_threshold_min=0, textHpo_threshold_min=0, textHpo_threshold_max=100, labHpo_threshold_min=0, labHpo_threshold_max=100, logger=logger)
N2, summary_statistics1_radiology2, summary_statistics1_lab2, summary_statistics22 = iterate_batch(diagnosis_threshold_min=0, textHpo_threshold_min=1000, textHpo_threshold_max=100000, labHpo_threshold_min=1000, labHpo_threshold_max=100000, logger=logger)

In [None]:
N

In [None]:
summary_statistics1_radiology.head()
#summary_statistics1_radiology
#summary_statistics1_radiology2.groupby('PHENOTYPE').agg({'N': sum})

In [None]:
summary_statistics1_radiology2.head()
summary_statistics1_radiology.merge(summary_statistics1_radiology2, on = ['DIAGNOSIS_CODE', 'PHENOTYPE', 'DIAGNOSIS_VALUE', 'PHENOTYPE_VALUE'])
c =summary_statistics1_lab.merge(summary_statistics1_lab2, on = ['DIAGNOSIS_CODE', 'PHENOTYPE', 'DIAGNOSIS_VALUE', 'PHENOTYPE_VALUE'])

In [None]:
#summary_statistics2.groupby(['PHEN_TXT', 'PHEN_LAB']).agg({'N': sum})
c = summary_statistics2.merge(summary_statistics22, on = ['DIAGNOSIS_CODE', 'PHEN_TXT', 'PHEN_LAB', 'DIAGNOSIS_VALUE', 'PHEN_TXT_VALUE', 'PHEN_LAB_VALUE'] )
#summary_statistics2.head()
c.loc[c.N_x != c.N_y, :]

In [91]:
def indexDiagnosisTable():
    cursor.execute("ALTER TABLE JAX_mf_diag ADD COLUMN ROW_ID INT AUTO_INCREMENT PRIMARY KEY;")
    
def batch_query(start_index, end_index, textHpo_threshold_min, textHpo_threshold_max, labHpo_threshold_min, labHpo_threshold_max):
    
    diagnosisVector = pd.read_sql_query('''
        SELECT * FROM JAX_mf_diag WHERE ROW_ID BETWEEN {} AND {}
    '''.format(start_index, end_index), mydb)
    
    textHpoFlat = pd.read_sql_query('''
        WITH encounters AS (
            SELECT SUBJECT_ID, HADM_ID
            FROM JAX_mf_diag 
            WHERE ROW_ID BETWEEN {} AND {}
        ), 
        textHpoOfInterest AS (
            SELECT MAP_TO 
            FROM JAX_textHpoFrequencyRank 
            WHERE N BETWEEN {} AND {}
        ), 
        joint as (
            SELECT *
            FROM encounters 
            JOIN textHpoOfInterest)
        
        SELECT L.SUBJECT_ID, L.HADM_ID, L.MAP_TO, IF(R.dummy IS NULL, 0, 1) AS VALUE
        FROM joint as L
        LEFT JOIN JAX_textHpoProfile AS R
        ON L.SUBJECT_ID = R.SUBJECT_ID AND L.HADM_ID = R.HADM_ID AND L.MAP_TO = R.MAP_TO
    '''.format(start_index, end_index, textHpo_threshold_min, textHpo_threshold_max), mydb)
    
    labHpoFlat = pd.read_sql_query('''
        WITH encounters AS (
            SELECT SUBJECT_ID, HADM_ID
            FROM JAX_mf_diag 
            WHERE ROW_ID BETWEEN {} AND {}
        ), 
        labHpoOfInterest AS (
            SELECT MAP_TO 
            FROM JAX_labHpoFrequencyRank 
            WHERE N BETWEEN {} AND {}
        ), 
        joint as (
            SELECT *
            FROM encounters 
            JOIN labHpoOfInterest)
        
        SELECT L.SUBJECT_ID, L.HADM_ID, L.MAP_TO, IF(R.dummy IS NULL, 0, 1) AS VALUE
        FROM joint as L
        LEFT JOIN JAX_labHpoProfile AS R
        ON L.SUBJECT_ID = R.SUBJECT_ID AND L.HADM_ID = R.HADM_ID AND L.MAP_TO = R.MAP_TO
    '''.format(start_index, end_index, labHpo_threshold_min, labHpo_threshold_max), mydb)
    
    return diagnosisVector, textHpoFlat, labHpoFlat

def iterate_in_batch(primary_diagnosis_only, diagnosis_threshold_min, textHpo_threshold_min, textHpo_threshold_max, labHpo_threshold_min, labHpo_threshold_max, logger):
    logger.info('starting iterate_in_batch()')
    batch_size = 100
    
    # define a set of diseases that we want to analyze
    rankICD()
    
    diseaseOfInterest = pd.read_sql_query("SELECT * FROM JAX_diagFrequencyRank WHERE N > {}".format(diagnosis_threshold_min), mydb).ICD9_CODE.values
    diseaseOfInterest = ['428', '584', '038']
    logger.info('diagnosis of interest: {}'.format(len(diseaseOfInterest)))
    
    synergies = {}
    
    pbar = tqdm(total=len(diseaseOfInterest))
    for diagnosis in diseaseOfInterest:
        logger.info("start analyzing disease {}".format(diagnosis))
        
        logger.info(".......assigning values of diagnosis")
        # assign each encounter whether a diagnosis code is observed
        # create a table j1 (joint 1)
        createDiagnosisTable(diagnosis, primary_diagnosis_only)
        indexDiagnosisTable()
        # for every diagnosis, find phenotypes of interest to look at from radiology reports
        # for every diagnosis, find phenotypes of interest to look at from laboratory tests
        rankHpoFromText(diagnosis)
        rankHpoFromLab(diagnosis)
        logger.info("..............diagnosis values found")
        
        textHpoOfInterest = pd.read_sql_query("SELECT * FROM JAX_textHpoFrequencyRank WHERE N BETWEEN {} AND {}".format(textHpo_threshold_min, textHpo_threshold_max), mydb).MAP_TO.values
        labHpoOfInterest = pd.read_sql_query("SELECT * FROM JAX_labHpoFrequencyRank WHERE N BETWEEN {} AND {}".format(labHpo_threshold_min, labHpo_threshold_max), mydb).MAP_TO.values
        logger.info("TextHpo of interest established, size: {}".format(len(textHpoOfInterest)))
        logger.info("LabHpo of interest established, size: {}".format(len(labHpoOfInterest)))

        ## find the start and end ROW_ID for patient*encounter
        ADM_ID_START, ADM_ID_END = pd.read_sql_query('SELECT MIN(ROW_ID) AS min, MAX(ROW_ID) AS max FROM JAX_mf_diag', mydb).iloc[0]
        batch_N = ADM_ID_END - ADM_ID_START + 1
        TOTAL_BATCH = math.ceil(batch_N / batch_size) # total number of batches
        
        synergies[diagnosis] = mf.MutualInfoXYz(textHpoOfInterest, labHpoOfInterest, diagnosis)
        
        logger.info('starting batch queries for {}'.format(diagnosis))
        for i in np.arange(TOTAL_BATCH):
            start_index = i * batch_size + ADM_ID_START
            if i < TOTAL_BATCH - 1:
                end_index = start_index + batch_size - 1
            else:
                end_index = batch_N

            diagnosisFlat, textHpoFlat, labHpoFlat =  batch_query(start_index, end_index, textHpo_threshold_min, textHpo_threshold_max, labHpo_threshold_min, labHpo_threshold_max)

            batch_size_actual = len(diagnosisFlat)
            textHpoOfInterest_size = len(textHpoOfInterest)
            labHpoOfInterest_size = len(labHpoOfInterest)
            assert(len(textHpoFlat) == batch_size_actual * textHpoOfInterest_size)
            assert(len(labHpoFlat) == batch_size_actual * labHpoOfInterest_size)
            
            if batch_size_actual > 0:
                diagnosisVector = diagnosisFlat.DIAGNOSIS.values.astype(int)
                # reformat the flat vector into N x M matrix, N is batch size, i.e. number of encounters, M is the length of HPO terms  
                textHpoMatrix = textHpoFlat.VALUE.values.astype(int).reshape([batch_size_actual, textHpoOfInterest_size], order='F')
                labHpoMatrix = labHpoFlat.VALUE.values.astype(int).reshape([batch_size_actual, labHpoOfInterest_size], order='F')
                # check the matrix formatting is correct
                # disable the following 4 lines to speed things up
                textHpoLabelsMatrix = textHpoFlat.MAP_TO.values.reshape([batch_size_actual, textHpoOfInterest_size], order='F')
                labHpoLabelsMatrix = labHpoFlat.MAP_TO.values.reshape([batch_size_actual, labHpoOfInterest_size], order='F')
                assert (textHpoLabelsMatrix[0, :] == textHpoOfInterest).all()
                assert (labHpoLabelsMatrix[0, :] == labHpoOfInterest).all()
                if i % 100 == 0:
                    logger.info('new batch: start_index={}, end_index={}, batch_size= {}, textHpo_size = {}, labHpo_size = {}'.format(start_index, end_index, batch_size_actual, textHpoMatrix.shape[1], labHpoMatrix.shape[1]))
                synergies[diagnosis].add_batch(textHpoMatrix,labHpoMatrix, diagnosisVector)
         
        pbar.update(1)
        
    pbar.close()
    
    return synergies

##  --TEST

In [92]:
# how to run this

# Again, it take either too long or too much memory space to run
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# 1. build the temp tables for Lab converted HPO, Text convert HPO
# Read the comments within the method!
initTables(debug=True)

# 2. iterate throw the dataset
primary_diagnosis_only = True
diagnosis_threshold_min = 5
textHpo_threshold_min, textHpo_threshold_max = 7, 100
labHpo_threshold_min, labHpo_threshold_max = 7, 100

synergies = iterate_in_batch(primary_diagnosis_only, diagnosis_threshold_min, textHpo_threshold_min, textHpo_threshold_max, labHpo_threshold_min, labHpo_threshold_max, logger)

2019-10-04 16:09:59,274 - 10502 - root - INFO - starting iterate_in_batch()
2019-10-04 16:09:59,279 - 10502 - root - INFO - diagnosis of interest: 3


  0%|          | 0/3 [00:00<?, ?it/s]

2019-10-04 16:09:59,281 - 10502 - root - INFO - start analyzing disease 428
2019-10-04 16:09:59,282 - 10502 - root - INFO - .......assigning values of diagnosis
2019-10-04 16:09:59,292 - 10502 - root - INFO - ..............diagnosis values found
2019-10-04 16:09:59,297 - 10502 - root - INFO - TextHpo of interest established, size: 28
2019-10-04 16:09:59,298 - 10502 - root - INFO - LabHpo of interest established, size: 93
2019-10-04 16:09:59,300 - 10502 - root - INFO - starting batch queries for 428
2019-10-04 16:09:59,439 - 10502 - root - INFO - new batch: start_index=1, end_index=100, batch_size= 100, textHpo_size = 28, labHpo_size = 93


 33%|███▎      | 1/3 [00:00<00:00,  6.05it/s]

2019-10-04 16:09:59,447 - 10502 - root - INFO - start analyzing disease 584
2019-10-04 16:09:59,448 - 10502 - root - INFO - .......assigning values of diagnosis
2019-10-04 16:09:59,458 - 10502 - root - INFO - ..............diagnosis values found
2019-10-04 16:09:59,462 - 10502 - root - INFO - TextHpo of interest established, size: 9
2019-10-04 16:09:59,463 - 10502 - root - INFO - LabHpo of interest established, size: 42
2019-10-04 16:09:59,466 - 10502 - root - INFO - starting batch queries for 584
2019-10-04 16:09:59,535 - 10502 - root - INFO - new batch: start_index=1, end_index=100, batch_size= 100, textHpo_size = 9, labHpo_size = 42
2019-10-04 16:09:59,537 - 10502 - root - INFO - start analyzing disease 038
2019-10-04 16:09:59,537 - 10502 - root - INFO - .......assigning values of diagnosis
2019-10-04 16:09:59,548 - 10502 - root - INFO - ..............diagnosis values found
2019-10-04 16:09:59,553 - 10502 - root - INFO - TextHpo of interest established, size: 14
2019-10-04 16:09:59,

100%|██████████| 3/3 [00:00<00:00,  6.88it/s]


## --PRODUCTION

In [93]:
# how to run this
# Again, it take either too long or too much memory space to run
logger = logging.getLogger()
logger.setLevel(logging.WARN)

# 1. build the temp tables for Lab converted HPO, Text convert HPO
# Read the comments within the method!
initTables(debug=False)

# 2. iterate throw the dataset
primary_diagnosis_only = True
diagnosis_threshold_min = 3000
textHpo_threshold_min, textHpo_threshold_max = 500, 100000
labHpo_threshold_min, labHpo_threshold_max = 1000, 100000

synergies = iterate_in_batch(primary_diagnosis_only, diagnosis_threshold_min, textHpo_threshold_min, textHpo_threshold_max, labHpo_threshold_min, labHpo_threshold_max, logger)

100%|██████████| 3/3 [23:39<00:00, 466.51s/it]


In [94]:
if primary_diagnosis_only:
    fName = 'synergies_radiology_lab_primary_only.obj'
else:
    fName = 'synergies_radiology_lab_primary_and_secondary.obj'

with open(fName, 'wb') as synergies_file:
    pickle.dump(synergies, synergies_file)

In [None]:
synergies

array([], dtype=object)

In [45]:
encounterOfInterest(debug=True)
indexEncounterOfInterest()
diagnosisProfile()
rankICD()
diagnosis = '038'
rankHpoFromText(diagnosis)
rankHpoFromLab(diagnosis)
createDiagnosisTable(diagnosis, primary_diagnosis_only=True)

In [46]:
pd.read_sql_query("SELECT COUNT(*) FROM JAX_encounterOfInterest", mydb)

Unnamed: 0,COUNT(*)
0,100


In [47]:
pd.read_sql_query("SELECT count(*) FROM JAX_diagnosisProfile LIMIT 4", mydb)

Unnamed: 0,count(*)
0,778


In [48]:
pd.read_sql_query("SELECT * FROM JAX_diagFrequencyRank WHERE ICD9_CODE='038'", mydb)

Unnamed: 0,ICD9_CODE,N
0,38,8


In [56]:
pd.read_sql_query("SELECT * FROM JAX_labHpoFrequencyRank WHERE N >= 7", mydb)

Unnamed: 0,MAP_TO,N,PHENOTYPE
0,HP:0003138,8,1
1,HP:0001943,8,1
2,HP:0000118,8,1
3,HP:0001881,8,1
4,HP:0011277,8,1
5,HP:0011014,8,1
6,HP:0031851,8,1
7,HP:0003074,8,1
8,HP:0002148,8,1
9,HP:0000001,8,1


In [54]:
pd.read_sql_query("SELECT * FROM JAX_textHpoFrequencyRank WHERE N >= 7", mydb)

Unnamed: 0,MAP_TO,N,PHENOTYPE
0,HP:0000118,8,1
1,HP:0002088,8,1
2,HP:0001626,8,1
3,HP:0001939,8,1
4,HP:0000001,8,1
5,HP:0012252,8,1
6,HP:0002086,8,1
7,HP:0002202,7,1
8,HP:0011032,7,1
9,HP:0100750,7,1


In [None]:
#createDiagnosisTable(diagnosis, primary_diagnosis_only=False )
pd.read_sql_query("SELECT * FROM JAX_mf_diag WHERE DIAGNOSIS=1", mydb)

In [None]:
#createDiagnosisTable(diagnosis, primary_diagnosis_only=True)
pd.read_sql_query("SELECT * FROM JAX_mf_diag", mydb)

In [None]:
indexDiagnosisTable()
pd.read_sql_query("SELECT * FROM JAX_mf_diag", mydb)

In [None]:
pd.read_sql_query('''
        SELECT DIAGNOSIS FROM JAX_mf_diag WHERE ROW_ID BETWEEN {} AND {}
    '''.format(1, 25), mydb).reset_index().DIAGNOSIS

In [79]:
start_index, end_index, textHpo_threshold_min, textHpo_threshold_max, labHpo_threshold_min, labHpo_threshold_max = 0, 9, 7, 100, 7, 100

diagnosisFlat, textHpoFlat, labHpoFlat =  batch_query(start_index, end_index, textHpo_threshold_min, textHpo_threshold_max, labHpo_threshold_min, labHpo_threshold_max)

print(diagnosisFlat)
print(textHpoFlat.head(14))
print(textHpoFlat.MAP_TO.values.reshape([9, 14], order='F'))
print(textHpoFlat.VALUE.values.reshape([9, 14], order='F'))

   SUBJECT_ID  HADM_ID DIAGNOSIS  ROW_ID
0           2   163353         0       1
1           3   145834         1       2
2           4   185777         0       3
3           5   178980         0       4
4           6   107064         0       5
5           7   118037         0       6
6           8   159514         0       7
7           9   150750         0       8
8          10   184167         0       9
    SUBJECT_ID  HADM_ID      MAP_TO  VALUE
0            2   163353  HP:0000118      0
1            3   145834  HP:0000118      1
2            4   185777  HP:0000118      1
3            5   178980  HP:0000118      0
4            6   107064  HP:0000118      1
5            7   118037  HP:0000118      0
6            8   159514  HP:0000118      1
7            9   150750  HP:0000118      1
8           10   184167  HP:0000118      1
9            2   163353  HP:0002088      0
10           3   145834  HP:0002088      1
11           4   185777  HP:0002088      1
12           5   178980  HP:000

In [73]:
print(labHpoFlat.head(14))
print(labHpoFlat.MAP_TO.values.reshape([9, 54], order='F'))
print(labHpoFlat.VALUE.values.reshape([9, 54], order='F'))

    SUBJECT_ID  HADM_ID      MAP_TO  VALUE
0            2   163353  HP:0003138      0
1            3   145834  HP:0003138      1
2            4   185777  HP:0003138      0
3            5   178980  HP:0003138      0
4            6   107064  HP:0003138      1
5            7   118037  HP:0003138      0
6            8   159514  HP:0003138      0
7            9   150750  HP:0003138      1
8           10   184167  HP:0003138      0
9            2   163353  HP:0001943      0
10           3   145834  HP:0001943      1
11           4   185777  HP:0001943      0
12           5   178980  HP:0001943      0
13           6   107064  HP:0001943      1
[['HP:0003138' 'HP:0001943' 'HP:0000118' 'HP:0001881' 'HP:0011277'
  'HP:0011014' 'HP:0031851' 'HP:0003074' 'HP:0002148' 'HP:0000001'
  'HP:0031970' 'HP:0010987' 'HP:0000079' 'HP:0004363' 'HP:0031850'
  'HP:0032251' 'HP:0000119' 'HP:0002901' 'HP:0010927' 'HP:0001877'
  'HP:0002715' 'HP:0020058' 'HP:0010929' 'HP:0004364' 'HP:0001871'
  'HP:0011893' 'HP:0

In [None]:
diagnosisTextHpo(phenotype='HP:0002086')
result = pd.read_sql_query("SELECT '{}' AS DIAGNOSIS_CODE, '{}' AS PHENOTYPE, DIAGNOSIS AS DIAGNOSIS_VALUE, PHEN_TXT_VALUE AS PHENOTYPE_VALUE, COUNT(*) AS N FROM JAX_mf_diag_textHpo GROUP BY DIAGNOSIS, PHEN_TXT_VALUE", mydb)
result.head()
#result.groupby(['DIAGNOSIS', 'PHEN_TXT_VALUE']).

In [None]:
diagnosisAllTextHpo(0, 100)
pd.read_sql_query("SELECT * FROM JAX_mf_diag_allTextHpo WHERE PHEN_TXT_VALUE = 1 LIMIT 5", mydb)
#result = pd.read_sql_query("SELECT '{}' AS DIAGNOSIS_CODE, PHEN_TXT AS PHENOTYPE, DIAGNOSIS AS DIAGNOSIS_VALUE, PHEN_TXT_VALUE AS PHENOTYPE_VALUE, COUNT(*) AS N FROM JAX_mf_diag_allTextHpo GROUP BY DIAGNOSIS, PHEN_TXT, PHEN_TXT_VALUE", mydb)
#result.groupby(['PHENOTYPE']).agg({'N':sum})
#result

## Mutual information between phenotypes in radiology and lab tests
Without considering diseases, we want to determine the mutual information between pairs of phenotypes from radiology reports and lab tests. The algorithm:

    * find phenotypes of interest from radiology reports
    * find phenotypes of interest from lab tests
    * in batches, create a matrix of text phenotype profiles and a matrix of lab phenotype profiles, update summary statistics

In [14]:
def batch_query_lab_text(start_index, end_index, textHpo_min, textHpo_max, labHpo_min, labHpo_max):
    
    textHpo_flat = pd.read_sql_query('''
        WITH encounters AS (
                SELECT *
                FROM JAX_encounterOfInterest
                WHERE ROW_ID BETWEEN {} AND {}),
            phenotypes AS (
                SELECT MAP_TO
                FROM JAX_textHpoFrequencyRank
                WHERE N BETWEEN {} AND {}
            ), 
            temp AS (
                SELECT * 
                FROM encounters 
                JOIN phenotypes)

            SELECT L.SUBJECT_ID, L.HADM_ID, L.MAP_TO AS PHEN_TEXT, IF(R.dummy IS NULL, 0, 1) AS PHEN_TEXT_VALUE
            FROM temp AS L
            LEFT JOIN 
                JAX_textHpoProfile AS R
            ON L.SUBJECT_ID = R.SUBJECT_ID AND L.HADM_ID = R.HADM_ID AND L.MAP_TO = R.MAP_TO
        '''.format(start_index, end_index, textHpo_min, textHpo_max), mydb)
    
    labHpo_flat = pd.read_sql_query('''
        WITH encounters AS (
                SELECT *
                FROM JAX_encounterOfInterest
                WHERE ROW_ID BETWEEN {} AND {}),
            phenotypes AS (
                SELECT MAP_TO
                FROM JAX_labHpoFrequencyRank
                WHERE N BETWEEN {} AND {}
            ), 
            temp AS (
                SELECT * 
                FROM encounters 
                JOIN phenotypes)

            SELECT L.SUBJECT_ID, L.HADM_ID, L.MAP_TO AS PHEN_LAB, IF(R.dummy IS NULL, 0, 1) AS PHEN_LAB_VALUE
            FROM temp AS L
            LEFT JOIN 
                JAX_labHpoProfile AS R
            ON L.SUBJECT_ID = R.SUBJECT_ID AND L.HADM_ID = R.HADM_ID AND L.MAP_TO = R.MAP_TO
        '''.format(start_index, end_index, labHpo_min, labHpo_max), mydb)
    
    return textHpo_flat, labHpo_flat 
    

def overall_mf(batch_size, textHpo_threshold_min, textHpo_threshold_max, labHpo_threshold_min, labHpo_threshold_max):
    
    textHpoOfInterest = pd.read_sql_query("SELECT * FROM JAX_textHpoFrequencyRank WHERE N BETWEEN {} AND {}".format(textHpo_threshold_min, textHpo_threshold_max), mydb).MAP_TO.values
    labHpoOfInterest = pd.read_sql_query("SELECT * FROM JAX_labHpoFrequencyRank WHERE N BETWEEN {} AND {}".format(labHpo_threshold_min, labHpo_threshold_max), mydb).MAP_TO.values
    M1 = len(textHpoOfInterest)
    M2 = len(labHpoOfInterest)
    
    mf_rad_lab = mf.MutualInformationVectorized(textHpoOfInterest, labHpoOfInterest)
    
    ## find the start and end ROW_ID for patient*encounter
    
    ADM_ID_START, ADM_ID_END = pd.read_sql_query('SELECT MIN(ROW_ID) AS min, MAX(ROW_ID) AS max FROM JAX_encounterOfInterest', mydb).iloc[0]
    batch_N = ADM_ID_END - ADM_ID_START + 1
    TOTAL_BATCH = math.ceil(batch_N / batch_size) # total number of batches

    print('total batches: ' + str(batch_N))
    pbar = tqdm(total=TOTAL_BATCH)
    for i in np.arange(TOTAL_BATCH):
        start_index = i * batch_size + ADM_ID_START
        if i < TOTAL_BATCH - 1:
            end_index = start_index + batch_size - 1
        else:
            end_index = batch_N
        #print('start_index = {}, end_index = {}'.format(start_index, end_index))
        actual_batch_size = end_index - start_index + 1
        #print('actual batch size: {}'.format(actual_batch_size))
        textHpo, labHpo = batch_query_lab_text(start_index, end_index, textHpo_threshold_min, textHpo_threshold_max, labHpo_threshold_min, labHpo_threshold_max)
        textHpo_matrix = textHpo.PHEN_TEXT_VALUE.values.astype(int).reshape([actual_batch_size, M1], order='F')
        labHpo_matrix = labHpo.PHEN_LAB_VALUE.values.astype(int).reshape([actual_batch_size, M2], order='F')
        #print(textHpo)
        #print(labHpo)
        #print(labHpo_matrix)
        mf_rad_lab.add_batch(textHpo_matrix, labHpo_matrix)
        pbar.update(1)
        
    pbar.close()
    
    return mf_rad_lab

In [12]:
# test

encounterOfInterest(debug=True)
indexEncounterOfInterest()
diagnosisProfile()
rankHpoFromText('')
rankHpoFromLab('')

batch_size = 11
textHpo_threshold_min = 45
textHpo_threshold_max = 65
labHpo_threshold_min = 75
labHpo_threshold_max = 85

mf_all = overall_mf(batch_size,textHpo_threshold_min,textHpo_threshold_max, labHpo_threshold_min, labHpo_threshold_max)

print(mf_all.m)
#print(mf_all.X_names)
#print(mf_all.Y_names)

100%|██████████| 10/10 [00:00<00:00, 133.00it/s]

total batches: 100
start_index = 1, end_index = 11
start_index = 12, end_index = 22
start_index = 23, end_index = 33
start_index = 34, end_index = 44
start_index = 45, end_index = 55
start_index = 56, end_index = 66
start_index = 67, end_index = 77
start_index = 78, end_index = 88
start_index = 89, end_index = 99
start_index = 100, end_index = 100
[[[59.  5. 26. 10.]
  [58.  6. 26. 10.]
  [61.  3. 23. 13.]
  [53. 11. 28.  8.]
  [53. 11. 28.  8.]
  [53. 11. 28.  8.]
  [61.  3. 20. 16.]
  [53. 11. 28.  8.]
  [53. 11. 28.  8.]
  [46. 18. 30.  6.]
  [57.  7. 19. 17.]
  [58.  6. 18. 18.]
  [57.  7. 18. 18.]
  [57.  7. 18. 18.]]

 [[55.  8. 30.  7.]
  [57.  6. 27. 10.]
  [58.  5. 26. 11.]
  [51. 12. 30.  7.]
  [51. 12. 30.  7.]
  [51. 12. 30.  7.]
  [59.  4. 22. 15.]
  [51. 12. 30.  7.]
  [51. 12. 30.  7.]
  [46. 17. 30.  7.]
  [55.  8. 21. 16.]
  [56.  7. 20. 17.]
  [55.  8. 20. 17.]
  [55.  8. 20. 17.]]

 [[57.  5. 28. 10.]
  [56.  6. 28. 10.]
  [59.  3. 25. 13.]
  [51. 11. 30.  8.]
  [51.




In [15]:
# production
encounterOfInterest(debug=False)
indexEncounterOfInterest()
diagnosisProfile()
rankHpoFromText('')
rankHpoFromLab('')

batch_size = 100
textHpo_threshold_min = 500
textHpo_threshold_max = 100000
labHpo_threshold_min = 1000
labHpo_threshold_max = 100000


mf_all = overall_mf(batch_size,textHpo_threshold_min,textHpo_threshold_max, labHpo_threshold_min, labHpo_threshold_max)

print(mf_all.m)
#print(mf_all.X_names)
#print(mf_all.Y_names)

  0%|          | 0/590 [00:00<?, ?it/s]

total batches: 58976


100%|██████████| 590/590 [14:22<00:00,  1.35s/it]

[[[5.4071e+04 7.5400e+02 3.7120e+03 4.3900e+02]
  [5.4071e+04 7.5400e+02 3.7120e+03 4.3900e+02]
  [5.3797e+04 1.0280e+03 3.5660e+03 5.8500e+02]
  ...
  [1.0940e+03 5.3731e+04 0.0000e+00 4.1510e+03]
  [1.0230e+03 5.3802e+04 3.0000e+00 4.1480e+03]
  [1.0230e+03 5.3802e+04 3.0000e+00 4.1480e+03]]

 [[5.4071e+04 7.5400e+02 3.7120e+03 4.3900e+02]
  [5.4071e+04 7.5400e+02 3.7120e+03 4.3900e+02]
  [5.3797e+04 1.0280e+03 3.5660e+03 5.8500e+02]
  ...
  [1.0940e+03 5.3731e+04 0.0000e+00 4.1510e+03]
  [1.0230e+03 5.3802e+04 3.0000e+00 4.1480e+03]
  [1.0230e+03 5.3802e+04 3.0000e+00 4.1480e+03]]

 [[4.8069e+04 5.0800e+02 9.7140e+03 6.8500e+02]
  [4.8069e+04 5.0800e+02 9.7140e+03 6.8500e+02]
  [4.7890e+04 6.8700e+02 9.4730e+03 9.2600e+02]
  ...
  [1.0270e+03 4.7550e+04 6.7000e+01 1.0332e+04]
  [9.5800e+02 4.7619e+04 6.8000e+01 1.0331e+04]
  [9.5800e+02 4.7619e+04 6.8000e+01 1.0331e+04]]

 ...

 [[5.1000e+02 0.0000e+00 5.7273e+04 1.1930e+03]
  [5.1000e+02 0.0000e+00 5.7273e+04 1.1930e+03]
  [5.1000e




In [17]:
mf_all=mf_all2
textHpo_labels = mf_all.X_names
labHpo_labels = mf_all.Y_names
M1 = len(mf_all.X_names)
M2 = len(mf_all.Y_names)
entropy_x, entropy_y = mf_all.entropies().values()
mf_all_df = pd.DataFrame(data={'P1': np.repeat(textHpo_labels, M2), 'P2': np.tile(labHpo_labels, [M1]),'entropy_P1': np.repeat(entropy_x, M2), 'entropy_P2': np.tile(entropy_y, [M1]), 'mf_P1_P2': mf_all.mf().flat})
mf_all_df.head()
mf_all_df.to_csv('mutual_info_textHpo_labHpo.csv', index=False)

In [18]:
mf_all_df.head()
mf_all_df.sort_values(by='mf_P1_P2', ascending=False).head()

Unnamed: 0,P1,P2,entropy_P1,entropy_P2,mf_P1_P2
8,HP:0000001,HP:0012337,0.367357,0.537873,0.191295
228,HP:0000118,HP:0012337,0.367357,0.537873,0.191295
230,HP:0000118,HP:0003111,0.367357,0.601487,0.173719
10,HP:0000001,HP:0003111,0.367357,0.601487,0.173719
236,HP:0000118,HP:0011015,0.367357,0.656426,0.159461


Prepare phenotypes for machine learning tasks. For specified disease of interest, find the patient info (gender, DOB), find the diagnosis info (case or control, date), and phenotype terms.  

In [None]:
encounterOfInterest(debug=False, N=100)
indexEncounterOfInterest()
diagnosisProfile()

In [None]:
patients = pd.read_sql_query('''
        SELECT 
            PATIENTS.SUBJECT_ID, PATIENTS.GENDER, PATIENTS.DOB
        FROM 
            PATIENTS
        WHERE 
            SUBJECT_ID IN (SELECT SUBJECT_ID FROM JAX_encounterOfInterest) 
    ''', mydb)

len(patients)

In [None]:
createDiagnosisTable(diagnosis='038', primary_diagnosis_only=False)

In [None]:
def first_diag_time(diagnosis):
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_first_diag_time')
    cursor.execute('''
        CREATE TEMPORARY TABLE JAX_first_diag_time
        WITH diag_time AS (
            SELECT L.*, R.ADMITTIME 
            FROM JAX_mf_diag AS L
            JOIN ADMISSIONS AS R
            ON 
                L.SUBJECT_ID = R.SUBJECT_ID AND L.HADM_ID = R.HADM_ID
        ), 
        overallDiagnosis AS (
            SELECT *, IF(SUM(DIAGNOSIS) OVER (PARTITION BY SUBJECT_ID) > 0, 1, 0) AS everDiagnosed
            FROM diag_time
        )
        
        SELECT 
            *, MIN(IF(everDiagnosed = 1, ADMITTIME, NULL) ) OVER (PARTITION BY SUBJECT_ID) AS first_diag, 
            MAX(IF(everDiagnosed = 0, ADMITTIME, NULL)) OVER (PARTITION BY SUBJECT_ID) AS last_visit_if_not_diagnosed
        FROM 
            overallDiagnosis  
    '''.format(diagnosis))

first_diag_time('038')
pd.read_sql_query('SELECT * FROM JAX_first_diag_time LIMIT 5', mydb)

In [None]:
def encountersAfterDiagnosis():
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_encounters_after_diagnosis')
    cursor.execute('''
        CREATE TEMPORARY TABLE JAX_encounters_after_diagnosis
            SELECT *, 1 AS toIgnore
            FROM JAX_first_diag_time
            WHERE DIAGNOSIS = 1 AND ADMITTIME > first_diag
    ''')
    cursor.execute('CREATE INDEX JAX_encounters_after_diagnosis_idx01 ON JAX_encounters_after_diagnosis (SUBJECT_ID, HADM_ID)')
    
encountersAfterDiagnosis() 
pd.read_sql_query("SELECT * FROM JAX_encounters_after_diagnosis LIMIT 5", mydb)

In [None]:
diagnosis_vector = pd.read_sql_query('''
        SELECT
            SUBJECT_ID, IF(SUM(DIAGNOSIS)>0, 1, 0) AS DIAGNOSED, MAX(IF(everDiagnosed = 1, first_diag, last_visit_if_not_diagnosed)) AS LAST_VISIT
        FROM
            JAX_first_diag_time
        GROUP BY SUBJECT_ID
    ''', mydb)
diagnosis_vector.head()

In [None]:
def lab_phenotype_before_diagnosis():
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_phen_lab_before_diag')
    cursor.execute('''
        CREATE TEMPORARY TABLE JAX_phen_lab_before_diag
        WITH temp as (
            SELECT L.*, W.toIgnore
            FROM JAX_LABHPOPROFILE AS L
            JOIN JAX_mf_diag AS R ON L.SUBJECT_ID = R.SUBJECT_ID AND L.HADM_ID = R.HADM_ID
            LEFT JOIN JAX_encounters_after_diagnosis W on  L.SUBJECT_ID = W.SUBJECT_ID AND L.HADM_ID = W.HADM_ID
        )
        SELECT SUBJECT_ID, MAP_TO, COUNT(*) as N
        FROM temp
        WHERE toIgnore IS NULL
        GROUP BY SUBJECT_ID, MAP_TO
    ''')
    cursor.execute('CREATE INDEX JAX_phen_lab_before_diag_idx01 ON JAX_phen_lab_before_diag (N)')
    

def text_phenotype_before_diagnosis():
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_phen_text_before_diag')
    cursor.execute('''
        CREATE TEMPORARY TABLE JAX_phen_text_before_diag
        WITH temp as (
            SELECT L.*, W.toIgnore
            FROM JAX_TEXTHPOPROFILE AS L
            JOIN JAX_mf_diag AS R ON L.SUBJECT_ID = R.SUBJECT_ID AND L.HADM_ID = R.HADM_ID
            LEFT JOIN JAX_encounters_after_diagnosis W on  L.SUBJECT_ID = W.SUBJECT_ID AND L.HADM_ID = W.HADM_ID
        )
        SELECT SUBJECT_ID, MAP_TO, COUNT(*) as N
        FROM temp
        WHERE toIgnore IS NULL
        GROUP BY SUBJECT_ID, MAP_TO
    ''')
    cursor.execute('CREATE INDEX JAX_phen_text_before_diag_idx01 ON JAX_phen_text_before_diag (N)')

In [None]:
#lab_phenotype_before_diagnosis()
lab_phenotype_vector = pd.read_sql_query('''
    SELECT * FROM JAX_phen_lab_before_diag WHERE N >= {}
'''.format(1), mydb)
lab_phenotype_vector.head()

In [None]:
#text_phenotype_before_diagnosis()
text_phenotype_vector = pd.read_sql_query('''
    SELECT * FROM JAX_phen_text_before_diag WHERE N >= {}
'''.format(1), mydb)
text_phenotype_vector.head()

In [None]:
merged = lab_phenotype_vector.merge(text_phenotype_vector, on = ['SUBJECT_ID', 'MAP_TO'])
merged.head()

In [None]:
a = pd.DataFrame(data={'x': ['a', 'b', 'c'], 'v': [1, 2, 3]}).set_index('x')
b = pd.DataFrame(data={'x': ['a', 'c', 'd'], 'v': [4, 5, 6]}).set_index('x')
print(a)
print(b)
a


In [None]:
m = phenotype_vector.pivot_table(values='N', index='SUBJECT_ID', columns='MAP_TO', fill_value=0)

In [None]:
m.shape

In [None]:
m.head()

In [None]:
X = patients.merge(m, on = 'SUBJECT_ID', how = 'left').sort_values(by = 'SUBJECT_ID').set_index('SUBJECT_ID')
X = X.drop('DOB', axis=1).fillna(value=0)
X.head(n = 3)

In [None]:
y = diagnosis_vector.sort_values(by = 'SUBJECT_ID').set_index('SUBJECT_ID')
y.head()

In [None]:
import sklearn
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.compose import make_column_transformer
from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import StandardScaler

In [None]:
categorical = X.dtypes == object
categorical

In [None]:
sex = pd.get_dummies(X.GENDER)
X = X.drop('GENDER', axis=1).merge(sex, left_index=True, right_index=True).drop('F', axis=1)
X.M = X.M.astype(float)
X.head()

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y)

In [None]:
clf = LogisticRegression()

In [None]:
clf.fit(X_train, y_train)

In [None]:
clf.score(X_test, y_test)

In [None]:
np.sum(y_test.IsDiagnosed)/len(y_test)

In [None]:
cursor.close()
mydb.close()