# PyHealth GAMENet Reproduction Study

## Imports

First, import python libraries, pyhealth, and related pyhealth libraries

In [1]:
# import python libraries
import argparse
import sys
import pandas
import json
import math
from typing import Tuple, List, Dict, Optional
# import torch
import torch
import torch.nn as nn
# import pyhealth libraries
import pyhealth
from pyhealth.datasets import MIMIC4Dataset, MIMIC3Dataset
from pyhealth.tasks import drug_recommendation_mimic4_fn, drug_recommendation_mimic3_fn
# import dataloader related functions
from pyhealth.datasets.splitter import split_by_patient
from pyhealth.datasets import split_by_patient, get_dataloader
# import gamenet model
from pyhealth.models import BaseModel, GAMENetLayer, RETAIN, GAMENet
# import trainer
from pyhealth.trainer import Trainer
from pyhealth.data import Patient, Visit

from pyhealth.datasets import SampleDataset
from pyhealth.medcode import ATC
from pyhealth.models.gamenet import GCN, GCNLayer
from pyhealth.models.utils import get_last_visit, batch_to_multihot

  from tqdm.autonotebook import trange


Next, import our custom libraries we've designed for this study

In [2]:
# import our custom wrapper classes
from model import ModelWrapper
from mimic import MIMIC4, MIMICWrapper

# import our constants
from constants import (
    #DEV,
    #EPOCHS, LR, DECAY_WEIGHT,
    DRUG_REC_TN, NO_HIST_TN, NO_PROC_TN,#ALL_TASKS,
    GN_KEY, RT_KEY,
    #MODEL_TYPES_PER_TASK, RETAIN_FEATS_PER_TASK,
    GAMENET_EXP, RETAIN_EXP,
    SCORE_KEY, DPV_KEY, DDI_RATE_KEY,
    BASE_DDI_RATE
)

## Constants

Next, we want to set up some constants, such as the hyperparameters we will be using

In [3]:
# whether to read in "dev" mode or not
DEV = True
EPOCHS = 5
LR = 1e-3
DECAY_WEIGHT=1e-5

## Methodology

