In [None]:
import os

import rootutils

root = rootutils.setup_root(os.path.abspath(""), dotenv=True, pythonpath=True, cwd=True)

from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path

import numpy as np
import polars as pl
from dateutil.relativedelta import relativedelta
from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict
from omegaconf import DictConfig


@dataclass
class DummyConfig:
    """Dummy configuration for testing MEDS dataset"""

    schema_files_root: str
    task_label_path: str
    data_dir: str
    task_name: str = "dummy_task"
    max_seq_len: int = 64
    do_prepend_static_data: bool = False
    postpend_eos_token: bool = False
    do_flatten_tensors: bool = True
    EOS_TOKEN_ID: int = 5
    do_include_subject_id: bool = True
    do_include_subsequence_indices: bool = True
    do_include_start_time_min: bool = True
    do_include_end_time: bool = True
    do_include_prediction_time: bool = True
    subsequence_sampling_strategy: str = "from_start"
    code_metadata_fp: str = field(init=False)

    def __post_init__(self):
        self.code_metadata_fp = self.data_dir + "/metadata.parquet"


def create_dummy_dataset(
    base_dir: str | Path, n_subjects: int = 3, split: str = "train", seed: int | None = 42, n_repeats: int = 3
) -> DummyConfig:
    if seed is not None:
        np.random.seed(seed)

    base_dir = Path(base_dir)

    # Create directories
    schema_dir = base_dir / "schema" / split
    schema_dir.mkdir(parents=True, exist_ok=True)
    base_dir.joinpath("data").mkdir(exist_ok=True)

    # Create static data
    base_datetime = datetime(1995, 1, 1)
    static_data = []
    for subject_id in range(n_subjects):
        static_data.append(
            {
                "subject_id": subject_id,
                "start_time": base_datetime,
                "time": [base_datetime + relativedelta(days=i) for i in range(8 * n_repeats)],
                "code": [1, 2, 3],
                "numeric_value": [0.1, 0.2, 0.3],
            }
        )
    static_df = pl.DataFrame(static_data)
    static_df.write_parquet(schema_dir / "shard_0.parquet", use_pyarrow=True)

    # Create dynamic data with consistent sequence lengths
    subject_dynamic_data = []
    for subject_id in range(n_subjects):
        length = np.random.randint(8, 8 * n_repeats)
        dynamic_data = JointNestedRaggedTensorDict(
            raw_tensors={
                "code": ([[1], [2], [1], [2], [1], [2], [1], [3]] * n_repeats)[:length],
                "numeric_value": (
                    [
                        [np.nan],
                        [np.nan],
                        [np.nan],
                        [np.nan],
                        [np.nan],
                        [np.nan],
                        [np.nan],
                        [np.nan],
                    ]
                    * n_repeats
                )[:length],
                "time_delta_days": ([1, 1, 1, 1, 1, 1, 1, 1] * n_repeats)[:length],
            }
        )
        subject_dynamic_data.append(dynamic_data)
    dynamic_data = JointNestedRaggedTensorDict.vstack(subject_dynamic_data)

    nrt_output_dir = base_dir / "data" / split
    nrt_output_dir.mkdir(parents=True, exist_ok=True)
    dynamic_data.save(nrt_output_dir / "shard_0.nrt")

    # Create task labels
    task_df = pl.DataFrame(
        {
            "subject_id": list(range(n_subjects)),
            "prediction_time": [base_datetime + relativedelta(years=3)] * n_subjects,
            "boolean_value": [i % 2 for i in range(n_subjects)],
        }
    )

    task_fp = base_dir / "task_labels.parquet"
    task_df.write_parquet(task_fp, use_pyarrow=True)

    config = DummyConfig(
        schema_files_root=str(base_dir / "schema"),
        task_label_path=str(task_fp),
        data_dir=str(base_dir),
    )

    metadata_df = pl.DataFrame(
        {
            "code": ["a", "b", "c"],
            "code/vocab_index": [1, 2, 3],
        }
    )

    metadata_df.write_parquet(config.code_metadata_fp)

    return config

In [None]:
import os
import tempfile

from meds_torch.data.components.pytorch_dataset import PytorchDataset

os.environ["CUDA_VISIBLE_DEVICES"] = "3"


