In [1]:
from comet_ml import Experiment
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

from core.model import SAnD
from data.mimiciii import get_mimic_iii
from utils.trainer import NeuralNetworkClassifier
from pyhealth.tasks import mortality_prediction_mimic3_fn
from pyhealth.tokenizer import Tokenizer

In [2]:
def collate_fn(data):
    sequences, labels = zip(*data)
    y = torch.tensor(labels, dtype=torch.long)
    num_patients = len(sequences)
    num_visits = [len(patient) for patient in sequences]
    max_num_visits = max(num_visits)

    x = torch.zeros((num_patients, max_num_visits, len(freq_codes)), dtype=torch.float)
    for i_patient, patient in enumerate(sequences):
        for j_visit, visit in enumerate(patient):
            for code in visit:
                """
                TODO: 1. check if code is in freq_codes;
                      2. obtain the code index using code2idx;
                      3. set the correspoindg element in x to 1.
                """
                if code in freq_codes:
                    x[i_patient, j_visit, code2idx[code]] = 1

                y[i_patient] = labels[i_patient]

    masks = torch.sum(x, dim=-1) > 0

    return x, masks, y

In [3]:
def get_dataloader(dataset, batch_size, shuffle=False):
    dataloader = DataLoader(
        dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn
    )
    return dataloader

In [4]:
def split_data(data, train: float, val: float, test: float):
    err = 1e-5
    if 1 - err < (train + val + test) < 1 + err == False:
        raise Exception(f"{train=} + {val=} + {test=} = {train+val+test}. Needs to be 1.")
    length = len(data)
    end_train = int(len(data) * train)
    end_val = int(len(data) * val) + end_train

    return data[:end_train], data[end_train:end_val], data[end_val:]

In [5]:
tokenizers = {}

def tokenizer_helper(sample, key: str) -> np.array:
    if key not in tokenizers:
        alls =  {s for l in [sample[key][0] for sample in dataset.samples] for s in l}
        tokenizers[key] = Tokenizer(list(alls))
    tokenizer = tokenizers[key]
    items = sample[key][0]
    item_table = np.zeros(shape=(tokenizer.get_vocabulary_size()))
    item_indicies = tokenizer.convert_tokens_to_indices(items)
    item_table[item_indicies] = True
    return item_table

In [18]:
dataset = get_mimic_iii().set_task(mortality_prediction_mimic3_fn)
sample = dataset.samples[0]
(tokenizer_helper(sample, "drugs"), tokenizer_helper(sample, "conditions"), tokenizer_helper(sample, "procedures"))
n_tokens = sum(v.get_vocabulary_size() for v in tokenizers.values())

Generating samples for mortality_prediction_mimic3_fn: 100%|██████████| 46520/46520 [00:00<00:00, 121588.35it/s]


28

In [31]:
"""
In Hospital Mortality: Mortality prediction is vital during rapid triage and risk/severity assessment. In Hospital
Mortality is defined as the outcome of whether a patient dies
during the period of hospital admission or lives to be discharged. This problem is posed as a binary classification one
where each data sample spans a 24-hour time window. True
mortality labels were curated by comparing date of death
(DOD) with hospital admission and discharge times. The
mortality rate within the benchmark cohort is only 13%.
"""
in_feature = n_tokens
n_heads = 32
factor = 32
num_class = 2
num_layers = 6

patients = len(dataset.patient_to_index)
seq_len = max(len(v) for v in dataset.patient_to_index.values())
new_dataset =torch.zeros((patients, seq_len, n_tokens,))
new_labels = torch.zeros((patients,))

i = 0
for p_id, visits in dataset.patient_to_index.items():
    for n_visit, sample_idx in enumerate(visits):
        sample = dataset.samples[sample_idx]
        sample_data = np.concatenate((tokenizer_helper(sample, "drugs"), tokenizer_helper(sample, "conditions"), tokenizer_helper(sample, "procedures")))
        new_dataset[i][n_visit] = torch.tensor(sample_data)
        new_labels[i] = max(new_labels[i], sample["label"])
    i += 1
# create dataloaders (they are <torch.data.DataLoader> object)
train_data, val_data, test_data = split_data(new_dataset, 0.8, .1, .1)
train_labels, val_labels, test_labels = split_data(new_labels, 0.8, .1, .1)

train_dataset = TensorDataset(train_data, train_labels)
val_dataset = TensorDataset(val_data, val_labels)
test_dataset = TensorDataset(test_data, test_labels)

train_loader = get_dataloader(train_dataset, batch_size=seq_len, shuffle=True)
val_loader = get_dataloader(val_dataset, batch_size=seq_len, shuffle=False)
test_loader = get_dataloader(test_dataset, batch_size=seq_len, shuffle=False)


In [32]:
clf = NeuralNetworkClassifier(
    SAnD(in_feature, seq_len, n_heads, factor, num_class, num_layers),
    nn.CrossEntropyLoss(),
    optim.Adam, optimizer_config={
        "lr": 1e-5, "betas": (0.9, 0.98), "eps": 4e-09, "weight_decay": 5e-4},
    experiment=Experiment(
        api_key="eQ3INeSsFGUYKahSdEtjhry42",
        project_name="general",
        workspace="samdoud"
    )
)

# training network
clf.fit(
    {
        "train": train_loader,
        "val": val_loader
    },
    epochs=1
)

# evaluating
clf.evaluate(test_loader)

# save
clf.save_to_file("./save_params/")

COMET INFO: Experiment is live on comet.com https://www.comet.com/samdoud/general/84247acd0256413d80b0637f72131f73

  0%|          | 0/4948 [00:00<?, ?it/s]

TypeError: only integer tensors of a single element can be converted to an index