GAMENet is implemented in [the `pyhealth` library](https://pyhealth.readthedocs.io/en/latest/api/models/pyhealth.models.GAMENet.html).
RETAIN is also implemented this way.
This makes it simple for me to try to reproduce parts of the original paper.
I implemented GAMENet and RETAIN, as well as some variants of GAMENet to support different data processing tasks.
Specifically, I performed two ablations.

The first ablation was to remove the patient's history information.
The drug preparation task did not build a history of a patient's procedures, conditions, or prescriptions.
Next, within the GAMENet model I removed the Dynamic Memory component present with the GAMENet layer.
This removed any possibility of using a patient's prior drugs to make recommendations.

The second ablation removed patient procedure information.
This involved simply omitting procedure information in the data preparation task.
Then, I modified GAMENet so that it did not intake procedures in its `forward` function.

### History Ablation

#### No Hist Drug Recommendation Task

First, we prepare the data without accounting for patient history:

In [4]:
def drug_recommendation_mimic4_no_hist(patient: Patient):
    samples = []
    for i in range(len(patient)):
        visit: Visit = patient[i]
        conditions = visit.get_code_list(table="diagnoses_icd")
        procedures = visit.get_code_list(table="procedures_icd")
        drugs = visit.get_code_list(table="prescriptions")
        # ATC 3 level
        drugs = [drug[:4] for drug in drugs]
        # exclude: visits without condition, procedure, or drug code
        if len(conditions) * len(procedures) * len(drugs) == 0:
            continue
        # TODO: should also exclude visit with age < 18
        samples.append(
            {
                "visit_id": visit.visit_id,
                "patient_id": patient.patient_id,
                "conditions": conditions,
                "procedures": procedures,
                "drugs": drugs,
                "drugs_all": drugs,
            }
        )
    # exclude: patients with less than 2 visit
    if len(samples) < 2:
        return []
    # dont add history, just make lists
    samples[0]["conditions"] = [samples[0]["conditions"]]
    samples[0]["procedures"] = [samples[0]["procedures"]]
    samples[0]["drugs_all"] = [samples[0]["drugs_all"]]

    for i in range(1, len(samples)):
        samples[i]["conditions"] = [samples[i]["conditions"]]
        samples[i]["procedures"] = [samples[i]["procedures"]]
        samples[i]["drugs_all"] = [samples[i]["drugs_all"]]

    return samples

#### GAMENet without Hist

The RETAIN model did not need to be modified to handle this ablation.
GAMENet, however, did.
I needed to define a variant of the GAMENetLayer to remove the Dynamic Memory component.
Then, I made a variant of the GAMENet model to use that new GAMENetLayer.

In [5]:
# GAMENet Layer without DM component
class GAMENetLayerNoDM(GAMENetLayer):
    """GAMENet layer.
    Paper: Junyuan Shang et al. GAMENet: Graph Augmented MEmory Networks for
    Recommending Medication Combination AAAI 2019.
    This layer is used in the GAMENet model. But it can also be used as a
    standalone layer.
    Args:
        hidden_size: hidden feature size.
        ehr_adj: an adjacency tensor of shape [num_drugs, num_drugs].
        ddi_adj: an adjacency tensor of shape [num_drugs, num_drugs].
        dropout : the dropout rate. Default is 0.5.
    Examples:
        >>> from pyhealth.models import GAMENetLayer
        >>> queries = torch.randn(3, 5, 32) # [patient, visit, hidden_size]
        >>> prev_drugs = torch.randint(0, 2, (3, 4, 50)).float()
        >>> curr_drugs = torch.randint(0, 2, (3, 50)).float()
        >>> ehr_adj = torch.randint(0, 2, (50, 50)).float()
        >>> ddi_adj = torch.randint(0, 2, (50, 50)).float()
        >>> layer = GAMENetLayer(32, ehr_adj, ddi_adj)
        >>> loss, y_prob = layer(queries, prev_drugs, curr_drugs)
        >>> loss.shape
        torch.Size([])
        >>> y_prob.shape
        torch.Size([3, 50])
    """

    def __init__(
        self,
        hidden_size: int,
        ehr_adj: torch.tensor,
        ddi_adj: torch.tensor,
        dropout: float = 0.5,
    ):
        super(GAMENetLayer, self).__init__()
        self.hidden_size = hidden_size
        self.ehr_adj = ehr_adj
        self.ddi_adj = ddi_adj

        num_labels = ehr_adj.shape[0]
        self.ehr_gcn = GCN(adj=ehr_adj, hidden_size=hidden_size, dropout=dropout)
        self.ddi_gcn = GCN(adj=ddi_adj, hidden_size=hidden_size, dropout=dropout)
        self.beta = nn.Parameter(torch.FloatTensor(1))
        self.fc = nn.Linear(2 * hidden_size, num_labels)
        self.bce_loss_fn = nn.BCEWithLogitsLoss()

    def forward(
        self,
        queries: torch.tensor,
        curr_drugs: torch.tensor,
        mask: Optional[torch.tensor] = None,
    ) -> Tuple[torch.tensor, torch.tensor]:
        """Forward propagation.
        Args:
            queries: query tensor of shape [patient, visit, hidden_size].
            prev_drugs: multihot tensor indicating drug usage in all previous
                visits of shape [patient, visit - 1, num_drugs].
            curr_drugs: multihot tensor indicating drug usage in the current
                visit of shape [patient, num_drugs].
            mask: an optional mask tensor of shape [patient, visit] where 1
                indicates valid visits and 0 indicates invalid visits.
        Returns:
            loss: a scalar tensor representing the loss.
            y_prob: a tensor of shape [patient, num_labels] representing
                the probability of each drug.
        """
        if mask is None:
            mask = torch.ones_like(queries[:, :, 0])

        """I: Input memory representation"""
        query = get_last_visit(queries, mask)

        """G: Generalization"""
        # memory bank
        MB = self.ehr_gcn() - self.ddi_gcn() * torch.sigmoid(self.beta)

        """O: Output memory representation"""
        a_c = torch.softmax(torch.mm(query, MB.t()), dim=-1)
        o_b = torch.mm(a_c, MB)

        """R: Response"""
        memory_output = torch.cat([query, o_b], dim=-1)
        logits = self.fc(memory_output)

        loss = self.bce_loss_fn(logits, curr_drugs)
        y_prob = torch.sigmoid(logits)

        return loss, y_prob

# GAMENet Model with the custom layer without DM
class GAMENetNoHist(GAMENet):
    """GAMENet model.
    Paper: Junyuan Shang et al. GAMENet: Graph Augmented MEmory Networks for
    Recommending Medication Combination AAAI 2019.
    Note:
        This model is only for medication prediction which takes conditions
        and procedures as feature_keys, and drugs_all as label_key (i.e., both
        current and previous drugs). It only operates on the visit level.
    Note:
        This model only accepts ATC level 3 as medication codes.
    Args:
        dataset: the dataset to train the model. It is used to query certain
            information such as the set of all tokens.
        embedding_dim: the embedding dimension. Default is 128.
        hidden_dim: the hidden dimension. Default is 128.
        num_layers: the number of layers used in RNN. Default is 1.
        dropout: the dropout rate. Default is 0.5.
        **kwargs: other parameters for the GAMENet layer.
    """

    def __init__(
        self,
        dataset: SampleDataset,
        embedding_dim: int = 128,
        hidden_dim: int = 128,
        num_layers: int = 1,
        dropout: float = 0.5,
        **kwargs
    ):
        super(GAMENet, self).__init__(
            dataset=dataset,
            feature_keys=["conditions", "procedures"],
            label_key="drugs_all",
            mode="multilabel",
        )
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.dropout = dropout

        self.feat_tokenizers = self.get_feature_tokenizers()
        self.label_tokenizer = self.get_label_tokenizer()
        self.embeddings = self.get_embedding_layers(self.feat_tokenizers, embedding_dim)

        ehr_adj = self.generate_ehr_adj()
        ddi_adj = self.generate_ddi_adj()

        self.cond_rnn = nn.GRU(
            embedding_dim,
            hidden_dim,
            num_layers=num_layers,
            dropout=dropout if num_layers > 1 else 0,
            batch_first=True,
        )
        self.proc_rnn = nn.GRU(
            embedding_dim,
            hidden_dim,
            num_layers=num_layers,
            dropout=dropout if num_layers > 1 else 0,
            batch_first=True,
        )
        self.query = nn.Sequential(
            nn.ReLU(),
            nn.Linear(hidden_dim * 2, hidden_dim),
        )

        # validate kwargs for GAMENet layer
        if "hidden_size" in kwargs:
            raise ValueError("hidden_size is determined by hidden_dim")
        if "ehr_adj" in kwargs:
            raise ValueError("ehr_adj is determined by the dataset")
        if "ddi_adj" in kwargs:
            raise ValueError("ddi_adj is determined by the dataset")
        self.gamenet = GAMENetLayerNoDM(
            hidden_size=hidden_dim,
            ehr_adj=ehr_adj,
            ddi_adj=ddi_adj,
            dropout=dropout,
            **kwargs,
        )

    def generate_ehr_adj(self) -> torch.tensor:
        """Generates the EHR graph adjacency matrix."""
        label_size = self.label_tokenizer.get_vocabulary_size()
        ehr_adj = torch.zeros((label_size, label_size))
        for sample in self.dataset:
            curr_drugs = sample["drugs_all"][-1]
            encoded_drugs = self.label_tokenizer.convert_tokens_to_indices(curr_drugs)
            for idx1, med1 in enumerate(encoded_drugs):
                for idx2, med2 in enumerate(encoded_drugs):
                    if idx1 >= idx2:
                        continue
                    ehr_adj[med1, med2] = 1
                    ehr_adj[med2, med1] = 1
        return ehr_adj

    def generate_ddi_adj(self) -> torch.tensor:
        """Generates the DDI graph adjacency matrix."""
        atc = ATC()
        ddi = atc.get_ddi(gamenet_ddi=True)
        label_size = self.label_tokenizer.get_vocabulary_size()
        vocab_to_index = self.label_tokenizer.vocabulary
        ddi_adj = torch.zeros((label_size, label_size))
        ddi_atc3 = [
            [ATC.convert(l[0], level=3), ATC.convert(l[1], level=3)] for l in ddi
        ]
        for atc_i, atc_j in ddi_atc3:
            if atc_i in vocab_to_index and atc_j in vocab_to_index:
                ddi_adj[vocab_to_index(atc_i), vocab_to_index(atc_j)] = 1
                ddi_adj[vocab_to_index(atc_j), vocab_to_index(atc_i)] = 1
        return ddi_adj

    def forward(
        self,
        conditions: List[List[List[str]]],
        procedures: List[List[List[str]]],
        drugs_all: List[List[List[str]]],
        **kwargs
    ) -> Dict[str, torch.Tensor]:
        """Forward propagation.
        Args:
            conditions: a nested list in three levels [patient, visit, condition].
            procedures: a nested list in three levels [patient, visit, procedure].
            drugs_all: a nested list in three levels [patient, visit, drug].
        Returns:
            A dictionary with the following keys:
                loss: a scalar tensor representing the loss.
                y_prob: a tensor of shape [patient, visit, num_labels] representing
                    the probability of each drug.
                y_true: a tensor of shape [patient, visit, num_labels] representing
                    the ground truth of each drug.
        """
        conditions = self.feat_tokenizers["conditions"].batch_encode_3d(conditions)
        # (patient, visit, code)
        conditions = torch.tensor(conditions, dtype=torch.long, device=self.device)
        # (patient, visit, code, embedding_dim)
        conditions = self.embeddings["conditions"](conditions)
        # (patient, visit, embedding_dim)
        conditions = torch.sum(conditions, dim=2)
        # (batch, visit, hidden_size)
        conditions, _ = self.cond_rnn(conditions)

        procedures = self.feat_tokenizers["procedures"].batch_encode_3d(procedures)
        # (patient, visit, code)
        procedures = torch.tensor(procedures, dtype=torch.long, device=self.device)
        # (patient, visit, code, embedding_dim)
        procedures = self.embeddings["procedures"](procedures)
        # (patient, visit, embedding_dim)
        procedures = torch.sum(procedures, dim=2)
        # (batch, visit, hidden_size)
        procedures, _ = self.proc_rnn(procedures)

        # (batch, visit, 2 * hidden_size)
        patient_representations = torch.cat([conditions, procedures], dim=-1)
        # (batch, visit, hidden_size)
        queries = self.query(patient_representations)

        label_size = self.label_tokenizer.get_vocabulary_size()
        drugs_all = self.label_tokenizer.batch_encode_3d(
            drugs_all, padding=(False, False), truncation=(True, False)
        )

        curr_drugs = [p[-1] for p in drugs_all]
        curr_drugs = batch_to_multihot(curr_drugs, label_size)
        curr_drugs = curr_drugs.to(self.device)

        # get mask
        mask = torch.sum(conditions, dim=2) != 0

        # process drugs
        loss, y_prob = self.gamenet(queries, curr_drugs, mask)


        return {
            "loss": loss,
            "y_prob": y_prob,
            "y_true": curr_drugs,
        }

## Procedure Ablation

The GAMENet model also needed a few changes for this ablation.
The only change on the RETAIN side is omitting the procedures feature from the model instantiation.

#### No Proc Drug Recommendation Task

First, we prepare the data without collecting patient procedure information:

In [6]:
def drug_recommendation_mimic4_no_proc(patient: Patient):
    samples = []
    for i in range(len(patient)):
        visit: Visit = patient[i]
        conditions = visit.get_code_list(table="diagnoses_icd")
        drugs = visit.get_code_list(table="prescriptions")
        # ATC 3 level
        drugs = [drug[:4] for drug in drugs]
        # exclude: visits without condition, procedure, or drug code
        if len(conditions) * len(drugs) == 0:
            continue
        # TODO: should also exclude visit with age < 18
        samples.append(
            {
                "visit_id": visit.visit_id,
                "patient_id": patient.patient_id,
                "conditions": conditions,
                "drugs": drugs,
                "drugs_all": drugs,
            }
        )
    # exclude: patients with less than 2 visit
    if len(samples) < 2:
        return []
    # add history
    samples[0]["conditions"] = [samples[0]["conditions"]]
    samples[0]["drugs_all"] = [samples[0]["drugs_all"]]

    for i in range(1, len(samples)):
        samples[i]["conditions"] = samples[i - 1]["conditions"] + [
            samples[i]["conditions"]
        ]
        samples[i]["drugs_all"] = samples[i - 1]["drugs_all"] + [
            samples[i]["drugs_all"]
        ]

    return samples

#### GAMENet with No Proc

Then, I need to redefine the GAMENet model to not accept procedure information in its forward function.
I also remove any processing of procedures within this function, and remove the Gated Recurrent Unit (GRU) related to procedures.

In [7]:
class GAMENetNoProc(GAMENet):
    """GAMENet model.

    Paper: Junyuan Shang et al. GAMENet: Graph Augmented MEmory Networks for
    Recommending Medication Combination AAAI 2019.

    Note:
        This model is only for medication prediction which takes conditions
        as feature_keys, and drugs_all as label_key (i.e., both
        current and previous drugs). It only operates on the visit level.

    Note:
        This model only accepts ATC level 3 as medication codes.

    Args:
        dataset: the dataset to train the model. It is used to query certain
            information such as the set of all tokens.
        embedding_dim: the embedding dimension. Default is 128.
        hidden_dim: the hidden dimension. Default is 128.
        num_layers: the number of layers used in RNN. Default is 1.
        dropout: the dropout rate. Default is 0.5.
        **kwargs: other parameters for the GAMENet layer.
    """

    def __init__(
        self,
        dataset: SampleDataset,
        embedding_dim: int = 128,
        hidden_dim: int = 128,
        num_layers: int = 1,
        dropout: float = 0.5,
        **kwargs
    ):
        super(GAMENet, self).__init__(
            dataset=dataset,
            feature_keys=["conditions"],
            label_key="drugs_all",
            mode="multilabel",
        )
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.dropout = dropout

        self.feat_tokenizers = self.get_feature_tokenizers()
        self.label_tokenizer = self.get_label_tokenizer()
        self.embeddings = self.get_embedding_layers(self.feat_tokenizers, embedding_dim)

        ehr_adj = self.generate_ehr_adj()
        ddi_adj = self.generate_ddi_adj()

        self.cond_rnn = nn.GRU(
            embedding_dim,
            hidden_dim,
            num_layers=num_layers,
            dropout=dropout if num_layers > 1 else 0,
            batch_first=True,
        )
        self.query = nn.Sequential(
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )

        # validate kwargs for GAMENet layer
        if "hidden_size" in kwargs:
            raise ValueError("hidden_size is determined by hidden_dim")
        if "ehr_adj" in kwargs:
            raise ValueError("ehr_adj is determined by the dataset")
        if "ddi_adj" in kwargs:
            raise ValueError("ddi_adj is determined by the dataset")
        self.gamenet = GAMENetLayer(
            hidden_size=hidden_dim,
            ehr_adj=ehr_adj,
            ddi_adj=ddi_adj,
            dropout=dropout,
            **kwargs,
        )

    def generate_ehr_adj(self) -> torch.tensor:
        """Generates the EHR graph adjacency matrix."""
        label_size = self.label_tokenizer.get_vocabulary_size()
        ehr_adj = torch.zeros((label_size, label_size))
        for sample in self.dataset:
            curr_drugs = sample["drugs_all"][-1]
            encoded_drugs = self.label_tokenizer.convert_tokens_to_indices(curr_drugs)
            for idx1, med1 in enumerate(encoded_drugs):
                for idx2, med2 in enumerate(encoded_drugs):
                    if idx1 >= idx2:
                        continue
                    ehr_adj[med1, med2] = 1
                    ehr_adj[med2, med1] = 1
        return ehr_adj

    def generate_ddi_adj(self) -> torch.tensor:
        """Generates the DDI graph adjacency matrix."""
        atc = ATC()
        ddi = atc.get_ddi(gamenet_ddi=True)
        label_size = self.label_tokenizer.get_vocabulary_size()
        vocab_to_index = self.label_tokenizer.vocabulary
        ddi_adj = torch.zeros((label_size, label_size))
        ddi_atc3 = [
            [ATC.convert(l[0], level=3), ATC.convert(l[1], level=3)] for l in ddi
        ]
        for atc_i, atc_j in ddi_atc3:
            if atc_i in vocab_to_index and atc_j in vocab_to_index:
                ddi_adj[vocab_to_index(atc_i), vocab_to_index(atc_j)] = 1
                ddi_adj[vocab_to_index(atc_j), vocab_to_index(atc_i)] = 1
        return ddi_adj

    def forward(
        self,
        conditions: List[List[List[str]]],
        drugs_all: List[List[List[str]]],
        **kwargs
    ) -> Dict[str, torch.Tensor]:
        """Forward propagation.

        Args:
            conditions: a nested list in three levels [patient, visit, condition].
            drugs_all: a nested list in three levels [patient, visit, drug].

        Returns:
            A dictionary with the following keys:
                loss: a scalar tensor representing the loss.
                y_prob: a tensor of shape [patient, visit, num_labels] representing
                    the probability of each drug.
                y_true: a tensor of shape [patient, visit, num_labels] representing
                    the ground truth of each drug.

        """
        conditions = self.feat_tokenizers["conditions"].batch_encode_3d(conditions)
        # (patient, visit, code)
        conditions = torch.tensor(conditions, dtype=torch.long, device=self.device)
        # (patient, visit, code, embedding_dim)
        conditions = self.embeddings["conditions"](conditions)
        # (patient, visit, embedding_dim)
        conditions = torch.sum(conditions, dim=2)
        # (batch, visit, hidden_size)
        conditions, _ = self.cond_rnn(conditions)

        # (batch, visit, 2 * hidden_size)
        patient_representations = torch.cat([conditions], dim=-1)
        # (batch, visit, hidden_size)
        queries = self.query(patient_representations)

        label_size = self.label_tokenizer.get_vocabulary_size()
        drugs_all = self.label_tokenizer.batch_encode_3d(
            drugs_all, padding=(False, False), truncation=(True, False)
        )

        curr_drugs = [p[-1] for p in drugs_all]
        curr_drugs = batch_to_multihot(curr_drugs, label_size)
        curr_drugs = curr_drugs.to(self.device)

        prev_drugs = [p[:-1] for p in drugs_all]
        max_num_visit = max([len(p) for p in prev_drugs])
        prev_drugs = [p + [[]] * (max_num_visit - len(p)) for p in prev_drugs]
        prev_drugs = [batch_to_multihot(p, label_size) for p in prev_drugs]
        prev_drugs = torch.stack(prev_drugs, dim=0)
        prev_drugs = prev_drugs.to(self.device)

        # get mask
        mask = torch.sum(conditions, dim=2) != 0

        # process drugs
        loss, y_prob = self.gamenet(queries, prev_drugs, curr_drugs, mask)

        return {
            "loss": loss,
            "y_prob": y_prob,
            "y_true": curr_drugs,
        }

## Data Tasks

After the tasks and models have been created, I can use these to create a "tasklist" that will help us to run all the tasks we are interested in.
I also create a dictionary showing which model variant we should use for each data preparation task.
Finally, I create a dictionary to decide which features to use for the RETAIN model based on the task.

In [8]:
# we do some strict checking in the ModelWrapper for the allowed model types
# so, we will actually import the gamenet variants here that we defined in the repo
# comment this out and modify the ModelWrapper class to allow custom variants
from alt_gamenets import GAMENetNoHist as GNNH
from alt_gamenets import GAMENetNoProc as GNNP
# create "tasklist", which is a dictionary of simple task name -> drug task
MIMIC4_TASKS = {
    DRUG_REC_TN: drug_recommendation_mimic4_fn,
    NO_HIST_TN: drug_recommendation_mimic4_no_hist,
    NO_PROC_TN: drug_recommendation_mimic4_no_proc,
}
# create dictionary showing what model corresponds to what task
MODEL_TYPES_PER_TASK = {
    DRUG_REC_TN: {GN_KEY: GAMENet, RT_KEY: RETAIN},
    NO_HIST_TN: {GN_KEY: GNNH, RT_KEY: RETAIN},
    NO_PROC_TN: {GN_KEY: GNNP, RT_KEY: RETAIN},
}
# define retain features per task
RETAIN_DEFAULT_FEATURES = ["conditions", "procedures"]

RETAIN_FEATS_PER_TASK = {
    DRUG_REC_TN: RETAIN_DEFAULT_FEATURES,
    NO_HIST_TN: RETAIN_DEFAULT_FEATURES,
    NO_PROC_TN: ["conditions"],
}

## Load Data

We use the `MIMIC4` data class from the `mimic` import.
Another option would be to import the `MIMIC3` data class and use that as the `dataset` below.
For this purpose, we just use `MIMIC4` and use that to load the data and prepare it with the appropriate tasks.

The data class (either `MIMIC4` or `MIMIC3` decides where the data root is.
By default, it reads in data from `./hiddendata/extracted/{mimic3/4}/`.
So, for MIMIC4 data the data would need to be in: `./hiddendata/extracted/mimic4/`.
This can be changed by either modifying the MIMIC4 class directly, or modifying the data root default in the `constants` file.

Finally, we use the tasks we defined above as the tasklist for the `MIMICWrapper`.

In [9]:
# save data in the ./hiddendata/extracted/ directory
## this uses MIMIC4, so place data in ./hiddendata/extracted/mimic4/
## could use MIMIC3 with from mimic import MIMIC3 and placing data in ./hiddendata/extracted/mimic3/
dataset = MIMIC4
# we will run all tasks possible on the dataset
# can define custom tasklist here
mimic = MIMICWrapper(datasource=dataset, tasks=MIMIC4_TASKS)
mimic_data = mimic.load_data(dev=DEV)
drug_task_data = mimic.drug_task_data()
dataloaders = mimic.create_dataloaders()

-*-READING DEV DATA-*-
reading mimic4 data...
---DATA STATS FOR mimic4 DATA---
stat

Statistics of base dataset (dev=True):
	- Dataset: MIMIC4Dataset
	- Number of patients: 607
	- Number of visits: 1463
	- Number of visits per patient: 2.4102
	- Number of events per visit in diagnoses_icd: 11.6705
	- Number of events per visit in procedures_icd: 1.4846
	- Number of events per visit in prescriptions: 55.9850

info

dataset.patients: patient_id -> <Patient>

<Patient>
    - visits: visit_id -> <Visit> 
    - other patient-level info
    
    <Visit>
        - event_list_dict: table_name -> List[Event]
        - other visit-level info
    
        <Event>
            - code: str
            - other event-level info

***run task: drug_recommendation


Generating samples for drug_recommendation_mimic4_fn: 100%|████████████████| 607/607 [00:00<00:00, 24046.00it/s]


{'visit_id': '22595853', 'patient_id': '10000032', 'conditions': [['5723', '78959', '5715', '07070', '496', '29680', '30981', 'V1582']], 'procedures': [['5491']], 'drugs': ['B01A', 'J07B', 'A12B', 'C03D', 'C03C', 'N02B', 'J05A', 'R03A', 'N07B', 'R03B'], 'drugs_all': [['B01A', 'J07B', 'A12B', 'C03D', 'C03C', 'N02B', 'J05A', 'R03A', 'N07B', 'R03B']]}
***run task: no_hist


Generating samples for drug_recommendation_mimic4_no_hist: 100%|███████████| 607/607 [00:00<00:00, 23352.13it/s]


{'visit_id': '22595853', 'patient_id': '10000032', 'conditions': [['5723', '78959', '5715', '07070', '496', '29680', '30981', 'V1582']], 'procedures': [['5491']], 'drugs': ['B01A', 'J07B', 'A12B', 'C03D', 'C03C', 'N02B', 'J05A', 'R03A', 'N07B', 'R03B'], 'drugs_all': [['B01A', 'J07B', 'A12B', 'C03D', 'C03C', 'N02B', 'J05A', 'R03A', 'N07B', 'R03B']]}
***run task: no_proc


Generating samples for drug_recommendation_mimic4_no_proc: 100%|███████████| 607/607 [00:00<00:00, 25597.14it/s]


{'visit_id': '22595853', 'patient_id': '10000032', 'conditions': [['5723', '78959', '5715', '07070', '496', '29680', '30981', 'V1582']], 'drugs': ['B01A', 'J07B', 'A12B', 'C03D', 'C03C', 'N02B', 'J05A', 'R03A', 'N07B', 'R03B'], 'drugs_all': [['B01A', 'J07B', 'A12B', 'C03D', 'C03C', 'N02B', 'J05A', 'R03A', 'N07B', 'R03B']]}


In [10]:
drug_task_data

{'drug_recommendation': <pyhealth.datasets.sample_dataset.SampleDataset at 0x7f0b0bb52190>,
 'no_hist': <pyhealth.datasets.sample_dataset.SampleDataset at 0x7f0adf0e2130>,
 'no_proc': <pyhealth.datasets.sample_dataset.SampleDataset at 0x7f0adf0e2610>}

## Create DDI Matrices

In order to calculate DDI Rate, we need to create the DDI matrices.
GAMENet models have this built-in, but RETAIN does not.
So, we craft our DDI matrices ahead of time to help us calculate the rate later.

In [11]:
ddi_mats = {}

for taskname in mimic.get_task_names():
    model_type = MODEL_TYPES_PER_TASK[taskname][GN_KEY]
    ddi_mats[taskname] = model_type(drug_task_data[taskname]).generate_ddi_adj()

## Train the Models

Previously, I defined a dictionary of taskname -> modelname -> model variant.
Using the ModeWrapper, I can pass in the appropriate model from that dictionary to use as the base model.
Then, I train the models for both the RETAIN baseline and for GAMENet

In [12]:
retain = {}
gamenet = {}

In [13]:
# baseline
print("---RETAIN TRAINING---")
for taskname,dataloader in dataloaders.items():
    print("--training retain on {} data--".format(taskname))
    # create and train retain model
    retain[taskname] = ModelWrapper(
        drug_task_data[taskname],
        model=MODEL_TYPES_PER_TASK[taskname][RT_KEY],
        feature_keys=RETAIN_FEATS_PER_TASK[taskname],
        experiment="{}_task_{}".format(RETAIN_EXP, taskname)
    )
    retain[taskname].train_model(
        dataloader["train"], dataloader["val"],
        decay_weight=DECAY_WEIGHT,
        learning_rate=LR,
        epochs=EPOCHS
    )

# gamenet
print("---GAMENET TRAINING---")
for taskname,dataloader in dataloaders.items():
    print("--training gamenet on {} data--".format(taskname))
    # create and train gamenet model
    gamenet[taskname] = ModelWrapper(
        drug_task_data[taskname],
        model=MODEL_TYPES_PER_TASK[taskname][GN_KEY],
        experiment="{}_task_{}".format(GAMENET_EXP, taskname)
    )
    gamenet[taskname].train_model(
        dataloader["train"], dataloader["val"],
        decay_weight = DECAY_WEIGHT,
        learning_rate = LR,
        epochs=EPOCHS
    )

---RETAIN TRAINING---
--training retain on drug_recommendation data--
making retain model


RETAIN(
  (embeddings): ModuleDict(
    (conditions): Embedding(1908, 128, padding_idx=0)
    (procedures): Embedding(609, 128, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (retain): ModuleDict(
    (conditions): RETAINLayer(
      (dropout_layer): Dropout(p=0.5, inplace=False)
      (alpha_gru): GRU(128, 128, batch_first=True)
      (beta_gru): GRU(128, 128, batch_first=True)
      (alpha_li): Linear(in_features=128, out_features=1, bias=True)
      (beta_li): Linear(in_features=128, out_features=128, bias=True)
    )
    (procedures): RETAINLayer(
      (dropout_layer): Dropout(p=0.5, inplace=False)
      (alpha_gru): GRU(128, 128, batch_first=True)
      (beta_gru): GRU(128, 128, batch_first=True)
      (alpha_li): Linear(in_features=128, out_features=1, bias=True)
      (beta_li): Linear(in_features=128, out_features=128, bias=True)
    )
  )
  (fc): Linear(in_features=256, out_features=160, bias=True)
)
Metrics: ['jaccard_samples', 'accuracy', 'hamming_loss', 'precision_sa

Epoch 0 / 5:   0%|          | 0/7 [00:00<?, ?it/s]

--- Train epoch-0, step-7 ---
loss: 0.6926
Evaluation: 100%|█████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 80.38it/s]
--- Eval epoch-0, step-7 ---
jaccard_samples: 0.1089
accuracy: 0.0000
hamming_loss: 0.4204
precision_samples: 0.1350
recall_samples: 0.4196
pr_auc_samples: 0.1674
f1_samples: 0.1891
loss: 0.6852
New best accuracy score (0.0000) at epoch-0, step-7



Epoch 1 / 5:   0%|          | 0/7 [00:00<?, ?it/s]

--- Train epoch-1, step-14 ---
loss: 0.6796
Evaluation: 100%|█████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 83.45it/s]
--- Eval epoch-1, step-14 ---
jaccard_samples: 0.1234
accuracy: 0.0000
hamming_loss: 0.3492
precision_samples: 0.1617
recall_samples: 0.4171
pr_auc_samples: 0.2067
f1_samples: 0.2115
loss: 0.6746



Epoch 2 / 5:   0%|          | 0/7 [00:00<?, ?it/s]

--- Train epoch-2, step-21 ---
loss: 0.6597
Evaluation: 100%|█████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 94.70it/s]
--- Eval epoch-2, step-21 ---
jaccard_samples: 0.1519
accuracy: 0.0000
hamming_loss: 0.2691
precision_samples: 0.2151
recall_samples: 0.4137
pr_auc_samples: 0.2741
f1_samples: 0.2542
loss: 0.6562



Epoch 3 / 5:   0%|          | 0/7 [00:00<?, ?it/s]

--- Train epoch-3, step-28 ---
loss: 0.6380
Evaluation: 100%|█████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 90.67it/s]
--- Eval epoch-3, step-28 ---
jaccard_samples: 0.1906
accuracy: 0.0000
hamming_loss: 0.1983
precision_samples: 0.3131
recall_samples: 0.4154
pr_auc_samples: 0.3626
f1_samples: 0.3099
loss: 0.6221



Epoch 4 / 5:   0%|          | 0/7 [00:00<?, ?it/s]

--- Train epoch-4, step-35 ---
loss: 0.5921
Evaluation: 100%|█████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 86.24it/s]
--- Eval epoch-4, step-35 ---
jaccard_samples: 0.2264
accuracy: 0.0000
hamming_loss: 0.1544
precision_samples: 0.4206
recall_samples: 0.4037
pr_auc_samples: 0.4328
f1_samples: 0.3579
loss: 0.5730
Loaded best model
RETAIN(
  (embeddings): ModuleDict(
    (conditions): Embedding(1908, 128, padding_idx=0)
    (procedures): Embedding(609, 128, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (retain): ModuleDict(
    (conditions): RETAINLayer(
      (dropout_layer): Dropout(p=0.5, inplace=False)
      (alpha_gru): GRU(128, 128, batch_first=True)
      (beta_gru): GRU(128, 128, batch_first=True)
      (alpha_li): Linear(in_features=128, out_features=1, bias=True)
      (beta_li): Linear(in_features=128, out_features=128, bias=True)
    )
    (procedures): RETAINLayer(
      (dropout_layer): Dropout(p=0.5, inplace=False)
      (alpha

--training retain on no_hist data--
making retain model


Epoch 0 / 5:   0%|          | 0/7 [00:00<?, ?it/s]

--- Train epoch-0, step-7 ---
loss: 0.7284
Evaluation: 100%|█████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 76.98it/s]
--- Eval epoch-0, step-7 ---
jaccard_samples: 0.1183
accuracy: 0.0000
hamming_loss: 0.4409
precision_samples: 0.1484
recall_samples: 0.4682
pr_auc_samples: 0.1708
f1_samples: 0.2021
loss: 0.6866
New best accuracy score (0.0000) at epoch-0, step-7



Epoch 1 / 5:   0%|          | 0/7 [00:00<?, ?it/s]

--- Train epoch-1, step-14 ---
loss: 0.6945
Evaluation: 100%|████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 139.18it/s]
--- Eval epoch-1, step-14 ---
jaccard_samples: 0.1283
accuracy: 0.0000
hamming_loss: 0.3811
precision_samples: 0.1667
recall_samples: 0.4596
pr_auc_samples: 0.1999
f1_samples: 0.2189
loss: 0.6640



Epoch 2 / 5:   0%|          | 0/7 [00:00<?, ?it/s]

--- Train epoch-2, step-21 ---
loss: 0.6509
Evaluation: 100%|████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 117.27it/s]
--- Eval epoch-2, step-21 ---
jaccard_samples: 0.1532
accuracy: 0.0000
hamming_loss: 0.3014
precision_samples: 0.2158
recall_samples: 0.4651
pr_auc_samples: 0.2557
f1_samples: 0.2584
loss: 0.6327



Epoch 3 / 5:   0%|          | 0/7 [00:00<?, ?it/s]

--- Train epoch-3, step-28 ---
loss: 0.5692
Evaluation: 100%|████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 124.53it/s]
--- Eval epoch-3, step-28 ---
jaccard_samples: 0.1846
accuracy: 0.0000
hamming_loss: 0.2228
precision_samples: 0.2890
recall_samples: 0.4424
pr_auc_samples: 0.3279
f1_samples: 0.3042
loss: 0.5938



Epoch 4 / 5:   0%|          | 0/7 [00:00<?, ?it/s]

--- Train epoch-4, step-35 ---
loss: 0.4963
Evaluation: 100%|████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 105.83it/s]
--- Eval epoch-4, step-35 ---
jaccard_samples: 0.2255
accuracy: 0.0000
hamming_loss: 0.1691
precision_samples: 0.3904
recall_samples: 0.4240
pr_auc_samples: 0.4038
f1_samples: 0.3604
loss: 0.5496
Loaded best model
RETAIN(
  (embeddings): ModuleDict(
    (conditions): Embedding(2575, 128, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (retain): ModuleDict(
    (conditions): RETAINLayer(
      (dropout_layer): Dropout(p=0.5, inplace=False)
      (alpha_gru): GRU(128, 128, batch_first=True)
      (beta_gru): GRU(128, 128, batch_first=True)
      (alpha_li): Linear(in_features=128, out_features=1, bias=True)
      (beta_li): Linear(in_features=128, out_features=128, bias=True)
    )
  )
  (fc): Linear(in_features=128, out_features=165, bias=True)
)
Metrics: ['jaccard_samples', 'accuracy', 'hamming_loss', 'precision_samples', 'rec

--training retain on no_proc data--
making retain model


Epoch 0 / 5:   0%|          | 0/11 [00:00<?, ?it/s]

--- Train epoch-0, step-11 ---
loss: 0.6899
Evaluation: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 70.41it/s]
--- Eval epoch-0, step-11 ---
jaccard_samples: 0.1536
accuracy: 0.0000
hamming_loss: 0.4218
precision_samples: 0.1930
recall_samples: 0.5045
pr_auc_samples: 0.2205
f1_samples: 0.2602
loss: 0.6827
New best accuracy score (0.0000) at epoch-0, step-11



Epoch 1 / 5:   0%|          | 0/11 [00:00<?, ?it/s]

--- Train epoch-1, step-22 ---
loss: 0.6604
Evaluation: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 67.44it/s]
--- Eval epoch-1, step-22 ---
jaccard_samples: 0.1805
accuracy: 0.0000
hamming_loss: 0.3124
precision_samples: 0.2566
recall_samples: 0.4479
pr_auc_samples: 0.2907
f1_samples: 0.2994
loss: 0.6337



Epoch 2 / 5:   0%|          | 0/11 [00:00<?, ?it/s]

--- Train epoch-2, step-33 ---
loss: 0.5830
Evaluation: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 71.22it/s]
--- Eval epoch-2, step-33 ---
jaccard_samples: 0.2165
accuracy: 0.0000
hamming_loss: 0.2188
precision_samples: 0.3959
recall_samples: 0.3908
pr_auc_samples: 0.4033
f1_samples: 0.3496
loss: 0.5369



Epoch 3 / 5:   0%|          | 0/11 [00:00<?, ?it/s]

--- Train epoch-3, step-44 ---
loss: 0.4903
Evaluation: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 69.55it/s]
--- Eval epoch-3, step-44 ---
jaccard_samples: 0.2457
accuracy: 0.0000
hamming_loss: 0.1672
precision_samples: 0.5478
recall_samples: 0.3458
pr_auc_samples: 0.5001
f1_samples: 0.3871
loss: 0.4391



