In [1]:
import mysql.connector
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import datetime
import math
import os
import sys
import logging
mf_module_path = os.path.abspath(os.path.join('../python'))
if mf_module_path not in sys.path:
    sys.path.append(mf_module_path)
import mf
import mf_random
import hpoutil
import networkx
import obonet
import pickle
from tqdm import tqdm
import time

**Connect to MySQL database**

In [2]:
mydb = mysql.connector.connect(host='localhost',
                               user='mimicuser',
                               passwd='mimic',
                               database='mimiciiiv13',
                              auth_plugin='mysql_native_password')

First approach to query mysql from python

Check that MySQL connection works properly

In [3]:
df = pd.read_sql_query("SELECT * FROM LABEVENTS LIMIT 5;", mydb)
df

Unnamed: 0,ROW_ID,SUBJECT_ID,HADM_ID,ITEMID,CHARTTIME,VALUE,VALUENUM,VALUEUOM,FLAG
0,1,2,163353,51143,2138-07-17 20:48:00,0,0.0,%,
1,2,2,163353,51144,2138-07-17 20:48:00,0,0.0,%,
2,3,2,163353,51146,2138-07-17 20:48:00,0,0.0,%,
3,4,2,163353,51200,2138-07-17 20:48:00,0,0.0,%,
4,5,2,163353,51221,2138-07-17 20:48:00,0,0.0,%,abnormal


Get a cursor so that it can be used later

In [4]:
cursor = mydb.cursor(buffered=True)

We explored several method to compute the synergy score for different diseases. Method 1-3 all worked but the time and space requirements are too high. See the archived file. Here, we use method 4 to compute phenotype pairwise synergies.


# Synergy between Lab-derived Abnormalities

## Algorithm

This method relies on the power of MySQL for doing queies and joins, return a batch of phenotype profiles a time, and then use the power of Numpy to do numeric computation.

Specificially, the method runs the following algorithm:

    1. For one diagnosis code, specify the phenotypes to analyze--a list of HPO terms.
    2. For a batch of patient encounters, return a list of diagnosis codes (1 or 0)
    3. For the same batch of patient*encounters, return a list of phenotypes.
    4. Create a numpy array with dimension (N x P)
    5. Perform numeric computation with Numpy:
        outer product for ++ of PxP.T
        outer product for +- of Px(1-P).T
        outer product for -+ of (1-P)xP
        outer product for -- of (1-P)x(1-P).T
        combine the above with - and + of diagnosis value
        stack them together as a (N x P x P x 8) matrix.
        Step 1 - 5 are performed at each site. The resulting matrix is returned to JAX for final analyze.
    6. Compute pairwise synergy:
        use the multi-dimension array to calculate p(D = 1), p(D = 0), p(P1 * P2)
        compute mutual information of each phenotype in regarding to one diagnosis I(P:D)
        compute mutual information of two phenotypes in regarding to one diagnosis I(P:D)
        compute pairwise synergy
        

# Synergy between lab-derived and radiology report-derived Abnormalities

## Algorithm

The algorithm is about the same. Briefly, 

    * Select encounterOfInterest, temp table: JAX_encounterOfInterest(SUBJECT_ID, HADM_ID)
    * Init diagnosisProfile: temp table: JAX_diagnosisProfile(SUBJECT_ID, HADM_ID, ICD, N)
    * Init textHpoProfile: temp table: JAX_textHpoProfile(SUBJECT_ID, HADM_ID, MAP_TO, N)
    * Init labHpoProfile: temp table: JAX_labHpoProfile(SUBJECT_ID, HADM_ID, MAP_TO, N)
    
    * Rank ICD frequency, temp table: JAX_diagFrequencyRank(ICD, N)
      select diagOfInterest
    * Rank textHPO frequency, temp table: JAX_textHpoFrequencyRank(MAP_TO, N)
      select textHpoOfInterest
    * Rank labHPO frequency, temp table: JAX_labHpoFrequencyRank(MAP_TO, N)
      select labHpoOfInterest
    
    * Iteratation
      for diagnosis in diagOfInterest
          for textHpo in textHpoOfInterest
              for labHpo in labHpoOfInterest
                 Assign diagnosis value: assignDiagnosis(), table: (SUBJECT_ID, HADM_ID, DIAGNOSIS)
                 Assign text2hpo phenotype value: table: SUBJECT_ID, HADM_ID, PHEN_TEXT
                 Assign lab2hpo phenotype value: table: SUBJECT_ID, HADM_ID, PHEN_LAB

## Algorithm implemention
The mutual information theory and algorithm is described in the following paper:


    Anastassiou D, Computational analysis of the synergy among multiple
    interacting genes. Molecular System Biology 3:83

