In [None]:
from transformers import AutoModel, AutoTokenizer
import torch
from tqdm import tqdm
import pickle
import json

In [None]:
def embed_descriptions(descriptions, model_directory, batch_size=32):
    """ Given a list of descriptions, a model directory, and a device, returns embeddings for the descriptions."""
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    tokenizer = AutoTokenizer.from_pretrained(model_directory)

    # Load pre-trained model (weights)
    model = AutoModel.from_pretrained(model_directory)
    model.to(device)
    model.eval()  # Put the model in "evaluation" mode, which turns off dropout
    print(f"Model loaded | Generating embeddings with batch size={batch_size}")

    # Prepare inputs as a dictionary for the model
    inputs = tokenizer(descriptions, padding=True, truncation=True, return_tensors="pt", max_length=64)
    inputs = {k: v.to(device) for k, v in inputs.items()}  # Move inputs to the specified device

    # Process in batches with tqdm for progress tracking
    embeddings = []
    for i in tqdm(range(0, len(descriptions), batch_size), desc="Generating Embeddings"):
        batch = {k: v[i:i + batch_size] for k, v in inputs.items()}  # Create batch for the current iteration
        with torch.no_grad():
            outputs = model(**batch)

        # Extract pooled output embeddings
        batch_embeddings = outputs.pooler_output
        embeddings.append(batch_embeddings)

    # Concatenate batched embeddings
    embeddings = torch.cat(embeddings, dim=0)

    return {descriptions[i]: embeddings[i] for i in range(len(descriptions))}

In [None]:
# generate embeddings

with open("data/all_v2.json","r") as j_file:
    data = json.load(j_file)

# get set of all descriptions / reasondescriptions
texts = set()

for patient in tqdm(data):
    for encounter in patient["encounters"]:
        texts.add(encounter["encounter"]["Description"]) # append enc description
        texts.add(encounter["encounter"]["ReasonDescription"]) # append enc reasondescription

        texts = texts | set([_["Description"] for _ in encounter["conditions"]]) # condition desc

        texts = texts | set([_["Description"] for _ in encounter["careplans"]]) # careplan descs

        texts = texts | set([_["ReasonDescription"] for _ in encounter["careplans"]]) # careplan reas descs

        texts = texts | set([_["Description"] for _ in encounter["procedures"]]) # proc descs

        texts = texts | set([_["ReasonDescription"] for _ in encounter["procedures"]]) # proc reas descs

text2embeddings = embed_descriptions(list(texts), "FremyCompany/BioLORD-2023")

for key, value in text2embeddings.items():
    text2embeddings[key] = value.cpu()

with open('data/text2embeddings.pkl', 'wb') as f:
    pickle.dump(text2embeddings, f)