tmp_dir = tempfile.TemporaryDirectory()
data_config = create_dummy_dataset(tmp_dir.name, n_subjects=64)
dataset = PytorchDataset(data_config, split="train")

dynamic_data, subject_id, st, end = dataset.load_subject_dynamic_data(0)
print("every patient has identical data:")
print(dynamic_data.flatten().to_dense()["code"])
print("the length of the data is:", str(len(dynamic_data.flatten().to_dense()["code"])))

In [None]:
"""This file prepares config fixtures for other tests."""

from pathlib import Path

import hydra
from hydra import compose, initialize

from meds_torch.utils.resolvers import setup_resolvers

setup_resolvers()


def create_cfg(overrides, config_name="train.yaml") -> DictConfig:
    """Helper function to create Hydra DictConfig with given overrides and common settings."""
    with initialize(version_base="1.3", config_path="../src/meds_torch/configs"):
        cfg = compose(config_name=config_name, return_hydra_config=True, overrides=overrides)
    return cfg


output_dir = Path(tmp_dir.name) / "output"
overrides = [
    "model=eic_forecasting",
    "model/backbone=eic_transformer_decoder_alibi",
    "trainer=gpu",
    "experiment=eic_forecast_mtr",
    "data.subsequence_sampling_strategy=random",
    f"data.code_metadata_fp={data_config.code_metadata_fp}",
    "model.optimizer.lr=0.001",
    "trainer.max_epochs=30",
    f"paths.output_dir={output_dir}",
    "model.top_k_acc=[1]",
    f"hydra.searchpath=[pkg://meds_torch.configs,{root}/MIMICIV_INDUCTIVE_EXPERIMENTS/configs/meds-torch-configs/]",
]
cfg = create_cfg(overrides)
model = hydra.utils.instantiate(cfg.model)

In [None]:
from torch.utils.data.dataloader import DataLoader

train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=dataset.collate)
val_dataloader = DataLoader(dataset, batch_size=8, shuffle=False, collate_fn=dataset.collate)

In [None]:
# callbacks = instantiate_callbacks(cfg.get("callbacks"))
# logger = instantiate_loggers(cfg.get("logger"))

trainer = hydra.utils.instantiate(cfg.trainer)
trainer.fit(
    model=model,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader,
    ckpt_path=cfg.get("ckpt_path"),
)
trainer.validate(model=model, dataloaders=val_dataloader, ckpt_path=cfg.get("ckpt_path"))

In [None]:
from meds_torch.latest_dir import get_latest_directory

print("Enter the following in terminal to view the tensorboard logs:")
print("tensorboard --logdir=%s" % get_latest_directory(cfg.paths.output_dir) + "/lightning_logs/")

In [None]:
from meds_torch.input_encoder import INPUT_ENCODER_MASK_KEY, INPUT_ENCODER_TOKENS_KEY

batch = dataset.collate([dataset[i] for i in range(8)])
input_batch = model.input_encoder.forward(batch)
prompts, mask = input_batch[INPUT_ENCODER_TOKENS_KEY], input_batch[INPUT_ENCODER_MASK_KEY]

In [None]:
import torch

prompt_lengths = mask.sum(dim=-1)
prompt_lengths

In [None]:
from x_transformers.autoregressive_wrapper import align_right

align_right(prompts, prompt_lengths, pad_id=0)

In [None]:
print(
    "Let's generate the future conditioned on the past 24 tokens (using a sliding window with a max sequence length of 24):"
)
future = model.model.generate(
    prompts=prompts[:1, :],
    mask=mask[:1, :],
    get_next_token_time=None,
    time_offset_years=None,
    temperature=model.cfg.temperature,
    eos_tokens=model.cfg.eos_tokens,
)

In [None]:
input_batch.keys()

In [None]:
test_batch = dict(
    time_delta_days=input_batch["time_delta_days"], code=input_batch["code"], mask=input_batch["mask"]
)
torch.functional.F.softmax(model.forward(test_batch)["MODEL//LOGITS_SEQUENCE"], dim=-1).argmax(dim=-1)

In [None]:
print("We observe that the future follos the repeating pattern 1,2,1,2,1,2,1,3, great!")
print(future[0][0, :24])