In [None]:
import pandas as pd
from collections import Counter
import json
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

df = pd.read_parquet("los_dataset_24h.parquet")

print(df.shape)
print(df.columns)

(424803, 17)
Index(['subject_id', 'hadm_id', 'admittime', 'dischtime', 'race', 'los_hours',
       'gender', 'anchor_age', 'curr_service', 'hcpcs_cd_list',
       'diagnoses_icd_code_list', 'procedures_icd_code_list', 'drg_code',
       'drg_severity', 'drg_mortality', 'medication_list', 'order_type_list'],
      dtype='object')


In [None]:
def build_label_encoder(series):
    classes = sorted(series.unique())
    stoi = {c: i for i, c in enumerate(classes)}
    return stoi

def build_vocab_from_list_column(df, col, min_freq=1, add_unk=True):
    counter = Counter()
    for lst in df[col]:
        counter.update(lst)

    stoi = {}
    idx = 0
    if add_unk:
        stoi["<UNK>"] = idx
        idx += 1

    for token, freq in counter.items():
        if freq >= min_freq:
            if token not in stoi:
                stoi[token] = idx
                idx += 1

    return stoi

In [3]:
# Encoder for gender, race, curr_service, drg_code
gender_stoi = build_label_encoder(df["gender"])
race_stoi = build_label_encoder(df["race"])
service_stoi = build_label_encoder(df["curr_service"])

# drg_code is already an int, but treated as a "category" and re-mapped to an index
drg_code_stoi = build_label_encoder(df["drg_code"].astype(int))

print("num_genders:", len(gender_stoi))
print("num_races:", len(race_stoi))
print("num_services:", len(service_stoi))
print("num_drg_codes:", len(drg_code_stoi))

num_genders: 2
num_races: 33
num_services: 21
num_drg_codes: 301


In [4]:
list_cols = [
    "diagnoses_icd_code_list",
    "procedures_icd_code_list",
    "hcpcs_cd_list",
    "medication_list",
    "order_type_list",
]

diag_stoi = build_vocab_from_list_column(df, "diagnoses_icd_code_list", min_freq=1)
proc_stoi = build_vocab_from_list_column(df, "procedures_icd_code_list", min_freq=1)
hcpcs_stoi = build_vocab_from_list_column(df, "hcpcs_cd_list", min_freq=1)
med_stoi = build_vocab_from_list_column(df, "medication_list", min_freq=1)
order_stoi = build_vocab_from_list_column(df, "order_type_list", min_freq=1)

print("diag vocab size:", len(diag_stoi))
print("proc vocab size:", len(proc_stoi))
print("hcpcs vocab size:", len(hcpcs_stoi))
print("med vocab size:", len(med_stoi))
print("order vocab size:", len(order_stoi))

diag vocab size: 27701
proc vocab size: 12184
hcpcs vocab size: 1925
med vocab size: 3358
order vocab size: 17


In [5]:
# Combine proc_stoi and hcpcs_stoi into a single unified vocab
proc_all_stoi = {}
idx = 0

# Only one UNK token
proc_all_stoi["<UNK>"] = idx
idx += 1

# Procedures first
for k in proc_stoi.keys():
    if k == "<UNK>":
        continue
    proc_all_stoi["PROC_" + k] = idx
    idx += 1

# HCPCS next
for k in hcpcs_stoi.keys():
    if k == "<UNK>":
        continue
    key = "HCPCS_" + k
    if key not in proc_all_stoi:
        proc_all_stoi[key] = idx
        idx += 1

print("combined proc vocab size:", len(proc_all_stoi))

combined proc vocab size: 14108


In [6]:
UNK_DIAG = diag_stoi["<UNK>"]
UNK_PROC = proc_all_stoi["<UNK>"]
UNK_MED = med_stoi["<UNK>"]
UNK_ORDER = order_stoi["<UNK>"]

def map_list_to_ids(lst, stoi, unk_token="<UNK>"):
    unk_idx = stoi.get(unk_token, None)
    out = []
    for x in lst:
        idx = stoi.get(x)
        if idx is None:
            if unk_idx is not None:
                out.append(unk_idx)
        else:
            out.append(idx)
    return out