Epoch 4 / 5:   0%|          | 0/11 [00:00<?, ?it/s]

--- Train epoch-4, step-55 ---
loss: 0.4269
Evaluation: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 70.27it/s]
--- Eval epoch-4, step-55 ---
jaccard_samples: 0.2502
accuracy: 0.0000
hamming_loss: 0.1547
precision_samples: 0.6101
recall_samples: 0.3302
pr_auc_samples: 0.5532
f1_samples: 0.3923
loss: 0.4034
Loaded best model


---GAMENET TRAINING---
--training gamenet on drug_recommendation data--
making gamenet model


GAMENet(
  (embeddings): ModuleDict(
    (conditions): Embedding(1908, 128, padding_idx=0)
    (procedures): Embedding(609, 128, padding_idx=0)
  )
  (cond_rnn): GRU(128, 128, batch_first=True)
  (proc_rnn): GRU(128, 128, batch_first=True)
  (query): Sequential(
    (0): ReLU()
    (1): Linear(in_features=256, out_features=128, bias=True)
  )
  (gamenet): GAMENetLayer(
    (ehr_gcn): GCN(
      (gcn1): GCNLayer()
      (dropout_layer): Dropout(p=0.5, inplace=False)
      (gcn2): GCNLayer()
    )
    (ddi_gcn): GCN(
      (gcn1): GCNLayer()
      (dropout_layer): Dropout(p=0.5, inplace=False)
      (gcn2): GCNLayer()
    )
    (fc): Linear(in_features=384, out_features=160, bias=True)
    (bce_loss_fn): BCEWithLogitsLoss()
  )
)
Metrics: ['jaccard_samples', 'accuracy', 'hamming_loss', 'precision_samples', 'recall_samples', 'pr_auc_samples', 'f1_samples']
Device: cuda

