In [1]:
import pandas as pd 
import numpy as np 
import pickle as pk

import os 

In [None]:
#  adm = pd.read_csv('/MIMIC_PATH/ADMISSIONS.csv.gz')
#  notes = pd.read_csv('/MIMIC_PATH/NOTEEVENTS.csv.gz')

class Checker:
    def __init__(self, split, s_df, t_df):
        
        self.split = split
        
        self.train = split['train']
        self.val = split['val']
        self.test = split['test']
        
        self.all = pd.concat([self.train, self.val, self.test])
        
        self.s_hadms = self.get_s_hadms(s_df)
        self.t_hadms = self.get_t_hadms(t_df)
        
        print('struct hadms:', len(self.s_hadms))
        print('text hadms:', len(self.t_hadms))
        print()
        
    def trim_split(self, main_mode='text'):
        if main_mode == 'text':
            new_split = {}
            missing_hadms = {}
            for k, v in self.split.items():
                mask = v['HADM_ID'].isin(self.s_hadms)
                yes = v[mask]
                no = v[~mask]
                
                new_split[k] = yes
                missing_hadms[k] = no['HADM_ID'].tolist()

            
        elif main_mode == 'struct':
            new_split = {}
            missing_hadms = {}
            for k, v in self.split.items():
                mask = v['HADM_ID'].isin(self.t_hadms)
                yes = v[mask]
                no = v[~mask]
                
                new_split[k] = yes
                missing_hadms[k] = no['HADM_ID'].tolist()

            
        return new_split, missing_hadms
        

    def check_all(self):
        print('Cases, Orig, Struct, Text')
        for name, fold in zip(['All', 'Train', 'Val', 'Test'], [self.all, self.train, self.val, self.test]):
            self._check_fold(fold, name)
        print()
        
    def _check_fold(self, fold, name='Train'):
        orig = len(fold)
        struct = len(fold[fold['HADM_ID'].isin(self.s_hadms)])
        text = len(fold[fold['HADM_ID'].isin(self.t_hadms)])
        
        print(f"{name}: {orig}, {struct}, {text}")
        
    def get_s_hadms(self, s_df):
        hadms = []
        for df in s_df:
            tmp = df.index.get_level_values('hadm_id').unique().tolist()
            
            hadms.extend(tmp)
            
        return hadms
    
    def get_t_hadms(self, t_df):
        return t_df.HADM_ID.unique().tolist()
    
    
    
def _print_split_nums(split):
    train = split['train']
    val = split['val']
    test = split['test']
    
    
    for s in [train, val, test]:
        print(len(s))
        
    print()
    

def check_missing_for_mex(missing_mex):
    
    empty_hadms = []
    
    for name, hadms in missing_mex.items():
        print('Check', name)
        
        has_note, no_note = 0,0
        for hadm in hadms:
            n = notes[notes.HADM_ID==hadm]
            
            if len(n) > 0:
                has_note += 1
            else:
                no_note += 1
                empty_hadms.append(hadm)
        
        print('later note \ no note')
        print(has_note, no_note)
        
    return empty_hadms

# New cohort w/o missing input

In [25]:
ms_drg = pd.read_pickle('splits_drg_ms.p')
apr_drg = pd.read_pickle('splits_drg_apr.p')
mex = pd.read_pickle('splits_mextract.p')

In [26]:
s_ms_drg = pd.read_pickle('../measurements/drg_ms_hourly.p')
s_apr_drg = pd.read_pickle('../measurements/drg_apr_hourly.p')
s_mex = pd.read_pickle('../measurements/mextract_hourly.p')


In [27]:
t_ms_drg = pd.read_pickle('../notes_raw/drg_ms_df.p')
t_apr_drg = pd.read_pickle('../notes_raw/drg_apr_df.p')
t_mex = pd.read_pickle('../notes_raw/mextract_df.p')


In [82]:
mex_check = Checker(mex, s_mex, t_mex)

ms_check = Checker(ms_drg, s_ms_drg, t_ms_drg)

apr_check = Checker(apr_drg, s_apr_drg, t_apr_drg)

struct hadms: 23944
text hadms: 23661

struct hadms: 12845
text hadms: 19132

struct hadms: 17270
text hadms: 25371



In [83]:
mex_check.check_all()

ms_check.check_all()

apr_check.check_all()

Cases, Orig, Struct, Text
All: 23944, 23944, 23661
Train: 16760, 16760, 16557
Val: 2394, 2394, 2372
Test: 4790, 4790, 4732

Cases, Orig, Struct, Text
All: 19132, 12845, 19132
Train: 16294, 10915, 16294
Val: 972, 645, 972
Test: 1866, 1285, 1866

Cases, Orig, Struct, Text
All: 25371, 17270, 25371
Train: 21610, 14679, 21610
Val: 1251, 862, 1251
Test: 2510, 1729, 2510



In [127]:
new_mex, missing_mex = mex_check.trim_split('struct')

new_ms, missing_ms = ms_check.trim_split('text')

new_apr, missing_apr = apr_check.trim_split('text')

_print_split_nums(new_mex)
_print_split_nums(new_ms)
_print_split_nums(new_apr)


16557
2372
4732

10915
645
1285

14679
862
1729



In [120]:
empty = check_missing_for_mex(missing_mex)

Check train
later note \ no note
91 112
Check val
later note \ no note
13 9
Check test
later note \ no note
18 40


In [129]:
# save new split

with open('trim_splits_mextract.p', 'wb') as outf:
    pk.dump(new_mex, outf)
    
with open('trim_splits_drg_ms.p', 'wb') as outf:
    pk.dump(new_ms, outf)
    
with open('trim_splits_drg_apr.p', 'wb') as outf:
    pk.dump(new_apr, outf)  
    


  exec(code_obj, self.user_global_ns, self.user_ns)
