In [1]:
import argparse
import collections
import datetime
import functools
import json
import logging
import os
import pickle
import random
from typing import Any, List, Tuple, TypeVar

import haiku as hk
import jax
import jax.numpy as jnp
import jmp
import msgpack
import numpy as np
import pandas as pd
import optax
import sklearn.metrics

import femr.datasets
import femr.extension.dataloader
import femr.models.dataloader
import femr.models.transformer
from femr.models.transformer import TransformerFeaturizer
T = TypeVar("T")

In [2]:
batch_info_path = 'EHRSHOT_ASSETS/features/clmbr_batches/batch_info.msgpack'
loader_config_path = 'EHRSHOT_ASSETS/features/clmbr_batches/loader_config.msgpack'
batches_path = 'EHRSHOT_ASSETS/features/clmbr_batches/'
data_path = 'EHRSHOT_ASSETS/femr/extract/'
model_dir = 'EHRSHOT_ASSETS/models/clmbr/clmbr_model/'

In [3]:
with open(os.path.join(model_dir, "config.msgpack"), "rb") as f:
    config = msgpack.load(f, use_list=False)
# random.seed(config["seed"])
random.seed(9721)
config = hk.data_structures.to_immutable_dict(config)
batch_info_path = os.path.join(batches_path, "batch_info.msgpack")

with open(batch_info_path, "rb") as f:
    batch_info = msgpack.load(f, use_list=False)

patient_labels = collections.defaultdict(list)

for pid, age, label in batch_info["config"]["task"]["labels"]:
    patient_labels[pid].append((age, label))

loader = femr.extension.dataloader.BatchLoader(data_path, batch_info_path)

def model_fn(config, batch):
    model = femr.models.transformer.EHRTransformer(config)(batch, no_task=True)
    return model

When mapping codes, dropped 4736 out of 39811


In [4]:
rng = jax.random.PRNGKey(42)
model = hk.transform(model_fn)

with open(os.path.join(model_dir, "best"), "rb") as f:
    params = pickle.load(f)

@functools.partial(jax.jit, static_argnames="config")
def compute_repr(params, rng, config, batch):
    return model.apply(params, rng, config, batch)

database = femr.datasets.PatientDatabase(data_path)
results = collections.defaultdict(list)

In [5]:
split = 'train'
dev_index = 0
# for dev_index in range(loader.get_number_of_batches(split)):
raw_batch = loader.get_batch(split, dev_index)
batch = jax.tree_map(lambda a: jnp.array(a), raw_batch)

repr, mask = compute_repr(
        femr.models.transformer.convert_params(params, dtype=jnp.float16),
        rng,
        config,
        batch,
    )

repr = np.array(repr)

p_index = batch["transformer"]["label_indices"] // batch["transformer"]["length"]

for i in range(batch["num_indices"]):
    r = repr[i, :]

    label_pid = raw_batch["patient_ids"][p_index[i]]
    label_age = raw_batch["task"]["label_ages"][i]

    offset = raw_batch["offsets"][p_index[i]]
    results[label_pid].append((label_age, offset, r))

Compiling the transformer ... (131072,) (4096,)
WITHOUT AGE


In [7]:
batch['num_indices']

Array(1334, dtype=int32, weak_type=True)

In [8]:
batch["transformer"]["label_indices"]

Array([     3,     64,    224, ..., 131072, 131072, 131072], dtype=uint32)

In [9]:
os.path.dirname('TransformerFeaturizer')

''

In [10]:
import inspect
file_path = inspect.getfile(TransformerFeaturizer)
print("Function is defined in:", file_path)

Function is defined in: /home/ubuntu/anaconda3/envs/EHRSHOT_ENV/lib/python3.10/site-packages/femr/models/transformer.py


In [11]:
batch["num_indices"]

Array(1334, dtype=int32, weak_type=True)

In [12]:
batch['transformer']

{'ages': Array([0.        , 0.99930555, 0.99930555, ..., 0.        , 0.        ,
        0.        ], dtype=float32),
 'label_indices': Array([     3,     64,    224, ..., 131072, 131072, 131072], dtype=uint32),
 'length': Array(32768, dtype=int32, weak_type=True),
 'normalized_ages': Array([-1.3069556, -1.3069555, -1.3069555, ...,  0.       ,  0.       ,
         0.       ], dtype=float32),
 'tokens': Array([ 0,  3, 21, ...,  0,  0,  0], dtype=uint32),
 'valid_tokens': Array([ True,  True,  True, ..., False, False, False], dtype=bool)}

In [13]:
config.get("note_embedding_data")

In [14]:
batch['transformer']['label_indices'][0:100]

Array([   3,   64,  224,  357,  575,  679,  830,  834,  947, 1009, 1034,
       1053, 1232, 1253, 1333, 1398, 1440, 1451, 1532, 1598, 1729, 1878,
       2039, 2099, 2144, 2173, 2214, 2267, 2293, 2339, 2400, 2463, 2529,
       2672, 2755, 2796, 2910, 3001, 3106, 3196, 3204, 3306, 3392, 3497,
       3582, 3663, 3731, 3823, 3904, 3938, 4018, 4106, 4185, 4257, 4353,
       4407, 4427, 4512, 4594, 4673, 4737, 4819, 4838, 4944, 4951, 5010,
       5073, 5161, 5283, 5303, 5455, 5565, 5651, 5726, 5738, 5831, 5954,
       5992, 6293, 6334, 6414, 6549, 6654, 6698, 6770, 6821, 6831, 6866,
       6920, 6999, 7115, 7301, 7406, 7422, 7485, 7501, 7576, 7662, 7716,
       7771], dtype=uint32)

In [15]:
batch['transformer']['normalized_ages'][0:100]

Array([-1.3069556, -1.3069555, -1.3069555, -1.3069555, -1.3056909,
       -1.3056909, -1.3056909, -1.3056909, -1.3056909, -1.3056909,
       -1.3056909, -1.3056278, -1.3056278, -1.3056278, -1.3056278,
       -1.3056278, -1.3056278, -1.3056278, -1.3056278, -1.3056275,
       -1.3056275, -1.3056275, -1.3056275, -1.3056275, -1.3056275,
       -1.3056275, -1.3056275, -1.3056275, -1.3056275, -1.3056275,
       -1.3056275, -1.3056275, -1.3056275, -1.3056275, -1.3056275,
       -1.3056275, -1.3056275, -1.3056275, -1.3056275, -1.3056275,
       -1.3056275, -1.3056275, -1.3056275, -1.3056275, -1.3056275,
       -1.3056275, -1.3056275, -1.3056275, -1.3056275, -1.3056275,
       -1.3056275, -1.3056275, -1.3056275, -1.3056275, -1.3056275,
       -1.3056272, -1.3056272, -1.3056272, -1.3056272, -1.3056272,
       -1.3056272, -1.3056272, -1.3056272, -1.3056272, -1.3056272,
       -1.305626 , -1.305626 , -1.305626 , -1.305626 , -1.305626 ,
       -1.305626 , -1.305626 , -1.305626 , -1.305626 , -1.3056