# Remove admissions with los<=2

In [32]:

# Load admissions file
df = pd.read_excel("final_admissions.xlsx")
# Drop admissions with LOS <= 2
df_filtered = df[df["LOS"] > 2].reset_index(drop=True)
# Save filtered file
df_filtered.to_excel("final_admissions1.xlsx", index=False)
print(f"Original admissions: {len(df)}")
print(f"Remaining admissions (LOS > 2): {len(df_filtered)}")

Original admissions: 1777
Remaining admissions (LOS > 2): 1660


In [80]:
import pandas as pd
import numpy as np
import ast
import random

# Data Loading

In [81]:
emb_df = pd.read_excel("with_note_embedding/final-Q&A_based_Note_embeddings.xlsx")
adm_df = pd.read_excel("final_admissions.xlsx")

In [82]:
emb_df.head(5)

Unnamed: 0,Note_id,Embedding,HADM_ID,DAY
0,176176_Day_2,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1,...",176176,2
1,176176_Day_3,"[-1, 0, 0, 0, 0, 0, 1, 0, -1, 0, 0, 0, -1, 0, ...",176176,3
2,185910_Day_1,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1,...",185910,1
3,185910_Day_2,"[0, -1, 0, 0, 0, 0, 1, 0, 0, 0, 1, -1, -1, 0, ...",185910,2
4,185910_Day_3,"[0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, -1, 0, 0...",185910,3


In [83]:
adm_df.head(5)

Unnamed: 0,SUBJECT_ID,HADM_ID,ADMITTIME,DISCHTIME,MORTALITY_STATUS,LOS,DIAGNOSIS,AGE,GENDER
0,33,176176,2116-12-23 22:30:00,2116-12-27 12:05:00,0,5,SEPSIS;TELEMETRY,82,M
1,38,185910,2166-08-10 00:28:00,2166-09-04 11:30:00,0,26,ACUTE MYOCARDIAL INFARCTION-SEPSIS,76,M
2,357,122609,2198-11-01 22:36:00,2198-11-14 14:20:00,0,14,SEPSIS,64,M
3,366,134462,2164-11-18 20:27:00,2164-11-22 15:18:00,0,5,SEPSIS,53,M
4,62,116009,2113-02-15 00:19:00,2113-02-19 15:30:00,0,5,"SEPSIS,URINARY TRACT INFECTION",69,M


# Data Preprocessing

In [84]:
emb_df["embedding"] = emb_df["Embedding"].apply(ast.literal_eval)


In [85]:
EMB_DIM = len(emb_df["embedding"].iloc[0])
print(EMB_DIM)

2424


In [86]:
emb_groups = emb_df.groupby("HADM_ID")

In [87]:
emb_groups.head(5)

Unnamed: 0,Note_id,Embedding,HADM_ID,DAY,embedding
0,176176_Day_2,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1,...",176176,2,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1,..."
1,176176_Day_3,"[-1, 0, 0, 0, 0, 0, 1, 0, -1, 0, 0, 0, -1, 0, ...",176176,3,"[-1, 0, 0, 0, 0, 0, 1, 0, -1, 0, 0, 0, -1, 0, ..."
2,185910_Day_1,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1,...",185910,1,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1,..."
3,185910_Day_2,"[0, -1, 0, 0, 0, 0, 1, 0, 0, 0, 1, -1, -1, 0, ...",185910,2,"[0, -1, 0, 0, 0, 0, 1, 0, 0, 0, 1, -1, -1, 0, ..."
4,185910_Day_3,"[0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, -1, 0, 0...",185910,3,"[0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, -1, 0, 0..."
...,...,...,...,...,...
8461,153703_Day_1,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, ...",153703,1,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, ..."
8462,153703_Day_2,"[1, -1, 1, 0, 0, 0, -1, 0, -1, 0, 1, 0, 1, 1, ...",153703,2,"[1, -1, 1, 0, 0, 0, -1, 0, -1, 0, 1, 0, 1, 1, ..."
8463,153703_Day_3,"[0, -1, 1, 0, 0, 0, -1, 0, -1, 0, 1, 0, 1, -1,...",153703,3,"[0, -1, 1, 0, 0, 0, -1, 0, -1, 0, 1, 0, 1, -1,..."
8464,153703_Day_4,"[1, 0, -1, 0, 0, 0, -1, 0, 0, 0, 0, 0, 1, -1, ...",153703,4,"[1, 0, -1, 0, 0, 0, -1, 0, 0, 0, 0, 0, 1, -1, ..."


