In [1]:
import pandas as pd
from collections import Counter
import json
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

df = pd.read_parquet("los_dataset_24h.parquet")

print(df.shape)
print(df.columns)

(424803, 17)
Index(['subject_id', 'hadm_id', 'admittime', 'dischtime', 'race', 'los_hours',
       'gender', 'anchor_age', 'curr_service', 'hcpcs_cd_list',
       'diagnoses_icd_code_list', 'procedures_icd_code_list', 'drg_code',
       'drg_severity', 'drg_mortality', 'medication_list', 'order_type_list'],
      dtype='object')


In [2]:
def build_label_encoder(series):
    classes = sorted(series.unique())
    stoi = {c: i for i, c in enumerate(classes)}
    return stoi


def build_vocab_from_list_column(df, col, min_freq=1, add_unk=True):
    counter = Counter()
    for lst in df[col]:
        counter.update(lst)

    stoi = {}
    idx = 0
    if add_unk:
        stoi["<UNK>"] = idx
        idx += 1

    for token, freq in counter.items():
        if freq >= min_freq:
            if token not in stoi:
                stoi[token] = idx
                idx += 1

    return stoi

In [3]:
# Encoders for gender, race, curr_service, drg_code
gender_stoi = build_label_encoder(df["gender"])
race_stoi = build_label_encoder(df["race"])
service_stoi = build_label_encoder(df["curr_service"])

# drg_code is already an int, but treat it as a "category" and map it to an index
drg_code_stoi = build_label_encoder(df["drg_code"].astype(int))

print("num_genders:", len(gender_stoi))
print("num_races:", len(race_stoi))
print("num_services:", len(service_stoi))
print("num_drg_codes:", len(drg_code_stoi))

num_genders: 2
num_races: 33
num_services: 21
num_drg_codes: 301


In [4]:
list_cols = [
    "diagnoses_icd_code_list",
    "procedures_icd_code_list",
    "hcpcs_cd_list",
    "medication_list",
    "order_type_list",
]

diag_stoi = build_vocab_from_list_column(df, "diagnoses_icd_code_list", min_freq=1)
proc_stoi = build_vocab_from_list_column(df, "procedures_icd_code_list", min_freq=1)
hcpcs_stoi = build_vocab_from_list_column(df, "hcpcs_cd_list", min_freq=1)
med_stoi = build_vocab_from_list_column(df, "medication_list", min_freq=1)
order_stoi = build_vocab_from_list_column(df, "order_type_list", min_freq=1)

print("diag vocab size:", len(diag_stoi))
print("proc vocab size:", len(proc_stoi))
print("hcpcs vocab size:", len(hcpcs_stoi))
print("med vocab size:", len(med_stoi))
print("order vocab size:", len(order_stoi))

diag vocab size: 27701
proc vocab size: 12184
hcpcs vocab size: 1925
med vocab size: 3358
order vocab size: 17


In [5]:
# Combine proc_stoi and hcpcs_stoi into a single unified vocabulary
proc_all_stoi = {}
idx = 0

# Only one UNK
proc_all_stoi["<UNK>"] = idx
idx += 1

# Procedures first
for k in proc_stoi.keys():
    if k == "<UNK>":
        continue
    proc_all_stoi["PROC_" + k] = idx
    idx += 1

# Then hcpcs
for k in hcpcs_stoi.keys():
    if k == "<UNK>":
        continue
    key = "HCPCS_" + k
    if key not in proc_all_stoi:
        proc_all_stoi[key] = idx
        idx += 1

print("combined proc vocab size:", len(proc_all_stoi))

combined proc vocab size: 14108


In [6]:
UNK_DIAG = diag_stoi["<UNK>"]
UNK_PROC = proc_all_stoi["<UNK>"]
UNK_MED = med_stoi["<UNK>"]
UNK_ORDER = order_stoi["<UNK>"]

def map_list_to_ids(lst, stoi, unk_token="<UNK>"):
    unk_idx = stoi.get(unk_token, None)
    out = []
    for x in lst:
        idx = stoi.get(x)
        if idx is None:
            if unk_idx is not None:
                out.append(unk_idx)
        else:
            out.append(idx)
    return out

# Diagnostic codes
df["diag_ids"] = df["diagnoses_icd_code_list"].apply(
    lambda lst: map_list_to_ids(lst, diag_stoi)
)

