In [1]:
import mysql.connector
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import datetime
import math
import os
import sys
import logging
mf_module_path = os.path.abspath(os.path.join('../python'))
if mf_module_path not in sys.path:
    sys.path.append(mf_module_path)
import mf
import mf_random
import synergy_tree
from ontology import Ontology
import pickle
from tqdm import tqdm, tqdm_notebook
import time

**Connect to MySQL database**

In [2]:
mydb = mysql.connector.connect(host='localhost',
                               user='mimicuser',
                               passwd='mimic',
                               database='mimiciiiv13',
                              auth_plugin='mysql_native_password')

First approach to query mysql from python

Check that MySQL connection works properly

In [3]:
df = pd.read_sql_query("SELECT * FROM LABEVENTS LIMIT 5;", mydb)
df

Unnamed: 0,ROW_ID,SUBJECT_ID,HADM_ID,ITEMID,CHARTTIME,VALUE,VALUENUM,VALUEUOM,FLAG
0,1,2,163353,51143,2138-07-17 20:48:00,0,0.0,%,
1,2,2,163353,51144,2138-07-17 20:48:00,0,0.0,%,
2,3,2,163353,51146,2138-07-17 20:48:00,0,0.0,%,
3,4,2,163353,51200,2138-07-17 20:48:00,0,0.0,%,
4,5,2,163353,51221,2138-07-17 20:48:00,0,0.0,%,abnormal


Get a cursor so that it can be used later

In [4]:
cursor = mydb.cursor(buffered=True)

We explored several method to compute the synergy score for different diseases. Method 1-3 all worked but the time and space requirements are too high. See the archived file. Here, we use method 4 to compute phenotype pairwise synergies.


# Synergy between Lab-derived Abnormalities

## Algorithm

This method relies on the power of MySQL for doing queies and joins, return a batch of phenotype profiles a time, and then use the power of Numpy to do numeric computation.

Specificially, the method runs the following algorithm:

    1. For one diagnosis code, specify the phenotypes to analyze--a list of HPO terms.
    2. For a batch of patient encounters, return a list of diagnosis codes (1 or 0)
    3. For the same batch of patient*encounters, return a list of phenotypes.
    4. Create a numpy array with dimension (N x P)
    5. Perform numeric computation with Numpy:
        outer product for ++ of PxP.T
        outer product for +- of Px(1-P).T
        outer product for -+ of (1-P)xP
        outer product for -- of (1-P)x(1-P).T
        combine the above with - and + of diagnosis value
        stack them together as a (N x P x P x 8) matrix.
        Step 1 - 5 are performed at each site. The resulting matrix is returned to JAX for final analyze.
    6. Compute pairwise synergy:
        use the multi-dimension array to calculate p(D = 1), p(D = 0), p(P1 * P2)
        compute mutual information of each phenotype in regarding to one diagnosis I(P:D)
        compute mutual information of two phenotypes in regarding to one diagnosis I(P:D)
        compute pairwise synergy
        

# Synergy between lab-derived and radiology report-derived Abnormalities

## Algorithm

The algorithm is about the same. Briefly, 

    * Select encounterOfInterest, temp table: JAX_encounterOfInterest(SUBJECT_ID, HADM_ID)
    * Init diagnosisProfile: temp table: JAX_diagnosisProfile(SUBJECT_ID, HADM_ID, ICD, N)
    * Init textHpoProfile: temp table: JAX_textHpoProfile(SUBJECT_ID, HADM_ID, MAP_TO, N)
    * Init labHpoProfile: temp table: JAX_labHpoProfile(SUBJECT_ID, HADM_ID, MAP_TO, N)
    
    * Rank ICD frequency, temp table: JAX_diagFrequencyRank(ICD, N)
      select diagOfInterest
    * Rank textHPO frequency, temp table: JAX_textHpoFrequencyRank(MAP_TO, N)
      select textHpoOfInterest
    * Rank labHPO frequency, temp table: JAX_labHpoFrequencyRank(MAP_TO, N)
      select labHpoOfInterest
    
    * Iteratation
      for diagnosis in diagOfInterest
          for textHpo in textHpoOfInterest
              for labHpo in labHpoOfInterest
                 Assign diagnosis value: assignDiagnosis(), table: (SUBJECT_ID, HADM_ID, DIAGNOSIS)
                 Assign text2hpo phenotype value: table: SUBJECT_ID, HADM_ID, PHEN_TEXT
                 Assign lab2hpo phenotype value: table: SUBJECT_ID, HADM_ID, PHEN_LAB

## Algorithm implemention
The mutual information theory and algorithm is described in the following paper:


    Anastassiou D, Computational analysis of the synergy among multiple
    interacting genes. Molecular System Biology 3:83