In [90]:
def build_admission_sequence(hadm_id, emb_group, los, emb_dim):
   """
   Returns:
     embeddings: (T, D)
     deltas: (T,)
     mask: (T,)  -> 1 if observed, 0 if missing
   """
   day_to_emb = {
       row.DAY: np.array(row.embedding)
       for _, row in emb_group.iterrows()
   }
   embeddings = []
   deltas = []
   mask = []
   prev_day = None
   for day in range(1, los + 1):
       if day in day_to_emb:
           emb = day_to_emb[day]
           m = 1
       else:
           emb = np.zeros(emb_dim)
           m = 0
       if prev_day is None:
           delta = 0
       else:
           delta = day - prev_day
       embeddings.append(emb)
       deltas.append(delta)
       mask.append(m)
       if m == 1:
           prev_day = day
   return (
       np.stack(embeddings),
       np.array(deltas),
       np.array(mask)
   )

In [91]:
admission_data = {}
for _, row in adm_df.iterrows():
   hadm_id = row.HADM_ID
   los = int(row.LOS)
   if hadm_id not in emb_groups.groups:
       continue
   emb_group = emb_groups.get_group(hadm_id)
   emb, delta, mask = build_admission_sequence(
       hadm_id, emb_group, los, EMB_DIM
   )
   admission_data[hadm_id] = {
       "HADM_ID": hadm_id,
       "emb": emb,
       "delta": delta,
       "mask": mask
   }

In [92]:
len(admission_data.keys())

1211

In [93]:
admission_data[199855]

{'HADM_ID': 199855,
 'emb': array([[ 0., -1.,  1., ...,  0.,  0.,  0.],
        [ 0.,  0.,  1., ...,  0.,  0.,  0.],
        [ 0.,  0.,  0., ...,  0.,  0.,  0.],
        [ 0.,  0.,  0., ...,  0.,  0.,  0.],
        [ 0.,  0.,  0., ...,  0.,  0.,  0.],
        [ 0.,  0.,  0., ...,  0.,  0.,  0.]]),
 'delta': array([0, 1, 1, 1, 2, 3]),
 'mask': array([1, 1, 1, 0, 0, 0])}

# Train/Val Split

In [94]:
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

In [95]:

hadm_ids = list(admission_data.keys())
train_ids, val_ids = train_test_split(
   hadm_ids,
   test_size=0.2,
   random_state=42
)
train_data = {k: admission_data[k] for k in train_ids}
val_data   = {k: admission_data[k] for k in val_ids}
print(len(train_data), len(val_data))

968 243


