In [63]:
!pip install duckdb




[notice] A new release of pip is available: 23.1.2 -> 24.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [64]:
#################################################################
# Code written by Sajad Darabi (sajad.darabi@cs.ucla.edu)
# For bug report, please contact author using the email address
#################################################################

from datetime import datetime
from collections import Counter, OrderedDict
import math
import pickle
import pandas as pd
import os
import sys
# from data_loader.utils.vocab import Vocab
from tqdm import tqdm
import duckdb



def convert_to_med2vec(patient_data):
    data = []
    for k, vv in patient_data.items():
        for v in vv:
            data.append(v[0])
        data.append([-1])
    return data

class Vocab:
    
    def __init__(self):
        self.vocabulary = {}
        self.codes = []
    
    def convert_to_icd9(self, dxStr):
        if dxStr.startswith('E'):
            if len(dxStr) > 4: return dxStr[:4] + '.' + dxStr[4:]
            else: return dxStr
        else:
            if len(dxStr) > 3: return dxStr[:3] + '.' + dxStr[3:]
            else: return dxStr

    def convert_to_ids(self, codes_list, icd9=True):
        # print(codes_list)
        res = []
        for code in codes_list:
            if icd9:
                standard_code = self.convert_to_icd9(code)
            else:
                standard_code = code

            if standard_code in self.vocabulary:
                res.append(self.vocabulary[standard_code])
            
            else:
                self.codes.append(standard_code)
                self.vocabulary[standard_code] = len(self.codes) # Should be len(self.vocabulary)? 
                res.append(self.vocabulary[standard_code])
        return res

class Duck:
    SELECT = """SELECT {col} from read_parquet({table})
                    WHERE HADM_ID = {hadm_id}"""
    PROCEDURE_COLUMN = "ICD9_CODE"
    PRESCRIPTION_COLUMN = "NDC"
    
    def execute(self, query):
        con = duckdb.connect()
        data = con.execute(query=query).df()
        con.close()
        return data
        

In [65]:
diagnosis_file = 'DIAGNOSES_ICD.csv'
admission_file = 'ADMISSIONS.csv'
procedure_file = 'PROCEDURES_ICD.parquet'
prescription_file = 'PRESCRIPTIONS.parquet'
outfile = 'model.seq'
med2vec_format = False

if (len(sys.argv) > 4 and 'med2vec' == sys.argv[5]): #(len(sys.argv) > 4):
    med2vec_format = True
df_diagnosis = pd.read_csv(diagnosis_file)
df_admission = pd.read_csv(admission_file)

df_admission

Unnamed: 0,ROW_ID,SUBJECT_ID,HADM_ID,ADMITTIME,DISCHTIME,DEATHTIME,ADMISSION_TYPE,ADMISSION_LOCATION,DISCHARGE_LOCATION,INSURANCE,LANGUAGE,RELIGION,MARITAL_STATUS,ETHNICITY,EDREGTIME,EDOUTTIME,DIAGNOSIS,HOSPITAL_EXPIRE_FLAG,HAS_CHARTEVENTS_DATA
0,21,22,165315,2196-04-09 12:26:00,2196-04-10 15:54:00,,EMERGENCY,EMERGENCY ROOM ADMIT,DISC-TRAN CANCER/CHLDRN H,Private,,UNOBTAINABLE,MARRIED,WHITE,2196-04-09 10:06:00,2196-04-09 13:24:00,BENZODIAZEPINE OVERDOSE,0,1
1,22,23,152223,2153-09-03 07:15:00,2153-09-08 19:10:00,,ELECTIVE,PHYS REFERRAL/NORMAL DELI,HOME HEALTH CARE,Medicare,,CATHOLIC,MARRIED,WHITE,,,CORONARY ARTERY DISEASE\CORONARY ARTERY BYPASS...,0,1
2,23,23,124321,2157-10-18 19:34:00,2157-10-25 14:00:00,,EMERGENCY,TRANSFER FROM HOSP/EXTRAM,HOME HEALTH CARE,Medicare,ENGL,CATHOLIC,MARRIED,WHITE,,,BRAIN MASS,0,1
3,24,24,161859,2139-06-06 16:14:00,2139-06-09 12:48:00,,EMERGENCY,TRANSFER FROM HOSP/EXTRAM,HOME,Private,,PROTESTANT QUAKER,SINGLE,WHITE,,,INTERIOR MYOCARDIAL INFARCTION,0,1
4,25,25,129635,2160-11-02 02:06:00,2160-11-05 14:55:00,,EMERGENCY,EMERGENCY ROOM ADMIT,HOME,Private,,UNOBTAINABLE,MARRIED,WHITE,2160-11-02 01:01:00,2160-11-02 04:27:00,ACUTE CORONARY SYNDROME,0,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
58971,58594,98800,191113,2131-03-30 21:13:00,2131-04-02 15:02:00,,EMERGENCY,CLINIC REFERRAL/PREMATURE,HOME,Private,ENGL,NOT SPECIFIED,SINGLE,WHITE,2131-03-30 19:44:00,2131-03-30 22:41:00,TRAUMA,0,1
58972,58595,98802,101071,2151-03-05 20:00:00,2151-03-06 09:10:00,2151-03-06 09:10:00,EMERGENCY,CLINIC REFERRAL/PREMATURE,DEAD/EXPIRED,Medicare,ENGL,CATHOLIC,WIDOWED,WHITE,2151-03-05 17:23:00,2151-03-05 21:06:00,SAH,1,1
58973,58596,98805,122631,2200-09-12 07:15:00,2200-09-20 12:08:00,,ELECTIVE,PHYS REFERRAL/NORMAL DELI,HOME HEALTH CARE,Private,ENGL,NOT SPECIFIED,MARRIED,WHITE,,,RENAL CANCER/SDA,0,1
58974,58597,98813,170407,2128-11-11 02:29:00,2128-12-22 13:11:00,,EMERGENCY,EMERGENCY ROOM ADMIT,SNF,Private,ENGL,CATHOLIC,MARRIED,WHITE,2128-11-10 23:48:00,2128-11-11 03:16:00,S/P FALL,0,0