# Combined procedure + hcpcs IDs
def build_proc_ids(row):
    ids = []
    for code in row["procedures_icd_code_list"]:
        tok = "PROC_" + code
        ids.append(proc_all_stoi.get(tok, UNK_PROC))
    for code in row["hcpcs_cd_list"]:
        tok = "HCPCS_" + code
        ids.append(proc_all_stoi.get(tok, UNK_PROC))
    return ids

df["proc_ids"] = df.apply(build_proc_ids, axis=1)

# Medication
df["med_ids"] = df["medication_list"].apply(
    lambda lst: map_list_to_ids(lst, med_stoi)
)

# Order type
df["order_ids"] = df["order_type_list"].apply(
    lambda lst: map_list_to_ids(lst, order_stoi)
)

# Single categorical → ID
df["gender_id"] = df["gender"].map(gender_stoi)
df["race_id"] = df["race"].map(race_stoi)
df["service_id"] = df["curr_service"].map(service_stoi)
df["drg_code_id"] = df["drg_code"].astype(int).map(drg_code_stoi)

In [7]:
def summarize_list(lst, max_len=10):
    """
    Summarizes a list by truncating it if it's too long.
    """
    if lst is None:
        return None
    if len(lst) <= max_len:
        return lst
    return lst[:max_len] + ["...(+{} more)".format(len(lst) - max_len)]


cols_to_show = [
    "gender", "race", "curr_service", "drg_code",
    "gender_id", "race_id", "service_id", "drg_code_id",
    "diagnoses_icd_code_list", "diag_ids",
    "procedures_icd_code_list", "hcpcs_cd_list", "proc_ids",
    "medication_list", "med_ids",
    "order_type_list", "order_ids",
]

print("\n================ SAMPLE DATA (after encoding) ================\n")

for i in range(3):
    row = df.iloc[i]
    print(f"--- Row {i} ---")

    for c in cols_to_show:
        val = row[c]

        # Summarize list
        if isinstance(val, list):
            val = summarize_list(val)

        print(f"{c:25} : {val}")

    print("\n")



--- Row 0 ---
gender                    : F
race                      : WHITE
curr_service              : MED
drg_code                  : 279
gender_id                 : 0
race_id                   : 28
service_id                : 7
drg_code_id               : 135
diagnoses_icd_code_list   : ['07071' '78959' '2875' '2761' '496' '5715' 'V08' '3051']
diag_ids                  : [1, 2, 3, 4, 5, 6, 7, 8]
procedures_icd_code_list  : ['5491']
hcpcs_cd_list             : []
proc_ids                  : [1]
medication_list           : ['Raltegravir' 'Rifaximin' 'Sodium Chloride 0.9%  Flush'
 'Calcium Carbonate' 'Rifaximin' 'Raltegravir'
 'Emtricitabine-Tenofovir (Truvada)' 'Sulfameth/Trimethoprim DS'
 'Furosemide' 'Tiotropium Bromide' 'Albuterol Inhaler' 'Lactulose'
 'Heparin' 'Sodium Chloride 0.9%  Flush' 'Acetaminophen' 'Heparin'
 'Lactulose' 'Albumin 25% (12.5g / 50mL)' 'Sodium Chloride 0.9%  Flush'
 'Albumin 25% (12.5g / 50mL)' 'Albumin 25% (12.5g / 50mL)']