In [96]:
class AdmissionTrainDataset(Dataset):
   def __init__(self, data, max_prefix=3):
       self.data = list(data.values())
       self.max_prefix = max_prefix
   def __len__(self):
       return len(self.data)
   def __getitem__(self, idx):
       item = self.data[idx]
       emb = torch.tensor(item["emb"], dtype=torch.float)
       delta = torch.tensor(item["delta"], dtype=torch.float)
       mask = torch.tensor(item["mask"], dtype=torch.float)
       T = emb.shape[0]
       max_k = min(self.max_prefix, T)
       # -----------------------------
       # CRITICAL FIX: mask-aware prefix
       # -----------------------------
       valid_idx = torch.where(mask == 1)[0]
       if len(valid_idx) == 0:
           raise RuntimeError("Admission has no valid observed days")
       # Case 1: at least one valid day occurs within max_prefix
       valid_within = valid_idx[valid_idx < max_k]
       if len(valid_within) > 0:
           last_valid_pos = valid_within.max().item()
           k = torch.randint(
               low=last_valid_pos + 1,
               high=max_k + 1,
               size=(1,)
           ).item()
       else:
           # Case 2: first valid day occurs after max_prefix
           # Extend prefix to include first valid observation
           first_valid_pos = valid_idx.min().item()
           k = first_valid_pos + 1
       prefix_emb = emb[:k]
       prefix_delta = delta[:k]
       prefix_mask = mask[:k]
       # FINAL SAFETY CHECK
       assert prefix_mask.sum() > 0, "Prefix is fully masked â€” this should never happen"
       return {
           "hadm_id": item["HADM_ID"],
           "full_emb": emb,
           "full_delta": delta,
           "full_mask": mask,
           "prefix_emb": prefix_emb,
           "prefix_delta": prefix_delta,
           "prefix_mask": prefix_mask,
       }

In [97]:
class AdmissionValDataset(Dataset):
   def __init__(self, data, prefix_len=2):
       self.data = list(data.values())
       self.prefix_len = prefix_len
   def __len__(self):
       return len(self.data)
   def __getitem__(self, idx):
       item = self.data[idx]
       emb = torch.tensor(item["emb"], dtype=torch.float)
       delta = torch.tensor(item["delta"], dtype=torch.float)
       mask = torch.tensor(item["mask"], dtype=torch.float)
       T = emb.shape[0]
       k = min(self.prefix_len, T)
       # -------- FIX: ensure prefix has â‰¥1 valid timestep --------
       if mask[:k].sum() == 0:
           valid_idx = torch.where(mask == 1)[0]
           if len(valid_idx) == 0:
               raise RuntimeError("Admission has no valid observed days")
           k = valid_idx.min().item() + 1
       prefix_emb = emb[:k]
       prefix_delta = delta[:k]
       prefix_mask = mask[:k]
       # Safety check (can remove later)
       assert prefix_mask.sum() > 0, "Validation prefix is fully masked"
       return {
           "hadm_id": item["HADM_ID"],
           "full_emb": emb,
           "full_delta": delta,
           "full_mask": mask,
           "prefix_emb": prefix_emb,
           "prefix_delta": prefix_delta,
           "prefix_mask": prefix_mask,
       }

In [98]:
def collate_fn(batch):
   def pad(key, padding_value=0.0):
       return pad_sequence(
           [b[key] for b in batch],
           batch_first=True,
           padding_value=padding_value
       )
   out = {
       "hadm_id": [b["hadm_id"] for b in batch],
       "full_emb": pad("full_emb", 0.0),
       "full_delta": pad("full_delta", 0.0),
       "full_mask": pad("full_mask", 0.0),
       "prefix_emb": pad("prefix_emb", 0.0),
       "prefix_delta": pad("prefix_delta", 0.0),
       "prefix_mask": pad("prefix_mask", 0.0),
   }
   # -------- sanity checks --------
   assert (out["full_mask"].sum(dim=1) > 0).all(), \
       "Found full sequence with all-masked timesteps"
   assert (out["prefix_mask"].sum(dim=1) > 0).all(), \
       "Found prefix with all-masked timesteps"
   return out

In [99]:

train_loader = DataLoader(
   AdmissionTrainDataset(train_data),
   batch_size=16,
   shuffle=True,
   collate_fn=collate_fn
)
val_loader = DataLoader(
   AdmissionValDataset(val_data),
   batch_size=16,
   shuffle=False,
   collate_fn=collate_fn
)

In [100]:

