# How to load tokens from patient data

Generally we start off with some hydra command that defines where data is stored and the general model training recipe. Here we directly convert this to a python configuration so we can tinker with everything locally.


If you want to prepend static data to the sequences, add `data.do_prepend_static_data=true` as an arg in the cmd variable.

By default, the dataset extracts subsequences of limited length from the full patient data. However, setting `data.max_seq_len=1000000` effectively disables this behavior by making the maximum sequence length larger than any actual patient sequence, ensuring you get the complete data for each patient.

In [None]:
import os
import shlex

import hydra

from meds_torch.latest_dir import get_latest_directory

# We need to import the resolvers so the cfg object has access to them
from meds_torch.utils.resolvers import setup_resolvers

setup_resolvers()

ROOT_DIR = "/storage/shared/mimic-iv/meds_v0.3.2/"  # Replace with your actual root directory
PRETRAIN_OUTPUT_DIR = f"{ROOT_DIR}/results/zero_shot/eic_hparam_sweep"
MODEL_SWEEP_DIR = get_latest_directory(PRETRAIN_OUTPUT_DIR)
BEST_CHECKPOINT = f"{MODEL_SWEEP_DIR}/checkpoints/best_model.ckpt"
BEST_CONFIG = f"{MODEL_SWEEP_DIR}/best_config.json"
MEDS_DIR = f"{ROOT_DIR}/meds/"
TENSOR_DIR = f"{ROOT_DIR}/eic_tensors/"
N = 10
M = 10
OUTPUT_DIR = f"{ROOT_DIR}/results/zero_shot/inference/eic/debug/"
TUTORIAL_DIR = os.getcwd()

cmd = f"""
meds-torch-generate model=eic_forecasting experiment=eic_forecast_mtr \
    model.max_tokens_budget={M} data.subsequence_sampling_strategy=from_start data.max_seq_len=1000000 \
    data.dataloader.batch_size=512 model.generate_id=0 trainer.devices=[0] data.predict_dataset=test \
    data.do_include_subject_id=true data.do_include_prediction_time=true data.do_include_end_time=true \
    paths.meds_cohort_dir={MEDS_DIR} ckpt_path={BEST_CHECKPOINT} \
    paths.data_dir={TENSOR_DIR} paths.output_dir={OUTPUT_DIR} \
    "hydra.searchpath=[pkg://meds_torch.configs,{TUTORIAL_DIR}/configs/]"
"""
with hydra.initialize(version_base="1.3", config_path="../src/meds_torch/configs"):
    args = shlex.split(cmd)[1:]
    overrides = hydra._internal.utils.get_args(args).overrides
    cfg = hydra.compose(config_name="generate_trajectories", return_hydra_config=True, overrides=overrides)

Boom! We have a config! Let's use this to load the dataset and pull some data

In [None]:
from meds_torch.data.components.pytorch_dataset import PytorchDataset

datamodule = hydra.utils.instantiate(cfg.data)
datamodule.setup()
val_pytorch_dataset: PytorchDataset = datamodule.data_val

In [None]:
data = val_pytorch_dataset.collate([val_pytorch_dataset[0]])
print(data.keys())
print(data["code"])  # sequence of observations for the patient
print(data["subject_id"])  # patient id
print(data["end_time"])  # datetime time of the last observation

To interpret these codes^, we can load the metadata that maps each vocabulary index to three pieces of information:

- The `code` - the human-readable string representation
- The `description` - a detailed explanation of what the code means
- The `values/mean` - for codes that represent binned numeric values (like vital signs or lab results), this shows the mean value within that bin

In [None]:
import polars as pl

metadata_df = pl.read_parquet(cfg.data.code_metadata_fp)
metadata_df.with_columns((pl.col("values/sum") / pl.col("values/n_occurrences")).alias("values/mean"))


print(metadata_df["code", "description", "code/vocab_index"])
index_to_code = dict(zip(metadata_df["code/vocab_index"], metadata_df["code"]))
interpetable_data = [index_to_code[index] for index in data["code"].flatten().tolist()]
interpetable_data

Let's now load a model and run some custom sequence of codes through it and see what it generates!

In [None]:
import torch

model = hydra.utils.instantiate(cfg.model)
checkpoint = torch.load(cfg.ckpt_path, map_location="cpu")
model.load_state_dict(checkpoint["state_dict"])

Now let's get some generated data. Each row in the generated table represents one of the next 10 predicted medical observations per patient, with these columns:

- `subject_id`: Unique identifier for each patient
- `prediction_time`: Timestamp of the last real observation in the patient's history used to make predictions - remains constant across all predictions for the same generation sequence.
- `time`: Predicted timestamp when this observation would occur (calculated by adding the model's predicted time delta to `prediction_time`)
- `code`: Human-readable name of the medical code (e.g., "Heart Rate", "Blood Pressure") 
- `code/vocab_index`: Internal integer id used by the model to represent each medical code
- `numeric_value`: For quantitative measurements (like lab values or vital signs), represents the average value within the predicted categorical bin (e.g., if heart rates are binned into ranges of 60-70, 70-80, etc., shows the average within the predicted bin)

In [None]:
import datetime

subject_id = torch.tensor([1,2])
custom_codes_patient_1 = torch.tensor([4, 5, 8, 182])
custom_codes_patient_2 = torch.tensor([188, 48, 78, 178, 1827, 2973])

print("We pad the codes with padding tokens (which are 0's)")
custom_codes = torch.nn.utils.rnn.pad_sequence([custom_codes_patient_1, custom_codes_patient_2], batch_first=True)
print(custom_codes)
mask = custom_codes != 0
print("The mask is True for non-padding tokens")
print(mask)

print("Add some end_time as the generated data will use this as the starting time.")
patient_1_end_time = datetime.datetime(year=2000, month=1, day=1)
patient_2_end_time = datetime.datetime(year=2000, month=1, day=1)
end_time = [patient_1_end_time, patient_2_end_time]

print("Pass this to the model and we get the following keys")
batch = dict(code=custom_codes, mask=mask, prediction_time=end_time, end_time=end_time, subject_id=subject_id)
output = model(batch)
print(output.keys())

print("And we can interpret the generated trajectory:")
Note that the time is the time of the observation, code is the interpretable code name, the code/vocab_index is the vocab index, and the numeric_value is the average numeric value for the code numeric value bin
output['GENERATE//0']