# Diagnostic codes
df["diag_ids"] = df["diagnoses_icd_code_list"].apply(
    lambda lst: map_list_to_ids(lst, diag_stoi)
)

# Combined procedure + hcpcs IDs
def build_proc_ids(row):
    ids = []
    for code in row["procedures_icd_code_list"]:
        tok = "PROC_" + code
        ids.append(proc_all_stoi.get(tok, UNK_PROC))
    for code in row["hcpcs_cd_list"]:
        tok = "HCPCS_" + code
        ids.append(proc_all_stoi.get(tok, UNK_PROC))
    return ids

df["proc_ids"] = df.apply(build_proc_ids, axis=1)

# medication
df["med_ids"] = df["medication_list"].apply(
    lambda lst: map_list_to_ids(lst, med_stoi)
)

# order_type
df["order_ids"] = df["order_type_list"].apply(
    lambda lst: map_list_to_ids(lst, order_stoi)
)

# Single categorical -> id
df["gender_id"] = df["gender"].map(gender_stoi)
df["race_id"] = df["race"].map(race_stoi)
df["service_id"] = df["curr_service"].map(service_stoi)
df["drg_code_id"] = df["drg_code"].astype(int).map(drg_code_stoi)

In [None]:
def summarize_list(lst, max_len=10):
    if lst is None:
        return None
    if len(lst) <= max_len:
        return lst
    return lst[:max_len] + ["...(+{} more)".format(len(lst) - max_len)]


# Output 3 samples
cols_to_show = [
    "gender", "race", "curr_service", "drg_code",
    "gender_id", "race_id", "service_id", "drg_code_id",
    "diagnoses_icd_code_list", "diag_ids",
    "procedures_icd_code_list", "hcpcs_cd_list", "proc_ids",
    "medication_list", "med_ids",
    "order_type_list", "order_ids",
]

print("\n================ SAMPLE DATA (after encoding) ==============\n")

for i in range(3):
    row = df.iloc[i]
    print(f"--- Row {i} ---")

    for c in cols_to_show:
        val = row[c]

        # Summarize list
        if isinstance(val, list):
            val = summarize_list(val)

        print(f"{c:25} : {val}")

    print("\n")



--- Row 0 ---
gender                    : F
race                      : WHITE
curr_service              : MED
drg_code                  : 279
gender_id                 : 0
race_id                   : 28
service_id                : 7
drg_code_id               : 135
diagnoses_icd_code_list   : ['07071' '78959' '2875' '2761' '496' '5715' 'V08' '3051']
diag_ids                  : [1, 2, 3, 4, 5, 6, 7, 8]
procedures_icd_code_list  : ['5491']
hcpcs_cd_list             : []
proc_ids                  : [1]
medication_list           : ['Raltegravir' 'Rifaximin' 'Sodium Chloride 0.9%  Flush'
 'Calcium Carbonate' 'Rifaximin' 'Raltegravir'
 'Emtricitabine-Tenofovir (Truvada)' 'Sulfameth/Trimethoprim DS'
 'Furosemide' 'Tiotropium Bromide' 'Albuterol Inhaler' 'Lactulose'
 'Heparin' 'Sodium Chloride 0.9%  Flush' 'Acetaminophen' 'Heparin'
 'Lactulose' 'Albumin 25% (12.5g / 50mL)' 'Sodium Chloride 0.9%  Flush'
 'Albumin 25% (12.5g / 50mL)' 'Albumin 25% (12.5g / 50mL)']