class TimeEmbedding(nn.Module):
   def __init__(self, dim):
       super().__init__()
       self.net = nn.Sequential(
           nn.Linear(1, dim),
           nn.ReLU(),
           nn.Linear(dim, dim)
       )
   def forward(self, delta):
       return self.net(delta.unsqueeze(-1))

class MaskEmbedding(nn.Module):
   def __init__(self, dim):
       super().__init__()
       self.emb = nn.Embedding(2, dim)
   def forward(self, mask):
       return self.emb(mask.long())

In [101]:
class TemporalTransformer(nn.Module):
   def __init__(self, dim, n_heads=4, n_layers=4):
       super().__init__()
       self.time_emb = TimeEmbedding(dim)
       self.mask_emb = MaskEmbedding(dim)
       encoder_layer = nn.TransformerEncoderLayer(
           d_model=dim,
           nhead=n_heads,
           batch_first=True
       )
       self.encoder = nn.TransformerEncoder(
           encoder_layer,
           num_layers=n_layers
       )
       self.query = nn.Parameter(torch.randn(dim))
   def forward(self, emb, delta, mask):
       """
       emb:   (B, T, D)
       delta: (B, T)
       mask:  (B, T)   with values {0,1}
       """
       x = emb + self.time_emb(delta) + self.mask_emb(mask)
       key_padding_mask = (mask == 0)
       h = self.encoder(
           x,
           src_key_padding_mask=key_padding_mask
       )  # (B, T, D)
       # -------- FIX: mask-aware attention pooling --------
       scores = torch.matmul(h, self.query)          # (B, T)
       scores = scores.masked_fill(mask == 0, -1e9)  # mask padded steps
       attn = torch.softmax(scores, dim=1)
       attn = attn * mask                             # zero out padded
       attn = attn / (attn.sum(dim=1, keepdim=True) + 1e-8)
       adm_emb = torch.sum(h * attn.unsqueeze(-1), dim=1)
       return adm_emb

# prefix-to-full Contrastive Loss

In [102]:

def contrastive_loss(x, y, temperature=0.1):
   """
   x, y: (B, D) embeddings
   """
   x = F.normalize(x, dim=1)
   y = F.normalize(y, dim=1)
   logits = torch.matmul(x, y.T) / temperature
   labels = torch.arange(x.size(0)).to(x.device)
   return F.cross_entropy(logits, labels)

# Implement Temporal Augmentation

In [103]:
def random_day_drop(emb, delta, mask, drop_prob=0.2, min_len=3):
   """
   emb:   (T, D)
   delta: (T,)
   mask:  (T,)  with values {0,1}
   """
   # Only keep originally valid (observed) days
   valid_indices = [i for i in range(len(mask)) if mask[i] == 1]
   if len(valid_indices) < min_len:
       # Extremely rare (LOS > 2), but safety first
       valid_indices = valid_indices
   keep_indices = []
   for i in valid_indices:
       if random.random() > drop_prob:
           keep_indices.append(i)
   # Enforce minimum length
   if len(keep_indices) < min_len:
       keep_indices = valid_indices[:min_len]
   emb_new = emb[keep_indices]
   delta_new = delta[keep_indices]
   # ðŸ”‘ CRITICAL FIX: augmented mask must be all ones
   mask_new = torch.ones(len(keep_indices), dtype=torch.float32, device=emb.device)
   return emb_new, delta_new, mask_new


In [104]:
def jitter_delta(delta, noise_std=0.1):
   noise = torch.randn_like(delta) * noise_std
   delta_new = delta + noise
   delta_new = torch.clamp(delta_new, min=0)
   return delta_new

