In [None]:
# ----------------------------
# 1) Dataset: (N,1,5000,12) -> (12,5000)
# ----------------------------
from scipy.signal import medfilt, iirnotch, filtfilt, butter, resample
import numpy as np
import torch
import os
from torch.utils.data import Dataset, DataLoader

class LabPercentiles(Dataset):
    def __init__(self, data_npy_root, split):
        split_dir = os.path.join(data_npy_root, split)
        lab_path = os.path.join(split_dir, f"labs_percentiles_{split}.npy")
        lab_missingness_path = os.path.join(split_dir, f"labs_missingness_{split}.npy")
        self.lab = np.load(lab_path) 
        self.lab_missingness = np.load(lab_missingness_path)
        
    def __len__(self): 
        return self.lab.shape[0]

    def __getitem__(self, idx):
        lab = torch.tensor(self.lab[idx]).float()
        lab_missingness = torch.tensor(self.lab_missingness[idx]).float()
        return lab, lab_missingness

train_lab_ds = LabPercentiles(data_npy_root="../../../scratch/physionet.org/files/symile-mimic/1.0.0/data_npy", split="train")
val_lab_ds = LabPercentiles(data_npy_root="../../../scratch/physionet.org/files/symile-mimic/1.0.0/data_npy", split="val")
test_lab_ds = LabPercentiles(data_npy_root="../../../scratch/physionet.org/files/symile-mimic/1.0.0/data_npy", split="test")

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import lightning.pytorch as pl
from torch.utils.data import DataLoader, TensorDataset
from torch.optim import AdamW


class LabsDenoisingAE(nn.Module):
    """
    Mask-aware denoising autoencoder for lab percentiles.
    Encoder sees (x, m); decoder reconstructs x only.
    """

    def __init__(self, input_dim: int, latent_dim: int = 256):
        super().__init__()

        self.encoder = nn.Sequential(
            nn.Linear(input_dim * 2, 512),
            nn.LayerNorm(512),
            nn.GELU(),
            nn.Linear(512, latent_dim),
        )

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 512),
            nn.GELU(),
            nn.Linear(512, input_dim),
            nn.Sigmoid(),
        )

    def forward(self, x, m):
        z = self.encoder(torch.cat([x, m], dim=1))
        x_hat = self.decoder(z)
        return z, x_hat
    
    def encode(self, x, m):
        return self.encoder(torch.cat([x, m], dim=1))
    
# load checkpoint
model = LabsDenoisingAE(input_dim=50, latent_dim=256)
checkpoint = torch.load("../src/epoch=49-step=1950.ckpt", map_location=torch.device('cpu'))
state = checkpoint['state_dict']
# ---- 1. Filter CXR encoder keys ----
cxr_state = {k: v for k, v in state.items() if k.startswith("model.")}

# ---- 2. Strip the "cxr_encoder." prefix ----
cxr_state_stripped = {}
for k, v in cxr_state.items():
    new_key = k.replace("model.", "")   # CXREncoder expects keys starting with "resnet."
    cxr_state_stripped[new_key] = v

model.load_state_dict(cxr_state_stripped, strict=False)

  checkpoint = torch.load("../../epoch=49-step=1950.ckpt", map_location=torch.device('cpu'))


<All keys matched successfully>

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

LabsDenoisingAE(
  (encoder): Sequential(
    (0): Linear(in_features=100, out_features=512, bias=True)
    (1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (2): GELU(approximate='none')
    (3): Linear(in_features=512, out_features=256, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=256, out_features=512, bias=True)
    (1): GELU(approximate='none')
    (2): Linear(in_features=512, out_features=50, bias=True)
    (3): Sigmoid()
  )
)

In [6]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_lab_ds, batch_size=64, shuffle=False)
val_loader = DataLoader(val_lab_ds, batch_size=64, shuffle=False)
test_loader = DataLoader(test_lab_ds, batch_size=64, shuffle=False)

In [None]:
import os
import torch
import numpy as np
from tqdm import tqdm


@torch.no_grad()
def extract_and_save_features(
    model,
    loader,
    device,
    save_path,
):
    model.eval()
    all_feats = []

    for x in tqdm(loader, desc=f"Extracting → {os.path.basename(save_path)}"):
        pct, missing = x
        pct = pct.to(device)
        missing = missing.to(device)

        z,_ = model(pct, missing)

        if z.ndim == 3:
            z = z.mean(dim=1)                # temporal pooling

        all_feats.append(z.cpu())

    feats = torch.cat(all_feats, dim=0)       # (N,1024)
    feats = feats.numpy().astype(np.float32)

    np.save(save_path, feats)
    print(f"Saved {feats.shape} → {save_path}")
    
root = "../../../scratch/physionet.org/files/symile-mimic/1.0.0/data_npy"

extract_and_save_features(
    model,
    train_loader,
    device,
    save_path=os.path.join(root, "train", "labs_features.npy"),
)

extract_and_save_features(
    model,
    val_loader,
    device,
    save_path=os.path.join(root, "val", "labs_features.npy"),
)

extract_and_save_features(
    model,
    test_loader,
    device,
    save_path=os.path.join(root, "test", "labs_features.npy"),
)


Extracting → labs_features.npy: 100%|██████████| 157/157 [00:35<00:00,  4.41it/s]


Saved (10000, 256) → ../../../scratch/1.0.0/data_npy/train/labs_features.npy


Extracting → labs_features.npy: 100%|██████████| 12/12 [00:02<00:00,  4.05it/s]


Saved (750, 256) → ../../../scratch/1.0.0/data_npy/val/labs_features.npy


Extracting → labs_features.npy: 100%|██████████| 73/73 [00:36<00:00,  1.98it/s]

Saved (4640, 256) → ../../../scratch/1.0.0/data_npy/test/labs_features.npy