med_ids                   : [1,

In [8]:
class LOSDataset(Dataset):
    def __init__(self, df, use_log_target=True):
        self.df = df.reset_index(drop=True)
        self.use_log_target = use_log_target

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        sample = {
            # Tabular features
            "age": float(row["anchor_age"]),
            "gender_id": int(row["gender_id"]),
            "race_id": int(row["race_id"]),
            "service_id": int(row["service_id"]),
            "drg_code_id": int(row["drg_code_id"]),
            "drg_severity": float(row["drg_severity"]),
            "drg_mortality": float(row["drg_mortality"]),

            # List-type IDs
            "diag_ids": row["diag_ids"],
            "proc_ids": row["proc_ids"],
            "med_ids": row["med_ids"],
            "order_ids": row["order_ids"],
        }

        # Target
        los = float(row["los_hours"])
        if self.use_log_target:
            sample["target"] = np.log1p(los)
        else:
            sample["target"] = los

        return sample

In [9]:
def los_collate_fn(batch):
    B = len(batch)

    # ----- Tabular Stack -----
    age = torch.tensor([b["age"] for b in batch], dtype=torch.float32)
    gender_id = torch.tensor([b["gender_id"] for b in batch], dtype=torch.long)
    race_id = torch.tensor([b["race_id"] for b in batch], dtype=torch.long)
    service_id = torch.tensor([b["service_id"] for b in batch], dtype=torch.long)
    drg_code_id = torch.tensor([b["drg_code_id"] for b in batch], dtype=torch.long)
    drg_severity = torch.tensor([b["drg_severity"] for b in batch], dtype=torch.float32)
    drg_mortality = torch.tensor([b["drg_mortality"] for b in batch], dtype=torch.float32)

    target = torch.tensor([b["target"] for b in batch], dtype=torch.float32)

    # ----- List-type: diag / proc / med / order -----
    def build_bag_inputs(key):
        codes_all = []
        offsets = [0]
        for b in batch:
            ids = b[key]
            codes_all.extend(ids)
            offsets.append(len(codes_all))
        if len(codes_all) == 0:
            codes_tensor = torch.empty(0, dtype=torch.long)
        else:
            codes_tensor = torch.tensor(codes_all, dtype=torch.long)
        offsets_tensor = torch.tensor(offsets, dtype=torch.long)
        return codes_tensor, offsets_tensor

    diag_codes, diag_offsets = build_bag_inputs("diag_ids")
    proc_codes, proc_offsets = build_bag_inputs("proc_ids")
    med_codes, med_offsets = build_bag_inputs("med_ids")
    order_codes, order_offsets = build_bag_inputs("order_ids")

    batch_out = {
        "age": age,
        "gender_id": gender_id,
        "race_id": race_id,
        "service_id": service_id,
        "drg_code_id": drg_code_id,
        "drg_severity": drg_severity,
        "drg_mortality": drg_mortality,
        "diag_codes": diag_codes,
        "diag_offsets": diag_offsets,
        "proc_codes": proc_codes,
        "proc_offsets": proc_offsets,
        "med_codes": med_codes,
        "med_offsets": med_offsets,
        "order_codes": order_codes,
        "order_offsets": order_offsets,
        "target": target,
    }

    return batch_out

In [10]:
from torch.utils.data import random_split, DataLoader
import torch

# Create Dataset
dataset = LOSDataset(df, use_log_target=True)

n_total = len(dataset)
n_train = int(n_total * 0.7)
n_val = int(n_total * 0.15)
n_test = n_total - n_train - n_val   # Remaining

g = torch.Generator().manual_seed(42)
train_ds, val_ds, test_ds = random_split(dataset, [n_train, n_val, n_test], generator=g)

train_loader = DataLoader(
    train_ds,
    batch_size=256,
    shuffle=True,
    collate_fn=los_collate_fn,
)

val_loader = DataLoader(
    val_ds,
    batch_size=256,
    shuffle=False,
    collate_fn=los_collate_fn,
)

test_loader = DataLoader(
    test_ds,
    batch_size=256,
    shuffle=False,
    collate_fn=los_collate_fn,
)

print(f"#train: {len(train_ds)}, #val: {len(val_ds)}, #test: {len(test_ds)}")

#train: 297362, #val: 63720, #test: 63721


In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass

@dataclass
class ModelConfig:
    # ----- Tabular categorical vocab sizes -----
    num_genders: int          # e.g. {"M", "F"} -> 2
    num_races: int            # unique race categories
    num_services: int         # curr_service (MED, ORTHO, ...)
    num_drg_codes: int        # drg_code vocab size

    # ----- Code vocab sizes -----
    diag_vocab_size: int      # diagnoses_icd_code_list vocab
    proc_vocab_size: int      # Combined vocab for procedures_icd_code_list + hcpcs_cd_list
    med_vocab_size: int       # medication_list vocab
    order_vocab_size: int     # order_type_list vocab

    # ----- Embedding dimensions (ADDED THESE) -----
    emb_dim_gender: int = 4
    emb_dim_race: int = 8
    emb_dim_service: int = 8
    emb_dim_drg: int = 16

    emb_dim_diag: int = 32
    emb_dim_proc: int = 32
    emb_dim_med: int = 32
    emb_dim_order: int = 16

    # ----- Transformer parameters -----
    transformer_d_model: int = 64 # Embedding dimension for transformer, output dim of individual branch MLPs
    transformer_n_heads: int = 8  # Number of attention heads
    transformer_n_layers: int = 2 # Number of transformer encoder layers
    transformer_dim_feedforward: int = 256 # Feedforward network dimension

    # ----- Dropout -----
    dropout: float = 0.2

    # ----- Max sequence lengths for codes -----
    max_seq_len_diag: int = 80
    max_seq_len_proc: int = 50
    max_seq_len_med: int = 110
    max_seq_len_order: int = 50

    # ----- Feature counts for modality embeddings -----
    num_scalar_features: int = 3 # age, drg_severity, drg_mortality
    num_categorical_features: int = 4 # gender, race, service, drg_code
    num_code_modalities: int = 4 # diag, proc, med, order
    num_token_types: int = 1 + 3 + 4 + 4 # 1 for CLS, 3 for scalar, 4 for categorical, 4 for code modalities (total 12)


class MultiModalTransformerLOSModel(nn.Module):
    def __init__(self, cfg: ModelConfig):
        super().__init__()
        self.cfg = cfg

        # 3. Initialize a learnable [CLS] token
        self.cls_token = nn.Parameter(torch.randn(1, 1, cfg.transformer_d_model))

        # 4. Define nn.Linear layers for scalar features
        self.age_proj = nn.Linear(1, cfg.transformer_d_model)
        self.drg_severity_proj = nn.Linear(1, cfg.transformer_d_model)
        self.drg_mortality_proj = nn.Linear(1, cfg.transformer_d_model)

        # 5. Define nn.Embedding layers for categorical features
        self.gender_emb = nn.Embedding(cfg.num_genders, cfg.transformer_d_model)
        self.race_emb = nn.Embedding(cfg.num_races, cfg.transformer_d_model)
        self.service_emb = nn.Embedding(cfg.num_services, cfg.transformer_d_model)
        self.drg_code_emb = nn.Embedding(cfg.num_drg_codes, cfg.transformer_d_model)

        # 6. Define nn.Embedding layers for code features
        self.diag_emb = nn.Embedding(cfg.diag_vocab_size, cfg.transformer_d_model, padding_idx=0)
        self.proc_emb = nn.Embedding(cfg.proc_vocab_size, cfg.transformer_d_model, padding_idx=0)
        self.med_emb = nn.Embedding(cfg.med_vocab_size, cfg.transformer_d_model, padding_idx=0)
        self.order_emb = nn.Embedding(cfg.order_vocab_size, cfg.transformer_d_model, padding_idx=0)

        # 7. Initialize nn.Embedding for modality type embeddings
        # Modality types: CLS=0, age=1, severity=2, mortality=3, gender=4, race=5, service=6, drg_code=7, diag_codes=8, proc_codes=9, med_codes=10, order_codes=11
        self.modality_type_emb = nn.Embedding(cfg.num_token_types, cfg.transformer_d_model)

        # 8. Calculate the maximum total sequence length
        max_total_seq_len = (
            1  # CLS token
            + cfg.num_scalar_features # age, severity, mortality
            + cfg.num_categorical_features # gender, race, service, drg_code
            + cfg.max_seq_len_diag
            + cfg.max_seq_len_proc
            + cfg.max_seq_len_med
            + cfg.max_seq_len_order
        )
        self.max_total_seq_len = max_total_seq_len

        # 9. Initialize nn.Embedding for positional embeddings
        self.positional_emb = nn.Embedding(max_total_seq_len, cfg.transformer_d_model)

        # 10. Create a nn.TransformerEncoderLayer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=cfg.transformer_d_model,
            nhead=cfg.transformer_n_heads,
            dim_feedforward=cfg.transformer_dim_feedforward,
            dropout=cfg.dropout,
            batch_first=True
        )

        # 11. Instantiate the nn.TransformerEncoder
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=cfg.transformer_n_layers
        )

        # 12. Define the regression head
        self.regression_head = nn.Sequential(
            nn.Linear(cfg.transformer_d_model, cfg.transformer_d_model // 2),
            nn.ReLU(),
            nn.Linear(cfg.transformer_d_model // 2, 1)
        )

    def _get_modality_emb_with_type(self, embedding_tensor, modality_type_idx):
        """
        Adds the modality type embedding to the given embedding tensor.
        Ensures correct device for the modality type index tensor.
        """
        modality_type_tensor = torch.tensor([modality_type_idx], device=embedding_tensor.device)
        modality_emb = self.modality_type_emb(modality_type_tensor).unsqueeze(0) # (1, 1, d_model)
        # modality_emb will broadcast correctly to embedding_tensor's shape (B, 1, d_model) or (B, max_len, d_model)
        return embedding_tensor + modality_emb

    def forward(
        self,
        age: torch.Tensor,
        gender_idx: torch.Tensor,
        race_idx: torch.Tensor,
        service_idx: torch.Tensor,
        drg_code_idx: torch.Tensor,
        drg_severity: torch.Tensor,
        drg_mortality: torch.Tensor,
        diag_codes: torch.Tensor,
        diag_offsets: torch.Tensor,
        proc_codes: torch.Tensor,
        proc_offsets: torch.Tensor,
        med_codes: torch.Tensor,
        med_offsets: torch.Tensor,
        order_codes: torch.Tensor,
        order_offsets: torch.Tensor,
    ):
        B = age.size(0) # Batch size
        device = age.device

        all_embeddings = []
        src_key_padding_mask_components = []

        # 1. Prepare the [CLS] token
        cls_embedding = self.cls_token.expand(B, -1, -1) # (B, 1, d_model)
        cls_embedding = self._get_modality_emb_with_type(cls_embedding, 0) # CLS modality type 0
        all_embeddings.append(cls_embedding)
        src_key_padding_mask_components.append(torch.zeros(B, 1, dtype=torch.bool, device=device))

        # 2. Process scalar features
        # age (modality type 1)
        age_proj_emb = self.age_proj(age.float().unsqueeze(-1)).unsqueeze(1) # Project (B,1) to (B, d_model), then add 1 for sequence dim: (B, 1, d_model)
        age_emb = self._get_modality_emb_with_type(age_proj_emb, 1)
        all_embeddings.append(age_emb)
        src_key_padding_mask_components.append(torch.zeros(B, 1, dtype=torch.bool, device=device))

        # drg_severity (modality type 2)
        sev_proj_emb = self.drg_severity_proj(drg_severity.float().unsqueeze(-1)).unsqueeze(1) # (B, 1, d_model)
        sev_emb = self._get_modality_emb_with_type(sev_proj_emb, 2)
        all_embeddings.append(sev_emb)
        src_key_padding_mask_components.append(torch.zeros(B, 1, dtype=torch.bool, device=device))

        # drg_mortality (modality type 3)
        mort_proj_emb = self.drg_mortality_proj(drg_mortality.float().unsqueeze(-1)).unsqueeze(1) # (B, 1, d_model)
        mort_emb = self._get_modality_emb_with_type(mort_proj_emb, 3)
        all_embeddings.append(mort_emb)
        src_key_padding_mask_components.append(torch.zeros(B, 1, dtype=torch.bool, device=device))

        # 3. Process categorical features
        # gender (modality type 4)
        gender_proj_emb = self.gender_emb(gender_idx).unsqueeze(1) # (B, 1, d_model)
        gender_emb = self._get_modality_emb_with_type(gender_proj_emb, 4)
        all_embeddings.append(gender_emb)
        src_key_padding_mask_components.append(torch.zeros(B, 1, dtype=torch.bool, device=device))

        # race (modality type 5)
        race_proj_emb = self.race_emb(race_idx).unsqueeze(1) # (B, 1, d_model)
        race_emb = self._get_modality_emb_with_type(race_proj_emb, 5)
        all_embeddings.append(race_emb)
        src_key_padding_mask_components.append(torch.zeros(B, 1, dtype=torch.bool, device=device))

        # service (modality type 6)
        service_proj_emb = self.service_emb(service_idx).unsqueeze(1) # (B, 1, d_model)
        service_emb = self._get_modality_emb_with_type(service_proj_emb, 6)
        all_embeddings.append(service_emb)
        src_key_padding_mask_components.append(torch.zeros(B, 1, dtype=torch.bool, device=device))

        # drg_code (modality type 7)
        drg_code_proj_emb = self.drg_code_emb(drg_code_idx).unsqueeze(1) # (B, 1, d_model)
        drg_code_emb = self._get_modality_emb_with_type(drg_code_proj_emb, 7)
        all_embeddings.append(drg_code_emb)
        src_key_padding_mask_components.append(torch.zeros(B, 1, dtype=torch.bool, device=device))

        # Helper for processing code sequences
        def process_code_sequence(codes, offsets, emb_layer, max_len, modality_type_idx):
            if codes.numel() == 0: # Handle empty codes tensor for a whole batch
                padded_seqs = torch.zeros(B, max_len, self.cfg.transformer_d_model, device=device)
                padding_mask = torch.ones(B, max_len, dtype=torch.bool, device=device)
            else:
                embedded_codes_flat = emb_layer(codes) # (N_codes, d_model)

                batch_padded_sequences = []
                batch_padding_masks = []

                # Split by offsets to get individual sequences for each batch item
                for i in range(B):
                    start = offsets[i]
                    end = offsets[i+1]

                    sequence = embedded_codes_flat[start:end] # (seq_len_i, d_model)

                    # Truncate
                    if sequence.size(0) > max_len:
                        sequence = sequence[:max_len]

                    # Pad
                    padding_needed = max_len - sequence.size(0)
                    if padding_needed > 0:
                        padding = torch.zeros(padding_needed, self.cfg.transformer_d_model, device=device)
                        sequence = torch.cat([sequence, padding], dim=0)
                        mask_row = torch.cat([
                            torch.zeros(max_len - padding_needed, dtype=torch.bool, device=device),
                            torch.ones(padding_needed, dtype=torch.bool, device=device)
                        ])
                    else:
                        mask_row = torch.zeros(max_len, dtype=torch.bool, device=device)

                    batch_padded_sequences.append(sequence.unsqueeze(0)) # (1, max_len, d_model)
                    batch_padding_masks.append(mask_row.unsqueeze(0)) # (1, max_len)

                # Stack sequences and masks for the batch
                padded_seqs = torch.cat(batch_padded_sequences, dim=0) # (B, max_len, d_model)
                padding_mask = torch.cat(batch_padding_masks, dim=0) # (B, max_len)

            # Add modality type embedding using the helper method
            padded_seqs = self._get_modality_emb_with_type(padded_seqs, modality_type_idx)

            return padded_seqs, padding_mask

        # 4. Process sequence-type code features
        # diag_codes (modality type 8)
        h_diag_sequence, diag_padding_mask = process_code_sequence(
            diag_codes, diag_offsets, self.diag_emb, self.cfg.max_seq_len_diag, 8
        )
        all_embeddings.append(h_diag_sequence)
        src_key_padding_mask_components.append(diag_padding_mask)

        # proc_codes (modality type 9)
        h_proc_sequence, proc_padding_mask = process_code_sequence(
            proc_codes, proc_offsets, self.proc_emb, self.cfg.max_seq_len_proc, 9
        )
        all_embeddings.append(h_proc_sequence)
        src_key_padding_mask_components.append(proc_padding_mask)

        # med_codes (modality type 10)
        h_med_sequence, med_padding_mask = process_code_sequence(
            med_codes, med_offsets, self.med_emb, self.cfg.max_seq_len_med, 10
        )
        all_embeddings.append(h_med_sequence)
        src_key_padding_mask_components.append(med_padding_mask)

        # order_codes (modality type 11)
        h_order_sequence, order_padding_mask = process_code_sequence(
            order_codes, order_offsets, self.order_emb, self.cfg.max_seq_len_order, 11
        )
        all_embeddings.append(h_order_sequence)
        src_key_padding_mask_components.append(order_padding_mask)

        # 5. Concatenate all processed tokens
        transformer_input = torch.cat(all_embeddings, dim=1) # (B, total_seq_len, d_model)

        # 6. Create the src_key_padding_mask
        src_key_padding_mask = torch.cat(src_key_padding_mask_components, dim=1) # (B, total_seq_len)

        # 7. Add positional embeddings
        positions = torch.arange(self.max_total_seq_len, device=device).unsqueeze(0) # (1, total_seq_len)
        positional_embeddings = self.positional_emb(positions)
        transformer_input = transformer_input + positional_embeddings # (B, total_seq_len, d_model)

        # 8. Pass through the Transformer Encoder
        transformer_output = self.transformer_encoder(
            transformer_input,
            src_key_padding_mask=src_key_padding_mask
        ) # (B, total_seq_len, d_model)

        # 9. Extract the output corresponding to the [CLS] token
        cls_output = transformer_output[:, 0, :] # (B, d_model)

        # 10. Pass to the regression head
        prediction = self.regression_head(cls_output).squeeze(-1) # (B,)

        return prediction

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass


# 1. Model / Config Creation

cfg = ModelConfig(
    num_genders=len(gender_stoi),
    num_races=len(race_stoi),
    num_services=len(service_stoi),
    num_drg_codes=len(drg_code_stoi),
    diag_vocab_size=len(diag_stoi),
    proc_vocab_size=len(proc_all_stoi),
    med_vocab_size=len(med_stoi),
    order_vocab_size=len(order_stoi),
)

model = MultiModalTransformerLOSModel(cfg)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = nn.MSELoss()  # MSE based on log(1+LOS)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

NUM_EPOCHS = 10

best_val_loss = float("inf")
best_model_path = "los_multibranch_best.pt"
print("Using device:", device)

# 2. Training Loop
for epoch in range(1, NUM_EPOCHS + 1):
    # ---------- Train ----------
    model.train()
    train_loss_sum = 0.0
    train_count = 0

    print(f"\n========== Epoch {epoch}/{NUM_EPOCHS} ==========")

    for batch_idx, batch in enumerate(train_loader):
        # Move tensors to device
        batch = {k: (v.to(device) if torch.is_tensor(v) else v)
                 for k, v in batch.items()}

        optimizer.zero_grad()

        y_pred = model(
            age=batch["age"],
            gender_idx=batch["gender_id"],
            race_idx=batch["race_id"],
            service_idx=batch["service_id"],
            drg_code_idx=batch["drg_code_id"],
            drg_severity=batch["drg_severity"],
            drg_mortality=batch["drg_mortality"],
            diag_codes=batch["diag_codes"],
            diag_offsets=batch["diag_offsets"],
            proc_codes=batch["proc_codes"],
            proc_offsets=batch["proc_offsets"],
            med_codes=batch["med_codes"],
            med_offsets=batch["med_offsets"],
            order_codes=batch["order_codes"],
            order_offsets=batch["order_offsets"],
        )

        y_true = batch["target"]  # log(1+LOS)
        loss = criterion(y_pred, y_true)

        loss.backward()
        optimizer.step()

        bs = y_true.size(0)
        train_loss_sum += loss.item() * bs
        train_count += bs

        if batch_idx % 100 == 0:
            avg_loss = train_loss_sum / train_count
            print(f"  [Epoch {epoch} | Step {batch_idx}/{len(train_loader)}] "
                  f"AvgTrainLoss={avg_loss:.4f}")

    train_loss = train_loss_sum / train_count

    # ---------- Validation ----------
    model.eval()
    val_loss_sum = 0.0
    val_count = 0
    val_mae_hours_sum = 0.0  # MAE based on actual LOS (hours)

    with torch.no_grad():
        for batch in val_loader:
            batch = {k: (v.to(device) if torch.is_tensor(v) else v)
                     for k, v in batch.items()}

            y_pred = model(
                age=batch["age"],
                gender_idx=batch["gender_id"],
                race_idx=batch["race_id"],
                service_idx=batch["service_id"],
                drg_code_idx=batch["drg_code_id"],
                drg_severity=batch["drg_severity"],
                drg_mortality=batch["drg_mortality"],
                diag_codes=batch["diag_codes"],
                diag_offsets=batch["diag_offsets"],
                proc_codes=batch["proc_codes"],
                proc_offsets=batch["proc_offsets"],
                med_codes=batch["med_codes"],
                med_offsets=batch["med_offsets"],
                order_codes=batch["order_codes"],
                order_offsets=batch["order_offsets"],
            )

            y_true = batch["target"]

            loss = criterion(y_pred, y_true)

            bs = y_true.size(0)
            val_loss_sum += loss.item() * bs
            val_count += bs

            # Convert log(1+LOS) -> actual LOS (hours) and calculate MAE
            y_true_hours = torch.expm1(y_true)
            y_pred_hours = torch.expm1(y_pred)

            mae_hours = torch.abs(y_pred_hours - y_true_hours).sum().item()
            val_mae_hours_sum += mae_hours

    val_loss = val_loss_sum / val_count
    val_mae_hours = val_mae_hours_sum / val_count

    print(f"[Epoch {epoch:03d}] "
          f"train_loss(log-MSE)={train_loss:.4f} | "
          f"val_loss(log-MSE)={val_loss:.4f} | "
          f"val_MAE(hours)={val_mae_hours:.2f}")

    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), best_model_path)
        print(f"  → Best model updated, saved to {best_model_path}")