In [105]:
def temporal_augment(batch, drop_prob=0.2, min_len=3):
   emb = batch["full_emb"]     # (B, T, D)
   delta = batch["full_delta"] # (B, T)
   mask = batch["full_mask"]   # (B, T)
   emb_aug, delta_aug, mask_aug = [], [], []
   B = emb.size(0)
   device = emb.device
   for i in range(B):
       e, d, m = random_day_drop(
           emb[i], delta[i], mask[i],
           drop_prob=drop_prob,
           min_len=min_len
       )
       # Optional but recommended
       d = jitter_delta(d)
       emb_aug.append(e)
       delta_aug.append(d)
       mask_aug.append(m)
   return emb_aug, delta_aug, mask_aug

In [106]:
def pad_augmented(emb_list, delta_list, mask_list):
   emb_pad = pad_sequence(emb_list, batch_first=True, padding_value=0.0)
   delta_pad = pad_sequence(delta_list, batch_first=True, padding_value=0.0)
   mask_pad = pad_sequence(mask_list, batch_first=True, padding_value=0.0)
   return emb_pad, delta_pad, mask_pad

# Model traning

In [125]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = TemporalTransformer(dim=EMB_DIM).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [126]:

num_epochs = 5
lambda_aug = 0.2
best_val_loss = float("inf")
best_epoch = -1
checkpoint_path = "best_temporal_transformer.pt"

In [127]:
for epoch in range(1, num_epochs + 1):
   # =========================
   # Training
   # =========================
   model.train()
   train_loss_total = 0.0
   train_pf_total = 0.0
   train_aug_total = 0.0
   train_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs} [Train]")
   for batch in train_bar:
       for k in batch:
           if k != "hadm_id":
               batch[k] = batch[k].to(device)
       # ----- Prefixâ€“Full loss -----
       e_full = model(
           batch["full_emb"],
           batch["full_delta"],
           batch["full_mask"]
       )
       e_prefix = model(
           batch["prefix_emb"],
           batch["prefix_delta"],
           batch["prefix_mask"]
       )
       loss_pf = contrastive_loss(e_prefix, e_full)
       # ----- Temporal augmentation loss -----
       emb_aug1, delta_aug1, mask_aug1 = temporal_augment(batch)
       emb_aug2, delta_aug2, mask_aug2 = temporal_augment(batch)
       emb_aug1, delta_aug1, mask_aug1 = pad_augmented(
           emb_aug1, delta_aug1, mask_aug1
       )
       emb_aug2, delta_aug2, mask_aug2 = pad_augmented(
           emb_aug2, delta_aug2, mask_aug2
       )
       emb_aug1 = emb_aug1.to(device)
       delta_aug1 = delta_aug1.to(device)
       mask_aug1 = mask_aug1.to(device)
       emb_aug2 = emb_aug2.to(device)
       delta_aug2 = delta_aug2.to(device)
       mask_aug2 = mask_aug2.to(device)
       e_aug1 = model(emb_aug1, delta_aug1, mask_aug1)
       e_aug2 = model(emb_aug2, delta_aug2, mask_aug2)
       loss_aug = contrastive_loss(e_aug1, e_aug2)
       # ----- Total loss -----
       loss = loss_pf + lambda_aug * loss_aug
       optimizer.zero_grad()
       loss.backward()
       torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
       optimizer.step()
       train_loss_total += loss.item()
       train_pf_total += loss_pf.item()
       train_aug_total += loss_aug.item()
       train_bar.set_postfix({
           "loss": f"{loss.item():.4f}",
           "pf": f"{loss_pf.item():.4f}",
           "aug": f"{loss_aug.item():.4f}"
       })
   avg_train_loss = train_loss_total / len(train_loader)
   avg_train_pf = train_pf_total / len(train_loader)
   avg_train_aug = train_aug_total / len(train_loader)
   # =========================
   # Validation
   # =========================
   model.eval()
   val_loss_total = 0.0
   with torch.no_grad():
       for batch in val_loader:
           for k in batch:
               if k != "hadm_id":
                   batch[k] = batch[k].to(device)
           e_full = model(
               batch["full_emb"],
               batch["full_delta"],
               batch["full_mask"]
           )
           e_prefix = model(
               batch["prefix_emb"],
               batch["prefix_delta"],
               batch["prefix_mask"]
           )
           loss_val = contrastive_loss(e_prefix, e_full)
           val_loss_total += loss_val.item()
   avg_val_loss = val_loss_total / len(val_loader)
   # =========================
   # Save best checkpoint
   # =========================
   if avg_val_loss < best_val_loss:
       best_val_loss = avg_val_loss
       best_epoch = epoch
       torch.save(
           {
               "epoch": epoch,
               "model_state_dict": model.state_dict(),
               "optimizer_state_dict": optimizer.state_dict(),
               "val_loss": avg_val_loss,
           },
           checkpoint_path
       )
       print(f" Best model saved at epoch {epoch} (Val Loss = {avg_val_loss:.4f})")
   # =========================
   # Epoch summary
   # =========================
   print(
       f"\nEpoch {epoch}/{num_epochs} Summary\n"
       f"  Train Loss       : {avg_train_loss:.4f}\n"
       f"    â”œâ”€ Prefixâ€“Full : {avg_train_pf:.4f}\n"
       f"    â””â”€ Augmentation: {avg_train_aug:.4f}\n"
       f"  Val Loss         : {avg_val_loss:.4f}\n"
       f"{'-'*60}"
   )