Training:
Batch size: 64
Optimizer: <class 'torch.optim.adam.Adam'>
Optimizer params: {'lr': 0.001}
Weight decay: 1e-05


Epoch 0 / 5:   0%|          | 0/7 [00:00<?, ?it/s]

--- Train epoch-0, step-7 ---
loss: 0.6223
Evaluation: 100%|█████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 95.80it/s]
--- Eval epoch-0, step-7 ---
jaccard_samples: 0.2637
accuracy: 0.0000
hamming_loss: 0.1107
precision_samples: 0.6063
recall_samples: 0.3564
pr_auc_samples: 0.4822
f1_samples: 0.4093
loss: 0.5519
New best accuracy score (0.0000) at epoch-0, step-7



Epoch 1 / 5:   0%|          | 0/7 [00:00<?, ?it/s]

--- Train epoch-1, step-14 ---
loss: 0.4844
Evaluation: 100%|█████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 95.43it/s]
--- Eval epoch-1, step-14 ---
jaccard_samples: 0.2627
accuracy: 0.0000
hamming_loss: 0.1095
precision_samples: 0.6344
recall_samples: 0.3468
pr_auc_samples: 0.5452
f1_samples: 0.4079
loss: 0.4558



Epoch 2 / 5:   0%|          | 0/7 [00:00<?, ?it/s]