med_ids                   : [1,

In [8]:
class LOSDataset(Dataset):
    def __init__(self, df, use_log_target=True):
        self.df = df.reset_index(drop=True)
        self.use_log_target = use_log_target

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        sample = {
            # tabular features
            "age": float(row["anchor_age"]),
            "gender_id": int(row["gender_id"]),
            "race_id": int(row["race_id"]),
            "service_id": int(row["service_id"]),
            "drg_code_id": int(row["drg_code_id"]),
            "drg_severity": float(row["drg_severity"]),
            "drg_mortality": float(row["drg_mortality"]),

            # list ids (codes)
            "diag_ids": row["diag_ids"],
            "proc_ids": row["proc_ids"],
            "med_ids": row["med_ids"],
            "order_ids": row["order_ids"],
        }

        # target variable
        los = float(row["los_hours"])
        if self.use_log_target:
            sample["target"] = np.log1p(los)
        else:
            sample["target"] = los

        return sample

In [None]:
def los_collate_fn(batch):
    B = len(batch)

    # ----- Tabular stack -----
    age = torch.tensor([b["age"] for b in batch], dtype=torch.float32)
    gender_id = torch.tensor([b["gender_id"] for b in batch], dtype=torch.long)
    race_id = torch.tensor([b["race_id"] for b in batch], dtype=torch.long)
    service_id = torch.tensor([b["service_id"] for b in batch], dtype=torch.long)
    drg_code_id = torch.tensor([b["drg_code_id"] for b in batch], dtype=torch.long)
    drg_severity = torch.tensor([b["drg_severity"] for b in batch], dtype=torch.float32)
    drg_mortality = torch.tensor([b["drg_mortality"] for b in batch], dtype=torch.float32)

    target = torch.tensor([b["target"] for b in batch], dtype=torch.float32)

    # ----- List-type: diag / proc / med / order -----
    def build_bag_inputs(key):
        codes_all = []
        offsets = [0]
        for b in batch:
            ids = b[key]
            codes_all.extend(ids)
            offsets.append(len(codes_all))
        if len(codes_all) == 0:
            # In case the list is empty for all admissions
            codes_tensor = torch.empty(0, dtype=torch.long)
        else:
            codes_tensor = torch.tensor(codes_all, dtype=torch.long)
        offsets_tensor = torch.tensor(offsets, dtype=torch.long)
        return codes_tensor, offsets_tensor

    diag_codes, diag_offsets = build_bag_inputs("diag_ids")
    proc_codes, proc_offsets = build_bag_inputs("proc_ids")
    med_codes, med_offsets = build_bag_inputs("med_ids")
    order_codes, order_offsets = build_bag_inputs("order_ids")

    batch_out = {
        "age": age,
        "gender_id": gender_id,
        "race_id": race_id,
        "service_id": service_id,
        "drg_code_id": drg_code_id,
        "drg_severity": drg_severity,
        "drg_mortality": drg_mortality,
        "diag_codes": diag_codes,
        "diag_offsets": diag_offsets,
        "proc_codes": proc_codes,
        "proc_offsets": proc_offsets,
        "med_codes": med_codes,
        "med_offsets": med_offsets,
        "order_codes": order_codes,
        "order_offsets": order_offsets,
        "target": target,
    }

    return batch_out

In [10]:
from torch.utils.data import random_split, DataLoader
import torch

# Generate Dataset
dataset = LOSDataset(df, use_log_target=True)

n_total = len(dataset)
n_train = int(n_total * 0.7)
n_val = int(n_total * 0.15)
n_test = n_total - n_train - n_val

# Use generator for reproducibility (optional)
g = torch.Generator().manual_seed(42)
train_ds, val_ds, test_ds = random_split(dataset, [n_train, n_val, n_test], generator=g)

train_loader = DataLoader(
    train_ds,
    batch_size=256,
    shuffle=True,
    collate_fn=los_collate_fn,
)

val_loader = DataLoader(
    val_ds,
    batch_size=256,
    shuffle=False,
    collate_fn=los_collate_fn,
)

test_loader = DataLoader(
    test_ds,
    batch_size=256,
    shuffle=False,
    collate_fn=los_collate_fn,
)

print(f"#train: {len(train_ds)}, #val: {len(val_ds)}, #test: {len(test_ds)}")

#train: 297362, #val: 63720, #test: 63721


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass


@dataclass
class ModelConfig:
    # ----- Tabular categorical vocab sizes -----
    num_genders: int          # e.g. {"M", "F"} -> 2
    num_races: int            # unique race categories
    num_services: int         # curr_service (MED, ORTHO, ...)
    num_drg_codes: int        # drg_code vocab size

    # ----- Code vocab sizes -----
    diag_vocab_size: int      # diagnoses_icd_code_list vocab
    proc_vocab_size: int      # procedures_icd_code_list + hcpcs_cd_list unified vocab
    med_vocab_size: int       # medication_list vocab
    order_vocab_size: int     # order_type_list vocab

    # ----- Embedding dimensions -----
    emb_dim_gender: int = 4
    emb_dim_race: int = 8
    emb_dim_service: int = 8
    emb_dim_drg: int = 16

    emb_dim_diag: int = 32
    emb_dim_proc: int = 32
    emb_dim_med: int = 32
    emb_dim_order: int = 16

    # ----- Hidden dimensions -----
    tabular_hidden_dim: int = 64
    diag_hidden_dim: int = 64
    proc_hidden_dim: int = 64
    med_hidden_dim: int = 64
    order_hidden_dim: int = 32

    fusion_hidden_dim: int = 128
    dropout: float = 0.2


class MultiModalLOSModel(nn.Module):
    """
    Multi-branch model for LOS prediction.

    Branches:
      - Tabular branch: age, gender, race, drg, severity, mortality, curr_service
      - Diagnostic branch: diagnoses_icd_code_list
      - Procedure branch: procedures_icd_code_list + hcpcs_cd_list
      - Medication branch: medication_list
      - Order-type branch: order_type_list

    Fusion:
      concat([h_tab, h_diag, h_proc, h_med, h_order]) -> Attention MLP -> Fusion MLP -> scalar LOS prediction
    """

    def __init__(self, cfg: ModelConfig):
        super().__init__()
        self.cfg = cfg

        # ------------- Tabular embeddings -------------
        self.gender_emb = nn.Embedding(cfg.num_genders, cfg.emb_dim_gender)
        self.race_emb = nn.Embedding(cfg.num_races, cfg.emb_dim_race)
        self.service_emb = nn.Embedding(cfg.num_services, cfg.emb_dim_service)
        self.drg_emb = nn.Embedding(cfg.num_drg_codes, cfg.emb_dim_drg)

        # Tabular MLP input dim:
        #   age(1) + severity(1) + mortality(1)
        # + gender_emb + race_emb + service_emb + drg_emb
        tab_in_dim = (
            3
            + cfg.emb_dim_gender
            + cfg.emb_dim_race
            + cfg.emb_dim_service
            + cfg.emb_dim_drg
        )

        self.tabular_mlp = nn.Sequential(
            nn.Linear(tab_in_dim, cfg.tabular_hidden_dim),
            nn.ReLU(),
            nn.Dropout(cfg.dropout),
            nn.Linear(cfg.tabular_hidden_dim, cfg.tabular_hidden_dim),
            nn.ReLU(),
        )

        # ------------- Code embeddings (EmbeddingBag: bag-of-codes) -------------
        self.diag_emb = nn.EmbeddingBag(
            cfg.diag_vocab_size, cfg.emb_dim_diag,
            mode="mean", include_last_offset=True
        )
        self.proc_emb = nn.EmbeddingBag(
            cfg.proc_vocab_size, cfg.emb_dim_proc,
            mode="mean", include_last_offset=True
        )
        self.med_emb = nn.EmbeddingBag(
            cfg.med_vocab_size, cfg.emb_dim_med,
            mode="mean", include_last_offset=True
        )
        self.order_emb = nn.EmbeddingBag(
            cfg.order_vocab_size, cfg.emb_dim_order,
            mode="mean", include_last_offset=True
        )

        # Small projection MLP for each branch
        self.diag_mlp = nn.Sequential(
            nn.Linear(cfg.emb_dim_diag, cfg.diag_hidden_dim),
            nn.ReLU(),
            nn.Dropout(cfg.dropout),
        )
        self.proc_mlp = nn.Sequential(
            nn.Linear(cfg.emb_dim_proc, cfg.proc_hidden_dim),
            nn.ReLU(),
            nn.Dropout(cfg.dropout),
        )
        self.med_mlp = nn.Sequential(
            nn.Linear(cfg.emb_dim_med, cfg.med_hidden_dim),
            nn.ReLU(),
            nn.Dropout(cfg.dropout),
        )
        self.order_mlp = nn.Sequential(
            nn.Linear(cfg.emb_dim_order, cfg.order_hidden_dim),
            nn.ReLU(),
            nn.Dropout(cfg.dropout),
        )

        # ------------- Attention mechanism -------------
        # fusion_in_dim is the concatenated size of all branch hidden dimensions
        fusion_in_dim = (
            cfg.tabular_hidden_dim
            + cfg.diag_hidden_dim
            + cfg.proc_hidden_dim
            + cfg.med_hidden_dim
            + cfg.order_hidden_dim
        )
        self.attention_mlp = nn.Linear(fusion_in_dim, 5) # 5 branches

        # ------------- Fusion MLP -------------
        # Fusion MLP input dim is still the sum of hidden dims (after scaling)
        self.fusion_mlp = nn.Sequential(
            nn.Linear(fusion_in_dim, cfg.fusion_hidden_dim),
            nn.ReLU(),
            nn.Dropout(cfg.dropout),
            nn.Linear(cfg.fusion_hidden_dim, cfg.fusion_hidden_dim),
            nn.ReLU(),
            nn.Dropout(cfg.dropout),
        )

        # Final regression output layer (LOS prediction)
        self.out = nn.Linear(cfg.fusion_hidden_dim, 1)

    # ------------------------------------------------------------------
    # Forward: batch input format
    # ------------------------------------------------------------------
    def forward(
        self,
        # ----- Tabular -----
        age,               # (B,) float tensor (anchor_age or normalized)
        gender_idx,        # (B,) long tensor
        race_idx,          # (B,) long tensor
        service_idx,       # (B,) long tensor (curr_service)
        drg_code_idx,      # (B,) long tensor
        drg_severity,      # (B,) float or long (recommend normalizing to float)
        drg_mortality,     # (B,) float or long

        # ----- Diagnoses (EmbeddingBag) -----
        diag_codes,        # (N_diag_codes,) long tensor (flattened)
        diag_offsets,      # (B+1,) long tensor, EmbeddingBag offset

        # ----- Procedures (EmbeddingBag) -----
        proc_codes,        # (N_proc_codes,) long tensor
        proc_offsets,      # (B+1,) long tensor

        # ----- Medications (EmbeddingBag) -----
        med_codes,         # (N_med_codes,) long tensor
        med_offsets,       # (B+1,) long tensor

        # ----- Order types (EmbeddingBag) -----
        order_codes,       # (N_order_codes,) long tensor
        order_offsets,     # (B+1,) long tensor
    ):
        # --------- 1. Tabular branch ---------
        # Embeddings
        g_emb = self.gender_emb(gender_idx)   # (B, emb_dim_gender)
        r_emb = self.race_emb(race_idx)       # (B, emb_dim_race)
        s_emb = self.service_emb(service_idx) # (B, emb_dim_service)
        d_emb = self.drg_emb(drg_code_idx)    # (B, emb_dim_drg)

        # Cast continuous features to float
        age = age.float().unsqueeze(-1)                 # (B, 1)
        sev = drg_severity.float().unsqueeze(-1)        # (B, 1)
        mort = drg_mortality.float().unsqueeze(-1)      # (B, 1)

        tabular_feat = torch.cat(
            [age, sev, mort, g_emb, r_emb, s_emb, d_emb],
            dim=-1
        )  # (B, tab_in_dim)

        h_tab = self.tabular_mlp(tabular_feat)  # (B, tabular_hidden_dim)

        # --------- 2. Diagnostic branch ---------
        diag_bag = self.diag_emb(diag_codes, diag_offsets)  # (B, emb_dim_diag)
        h_diag = self.diag_mlp(diag_bag)  # (B, diag_hidden_dim)

        # --------- 3. Procedure branch ---------
        proc_bag = self.proc_emb(proc_codes, proc_offsets)  # (B, emb_dim_proc)
        h_proc = self.proc_mlp(proc_bag)  # (B, proc_hidden_dim)

        # --------- 4. Medication branch ---------
        med_bag = self.med_emb(med_codes, med_offsets)  # (B, emb_dim_med)
        h_med = self.med_mlp(med_bag)  # (B, med_hidden_dim)

        # --------- 5. Order-type branch ---------
        order_bag = self.order_emb(order_codes, order_offsets)  # (B, emb_dim_order)
        h_order = self.order_mlp(order_bag)  # (B, order_hidden_dim)

        # --------- 6. Apply Self-Attention Mechanism ---------
        # Concatenate hidden representations for attention input
        h_pre_attention = torch.cat([h_tab, h_diag, h_proc, h_med, h_order], dim=-1)

        # Compute raw attention scores
        alpha_raw = self.attention_mlp(h_pre_attention) # (B, 5)

        # Apply softmax to get normalized attention weights
        alpha_weights = F.softmax(alpha_raw, dim=-1) # (B, 5)

        # Split alpha_weights into individual branch weights
        alpha_tab = alpha_weights[:, 0].unsqueeze(-1)
        alpha_diag = alpha_weights[:, 1].unsqueeze(-1)
        alpha_proc = alpha_weights[:, 2].unsqueeze(-1)
        alpha_med = alpha_weights[:, 3].unsqueeze(-1)
        alpha_order = alpha_weights[:, 4].unsqueeze(-1)

        # Scale each branch's hidden representation by its attention weight
        h_tab_scaled = h_tab * alpha_tab
        h_diag_scaled = h_diag * alpha_diag
        h_proc_scaled = h_proc * alpha_proc
        h_med_scaled = h_med * alpha_med
        h_order_scaled = h_order * alpha_order

        # --------- 7. Fusion ---------
        h = torch.cat([h_tab_scaled, h_diag_scaled, h_proc_scaled, h_med_scaled, h_order_scaled], dim=-1)
        h = self.fusion_mlp(h)
        out = self.out(h).squeeze(-1)  # (B,)

        # out = Predicted LOS (can directly use hours, or design to predict log1p)
        return out

In [None]:
import torch
import torch.nn as nn

# 1. Model / Configuration Generation

cfg = ModelConfig(
    num_genders=len(gender_stoi),
    num_races=len(race_stoi),
    num_services=len(service_stoi),
    num_drg_codes=len(drg_code_stoi),
    diag_vocab_size=len(diag_stoi),
    proc_vocab_size=len(proc_all_stoi),
    med_vocab_size=len(med_stoi),
    order_vocab_size=len(order_stoi),
)

model = MultiModalLOSModel(cfg)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = nn.MSELoss()  # MSE based on log(1+LOS)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

NUM_EPOCHS = 20 

best_val_loss = float("inf")
best_model_path = "los_multibranch_attention_best.pt" 
print("Using device:", device)

# 2. Training Loop
for epoch in range(1, NUM_EPOCHS + 1):
    # ---------- Train ----------
    model.train()
    train_loss_sum = 0.0
    train_count = 0

    print(f"\n========== Epoch {epoch}/{NUM_EPOCHS} ==========")

    for batch_idx, batch in enumerate(train_loader):
        # Move tensors to device
        batch = {k: (v.to(device) if torch.is_tensor(v) else v)
                 for k, v in batch.items()}

        optimizer.zero_grad()

        y_pred = model(
            age=batch["age"],
            gender_idx=batch["gender_id"],
            race_idx=batch["race_id"],
            service_idx=batch["service_id"],
            drg_code_idx=batch["drg_code_id"],
            drg_severity=batch["drg_severity"],
            drg_mortality=batch["drg_mortality"],
            diag_codes=batch["diag_codes"],
            diag_offsets=batch["diag_offsets"],
            proc_codes=batch["proc_codes"],
            proc_offsets=batch["proc_offsets"],
            med_codes=batch["med_codes"],
            med_offsets=batch["med_offsets"],
            order_codes=batch["order_codes"],
            order_offsets=batch["order_offsets"],
        )

        y_true = batch["target"]  # log(1+LOS)
        loss = criterion(y_pred, y_true)

        loss.backward()
        optimizer.step()

        bs = y_true.size(0)
        train_loss_sum += loss.item() * bs
        train_count += bs

        if batch_idx % 100 == 0:
            avg_loss = train_loss_sum / train_count
            print(f"  [Epoch {epoch} | Step {batch_idx}/{len(train_loader)}] "
                  f"AvgTrainLoss={avg_loss:.4f}")

    train_loss = train_loss_sum / train_count

    # ---------- Validation ----------
    model.eval()
    val_loss_sum = 0.0
    val_count = 0
    val_mae_hours_sum = 0.0  # MAE based on actual LOS (hours)

    with torch.no_grad():
        for batch in val_loader:
            batch = {k: (v.to(device) if torch.is_tensor(v) else v)
                     for k, v in batch.items()}

            y_pred = model(
                age=batch["age"],
                gender_idx=batch["gender_id"],
                race_idx=batch["race_id"],
                service_idx=batch["service_id"],
                drg_code_idx=batch["drg_code_id"],
                drg_severity=batch["drg_severity"],
                drg_mortality=batch["drg_mortality"],
                diag_codes=batch["diag_codes"],
                diag_offsets=batch["diag_offsets"],
                proc_codes=batch["proc_codes"],
                proc_offsets=batch["proc_offsets"],
                med_codes=batch["med_codes"],
                med_offsets=batch["med_offsets"],
                order_codes=batch["order_codes"],
                order_offsets=batch["order_offsets"],
            )

            y_true = batch["target"]

            loss = criterion(y_pred, y_true)

            bs = y_true.size(0)
            val_loss_sum += loss.item() * bs
            val_count += bs

            # Convert log(1+LOS) -> actual LOS (hours) and calculate MAE
            y_true_hours = torch.expm1(y_true)
            y_pred_hours = torch.expm1(y_pred)

            mae_hours = torch.abs(y_pred_hours - y_true_hours).sum().item()
            val_mae_hours_sum += mae_hours

    val_loss = val_loss_sum / val_count
    val_mae_hours = val_mae_hours_sum / val_count

    print(f"[Epoch {epoch:03d}] "
          f"train_loss(log-MSE)={train_loss:.4f} | "
          f"val_loss(log-MSE)={val_loss:.4f} | "
          f"val_MAE(hours)={val_mae_hours:.2f}")

    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), best_model_path)
        print(f"  \u21B3 Best model updated, saved to {best_model_path}")

