In [1]:
from __future__ import annotations

import argparse
import datetime
import logging
import math
import os
import queue
import random
import sys
import threading
from typing import Any, List, Optional, Tuple, TypeVar

import jax
import msgpack
import numpy as np

import femr.datasets
import femr.extension.dataloader
import femr.labelers

T = TypeVar("T")

BatchLoader = femr.extension.dataloader.BatchLoader


def _index_thread(
    index_queue: queue.Queue[Optional[Tuple[int, int]]],
    seed: int,
    num_epochs: int,
    num_batch_threads: int,
    split: str,
    data_path: str,
    batch_info_path: str,
    num_batches: int,
) -> None:
    """Generate indices in random order and add them to the queue."""
    rng = random.Random(seed)
    step = 0
    for _ in range(num_epochs):
        order: List[int] = list(range(num_batches))
        rng.shuffle(order)

        for i in order:
            index_queue.put((i, step))
            step += 1

    for _ in range(num_batch_threads):
        index_queue.put(None)


def _batch_thread(
    index_queue: queue.Queue[Optional[Tuple[int, int]]],
    batch_queue: queue.Queue[Optional[Tuple[Any, int]]],
    data_path: str,
    batch_info_path: str,
    token_dropout: float,
    split: str,
) -> None:
    """Load batches according to the indices in the index thread and add them to the batch queue."""
    thread_loader = BatchLoader(data_path, batch_info_path, token_dropout=token_dropout)
    while True:
        next_item = index_queue.get()
        if next_item is None:
            batch_queue.put(None)
            break

        batch_index, step = next_item

        batch = thread_loader.get_batch(split, batch_index)
        if batch["num_indices"] == 0:
            batch_queue.put((None, step))
        else:
            batch = jax.tree_map(lambda a: jax.device_put(a, device=jax.devices("cpu")[0]), batch)
            batch_queue.put((batch, step))

    batch_queue.put(None)


class Batches:
    def __init__(
        self,
        data_path: str,
        batch_info_path: str,
        token_dropout: float,
        seed: int,
        num_epochs: int,
        num_batch_threads: int,
        num_batches: int,
        split: str = "train",
    ):
        print("Working with seed", seed, file=sys.stderr)
        """Create a multithreaded batch loader for the given batch info."""
        index_queue: queue.Queue[Optional[int]] = queue.Queue(maxsize=300)
        _ = index_queue

        self.batch_queue: queue.Queue[Optional[Any]] = queue.Queue(maxsize=5)

        batch_queue = self.batch_queue
        _ = batch_queue

        local = locals()

        batcher_thread = threading.Thread(
            target=_index_thread,
            kwargs={
                k: local[k]
                for k in (
                    "index_queue",
                    "seed",
                    "num_batch_threads",
                    "num_epochs",
                    "data_path",
                    "batch_info_path",
                    "num_batches",
                    "split",
                )
            },
            name="batch_thread",
            daemon=True,
        )
        batcher_thread.start()

        batcher_threads = [
            threading.Thread(
                target=_batch_thread,
                kwargs={
                    k: local[k]
                    for k in (
                        "index_queue",
                        "batch_queue",
                        "data_path",
                        "batch_info_path",
                        "data_path",
                        "token_dropout",
                        "split",
                    )
                },
                name="batch_thread",
                daemon=True,
            )
            for _ in range(num_batch_threads)
        ]

        for t in batcher_threads:
            t.start()

        self.remaining_threads = num_batch_threads

    def get_next(self) -> Optional[Any]:
        """Get the next batch, or None if we are out of batches."""
        next_item = None

        while next_item is None:
            next_item = self.batch_queue.get()
            if next_item is not None:
                return next_item
            else:
                self.remaining_threads -= 1
                if self.remaining_threads == 0:
                    return None



In [2]:
# import os
# os.listdir('EHRSHOT_ASSETS/models/clmbr/dictionary')

In [3]:
dict_path = "EHRSHOT_ASSETS/models/clmbr/dictionary"
is_hierarchical = False
transformer_vocab_size = 1024 * 64
data_path = 'EHRSHOT_ASSETS/femr/extract'
with open(dict_path, "rb") as f:
        dictionary = msgpack.load(f, use_list=False)

        if is_hierarchical:
            dict_len = len(dictionary["ontology_rollup"])
        else:
            dict_len = len(dictionary["regular"])

        assert (
            transformer_vocab_size <= dict_len
        ), f"Transformer vocab size ({transformer_vocab_size}) must be <= len(dictionary) ({dict_len})"

data = femr.datasets.PatientDatabase(data_path)

In [4]:
data[115967098].events

(Event(start=1943-12-15 00:00:00, code=SNOMED/3950001, value=None, =109886, omop_table=person),
 Event(start=1943-12-15 23:59:00, code=Race/5, value=None, =63, omop_table=person),
 Event(start=1943-12-15 23:59:00, code=Gender/F, value=None, =123, omop_table=person),
 Event(start=1943-12-15 23:59:00, code=Ethnicity/Not Hispanic, value=None, =169495, omop_table=person),
 Event(start=1997-01-04 23:59:00, code=LOINC/68848-1, value=None, =68985, omop_table=note),
 Event(start=1998-07-13 23:59:00, code=LOINC/11506-3, value=None, =35301, omop_table=note),
 Event(start=1998-11-07 11:15:00, code=LOINC/2714-4, value=95.9000015258789, =505718, unit=%, visit_id=28923412.0, omop_table=measurement),
 Event(start=1998-11-07 11:15:00, code=LOINC/19218-7, value=14.300000190734863, =589358, unit=mL/dL, visit_id=28923412.0, omop_table=measurement),
 Event(start=1998-11-07 11:15:00, code=LOINC/3150-0, value=21.0, =261347, unit=%, visit_id=28923412.0, omop_table=measurement),
 Event(start=1998-11-07 11:15:

In [7]:
data.get_patient(0)

AttributeError: 'femr.extension.datasets.PatientDatabase' object has no attribute 'get_patient'

In [25]:
import inspect
from femr.extension.dataloader import create_batches, BatchLoader

In [26]:
# inspect.getsource(create_batches)
inspect.getsource(BatchLoader)

TypeError: <class 'femr.extension.dataloader.BatchLoader'> is a built-in class