The algorithm is implemented in Python [link](https://github.com/TheJacksonLaboratory/MIMIC_HPO/blob/export_intermediate_data/src/main/python/mf.py). Python is chosen over Java because the numeric computation library (Numpy) in python is better (?) and easier to use (sure) than Java counterparts.

Methods defined below are to prepare MySql database, query data, format them and call the mutual information algorithm implementation. 

The results are summary statistics that do not contain any PHI. 






In [5]:
def encounterOfInterest(debug=False, N=100):
    """
    Define encounters of interest. The method is not finalized yet. Currently, it will use all encounters in our database. 
    @param debug: set to True to select a small subset for testing
    @param N: limit the number of encounters when debug is set to True. If debug is set to False, N is ignored.  
    """
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_encounterOfInterest')
    if debug:
        limit = 'LIMIT {}'.format(N)
    else:
        limit = ''
    # This is admissions that we want to analyze, 'LIMIT 100' in debug mode
    cursor.execute('''
                CREATE TEMPORARY TABLE IF NOT EXISTS JAX_encounterOfInterest(
                    ROW_ID MEDIUMINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY)
                
                SELECT 
                    DISTINCT SUBJECT_ID, HADM_ID 
                FROM admissions
                {}
                '''.format(limit))
    
def indexEncounterOfInterest():
    """
    Create index on encounters table.
    """
    cursor.execute('CREATE INDEX JAX_encounterOfInterest_idx01 ON JAX_encounterOfInterest (SUBJECT_ID, HADM_ID)')
    
def diagnosisProfile():
    """
    For encounters of interest, find all of their diagnosis codes
    """
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_diagnosisProfile')
    cursor.execute('''
                CREATE TEMPORARY TABLE IF NOT EXISTS JAX_diagnosisProfile
                SELECT 
                    DIAGNOSES_ICD.SUBJECT_ID, DIAGNOSES_ICD.HADM_ID, DIAGNOSES_ICD.ICD9_CODE, DIAGNOSES_ICD.SEQ_NUM
                FROM
                    DIAGNOSES_ICD
                RIGHT JOIN
                    JAX_encounterOfInterest
                ON 
                    DIAGNOSES_ICD.SUBJECT_ID = JAX_encounterOfInterest.SUBJECT_ID 
                    AND 
                    DIAGNOSES_ICD.HADM_ID = JAX_encounterOfInterest.HADM_ID
                ''')
    
def textHpoProfile(include_inferred=True):
    """
    Set up a table for patient phenotypes from text mining. By default, merge directly mapped HPO terms and inferred terms.
    It is currently defined as a temporary table. But in reality, it is created as a perminent table as it takes a long time to init, and it is going to be used multiple times. 
    """
    if include_inferred:
        cursor.execute('''
                    CREATE TEMPORARY TABLE IF NOT EXISTS JAX_textHpoProfile
                    WITH abnorm AS (
                        SELECT
                            NOTEEVENTS.SUBJECT_ID, NOTEEVENTS.HADM_ID, NoteHpoClinPhen.MAP_TO
                        FROM 
                            NOTEEVENTS 
                        JOIN NoteHpoClinPhen on NOTEEVENTS.ROW_ID = NoteHpoClinPhen.NOTES_ROW_ID
                        
                        UNION ALL
                        
                        SELECT
                            NOTEEVENTS.SUBJECT_ID, NOTEEVENTS.HADM_ID, Inferred_NoteHpo.INFERRED_TO AS MAP_TO
                        FROM 
                            NOTEEVENTS 
                        JOIN Inferred_NoteHpo on NOTEEVENTS.ROW_ID = Inferred_NoteHpo.NOTEEVENT_ROW_ID
                        )
                    SELECT SUBJECT_ID, HADM_ID, MAP_TO, COUNT(*) AS OCCURRANCE, 1 AS dummy
                    FROM abnorm 
                    GROUP BY SUBJECT_ID, HADM_ID, MAP_TO
                ''')
        
    else:
        cursor.execute('''
                    CREATE TEMPORARY TABLE IF NOT EXISTS JAX_p_text
                    WITH abnorm AS (
                        SELECT
                            NOTEEVENTS.SUBJECT_ID, NOTEEVENTS.HADM_ID, NoteHpoClinPhen.MAP_TO
                        FROM 
                            NOTEEVENTS 
                        JOIN NoteHpoClinPhen on NOTEEVENTS.ROW_ID = NoteHpoClinPhen.NOTES_ROW_ID)
                    SELECT SUBJECT_ID, HADM_ID, MAP_TO
                    FROM abnorm 
                    GROUP BY SUBJECT_ID, HADM_ID, MAP_TO, COUNT(*) AS OCCURRANCE, 1 AS dummy
                ''')
        
def indexTextHpoProfile():
    """
    Create indeces to speed up query
    """
    #_idx01 is unnecessary if _idx3 exists
    #cursor.execute('CREATE INDEX JAX_textHpoProfile_idx01 ON JAX_textHpoProfile (SUBJECT_ID, HADM_ID)')
    cursor.execute('CREATE INDEX JAX_textHpoProfile_idx02 ON JAX_textHpoProfile (MAP_TO);')
    cursor.execute('CREATE INDEX JAX_textHpoProfile_idx03 ON JAX_textHpoProfile (SUBJECT_ID, HADM_ID, MAP_TO)')
    cursor.execute('CREATE INDEX JAX_textHpoProfile_idx04 ON JAX_textHpoProfile (OCCURRANCE)')
    
def labHpoProfile(include_inferred=True):
    """
    Set up a table for lab tests-derived phenotypes. By default, also include phenotypes that are inferred from direct mapping.
    Similar to textHpoProfile, this could be created as a perminent table. 
    """
    cursor.execute('''DROP TEMPORARY TABLE IF EXISTS JAX_labHpoProfile''')
    if include_inferred:
        cursor.execute('''
                    CREATE TEMPORARY TABLE IF NOT EXISTS JAX_labHpoProfile
                    WITH abnorm AS (
                        SELECT
                            LABEVENTS.SUBJECT_ID, LABEVENTS.HADM_ID, LabHpo.MAP_TO
                        FROM 
                            LABEVENTS 
                        JOIN LabHpo on LABEVENTS.ROW_ID = LabHpo.ROW_ID
                        WHERE LabHpo.NEGATED = 'F'
                        
                        UNION ALL
                        
                        SELECT 
                            LABEVENTS.SUBJECT_ID, LABEVENTS.HADM_ID, INFERRED_LABHPO.INFERRED_TO AS MAP_TO 
                        FROM 
                            INFERRED_LABHPO 
                        JOIN 
                            LABEVENTS ON INFERRED_LABHPO.LABEVENT_ROW_ID = LABEVENTS.ROW_ID
                        )
                    SELECT SUBJECT_ID, HADM_ID, MAP_TO, COUNT(*) AS OCCURRANCE, 1 AS dummy
                    FROM abnorm 
                    GROUP BY SUBJECT_ID, HADM_ID, MAP_TO
                ''')
    else:       
        cursor.execute('''
                    CREATE TEMPORARY TABLE IF NOT EXISTS JAX_labHpoProfile
                    WITH abnorm AS (
                        SELECT
                            LABEVENTS.SUBJECT_ID, LABEVENTS.HADM_ID, LabHpo.MAP_TO
                        FROM 
                            LABEVENTS 
                        JOIN LabHpo on LABEVENTS.ROW_ID = LabHpo.ROW_ID
                        WHERE LabHpo.NEGATED = 'F')
                    SELECT SUBJECT_ID, HADM_ID, MAP_TO, COUNT(*) AS OCCURRANCE, 1 AS dummy
                    FROM abnorm 
                    GROUP BY SUBJECT_ID, HADM_ID, MAP_TO
                ''')

def indexLabHpoProfile():
    #_idx01 is not necessary if _idx3 exists
    #cursor.execute('CREATE INDEX JAX_labHpoProfile_idx01 ON JAX_labHpoProfile (SUBJECT_ID, HADM_ID)')
    cursor.execute('CREATE INDEX JAX_labHpoProfile_idx02 ON JAX_labHpoProfile (MAP_TO);')
    cursor.execute('CREATE INDEX JAX_labHpoProfile_idx03 ON JAX_labHpoProfile (SUBJECT_ID, HADM_ID, MAP_TO)')
    cursor.execute('CREATE INDEX JAX_labHpoProfile_idx04 ON JAX_labHpoProfile (OCCURRANCE)')
    
def rankICD():
    """
    Rank frequently seen ICD-9 codes (first three or four digits) among encounters of interest.
    """
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_diagFrequencyRank')
    cursor.execute("""
        CREATE TEMPORARY TABLE IF NOT EXISTS JAX_diagFrequencyRank
        WITH JAX_temp_diag AS (
            SELECT DISTINCT SUBJECT_ID, HADM_ID, 
                CASE 
                    WHEN(ICD9_CODE LIKE 'V%') THEN SUBSTRING(ICD9_CODE, 1, 3) 
                    WHEN(ICD9_CODE LIKE 'E%') THEN SUBSTRING(ICD9_CODE, 1, 4) 
                ELSE 
                    SUBSTRING(ICD9_CODE, 1, 3) END AS ICD9_CODE 
            FROM JAX_diagnosisProfile)
        SELECT 
            ICD9_CODE, COUNT(*) AS N
        FROM
            JAX_temp_diag
        GROUP BY 
            ICD9_CODE
        ORDER BY N
        DESC
        """)

def rankHpoFromText(diagnosis, hpo_min_occurrence_per_encounter):
    """
    Rank frequently seen phenotypes (HPO term) from text mining among encounters of interest. 
    An encounter may have multiple occurrances of a phenotype term. A phenotype is called if its occurrance
    meets a minimum threshold. 
    @param hpo_min_occurrence_per_encounter: threshold for a phenotype abnormality to be called. Usually use 1. 
    """
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_textHpoFrequencyRank')
    cursor.execute('''
            CREATE TEMPORARY TABLE JAX_textHpoFrequencyRank            
            WITH pd AS(
                SELECT 
                    JAX_textHpoProfile.*
                FROM 
                    JAX_textHpoProfile 
                JOIN (
                    SELECT 
                        DISTINCT SUBJECT_ID, HADM_ID
                    FROM 
                        JAX_diagnosisProfile 
                    WHERE 
                        ICD9_CODE LIKE '{}%') AS d
                ON 
                    JAX_textHpoProfile.SUBJECT_ID = d.SUBJECT_ID AND JAX_textHpoProfile.HADM_ID = d.HADM_ID
                WHERE 
                    OCCURRANCE >= {})
            SELECT 
                MAP_TO, COUNT(*) AS N, 1 AS PHENOTYPE
            FROM pd
            GROUP BY MAP_TO
            ORDER BY N DESC'''.format(diagnosis, hpo_min_occurrence_per_encounter))
    
def rankHpoFromLab(diagnosis, hpo_min_occurrence_per_encounter):
    """
    Rank frequently seen phenotypes (HPO term) from lab texts among encounters of interest. 
    An encounter may have multiple occurrances of a phenotype term, such as from lab tests that are frequently ordered.
    A phenotype is called if its occurrance meets a minimum threshold.
    @param hpo_min_occurrence_per_encounter: threshold for a phenotype abnormality to be called. 
    For example, if the parameter is set to 3, HP:0002153 Hyperkalemia is assigned iff three or more lab tests return higher than normal values for blood potassium concentrations
    """
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_labHpoFrequencyRank')
    cursor.execute('''
            CREATE TEMPORARY TABLE JAX_labHpoFrequencyRank            
            WITH pd AS(
                SELECT 
                    JAX_labHpoProfile.*
                FROM 
                    JAX_labHpoProfile 
                JOIN (
                    SELECT 
                        DISTINCT SUBJECT_ID, HADM_ID
                    FROM 
                        JAX_diagnosisProfile 
                    WHERE 
                        ICD9_CODE LIKE '{}%') AS d
                ON 
                    JAX_labHpoProfile.SUBJECT_ID = d.SUBJECT_ID AND JAX_labHpoProfile.HADM_ID = d.HADM_ID
                WHERE
                    OCCURRANCE >= {})
            SELECT 
                MAP_TO, COUNT(*) AS N, 1 AS PHENOTYPE
            FROM pd
            GROUP BY MAP_TO
            ORDER BY N DESC'''.format(diagnosis, hpo_min_occurrence_per_encounter))           

## mutual information between radiology and lab

In [6]:
def createDiagnosisTable(diagnosis, primary_diagnosis_only):
    """
    Create a temporary table JAX_mf_diag. For encounters of interest, assign 0 or 1 to each encouter whether a diagnosis is observed.
    @param diagnosis: diagnosis code. An encounter is considered to be 1 if same or more detailed code is called. 
    @prarm primary_diagnosis_only: an encounter may be associated with one primary diagnosis and many secondary ones. 
    if value is set true, only primary diagnosis counts.  
    """
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_mf_diag')
    if primary_diagnosis_only:
        limit = 'AND SEQ_NUM=1'
    else:
        limit = ''
    cursor.execute('''
                CREATE TEMPORARY TABLE IF NOT EXISTS JAX_mf_diag 
                WITH 
                    d AS (
                        SELECT 
                            DISTINCT SUBJECT_ID, HADM_ID, '1' AS DIAGNOSIS
                        FROM 
                            JAX_diagnosisProfile 
                        WHERE ICD9_CODE LIKE '{}%' {})
                    -- This is encounters with positive diagnosis

                SELECT 
                    DISTINCT a.SUBJECT_ID, a.HADM_ID, IF(d.DIAGNOSIS IS NULL, '0', '1') AS DIAGNOSIS
                FROM 
                    JAX_encounterOfInterest AS a
                LEFT JOIN
                    d ON a.SUBJECT_ID = d.SUBJECT_ID AND a.HADM_ID = d.HADM_ID       
                /* -- This is the first join for diagnosis (0, or 1) */    
                '''.format(diagnosis, limit))
    cursor.execute('CREATE INDEX JAX_mf_diag_idx01 ON JAX_mf_diag (SUBJECT_ID, HADM_ID)')


def diagnosisTextHpo(phenotype):
    """
    Assign 0 or 1 to each encounter whether a phenotype is observed from radiology reports
    @phenotype: an HPO term id
    """
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_mf_diag_textHpo')
    """
    cursor.execute('''
        CREATE TEMPORARY TABLE JAX_mf_diag_textHpo
        SELECT 
            L.*, IF(R.MAP_TO IS NULL, '0', '1') AS PHEN_TXT
        FROM JAX_mf_diag AS L 
        LEFT JOIN 
            (SELECT * 
            FROM JAX_textHpoProfile 
            WHERE JAX_textHpoProfile.MAP_TO = '{}') AS R 
        ON L.SUBJECT_ID = R.SUBJECT_ID AND L.HADM_ID = R.HADM_ID 
    '''.format(phenotype))
    """
    cursor.execute('''
        CREATE TEMPORARY TABLE JAX_mf_diag_textHpo
        WITH L AS (SELECT JAX_mf_diag.*, '{}' AS PHEN_TXT FROM JAX_mf_diag)
        SELECT 
            L.*, IF(R.dummy IS NULL, '0', '1') AS PHEN_TXT_VALUE
        FROM L 
        LEFT JOIN 
            JAX_textHpoProfile AS R
        ON L.SUBJECT_ID = R.SUBJECT_ID AND L.HADM_ID = R.HADM_ID AND L.PHEN_TXT = R.MAP_TO
    '''.format(phenotype))
    cursor.execute('CREATE INDEX JAX_mf_diag_textHpo_idx01 ON JAX_mf_diag_textHpo (SUBJECT_ID, HADM_ID)')

def diagnosisAllTextHpo(threshold_min, threshold_max):
    """
    For phenotypes of interest, defined with two parameters, assign 0 or 1 to each encounter whether a phenotype is observated from text data
    @param threshold_min: minimum threshold of encounter count for a phenotype to be interesting. 
    @param threshold_max: maximum threshold of encounter count for a phenotype to be interesting. 
    """
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_mf_diag_allTextHpo')
    cursor.execute('''
        CREATE TEMPORARY TABLE JAX_mf_diag_allTextHpo
        WITH 
            P AS (SELECT MAP_TO AS PHEN_TXT FROM JAX_textHpoFrequencyRank WHERE N BETWEEN {} AND {}),
            L AS (SELECT * FROM JAX_mf_diag JOIN P)
        SELECT 
            L.*, IF(R.dummy IS NULL, '0', '1') AS PHEN_TXT_VALUE
        FROM L 
        LEFT JOIN 
            JAX_textHpoProfile AS R
        ON L.SUBJECT_ID = R.SUBJECT_ID AND L.HADM_ID = R.HADM_ID AND L.PHEN_TXT = R.MAP_TO
    '''.format(threshold_min, threshold_max))
    cursor.execute('CREATE INDEX JAX_mf_diag_allTextHpo_idx01 ON JAX_mf_diag_allTextHpo (SUBJECT_ID, HADM_ID, PHEN_TXT)')
    
def diagnosisLabHpo(phenotype):
    """
    Assign 0 or 1 to each encounter whether a phenotype is observed from lab tests
    @phenotype: an HPO term id
    """
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_mf_diag_labHpo')
    cursor.execute('''
        CREATE TEMPORARY TABLE JAX_mf_diag_labHpo
        WITH L AS (SELECT JAX_mf_diag.*, '{}' AS PHEN_LAB FROM JAX_mf_diag)
        SELECT 
            L.*, IF(R.dummy IS NULL, '0', '1') AS PHEN_LAB_VALUE
        FROM L 
        LEFT JOIN 
             JAX_labHpoProfile AS R 
        ON L.SUBJECT_ID = R.SUBJECT_ID AND L.HADM_ID = R.HADM_ID AND L.PHEN_LAB = R.MAP_TO
    '''.format(phenotype))
    cursor.execute('CREATE INDEX JAX_mf_diag_labHpo_idx01 ON JAX_mf_diag_labHpo (SUBJECT_ID, HADM_ID)')
    
def diagnosisAllLabHpo(threshold_min, threshold_max):
    """
    For phenotypes of interest, defined with two parameters, assign 0 or 1 to each encounter whether a phenotype is observated from lab tests
    @param threshold_min: minimum threshold of encounter count for a phenotype to be interesting. 
    @param threshold_max: maximum threshold of encounter count for a phenotype to be interesting. 
    """
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_mf_diag_allLabHpo')
    cursor.execute('''
        CREATE TEMPORARY TABLE JAX_mf_diag_allLabHpo
        WITH 
            P AS (SELECT MAP_TO AS PHEN_LAB FROM JAX_labHpoFrequencyRank WHERE N BETWEEN {} AND {}),
            L AS (SELECT * FROM JAX_mf_diag JOIN P)
        SELECT 
            L.*, IF(R.dummy IS NULL, '0', '1') AS PHEN_LAB_VALUE
        FROM L 
        LEFT JOIN 
             JAX_labHpoProfile AS R 
        ON L.SUBJECT_ID = R.SUBJECT_ID AND L.HADM_ID = R.HADM_ID AND L.PHEN_LAB = R.MAP_TO
    '''.format(threshold_min, threshold_max))
    cursor.execute('CREATE INDEX JAX_mf_diag_allLabHpo_idx01 ON JAX_mf_diag_allLabHpo (SUBJECT_ID, HADM_ID)')

def diagnosisTextLab(phenotype):
    """
    Merge temporary tables to create one in which each encounter is assigned with 0 or 1 for diagnosis code, phenotype from text data and phenotype from lab tests. 
    """
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_mf_diag_txtHpo_labHpo')
    result = cursor.execute('''
        CREATE TEMPORARY TABLE JAX_mf_diag_txtHpo_labHpo 
        WITH L AS (SELECT JAX_mf_diag_textHpo.*, '{}' AS PHEN_LAB FROM JAX_mf_diag_textHpo)
        SELECT L.*, IF(R.dummy IS NULL, '0', '1') AS PHEN_LAB_VALUE
        FROM L 
        LEFT JOIN 
            JAX_labHpoProfile AS R 
        ON L.SUBJECT_ID = R.SUBJECT_ID AND L.HADM_ID = R.HADM_ID AND L.PHEN_LAB = R.MAP_TO
    '''.format(phenotype))
    
    
def diagnosisAllTextAllLab():
    """
    Merge temporary tables to create one in which each encounter is assigned with 0 or 1 for diagnosis code, all phenotypes of interest from text data, and all phenotypes from lab tests
    """
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_mf_diag_allTxtHpo_allLabHpo')
    cursor.execute('''
        CREATE TEMPORARY TABLE JAX_mf_diag_allTxtHpo_allLabHpo 
        SELECT L.SUBJECT_ID, L.HADM_ID, L.DIAGNOSIS, L.PHEN_TXT, L.PHEN_TXT_VALUE, R.PHEN_LAB, R.PHEN_LAB_VALUE 
        FROM JAX_mf_diag_allTextHpo AS L 
        JOIN JAX_mf_diag_allLabHpo AS R
        ON L.SUBJECT_ID = R.SUBJECT_ID AND L.HADM_ID = R.HADM_ID
    ''')
    

def initSummaryStatisticTables():
    """
    Init summary statistics tables.
    """
    # define empty columns to store summary statistics
    summary_statistics1_radiology = pd.DataFrame(data={'DIAGNOSIS_CODE':[], 
                       'PHENOTYPE':[], 
                       'DIAGNOSIS_VALUE':[], 
                       'PHENOTYPE_VALUE':[], 
                       'N':[]},
                columns = ['DIAGNOSIS_CODE', 'PHENOTYPE', 'DIAGNOSIS_VALUE', 'PHENOTYPE_VALUE', 'N'])
    
    summary_statistics1_lab = pd.DataFrame(data={'DIAGNOSIS_CODE':[], 
                       'PHENOTYPE':[], 
                       'DIAGNOSIS_VALUE':[], 
                       'PHENOTYPE_VALUE':[], 
                       'N':[]},
                columns = ['DIAGNOSIS_CODE', 'PHENOTYPE', 'DIAGNOSIS_VALUE', 'PHENOTYPE_VALUE', 'N'])

    summary_statistics2 = pd.DataFrame(data={'DIAGNOSIS_CODE':[], 
                       'PHEN_TXT':[], 
                       'PHEN_LAB':[], 
                       'DIAGNOSIS_VALUE':[], 
                       'PHEN_TXT_VALUE':[], 
                       'PHEN_LAB_VALUE':[], 
                       'N':[]},
                columns = ['DIAGNOSIS_CODE', 'PHEN_TXT', 'PHEN_LAB', 'DIAGNOSIS_VALUE', 'PHEN_TXT_VALUE', 'PHEN_LAB_VALUE', 'N']) 

    return summary_statistics1_radiology, summary_statistics1_lab, summary_statistics2

def initTables(debug=False):
    """
    This combines LabHpo and Inferred_LabHpo, and combines TextHpo and Inferred_TextHpo. 
    Only need to run once. For efficiency consideration, the tables can also be created as perminent. 
    It is time-consuming, so call it with caution. 
    """
    #init textHpoProfile and index it
    #I create perminant tables to save time; other users should enable them
    #textHpoProfile(include_inferred=True, threshold=1)
    #indexTextHpoProfile()
    #init labHpoProfile and index it
    #labHpoProfile(threshold=1, include_inferred=True, force_update=True)
    #indexLabHpoProfile()
    
    #define encounters to analyze
    encounterOfInterest(debug)
    indexEncounterOfInterest()
    #init diagnosisProfile
    diagnosisProfile()
    

def iterate(primary_diagnosis_only, diagnosis_threshold_min, textHpo_threshold_min, labHpo_threshold_min, logger): 
    logger.info('starting iterating...................................')
    N = pd.read_sql_query("SELECT count(*) FROM JAX_encounterOfInterest", mydb)
    # init empty tables to hold summary statistics
    summary_statistics1_radiology, summary_statistics1_lab, summary_statistics2 = initSummaryStatisticTables()
    
    # define a set of diseases that we want to analyze
    rankICD()
    
    diseaseOfInterest = pd.read_sql_query("SELECT * FROM JAX_diagFrequencyRank WHERE N > {}".format(diagnosis_threshold_min), mydb).ICD9_CODE.values
    diseaseOfInterest = ['428']
    # define encounters to analyze
    logger.info('diseases of interest established: {}'.format(len(diseaseOfInterest)))
    for diagnosis in diseaseOfInterest:
        logger.info("start analyzing disease {}".format(diagnosis))
        
        # assign each encounter whether a diagnosis code is observed
        # create a table j1 (joint 1)
        createDiagnosisTable(diagnosis, primary_diagnosis_only)
        # for every diagnosis, find phenotypes of interest to look at from radiology reports
        # for every diagnosis, find phenotypes of interest to look at from laboratory tests
        rankHpoFromText(diagnosis)
        rankHpoFromLab(diagnosis)
        
        textHpoOfInterest = pd.read_sql_query("SELECT * FROM JAX_textHpoFrequencyRank WHERE N > {}".format(textHpo_threshold_min), mydb).MAP_TO.values
        labHpoOfInterest = pd.read_sql_query("SELECT * FROM JAX_labHpoFrequencyRank WHERE N > {}".format(labHpo_threshold_min), mydb).MAP_TO.values
        logger.info("TextHpo of interest established, size: {}".format(len(textHpoOfInterest)))
        logger.info("LabHpo of interest established, size: {}".format(len(labHpoOfInterest)))
        for textHpo in textHpoOfInterest:
            logger.info("iteration: TextHpo--{}".format(textHpo))
            # assign each encounter whether a phenotype is observed from radiology reports
            diagnosisTextHpo(textHpo)            
            result1_text = pd.read_sql_query('''
                SELECT 
                    '{}' AS DIAGNOSIS_CODE, '{}' AS PHENOTYPE, DIAGNOSIS AS DIAGNOSIS_VALUE, PHEN_TXT_VALUE AS PHENOTYPE_VALUE, COUNT(*) AS N 
                FROM JAX_mf_diag_textHpo 
                GROUP BY 
                    DIAGNOSIS, PHEN_TXT_VALUE
            '''.format(diagnosis, textHpo), mydb)
            summary_statistics1_radiology = summary_statistics1_radiology.append(result1_text)
            # summary statistics for p1
            # calculate I(p1;D)
            for labHpo in labHpoOfInterest:
                logger.info(".........LabHpo--{}".format(labHpo))
                diagnosisLabHpo(labHpo)
                result1_lab = pd.read_sql_query('''
                    SELECT 
                        '{}' AS DIAGNOSIS_CODE, '{}' AS PHENOTYPE, DIAGNOSIS AS DIAGNOSIS_VALUE, PHEN_LAB_VALUE AS PHENOTYPE_VALUE, COUNT(*) AS N 
                    FROM 
                        JAX_mf_diag_labHpo 
                    GROUP BY DIAGNOSIS, PHEN_LAB_VALUE
                '''.format(diagnosis, labHpo), mydb)
                summary_statistics1_lab = summary_statistics1_lab.append(result1_lab)
            
                # assign each encounter whether a phenotype is observed from lab tests
                diagnosisTextLab(labHpo)
                result2 = pd.read_sql_query('''
                    SELECT 
                        '{}' AS DIAGNOSIS_CODE, 
                        '{}' AS PHEN_TXT, 
                        '{}' AS PHEN_LAB,  
                        DIAGNOSIS AS DIAGNOSIS_VALUE, 
                        PHEN_TXT_VALUE, 
                        PHEN_LAB_VALUE, 
                        COUNT(*) AS N
                    FROM JAX_mf_diag_txtHpo_labHpo 
                    GROUP BY DIAGNOSIS, PHEN_TXT_VALUE, PHEN_LAB_VALUE
                '''.format(diagnosis, textHpo, labHpo), mydb)
                summary_statistics2 = summary_statistics2.append(result2)
    logger.info('end iterating.....................................')            
    return N, summary_statistics1_radiology, summary_statistics1_lab, summary_statistics2 


def iterate_batch(primary_diagnosis_only, diagnosis_threshold_min, textHpo_threshold_min, textHpo_threshold_max, labHpo_threshold_min, labHpo_threshold_max, logger): 
    logger.info('starting iterating...................................')
    N = pd.read_sql_query("SELECT count(*) FROM JAX_encounterOfInterest", mydb)
    # init empty tables to hold summary statistics
    summary_statistics1_radiology, summary_statistics1_lab, summary_statistics2 = initSummaryStatisticTables()
    
    # define a set of diseases that we want to analyze
    rankICD()
    
    diseaseOfInterest = pd.read_sql_query("SELECT * FROM JAX_diagFrequencyRank WHERE N > {}".format(diagnosis_threshold_min), mydb).ICD9_CODE.values
    diseaseOfInterest = ['428']
    logger.info('diseases of interest established: {}'.format(len(diseaseOfInterest)))
    
    for diagnosis in diseaseOfInterest:
        logger.info("start analyzing disease {}".format(diagnosis))
        
        logger.info(".......assigning values of diagnosis")
        # assign each encounter whether a diagnosis code is observed
        # create a table j1 (joint 1)
        createDiagnosisTable(diagnosis, primary_diagnosis_only)
        # for every diagnosis, find phenotypes of interest to look at from radiology reports
        # for every diagnosis, find phenotypes of interest to look at from laboratory tests
        rankHpoFromText(diagnosis)
        rankHpoFromLab(diagnosis)
        logger.info("..............diagnosis values found")
        
        logger.info(".......assigning values of TextHpo")
        diagnosisAllTextHpo(textHpo_threshold_min, textHpo_threshold_max)
        result1_text = pd.read_sql_query("""
            SELECT '{}' AS DIAGNOSIS_CODE, 
                PHEN_TXT AS PHENOTYPE, 
                DIAGNOSIS AS DIAGNOSIS_VALUE, 
                PHEN_TXT_VALUE AS PHENOTYPE_VALUE, 
                COUNT(*) AS N 
            FROM JAX_mf_diag_allTextHpo 
            GROUP BY DIAGNOSIS, PHEN_TXT, PHEN_TXT_VALUE
        """.format(diagnosis), mydb)
        logger.info("..............TextHpo values found")
        summary_statistics1_radiology = summary_statistics1_radiology.append(result1_text)

        
        logger.info(".......assigning values of LabHpo")
        diagnosisAllLabHpo(labHpo_threshold_min, labHpo_threshold_max)
        result1_lab = pd.read_sql_query("""
            SELECT 
                '{}' AS DIAGNOSIS_CODE, 
                PHEN_LAB AS PHENOTYPE, 
                DIAGNOSIS AS DIAGNOSIS_VALUE, 
                PHEN_LAB_VALUE AS PHENOTYPE_VALUE, 
                COUNT(*) AS N 
            FROM JAX_mf_diag_allLabHpo 
            GROUP BY DIAGNOSIS, PHEN_LAB, PHEN_LAB_VALUE
        """.format(diagnosis), mydb)
        logger.info("..............LabHpo values found")
        summary_statistics1_lab = summary_statistics1_lab.append(result1_lab)

        logger.info(".......building diagnosis-TextHpo-LabHpo joint distribution")
        diagnosisAllTextAllLab()
        result2 = pd.read_sql_query("""
            SELECT 
                '{}' AS DIAGNOSIS_CODE, 
                PHEN_TXT, 
                PHEN_LAB, 
                DIAGNOSIS AS DIAGNOSIS_VALUE,
                PHEN_TXT_VALUE, 
                PHEN_LAB_VALUE, 
                COUNT(*) AS N 
            FROM JAX_mf_diag_allTxtHpo_allLabHpo 
            GROUP BY DIAGNOSIS, PHEN_LAB, PHEN_LAB_VALUE, PHEN_TXT, PHEN_TXT_VALUE
        """.format(diagnosis) , mydb)
        logger.info("..............diagnosis-TextHpo-LabHpo joint distribution built")
        summary_statistics2 = summary_statistics2.append(result2)

    logger.info('end iterating.....................................')            
    return N, summary_statistics1_radiology, summary_statistics1_lab, summary_statistics2 

In [None]:
# how to run this
# it takes either too long or too much memory space to run
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

# 1. build the temp tables for Lab converted HPO, Text convert HPO
# Read the comments within the method!
#initTables(debug=False)
# 2. iterate the database t (for debug, use parameter values: 0, 10, 15, for production, use parameter values: 0, 10000, 10000
#N, summary_statistics1_radiology, summary_statistics1_lab, summary_statistics2 = iterate(diagnosis_threshold_min=0, textHpo_threshold_min=10, labHpo_threshold_min=15, logger=logger)
#N, summary_statistics1_radiology, summary_statistics1_lab, summary_statistics2 = iterate(diagnosis_threshold_min=0, textHpo_threshold_min=1000, labHpo_threshold_min=1000, logger=logger)

# 2b. use the batch method
#N2, summary_statistics1_radiology2, summary_statistics1_lab2, summary_statistics22 = iterate_batch(diagnosis_threshold_min=0, textHpo_threshold_min=0, textHpo_threshold_max=100, labHpo_threshold_min=0, labHpo_threshold_max=100, logger=logger)
#N2, summary_statistics1_radiology2, summary_statistics1_lab2, summary_statistics22 = iterate_batch(diagnosis_threshold_min=0, textHpo_threshold_min=1000, textHpo_threshold_max=100000, labHpo_threshold_min=1000, labHpo_threshold_max=100000, logger=logger)

In [23]:
def indexDiagnosisTable():
    cursor.execute("ALTER TABLE JAX_mf_diag ADD COLUMN ROW_ID INT AUTO_INCREMENT PRIMARY KEY;")
    
def batch_query(start_index, end_index, textHpo_occurrance_min, labHpo_occurrance_min, textHpo_threshold_min, textHpo_threshold_max, labHpo_threshold_min, labHpo_threshold_max):
    """
    Queries databases in small batches, return diagnosis values, phenotypes from text data and phenotypes from lab data.
    @param start_index: minimum row_id
    @param end_index: maximum row_id
    @param textHpo_occurrance_min: minimum occurrances of a phenotype from text data for it to be called in one encounter
    @param labHpo_occurrance_max: maximum occurrances of a phenotype from lab tests for it to be called in one encounter
    @param textHpo_threshold_min: minimum number of encounters of a phenotypes from text data for it to be analyzed
    @param textHpo_threshold_max: maximum number of encounters of a phenotypes from text data for it to be analyzed
    @param labHpo_threshold_min: minimum number of encounters of a phenotype from lab tests for it to be analyzed
    @param labHpo_threshold_max: maximum number of encounters of a phenotype from lab tests for it to be analyzed
    """
    diagnosisVector = pd.read_sql_query('''
        SELECT * FROM JAX_mf_diag WHERE ROW_ID BETWEEN {} AND {}
    '''.format(start_index, end_index), mydb)
    
    textHpoFlat = pd.read_sql_query('''
        WITH encounters AS (
            SELECT SUBJECT_ID, HADM_ID
            FROM JAX_mf_diag 
            WHERE ROW_ID BETWEEN {} AND {}
        ), 
        textHpoOfInterest AS (
            SELECT MAP_TO 
            FROM JAX_textHpoFrequencyRank 
            WHERE N BETWEEN {} AND {}
        ), 
        joint as (
            SELECT *
            FROM encounters 
            JOIN textHpoOfInterest),
        JAX_textHpoProfile_filtered AS (
            SELECT * 
            FROM JAX_textHpoProfile 
            WHERE OCCURRANCE >= {}
        )
        
        SELECT L.SUBJECT_ID, L.HADM_ID, L.MAP_TO, IF(R.dummy IS NULL, 0, 1) AS VALUE
        FROM joint as L
        LEFT JOIN 
        JAX_textHpoProfile_filtered AS R
        ON L.SUBJECT_ID = R.SUBJECT_ID AND L.HADM_ID = R.HADM_ID AND L.MAP_TO = R.MAP_TO  
    '''.format(start_index, end_index, textHpo_threshold_min, textHpo_threshold_max, textHpo_occurrance_min), mydb)
    
    labHpoFlat = pd.read_sql_query('''
        WITH encounters AS (
            SELECT SUBJECT_ID, HADM_ID
            FROM JAX_mf_diag 
            WHERE ROW_ID BETWEEN {} AND {}
        ), 
        labHpoOfInterest AS (
            SELECT MAP_TO 
            FROM JAX_labHpoFrequencyRank 
            WHERE N BETWEEN {} AND {}
        ), 
        joint as (
            SELECT *
            FROM encounters 
            JOIN labHpoOfInterest),
        JAX_labHpoProfile_filtered AS (
            SELECT * 
            FROM JAX_labHpoProfile 
            WHERE OCCURRANCE >= {}
        )
        
        SELECT L.SUBJECT_ID, L.HADM_ID, L.MAP_TO, IF(R.dummy IS NULL, 0, 1) AS VALUE
        FROM joint as L
        LEFT JOIN 
        JAX_labHpoProfile_filtered AS R
        ON L.SUBJECT_ID = R.SUBJECT_ID AND L.HADM_ID = R.HADM_ID AND L.MAP_TO = R.MAP_TO
    '''.format(start_index, end_index, labHpo_threshold_min, labHpo_threshold_max, labHpo_occurrance_min), mydb)
    
    return diagnosisVector, textHpoFlat, labHpoFlat

def summarize_diagnosis_textHpo_labHpo(primary_diagnosis_only, textHpo_occurrance_min, labHpo_occurrance_min, diagnosis_threshold_min, textHpo_threshold_min, textHpo_threshold_max, labHpo_threshold_min, labHpo_threshold_max, logger):
    """
    Iterate database to get summary statistics. 
    
    @param primary_diagnosis_only: only primary diagnosis is analyzed
    @param textHpo_occurrance_min: minimum occurrances of a phenotype from text data for it to be called in one encounter
    @param labHpo_occurrance_max: maximum occurrances of a phenotype from lab tests for it to be called in one encounter
    @param textHpo_threshold_min: minimum number of encounters of a phenotypes from text data for it to be analyzed
    @param textHpo_threshold_max: maximum number of encounters of a phenotypes from text data for it to be analyzed
    @param labHpo_threshold_min: minimum number of encounters of a phenotype from lab tests for it to be analyzed
    @param labHpo_threshold_max: maximum number of encounters of a phenotype from lab tests for it to be analyzed
    @param logger: logger for logging
    
    """
    logger.info('starting iterate_in_batch()')
    batch_size = 100
    
    # define a set of diseases that we want to analyze
    rankICD()
    
    diseaseOfInterest = pd.read_sql_query("SELECT * FROM JAX_diagFrequencyRank WHERE N > {}".format(diagnosis_threshold_min), mydb).ICD9_CODE.values
    #disable the following line to analyze all diseases of interest
    diseaseOfInterest = ['428', '584', '038']
    logger.info('diagnosis of interest: {}'.format(len(diseaseOfInterest)))
    
    summaries_diag_textHpo_labHpo = {}
    summaries_diag_textHpo_textHpo = {}
    summaries_diag_labHpo_labHpo = {}
    
    pbar = tqdm(total=len(diseaseOfInterest))
    for diagnosis in diseaseOfInterest:
        logger.info("start analyzing disease {}".format(diagnosis))
        
        logger.info(".......assigning values of diagnosis")
        # assign each encounter whether a diagnosis code is observed
        # create a table j1 (joint 1)
        createDiagnosisTable(diagnosis, primary_diagnosis_only)
        indexDiagnosisTable()
        # for every diagnosis, find phenotypes of interest to look at from radiology reports
        # for every diagnosis, find phenotypes of interest to look at from laboratory tests
        rankHpoFromText(diagnosis, textHpo_occurrance_min)
        rankHpoFromLab(diagnosis, labHpo_occurrance_min)
        logger.info("..............diagnosis values found")
        
        textHpoOfInterest = pd.read_sql_query("SELECT * FROM JAX_textHpoFrequencyRank WHERE N BETWEEN {} AND {}".format(textHpo_threshold_min, textHpo_threshold_max), mydb).MAP_TO.values
        labHpoOfInterest = pd.read_sql_query("SELECT * FROM JAX_labHpoFrequencyRank WHERE N BETWEEN {} AND {}".format(labHpo_threshold_min, labHpo_threshold_max), mydb).MAP_TO.values
        logger.info("TextHpo of interest established, size: {}".format(len(textHpoOfInterest)))
        logger.info("LabHpo of interest established, size: {}".format(len(labHpoOfInterest)))

        ## find the start and end ROW_ID for patient*encounter
        ADM_ID_START, ADM_ID_END = pd.read_sql_query('SELECT MIN(ROW_ID) AS min, MAX(ROW_ID) AS max FROM JAX_mf_diag', mydb).iloc[0]
        batch_N = ADM_ID_END - ADM_ID_START + 1
        TOTAL_BATCH = math.ceil(batch_N / batch_size) # total number of batches
        
        summaries_diag_textHpo_labHpo[diagnosis] = mf.SummaryXYz(textHpoOfInterest, labHpoOfInterest, diagnosis)
        summaries_diag_textHpo_textHpo[diagnosis] = mf.SummaryXYz(textHpoOfInterest, textHpoOfInterest, diagnosis)
        summaries_diag_labHpo_labHpo[diagnosis] = mf.SummaryXYz(labHpoOfInterest, labHpoOfInterest, diagnosis)
        
        logger.info('starting batch queries for {}'.format(diagnosis))
        for i in np.arange(TOTAL_BATCH):
            start_index = i * batch_size + ADM_ID_START
            if i < TOTAL_BATCH - 1:
                end_index = start_index + batch_size - 1
            else:
                end_index = batch_N

            diagnosisFlat, textHpoFlat, labHpoFlat =  batch_query(start_index, end_index, textHpo_occurrance_min, labHpo_occurrance_min, textHpo_threshold_min, textHpo_threshold_max, labHpo_threshold_min, labHpo_threshold_max)
            
            batch_size_actual = len(diagnosisFlat)
            textHpoOfInterest_size = len(textHpoOfInterest)
            labHpoOfInterest_size = len(labHpoOfInterest)
            #print('len(textHpoFlat)= {}, batch_size_actual={}, textHpoOfInterest_size={}'.format(len(textHpoFlat), batch_size_actual, textHpoOfInterest_size))
            assert(len(textHpoFlat) == batch_size_actual * textHpoOfInterest_size)
            assert(len(labHpoFlat) == batch_size_actual * labHpoOfInterest_size)
            
            if batch_size_actual > 0:
                diagnosisVector = diagnosisFlat.DIAGNOSIS.values.astype(int)
                # reformat the flat vector into N x M matrix, N is batch size, i.e. number of encounters, M is the length of HPO terms  
                textHpoMatrix = textHpoFlat.VALUE.values.astype(int).reshape([batch_size_actual, textHpoOfInterest_size], order='F')
                labHpoMatrix = labHpoFlat.VALUE.values.astype(int).reshape([batch_size_actual, labHpoOfInterest_size], order='F')
                # check the matrix formatting is correct
                # disable the following 4 lines to speed things up
                textHpoLabelsMatrix = textHpoFlat.MAP_TO.values.reshape([batch_size_actual, textHpoOfInterest_size], order='F')
                labHpoLabelsMatrix = labHpoFlat.MAP_TO.values.reshape([batch_size_actual, labHpoOfInterest_size], order='F')
                assert (textHpoLabelsMatrix[0, :] == textHpoOfInterest).all()
                assert (labHpoLabelsMatrix[0, :] == labHpoOfInterest).all()
                if i % 100 == 0:
                    logger.info('new batch: start_index={}, end_index={}, batch_size= {}, textHpo_size = {}, labHpo_size = {}'.format(start_index, end_index, batch_size_actual, textHpoMatrix.shape[1], labHpoMatrix.shape[1]))
                summaries_diag_textHpo_labHpo[diagnosis].add_batch(textHpoMatrix,labHpoMatrix, diagnosisVector)
                summaries_diag_textHpo_textHpo[diagnosis].add_batch(textHpoMatrix,textHpoMatrix, diagnosisVector)
                summaries_diag_labHpo_labHpo[diagnosis].add_batch(labHpoMatrix,labHpoMatrix, diagnosisVector)
         
        pbar.update(1)
        
    pbar.close()
    
    return summaries_diag_textHpo_labHpo, summaries_diag_textHpo_textHpo, summaries_diag_labHpo_labHpo

##  --TEST

In [24]:
# how to run this

# Again, it take either too long or too much memory space to run
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# 1. build the temp tables for Lab converted HPO, Text convert HPO
# Read the comments within the method!
initTables(debug=True)

# 2. iterate throw the dataset
primary_diagnosis_only = True
diagnosis_threshold_min = 5
textHpo_occurrance_min, labHpo_occurrance_min = 1, 3
textHpo_threshold_min, textHpo_threshold_max = 7, 100
labHpo_threshold_min, labHpo_threshold_max = 7, 100

summaries_diag_textHpo_labHpo, summaries_diag_textHpo_textHpo, summaries_diag_labHpo_labHpo = summarize_diagnosis_textHpo_labHpo(primary_diagnosis_only, textHpo_occurrance_min, labHpo_occurrance_min, diagnosis_threshold_min, textHpo_threshold_min, textHpo_threshold_max, labHpo_threshold_min, labHpo_threshold_max, logger)

2019-10-21 14:38:58,960 - 6689 - root - INFO - starting iterate_in_batch()
2019-10-21 14:38:58,966 - 6689 - root - INFO - diagnosis of interest: 3


  0%|          | 0/3 [00:00<?, ?it/s]

2019-10-21 14:38:58,967 - 6689 - root - INFO - start analyzing disease 428
2019-10-21 14:38:58,968 - 6689 - root - INFO - .......assigning values of diagnosis
2019-10-21 14:38:58,997 - 6689 - root - INFO - ..............diagnosis values found
2019-10-21 14:38:59,002 - 6689 - root - INFO - TextHpo of interest established, size: 28
2019-10-21 14:38:59,002 - 6689 - root - INFO - LabHpo of interest established, size: 52
2019-10-21 14:38:59,005 - 6689 - root - INFO - starting batch queries for 428
2019-10-21 14:38:59,137 - 6689 - root - INFO - new batch: start_index=1, end_index=100, batch_size= 100, textHpo_size = 28, labHpo_size = 52


 33%|███▎      | 1/3 [00:00<00:00,  5.43it/s]

2019-10-21 14:38:59,152 - 6689 - root - INFO - start analyzing disease 584
2019-10-21 14:38:59,153 - 6689 - root - INFO - .......assigning values of diagnosis
2019-10-21 14:38:59,168 - 6689 - root - INFO - ..............diagnosis values found
2019-10-21 14:38:59,172 - 6689 - root - INFO - TextHpo of interest established, size: 9
2019-10-21 14:38:59,173 - 6689 - root - INFO - LabHpo of interest established, size: 29
2019-10-21 14:38:59,175 - 6689 - root - INFO - starting batch queries for 584
2019-10-21 14:38:59,241 - 6689 - root - INFO - new batch: start_index=1, end_index=100, batch_size= 100, textHpo_size = 9, labHpo_size = 29
2019-10-21 14:38:59,246 - 6689 - root - INFO - start analyzing disease 038
2019-10-21 14:38:59,247 - 6689 - root - INFO - .......assigning values of diagnosis
2019-10-21 14:38:59,269 - 6689 - root - INFO - ..............diagnosis values found
2019-10-21 14:38:59,274 - 6689 - root - INFO - TextHpo of interest established, size: 14
2019-10-21 14:38:59,274 - 6689 

100%|██████████| 3/3 [00:00<00:00,  6.23it/s]


In [28]:
summaries_diag_textHpo_labHpo['038'].m2.shape

(14, 25, 8)

## --PRODUCTION

In [32]:
# how to run this
# Again, it take either too long or too much memory space to run
logger = logging.getLogger()
logger.setLevel(logging.WARN)

# 1. build the temp tables for Lab converted HPO, Text convert HPO
# Read the comments within the method!
initTables(debug=False)

# 2. iterate throw the dataset
primary_diagnosis_only = False
diagnosis_threshold_min = 3000
textHpo_threshold_min, textHpo_threshold_max = 500, 100000
labHpo_threshold_min, labHpo_threshold_max = 1000, 100000
textHpo_occurrance_min, labHpo_occurrance_min = 1, 3

summaries_diag_textHpo_labHpo, summaries_diag_textHpo_textHpo, summaries_diag_labHpo_labHpo = summarize_diagnosis_textHpo_labHpo(primary_diagnosis_only, textHpo_occurrance_min, labHpo_occurrance_min, diagnosis_threshold_min, textHpo_threshold_min, textHpo_threshold_max, labHpo_threshold_min, labHpo_threshold_max, logger)

100%|██████████| 3/3 [32:12<00:00, 641.69s/it]


In [33]:
if primary_diagnosis_only:
    fName_diag_textHpo_labHpo = '../../../data/mf_regarding_diseases/primary_only/summaries_diagnosis_textHpo_labHpo.obj'
    fName_diag_textHpo_textHpo = '../../../data/mf_regarding_diseases/primary_only/summaries_diagnosis_textHpo_textHpo.obj'
    fName_diag_labHpo_labHpo = '../../../data/mf_regarding_diseases/primary_only/summaries_diagnosis_labHpo_labHpo.obj'
else:
    fName_diag_textHpo_labHpo = '../../../data/mf_regarding_diseases/primary_and_secondary/summaries_diagnosis_textHpo_labHpo.obj'
    fName_diag_textHpo_textHpo = '../../../data/mf_regarding_diseases/primary_and_secondary/summaries_diagnosis_textHpo_textHpo.obj'
    fName_diag_labHpo_labHpo = '../../../data/mf_regarding_diseases/primary_and_secondary/summaries_diagnosis_labHpo_labHpo.obj'

with open(fName_diag_textHpo_labHpo, 'wb') as f:
    pickle.dump(summaries_diag_textHpo_labHpo, f)
with open(fName_diag_textHpo_textHpo, 'wb') as f:
    pickle.dump(summaries_diag_textHpo_textHpo, f)
with open(fName_diag_labHpo_labHpo, 'wb') as f:
    pickle.dump(summaries_diag_labHpo_labHpo, f)

In [31]:
summaries_diag_textHpo_labHpo

{'428': <mf.SummaryXYz at 0xa1aef1dd8>,
 '584': <mf.SummaryXYz at 0xa1ae5d358>,
 '038': <mf.SummaryXYz at 0x1083c4cf8>}

## mutual information among radiology phenotypes or among lab phenotypes

In [47]:
def summarize_intra_textHpo_or_intra_labHpo(primary_diagnosis_only, textHpo_occurrance_min, labHpo_occurrance_min, diagnosis_threshold_min, textHpo_threshold_min, textHpo_threshold_max, labHpo_threshold_min, labHpo_threshold_max, logger):
    logger.info('starting iterate_in_batch()')
    batch_size = 100
    
    # define a set of diseases that we want to analyze
    rankICD()
    
    diseaseOfInterest = pd.read_sql_query("SELECT * FROM JAX_diagFrequencyRank WHERE N > {}".format(diagnosis_threshold_min), mydb).ICD9_CODE.values
    diseaseOfInterest = ['428', '584', '038']
    logger.info('diagnosis of interest: {}'.format(len(diseaseOfInterest)))
    
    summaries_text_text = {}
    summaries_lab_lab = {}
    
    pbar = tqdm(total=len(diseaseOfInterest))
    for diagnosis in diseaseOfInterest:
        logger.info("start analyzing disease {}".format(diagnosis))
        
        logger.info(".......assigning values of diagnosis")
        # assign each encounter whether a diagnosis code is observed
        # create a table j1 (joint 1)
        createDiagnosisTable(diagnosis, primary_diagnosis_only)
        indexDiagnosisTable()
        # for every diagnosis, find phenotypes of interest to look at from radiology reports
        # for every diagnosis, find phenotypes of interest to look at from laboratory tests
        rankHpoFromText(diagnosis, textHpo_occurrance_min)
        rankHpoFromLab(diagnosis, labHpo_occurrance_min)
        logger.info("..............diagnosis values found")
        
        textHpoOfInterest = pd.read_sql_query("SELECT * FROM JAX_textHpoFrequencyRank WHERE N BETWEEN {} AND {}".format(textHpo_threshold_min, textHpo_threshold_max), mydb).MAP_TO.values
        labHpoOfInterest = pd.read_sql_query("SELECT * FROM JAX_labHpoFrequencyRank WHERE N BETWEEN {} AND {}".format(labHpo_threshold_min, labHpo_threshold_max), mydb).MAP_TO.values
        logger.info("TextHpo of interest established, size: {}".format(len(textHpoOfInterest)))
        logger.info("LabHpo of interest established, size: {}".format(len(labHpoOfInterest)))

        ## find the start and end ROW_ID for patient*encounter
        ADM_ID_START, ADM_ID_END = pd.read_sql_query('SELECT MIN(ROW_ID) AS min, MAX(ROW_ID) AS max FROM JAX_mf_diag', mydb).iloc[0]
        batch_N = ADM_ID_END - ADM_ID_START + 1
        TOTAL_BATCH = math.ceil(batch_N / batch_size) # total number of batches
        
        summaries_text_text[diagnosis] = mf.SummaryXYz(textHpoOfInterest, textHpoOfInterest, diagnosis)
        summaries_lab_lab[diagnosis] = mf.SummaryXYz(labHpoOfInterest, labHpoOfInterest, diagnosis)
        
        logger.info('starting batch queries for {}'.format(diagnosis))
        for i in np.arange(TOTAL_BATCH):
            start_index = i * batch_size + ADM_ID_START
            if i < TOTAL_BATCH - 1:
                end_index = start_index + batch_size - 1
            else:
                end_index = batch_N

            diagnosisFlat, textHpoFlat, labHpoFlat =  batch_query(start_index, end_index, textHpo_occurrance_min, labHpo_occurrance_min, textHpo_threshold_min, textHpo_threshold_max, labHpo_threshold_min, labHpo_threshold_max)

            batch_size_actual = len(diagnosisFlat)
            textHpoOfInterest_size = len(textHpoOfInterest)
            labHpoOfInterest_size = len(labHpoOfInterest)
            assert(len(textHpoFlat) == batch_size_actual * textHpoOfInterest_size)
            assert(len(labHpoFlat) == batch_size_actual * labHpoOfInterest_size)
            
            if batch_size_actual > 0:
                diagnosisVector = diagnosisFlat.DIAGNOSIS.values.astype(int)
                # reformat the flat vector into N x M matrix, N is batch size, i.e. number of encounters, M is the length of HPO terms  
                textHpoMatrix = textHpoFlat.VALUE.values.astype(int).reshape([batch_size_actual, textHpoOfInterest_size], order='F')
                labHpoMatrix = labHpoFlat.VALUE.values.astype(int).reshape([batch_size_actual, labHpoOfInterest_size], order='F')
                # check the matrix formatting is correct
                # disable the following 4 lines to speed things up
                textHpoLabelsMatrix = textHpoFlat.MAP_TO.values.reshape([batch_size_actual, textHpoOfInterest_size], order='F')
                labHpoLabelsMatrix = labHpoFlat.MAP_TO.values.reshape([batch_size_actual, labHpoOfInterest_size], order='F')
                assert (textHpoLabelsMatrix[0, :] == textHpoOfInterest).all()
                assert (labHpoLabelsMatrix[0, :] == labHpoOfInterest).all()
                if i % 100 == 0:
                    logger.info('new batch: start_index={}, end_index={}, batch_size= {}, textHpo_size = {}, labHpo_size = {}'.format(start_index, end_index, batch_size_actual, textHpoMatrix.shape[1], labHpoMatrix.shape[1]))
                
                summaries_text_text[diagnosis].add_batch(textHpoMatrix,textHpoMatrix, diagnosisVector)
                summaries_lab_lab[diagnosis].add_batch(labHpoMatrix, labHpoMatrix, diagnosisVector)
         
        pbar.update(1)
        
    pbar.close()
    
    return summaries_text_text, summaries_lab_lab

### Test

In [48]:
# how to run this
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# 1. build the temp tables for Lab converted HPO, Text convert HPO
# Read the comments within the method!
initTables(debug=True)

# 2. iterate throw the dataset
primary_diagnosis_only = True
diagnosis_threshold_min = 5
textHpo_threshold_min, textHpo_threshold_max = 7, 100
labHpo_threshold_min, labHpo_threshold_max = 7, 100
textHpo_occurrance_min, labHpo_occurrance_min = 1, 3

summaries_text_text, summaries_lab_lab = summarize_intra_textHpo_or_intra_labHpo(primary_diagnosis_only, textHpo_occurrance_min, labHpo_occurrance_min, diagnosis_threshold_min, textHpo_threshold_min, textHpo_threshold_max, labHpo_threshold_min, labHpo_threshold_max, logger)

2019-10-16 15:43:03,917 - 2538 - root - INFO - starting iterate_in_batch()
2019-10-16 15:43:03,923 - 2538 - root - INFO - diagnosis of interest: 3



  0%|          | 0/3 [00:00<?, ?it/s][A

2019-10-16 15:43:03,926 - 2538 - root - INFO - start analyzing disease 428
2019-10-16 15:43:03,927 - 2538 - root - INFO - .......assigning values of diagnosis
2019-10-16 15:43:03,960 - 2538 - root - INFO - ..............diagnosis values found
2019-10-16 15:43:03,966 - 2538 - root - INFO - TextHpo of interest established, size: 28
2019-10-16 15:43:03,967 - 2538 - root - INFO - LabHpo of interest established, size: 52
2019-10-16 15:43:03,970 - 2538 - root - INFO - starting batch queries for 428
2019-10-16 15:43:04,111 - 2538 - root - INFO - new batch: start_index=1, end_index=100, batch_size= 100, textHpo_size = 28, labHpo_size = 52



 33%|███▎      | 1/3 [00:00<00:00,  5.01it/s][A

2019-10-16 15:43:04,128 - 2538 - root - INFO - start analyzing disease 584
2019-10-16 15:43:04,129 - 2538 - root - INFO - .......assigning values of diagnosis
2019-10-16 15:43:04,147 - 2538 - root - INFO - ..............diagnosis values found
2019-10-16 15:43:04,151 - 2538 - root - INFO - TextHpo of interest established, size: 9
2019-10-16 15:43:04,152 - 2538 - root - INFO - LabHpo of interest established, size: 29
2019-10-16 15:43:04,155 - 2538 - root - INFO - starting batch queries for 584
2019-10-16 15:43:04,225 - 2538 - root - INFO - new batch: start_index=1, end_index=100, batch_size= 100, textHpo_size = 9, labHpo_size = 29



 67%|██████▋   | 2/3 [00:00<00:00,  5.86it/s][A

2019-10-16 15:43:04,230 - 2538 - root - INFO - start analyzing disease 038
2019-10-16 15:43:04,231 - 2538 - root - INFO - .......assigning values of diagnosis
2019-10-16 15:43:04,255 - 2538 - root - INFO - ..............diagnosis values found
2019-10-16 15:43:04,260 - 2538 - root - INFO - TextHpo of interest established, size: 14
2019-10-16 15:43:04,261 - 2538 - root - INFO - LabHpo of interest established, size: 25
2019-10-16 15:43:04,264 - 2538 - root - INFO - starting batch queries for 038
2019-10-16 15:43:04,346 - 2538 - root - INFO - new batch: start_index=1, end_index=100, batch_size= 100, textHpo_size = 14, labHpo_size = 25



100%|██████████| 3/3 [00:00<00:00,  6.41it/s][A

### Production

In [37]:
# how to run this
logger = logging.getLogger()
logger.setLevel(logging.WARN)

# 1. build the temp tables for Lab converted HPO, Text convert HPO
# Read the comments within the method!
initTables(debug=False)

# 2. iterate throw the dataset
primary_diagnosis_only = True
diagnosis_threshold_min = 3000
textHpo_threshold_min, textHpo_threshold_max = 500, 100000
labHpo_threshold_min, labHpo_threshold_max = 1000, 100000
textHpo_occurrance_min, labHpo_occurrance_min = 1, 3

summaries_text_text, summaries_lab_lab = summarize_intra_textHpo_or_intra_labHpo(primary_diagnosis_only, textHpo_occurrance_min, labHpo_occurrance_min, diagnosis_threshold_min, textHpo_threshold_min, textHpo_threshold_max, labHpo_threshold_min, labHpo_threshold_max, logger)


  0%|          | 0/3 [00:00<?, ?it/s][A
 33%|███▎      | 1/3 [10:00<20:01, 600.98s/it][A
 67%|██████▋   | 2/3 [21:15<10:23, 623.08s/it][A
100%|██████████| 3/3 [27:51<00:00, 555.06s/it][A

In [None]:
with open('summaries-intra-textHpo-primary_only.obj', 'wb') as summaries_file:
    pickle.dump(summaries_text_text, summaries_file)
    
with open('summaries-intra-labHpo-primary_only.obj', 'wb') as summaries_file:
    pickle.dump(summaries_lab_lab, summaries_file)

## Mutual information between phenotypes of radiology and lab tests regardless of diseases
Without considering diseases, we want to determine the mutual information between pairs of phenotypes from radiology reports and lab tests. The algorithm:

    * find phenotypes of interest from radiology reports
    * find phenotypes of interest from lab tests
    * in batches, create a matrix of text phenotype profiles and a matrix of lab phenotype profiles, update summary statistics

In [15]:
def batch_query_lab_text(start_index, end_index, textHpo_occurrance_min, labHpo_occurrance_min, textHpo_min, textHpo_max, labHpo_min, labHpo_max):
    
    textHpo_flat = pd.read_sql_query('''
        WITH encounters AS (
                SELECT *
                FROM JAX_encounterOfInterest
                WHERE ROW_ID BETWEEN {} AND {}),
            phenotypes AS (
                SELECT MAP_TO
                FROM JAX_textHpoFrequencyRank
                WHERE N BETWEEN {} AND {}
            ), 
            temp AS (
                SELECT * 
                FROM encounters 
                JOIN phenotypes)

            SELECT L.SUBJECT_ID, L.HADM_ID, L.MAP_TO AS PHEN_TEXT, IF(R.dummy IS NULL, 0, 1) AS PHEN_TEXT_VALUE
            FROM temp AS L
            LEFT JOIN 
                (SELECT * FROM JAX_textHpoProfile WHERE OCCURRANCE >= {}) AS R
            ON L.SUBJECT_ID = R.SUBJECT_ID AND L.HADM_ID = R.HADM_ID AND L.MAP_TO = R.MAP_TO
        '''.format(start_index, end_index, textHpo_min, textHpo_max, textHpo_occurrance_min), mydb)
    
    labHpo_flat = pd.read_sql_query('''
        WITH encounters AS (
                SELECT *
                FROM JAX_encounterOfInterest
                WHERE ROW_ID BETWEEN {} AND {}),
            phenotypes AS (
                SELECT MAP_TO
                FROM JAX_labHpoFrequencyRank
                WHERE N BETWEEN {} AND {}
            ), 
            temp AS (
                SELECT * 
                FROM encounters 
                JOIN phenotypes)

            SELECT L.SUBJECT_ID, L.HADM_ID, L.MAP_TO AS PHEN_LAB, IF(R.dummy IS NULL, 0, 1) AS PHEN_LAB_VALUE
            FROM temp AS L
            LEFT JOIN 
                (SELECT * FROM JAX_labHpoProfile WHERE OCCURRANCE >= {}) AS R
            ON L.SUBJECT_ID = R.SUBJECT_ID AND L.HADM_ID = R.HADM_ID AND L.MAP_TO = R.MAP_TO
        '''.format(start_index, end_index, labHpo_min, labHpo_max, labHpo_occurrance_min), mydb)
    
    return textHpo_flat, labHpo_flat 
    

def summary_textHpo_labHpo(batch_size, textHpo_occurrance_min, labHpo_occurrance_min, textHpo_threshold_min, textHpo_threshold_max, labHpo_threshold_min, labHpo_threshold_max):
    
    textHpoOfInterest = pd.read_sql_query("SELECT * FROM JAX_textHpoFrequencyRank WHERE N BETWEEN {} AND {}".format(textHpo_threshold_min, textHpo_threshold_max), mydb).MAP_TO.values
    labHpoOfInterest = pd.read_sql_query("SELECT * FROM JAX_labHpoFrequencyRank WHERE N BETWEEN {} AND {}".format(labHpo_threshold_min, labHpo_threshold_max), mydb).MAP_TO.values
    M1 = len(textHpoOfInterest)
    M2 = len(labHpoOfInterest)
    
    summary_rad_lab = mf.SummaryXY(textHpoOfInterest, labHpoOfInterest)
    summary_rad_rad = mf.SummaryXY(textHpoOfInterest, textHpoOfInterest)
    summary_lab_lab = mf.SummaryXY(labHpoOfInterest, labHpoOfInterest)
    
    ## find the start and end ROW_ID for patient*encounter
    
    ADM_ID_START, ADM_ID_END = pd.read_sql_query('SELECT MIN(ROW_ID) AS min, MAX(ROW_ID) AS max FROM JAX_encounterOfInterest', mydb).iloc[0]
    batch_N = ADM_ID_END - ADM_ID_START + 1
    TOTAL_BATCH = math.ceil(batch_N / batch_size) # total number of batches

    print('total batches: ' + str(batch_N))
    pbar = tqdm(total=TOTAL_BATCH)
    for i in np.arange(TOTAL_BATCH):
        start_index = i * batch_size + ADM_ID_START
        if i < TOTAL_BATCH - 1:
            end_index = start_index + batch_size - 1
        else:
            end_index = batch_N
        actual_batch_size = end_index - start_index + 1
        textHpo, labHpo = batch_query_lab_text(start_index, end_index, textHpo_occurrance_min, labHpo_occurrance_min, textHpo_threshold_min, textHpo_threshold_max, labHpo_threshold_min, labHpo_threshold_max)
        textHpo_matrix = textHpo.PHEN_TEXT_VALUE.values.astype(int).reshape([actual_batch_size, M1], order='F')
        labHpo_matrix = labHpo.PHEN_LAB_VALUE.values.astype(int).reshape([actual_batch_size, M2], order='F')
        summary_rad_lab.add_batch(textHpo_matrix, labHpo_matrix)
        summary_rad_rad.add_batch(textHpo_matrix, textHpo_matrix)
        summary_lab_lab.add_batch(labHpo_matrix, labHpo_matrix)
        pbar.update(1)
        
    pbar.close()
    
    return summary_rad_lab, summary_rad_rad, summary_lab_lab

### Test

In [16]:
encounterOfInterest(debug=True)
indexEncounterOfInterest()
diagnosisProfile()
rankHpoFromText('', hpo_min_occurrence_per_encounter=1)
rankHpoFromLab('', hpo_min_occurrence_per_encounter=3)

batch_size = 11
textHpo_threshold_min = 45
textHpo_threshold_max = 65
labHpo_threshold_min = 75
labHpo_threshold_max = 85
textHpo_occurrance_min = 1
labHpo_occurrance_min = 3

summary_rad_lab, summary_rad_rad, summary_lab_lab = summary_textHpo_labHpo(batch_size, textHpo_occurrance_min, labHpo_occurrance_min, textHpo_threshold_min,textHpo_threshold_max, labHpo_threshold_min, labHpo_threshold_max)

print(summary_rad_lab.m.shape)
print(summary_rad_rad.m.shape)
print(summary_lab_lab.m.shape)

100%|██████████| 10/10 [00:00<00:00, 169.42it/s]

total batches: 100
(8, 2, 4)
(8, 8, 4)
(2, 2, 4)





### Production

In [18]:
encounterOfInterest(debug=False)
indexEncounterOfInterest()
diagnosisProfile()
rankHpoFromText('', hpo_min_occurrence_per_encounter=1)
rankHpoFromLab('', hpo_min_occurrence_per_encounter=3)

batch_size = 100
textHpo_threshold_min = 500
textHpo_threshold_max = 100000
labHpo_threshold_min = 1000
labHpo_threshold_max = 100000
textHpo_occurrance_min = 1
labHpo_occurrance_min = 3

summary_rad_lab, summary_rad_rad, summary_lab_lab = summary_textHpo_labHpo(batch_size, textHpo_occurrance_min, labHpo_occurrance_min, textHpo_threshold_min,textHpo_threshold_max, labHpo_threshold_min, labHpo_threshold_max)


  0%|          | 0/590 [00:00<?, ?it/s]

total batches: 58976


100%|██████████| 590/590 [32:42<00:00,  3.08s/it]


In [20]:
with open('../../../data/mf_regardless_of_diseases/summary_textHpo_labHpo.obj', 'wb') as file:
    pickle.dump(summary_rad_lab, file)
    
with open('../../../data/mf_regardless_of_diseases/summary_textHpo_textHpo.obj', 'wb') as file:
    pickle.dump(summary_rad_rad, file)
    
with open('../../../data/mf_regardless_of_diseases/summary_labHpo_labHpo.obj', 'wb') as file:
    pickle.dump(summary_lab_lab, file)

In [None]:
mf_all=mf.MutualInfoXY(summary_all)
textHpo_labels = mf_all.X_names
labHpo_labels = mf_all.Y_names
M1 = len(mf_all.X_names)
M2 = len(mf_all.Y_names)
entropy_x, entropy_y = mf_all.entropies().values()
mf_all_df = pd.DataFrame(data={'P1': np.repeat(textHpo_labels, M2), 'P2': np.tile(labHpo_labels, [M1]),'entropy_P1': np.repeat(entropy_x, M2), 'entropy_P2': np.tile(entropy_y, [M1]), 'mf_P1_P2': mf_all.mf().flat})
mf_all_df.head()
mf_all_df.to_csv('mutual_info_textHpo_labHpo.csv', index=False)

In [None]:
mf_all_df.head()
mf_all_df.sort_values(by='mf_P1_P2', ascending=False).head()

TODO: move the following section to the machine learning notebook

## Feature generation for phenotype-based disease prediction

Prepare phenotypes for machine learning tasks. For specified disease of interest, find the patient info (gender, DOB), find the diagnosis info (case or control, date), and phenotype terms.  

In [None]:
encounterOfInterest(debug=False, N=100)
indexEncounterOfInterest()
diagnosisProfile()

In [None]:
patients = pd.read_sql_query('''
        SELECT 
            PATIENTS.SUBJECT_ID, PATIENTS.GENDER, PATIENTS.DOB
        FROM 
            PATIENTS
        WHERE 
            SUBJECT_ID IN (SELECT SUBJECT_ID FROM JAX_encounterOfInterest) 
    ''', mydb)

len(patients)

In [None]:
createDiagnosisTable(diagnosis='038', primary_diagnosis_only=False)

In [None]:
def first_diag_time(diagnosis):
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_first_diag_time')
    cursor.execute('''
        CREATE TEMPORARY TABLE JAX_first_diag_time
        WITH diag_time AS (
            SELECT L.*, R.ADMITTIME 
            FROM JAX_mf_diag AS L
            JOIN ADMISSIONS AS R
            ON 
                L.SUBJECT_ID = R.SUBJECT_ID AND L.HADM_ID = R.HADM_ID
        ), 
        overallDiagnosis AS (
            SELECT *, IF(SUM(DIAGNOSIS) OVER (PARTITION BY SUBJECT_ID) > 0, 1, 0) AS everDiagnosed
            FROM diag_time
        )
        
        SELECT 
            *, MIN(IF(everDiagnosed = 1, ADMITTIME, NULL) ) OVER (PARTITION BY SUBJECT_ID) AS first_diag, 
            MAX(IF(everDiagnosed = 0, ADMITTIME, NULL)) OVER (PARTITION BY SUBJECT_ID) AS last_visit_if_not_diagnosed
        FROM 
            overallDiagnosis  
    '''.format(diagnosis))

first_diag_time('038')
pd.read_sql_query('SELECT * FROM JAX_first_diag_time LIMIT 5', mydb)

In [None]:
def encountersAfterDiagnosis():
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_encounters_after_diagnosis')
    cursor.execute('''
        CREATE TEMPORARY TABLE JAX_encounters_after_diagnosis
            SELECT *, 1 AS toIgnore
            FROM JAX_first_diag_time
            WHERE DIAGNOSIS = 1 AND ADMITTIME > first_diag
    ''')
    cursor.execute('CREATE INDEX JAX_encounters_after_diagnosis_idx01 ON JAX_encounters_after_diagnosis (SUBJECT_ID, HADM_ID)')
    
encountersAfterDiagnosis() 
pd.read_sql_query("SELECT * FROM JAX_encounters_after_diagnosis LIMIT 5", mydb)

In [None]:
diagnosis_vector = pd.read_sql_query('''
        SELECT
            SUBJECT_ID, IF(SUM(DIAGNOSIS)>0, 1, 0) AS DIAGNOSED, MAX(IF(everDiagnosed = 1, first_diag, last_visit_if_not_diagnosed)) AS LAST_VISIT
        FROM
            JAX_first_diag_time
        GROUP BY SUBJECT_ID
    ''', mydb)
diagnosis_vector.head()

In [None]:
def lab_phenotype_before_diagnosis():
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_phen_lab_before_diag')
    cursor.execute('''
        CREATE TEMPORARY TABLE JAX_phen_lab_before_diag
        WITH temp as (
            SELECT L.*, W.toIgnore
            FROM JAX_LABHPOPROFILE AS L
            JOIN JAX_mf_diag AS R ON L.SUBJECT_ID = R.SUBJECT_ID AND L.HADM_ID = R.HADM_ID
            LEFT JOIN JAX_encounters_after_diagnosis W on  L.SUBJECT_ID = W.SUBJECT_ID AND L.HADM_ID = W.HADM_ID
            WHERE L.OCCURRANCE >= 1
        )
        SELECT SUBJECT_ID, MAP_TO, COUNT(*) as N
        FROM temp
        WHERE toIgnore IS NULL
        GROUP BY SUBJECT_ID, MAP_TO
    ''')
    cursor.execute('CREATE INDEX JAX_phen_lab_before_diag_idx01 ON JAX_phen_lab_before_diag (N)')
    

def text_phenotype_before_diagnosis():
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_phen_text_before_diag')
    cursor.execute('''
        CREATE TEMPORARY TABLE JAX_phen_text_before_diag
        WITH temp as (
            SELECT L.*, W.toIgnore
            FROM JAX_TEXTHPOPROFILE AS L
            JOIN JAX_mf_diag AS R ON L.SUBJECT_ID = R.SUBJECT_ID AND L.HADM_ID = R.HADM_ID
            LEFT JOIN JAX_encounters_after_diagnosis W on  L.SUBJECT_ID = W.SUBJECT_ID AND L.HADM_ID = W.HADM_ID
        )
        SELECT SUBJECT_ID, MAP_TO, COUNT(*) as N
        FROM temp
        WHERE toIgnore IS NULL
        GROUP BY SUBJECT_ID, MAP_TO
    ''')
    cursor.execute('CREATE INDEX JAX_phen_text_before_diag_idx01 ON JAX_phen_text_before_diag (N)')

In [None]:
#lab_phenotype_before_diagnosis()
lab_phenotype_vector = pd.read_sql_query('''
    SELECT * FROM JAX_phen_lab_before_diag
'''.format(1), mydb)
lab_phenotype_vector.head()

In [None]:
text_phenotype_before_diagnosis()
text_phenotype_vector = pd.read_sql_query('''
    SELECT * FROM JAX_phen_text_before_diag WHERE N >= {}
'''.format(1), mydb)
text_phenotype_vector.head()

In [None]:
phenotypes = lab_phenotype_vector.merge(text_phenotype_vector, on = ['SUBJECT_ID', 'MAP_TO'], how = 'outer').fillna(value = 0)
phenotypes['N'] = phenotypes['N_x'] + phenotypes['N_y']
phenotypes.head()

In [None]:
phenotypes_matrix = merged.loc[:, ['SUBJECT_ID', 'MAP_TO', 'N']].pivot_table(values='N', index='SUBJECT_ID', columns='MAP_TO', fill_value=0)

In [None]:
df = patients.merge(diagnosis_vector, on = 'SUBJECT_ID')
df['AGE'] = (pd.to_datetime(df.LAST_VISIT, format='%Y-%m-%d %H:%M:%S').dt.year - pd.to_datetime(df.DOB, format='%Y-%m-%d %H:%M:%S').dt.year)
df.head()


In [None]:
df = df.loc[:, ['SUBJECT_ID', 'GENDER', 'AGE', 'DIAGNOSED']].merge(phenotypes_matrix, on = 'SUBJECT_ID')

In [None]:
df.head()

In [None]:
df.columns

In [None]:
df.to_csv('ml_df_038_primary_only.csv', index=False)

In [None]:
X = patients.merge(m, on = 'SUBJECT_ID', how = 'left').sort_values(by = 'SUBJECT_ID').set_index('SUBJECT_ID')
X = X.drop('DOB', axis=1).fillna(value=0)
X.head(n = 3)

In [None]:
y = diagnosis_vector.sort_values(by = 'SUBJECT_ID').set_index('SUBJECT_ID')
y.head()

In [None]:
import sklearn
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.compose import make_column_transformer
from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import StandardScaler

In [None]:
categorical = X.dtypes == object
categorical

In [None]:
sex = pd.get_dummies(X.GENDER)
X = X.drop('GENDER', axis=1).merge(sex, left_index=True, right_index=True).drop('F', axis=1)
X.M = X.M.astype(float)
X.head()

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y)

In [None]:
clf = LogisticRegression()

In [None]:
clf.fit(X_train, y_train)

In [None]:
clf.score(X_test, y_test)

In [None]:
np.sum(y_test.IsDiagnosed)/len(y_test)

In [None]:
cursor.close()
mydb.close()