print("Training finished. Best val_loss:", best_val_loss)

# 3. Calculate MAE (hours) on Test set
best_model = MultiModalTransformerLOSModel(cfg).to(device)
best_model.load_state_dict(torch.load(best_model_path, map_location=device))
best_model.eval()

test_abs_error_sum = 0.0
test_count = 0

with torch.no_grad():
    for batch in test_loader:
        batch = {k: (v.to(device) if torch.is_tensor(v) else v)
                 for k, v in batch.items()}

        y_pred = best_model(
            age=batch["age"],
            gender_idx=batch["gender_id"],
            race_idx=batch["race_id"],
            service_idx=batch["service_id"],
            drg_code_idx=batch["drg_code_id"],
            drg_severity=batch["drg_severity"],
            drg_mortality=batch["drg_mortality"],
            diag_codes=batch["diag_codes"],
            diag_offsets=batch["diag_offsets"],
            proc_codes=batch["proc_codes"],
            proc_offsets=batch["proc_offsets"],
            med_codes=batch["med_codes"],
            med_offsets=batch["med_offsets"],
            order_codes=batch["order_codes"],
            order_offsets=batch["order_offsets"],
        )

        y_true = batch["target"]  # log(1+LOS)

        # Convert log(1+LOS) → actual hours
        y_true_hours = torch.expm1(y_true)
        y_pred_hours = torch.expm1(y_pred)

        abs_err = torch.abs(y_pred_hours - y_true_hours)
        test_abs_error_sum += abs_err.sum().item()
        test_count += y_true.size(0)