--- Train epoch-2, step-21 ---
loss: 0.4151
Evaluation: 100%|█████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 96.45it/s]
--- Eval epoch-2, step-21 ---
jaccard_samples: 0.2568
accuracy: 0.0000
hamming_loss: 0.1082
precision_samples: 0.6521
recall_samples: 0.3301
pr_auc_samples: 0.5589
f1_samples: 0.4012
loss: 0.4049



Epoch 3 / 5:   0%|          | 0/7 [00:00<?, ?it/s]

--- Train epoch-3, step-28 ---
loss: 0.3768
Evaluation: 100%|████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 103.72it/s]
--- Eval epoch-3, step-28 ---
jaccard_samples: 0.2585
accuracy: 0.0000
hamming_loss: 0.1087
precision_samples: 0.6424
recall_samples: 0.3342
pr_auc_samples: 0.5635
f1_samples: 0.4029
loss: 0.3532



Epoch 4 / 5:   0%|          | 0/7 [00:00<?, ?it/s]

--- Train epoch-4, step-35 ---
loss: 0.3345
Evaluation: 100%|█████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 93.89it/s]
--- Eval epoch-4, step-35 ---
jaccard_samples: 0.2519
accuracy: 0.0000
hamming_loss: 0.1070
precision_samples: 0.6855
recall_samples: 0.3172
pr_auc_samples: 0.5645
f1_samples: 0.3955
loss: 0.3226
Loaded best model