In [66]:
full_digit_icd9 = True #flag to extrat short
# REMOVE 'ORGAN DONOR ACCOUNT' , 'DONOR ACCOUNT' , AND 'ORGAN DONOR' DIAGNOSIS ROWS
REMOVE_DIAGNOSIS = ~((df_admission['DIAGNOSIS'] == 'ORGAN DONOR ACCOUNT') | (df_admission['DIAGNOSIS'] == 'ORGAN DONOR') | \
                    (df_admission['DIAGNOSIS'] == 'DONOR ACCOUNT'))
df = df_admission[REMOVE_DIAGNOSIS]

patient_data = {}
patient_id = set(df['SUBJECT_ID'])
vocab = Vocab()


In [67]:
vocab.vocabulary

{}

In [68]:
duck_instance = Duck()

path_to_parquet = r"PROCEDURES_ICD.parquet"


query = f"SELECT * FROM read_parquet('{path_to_parquet}') LIMIT 10"

data = duck_instance.execute(query)

print(data)


   ROW_ID  SUBJECT_ID  HADM_ID  SEQ_NUM  ICD9_CODE
0     944       62641   154460        3       3404
1     945        2592   130856        1       9671
2     946        2592   130856        2       3893
3     947       55357   119355        1       9672
4     948       55357   119355        2        331
5     949       55357   119355        3       3893
6     950        9545   158060        1         34
7     951       28600   189217        1       3613
8     952       28600   189217        2       3615
9     953       28600   189217        3       3961


In [69]:
duck_instance = Duck()

path_to_parquet = r"PRESCRIPTIONS.parquet"


query = f"SELECT * FROM read_parquet('{path_to_parquet}') LIMIT 10"

data = duck_instance.execute(query)

print(data)


    ROW_ID  SUBJECT_ID  HADM_ID  ICUSTAY_ID            STARTDATE  \
0  2214776           6   107064         NaN  2175-06-11 00:00:00   
1  2214775           6   107064         NaN  2175-06-11 00:00:00   
2  2215524           6   107064         NaN  2175-06-11 00:00:00   
3  2216265           6   107064         NaN  2175-06-11 00:00:00   
4  2214773           6   107064         NaN  2175-06-11 00:00:00   
5  2214774           6   107064         NaN  2175-06-11 00:00:00   
6  2215525           6   107064         NaN  2175-06-12 00:00:00   
7  2216266           6   107064         NaN  2175-06-12 00:00:00   
8  2215526           6   107064         NaN  2175-06-12 00:00:00   
9  2214778           6   107064         NaN  2175-06-12 00:00:00   

               ENDDATE DRUG_TYPE            DRUG DRUG_NAME_POE  \
0  2175-06-12 00:00:00      MAIN      Tacrolimus    Tacrolimus   
1  2175-06-12 00:00:00      MAIN        Warfarin      Warfarin   
2  2175-06-12 00:00:00      MAIN  Heparin Sodium     

In [70]:
for pid in tqdm(patient_id):
    pid_df = df[df['SUBJECT_ID'] == pid]
    if (len(pid_df) < 2):
        continue
    adm_list = pid_df[['HADM_ID', 'ADMITTIME', 'DEATHTIME']] # add DISCHATIME ?
    patient_data[pid] = []
    for i, r in adm_list.iterrows():
        admid = r['HADM_ID']
        admitime = r['ADMITTIME']
        icd9_raw = df_diagnosis[df_diagnosis['HADM_ID'] == admid]['ICD9_CODE'].values
        icd9_raw = list(map(str, icd9_raw))
        # duck = Duck()
        # pro_query = duck.SELECT.format(
        #     col = duck.PROCEDURE_COLUMN,
        #     table = procedure_file,
        #     hadm_id =  admid
        # )
        # procedure =  duck.execute(pro_query)
        # # print(procedure)
        # procedure_icd9_raw  = list([str(i[0]) for i in procedure.values])
        # # print(procedure_icd9_raw)
        
        # ### ADD PRESCRIPTION CODES ###
        # pres_query = duck.SELECT.format(
        #     col = duck.PRESCRIPTION_COLUMN,
        #     table = prescription_file,
        #     hadm_id = admid
        # )
        
        # prescription = duck.execute(pres_query)
        # # print(prescription)
        # prescription_ndc_raw = list([str(i[0]) for i in prescription.values])
        # # print(prescription_ndc_raw)
        
        # icd9_raw.extend(procedure_icd9_raw)
        icd9 = vocab.convert_to_ids(icd9_raw)
        # print(icd9)
        
        # prescription_ndc = vocab.convert_to_ids(prescription_ndc_raw, icd9=False)
        # # print(prescription_ndc)
        
        # icd9.extend(prescription_ndc) #ALL CODES 
        # # print(icd9)
        # ###############################################
        
        
        mortality = r['DEATHTIME'] == r['DEATHTIME'] # check not nan
        admtime = datetime.strptime(r['ADMITTIME'], '%Y-%m-%d %H:%M:%S') # TODO: convert date time to integers.. ?!?
        tup = (icd9, admtime, mortality)
        patient_data[pid].append(tup)
        
        ## {10: [(), (), ()]}
        ## [[()], [()], [()]]
        ## Add -1 after interval 
        
    


100%|██████████| 46518/46518 [00:36<00:00, 1262.89it/s]


In [71]:
len(vocab.vocabulary)

4876

In [72]:
patient_data