print("Training finished. Best val_loss:", best_val_loss)

Using device: cuda

  [Epoch 1 | Step 0/1162] AvgTrainLoss=22.0654
  [Epoch 1 | Step 100/1162] AvgTrainLoss=3.8523
  [Epoch 1 | Step 200/1162] AvgTrainLoss=2.4147
  [Epoch 1 | Step 300/1162] AvgTrainLoss=1.8707
  [Epoch 1 | Step 400/1162] AvgTrainLoss=1.5802
  [Epoch 1 | Step 500/1162] AvgTrainLoss=1.3983
  [Epoch 1 | Step 600/1162] AvgTrainLoss=1.2692
  [Epoch 1 | Step 700/1162] AvgTrainLoss=1.1789
  [Epoch 1 | Step 800/1162] AvgTrainLoss=1.1078
  [Epoch 1 | Step 900/1162] AvgTrainLoss=1.0511
  [Epoch 1 | Step 1000/1162] AvgTrainLoss=1.0012
  [Epoch 1 | Step 1100/1162] AvgTrainLoss=0.9592
[Epoch 001] train_loss(log-MSE)=0.9361 | val_loss(log-MSE)=0.4046 | val_MAE(hours)=73.24
  ↳ Best model updated, saved to los_multibranch_attention_best.pt

  [Epoch 2 | Step 0/1162] AvgTrainLoss=0.4641
  [Epoch 2 | Step 100/1162] AvgTrainLoss=0.4890
  [Epoch 2 | Step 200/1162] AvgTrainLoss=0.4801
  [Epoch 2 | Step 300/1162] AvgTrainLoss=0.4727
  [Epoch 2 | Step 400/1162] AvgTrainLoss=0.4672
  [Epoch

In [None]:
# 3. Calculate MAE (hours) on Test set

best_model = MultiModalLOSModel(cfg).to(device)
best_model.load_state_dict(torch.load(best_model_path, map_location=device))
best_model.eval()

test_abs_error_sum = 0.0
test_count = 0

with torch.no_grad():
    for batch in test_loader:
        batch = {k: (v.to(device) if torch.is_tensor(v) else v)
                 for k, v in batch.items()}

        y_pred = best_model(
            age=batch["age"],
            gender_idx=batch["gender_id"],
            race_idx=batch["race_id"],
            service_idx=batch["service_id"],
            drg_code_idx=batch["drg_code_id"],
            drg_severity=batch["drg_severity"],
            drg_mortality=batch["drg_mortality"],
            diag_codes=batch["diag_codes"],
            diag_offsets=batch["diag_offsets"],
            proc_codes=batch["proc_codes"],
            proc_offsets=batch["proc_offsets"],
            med_codes=batch["med_codes"],
            med_offsets=batch["med_offsets"],
            order_codes=batch["order_codes"],
            order_offsets=batch["order_offsets"],
        )

        y_true = batch["target"]  # log(1+LOS)

        # Convert log(1+LOS) -> actual hours
        y_true_hours = torch.expm1(y_true)
        y_pred_hours = torch.expm1(y_pred)

        abs_err = torch.abs(y_pred_hours - y_true_hours)
        test_abs_error_sum += abs_err.sum().item()
        test_count += y_true.size(0)

test_mae_hours = test_abs_error_sum / test_count

print(f"\n===== Test set MAE =====")
print(f"Test MAE (hours): {test_mae_hours:.2f}")
print(f"Test MAE (days) : {test_mae_hours / 24:.2f}")


===== Test set MAE =====
Test MAE (hours): 57.73
Test MAE (days) : 2.41


In [None]:
import numpy as np

best_model = MultiModalLOSModel(cfg).to(device)
best_model.load_state_dict(torch.load(best_model_path, map_location=device))
best_model.eval()

all_attention_weights = {
    "tabular": [],
    "diag": [],
    "proc": [],
    "med": [],
    "order": [],
}

with torch.no_grad():
    for batch in val_loader:
        batch = {k: (v.to(device) if torch.is_tensor(v) else v)
                 for k, v in batch.items()}

        # Tabular branch
        g_emb = best_model.gender_emb(batch["gender_id"])
        r_emb = best_model.race_emb(batch["race_id"])
        s_emb = best_model.service_emb(batch["service_id"])
        d_emb = best_model.drg_emb(batch["drg_code_id"])

        age = batch["age"].float().unsqueeze(-1)
        sev = batch["drg_severity"].float().unsqueeze(-1)
        mort = batch["drg_mortality"].float().unsqueeze(-1)

        tabular_feat = torch.cat(
            [age, sev, mort, g_emb, r_emb, s_emb, d_emb],
            dim=-1
        )
        h_tab = best_model.tabular_mlp(tabular_feat)

        # Diagnostic branch
        diag_bag = best_model.diag_emb(batch["diag_codes"], batch["diag_offsets"])
        h_diag = best_model.diag_mlp(diag_bag)

        # Procedure branch
        proc_bag = best_model.proc_emb(batch["proc_codes"], batch["proc_offsets"])
        h_proc = best_model.proc_mlp(proc_bag)

        # Medication branch
        med_bag = best_model.med_emb(batch["med_codes"], batch["med_offsets"])
        h_med = best_model.med_mlp(med_bag)

        # Order-type branch
        order_bag = best_model.order_emb(batch["order_codes"], batch["order_offsets"])
        h_order = best_model.order_mlp(order_bag)

        # Apply self-attention mechanism to get weights
        h_pre_attention = torch.cat([h_tab, h_diag, h_proc, h_med, h_order], dim=-1)
        alpha_raw = best_model.attention_mlp(h_pre_attention)
        alpha_weights = F.softmax(alpha_raw, dim=-1)

        # Split alpha_weights into individual branch weights and store
        all_attention_weights["tabular"].extend(alpha_weights[:, 0].cpu().numpy())
        all_attention_weights["diag"].extend(alpha_weights[:, 1].cpu().numpy())
        all_attention_weights["proc"].extend(alpha_weights[:, 2].cpu().numpy())
        all_attention_weights["med"].extend(alpha_weights[:, 3].cpu().numpy())
        all_attention_weights["order"].extend(alpha_weights[:, 4].cpu().numpy())

# 4. Calculate and print the mean of all collected attention weights
print("\n====== Average Attention Weights ======")
for branch_name, weights in all_attention_weights.items():
    avg_weight = np.mean(weights)
    print(f"Average {branch_name.capitalize()} Attention Weight: {avg_weight:.4f}")


Average Tabular Attention Weight: 0.3788
Average Diag Attention Weight: 0.2500
Average Proc Attention Weight: 0.1966
Average Med Attention Weight: 0.1420
Average Order Attention Weight: 0.0326