--training gamenet on no_hist data--
making gamenet model without hist...


GAMENetNoHist(
  (embeddings): ModuleDict(
    (conditions): Embedding(1908, 128, padding_idx=0)
    (procedures): Embedding(609, 128, padding_idx=0)
  )
  (cond_rnn): GRU(128, 128, batch_first=True)
  (proc_rnn): GRU(128, 128, batch_first=True)
  (query): Sequential(
    (0): ReLU()
    (1): Linear(in_features=256, out_features=128, bias=True)
  )
  (gamenet): GAMENetLayerNoDM(
    (ehr_gcn): GCN(
      (gcn1): GCNLayer()
      (dropout_layer): Dropout(p=0.5, inplace=False)
      (gcn2): GCNLayer()
    )
    (ddi_gcn): GCN(
      (gcn1): GCNLayer()
      (dropout_layer): Dropout(p=0.5, inplace=False)
      (gcn2): GCNLayer()
    )
    (fc): Linear(in_features=256, out_features=160, bias=True)
    (bce_loss_fn): BCEWithLogitsLoss()
  )
)
Metrics: ['jaccard_samples', 'accuracy', 'hamming_loss', 'precision_samples', 'recall_samples', 'pr_auc_samples', 'f1_samples']
Device: cuda

Training:
Batch size: 64
Optimizer: <class 'torch.optim.adam.Adam'>
Optimizer params: {'lr': 0.001}
Weight dec

Epoch 0 / 5:   0%|          | 0/7 [00:00<?, ?it/s]

--- Train epoch-0, step-7 ---
loss: 0.6617
Evaluation: 100%|████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 257.35it/s]
--- Eval epoch-0, step-7 ---
jaccard_samples: 0.2252
accuracy: 0.0000
hamming_loss: 0.1389
precision_samples: 0.4588
recall_samples: 0.3400
pr_auc_samples: 0.4303
f1_samples: 0.3588
loss: 0.5925
New best accuracy score (0.0000) at epoch-0, step-7



Epoch 1 / 5:   0%|          | 0/7 [00:00<?, ?it/s]

--- Train epoch-1, step-14 ---
loss: 0.5050
Evaluation: 100%|████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 181.83it/s]
--- Eval epoch-1, step-14 ---
jaccard_samples: 0.2813
accuracy: 0.0000
hamming_loss: 0.1239
precision_samples: 0.5550
recall_samples: 0.3908
pr_auc_samples: 0.5154
f1_samples: 0.4257
loss: 0.4254



Epoch 2 / 5:   0%|          | 0/7 [00:00<?, ?it/s]

--- Train epoch-2, step-21 ---
loss: 0.3575
Evaluation: 100%|████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 271.02it/s]
--- Eval epoch-2, step-21 ---
jaccard_samples: 0.3003
accuracy: 0.0000
hamming_loss: 0.1199
precision_samples: 0.5839
recall_samples: 0.4146
pr_auc_samples: 0.5557
f1_samples: 0.4475
loss: 0.3250



Epoch 3 / 5:   0%|          | 0/7 [00:00<?, ?it/s]

--- Train epoch-3, step-28 ---
loss: 0.3296
Evaluation: 100%|████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 284.78it/s]
--- Eval epoch-3, step-28 ---
jaccard_samples: 0.2853
accuracy: 0.0000
hamming_loss: 0.1260
precision_samples: 0.5362
recall_samples: 0.4116
pr_auc_samples: 0.5618
f1_samples: 0.4310
loss: 0.3134



Epoch 4 / 5:   0%|          | 0/7 [00:00<?, ?it/s]

--- Train epoch-4, step-35 ---
loss: 0.3127
Evaluation: 100%|████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 267.00it/s]
--- Eval epoch-4, step-35 ---
jaccard_samples: 0.2818
accuracy: 0.0000
hamming_loss: 0.1279
precision_samples: 0.5293
recall_samples: 0.4136
pr_auc_samples: 0.5546
f1_samples: 0.4286
loss: 0.3138
Loaded best model


--training gamenet on no_proc data--
making gamenet model without procedures...