print(
   f"\nTraining completed.\n"
   f"Best model from epoch {best_epoch} "
   f"with validation loss {best_val_loss:.4f}"
)

Epoch 1/5 [Train]:   0%|                                                                        | 0/61 [00:00<?, ?it/s]


RuntimeError: The size of tensor a (2424) must match the size of tensor b (256) at non-singleton dimension 2

# Test

In [112]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TemporalTransformer(
   dim=EMB_DIM,
   n_heads=4,
   n_layers=4
).to(device)
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
print(
   f"Loaded model from epoch {checkpoint['epoch']} "
   f"(val loss = {checkpoint['val_loss']:.4f})"
)

  checkpoint = torch.load(checkpoint_path, map_location=device)


Loaded model from epoch 3 (val loss = 0.0197)


In [113]:
class AdmissionTestDataset(Dataset):
   def __init__(self, data, prefix_len=2):
       self.data = list(data.values())
       self.prefix_len = prefix_len
   def __len__(self):
       return len(self.data)
   def __getitem__(self, idx):
       item = self.data[idx]
       emb = torch.tensor(item["emb"], dtype=torch.float)
       delta = torch.tensor(item["delta"], dtype=torch.float)
       mask = torch.tensor(item["mask"], dtype=torch.float)
       T = emb.shape[0]
       k = min(self.prefix_len, T)
       # -------- FIX: ensure prefix has â‰¥1 valid timestep --------
       if mask[:k].sum() == 0:
           valid_idx = torch.where(mask == 1)[0]
           if len(valid_idx) == 0:
               raise RuntimeError("Admission has no valid observed days")
           k = valid_idx.min().item() + 1
       prefix_emb = emb[:k]
       prefix_delta = delta[:k]
       prefix_mask = mask[:k]
       # Safety check (can remove later)
       assert prefix_mask.sum() > 0, "Validation prefix is fully masked"
       return {
           "hadm_id": item["HADM_ID"],
           "full_emb": emb,
           "full_delta": delta,
           "full_mask": mask,
           "prefix_emb": prefix_emb,
           "prefix_delta": prefix_delta,
           "prefix_mask": prefix_mask,
       }

In [114]:
# def test_collate_fn(batch):
#    return {
#        "hadm_id": [b["hadm_id"] for b in batch],
#        "emb": pad_sequence([b["emb"] for b in batch], batch_first=True, padding_value=0.0),
#        "delta": pad_sequence([b["delta"] for b in batch], batch_first=True, padding_value=0.0),
#        "mask": pad_sequence([b["mask"] for b in batch], batch_first=True, padding_value=0.0),
#    }

