In [3]:
from pyhealth.datasets import MIMIC3Dataset
import random
import torch
from torch.utils.data import Dataset
import numpy as np
import copy

## Extract Data

This file gets the data from the MIMIC-III database and formats it to be used by our Pretraining method and Fine Tuning method.     

**Single Visit Task**: pyhealth task that finds all the patients with only one visit and makes an object for that visit of the form:
```
{
    "visit_id": id of given visit,
    "patient_id" id of patient,
    "conditions": list of ICD9 codes indicating the conditions recorded in the visit,
    "drugs": list of ATC codes indicating the drugs perscribed in the visit
}
```
reference:
<br/>
1. https://pyhealth.readthedocs.io/en/latest/_modules/pyhealth/tasks/drug_recommendation.html#drug_recommendation_mimic3_fn

In [4]:
def single_visit_fn(patient):
    samples = []
    for i in range(len(patient)):
        visit: Visit = patient[i]
        conditions = visit.get_code_list(table="DIAGNOSES_ICD")
        drugs = visit.get_code_list(table="PRESCRIPTIONS")
        
        drugs = [drug[:4] for drug in drugs]
        
        if len(conditions) * len(drugs) == 0:
            continue
        samples.append(
            {
                "visit_id": visit.visit_id,
                "patient_id": patient.patient_id,
                "conditions": conditions,
                "drugs": drugs,
            }
        )

    if len(samples) != 1:
        return []
        
    return samples

**Multi Visit Task**: pyhealth task that finds all the patients with multiple visits and makes an object for that visit of the form:
```
{
    "visit_id": id of given visit,
    "patient_id" id of patient,
    "conditions": list of ICD9 codes indicating the conditions recorded in the visit,
    "drugs": list of ATC codes indicating the drugs perscribed in the visit
}
```
reference:
<br/>
1. https://pyhealth.readthedocs.io/en/latest/_modules/pyhealth/tasks/drug_recommendation.html#drug_recommendation_mimic3_fn

In [5]:
def multi_visit_fn(patient):
    samples = []
    for i in range(len(patient)):
        visit: Visit = patient[i]
        conditions = visit.get_code_list(table="DIAGNOSES_ICD")
        drugs = visit.get_code_list(table="PRESCRIPTIONS")
        # ATC 3 level
        drugs = [drug[:4] for drug in drugs]
        # exclude: visits without condition or drug code
        if len(conditions) * len(drugs) == 0:
            continue
        samples.append(
            {
                "visit_id": visit.visit_id,
                "patient_id": patient.patient_id,
                "conditions": conditions,
                "drugs": drugs,
            }
        )
    # exclude: patients with less than 2 visit
    if len(samples) < 2:
        return []

    return samples

**Extract Data**: Extracts the data from MIMIC-III tables using the Pyhealth MIMIC3Dataset. Returns a data object of the following: 
```
{
    "single_visit_patients": a list of visits for patients with only one visit,
    "multi_visit_patients": a list of visits for patients with multile visits,
    "all_drugs": a list of all the unique ATC codes in the data,
    "multi_visit_drugs": a list of all the unique ATC codes in only the multi_visit_patients data,
    "all_conditions": a list of all the unique ICD9 codes in the data,
    "multi_visit_conditions": list of all unique ICD9 codes in only the multi_visit_patients data,
    "vocab": a list of both ATC and ICD9 codes in data as well as the special codes: [PAD], [CLS], [MASK] 
}
```

In [6]:
def extract_data():
    dataset = MIMIC3Dataset(
        root="",
        tables=["DIAGNOSES_ICD", "PRESCRIPTIONS"],
        code_mapping={"NDC": "ATC"},
        # so the run time is shorter
        dev=True,
    )
    
    single_visit_patients = dataset.set_task(task_fn=single_visit_fn)
    multi_visit_patients = dataset.set_task(task_fn=multi_visit_fn)
    
    all_drugs = set([drug for visit in single_visit_patients for drug in visit["drugs"]])
    multi_visit_drugs = set([drug for visit in multi_visit_patients for drug in visit["drugs"]])
    all_drugs = all_drugs.union(multi_visit_drugs)
    
    all_conditions = set([condition for visit in single_visit_patients for condition in visit["conditions"]])
    multi_visit_conditions = set([condition for visit in multi_visit_patients for condition in visit["conditions"]])
    all_conditions = all_conditions.union(multi_visit_conditions)
    
    vocab = list(all_drugs) + list(all_conditions) + ["[PAD]", "[CLS]", "[MASK]"]
    
    data = {
        "single_visit_patients": list(single_visit_patients),
        "multi_visit_patients": list(multi_visit_patients),
        "all_drugs": list(all_drugs),
        "multi_visit_drugs": list(multi_visit_drugs),
        "all_conditions": list(all_conditions),
        "multi_visit_conditions": list(multi_visit_conditions),
        "vocab": vocab 
    }
    
    return data
    