The algorithm is implemented in Python [link](https://github.com/TheJacksonLaboratory/MIMIC_HPO/blob/export_intermediate_data/src/main/python/mf.py). Python is chosen over Java because the numeric computation library (Numpy) in python is better (?) and easier to use (sure) than Java counterparts.

Methods defined below are to prepare MySql database, query data, format them and call the mutual information algorithm implementation. 

The results are summary statistics that do not contain any PHI. 






In [5]:
def encounterOfInterest(debug=False, N=100):
    """
    Define encounters of interest. The method is not finalized yet. Currently, it will use all encounters in our database. 
    @param debug: set to True to select a small subset for testing
    @param N: limit the number of encounters when debug is set to True. If debug is set to False, N is ignored.  
    """
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_encounterOfInterest')
    if debug:
        limit = 'LIMIT {}'.format(N)
    else:
        limit = ''
    # This is admissions that we want to analyze, 'LIMIT 100' in debug mode
    cursor.execute('''
                CREATE TEMPORARY TABLE IF NOT EXISTS JAX_encounterOfInterest(
                    ROW_ID MEDIUMINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY)
                
                SELECT 
                    DISTINCT SUBJECT_ID, HADM_ID 
                FROM admissions
                {}
                '''.format(limit))
    
def indexEncounterOfInterest():
    """
    Create index on encounters table.
    """
    cursor.execute('CREATE INDEX JAX_encounterOfInterest_idx01 ON JAX_encounterOfInterest (SUBJECT_ID, HADM_ID)')
    
def diagnosisProfile():
    """
    For encounters of interest, find all of their diagnosis codes
    """
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_diagnosisProfile')
    cursor.execute('''
                CREATE TEMPORARY TABLE IF NOT EXISTS JAX_diagnosisProfile
                SELECT 
                    DIAGNOSES_ICD.SUBJECT_ID, DIAGNOSES_ICD.HADM_ID, DIAGNOSES_ICD.ICD9_CODE, DIAGNOSES_ICD.SEQ_NUM
                FROM
                    DIAGNOSES_ICD
                RIGHT JOIN
                    JAX_encounterOfInterest
                ON 
                    DIAGNOSES_ICD.SUBJECT_ID = JAX_encounterOfInterest.SUBJECT_ID 
                    AND 
                    DIAGNOSES_ICD.HADM_ID = JAX_encounterOfInterest.HADM_ID
                ''')
    
def textHpoProfile(include_inferred=True):
    """
    Set up a table for patient phenotypes from text mining. By default, merge directly mapped HPO terms and inferred terms.
    It is currently defined as a temporary table. But in reality, it is created as a perminent table as it takes a long time to init, and it is going to be used multiple times. 
    """
    if include_inferred:
        cursor.execute('''
                    CREATE TEMPORARY TABLE IF NOT EXISTS JAX_textHpoProfile
                    WITH abnorm AS (
                        SELECT
                            NOTEEVENTS.SUBJECT_ID, NOTEEVENTS.HADM_ID, NoteHpoClinPhen.MAP_TO
                        FROM 
                            NOTEEVENTS 
                        JOIN NoteHpoClinPhen on NOTEEVENTS.ROW_ID = NoteHpoClinPhen.NOTES_ROW_ID
                        
                        UNION ALL
                        
                        SELECT
                            NOTEEVENTS.SUBJECT_ID, NOTEEVENTS.HADM_ID, Inferred_NoteHpo.INFERRED_TO AS MAP_TO
                        FROM 
                            NOTEEVENTS 
                        JOIN Inferred_NoteHpo on NOTEEVENTS.ROW_ID = Inferred_NoteHpo.NOTEEVENT_ROW_ID
                        )
                    SELECT SUBJECT_ID, HADM_ID, MAP_TO, COUNT(*) AS OCCURRANCE, 1 AS dummy
                    FROM abnorm 
                    GROUP BY SUBJECT_ID, HADM_ID, MAP_TO
                ''')
        
    else:
        cursor.execute('''
                    CREATE TEMPORARY TABLE IF NOT EXISTS JAX_p_text
                    WITH abnorm AS (
                        SELECT
                            NOTEEVENTS.SUBJECT_ID, NOTEEVENTS.HADM_ID, NoteHpoClinPhen.MAP_TO
                        FROM 
                            NOTEEVENTS 
                        JOIN NoteHpoClinPhen on NOTEEVENTS.ROW_ID = NoteHpoClinPhen.NOTES_ROW_ID)
                    SELECT SUBJECT_ID, HADM_ID, MAP_TO
                    FROM abnorm 
                    GROUP BY SUBJECT_ID, HADM_ID, MAP_TO, COUNT(*) AS OCCURRANCE, 1 AS dummy
                ''')
        
def indexTextHpoProfile():
    """
    Create indeces to speed up query
    """
    #_idx01 is unnecessary if _idx3 exists
    #cursor.execute('CREATE INDEX JAX_textHpoProfile_idx01 ON JAX_textHpoProfile (SUBJECT_ID, HADM_ID)')
    cursor.execute('CREATE INDEX JAX_textHpoProfile_idx02 ON JAX_textHpoProfile (MAP_TO);')
    cursor.execute('CREATE INDEX JAX_textHpoProfile_idx03 ON JAX_textHpoProfile (SUBJECT_ID, HADM_ID, MAP_TO)')
    cursor.execute('CREATE INDEX JAX_textHpoProfile_idx04 ON JAX_textHpoProfile (OCCURRANCE)')
    
def labHpoProfile(include_inferred=True):
    """
    Set up a table for lab tests-derived phenotypes. By default, also include phenotypes that are inferred from direct mapping.
    Similar to textHpoProfile, this could be created as a perminent table. 
    """
    cursor.execute('''DROP TEMPORARY TABLE IF EXISTS JAX_labHpoProfile''')
    if include_inferred:
        cursor.execute('''
                    CREATE TEMPORARY TABLE IF NOT EXISTS JAX_labHpoProfile
                    WITH abnorm AS (
                        SELECT
                            LABEVENTS.SUBJECT_ID, LABEVENTS.HADM_ID, LabHpo.MAP_TO
                        FROM 
                            LABEVENTS 
                        JOIN LabHpo on LABEVENTS.ROW_ID = LabHpo.ROW_ID
                        WHERE LabHpo.NEGATED = 'F'
                        
                        UNION ALL
                        
                        SELECT 
                            LABEVENTS.SUBJECT_ID, LABEVENTS.HADM_ID, INFERRED_LABHPO.INFERRED_TO AS MAP_TO 
                        FROM 
                            INFERRED_LABHPO 
                        JOIN 
                            LABEVENTS ON INFERRED_LABHPO.LABEVENT_ROW_ID = LABEVENTS.ROW_ID
                        )
                    SELECT SUBJECT_ID, HADM_ID, MAP_TO, COUNT(*) AS OCCURRANCE, 1 AS dummy
                    FROM abnorm 
                    GROUP BY SUBJECT_ID, HADM_ID, MAP_TO
                ''')
    else:       
        cursor.execute('''
                    CREATE TEMPORARY TABLE IF NOT EXISTS JAX_labHpoProfile
                    WITH abnorm AS (
                        SELECT
                            LABEVENTS.SUBJECT_ID, LABEVENTS.HADM_ID, LabHpo.MAP_TO
                        FROM 
                            LABEVENTS 
                        JOIN LabHpo on LABEVENTS.ROW_ID = LabHpo.ROW_ID
                        WHERE LabHpo.NEGATED = 'F')
                    SELECT SUBJECT_ID, HADM_ID, MAP_TO, COUNT(*) AS OCCURRANCE, 1 AS dummy
                    FROM abnorm 
                    GROUP BY SUBJECT_ID, HADM_ID, MAP_TO
                ''')

def indexLabHpoProfile():
    #_idx01 is not necessary if _idx3 exists
    #cursor.execute('CREATE INDEX JAX_labHpoProfile_idx01 ON JAX_labHpoProfile (SUBJECT_ID, HADM_ID)')
    cursor.execute('CREATE INDEX JAX_labHpoProfile_idx02 ON JAX_labHpoProfile (MAP_TO);')
    cursor.execute('CREATE INDEX JAX_labHpoProfile_idx03 ON JAX_labHpoProfile (SUBJECT_ID, HADM_ID, MAP_TO)')
    cursor.execute('CREATE INDEX JAX_labHpoProfile_idx04 ON JAX_labHpoProfile (OCCURRANCE)')
    
def rankICD():
    """
    Rank frequently seen ICD-9 codes (first three or four digits) among encounters of interest.
    """
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_diagFrequencyRank')
    cursor.execute("""
        CREATE TEMPORARY TABLE IF NOT EXISTS JAX_diagFrequencyRank
        WITH JAX_temp_diag AS (
            SELECT DISTINCT SUBJECT_ID, HADM_ID, 
                CASE 
                    WHEN(ICD9_CODE LIKE 'V%') THEN SUBSTRING(ICD9_CODE, 1, 3) 
                    WHEN(ICD9_CODE LIKE 'E%') THEN SUBSTRING(ICD9_CODE, 1, 4) 
                ELSE 
                    SUBSTRING(ICD9_CODE, 1, 3) END AS ICD9_CODE 
            FROM JAX_diagnosisProfile)
        SELECT 
            ICD9_CODE, COUNT(*) AS N
        FROM
            JAX_temp_diag
        GROUP BY 
            ICD9_CODE
        ORDER BY N
        DESC
        """)

def rankHpoFromText(diagnosis, hpo_min_occurrence_per_encounter):
    """
    Rank frequently seen phenotypes (HPO term) from text mining among encounters of interest. 
    An encounter may have multiple occurrances of a phenotype term. A phenotype is called if its occurrance
    meets a minimum threshold. 
    @param hpo_min_occurrence_per_encounter: threshold for a phenotype abnormality to be called. Usually use 1. 
    """
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_textHpoFrequencyRank')
    cursor.execute('''
            CREATE TEMPORARY TABLE JAX_textHpoFrequencyRank            
            WITH pd AS(
                SELECT 
                    JAX_textHpoProfile.*
                FROM 
                    JAX_textHpoProfile 
                JOIN (
                    SELECT 
                        DISTINCT SUBJECT_ID, HADM_ID
                    FROM 
                        JAX_diagnosisProfile 
                    WHERE 
                        ICD9_CODE LIKE '{}%') AS d
                ON 
                    JAX_textHpoProfile.SUBJECT_ID = d.SUBJECT_ID AND JAX_textHpoProfile.HADM_ID = d.HADM_ID
                WHERE 
                    OCCURRANCE >= {})
            SELECT 
                MAP_TO, COUNT(*) AS N, 1 AS PHENOTYPE
            FROM pd
            GROUP BY MAP_TO
            ORDER BY N DESC'''.format(diagnosis, hpo_min_occurrence_per_encounter))
    
def rankHpoFromLab(diagnosis, hpo_min_occurrence_per_encounter):
    """
    Rank frequently seen phenotypes (HPO term) from lab texts among encounters of interest. 
    An encounter may have multiple occurrances of a phenotype term, such as from lab tests that are frequently ordered.
    A phenotype is called if its occurrance meets a minimum threshold.
    @param hpo_min_occurrence_per_encounter: threshold for a phenotype abnormality to be called. 
    For example, if the parameter is set to 3, HP:0002153 Hyperkalemia is assigned iff three or more lab tests return higher than normal values for blood potassium concentrations
    """
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_labHpoFrequencyRank')
    cursor.execute('''
            CREATE TEMPORARY TABLE JAX_labHpoFrequencyRank            
            WITH pd AS(
                SELECT 
                    JAX_labHpoProfile.*
                FROM 
                    JAX_labHpoProfile 
                JOIN (
                    SELECT 
                        DISTINCT SUBJECT_ID, HADM_ID
                    FROM 
                        JAX_diagnosisProfile 
                    WHERE 
                        ICD9_CODE LIKE '{}%') AS d
                ON 
                    JAX_labHpoProfile.SUBJECT_ID = d.SUBJECT_ID AND JAX_labHpoProfile.HADM_ID = d.HADM_ID
                WHERE
                    OCCURRANCE >= {})
            SELECT 
                MAP_TO, COUNT(*) AS N, 1 AS PHENOTYPE
            FROM pd
            GROUP BY MAP_TO
            ORDER BY N DESC'''.format(diagnosis, hpo_min_occurrence_per_encounter))  
    
def createDiagnosisTable(diagnosis, primary_diagnosis_only):
    """
    Create a temporary table JAX_mf_diag. For encounters of interest, assign 0 or 1 to each encouter whether a diagnosis is observed.
    @param diagnosis: diagnosis code. An encounter is considered to be 1 if same or more detailed code is called. 
    @prarm primary_diagnosis_only: an encounter may be associated with one primary diagnosis and many secondary ones. 
    if value is set true, only primary diagnosis counts.  
    """
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_mf_diag')
    if primary_diagnosis_only:
        limit = 'AND SEQ_NUM=1'
    else:
        limit = ''
    cursor.execute('''
                CREATE TEMPORARY TABLE IF NOT EXISTS JAX_mf_diag 
                WITH 
                    d AS (
                        SELECT 
                            DISTINCT SUBJECT_ID, HADM_ID, '1' AS DIAGNOSIS
                        FROM 
                            JAX_diagnosisProfile 
                        WHERE ICD9_CODE LIKE '{}%' {})
                    -- This is encounters with positive diagnosis

                SELECT 
                    DISTINCT a.SUBJECT_ID, a.HADM_ID, IF(d.DIAGNOSIS IS NULL, '0', '1') AS DIAGNOSIS
                FROM 
                    JAX_encounterOfInterest AS a
                LEFT JOIN
                    d ON a.SUBJECT_ID = d.SUBJECT_ID AND a.HADM_ID = d.HADM_ID       
                /* -- This is the first join for diagnosis (0, or 1) */    
                '''.format(diagnosis, limit))
    cursor.execute('CREATE INDEX JAX_mf_diag_idx01 ON JAX_mf_diag (SUBJECT_ID, HADM_ID)')


def initTables(debug=False):
    """
    This combines LabHpo and Inferred_LabHpo, and combines TextHpo and Inferred_TextHpo. 
    Only need to run once. For efficiency consideration, the tables can also be created as perminent. 
    It is time-consuming, so call it with caution. 
    """
    #init textHpoProfile and index it
    #I created perminant tables to save time; other users should enable them
    #textHpoProfile(include_inferred=True, threshold=1)
    #indexTextHpoProfile()
    #init labHpoProfile and index it
    #labHpoProfile(threshold=1, include_inferred=True, force_update=True)
    #indexLabHpoProfile()
    
    #define encounters to analyze
    encounterOfInterest(debug)
    indexEncounterOfInterest()
    #init diagnosisProfile
    diagnosisProfile()
   

In [6]:
def indexDiagnosisTable():
    cursor.execute("ALTER TABLE JAX_mf_diag ADD COLUMN ROW_ID INT AUTO_INCREMENT PRIMARY KEY;")
    
def batch_query(start_index, 
                end_index, 
                textHpo_occurrance_min, 
                labHpo_occurrance_min, 
                textHpo_threshold_min, 
                textHpo_threshold_max, 
                labHpo_threshold_min, 
                labHpo_threshold_max):
    """
    Queries databases in small batches, return diagnosis values, phenotypes from text data and phenotypes from lab data.
    @param start_index: minimum row_id
    @param end_index: maximum row_id
    @param textHpo_occurrance_min: minimum occurrances of a phenotype from text data for it to be called in one encounter
    @param labHpo_occurrance_max: maximum occurrances of a phenotype from lab tests for it to be called in one encounter
    @param textHpo_threshold_min: minimum number of encounters of a phenotypes from text data for it to be analyzed
    @param textHpo_threshold_max: maximum number of encounters of a phenotypes from text data for it to be analyzed
    @param labHpo_threshold_min: minimum number of encounters of a phenotype from lab tests for it to be analyzed
    @param labHpo_threshold_max: maximum number of encounters of a phenotype from lab tests for it to be analyzed
    """
    diagnosisVector = pd.read_sql_query('''
        SELECT * FROM JAX_mf_diag WHERE ROW_ID BETWEEN {} AND {}
    '''.format(start_index, end_index), mydb)
    
    textHpoFlat = pd.read_sql_query('''
        WITH encounters AS (
            SELECT SUBJECT_ID, HADM_ID
            FROM JAX_mf_diag 
            WHERE ROW_ID BETWEEN {} AND {}
        ), 
        textHpoOfInterest AS (
            SELECT MAP_TO 
            FROM JAX_textHpoFrequencyRank 
            WHERE N BETWEEN {} AND {}
        ), 
        joint as (
            SELECT *
            FROM encounters 
            JOIN textHpoOfInterest),
        JAX_textHpoProfile_filtered AS (
            SELECT * 
            FROM JAX_textHpoProfile 
            WHERE OCCURRANCE >= {}
        )
        
        SELECT L.SUBJECT_ID, L.HADM_ID, L.MAP_TO, IF(R.dummy IS NULL, 0, 1) AS VALUE
        FROM joint as L
        LEFT JOIN 
        JAX_textHpoProfile_filtered AS R
        ON L.SUBJECT_ID = R.SUBJECT_ID AND L.HADM_ID = R.HADM_ID AND L.MAP_TO = R.MAP_TO  
    '''.format(start_index, end_index, textHpo_threshold_min, textHpo_threshold_max, textHpo_occurrance_min), mydb)
    
    labHpoFlat = pd.read_sql_query('''
        WITH encounters AS (
            SELECT SUBJECT_ID, HADM_ID
            FROM JAX_mf_diag 
            WHERE ROW_ID BETWEEN {} AND {}
        ), 
        labHpoOfInterest AS (
            SELECT MAP_TO 
            FROM JAX_labHpoFrequencyRank 
            WHERE N BETWEEN {} AND {}
        ), 
        joint as (
            SELECT *
            FROM encounters 
            JOIN labHpoOfInterest),
        JAX_labHpoProfile_filtered AS (
            SELECT * 
            FROM JAX_labHpoProfile 
            WHERE OCCURRANCE >= {}
        )
        
        SELECT L.SUBJECT_ID, L.HADM_ID, L.MAP_TO, IF(R.dummy IS NULL, 0, 1) AS VALUE
        FROM joint as L
        LEFT JOIN 
        JAX_labHpoProfile_filtered AS R
        ON L.SUBJECT_ID = R.SUBJECT_ID AND L.HADM_ID = R.HADM_ID AND L.MAP_TO = R.MAP_TO
    '''.format(start_index, end_index, labHpo_threshold_min, labHpo_threshold_max, labHpo_occurrance_min), mydb)
    
    return diagnosisVector, textHpoFlat, labHpoFlat

def summarize_diagnosis_textHpo_labHpo(primary_diagnosis_only, 
                                       textHpo_occurrance_min, 
                                       labHpo_occurrance_min, 
                                       diagnosis_threshold_min, 
                                       textHpo_threshold_min, 
                                       textHpo_threshold_max, 
                                       labHpo_threshold_min, 
                                       labHpo_threshold_max,
                                       disease_of_interest,
                                       logger):
    """
    Iterate database to get summary statistics. For each disease of interest, automatically determine a list of phenotypes derived from labs (labHpo) and a list of phenotypes from text mining (textHpo). For each pair of phenotypes, count the number of encounters according to whether the phenotypes and diagnosis are observated.    
    @param primary_diagnosis_only: only primary diagnosis is analyzed
    @param textHpo_occurrance_min: minimum occurrances of a phenotype from text data for it to be called in one encounter
    @param labHpo_occurrance_max: maximum occurrances of a phenotype from lab tests for it to be called in one encounter
    @param textHpo_threshold_min: minimum number of encounters of a phenotypes from text data for it to be analyzed
    @param textHpo_threshold_max: maximum number of encounters of a phenotypes from text data for it to be analyzed
    @param labHpo_threshold_min: minimum number of encounters of a phenotype from lab tests for it to be analyzed
    @param labHpo_threshold_max: maximum number of encounters of a phenotype from lab tests for it to be analyzed
    @param disease_of_interest: either set to "calculated", or a list of ICD-9 codes (get all possible codes from temp table JAX_diagFrequencyRank)
    @param logger: logger for logging
    
    :return: three dictionaries of summary statistics, of which the keys are diagnosis codes and the values are instances of the SummaryXYz class. 
    First dictionary, X (a list of phenotype variables) are from textHpo and Y are from labHpo; 
    Secondary dictionary, all terms in X, Y are from textHpo;
    Third dictionary, all terms in X, Y are all from labHpo. Note that terms in X and Y are calculated separately for each diagnosis and may be different.
    """
    logger.info('starting iterate_in_batch()')
    batch_size = 100
    
    # define a set of diseases that we want to analyze
    rankICD()
    
    if disease_of_interest == 'calculated': 
        diseaseOfInterest = pd.read_sql_query("SELECT * FROM JAX_diagFrequencyRank WHERE N > {}".format(diagnosis_threshold_min), mydb).ICD9_CODE.values
    elif isinstance(disease_of_interest, list) and len(disease_of_interest) > 0:
        #disable the following line to analyze all diseases of interest
        #diseaseOfInterest = ['428', '584', '038', '493']
        diseaseOfInterest = disease_of_interest
    else:
        raise RuntimeError
    logger.info('diagnosis of interest: {}'.format(len(diseaseOfInterest)))
    
    summaries_diag_textHpo_labHpo = {}
    summaries_diag_textHpo_textHpo = {}
    summaries_diag_labHpo_labHpo = {}
    
    pbar = tqdm(total=len(diseaseOfInterest))
    for diagnosis in diseaseOfInterest:
        logger.info("start analyzing disease {}".format(diagnosis))
        
        logger.info(".......assigning values of diagnosis")
        # assign each encounter whether a diagnosis code is observed
        # create a table j1 (joint 1)
        createDiagnosisTable(diagnosis, primary_diagnosis_only)
        indexDiagnosisTable()
        # for every diagnosis, find phenotypes of interest to look at from radiology reports
        # for every diagnosis, find phenotypes of interest to look at from laboratory tests
        rankHpoFromText(diagnosis, textHpo_occurrance_min)
        rankHpoFromLab(diagnosis, labHpo_occurrance_min)
        logger.info("..............diagnosis values found")
        
        textHpoOfInterest = pd.read_sql_query("SELECT * FROM JAX_textHpoFrequencyRank WHERE N BETWEEN {} AND {}".format(textHpo_threshold_min, textHpo_threshold_max), mydb).MAP_TO.values
        labHpoOfInterest = pd.read_sql_query("SELECT * FROM JAX_labHpoFrequencyRank WHERE N BETWEEN {} AND {}".format(labHpo_threshold_min, labHpo_threshold_max), mydb).MAP_TO.values
        logger.info("TextHpo of interest established, size: {}".format(len(textHpoOfInterest)))
        logger.info("LabHpo of interest established, size: {}".format(len(labHpoOfInterest)))

        ## find the start and end ROW_ID for patient*encounter
        ADM_ID_START, ADM_ID_END = pd.read_sql_query('SELECT MIN(ROW_ID) AS min, MAX(ROW_ID) AS max FROM JAX_mf_diag', mydb).iloc[0]
        batch_N = ADM_ID_END - ADM_ID_START + 1
        TOTAL_BATCH = math.ceil(batch_N / batch_size) # total number of batches
        
        summaries_diag_textHpo_labHpo[diagnosis] = mf.SummaryXYz(textHpoOfInterest, labHpoOfInterest, diagnosis)
        summaries_diag_textHpo_textHpo[diagnosis] = mf.SummaryXYz(textHpoOfInterest, textHpoOfInterest, diagnosis)
        summaries_diag_labHpo_labHpo[diagnosis] = mf.SummaryXYz(labHpoOfInterest, labHpoOfInterest, diagnosis)
        
        logger.info('starting batch queries for {}'.format(diagnosis))
        for i in np.arange(TOTAL_BATCH):
            start_index = i * batch_size + ADM_ID_START
            if i < TOTAL_BATCH - 1:
                end_index = start_index + batch_size - 1
            else:
                end_index = batch_N

            diagnosisFlat, textHpoFlat, labHpoFlat =  batch_query(start_index, end_index, textHpo_occurrance_min, labHpo_occurrance_min, textHpo_threshold_min, textHpo_threshold_max, labHpo_threshold_min, labHpo_threshold_max)
            
            batch_size_actual = len(diagnosisFlat)
            textHpoOfInterest_size = len(textHpoOfInterest)
            labHpoOfInterest_size = len(labHpoOfInterest)
            #print('len(textHpoFlat)= {}, batch_size_actual={}, textHpoOfInterest_size={}'.format(len(textHpoFlat), batch_size_actual, textHpoOfInterest_size))
            assert(len(textHpoFlat) == batch_size_actual * textHpoOfInterest_size)
            assert(len(labHpoFlat) == batch_size_actual * labHpoOfInterest_size)
            
            if batch_size_actual > 0:
                diagnosisVector = diagnosisFlat.DIAGNOSIS.values.astype(int)
                # reformat the flat vector into N x M matrix, N is batch size, i.e. number of encounters, M is the length of HPO terms  
                textHpoMatrix = textHpoFlat.VALUE.values.astype(int).reshape([batch_size_actual, textHpoOfInterest_size], order='F')
                labHpoMatrix = labHpoFlat.VALUE.values.astype(int).reshape([batch_size_actual, labHpoOfInterest_size], order='F')
                # check the matrix formatting is correct
                # disable the following 4 lines to speed things up
                textHpoLabelsMatrix = textHpoFlat.MAP_TO.values.reshape([batch_size_actual, textHpoOfInterest_size], order='F')
                labHpoLabelsMatrix = labHpoFlat.MAP_TO.values.reshape([batch_size_actual, labHpoOfInterest_size], order='F')
                assert (textHpoLabelsMatrix[0, :] == textHpoOfInterest).all()
                assert (labHpoLabelsMatrix[0, :] == labHpoOfInterest).all()
                if i % 100 == 0:
                    logger.info('new batch: start_index={}, end_index={}, batch_size= {}, textHpo_size = {}, labHpo_size = {}'.format(start_index, end_index, batch_size_actual, textHpoMatrix.shape[1], labHpoMatrix.shape[1]))
                summaries_diag_textHpo_labHpo[diagnosis].add_batch(textHpoMatrix,labHpoMatrix, diagnosisVector)
                summaries_diag_textHpo_textHpo[diagnosis].add_batch(textHpoMatrix,textHpoMatrix, diagnosisVector)
                summaries_diag_labHpo_labHpo[diagnosis].add_batch(labHpoMatrix,labHpoMatrix, diagnosisVector)
         
        pbar.update(1)
        
    pbar.close()
    
    return summaries_diag_textHpo_labHpo, summaries_diag_textHpo_textHpo, summaries_diag_labHpo_labHpo

##  --TEST

In [18]:
# how to run this

logger = logging.getLogger()
logger.setLevel(logging.INFO)

# 1. build the temp tables for Lab converted HPO, Text convert HPO
# Read the comments within the method!
initTables(debug=True)

# 2. iterate throw the dataset
primary_diagnosis_only = True
diagnosis_threshold_min = 5
textHpo_occurrance_min, labHpo_occurrance_min = 1, 3
textHpo_threshold_min, textHpo_threshold_max = 7, 100
labHpo_threshold_min, labHpo_threshold_max = 7, 100
disease_of_interest = ['428', '584', '038', '493', '*']

summaries_diag_textHpo_labHpo, summaries_diag_textHpo_textHpo, summaries_diag_labHpo_labHpo = summarize_diagnosis_textHpo_labHpo(primary_diagnosis_only, textHpo_occurrance_min, labHpo_occurrance_min, diagnosis_threshold_min, textHpo_threshold_min, textHpo_threshold_max, labHpo_threshold_min, labHpo_threshold_max, disease_of_interest, logger)

2019-12-20 15:09:58,034 - 11200 - root - INFO - starting iterate_in_batch()
2019-12-20 15:09:58,038 - 11200 - root - INFO - diagnosis of interest: 5


  0%|          | 0/5 [00:00<?, ?it/s]

2019-12-20 15:09:58,040 - 11200 - root - INFO - start analyzing disease 428
2019-12-20 15:09:58,041 - 11200 - root - INFO - .......assigning values of diagnosis
2019-12-20 15:09:58,056 - 11200 - root - INFO - ..............diagnosis values found
2019-12-20 15:09:58,061 - 11200 - root - INFO - TextHpo of interest established, size: 28
2019-12-20 15:09:58,062 - 11200 - root - INFO - LabHpo of interest established, size: 52
2019-12-20 15:09:58,065 - 11200 - root - INFO - starting batch queries for 428
2019-12-20 15:09:58,173 - 11200 - root - INFO - new batch: start_index=1, end_index=100, batch_size= 100, textHpo_size = 28, labHpo_size = 52


 20%|██        | 1/5 [00:00<00:00,  6.84it/s]

2019-12-20 15:09:58,187 - 11200 - root - INFO - start analyzing disease 584
2019-12-20 15:09:58,188 - 11200 - root - INFO - .......assigning values of diagnosis
2019-12-20 15:09:58,200 - 11200 - root - INFO - ..............diagnosis values found
2019-12-20 15:09:58,204 - 11200 - root - INFO - TextHpo of interest established, size: 9
2019-12-20 15:09:58,205 - 11200 - root - INFO - LabHpo of interest established, size: 29
2019-12-20 15:09:58,208 - 11200 - root - INFO - starting batch queries for 584
2019-12-20 15:09:58,269 - 11200 - root - INFO - new batch: start_index=1, end_index=100, batch_size= 100, textHpo_size = 9, labHpo_size = 29
2019-12-20 15:09:58,274 - 11200 - root - INFO - start analyzing disease 038
2019-12-20 15:09:58,275 - 11200 - root - INFO - .......assigning values of diagnosis
2019-12-20 15:09:58,287 - 11200 - root - INFO - ..............diagnosis values found
2019-12-20 15:09:58,292 - 11200 - root - INFO - TextHpo of interest established, size: 14
2019-12-20 15:09:58,

 60%|██████    | 3/5 [00:00<00:00,  7.75it/s]

2019-12-20 15:09:58,365 - 11200 - root - INFO - start analyzing disease 493
2019-12-20 15:09:58,366 - 11200 - root - INFO - .......assigning values of diagnosis
2019-12-20 15:09:58,373 - 11200 - root - INFO - ..............diagnosis values found
2019-12-20 15:09:58,377 - 11200 - root - INFO - TextHpo of interest established, size: 0
2019-12-20 15:09:58,378 - 11200 - root - INFO - LabHpo of interest established, size: 0
2019-12-20 15:09:58,380 - 11200 - root - INFO - starting batch queries for 493
2019-12-20 15:09:58,388 - 11200 - root - INFO - new batch: start_index=1, end_index=100, batch_size= 100, textHpo_size = 0, labHpo_size = 0
2019-12-20 15:09:58,390 - 11200 - root - INFO - start analyzing disease *
2019-12-20 15:09:58,391 - 11200 - root - INFO - .......assigning values of diagnosis
2019-12-20 15:09:58,397 - 11200 - root - INFO - ..............diagnosis values found
2019-12-20 15:09:58,401 - 11200 - root - INFO - TextHpo of interest established, size: 0
2019-12-20 15:09:58,401 -

100%|██████████| 5/5 [00:00<00:00, 13.39it/s]


In [10]:
summaries_diag_textHpo_labHpo['038'].m2.shape

(14, 25, 8)

## --PRODUCTION

In [10]:
# how to run this
# Again, it take either too long or too much memory space to run
logger = logging.getLogger()
logger.setLevel(logging.WARN)

# 1. build the temp tables for Lab converted HPO, Text convert HPO
# Read the comments within the method!
initTables(debug=False)

# 2. iterate throw the dataset
primary_diagnosis_only = True
diagnosis_threshold_min = 3000
textHpo_threshold_min, textHpo_threshold_max = 500, 100000
labHpo_threshold_min, labHpo_threshold_max = 1000, 100000
textHpo_occurrance_min, labHpo_occurrance_min = 1, 3
disease_of_interest = ['428', '584', '038', '493']

summaries_diag_textHpo_labHpo, summaries_diag_textHpo_textHpo, summaries_diag_labHpo_labHpo = summarize_diagnosis_textHpo_labHpo(primary_diagnosis_only, textHpo_occurrance_min, labHpo_occurrance_min, diagnosis_threshold_min, textHpo_threshold_min, textHpo_threshold_max, labHpo_threshold_min, labHpo_threshold_max, disease_of_interest, logger)

100%|██████████| 4/4 [00:00<00:00, 39.62it/s]


In [11]:
if primary_diagnosis_only:
    fName_diag_textHpo_labHpo = '../../../data/mf_regarding_diseases/primary_only/summaries_diagnosis_textHpo_labHpo.obj'
    fName_diag_textHpo_textHpo = '../../../data/mf_regarding_diseases/primary_only/summaries_diagnosis_textHpo_textHpo.obj'
    fName_diag_labHpo_labHpo = '../../../data/mf_regarding_diseases/primary_only/summaries_diagnosis_labHpo_labHpo.obj'
else:
    fName_diag_textHpo_labHpo = '../../../data/mf_regarding_diseases/primary_and_secondary/summaries_diagnosis_textHpo_labHpo.obj'
    fName_diag_textHpo_textHpo = '../../../data/mf_regarding_diseases/primary_and_secondary/summaries_diagnosis_textHpo_textHpo.obj'
    fName_diag_labHpo_labHpo = '../../../data/mf_regarding_diseases/primary_and_secondary/summaries_diagnosis_labHpo_labHpo.obj'

with open(fName_diag_textHpo_labHpo, 'wb') as f:
    pickle.dump(summaries_diag_textHpo_labHpo, f)
with open(fName_diag_textHpo_textHpo, 'wb') as f:
    pickle.dump(summaries_diag_textHpo_textHpo, f)
with open(fName_diag_labHpo_labHpo, 'wb') as f:
    pickle.dump(summaries_diag_labHpo_labHpo, f)

In [16]:
summaries_diag_textHpo_labHpo

{'428': <mf.SummaryXYz at 0xa17c01630>,
 '584': <mf.SummaryXYz at 0x1050d9cf8>,
 '038': <mf.SummaryXYz at 0xa1776c048>,
 '493': <mf.SummaryXYz at 0xa17b37b70>}

In [20]:
summaries_diag_textHpo_textHpo

{'428': <mf.SummaryXYz at 0xa17c01cc0>,
 '584': <mf.SummaryXYz at 0xa1776bf98>,
 '038': <mf.SummaryXYz at 0xa1776bef0>,
 '493': <mf.SummaryXYz at 0x81753fa90>}

In [21]:
summaries_diag_labHpo_labHpo

{'428': <mf.SummaryXYz at 0xa17c01a90>,
 '584': <mf.SummaryXYz at 0xa1776bb38>,
 '038': <mf.SummaryXYz at 0xa17bba978>,
 '493': <mf.SummaryXYz at 0xa1776bdd8>}

In [8]:
hpo = Ontology('/Users/zhangx/git/human-phenotype-ontology/hp.obo')

## Synergy network

In [9]:
def build_multi_variant_table(labHpos, radHpos, diag):
    """
    Compute the mutual information between multiple phenotypes and the diagnosis.
    Algorithm:
    1. create a temp table Jax_multivariant_synergy_table (SUBJECT_ID, HADM_ID, diag_value)
    2. for each phenotype, finding whether they are present, add a column 
    3. using the combined table, calculate the mutual information of 
    @param labHpos: a set of lab-derived HPO terms
    @param radHpos: a set of text-derived HPO terms
    @param diag: a diagnosis code
    """
    # diagnosis column
    # create diagnosis table
    
    # iterate labHpos
    
    # iterate textHpos
    
    
    pass


def add_diag_columns(diagnosis, primary_diagnosis_only):
    createDiagnosisTable(diagnosis, primary_diagnosis_only)
    # copy into a new table Jax_multivariant_synergy_table(SUBJECT_ID, HADM_ID, DIAGNOSIS)
    cursor.execute("""
        CREATE TEMPORARY TABLE IF NOT EXISTS Jax_multivariant_synergy_table AS (
            SELECT * 
            FROM JAX_mf_diag
        )""")
    cursor.execute('CREATE INDEX Jax_multivariant_synergy_table_idx01 ON JAX_mf_diag (SUBJECT_ID, HADM_ID)')
 
    
def add_phenotype_columns(labHpos, textHpos, labHpo_threshold_min, textHpo_threshold_min):
    # save the variable transformation for later use
    var_dict = {}
    i = 0
    for labHpo in labHpos:
        i = i + 1
        colName = 'V' + str(i)
        var_dict[colName] = ('LabHpo', labHpo)
        cursor.execute("""
            ALTER TABLE Jax_multivariant_synergy_table ADD COLUMN {} INT DEFAULT 0""".format(colName))
        cursor.execute("""
            UPDATE Jax_multivariant_synergy_table 
            LEFT JOIN JAX_labHpoProfile 
            ON Jax_multivariant_synergy_table.SUBJECT_ID = JAX_labHpoProfile.SUBJECT_ID AND 
            Jax_multivariant_synergy_table.HADM_ID = JAX_labHpoProfile.HADM_ID
            SET {} = IF(JAX_labHpoProfile.OCCURRANCE > {}, 1, 0)
            WHERE JAX_labHpoProfile.MAP_TO = '{}'
        """.format(colName, labHpo_threshold_min, labHpo))
        
    for textHpo in textHpos:
        i = i + 1
        colName = 'V' + str(i)
        var_dict[colName] = ('TextHpo', textHpo)
        cursor.execute("""
            ALTER TABLE Jax_multivariant_synergy_table ADD COLUMN {} INT DEFAULT 0""".format(colName))
        cursor.execute("""
            UPDATE Jax_multivariant_synergy_table 
            LEFT JOIN JAX_textHpoProfile 
            ON Jax_multivariant_synergy_table.SUBJECT_ID = JAX_textHpoProfile.SUBJECT_ID AND 
            Jax_multivariant_synergy_table.HADM_ID = JAX_textHpoProfile.HADM_ID
            SET {} = IF(JAX_textHpoProfile.OCCURRANCE > {}, 1, 0)
            WHERE JAX_textHpoProfile.MAP_TO = '{}'
        """.format(colName, textHpo_threshold_min, textHpo))
    return var_dict
        

def precompute_mf(variables):
    """
    Compute the mutual information between the joint distribution of all the variables and the medical outcome
    """ 
    summary_counts = pd.read_sql_query("""
        WITH summary AS (
        SELECT {}, DIAGNOSIS, COUNT(*) AS N
        FROM Jax_multivariant_synergy_table
        GROUP BY {}, DIAGNOSIS)
        SELECT *, SUM(N) OVER (PARTITION BY {}) AS V, SUM(N) OVER (PARTITION BY DIAGNOSIS) AS D
        FROM summary
    """.format(','.join(variables), ','.join(variables), ','.join(variables)), mydb)
    total = np.sum(summary_counts.N)
    p = summary_counts.N / total
    p_V = summary_counts.V / total
    p_D = summary_counts.D / total
    mf = np.sum(p * np.log2(p / (p_V * p_D)))
    return mf, summary_counts


def precompute_mf_dict(var_ids):
    var_subsets = synergy_tree.subsets(var_ids, include_self=True)
    mf_dict = {}
    summary_dict = {}
    pbar = tqdm_notebook(total=len(var_subsets))
    for var_subset in var_subsets: 
        mf, summary_count = precompute_mf(var_subset)
        mf_dict[var_subset] = mf
        summary_dict[var_subset] = summary_count
        pbar.update(1)
    pbar.close()
    
    return mf_dict, summary_dict       

## Test

In [10]:
diagnosis = '038'
textHpo_occurrance_min=1
labHpo_occurrance_min=3
textHpo_threshold_min=7
textHpo_threshold_max=7
labHpo_threshold_min=8
labHpo_threshold_max=8
primary_diagnosis_only=True

initTables(debug=True)

rankHpoFromText(diagnosis, textHpo_occurrance_min)
rankHpoFromLab(diagnosis, labHpo_occurrance_min)
#logger.info("..............diagnosis values found")

textHpoOfInterest = pd.read_sql_query("SELECT * FROM JAX_textHpoFrequencyRank WHERE N BETWEEN {} AND {}".format(textHpo_threshold_min, textHpo_threshold_max), mydb).MAP_TO.values
labHpoOfInterest = pd.read_sql_query("SELECT * FROM JAX_labHpoFrequencyRank WHERE N BETWEEN {} AND {}".format(labHpo_threshold_min, labHpo_threshold_max), mydb).MAP_TO.values

In [11]:
# manually trim phenotypes TODO: further filter them
print(labHpoOfInterest)
print(textHpoOfInterest)
textHpoOfInterest = ['HP:0001877', 'HP:0020058', 'HP:0010927', 'HP:0001871', 'HP:0010929']
labHpoOfInterest = ['HP:0002202', 'HP:0011032', 'HP:0100750']
print("filtered phenotypes:")
print(labHpoOfInterest)
print(textHpoOfInterest)


cursor.execute("""drop table if exists Jax_multivariant_synergy_table""", mydb)

add_diag_columns(diagnosis, primary_diagnosis_only)
var_dict = add_phenotype_columns(labHpos=labHpoOfInterest, \
                                 textHpos=textHpoOfInterest, \
                                 labHpo_threshold_min = labHpo_occurrance_min, \
                                 textHpo_threshold_min = textHpo_occurrance_min)
mf_dict, summary_dict = precompute_mf_dict(var_dict.keys())

['HP:0001877' 'HP:0020058' 'HP:0010927' 'HP:0001871' 'HP:0010929'
 'HP:0020061' 'HP:0003111' 'HP:0012337' 'HP:0032180' 'HP:0011015'
 'HP:0001939' 'HP:0011014' 'HP:0000118' 'HP:0000001' 'HP:0031850']
['HP:0002202' 'HP:0011032' 'HP:0100750' 'HP:0000969' 'HP:0002597'
 'HP:0012337' 'HP:0030680']
filtered phenotypes:
['HP:0002202', 'HP:0011032', 'HP:0100750']
['HP:0001877', 'HP:0020058', 'HP:0010927', 'HP:0001871', 'HP:0010929']


HBox(children=(IntProgress(value=0, max=255), HTML(value='')))

2019-12-20 14:40:39,890 - 11200 - numexpr.utils - INFO - NumExpr defaulting to 8 threads.



In [12]:
syntree_038 = synergy_tree.SynergyTree(var_dict.keys(), var_dict, mf_dict)
syntree_038.synergy_tree().show()
print(var_dict)

('V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'V8')
├── ('V1', 'V2', 'V4', 'V7')
│   ├── ('V1',)
│   ├── ('V2', 'V7')
│   │   ├── ('V2',)
│   │   └── ('V7',)
│   └── ('V4',)
├── ('V3', 'V5', 'V6')
│   ├── ('V3', 'V6')
│   │   ├── ('V3',)
│   │   └── ('V6',)
│   └── ('V5',)
└── ('V8',)

{'V1': ('LabHpo', 'HP:0002202'), 'V2': ('LabHpo', 'HP:0011032'), 'V3': ('LabHpo', 'HP:0100750'), 'V4': ('TextHpo', 'HP:0001877'), 'V5': ('TextHpo', 'HP:0020058'), 'V6': ('TextHpo', 'HP:0010927'), 'V7': ('TextHpo', 'HP:0001871'), 'V8': ('TextHpo', 'HP:0010929')}


## Production

In [None]:
diagnosis = '038'
textHpo_occurrance_min=1
labHpo_occurrance_min=3
textHpo_threshold_min=500
textHpo_threshold_max=100000
labHpo_threshold_min=1000
labHpo_threshold_max=100000
primary_diagnosis_only=True

initTables(debug=False)

rankHpoFromText(diagnosis, textHpo_occurrance_min)
rankHpoFromLab(diagnosis, labHpo_occurrance_min)
#logger.info("..............diagnosis values found")

textHpoOfInterest = pd.read_sql_query("SELECT * FROM JAX_textHpoFrequencyRank WHERE N BETWEEN {} AND {}".format(textHpo_threshold_min, textHpo_threshold_max), mydb).MAP_TO.values
labHpoOfInterest = pd.read_sql_query("SELECT * FROM JAX_labHpoFrequencyRank WHERE N BETWEEN {} AND {}".format(labHpo_threshold_min, labHpo_threshold_max), mydb).MAP_TO.values

In [14]:
# manually trim phenotypes TODO: further filter them
textHpoOfInterest = ['HP:0002107', 'HP:0000969', 'HP:0001945', 'HP:0001626']
labHpoOfInterest = ['HP:0020062', 'HP:0020063', 'HP:0012419', 'HP:0002151', 'HP:0012417', 
                    'HP:0020060', 'HP:0020059', 'HP:0002148']
print("filtered phenotypes:")
print(labHpoOfInterest)
print(textHpoOfInterest)


cursor.execute("""drop table if exists Jax_multivariant_synergy_table""", mydb)

add_diag_columns(diagnosis, primary_diagnosis_only)
var_dict = add_phenotype_columns(labHpos=labHpoOfInterest, \
                                 textHpos=textHpoOfInterest, \
                                 labHpo_threshold_min = labHpo_occurrance_min, \
                                 textHpo_threshold_min = textHpo_occurrance_min)
mf_dict, summary_dict = precompute_mf_dict(var_dict.keys())

filtered phenotypes:
['HP:0020062', 'HP:0020063', 'HP:0012419', 'HP:0002151', 'HP:0012417', 'HP:0020060', 'HP:0020059', 'HP:0002148']
['HP:0002107', 'HP:0000969', 'HP:0001945', 'HP:0001626']


HBox(children=(IntProgress(value=0, max=4095), HTML(value='')))




In [None]:
syntree_038 = synergy_tree.SynergyTree(var_dict.keys(), var_dict, mf_dict)
syntree_038.synergy_tree().show()
term_map = hpo.term_id_2_label_map
for k, v in var_dict.items():
    print('{}: {} {}'.format(k, v[1], term_map[v[1]]))

## Mutual information between phenotypes of radiology and lab tests regardless of diseases
Without considering diseases, we want to determine the mutual information between pairs of phenotypes from radiology reports and lab tests. The algorithm:

    * find phenotypes of interest from radiology reports
    * find phenotypes of interest from lab tests
    * in batches, create a matrix of text phenotype profiles and a matrix of lab phenotype profiles, update summary statistics

In [17]:
def batch_query_lab_text(start_index, end_index, textHpo_occurrance_min, labHpo_occurrance_min, textHpo_min, textHpo_max, labHpo_min, labHpo_max):
    
    textHpo_flat = pd.read_sql_query('''
        WITH encounters AS (
                SELECT *
                FROM JAX_encounterOfInterest
                WHERE ROW_ID BETWEEN {} AND {}),
            phenotypes AS (
                SELECT MAP_TO
                FROM JAX_textHpoFrequencyRank
                WHERE N BETWEEN {} AND {}
            ), 
            temp AS (
                SELECT * 
                FROM encounters 
                JOIN phenotypes)

            SELECT L.SUBJECT_ID, L.HADM_ID, L.MAP_TO AS PHEN_TEXT, IF(R.dummy IS NULL, 0, 1) AS PHEN_TEXT_VALUE
            FROM temp AS L
            LEFT JOIN 
                (SELECT * FROM JAX_textHpoProfile WHERE OCCURRANCE >= {}) AS R
            ON L.SUBJECT_ID = R.SUBJECT_ID AND L.HADM_ID = R.HADM_ID AND L.MAP_TO = R.MAP_TO
        '''.format(start_index, end_index, textHpo_min, textHpo_max, textHpo_occurrance_min), mydb)
    
    labHpo_flat = pd.read_sql_query('''
        WITH encounters AS (
                SELECT *
                FROM JAX_encounterOfInterest
                WHERE ROW_ID BETWEEN {} AND {}),
            phenotypes AS (
                SELECT MAP_TO
                FROM JAX_labHpoFrequencyRank
                WHERE N BETWEEN {} AND {}
            ), 
            temp AS (
                SELECT * 
                FROM encounters 
                JOIN phenotypes)

            SELECT L.SUBJECT_ID, L.HADM_ID, L.MAP_TO AS PHEN_LAB, IF(R.dummy IS NULL, 0, 1) AS PHEN_LAB_VALUE
            FROM temp AS L
            LEFT JOIN 
                (SELECT * FROM JAX_labHpoProfile WHERE OCCURRANCE >= {}) AS R
            ON L.SUBJECT_ID = R.SUBJECT_ID AND L.HADM_ID = R.HADM_ID AND L.MAP_TO = R.MAP_TO
        '''.format(start_index, end_index, labHpo_min, labHpo_max, labHpo_occurrance_min), mydb)
    
    return textHpo_flat, labHpo_flat 
    

def summary_textHpo_labHpo(batch_size, textHpo_occurrance_min, labHpo_occurrance_min, textHpo_threshold_min, textHpo_threshold_max, labHpo_threshold_min, labHpo_threshold_max):
    
    textHpoOfInterest = pd.read_sql_query("SELECT * FROM JAX_textHpoFrequencyRank WHERE N BETWEEN {} AND {}".format(textHpo_threshold_min, textHpo_threshold_max), mydb).MAP_TO.values
    labHpoOfInterest = pd.read_sql_query("SELECT * FROM JAX_labHpoFrequencyRank WHERE N BETWEEN {} AND {}".format(labHpo_threshold_min, labHpo_threshold_max), mydb).MAP_TO.values
    M1 = len(textHpoOfInterest)
    M2 = len(labHpoOfInterest)
    
    summary_rad_lab = mf.SummaryXY(textHpoOfInterest, labHpoOfInterest)
    summary_rad_rad = mf.SummaryXY(textHpoOfInterest, textHpoOfInterest)
    summary_lab_lab = mf.SummaryXY(labHpoOfInterest, labHpoOfInterest)
    
    ## find the start and end ROW_ID for patient*encounter
    
    ADM_ID_START, ADM_ID_END = pd.read_sql_query('SELECT MIN(ROW_ID) AS min, MAX(ROW_ID) AS max FROM JAX_encounterOfInterest', mydb).iloc[0]
    batch_N = ADM_ID_END - ADM_ID_START + 1
    TOTAL_BATCH = math.ceil(batch_N / batch_size) # total number of batches

    print('total batches: ' + str(batch_N))
    pbar = tqdm(total=TOTAL_BATCH)
    for i in np.arange(TOTAL_BATCH):
        start_index = i * batch_size + ADM_ID_START
        if i < TOTAL_BATCH - 1:
            end_index = start_index + batch_size - 1
        else:
            end_index = batch_N
        actual_batch_size = end_index - start_index + 1
        textHpo, labHpo = batch_query_lab_text(start_index, end_index, textHpo_occurrance_min, labHpo_occurrance_min, textHpo_threshold_min, textHpo_threshold_max, labHpo_threshold_min, labHpo_threshold_max)
        textHpo_matrix = textHpo.PHEN_TEXT_VALUE.values.astype(int).reshape([actual_batch_size, M1], order='F')
        labHpo_matrix = labHpo.PHEN_LAB_VALUE.values.astype(int).reshape([actual_batch_size, M2], order='F')
        summary_rad_lab.add_batch(textHpo_matrix, labHpo_matrix)
        summary_rad_rad.add_batch(textHpo_matrix, textHpo_matrix)
        summary_lab_lab.add_batch(labHpo_matrix, labHpo_matrix)
        pbar.update(1)
        
    pbar.close()
    
    return summary_rad_lab, summary_rad_rad, summary_lab_lab

In [None]:
mf_all_df.head()
mf_all_df.sort_values(by='mf_P1_P2', ascending=False).head()

### Test

In [None]:
encounterOfInterest(debug=True)
indexEncounterOfInterest()
diagnosisProfile()
rankHpoFromText('', hpo_min_occurrence_per_encounter=1)
rankHpoFromLab('', hpo_min_occurrence_per_encounter=3)

batch_size = 11
textHpo_threshold_min = 45
textHpo_threshold_max = 65
labHpo_threshold_min = 75
labHpo_threshold_max = 85
textHpo_occurrance_min = 1
labHpo_occurrance_min = 3

summary_rad_lab, summary_rad_rad, summary_lab_lab = summary_textHpo_labHpo(batch_size, textHpo_occurrance_min, labHpo_occurrance_min, textHpo_threshold_min,textHpo_threshold_max, labHpo_threshold_min, labHpo_threshold_max)

print(summary_rad_lab.m.shape)
print(summary_rad_rad.m.shape)
print(summary_lab_lab.m.shape)

### Production

In [18]:
encounterOfInterest(debug=False)
indexEncounterOfInterest()
diagnosisProfile()
rankHpoFromText('', hpo_min_occurrence_per_encounter=1)
rankHpoFromLab('', hpo_min_occurrence_per_encounter=3)

batch_size = 100
textHpo_threshold_min = 500
textHpo_threshold_max = 100000
labHpo_threshold_min = 1000
labHpo_threshold_max = 100000
textHpo_occurrance_min = 1
labHpo_occurrance_min = 3

summary_rad_lab, summary_rad_rad, summary_lab_lab = summary_textHpo_labHpo(batch_size, textHpo_occurrance_min, labHpo_occurrance_min, textHpo_threshold_min,textHpo_threshold_max, labHpo_threshold_min, labHpo_threshold_max)


  0%|          | 0/590 [00:00<?, ?it/s]

total batches: 58976


100%|██████████| 590/590 [32:42<00:00,  3.08s/it]


In [20]:
with open('../../../data/mf_regardless_of_diseases/summary_textHpo_labHpo.obj', 'wb') as file:
    pickle.dump(summary_rad_lab, file)
    
with open('../../../data/mf_regardless_of_diseases/summary_textHpo_textHpo.obj', 'wb') as file:
    pickle.dump(summary_rad_rad, file)
    
with open('../../../data/mf_regardless_of_diseases/summary_labHpo_labHpo.obj', 'wb') as file:
    pickle.dump(summary_lab_lab, file)

In [None]:
mf_all=mf.MutualInfoXY(summary_all)
textHpo_labels = mf_all.X_names
labHpo_labels = mf_all.Y_names
M1 = len(mf_all.X_names)
M2 = len(mf_all.Y_names)
entropy_x, entropy_y = mf_all.entropies().values()
mf_all_df = pd.DataFrame(data={'P1': np.repeat(textHpo_labels, M2), 'P2': np.tile(labHpo_labels, [M1]),'entropy_P1': np.repeat(entropy_x, M2), 'entropy_P2': np.tile(entropy_y, [M1]), 'mf_P1_P2': mf_all.mf().flat})
mf_all_df.head()
mf_all_df.to_csv('mutual_info_textHpo_labHpo.csv', index=False)

100%|██████████| 10/10 [00:00<00:00, 178.56it/s]

total batches: 100
(8, 2, 4)
(8, 8, 4)
(2, 2, 4)





TODO: move the following section to the machine learning notebook

## Feature generation for phenotype-based disease prediction

Prepare phenotypes for machine learning tasks. For specified disease of interest, find the patient info (gender, DOB), find the diagnosis info (case or control, date), and phenotype terms.  

In [None]:
encounterOfInterest(debug=False, N=100)
indexEncounterOfInterest()
diagnosisProfile()

In [None]:
patients = pd.read_sql_query('''
        SELECT 
            PATIENTS.SUBJECT_ID, PATIENTS.GENDER, PATIENTS.DOB
        FROM 
            PATIENTS
        WHERE 
            SUBJECT_ID IN (SELECT SUBJECT_ID FROM JAX_encounterOfInterest) 
    ''', mydb)

len(patients)

In [None]:
createDiagnosisTable(diagnosis='038', primary_diagnosis_only=False)

In [None]:
def first_diag_time(diagnosis):
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_first_diag_time')
    cursor.execute('''
        CREATE TEMPORARY TABLE JAX_first_diag_time
        WITH diag_time AS (
            SELECT L.*, R.ADMITTIME 
            FROM JAX_mf_diag AS L
            JOIN ADMISSIONS AS R
            ON 
                L.SUBJECT_ID = R.SUBJECT_ID AND L.HADM_ID = R.HADM_ID
        ), 
        overallDiagnosis AS (
            SELECT *, IF(SUM(DIAGNOSIS) OVER (PARTITION BY SUBJECT_ID) > 0, 1, 0) AS everDiagnosed
            FROM diag_time
        )
        
        SELECT 
            *, MIN(IF(everDiagnosed = 1, ADMITTIME, NULL) ) OVER (PARTITION BY SUBJECT_ID) AS first_diag, 
            MAX(IF(everDiagnosed = 0, ADMITTIME, NULL)) OVER (PARTITION BY SUBJECT_ID) AS last_visit_if_not_diagnosed
        FROM 
            overallDiagnosis  
    '''.format(diagnosis))

first_diag_time('038')
pd.read_sql_query('SELECT * FROM JAX_first_diag_time LIMIT 5', mydb)

In [None]:
def encountersAfterDiagnosis():
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_encounters_after_diagnosis')
    cursor.execute('''
        CREATE TEMPORARY TABLE JAX_encounters_after_diagnosis
            SELECT *, 1 AS toIgnore
            FROM JAX_first_diag_time
            WHERE DIAGNOSIS = 1 AND ADMITTIME > first_diag
    ''')
    cursor.execute('CREATE INDEX JAX_encounters_after_diagnosis_idx01 ON JAX_encounters_after_diagnosis (SUBJECT_ID, HADM_ID)')
    
encountersAfterDiagnosis() 
pd.read_sql_query("SELECT * FROM JAX_encounters_after_diagnosis LIMIT 5", mydb)

In [None]:
diagnosis_vector = pd.read_sql_query('''
        SELECT
            SUBJECT_ID, IF(SUM(DIAGNOSIS)>0, 1, 0) AS DIAGNOSED, MAX(IF(everDiagnosed = 1, first_diag, last_visit_if_not_diagnosed)) AS LAST_VISIT
        FROM
            JAX_first_diag_time
        GROUP BY SUBJECT_ID
    ''', mydb)
diagnosis_vector.head()

In [None]:
def lab_phenotype_before_diagnosis():
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_phen_lab_before_diag')
    cursor.execute('''
        CREATE TEMPORARY TABLE JAX_phen_lab_before_diag
        WITH temp as (
            SELECT L.*, W.toIgnore
            FROM JAX_LABHPOPROFILE AS L
            JOIN JAX_mf_diag AS R ON L.SUBJECT_ID = R.SUBJECT_ID AND L.HADM_ID = R.HADM_ID
            LEFT JOIN JAX_encounters_after_diagnosis W on  L.SUBJECT_ID = W.SUBJECT_ID AND L.HADM_ID = W.HADM_ID
            WHERE L.OCCURRANCE >= 1
        )
        SELECT SUBJECT_ID, MAP_TO, COUNT(*) as N
        FROM temp
        WHERE toIgnore IS NULL
        GROUP BY SUBJECT_ID, MAP_TO
    ''')
    cursor.execute('CREATE INDEX JAX_phen_lab_before_diag_idx01 ON JAX_phen_lab_before_diag (N)')
    

def text_phenotype_before_diagnosis():
    cursor.execute('DROP TEMPORARY TABLE IF EXISTS JAX_phen_text_before_diag')
    cursor.execute('''
        CREATE TEMPORARY TABLE JAX_phen_text_before_diag
        WITH temp as (
            SELECT L.*, W.toIgnore
            FROM JAX_TEXTHPOPROFILE AS L
            JOIN JAX_mf_diag AS R ON L.SUBJECT_ID = R.SUBJECT_ID AND L.HADM_ID = R.HADM_ID
            LEFT JOIN JAX_encounters_after_diagnosis W on  L.SUBJECT_ID = W.SUBJECT_ID AND L.HADM_ID = W.HADM_ID
        )
        SELECT SUBJECT_ID, MAP_TO, COUNT(*) as N
        FROM temp
        WHERE toIgnore IS NULL
        GROUP BY SUBJECT_ID, MAP_TO
    ''')
    cursor.execute('CREATE INDEX JAX_phen_text_before_diag_idx01 ON JAX_phen_text_before_diag (N)')

In [None]:
#lab_phenotype_before_diagnosis()
lab_phenotype_vector = pd.read_sql_query('''
    SELECT * FROM JAX_phen_lab_before_diag
'''.format(1), mydb)
lab_phenotype_vector.head()

In [None]:
text_phenotype_before_diagnosis()
text_phenotype_vector = pd.read_sql_query('''
    SELECT * FROM JAX_phen_text_before_diag WHERE N >= {}
'''.format(1), mydb)
text_phenotype_vector.head()

In [None]:
phenotypes = lab_phenotype_vector.merge(text_phenotype_vector, on = ['SUBJECT_ID', 'MAP_TO'], how = 'outer').fillna(value = 0)
phenotypes['N'] = phenotypes['N_x'] + phenotypes['N_y']
phenotypes.head()

In [None]:
phenotypes_matrix = merged.loc[:, ['SUBJECT_ID', 'MAP_TO', 'N']].pivot_table(values='N', index='SUBJECT_ID', columns='MAP_TO', fill_value=0)

In [None]:
df = patients.merge(diagnosis_vector, on = 'SUBJECT_ID')
df['AGE'] = (pd.to_datetime(df.LAST_VISIT, format='%Y-%m-%d %H:%M:%S').dt.year - pd.to_datetime(df.DOB, format='%Y-%m-%d %H:%M:%S').dt.year)
df.head()


In [None]:
df = df.loc[:, ['SUBJECT_ID', 'GENDER', 'AGE', 'DIAGNOSED']].merge(phenotypes_matrix, on = 'SUBJECT_ID')

In [None]:
df.head()

In [None]:
df.columns

In [None]:
df.to_csv('ml_df_038_primary_only.csv', index=False)

In [None]:
X = patients.merge(m, on = 'SUBJECT_ID', how = 'left').sort_values(by = 'SUBJECT_ID').set_index('SUBJECT_ID')
X = X.drop('DOB', axis=1).fillna(value=0)
X.head(n = 3)

In [None]:
y = diagnosis_vector.sort_values(by = 'SUBJECT_ID').set_index('SUBJECT_ID')
y.head()

In [None]:
import sklearn
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.compose import make_column_transformer
from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import StandardScaler

In [None]:
categorical = X.dtypes == object
categorical

In [None]:
sex = pd.get_dummies(X.GENDER)
X = X.drop('GENDER', axis=1).merge(sex, left_index=True, right_index=True).drop('F', axis=1)
X.M = X.M.astype(float)
X.head()

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y)

In [None]:
clf = LogisticRegression()

In [None]:
clf.fit(X_train, y_train)

In [None]:
clf.score(X_test, y_test)

In [None]:
np.sum(y_test.IsDiagnosed)/len(y_test)

In [26]:
cursor.close()
mydb.close()