def test_collate_fn(batch):
   def pad(key, padding_value=0.0):
       return pad_sequence(
           [b[key] for b in batch],
           batch_first=True,
           padding_value=padding_value
       )
   out = {
       "hadm_id": [b["hadm_id"] for b in batch],
       "full_emb": pad("full_emb", 0.0),
       "full_delta": pad("full_delta", 0.0),
       "full_mask": pad("full_mask", 0.0),
       "prefix_emb": pad("prefix_emb", 0.0),
       "prefix_delta": pad("prefix_delta", 0.0),
       "prefix_mask": pad("prefix_mask", 0.0),
   }
   return out


In [115]:
len(admission_data)

1211

In [116]:
test_dataset = AdmissionTestDataset(admission_data)
test_loader = torch.utils.data.DataLoader(
   test_dataset,
   batch_size=32,
   shuffle=False,
   collate_fn=test_collate_fn
)

In [117]:
all_hadm_ids = []
all_full_embeddings = []
all_pref_embeddings = []
with torch.no_grad():
   for batch in tqdm(test_loader, desc="Test"):
       #print(batch)
       full_emb = batch["full_emb"].to(device)
       full_delta = batch["full_delta"].to(device)
       full_mask = batch["full_mask"].to(device)
       full_adm_emb = model(full_emb, full_delta, full_mask)  # (B, D)
       all_full_embeddings.append(full_adm_emb.cpu().numpy())
       pref_emb = batch["prefix_emb"].to(device)
       pref_delta = batch["prefix_delta"].to(device)
       pref_mask = batch["prefix_mask"].to(device)
       pref_adm_emb = model(pref_emb, pref_delta, pref_mask)  # (B, D)
       all_pref_embeddings.append(pref_adm_emb.cpu().numpy())
       all_hadm_ids.extend(batch["hadm_id"])
all_full_embeddings = np.vstack(all_full_embeddings)
all_pref_embeddings = np.vstack(all_pref_embeddings)

Test: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 38/38 [02:55<00:00,  4.63s/it]


In [123]:
df_embeddings = pd.DataFrame({
   "HADM_ID": all_hadm_ids,
   "full_admission_embedding": all_full_embeddings.tolist()
})
df_embeddings.to_csv(
   "temporal_admission_embeddings.csv",
   index=False
)
print(f"Saved embeddings for {len(df_embeddings)} admissions.")

Saved embeddings for 1211 admissions.


In [121]:
len(all_full_embeddings[0].tolist())

2424

In [122]:
all_full_embeddings[0].tolist()

[-0.4708498418331146,
 -0.5475208759307861,
 0.28125613927841187,
 0.516147792339325,
 0.36070409417152405,
 -0.04363811016082764,
 0.9618723392486572,
 0.3940972089767456,
 2.942169666290283,
 1.1325111389160156,
 -0.9110435247421265,
 -0.5094019770622253,
 0.06209389865398407,
 0.30523550510406494,
 -0.14319802820682526,
 -1.0020883083343506,
 0.07837782055139542,
 1.7236154079437256,
 -0.23076598346233368,
 0.04637850075960159,
 0.5555213093757629,
 -0.9091050028800964,
 0.34936612844467163,
 0.40949851274490356,
 0.5398333668708801,
 -0.6962534785270691,
 1.7707552909851074,
 -0.9944214224815369,
 -0.9976385235786438,
 -0.011357879266142845,
 -0.3703729510307312,
 0.12053392082452774,
 -0.17379771173000336,
 -0.2549597918987274,
 -0.5007861852645874,
 -2.7270724773406982,
 0.7572588920593262,
 -1.2707879543304443,
 1.1494656801223755,
 -1.1632790565490723,
 -0.7993175387382507,
 -0.724682629108429,
 -2.9457409381866455,
 -2.882647752761841,
 0.8837493062019348,
 -0.1631938517093658