**Split data**: Splits the ids of the patients with multiple visits into 3 lists for training, evaluating, and testing the model. Shang et al. divided the dataset using the following ration 0.6 : 0.2 : 0.2 ratio resectively and we did the same. 

In [7]:
def split_data(data):    
    train_ids = []
    test_ids = []
    eval_ids = []
    
    multi_visit_patients = data["multi_visit_patients"]
    
    all_ids = set([visit['patient_id'] for visit in multi_visit_patients])
    all_ids = list(all_ids)
    
    random_numbers = [i for i in range(len(all_ids))]
    random.shuffle(random_numbers)
    
    for i in range(int(len(all_ids)*2/3)):
        train_ids.append(all_ids[random_numbers[i]])
        
    for i in range(int(len(all_ids)*2/3), int(len(all_ids)*5/6)):
        eval_ids.append(all_ids[random_numbers[i]])
        
    for i in range(int(len(all_ids)*5/6), len(all_ids)):
        test_ids.append(all_ids[random_numbers[i]])
    
    return train_ids, test_ids, eval_ids
    

**Padding and Truncating the Data**: when getting a sample from the dataset we set the length of the drug or condition codes list to a set value of 55.  If the list was under this length the token \[PAD\] would be added to make sure it got to the desired length.  If the list was over, it was shortened to be the appropriate length.  

In [8]:
def pad_or_truncate_data(data, max_len):
    if len(data) < max_len:
        while len(data) < max_len:
            data.append('[PAD]')
    else:
        data = data[:max_len]
    return data

**Format Data for G-BERT**: the data in the G-BERT dataset was formatted in the following way:
```
data = { patient_id: [[[list of ICD9 codes for first visit], [list of ATC codes for first visit]], 
                      [[list of ICD9 codes for second visit], [list of ATC codes for second visit]],
                      ...
                      [[list of ICD9 codes for last visit], [list of ATC codes for last visit]]], ... }
```

In [9]:
def format_data_for_GBERT(data):
    data_dict = {}
    for i in data:
        if data_dict.__contains__(i["patient_id"]):
            data_dict[i["patient_id"]].append([i["conditions"], i["drugs"]])
        else:
            data_dict[i["patient_id"]] = [[i["conditions"], i["drugs"]]]
            
    return data_dict

**EHR Dataset for G-BERT**: This is the dataset used for the Dataloaders when fine tuning the model.  Shang et al's code on github was used as a reference for this dataset. 
</br>
reference:
1. https://github.com/jshang123/G-Bert/tree/f5375265ecad5724c273712e13f3afa0e6a0f932

In [17]:
class EHRDatasetForGBERT(Dataset): 
    def __init__(self, visits, data, max_seq_len):

        self.data = visits
        self.multi_visit_conditions = data["multi_visit_conditions"]
        self.multi_visit_drugs = data["multi_visit_drugs"]
        self.vocab = data["vocab"]
        self.seq_len = max_seq_len
    
    def __len__(self):
        
        return len(self.data)
    
    def __getitem__(self, item):
        subject_id = list(self.data.keys())[item]

        def fill_to_max(l, seq):
            while len(l) < seq:
                l.append('[PAD]')
            return l
        
        input_tokens = []  # (2*max_len*adm)
        output_conditions_tokens = []  # (adm-1, l)
        output_drug_tokens = []  # (adm-1, l)
        
        
        for idx, adm in enumerate(self.data[subject_id]):
            input_tokens.extend(
                (['[CLS]'] + pad_or_truncate_data(list(adm[0]), self.seq_len - 1)))
            input_tokens.extend(
                (['[CLS]'] + pad_or_truncate_data(list(adm[1]), self.seq_len - 1)))

            if idx != 0:
                output_conditions_tokens.append(list(adm[0]))
                output_drug_tokens.append(list(adm[1]))
                
        input_ids = []
        for token in input_tokens:
            input_ids.append(self.vocab.index(token))
        
        output_condition_labels = []
        output_drug_labels = []
        
        condition_voc_size = len(self.multi_visit_conditions)
        drug_voc_size = len(self.multi_visit_drugs)
        for tokens in output_conditions_tokens:
            tmp_labels = np.zeros(condition_voc_size)
            tmp_labels[list(
                map(lambda x: self.multi_visit_conditions.index(x), tokens))] = 1
            output_condition_labels.append(tmp_labels)

        for tokens in output_drug_tokens:
            tmp_labels = np.zeros(drug_voc_size)
            tmp_labels[list(
                map(lambda x: self.multi_visit_drugs.index(x), tokens))] = 1
            output_drug_labels.append(tmp_labels)
        
        cur_tensors = (torch.tensor(input_ids).view(-1, self.seq_len),
                       torch.tensor(output_condition_labels, dtype=torch.float),
                       torch.tensor(output_drug_labels, dtype=torch.float))
        return cur_tensors