test_mae_hours = test_abs_error_sum / test_count

print(f"\n===== Test set MAE =====")
print(f"Test MAE (hours): {test_mae_hours:.2f}")
print(f"Test MAE (days) : {test_mae_hours / 24:.2f}")


Using device: cuda

  [Epoch 1 | Step 0/1162] AvgTrainLoss=17.8877
  [Epoch 1 | Step 100/1162] AvgTrainLoss=1.9337
  [Epoch 1 | Step 200/1162] AvgTrainLoss=1.2955
  [Epoch 1 | Step 300/1162] AvgTrainLoss=1.0749
  [Epoch 1 | Step 400/1162] AvgTrainLoss=0.9304
  [Epoch 1 | Step 500/1162] AvgTrainLoss=0.8344
  [Epoch 1 | Step 600/1162] AvgTrainLoss=0.7665
  [Epoch 1 | Step 700/1162] AvgTrainLoss=0.7160
  [Epoch 1 | Step 800/1162] AvgTrainLoss=0.6762
  [Epoch 1 | Step 900/1162] AvgTrainLoss=0.6439
  [Epoch 1 | Step 1000/1162] AvgTrainLoss=0.6161
  [Epoch 1 | Step 1100/1162] AvgTrainLoss=0.5940
[Epoch 001] train_loss(log-MSE)=0.5817 | val_loss(log-MSE)=0.3852 | val_MAE(hours)=75.70
  → Best model updated, saved to los_multibranch_best.pt

  [Epoch 2 | Step 0/1162] AvgTrainLoss=0.3576
  [Epoch 2 | Step 100/1162] AvgTrainLoss=0.3432
  [Epoch 2 | Step 200/1162] AvgTrainLoss=0.3386
  [Epoch 2 | Step 300/1162] AvgTrainLoss=0.3323
  [Epoch 2 | Step 400/1162] AvgTrainLoss=0.3290
  [Epoch 2 | Step 