In [1]:
import torch
import torch_directml
from torch.utils.data import DataLoader, Dataset

from transformers import BertModel, BertTokenizer,AutoModel, AutoTokenizer

from tqdm import tqdm
import json

import torch.nn as nn
import lightning as L

# dml = torch_directml.device()
dml = "cpu"


In [2]:
with open("data/sub_1000.json","r") as j_file:
    data = json.load(j_file)

In [3]:
def embed_descriptions(descriptions, model_directory, dml, batch_size=32):
    """ Given a list of descriptions, a model directory, and a device, returns embeddings for the descriptions."""
    tokenizer = AutoTokenizer.from_pretrained(model_directory)

    # Load pre-trained model (weights)
    model = AutoModel.from_pretrained(model_directory)
    model.to(dml)  # Move model to the specified device

    model.eval()  # Put the model in "evaluation" mode, which turns off dropout
    print(f"Model loaded | Generating embeddings with batch size={batch_size}")

    # Prepare inputs as a dictionary for the model
    inputs = tokenizer(descriptions, padding=True, truncation=True, return_tensors="pt", max_length=64)
    inputs = {k: v.to(dml) for k, v in inputs.items()}  # Move inputs to the specified device

    # Process in batches with tqdm for progress tracking
    embeddings = []
    for i in tqdm(range(0, len(descriptions), batch_size), desc="Generating Embeddings"):
        batch = {k: v[i:i + batch_size] for k, v in inputs.items()}  # Create batch for the current iteration
        with torch.no_grad():
            outputs = model(**batch)

        # Extract pooled output embeddings
        batch_embeddings = outputs.pooler_output
        embeddings.append(batch_embeddings)

    # Concatenate batched embeddings
    embeddings = torch.cat(embeddings, dim=0)

    return {descriptions[i]: embeddings[i] for i in range(len(descriptions))}

In [4]:
# generate embeddings

# get set of all descriptions / reasondescriptions
texts = set()

for patient in tqdm(data):
    for encounter in patient["encounters"]:
        texts.add(encounter["encounter"]["Description"]) # append enc description
        texts.add(encounter["encounter"]["ReasonDescription"]) # append enc reasondescription

        texts = texts | set([_["Description"] for _ in encounter["conditions"]]) # condition desc

        texts = texts | set([_["Description"] for _ in encounter["careplans"]]) # careplan descs

        texts = texts | set([_["ReasonDescription"] for _ in encounter["careplans"]]) # careplan reas descs

        texts = texts | set([_["Description"] for _ in encounter["procedures"]]) # proc descs

        texts = texts | set([_["ReasonDescription"] for _ in encounter["procedures"]]) # proc reas descs

# tokenizer = BertTokenizer.from_pretrained('FremyCompany/BioLORD-2023')
# model = BertModel.from_pretrained('FremyCompany/BioLORD-2023')

text2embeddings = embed_descriptions(list(texts), "FremyCompany/BioLORD-2023",dml=dml)

100%|██████████| 1000/1000 [00:00<00:00, 1005.68it/s]


Model loaded | Generating embeddings with batch size=32


Generating Embeddings: 100%|██████████| 20/20 [00:25<00:00,  1.28s/it]


In [5]:
class PatientDataset(Dataset):
    def __init__(self, data, text2embedding):
        self.data = data
        self.text2embedding = text2embedding

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        patient = self.data[idx]
        embeddings = []
        for encounter in patient['encounters']:
            # Fetch embeddings for each text piece
            for key in ['Description', 'ReasonDescription']:
                text = encounter['encounter'].get(key, '')
                if text and text in self.text2embedding:
                    embeddings.append(self.text2embedding[text])

            for item_type in ['conditions', 'careplans', 'procedures']:
                for item in encounter[item_type]:
                    for key in ['Description', 'ReasonDescription']:
                        text = item.get(key, '')
                        if text and text in self.text2embedding:
                            embeddings.append(self.text2embedding[text])

        # Stack embeddings if not empty, else return a zero tensor
        embeddings_tensor = torch.stack(embeddings) if embeddings else torch.zeros(1, len(next(iter(self.text2embedding.values()))))

        return {
            'embeddings': embeddings_tensor,  # a list of tensors
            'features': torch.tensor([patient['lat'], patient['lon']])
        }

dataset = PatientDataset(data, text2embeddings)
dataloader = DataLoader(dataset, batch_size=2)


In [6]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    embeddings = [item['embeddings'] for item in batch]
    features = torch.stack([item['features'] for item in batch])

    # Pad the embeddings sequences
    embeddings_padded = pad_sequence(embeddings, batch_first=True)

    return {
        'embeddings': embeddings_padded,
        'features': features
    }

In [20]:
class EncounterAutoencoder(nn.Module):
    def __init__(self, embedding_dim, hidden_dim):
        super(EncounterAutoencoder, self).__init__()
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)

    def forward(self, x):
        output, (hidden, cell) = self.lstm(x)
        return hidden[-1]  # Taking the last hidden state as the representation
        # return output

class PatientAutoencoder(L.LightningModule):
    def __init__(self, embedding_dim, hidden_dim, patient_latent_dim):
        super(PatientAutoencoder, self).__init__()
        self.encounter_autoencoder = EncounterAutoencoder(embedding_dim, hidden_dim)
        self.patient_encoder = nn.Linear(hidden_dim, patient_latent_dim)
        self.patient_decoder_1 = nn.Linear(patient_latent_dim, hidden_dim)
        self.patient_decoder_2 = nn.Linear(hidden_dim, embedding_dim)

    def forward(self, x):
        # print("x embeddings", x["embeddings"].shape)
        encounter_representation = self.encounter_autoencoder(x['embeddings'])
        # print("encounter representation", encounter_representation.shape)
        patient_encoded = self.patient_encoder(encounter_representation)
        # print("patient encoded", patient_encoded.shape)

        patient_decoded = self.patient_decoder_1(patient_encoded)

        # print("patient decoded", patient_decoded.shape)

        return patient_decoded

    def training_step(self, batch, batch_idx):
        x = batch
        output = self.forward(x)
        # Assuming the original data is available for computing loss
        # You might need to adjust this based on how you structure your data and loss function
        loss = nn.MSELoss()(output, self.encounter_autoencoder(x["embeddings"]))
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer


In [21]:
from torch.utils.data import DataLoader
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger

dataset = PatientDataset(data, text2embeddings)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

# Initialize the model
embedding_dim = 768
hidden_dim = 128  # Intermediate representation size
patient_latent_dim = 64  # Final latent space dimension for the patient

model = PatientAutoencoder(embedding_dim, hidden_dim, patient_latent_dim)

logger = CSVLogger("pl_logs", name="hae")

# Initialize PyTorch Lightning trainer and train the model
trainer = L.Trainer(max_epochs=1, accelerator="cpu", callbacks=[
    ModelCheckpoint(
        monitor="train_loss", mode="min", save_last=True, save_top_k=1,
        dirpath="checkpoints/", filename="hae-{epoch:02d}-{val_loss:.2f}"
    ),
    EarlyStopping(monitor="train_loss", patience=3, mode="min")
], logger=logger)
trainer.fit(model, dataloader)


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: pl_logs/hae

  | Name                  | Type                 | Params
---------------------------------------------------------------
0 | encounter_autoencoder | EncounterAutoencoder | 459 K 
1 | patient_encoder       | Linear               | 8.3 K 
2 | patient_decoder_1     | Linear               | 8.3 K 
3 | patient_decoder_2     | Linear               | 99.1 K
---------------------------------------------------------------
575 K     Trainable params
0         Non-trainable params
575 K     Total params
2.302     Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.