**Get Data for G-BERT**: this code separates the data into groups based on the patient_id and initializes the training, testing, and validating datasets.

In [11]:
def get_data_for_GBERT(data, train_ids, test_ids, eval_ids):
    train_data = []
    test_data = []
    eval_data = []
    for i in data["multi_visit_patients"]:
        if i["patient_id"] in train_ids:
            train_data.append(i)
        if i["patient_id"] in test_ids:
            test_data.append(i)
        if i["patient_id"] in eval_ids:
            eval_data.append(i)
    
    train_data = EHRDatasetForGBERT(format_data_for_GBERT(train_data), data, 55)
    test_data = EHRDatasetForGBERT(format_data_for_GBERT(test_data), data, 55)
    eval_data = EHRDatasetForGBERT(format_data_for_GBERT(eval_data), data, 55)
            
    return train_data, test_data, eval_data

**Format data for Pre-Training**: data was formatted for pretraining in the following way:
```
    data = [[[list of ICD9 codes for first visit], [list of ATC codes for first visit]], 
            [[list of ICD9 codes for next visit], [list of ATC codes for next visit]],
            ...
            [[list of ICD9 codes for last visit], [list of ATC codes for last visit]]]

```

In [12]:
def format_data_for_pretraining(data):
    data_list = []
    for i in data:
        visit = [i["conditions"], i["drugs"]]
        data_list.append(visit)
            
    return data_list

**Random Word Masking**: in the pretraining step of G-BERT codes are masked at random with a 15% probability.  This code replaces those tokens with a \[MASK\] token. 

In [13]:
def random_word(tokens, vocab):
    for i, _ in enumerate(tokens):
        prob = random.random()
        # mask token with 15% probability
        if prob < 0.15:
            tokens[i] = "[MASK]"
        else:
            pass

    return tokens

**EHR Dataset for Pretraining**: This is the dataset used for the Dataloaders when pretraining the model.  Shang et al's code on github was used as a reference for this dataset. 
</br>
reference:
1. https://github.com/jshang123/G-Bert/tree/f5375265ecad5724c273712e13f3afa0e6a0f932

In [18]:
class EHRDatasetForPretraining(Dataset): 
    def __init__(self, visits, data, max_seq_len):

        self.data = visits
        self.all_conditions = data["all_conditions"]
        self.all_drugs = data["all_drugs"]
        self.seq_len = max_seq_len
        self.vocab = data["vocab"]
    
    def __len__(self):
        
        return len(self.data)
    
    def __getitem__(self, item):
        cur_id = item
        visit = copy.deepcopy(self.data[item])

        y_conditions = np.zeros(len(self.all_conditions))
        y_drugs = np.zeros(len(self.all_drugs))
        for item in visit[0]:
            y_conditions[self.all_conditions.index(item)] = 1
        for item in visit[1]:
            y_drugs[self.all_drugs.index(item)] = 1
                 
        visit[0] = random_word(visit[0], self.all_drugs)
        visit[1] = random_word(visit[1], self.all_conditions)
        
        input_tokens = []  # (2*max_len)
        input_tokens.extend(
            ['[CLS]'] + pad_or_truncate_data(list(visit[0]), self.seq_len - 1))
        input_tokens.extend(
            ['[CLS]'] + pad_or_truncate_data(list(visit[1]), self.seq_len - 1))
                 
        input_ids = []
        for token in input_tokens:
            input_ids.append(self.vocab.index(token))
    
        cur_tensors = (torch.tensor(input_ids, dtype=torch.long).view(-1, self.seq_len),
                       torch.tensor(y_conditions, dtype=torch.float),
                       torch.tensor(y_drugs, dtype=torch.float))

        return cur_tensors

**Get Data for Pre-Training**: this code separates the data into groups based on the patient_id and initializes the training and validating datasets.

In [15]:
def get_data_for_pretraining(data, train_ids, eval_ids):
    train_data = []
    eval_data = []
    for i in data["multi_visit_patients"]:
        if i["patient_id"] in train_ids:
            train_data.append(i)
        if i["patient_id"] in eval_ids:
            eval_data.append(i)
    
    train_data.extend(data["single_visit_patients"])
    
    train_data = EHRDatasetForPretraining(format_data_for_pretraining(train_data), data, 55)
    eval_data = EHRDatasetForPretraining(format_data_for_pretraining(eval_data), data, 55)
            
    return train_data, eval_data