{17: [([1, 2, 3, 4], datetime.datetime(2134, 12, 27, 7, 15), False),
  ([5, 6, 7, 8, 9, 10, 11, 4], datetime.datetime(2135, 5, 9, 14, 11), False)],
 21: [([12, 7, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28],
   datetime.datetime(2134, 9, 11, 12, 17),
   False),
  ([29,
    30,
    15,
    19,
    31,
    6,
    32,
    33,
    34,
    35,
    36,
    37,
    22,
    38,
    28,
    26,
    21,
    39,
    40,
    41,
    42],
   datetime.datetime(2135, 1, 30, 20, 50),
   True)],
 23: [([21, 43, 44, 45, 4, 46, 47, 48],
   datetime.datetime(2153, 9, 3, 7, 15),
   False),
  ([49, 50, 51, 44, 46, 25, 4, 52, 53, 54],
   datetime.datetime(2157, 10, 18, 19, 34),
   False)],
 34: [([12, 16, 55, 19, 56, 57, 21, 58],
   datetime.datetime(2186, 7, 18, 16, 46),
   False),
  ([59, 60, 61, 21, 62, 16, 41, 63],
   datetime.datetime(2191, 2, 23, 5, 23),
   False)],
 36: [([21, 43, 64, 46, 65, 66, 47, 67, 68],
   datetime.datetime(2131, 4, 30, 7, 15),
   False),
  ([69, 70, 71, 72, 

In [73]:
# outdir = os.path.abspath(os.path.curdir)
# if not os.path.exists(os.path.join(outdir, 'data')):
#     os.mkdir(os.path.join(outdir, 'data'))
# outfile = os.path.join(outdir, 'data', outfile)
# if (med2vec_format):
#     patient_data = convert_to_med2vec(patient_data)
# if (med2vec_format):
#     pickle.dump(patient_data, open('med2vec.seqs', 'wb'), -1)
#     pickle.dump(vocab, open('med2vec.vocab', 'wb'), -1)
# else:
#     pickle.dump(patient_data, open(outfile + '_mimic_iii.seqs', 'wb'), -1)
#     pickle.dump(vocab, open(outfile + '.vocab', 'wb'), -1)


In [74]:
train_list = []

In [75]:
for visits in list(patient_data.values()):
    visits = sorted([(visit[1], visit[0]) for visit in visits])
    
    for visit in visits:
        train_list.append(visit[1])
    
    train_list.append([-1])



In [76]:
# train_list

In [77]:
len(train_list)

27405

In [78]:
train_list[:5]

[[1, 2, 3, 4],
 [5, 6, 7, 8, 9, 10, 11, 4],
 [-1],
 [12, 7, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28],
 [29,
  30,
  15,
  19,
  31,
  6,
  32,
  33,
  34,
  35,
  36,
  37,
  22,
  38,
  28,
  26,
  21,
  39,
  40,
  41,
  42]]

In [79]:
import torch
import torch.utils.data as data
import os
import pickle
import numpy as np

class Med2VecDataset(data.Dataset):

    def __init__(self, num_codes, train=True, transform=None, target_transform=None, download=False):
        # self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        print("Med2VecDataset.train ", train)
        self.train = train
        self.num_codes = num_codes
        if download:
            raise ValueError('cannot download')

        self.train_data = train_list
        print("Length of Training Data", self.train_data.__len__())
        # self.correct_train_data()
        
        self.test_data = []

    def __len__(self):
        if self.train:
            return len(self.train_data)
        else:
            return len(self.test_data)
        
    # def get_codes(self, val):
    #     return list(map(lambda x: x[0], val))
        
    # def correct_train_data(self):
    #     res = []
    #     for data in self.train_data:
    #         codes_list = self.get_codes(data) ## [[1st visit codes], [second]] = a
    #         ## res = [[], [-1], a]
    #         res.extend(codes_list)
    #         res.append([-1])
        
    #     self.train_data = res 
        

    def __getitem__(self, index):
        if index >= len(self.train_data):
            raise IndexError(f"Index {index} out of range for dataset of length {len(self.train_data)}")
        
        x, ivec, jvec, d = self.preprocess(self.train_data[index])
        return x, ivec, jvec, d

    def preprocess(self, seq):
        """ create one hot vector of idx in seq, with length self.num_codes

            Args:
                seq: list of ideces where code should be 1

            Returns:
                x: one hot vector
                ivec: vector for learning code representation
                jvec: vector for learning code representation
        """
        x = torch.zeros((self.num_codes, ), dtype=torch.long)

        ivec = []
        jvec = []
        d = []
        if seq == [-1]:
            return x, torch.LongTensor(ivec), torch.LongTensor(jvec), torch.tensor(d, dtype=torch.int32)
        

        x[seq] = 1
        for i in seq:
            for j in seq:
                if i == j:
                    continue
                ivec.append(i)
                jvec.append(j)
                
        
        return x, torch.LongTensor(ivec), torch.LongTensor(jvec), torch.tensor(d, dtype=torch.int32)



In [80]:
x = torch.zeros((8, ), dtype=torch.long)
x

tensor([0, 0, 0, 0, 0, 0, 0, 0])

In [81]:
seq = train_list[0]

In [82]:
seq

[1, 2, 3, 4]

In [83]:
x[[s-1 for s in seq]] = 1
x

tensor([1, 1, 1, 1, 0, 0, 0, 0])

In [84]:
num_codes = 4876
dataset_med2vec = Med2VecDataset(num_codes, True)

Med2VecDataset.train  True
Length of Training Data 27405


In [85]:
dataset_med2vec.__getitem__(9)

(tensor([0, 0, 0,  ..., 0, 0, 0]),
 tensor([12, 12, 12, 12, 12, 12, 12, 16, 16, 16, 16, 16, 16, 16, 55, 55, 55, 55,
         55, 55, 55, 19, 19, 19, 19, 19, 19, 19, 56, 56, 56, 56, 56, 56, 56, 57,
         57, 57, 57, 57, 57, 57, 21, 21, 21, 21, 21, 21, 21, 58, 58, 58, 58, 58,
         58, 58]),
 tensor([16, 55, 19, 56, 57, 21, 58, 12, 55, 19, 56, 57, 21, 58, 12, 16, 19, 56,
         57, 21, 58, 12, 16, 55, 56, 57, 21, 58, 12, 16, 55, 19, 57, 21, 58, 12,
         16, 55, 19, 56, 21, 58, 12, 16, 55, 19, 56, 57, 58, 12, 16, 55, 19, 56,
         57, 21]),
 tensor([], dtype=torch.int32))

In [86]:
import logging
import torch.nn as nn
import numpy as np


class BaseModel(nn.Module):
    """
    Base class for all models
    """
    def __init__(self):
        super(BaseModel, self).__init__()
        self.logger = logging.getLogger(self.__class__.__name__)

    def forward(self, *input):
        """
        Forward pass logic

        :return: Model output
        """
        raise NotImplementedError

    def summary(self):
        """
        Model summary
        """
        model_parameters = filter(lambda p: p.requires_grad, self.parameters())
        params = sum([np.prod(p.size()) for p in model_parameters])
        self.logger.info('Trainable parameters: {}'.format(params))
        self.logger.info(self)

    def __str__(self):
        """
        Model prints with number of trainable parameters
        """
        model_parameters = filter(lambda p: p.requires_grad, self.parameters())
        params = sum([np.prod(p.size()) for p in model_parameters if p is not None])
        return super(BaseModel, self).__str__() + '\nTrainable parameters: {}'.format(params)
        # print(super(BaseModel, self))




In [87]:
# !pip install tensorboardX

base -> Utils.py

In [88]:
import torch
import torchvision.utils as vutils
from tensorboardX import SummaryWriter

class TensorboardWriter:
    def __init__(self, log_dir, logger, enabled=True):
        self.writer = SummaryWriter(log_dir) if enabled else None
        self.log_dir = log_dir
        self.step = 0
        self.mode = ''
        self.tensorboard_enabled = enabled
        self.logger = logger

    def set_step(self, step, mode='train'):
        self.step = step
        self.mode = mode

    def add_scalar(self, tag, value, step=None):
        if self.tensorboard_enabled:
            if step is None:
                step = self.step
            self.writer.add_scalar(f'{self.mode}/{tag}', value, step)
            self.logger.info(f'Scalar Summary - {tag}: {value} at step {step}')

    def add_image(self, tag, images, step=None):
        if self.tensorboard_enabled:
            if step is None:
                step = self.step
            img_grid = vutils.make_grid(images)
            self.writer.add_image(f'{self.mode}/{tag}', img_grid, step)
            self.logger.info(f'Image Summary - {tag} at step {step}')

    def histogram_summary(self, tag, values, step=None):
        if self.tensorboard_enabled:
            if step is None:
                step = self.step
            self.writer.add_histogram(f'{self.mode}/{tag}', values, step)
            self.logger.info(f'Histogram Summary - {tag} at step {step}')

    def close(self):
        if self.tensorboard_enabled:
            self.writer.close()
            self.logger.info('Tensorboard writer closed')

    def __del__(self):
        self.close()



base -> base_trainer.py

In [89]:
import os
import math
import json
import logging
import datetime
import torch
# from utils.util import ensure_dir
# from .utils import TensorboardWriter
logging.basicConfig(filename="all_logs.logs")
def ensure_dir(dir):
    if not os.path.isdir(dir):
        os.makedirs(dir)


class BaseTrainer:
    """
    Base class for all trainers
    """
    def __init__(self, model, loss, metrics, optimizer, resume, config, train_logger=None):
        self.config = config
        self.logger = logging.getLogger(self.__class__.__name__)
        self.logger.setLevel(logging.DEBUG)

        # setup GPU device if available, move model into configured device
        self.device, device_ids = self._prepare_device(config['n_gpu'])
        self.model = model.to(self.device)
        if len(device_ids) > 1:
            self.model = torch.nn.DataParallel(model, device_ids=device_ids)

        self.loss = loss
        self.metrics = metrics
        self.optimizer = optimizer
        self.train_logger = train_logger

        cfg_trainer = config['trainer']
        self.epochs = cfg_trainer['epochs']
        self.save_period = cfg_trainer['save_period']
        self.verbosity = cfg_trainer['verbosity']
        self.monitor = cfg_trainer.get('monitor', 'off')

        # configuration to monitor model performance and save best
        if self.monitor == 'off':
            self.mnt_mode = 'off'
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ['min', 'max']

            self.mnt_best = math.inf if self.mnt_mode == 'min' else -math.inf
            self.early_stop = cfg_trainer.get('early_stop', math.inf)
        
        self.start_epoch = 1

        # setup directory for checkpoint saving
        start_time = datetime.datetime.now().strftime('%m%d_%H%M%S')
        self.checkpoint_dir = os.path.join(cfg_trainer['save_dir'], config['name'], start_time)
        # setup visualization writer instance
        writer_dir = os.path.join(cfg_trainer['log_dir'], config['name'], start_time)
        # self.writer = WriterTensorboardX(writer_dir, self.logger, cfg_trainer['tensorboardX'])
        self.writer = TensorboardWriter(writer_dir, self.logger, cfg_trainer['tensorboardX'] )

        # Save configuration file into checkpoint directory:
        ensure_dir(self.checkpoint_dir)
        config_save_path = os.path.join(self.checkpoint_dir, 'config.json')
        with open(config_save_path, 'w') as handle:
            json.dump(config, handle, indent=4, sort_keys=False)

        if resume:
            self._resume_checkpoint(resume)
    
    def _prepare_device(self, n_gpu_use):
        """ 
        setup GPU device if available, move model into configured device
        """ 
        n_gpu = torch.cuda.device_count()
        if n_gpu_use > 0 and n_gpu == 0:
            self.logger.warning("Warning: There\'s no GPU available on this machine, training will be performed on CPU.")
            n_gpu_use = 0
        if n_gpu_use > n_gpu:
            self.logger.warning("Warning: The number of GPU\'s configured to use is {}, but only {} are available on this machine.".format(n_gpu_use, n_gpu))
            n_gpu_use = n_gpu
        device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
        list_ids = list(range(n_gpu_use))
        return device, list_ids

    def train(self):
        """
        Full training logic
        """
        for epoch in range(self.start_epoch, self.epochs + 1):
            result = self._train_epoch(epoch)
            
            # save logged informations into log dict
            log = {'epoch': epoch}
            for key, value in result.items():
                if key == 'metrics':
                    log.update({mtr.__name__ : value[i] for i, mtr in enumerate(self.metrics)})
                elif key == 'val_metrics':
                    log.update({'val_' + mtr.__name__ : value[i] for i, mtr in enumerate(self.metrics)})
                else:
                    log[key] = value

            # print logged informations to the screen
            if self.train_logger is not None:
                # self.train_logger.add_entry(log)
                if self.verbosity >= 1:
                    for key, value in log.items():
                        self.logger.info('    {:15s}: {}'.format(str(key), value))

            # evaluate model performance according to configured metric, save best checkpoint as model_best
            best = False
            if self.mnt_mode != 'off':
                try:
                    # check whether model performance improved or not, according to specified metric(mnt_metric)
                    improved = (self.mnt_mode == 'min' and log[self.mnt_metric] < self.mnt_best) or \
                               (self.mnt_mode == 'max' and log[self.mnt_metric] > self.mnt_best)
                except KeyError:
                    self.logger.warning("Warning: Metric '{}' is not found. Model performance monitoring is disabled.".format(self.mnt_metric))
                    self.mnt_mode = 'off'
                    improved = False
                    not_improved_count = 0

                if improved:
                    self.mnt_best = log[self.mnt_metric]
                    not_improved_count = 0
                    best = True
                else:
                    not_improved_count += 1

                if not_improved_count > self.early_stop:
                    self.logger.info("Validation performance didn\'t improve for {} epochs. Training stops.".format(self.early_stop))
                    break

            if epoch % self.save_period == 0:
                self._save_checkpoint(epoch, save_best=best)
            

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        :param epoch: Current epoch number
        """
        raise NotImplementedError

    def _save_checkpoint(self, epoch, save_best=False):
        """
        Saving checkpoints

        :param epoch: current epoch number
        :param log: logging information of the epoch
        :param save_best: if True, rename the saved checkpoint to 'model_best.pth'
        """
        arch = type(self.model).__name__
        state = {
            'arch': arch,
            'epoch': epoch,
            'logger': self.train_logger,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'monitor_best': self.mnt_best,
            'config': self.config
        }
        filename = os.path.join(self.checkpoint_dir, 'checkpoint-epoch{}.pth'.format(epoch))
        torch.save(state, filename)
        self.logger.info("Saving checkpoint: {} ...".format(filename))
        if save_best:
            best_path = os.path.join(self.checkpoint_dir, 'model_best.pth')
            torch.save(state, best_path)
            self.logger.info("Saving current best: {} ...".format('model_best.pth'))

    def _resume_checkpoint(self, resume_path):
        """
        Resume from saved checkpoints

        :param resume_path: Checkpoint path to be resumed
        """
        self.logger.info("Loading checkpoint: {} ...".format(resume_path))
        checkpoint = torch.load(resume_path)
        self.start_epoch = checkpoint['epoch'] + 1
        self.mnt_best = checkpoint['monitor_best']

        # load architecture params from checkpoint.
        if checkpoint['config']['arch'] != self.config['arch']:
            self.logger.warning('Warning: Architecture configuration given in config file is different from that of checkpoint. ' + \
                                'This may yield an exception while state_dict is being loaded.')
        self.model.load_state_dict(checkpoint['state_dict'])

        # load optimizer state from checkpoint only when optimizer type is not changed. 
        if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']:
            self.logger.warning('Warning: Optimizer type given in config file is different from that of checkpoint. ' + \
                                'Optimizer parameters not being resumed.')
        else:
            self.optimizer.load_state_dict(checkpoint['optimizer'])
    
        self.train_logger = checkpoint['logger']
        self.logger.info("Checkpoint '{}' (epoch {}) loaded".format(resume_path, self.start_epoch))


data_loader -> data_loaders.py

In [90]:
def collate_fn(data):
    """ Creates mini-batch from x, ivec, jvec tensors

    We should build custom collate_fn, as the ivec, and jvec have varying lengths. These should be appended
    in row form

    Args:
        data: list of tuples contianing (x, ivec, jvec)

    Returns:
        x: one hot encoded vectors stacked vertically
        ivec: long vector
        jvec: long vector
    """
    print(data)
    x, ivec, jvec, d = zip(*data)
    x = torch.stack(x, dim=0)
    mask = torch.sum(x, dim=1) > 0
    mask = mask[:, None]
    ivec = torch.cat(ivec, dim=0)
    jvec = torch.cat(jvec, dim=0)
    d = torch.stack(d, dim=0)
    
    # print(len(x), torch.sum(x, dim=1))

    return x, ivec, jvec, mask, d


In [91]:

def get_loader(num_codes, train=True, transform=None, target_transform=None, download=False, batch_size=1000):
    """ returns torch.utils.data.DataLoader for Med2Vec dataset """
    med2vec = dataset_med2vec
    data_loader = torch.utils.data.DataLoader(dataset=med2vec, batch_size=batch_size, shuffle=False, num_workers=1, collate_fn=collate_fn)
    return data_loader


In [92]:
dataset_med2vec.train_data

[[1, 2, 3, 4],
 [5, 6, 7, 8, 9, 10, 11, 4],
 [-1],
 [12, 7, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28],
 [29,
  30,
  15,
  19,
  31,
  6,
  32,
  33,
  34,
  35,
  36,
  37,
  22,
  38,
  28,
  26,
  21,
  39,
  40,
  41,
  42],
 [-1],
 [21, 43, 44, 45, 4, 46, 47, 48],
 [49, 50, 51, 44, 46, 25, 4, 52, 53, 54],
 [-1],
 [12, 16, 55, 19, 56, 57, 21, 58],
 [59, 60, 61, 21, 62, 16, 41, 63],
 [-1],
 [21, 43, 64, 46, 65, 66, 47, 67, 68],
 [69, 70, 71, 72, 73, 64, 21, 52, 46, 47, 66, 67],
 [74, 71, 75, 76, 77, 78, 79, 64, 46, 47, 80, 67],
 [-1],
 [81, 82, 83, 13, 84, 85, 86, 87, 88],
 [81, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 8, 19, 100, 101, 102],
 [-1],
 [103, 104, 46, 41, 84, 9],
 [105, 41, 46, 106, 9],
 [-1],
 [107, 76, 55, 108, 16, 14, 109, 85, 89, 110, 38, 111, 112],
 [107, 76, 55, 93, 14, 113, 109, 89, 83, 16, 38],
 [-1],
 [114],
 [114, 115, 116, 117],
 [-1],
 [44, 81, 118, 70, 119, 80, 66, 47, 120, 21, 121],
 [94,
  76,
  122,
  123,
  19,
  79,
  124,
  4,

In [93]:
def collate_fn(data):
    """ Creates mini-batch from x, ivec, jvec tensors

    We should build custom collate_fn, as the ivec, and jvec have varying lengths. These should be appended
    in row form

    Args:
        data: list of tuples contianing (x, ivec, jvec)

    Returns:
        x: one hot encoded vectors stacked vertically
        ivec: long vector
        jvec: long vector
    """
    # print(data)
    x, ivec, jvec, d = zip(*data)
    # print(x)
    x = torch.stack(x, dim=0)
    # print(x)
    y  = torch.sum(x, dim = 1)
    # print(y)
    mask = torch.sum(x, dim=1) > 0
    # print(mask)
    mask = mask[:, None]
    # print(mask)
    # print(ivec)
    ivec = torch.cat(ivec, dim=0)
    # print(ivec)
    jvec = torch.cat(jvec, dim=0)
    d = torch.stack(d, dim=0)
    
    # print(len(x), torch.sum(x, dim=1))

    return x, ivec, jvec, mask, d


In [94]:

data_loader = torch.utils.data.DataLoader(
    dataset=dataset_med2vec,
    batch_size=4,
    shuffle=False,
    collate_fn=collate_fn
)

# Run the DataLoader to trigger collate_fn and see the output
for batch in data_loader:
    x, ivec, jvec, mask, d = batch
    print("Batch output:")
    print("x:", x)
    print("ivec:", ivec)
    print("jvec:", jvec)
    print("mask:", mask)
    print("d:", d)
    print("---")
    break


Batch output:
x: tensor([[0, 1, 1,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])
ivec: tensor([ 1,  1,  1,  2,  2,  2,  3,  3,  3,  4,  4,  4,  5,  5,  5,  5,  5,  5,
         5,  6,  6,  6,  6,  6,  6,  6,  7,  7,  7,  7,  7,  7,  7,  8,  8,  8,
         8,  8,  8,  8,  9,  9,  9,  9,  9,  9,  9, 10, 10, 10, 10, 10, 10, 10,
        11, 11, 11, 11, 11, 11, 11,  4,  4,  4,  4,  4,  4,  4, 12, 12, 12, 12,
        12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,  7,  7,  7,  7,  7,
         7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7, 13, 13, 13, 13, 13, 13,
        13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 14,
        14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 15, 15,
        15, 15, 15, 15, 15, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 16, 16, 16,
        16, 16, 16, 16, 16, 16, 16, 16, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17,
        17, 17, 17, 17, 17, 17, 17, 18, 

In [95]:
train_list[:2]

[[1, 2, 3, 4], [5, 6, 7, 8, 9, 10, 11, 4]]

In [96]:
# collate_fn(dataset_med2vec.train_data)

In [97]:
dataset_med2vec

<__main__.Med2VecDataset at 0x2cf844b7390>

In [98]:
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data.sampler import SequentialSampler


class BaseDataLoader(DataLoader):
    """
    Base class for all data loaders
    """
    def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate):
        self.validation_split = validation_split
        self.shuffle = shuffle

        self.batch_idx = 0
        self.n_samples = len(dataset)

        self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)

        self.init_kwargs = {
            'dataset': dataset,
            'batch_size': batch_size,
            'shuffle': self.shuffle,
            'collate_fn': collate_fn,
            'num_workers': num_workers
            }
        super(BaseDataLoader, self).__init__(sampler=self.sampler, **self.init_kwargs)

    def _split_sampler(self, split):
        if split == 0.0:
            return None, None
        idx_full = np.arange(self.n_samples)
        # shuffle indexes only if shuffle is true
        # added for med2vec dataset where order matters
        if (self.shuffle):
            np.random.seed(0)
            np.random.shuffle(idx_full)

        len_valid = int(self.n_samples * split)

        valid_idx = idx_full[0:len_valid]
        train_idx = np.delete(idx_full, np.arange(0, len_valid))
        self.shuffle = False
        if (self.shuffle):
            train_sampler = SubsetRandomSampler(train_idx)
            valid_sampler = SubsetRandomSampler(valid_idx)
        else:
            train_sampler = SequentialSampler(train_idx)
            valid_sampler = SequentialSampler(valid_idx)

        # turn off shuffle option which is mutually exclusive with sampler
        
        self.n_samples = len(train_idx)

        return train_sampler, valid_sampler

    def split_validation(self):
        if self.valid_sampler is None:
            return None
        else:
            return DataLoader(sampler=self.valid_sampler, **self.init_kwargs)
        



In [99]:
from torchvision import datasets
# from .med2vec_dataset import Med2VecDataset
# from .med2vec_dataset import collate_fn as med2vec_collate

# from base import BaseDataLoader

import os
import pickle

class Med2VecDataLoader(BaseDataLoader):
    """
    Med2Vec Dataloader
    """
    def __init__(self, data_dir, num_codes, batch_size, shuffle, validation_split, num_workers, med=False, diag=False, proc=False, file_name=None, training=True, dict_format=False):
        
        self.num_codes = num_codes
        
        self.train = training
                
        self.dataset = dataset_med2vec
        
        super(Med2VecDataLoader, self).__init__(dataset_med2vec, batch_size, shuffle, validation_split, num_workers, collate_fn=collate_fn)


In [100]:
import json

In [101]:
config = json.load(open('./configs/config.json'))

In [102]:
config['data_loader']

{'type': 'Med2VecDataLoader',
 'args': {'data_dir': './',
  'batch_size': 2,
  'num_codes': 4876,
  'shuffle': False,
  'validation_split': 0.1,
  'num_workers': 0,
  'training': True}}

In [103]:
data_loader_args = config['data_loader']['args']

In [104]:
data_loader = Med2VecDataLoader(**data_loader_args)

In [105]:
for batch in data_loader:
    x, ivec, jvec, mask, d = batch
    print("Batch output:")
    print("x:", x)
    print("ivec:", ivec)
    print("jvec:", jvec)
    print("mask:", mask)
    print("d:", d)
    print("---")
    break


Batch output:
x: tensor([[0, 1, 1,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])
ivec: tensor([ 1,  1,  1,  2,  2,  2,  3,  3,  3,  4,  4,  4,  5,  5,  5,  5,  5,  5,
         5,  6,  6,  6,  6,  6,  6,  6,  7,  7,  7,  7,  7,  7,  7,  8,  8,  8,
         8,  8,  8,  8,  9,  9,  9,  9,  9,  9,  9, 10, 10, 10, 10, 10, 10, 10,
        11, 11, 11, 11, 11, 11, 11,  4,  4,  4,  4,  4,  4,  4])
jvec: tensor([ 2,  3,  4,  1,  3,  4,  1,  2,  4,  1,  2,  3,  6,  7,  8,  9, 10, 11,
         4,  5,  7,  8,  9, 10, 11,  4,  5,  6,  8,  9, 10, 11,  4,  5,  6,  7,
         9, 10, 11,  4,  5,  6,  7,  8, 10, 11,  4,  5,  6,  7,  8,  9, 11,  4,
         5,  6,  7,  8,  9, 10,  4,  5,  6,  7,  8,  9, 10, 11])
mask: tensor([[True],
        [True]])
d: tensor([], size=(2, 0), dtype=torch.int32)
---


In [106]:
def get_loader(root, num_codes, train=True, transform=None, target_transform=None, download=False, batch_size=10):
    """ returns torch.utils.data.DataLoader for Med2Vec dataset """
    med2vec = dataset_med2vec
    data_loader = torch.utils.data.DataLoader(dataset=med2vec, batch_size=batch_size, shuffle=False, num_workers=1, collate_fn=collate_fn)
    return data_loader


In [107]:
data_load = get_loader(4876, dataset_med2vec)

In [108]:
next(iter(data_loader))

(tensor([[0, 1, 1,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]]),
 tensor([ 1,  1,  1,  2,  2,  2,  3,  3,  3,  4,  4,  4,  5,  5,  5,  5,  5,  5,
          5,  6,  6,  6,  6,  6,  6,  6,  7,  7,  7,  7,  7,  7,  7,  8,  8,  8,
          8,  8,  8,  8,  9,  9,  9,  9,  9,  9,  9, 10, 10, 10, 10, 10, 10, 10,
         11, 11, 11, 11, 11, 11, 11,  4,  4,  4,  4,  4,  4,  4]),
 tensor([ 2,  3,  4,  1,  3,  4,  1,  2,  4,  1,  2,  3,  6,  7,  8,  9, 10, 11,
          4,  5,  7,  8,  9, 10, 11,  4,  5,  6,  8,  9, 10, 11,  4,  5,  6,  7,
          9, 10, 11,  4,  5,  6,  7,  8, 10, 11,  4,  5,  6,  7,  8,  9, 11,  4,
          5,  6,  7,  8,  9, 10,  4,  5,  6,  7,  8,  9, 10, 11]),
 tensor([[True],
         [True]]),
 tensor([], size=(2, 0), dtype=torch.int32))

In [109]:
# valid_data_loader = data_loader.split_validation()

In [110]:
# valid_data_loader

In [111]:
import logging
import torch.nn as nn
import numpy as np


class BaseModel(nn.Module):
    """
    Base class for all models
    """
    def __init__(self):
        super(BaseModel, self).__init__()
        self.logger = logging.getLogger(self.__class__.__name__)

    def forward(self, *input):
        """
        Forward pass logic

        :return: Model output
        """
        raise NotImplementedError

    def summary(self):
        """
        Model summary
        """
        model_parameters = filter(lambda p: p.requires_grad, self.parameters())
        params = sum([np.prod(p.size()) for p in model_parameters])
        self.logger.info('Trainable parameters: {}'.format(params))
        self.logger.info(self)

    def __str__(self):
        """
        Model prints with number of trainable parameters
        """
        model_parameters = filter(lambda p: p.requires_grad, self.parameters())
        params = sum([np.prod(p.size()) for p in model_parameters if p is not None])
        return super(BaseModel, self).__str__() + '\nTrainable parameters: {}'.format(params)
        # print(super(BaseModel, self))



In [112]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from base import BaseModel
__all__ = ['Med2Vec']

class Med2Vec(BaseModel):
    def __init__(self, icd9_size, demographics_size=0, embedding_size=2000, hidden_size=100,):
        super(Med2Vec, self).__init__()
        self.embedding_size = embedding_size
        self.demographics_size = demographics_size
        self.hidden_size = hidden_size
        self.vocabulary_size = icd9_size
        self.embedding_demo_size = self.embedding_size + self.demographics_size
        
        self.embedding_w = torch.nn.Parameter(torch.Tensor(self.embedding_size, self.vocabulary_size))
        torch.nn.init.uniform_(self.embedding_w, a=-0.1, b=0.1)
        self.embedding_b = torch.nn.Parameter(torch.Tensor(1, self.embedding_size))
        self.embedding_b.data.fill_(0)
        self.embedding = nn.Embedding(self.vocabulary_size, self.embedding_size)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.linear = nn.Linear(self.embedding_demo_size, self.hidden_size)
        self.probits = nn.Linear(self.hidden_size, self.vocabulary_size)

        self.bce_loss = nn.BCEWithLogitsLoss()


    def embedding(self, x):
        return F.linear(x, self.embedding_w, self.embedding_b)

    def forward(self, x, d=torch.Tensor([])):
        x = self.embedding(x)
        x = self.relu1(x)
        emb = F.relu(self.embedding_w)

        if (self.demographics_size):
            x = torch.cat((x, d), dim=1)
        x = self.linear(x)
        x = self.relu2(x)
        probits = self.probits(x)
        return probits, emb


In [113]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def loss(inputs, mask, probits, bce_loss, emb_w, ivec, jvec, window=1, eps=1.0e-8):
    """ returns the med2vec loss
    """
    def visit_loss(x, mask, probits, window=1):
        loss = 0
        for i in range(0, window):
            if (i == 0):
                maski = mask[i + 1:] * mask[:-i - 1]
            else:
                maski = mask[i + 1:] * mask[1:-i] * mask[:-i - 1]
            backward_preds = probits[i+1:] * maski
            forward_preds = probits[:-i-1] * maski
            loss += bce_loss(forward_preds, x[i+1:].float()) + bce_loss(backward_preds, x[:-i-1].float())
        return loss

    def code_loss(emb_w, ivec, jvec, eps=1.e-6):
        norm = torch.sum(torch.exp(torch.mm(emb_w.t(), emb_w)), dim=1)

        cost = -torch.log((torch.exp(torch.sum(emb_w[:, ivec].t() * emb_w[:, jvec].t(), dim=1)) / norm[ivec]) + eps)
        cost = torch.mean(cost)
        return cost

    vl = visit_loss(inputs, mask, probits, window=window)
    cl = code_loss(emb_w, ivec, jvec, eps=1.e-6)
    return {'visit_loss': vl, 'code_loss': cl}


In [114]:
config['model']

{'type': 'Med2Vec',
 'module_name': 'med2vec',
 'args': {'icd9_size': 4876, 'embedding_size': 256, 'hidden_size': 512}}

In [115]:
def get_instance(module, name, config, *args):
    return getattr(module, config[name]['type'])(*args, **config[name]['args'])

In [116]:
model = Med2Vec(**config['model']['args'])

In [117]:
model

Med2Vec(
  (embedding): Embedding(4876, 256)
  (relu1): ReLU()
  (relu2): ReLU()
  (linear): Linear(in_features=256, out_features=512, bias=True)
  (probits): Linear(in_features=512, out_features=4876, bias=True)
  (bce_loss): BCEWithLogitsLoss()
)

In [118]:
model

Med2Vec(
  (embedding): Embedding(4876, 256)
  (relu1): ReLU()
  (relu2): ReLU()
  (linear): Linear(in_features=256, out_features=512, bias=True)
  (probits): Linear(in_features=512, out_features=4876, bias=True)
  (bce_loss): BCEWithLogitsLoss()
)

In [119]:
# model = get_instance(module_arch, 'model', config)
# print(model)

In [120]:
config['loss_window']

5

In [121]:
for batch_idx, (x, ivec, jvec, mask, d) in enumerate(data_loader):
    data, ivec, jvec, mask, d  = x, ivec, jvec, mask, d
    # self.optimizer.zero_grad()
    probits, emb_w = model(data.float(), d)
    print(probits)
    print(emb_w)
    loss_dict = loss(data, mask.float(), probits, model.bce_loss, emb_w, ivec, jvec, window= config["loss_window"])
    # loss = loss_dict['visit_loss'] + loss_dict['code_loss']
    # loss.backward()
    # self.optimizer.step()
    break

tensor([[-0.0036, -0.0022,  0.0182,  ..., -0.0220, -0.0208,  0.0332],
        [-0.0024,  0.0133,  0.0161,  ...,  0.0050, -0.0167,  0.0200]],
       grad_fn=<AddmmBackward0>)
tensor([[0.0590, 0.0979, 0.0000,  ..., 0.0388, 0.0000, 0.0427],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0584, 0.0894],
        [0.0732, 0.0585, 0.0000,  ..., 0.0244, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0479, 0.0115,  ..., 0.0000, 0.0000, 0.0708],
        [0.0178, 0.0000, 0.0182,  ..., 0.0000, 0.0000, 0.0557],
        [0.0000, 0.0000, 0.0148,  ..., 0.0000, 0.0915, 0.0000]],
       grad_fn=<ReluBackward0>)


In [None]:
loss_

In [122]:
emb_w.shape

torch.Size([256, 4876])

In [123]:
probits.shape

torch.Size([2, 4876])

In [124]:
total_loss = 0
total_metrics = np.zeros(len(self.metrics))
for batch_idx, (x, ivec, jvec, mask, d) in enumerate(self.data_loader):
    data, ivec, jvec, mask, d  = x.to(self.device), ivec.to(self.device), jvec.to(self.device), mask.to(self.device), d.to(self.device)
    self.optimizer.zero_grad()
    probits, emb_w = self.model(data.float(), d)
    loss_dict = self.loss(data, mask.float(), probits, self.model.bce_loss, emb_w, ivec, jvec, window=self.config["loss_window"])
    loss = loss_dict['visit_loss'] + loss_dict['code_loss']
    loss.backward()
    self.optimizer.step()

NameError: name 'self' is not defined

In [None]:
# build model architecture
    model = get_instance(module_arch, 'model', config)
    print(model)

    # get function handles of loss and metrics
    loss = getattr(module_loss, config['loss'])
    metrics = [getattr(module_metric, met) for met in config['metrics']]

    # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
    trainable_params = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = get_instance(torch.optim, 'optimizer',
                             config, trainable_params)
    lr_scheduler = get_instance(
        torch.optim.lr_scheduler, 'lr_scheduler', config, optimizer)
    print(f"Length of the Data Set: {len(data_loader)}")

    trainer = Trainer(model, loss, metrics, optimizer,
                      resume=resume,
                      config=config,
                      data_loader=data_loader,
                      valid_data_loader=valid_data_loader,
                      lr_scheduler=lr_scheduler,
                      train_logger=train_logger)

    trainer.train()