GAMENetNoProc(
  (embeddings): ModuleDict(
    (conditions): Embedding(2575, 128, padding_idx=0)
  )
  (cond_rnn): GRU(128, 128, batch_first=True)
  (query): Sequential(
    (0): ReLU()
    (1): Linear(in_features=128, out_features=128, bias=True)
  )
  (gamenet): GAMENetLayer(
    (ehr_gcn): GCN(
      (gcn1): GCNLayer()
      (dropout_layer): Dropout(p=0.5, inplace=False)
      (gcn2): GCNLayer()
    )
    (ddi_gcn): GCN(
      (gcn1): GCNLayer()
      (dropout_layer): Dropout(p=0.5, inplace=False)
      (gcn2): GCNLayer()
    )
    (fc): Linear(in_features=384, out_features=165, bias=True)
    (bce_loss_fn): BCEWithLogitsLoss()
  )
)
Metrics: ['jaccard_samples', 'accuracy', 'hamming_loss', 'precision_samples', 'recall_samples', 'pr_auc_samples', 'f1_samples']
Device: cuda

Training:
Batch size: 64
Optimizer: <class 'torch.optim.adam.Adam'>
Optimizer params: {'lr': 0.001}
Weight decay: 1e-05
Max grad norm: None
Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x7f0ad

Epoch 0 / 5:   0%|          | 0/11 [00:00<?, ?it/s]

--- Train epoch-0, step-11 ---
loss: 0.5823
Evaluation: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 49.13it/s]
--- Eval epoch-0, step-11 ---
jaccard_samples: 0.2250
accuracy: 0.0000
hamming_loss: 0.1435
precision_samples: 0.7698
recall_samples: 0.2546
pr_auc_samples: 0.5918
f1_samples: 0.3584
loss: 0.4793
New best accuracy score (0.0000) at epoch-0, step-11



Epoch 1 / 5:   0%|          | 0/11 [00:00<?, ?it/s]

--- Train epoch-1, step-22 ---
loss: 0.4123
Evaluation: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 43.98it/s]
--- Eval epoch-1, step-22 ---
jaccard_samples: 0.2137
accuracy: 0.0000
hamming_loss: 0.1451
precision_samples: 0.7670
recall_samples: 0.2419
pr_auc_samples: 0.6116
f1_samples: 0.3426
loss: 0.4697



Epoch 2 / 5:   0%|          | 0/11 [00:00<?, ?it/s]

--- Train epoch-2, step-33 ---
loss: 0.3305
Evaluation: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 40.88it/s]
--- Eval epoch-2, step-33 ---
jaccard_samples: 0.2232
accuracy: 0.0000
hamming_loss: 0.1438
precision_samples: 0.7727
recall_samples: 0.2489
pr_auc_samples: 0.6263
f1_samples: 0.3557
loss: 0.4350



Epoch 3 / 5:   0%|          | 0/11 [00:00<?, ?it/s]

--- Train epoch-3, step-44 ---
loss: 0.2935
Evaluation: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 50.40it/s]
--- Eval epoch-3, step-44 ---
jaccard_samples: 0.2659
accuracy: 0.0000
hamming_loss: 0.1381
precision_samples: 0.7798
recall_samples: 0.2991
pr_auc_samples: 0.6294
f1_samples: 0.4088
loss: 0.3922



Epoch 4 / 5:   0%|          | 0/11 [00:00<?, ?it/s]

--- Train epoch-4, step-55 ---
loss: 0.2774
Evaluation: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 44.33it/s]
--- Eval epoch-4, step-55 ---
jaccard_samples: 0.2126
accuracy: 0.0000
hamming_loss: 0.1440
precision_samples: 0.7930
recall_samples: 0.2331
pr_auc_samples: 0.6254
f1_samples: 0.3426
loss: 0.3997
Loaded best model


## Evaluate the Models

Next I evaluate all of the models for the different tasks.
This includes getting the normal scores from `pyhealth` and also calculating DDI Rate and average drugs per visit.

In [14]:
baseline_result = {}
gamenet_result = {}

In [15]:
# baseline
print("---RETAIN EVALUATION---")
for taskname in mimic.get_task_names():
    print("--eval retain on {} data--".format(taskname))
    test_loader = dataloaders[taskname]["test"]
    baseline_result[taskname] = {}
    baseline_result[taskname] = retain[taskname].evaluate_model(test_loader)
    baseline_result[taskname][DPV_KEY] = retain[taskname].calc_avg_drugs_per_visit(test_loader)
    baseline_result[taskname][DDI_RATE_KEY] = retain[taskname].calc_ddi_rate(test_loader, ddi_mats[taskname])
    
# gamenet
print("---GAMENET EVALUATION---")
for taskname in mimic.get_task_names():
    print("--eval gamenet on {} data--".format(taskname))
    test_loader = dataloaders[taskname]["test"]
    gamenet_result[taskname] = {}
    gamenet_result[taskname] = gamenet[taskname].evaluate_model(test_loader)
    gamenet_result[taskname][DPV_KEY] = gamenet[taskname].calc_avg_drugs_per_visit(test_loader)
    gamenet_result[taskname][DDI_RATE_KEY] = gamenet[taskname].calc_ddi_rate(test_loader, ddi_mats[taskname])

---RETAIN EVALUATION---
--eval retain on drug_recommendation data--


Evaluation: 100%|█████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 23.91it/s]


{'jaccard_samples': 0.14146940962999824, 'accuracy': 0.0, 'hamming_loss': 0.42834302325581397, 'precision_samples': 0.17616063019906833, 'recall_samples': 0.4827049565056121, 'pr_auc_samples': 0.20238118914231482, 'f1_samples': 0.241459547618023, 'loss': 0.6866651773452759}


Evaluation: 100%|█████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 75.87it/s]
Evaluation: 100%|█████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 68.72it/s]


--eval retain on no_hist data--


Evaluation: 100%|████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 104.18it/s]


{'jaccard_samples': 0.15062685408395096, 'accuracy': 0.0, 'hamming_loss': 0.43569711538461536, 'precision_samples': 0.18602211478706557, 'recall_samples': 0.5375429973384637, 'pr_auc_samples': 0.21858402598081533, 'f1_samples': 0.25050916474858786, 'loss': 0.6869425773620605}


Evaluation: 100%|█████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 91.53it/s]
Evaluation: 100%|█████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 92.49it/s]


--eval retain on no_proc data--


Evaluation: 100%|████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 102.50it/s]


{'jaccard_samples': 0.12198150103349795, 'accuracy': 0.0, 'hamming_loss': 0.4001393242772553, 'precision_samples': 0.15764750921602252, 'recall_samples': 0.4667931620825103, 'pr_auc_samples': 0.1969090999965072, 'f1_samples': 0.2122508281576915, 'loss': 0.6740173697471619}


Evaluation: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 94.05it/s]
Evaluation: 100%|████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 105.45it/s]


---GAMENET EVALUATION---
--eval gamenet on drug_recommendation data--


Evaluation: 100%|█████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 83.54it/s]


{'jaccard_samples': 0.2782539747315031, 'accuracy': 0.0, 'hamming_loss': 0.1311046511627907, 'precision_samples': 0.7051679586563306, 'recall_samples': 0.3310068391585073, 'pr_auc_samples': 0.5450597929886727, 'f1_samples': 0.4244387699606564, 'loss': 0.5369114279747009}


Evaluation: 100%|█████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 69.90it/s]
Evaluation: 100%|█████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 75.61it/s]


--eval gamenet on no_hist data--


Evaluation: 100%|████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 218.34it/s]


{'jaccard_samples': 0.2577125324103029, 'accuracy': 0.0, 'hamming_loss': 0.14927884615384615, 'precision_samples': 0.5576445189425958, 'recall_samples': 0.3712612252367469, 'pr_auc_samples': 0.5113222327062721, 'f1_samples': 0.3992516398878801, 'loss': 0.5893192887306213}


Evaluation: 100%|████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 161.44it/s]
Evaluation: 100%|████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 228.42it/s]


--eval gamenet on no_proc data--


Evaluation: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 84.71it/s]


{'jaccard_samples': 0.24298973179750794, 'accuracy': 0.0, 'hamming_loss': 0.12079414838035528, 'precision_samples': 0.6929392446633825, 'recall_samples': 0.29053999571629796, 'pr_auc_samples': 0.5123229705737585, 'f1_samples': 0.37743372050818086, 'loss': 0.4500429481267929}


Evaluation: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 82.76it/s]
Evaluation: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 87.58it/s]


## Results Summary

Finally, we generate some tables and visualizations that will be used in the final report and video.
We build these tables from the dictionaries of taskname -> model that we created during the training/evaluation step.
Keep in mind, that if `DEV=True`, these will just be sample results and not reflective of the study.
For an example of results run with the full dataset, [please see this old version of the notebook](https://github.com/HeapsOfRam/GAMENet/blob/5fbb96e549e5aaf9ec2e4ffeef1bff06d0589e67/pyhealth/EDA.ipynb).

In [16]:
metrics_columns = [
    "jaccard_samples", "precision_samples", "recall_samples",
    "pr_auc_samples", "f1_samples",
    "avg_dpv", "ddi_rate"
]

In [17]:
retain_metrics = pandas.DataFrame.from_dict(baseline_result)

In [18]:
gamenet_metrics = pandas.DataFrame.from_dict(gamenet_result)

In [19]:
display(retain_metrics.T)

Unnamed: 0,jaccard_samples,accuracy,hamming_loss,precision_samples,recall_samples,pr_auc_samples,f1_samples,loss,avg_dpv,ddi_rate
drug_recommendation,0.141469,0.0,0.428343,0.176161,0.482705,0.202381,0.24146,0.686665,67.372093,0.040289
no_hist,0.150627,0.0,0.435697,0.186022,0.537543,0.218584,0.250509,0.686943,71.307692,0.045094
no_proc,0.121982,0.0,0.400139,0.157648,0.466793,0.196909,0.212251,0.674017,62.816092,0.042028


In [20]:
display(gamenet_metrics.T)

Unnamed: 0,jaccard_samples,accuracy,hamming_loss,precision_samples,recall_samples,pr_auc_samples,f1_samples,loss,avg_dpv,ddi_rate
drug_recommendation,0.278254,0.0,0.131105,0.705168,0.331007,0.54506,0.424439,0.536911,9.767442,0.116531
no_hist,0.257713,0.0,0.149279,0.557645,0.371261,0.511322,0.399252,0.589319,12.942308,0.070128
no_proc,0.24299,0.0,0.120794,0.692939,0.29054,0.512323,0.377434,0.450043,7.0,0.238095


In [21]:
print(retain_metrics.T.to_latex())

\begin{tabular}{lrrrrrrrrrr}
\toprule
 & jaccard_samples & accuracy & hamming_loss & precision_samples & recall_samples & pr_auc_samples & f1_samples & loss & avg_dpv & ddi_rate \\
\midrule
drug_recommendation & 0.141469 & 0.000000 & 0.428343 & 0.176161 & 0.482705 & 0.202381 & 0.241460 & 0.686665 & 67.372093 & 0.040289 \\
no_hist & 0.150627 & 0.000000 & 0.435697 & 0.186022 & 0.537543 & 0.218584 & 0.250509 & 0.686943 & 71.307692 & 0.045094 \\
no_proc & 0.121982 & 0.000000 & 0.400139 & 0.157648 & 0.466793 & 0.196909 & 0.212251 & 0.674017 & 62.816092 & 0.042028 \\
\bottomrule
\end{tabular}



In [22]:
print(gamenet_metrics.T.to_latex())

\begin{tabular}{lrrrrrrrrrr}
\toprule
 & jaccard_samples & accuracy & hamming_loss & precision_samples & recall_samples & pr_auc_samples & f1_samples & loss & avg_dpv & ddi_rate \\
\midrule
drug_recommendation & 0.278254 & 0.000000 & 0.131105 & 0.705168 & 0.331007 & 0.545060 & 0.424439 & 0.536911 & 9.767442 & 0.116531 \\
no_hist & 0.257713 & 0.000000 & 0.149279 & 0.557645 & 0.371261 & 0.511322 & 0.399252 & 0.589319 & 12.942308 & 0.070128 \\
no_proc & 0.242990 & 0.000000 & 0.120794 & 0.692939 & 0.290540 & 0.512323 & 0.377434 & 0.450043 & 7.000000 & 0.238095 \\
\bottomrule
\end{tabular}



# Reproducibility Summary

In the original study, GAMENet proves to be a powerful tool for recommending drug combinations to patients while avoiding adverse DDI.
It also consistently recommends fewer drugs than other methods, especially compared to baselines.
In general, my replication study seems to corroborate these findings.
Through testing the GAMENet model against the RETAIN baseline and performing a couple of ablations on the original model, some interesting results were produced.

Like the original paper, in my reproduction GAMEnet generated less DDI than other models.
Unlike the original paper, in my reproduction the RETAIN baseline actually recommends fewer drugs per visit on average.
However, this difference is only about 2 drugs per visit.
Additionally, GAMENet maintains competitive evaluation metrics to RETAIN in many cases.
My conclusion is that the additional 2 drugs per visit made by GAMENet is worth the tradeoff of avoiding potentially harmful DDI.

The ablations proved that the procecure information is critical to making good recommendations.
These patterns help to avoid DDI and also generally improve performance metrics.
History is also an important characteristic to consider, however omitting these details do not hurt the model performance as much as omitting procedure information.

## References

```
@inproceedings{GAMENet:2019,
    title="{GAMENet: Graph Augmented MEmory Networks for Recommending Medication Combination}",
    author={Junyuan Shang and Cao Xiao and Tengfei Ma and Hongyan Li and Jimeng Sun},
    journal={arXiv preprint arXiv:1809.01852},
    year={2019},
    eprint={1809.01852},
    archivePrefix={arXiv},
    primaryClass={cs.AI}
}

@inproceedings{Doctor2Vec:2020,
    title="{Doctor2Vec: Dynamic Doctor Representation Learning for Clinical Trial Recruitment}",
    author={Siddharth Biswal and Cao Xiao and Lucas M. Glass and Elizabeth Milkovits and Jimeng Sun},
    year={2019},
    eprint={1911.10395},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}

@inproceedings{Leap:2017,
    author = {Zhang, Yutao and Chen, Robert and Tang, Jie and Stewart, Walter F. and Sun, Jimeng},
    title = {LEAP: Learning to Prescribe Effective and Safe Treatment Combinations for Multimorbidity},
    year = {2017},
    isbn = {9781450348874},
    publisher = {Association for Computing Machinery},
    address = {New York, NY, USA},
    url = {https://doi.org/10.1145/3097983.3098109},
    doi = {10.1145/3097983.3098109},
    abstract = {Managing patients with complex multimorbidity has long been recognized as a difficult problem due to complex disease and medication dependencies and the potential risk of adverse drug interactions. Existing work either uses complicated rule-based protocols which are hard to implement and maintain, or simple statistical models that treat each disease independently, which may lead to sub-optimal or even harmful drug combinations. In this work, we propose the LEAP (LEArn to Prescribe) algorithm to decompose the treatment recommendation into a sequential decision-making process while automatically determining the appropriate number of medications. A recurrent decoder is used to model label dependencies and content-based attention is used to capture label instance mapping. We further leverage reinforcement learning to fine tune the model parameters to ensure accuracy and completeness. We incorporate external clinical knowledge into the design of the reinforcement reward to effectively prevent generating unfavorable drug combinations. Both quantitative experiments and qualitative case studies are conducted on two real world electronic health record datasets to verify the effectiveness of our solution. On both datasets, LEAP significantly outperforms baselines by up to 10-30% in terms of mean Jaccard coefficient and removes 99.8% adverse drug interactions in the recommended treatment sets.},
    booktitle = {Proceedings of the 23rd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining},
    pages = {1315–1324},
    numpages = {10},
    keywords = {multi-instance multilabel learning, multimorbidity, treatment recommendation},
    location = {Halifax, NS, Canada},
    series = {KDD '17}
}

@inproceedings{DMNC:2018,
    title="{Dual Memory Neural Computer for Asynchronous Two-view Sequential Learning}",
    author={Hung Le and Truyen Tran and Svetha Venkatesh},
    year={2018},
    eprint={1802.00662},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}

@inproceedings{RETAIN:2017,
    title="{RETAIN: An Interpretable Predictive Model for Healthcare using Reverse Time Attention Mechanism}",
    author={Edward Choi and Mohammad Taha Bahadori and Joshua A. Kulas and Andy Schuetz and Walter F. Stewart and Jimeng Sun},
    year={2017},
    eprint={1608.05745},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}
```

## Cite 

The authors have asked to cite their paper when using their work:

```
@article{shang2018gamenet,
  title="{GAMENet: Graph Augmented MEmory Networks for Recommending Medication Combination}",
  author={Shang, Junyuan and Xiao, Cao and Ma, Tengfei and Li, Hongyan and Sun, Jimeng},
  journal={arXiv preprint arXiv:1809.01852},
  year={2018}
}
```