In [None]:
# conda install -c conda-forge numpy=1.26.4 pandas=2.3.3 scipy=1.15.3 scikit-learn=1.7.2 matplotlib=3.10.6 seaborn=0.13.2 pytorch=2.5.1 torchvision=0.20.1 torchaudio=2.5.1 ipython ipykernel tqdm psutil pyyaml -y && pip install accelerate==1.11.0 aiohappyeyeballs==2.6.1 aiohttp==3.12.15 aiosignal==1.4.0 alembic==1.17.0 annotated-types==0.7.0 anyio==4.11.0 attrs==25.3.0 Authlib==1.6.5 blinker==1.9.0 cachetools==6.2.2 certifi==2025.11.12 charset-normalizer==3.4.4 click==8.3.0 cloudpickle==3.1.2 colorama==0.4.6 colorlog==6.10.1 cryptography==45.0.7 cyclopts==3.24.0 datasets==4.1.1 dill==0.4.0 dnspython==2.8.0 docker==7.1.0 docstring_parser==0.17.0 evaluate==0.4.6 fastapi==0.119.0 Flask==3.1.2 flask-cors==6.0.1 fsspec==2025.9.0 ftfy==6.3.1 GitPython==3.1.45 google-api-core==2.29.0 google-auth==2.41.1 google-cloud-bigquery==3.40.0 google-cloud-core==2.5.0 graphene==3.4.3 grpcio==1.76.0 httpx==0.28.1 huggingface-hub==0.35.1 imbalanced-learn==0.14.0 joblib==1.5.2 kornia==0.8.1 lightning-utilities==0.15.2 mlflow==3.6.0 multiprocess==0.70.16 networkx open_clip_torch==3.2.0 optuna==4.5.0 pillow==12.0.0 protobuf==6.33.4 pydantic==2.12.0 pydantic-settings==2.11.0 pyarrow==21.0.0 regex==2025.9.18 rich==14.2.0 safetensors==0.6.2 SQLAlchemy==2.0.44 starlette==0.48.0 sympy==1.13.1 timm==1.0.20 tokenizers==0.22.1 transformers==4.57.1 uvicorn==0.37.0 waitress==3.0.2 xgboost==3.0.5


In [1]:
import pyarrow.parquet as pq
import pandas as pd

train_path = "/root/df_train30.parquet"

df_train = pq.read_table(train_path)
df_train = df_train.to_pandas()

print("Shape:", df_train.shape)
print(df_train.head())

val_path = "/root/df_val30.parquet"

df_val = pq.read_table(val_path)
df_val = df_val.to_pandas()

print("Shape:", df_val.shape)
print(df_val.head())

Shape: (1600488, 135)
    row_id  subject_id   stay_id  hr           starttime             endtime  \
0  3907981    17371178  30001396   0 2147-10-18 13:00:00 2147-10-18 14:00:00   
1  1645739    17371178  30001396   1 2147-10-18 14:00:00 2147-10-18 15:00:00   
2  4154575    17371178  30001396   1 2147-10-18 14:00:00 2147-10-18 15:00:00   
3  4550179    17371178  30001396   1 2147-10-18 14:00:00 2147-10-18 15:00:00   
4  2506459    17371178  30001396   2 2147-10-18 15:00:00 2147-10-18 16:00:00   

   age  height  weight  gender  ...  respiration  coagulation  liver  \
0   40     NaN   162.1     NaN  ...            0            0      0   
1   40     NaN   162.1     NaN  ...            0            0      0   
2   40     NaN   162.1     NaN  ...            0            0      0   
3   40     NaN   162.1     NaN  ...            0            0      0   
4   40     NaN   162.1     NaN  ...            0            0      0   

   cardiovascular  cns  renal  hours_beforesepsis  sepsis  fod  

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn.utils.rnn import pad_sequence
import numpy as np
import pandas as pd
import warnings
from tqdm import tqdm
from sklearn.preprocessing import StandardScaler
from torch.amp import autocast, GradScaler

In [3]:
output_cols = [
    "respiration",
    "coagulation",
    "liver",
    "cardiovascular",
    "cns",
    "renal",
    "hours_beforesepsis",
    "sepsis",
    "fod",
    "hours_beforedeath"
]

non_output_cols = df_train.columns.difference(output_cols)

# df_train["hours_beforesepsis"] = df_train["hours_beforesepsis"].fillna(0)
# df_train['gender'] = df_train['gender'].map({'M':0, 'F':1})

# for col in df_train.select_dtypes(include=['object']).columns:
#     df_train[col] = pd.to_numeric(df_train[col], errors='coerce')

# df_val["hours_beforesepsis"] = df_val["hours_beforesepsis"].fillna(0)
# df_val['gender'] = df_val['gender'].map({'M':0, 'F':1})

# for col in df_val.select_dtypes(include=['object']).columns:
#     df_val[col] = pd.to_numeric(df_val[col], errors='coerce')

In [4]:
regression_cols = ["respiration", "coagulation", "liver", "cardiovascular", 
                   "cns", "renal", "hours_beforesepsis", "hours_beforedeath"]
binary_cols = ["sepsis"]

In [5]:
# Train
X_train = df_train[non_output_cols].drop(columns=["starttime", "endtime", "subject_id", "row_id"]).to_numpy(dtype=np.float32)
y_train_reg = df_train[regression_cols].to_numpy(dtype=np.float32)
y_train_bin = df_train[binary_cols].to_numpy(dtype=np.int32)
stay_ids_train = df_train['stay_id'].values
times_train = df_train.groupby("stay_id").cumcount().values

# Val
X_val = df_val[non_output_cols].drop(columns=["starttime", "endtime", "subject_id", "row_id"]).to_numpy(dtype=np.float32)
y_val_reg = df_val[regression_cols].to_numpy(dtype=np.float32)
y_val_bin = df_val[binary_cols].to_numpy(dtype=np.int32)
stay_ids_val = df_val['stay_id'].values
times_val = df_val.groupby("stay_id").cumcount().values

In [None]:
# from sklearn.preprocessing import StandardScaler
# import numpy as np

# n_train, n_features = X_train.shape

# X_train_scaled = np.memmap(
#     "X_train_scaled2.dat",   
#     dtype="float32",
#     mode="w+",
#     shape=(n_train, n_features)
# )

# scaler_X = StandardScaler()
# scaler_X.fit(X_train.astype(np.float32))

# chunk_size = 100_000
# for start in range(0, n_train, chunk_size):
#     end = min(start + chunk_size, n_train)
#     X_train_scaled[start:end] = scaler_X.transform(
#         X_train[start:end].astype(np.float32)
#     )


In [None]:
# n_val = X_val.shape[0]

# X_val_scaled = np.memmap(
#     "X_val_scaled3.dat",
#     dtype="float32",
#     mode="w+",
#     shape=(n_val, n_features)
# )

# for start in range(0, n_val, chunk_size):
#     end = min(start + chunk_size, n_val)
#     X_val_scaled[start:end] = scaler_X.transform(
#         X_val[start:end].astype(np.float32)
#     )


In [None]:
# from sklearn.preprocessing import StandardScaler

# scaler_y_reg = StandardScaler()
# y_train_reg_scaled = scaler_y_reg.fit_transform(y_train_reg)
# y_val_reg_scaled = scaler_y_reg.transform(y_val_reg)

# # Binary targets
# y_train_bin_scaled = y_train_bin
# y_val_bin_scaled = y_val_bin

# # Gabungkan kembali
# y_train_scaled = np.concatenate([y_train_reg_scaled, y_train_bin_scaled], axis=1)
# y_val_scaled = np.concatenate([y_val_reg_scaled, y_val_bin_scaled], axis=1)

In [None]:
# # Global feature mean
# global_feat_mean = np.nanmean(X_train_scaled, axis=0)
# global_feat_mean = np.nan_to_num(global_feat_mean, nan=0.0)

In [6]:
import joblib

scaler_X = joblib.load("scaler_X.pkl")
scaler_y_reg = joblib.load("scaler_y_reg.pkl")
global_feat_mean = np.load("global_feat_mean.npy")

print("All joblib pkl loaded ✅")

All joblib pkl loaded ✅


In [7]:
n_train, n_features = X_train.shape

X_train_scaled = np.memmap(
    "X_train_scaled2.dat",
    dtype="float32",
    mode="r",
    shape=(n_train, n_features)
)

In [8]:
n_val, n_features = X_val.shape

X_val_scaled = np.memmap(
    "X_val_scaled2.dat",
    dtype="float32",
    mode="r",
    shape=(n_val, n_features)
)

In [9]:
scaler_y_reg.fit(y_train_reg)

y_train_bin = y_train_bin

y_train_reg_scaled = scaler_y_reg.transform(y_train_reg)
y_train_bin_scaled = y_train_bin

y_train_scaled = np.concatenate(
    [y_train_reg_scaled, y_train_bin_scaled],
    axis=1
)


y_val_bin = y_val_bin

y_val_reg_scaled = scaler_y_reg.transform(y_val_reg)
y_val_bin_scaled = y_val_bin

y_val_scaled = np.concatenate(
    [y_val_reg_scaled, y_val_bin_scaled],
    axis=1
)

In [10]:
class TemporalWindowDataset(Dataset):
    def __init__(
        self,
        X, y,
        stay_ids,
        times,
        global_feat_mean,
        window_sizes=(6, 12, 24),
        horizon=1
    ):
        self.X = X
        self.y = y
        self.global_mean = global_feat_mean
        self.samples = []

        df = pd.DataFrame({
            "stay_id": stay_ids,
            "time": times,
            "idx": np.arange(len(stay_ids))
        })

        for stay_id, g in df.groupby("stay_id"):
            g = g.sort_values("time")
            idxs = g["idx"].values
            tvals = g["time"].values

            for w in window_sizes:
                if len(idxs) <= w + horizon:
                    continue
                for i in range(w, len(idxs) - horizon):
                    hist = idxs[i - w:i]
                    target = idxs[i + horizon]
                    times_hist = tvals[i - w:i]
                    self.samples.append((hist, target, times_hist, w))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        hist_idx, target_idx, times, w = self.samples[idx]
        window_id = {6: 0, 12: 1, 24: 2}[w]

        X_seq = self.X[hist_idx].astype(float)
        y_target = self.y[target_idx]

        # Pisahin regresi & binary
        n_reg = y_train_reg.shape[1]  # jumlah kolom regresi
        n_bin = y_train_bin.shape[1]  # jumlah kolom binary

        y_reg = y_target[:n_reg].astype(float)
        y_bin = y_target[n_reg:].astype(int)  # pastikan integer 0/1

        mask = ~np.isnan(X_seq)

        T, F = X_seq.shape
        X_filled = np.zeros_like(X_seq)
        delta = np.zeros_like(X_seq)

        for f in range(F):
            last_val = self.global_mean[f]
            last_time = times[0]
            for t in range(T):
                if mask[t, f]:
                    delta[t, f] = 0.0
                    last_val = X_seq[t, f]
                    last_time = times[t]
                    X_filled[t, f] = last_val
                else:
                    delta[t, f] = times[t] - last_time
                    gamma = np.exp(-delta[t, f])
                    X_filled[t, f] = gamma * last_val + (1 - gamma) * self.global_mean[f]
                    last_val = X_filled[t, f]

        return {
            "X": torch.tensor(X_filled, dtype=torch.float32),
            "mask": torch.tensor(mask.astype(float), dtype=torch.float32),
            "delta": torch.tensor(delta, dtype=torch.float32),
            "y_reg": torch.tensor(y_reg, dtype=torch.float32),
            "y_bin": torch.tensor(y_bin, dtype=torch.float32),  # bisa float32 untuk BCE
            "window_id": torch.tensor(window_id, dtype=torch.long),
        }


def collate_fn(batch):
    X_list, mask_list, delta_list, y_reg_list, y_bin_list, window_id_list = [], [], [], [], [], []

    max_len = max(item["X"].shape[0] for item in batch)
    F = batch[0]["X"].shape[1]

    for item in batch:
        T = item["X"].shape[0]
        pad_len = max_len - T

        # pad X, mask, delta
        X_padded = torch.cat([item["X"], torch.zeros(pad_len, F)], dim=0)
        mask_padded = torch.cat([item["mask"], torch.zeros(pad_len, F)], dim=0)
        delta_padded = torch.cat([item["delta"], torch.zeros(pad_len, F)], dim=0)

        X_list.append(X_padded)
        mask_list.append(mask_padded)
        delta_list.append(delta_padded)
        y_reg_list.append(item["y_reg"])
        y_bin_list.append(item["y_bin"])
        window_id_list.append(item["window_id"])

    return {
        "X": torch.stack(X_list),
        "mask": torch.stack(mask_list),
        "delta": torch.stack(delta_list),
        "y_reg": torch.stack(y_reg_list),
        "y_bin": torch.stack(y_bin_list),
        "window_id": torch.stack(window_id_list),
    }


In [11]:
train_dataset = TemporalWindowDataset(
    X=X_train_scaled,
    y=y_train_scaled,
    stay_ids=stay_ids_train,
    times=times_train,
    global_feat_mean=global_feat_mean,
    window_sizes=(6, 12, 24),
    horizon=1
)

val_dataset = TemporalWindowDataset(
    X=X_val_scaled,
    y=y_val_scaled,
    stay_ids=stay_ids_val,
    times=times_val,
    global_feat_mean=global_feat_mean,
    window_sizes=(6, 12, 24),
    horizon=1
)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=128,
    shuffle=True,
    collate_fn=collate_fn, 
    pin_memory=True
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=128,
    shuffle=False,
    collate_fn=collate_fn, 
    pin_memory=True
)

In [12]:
class TemporalAttnPool(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.score = nn.Linear(d_model, 1)

    def forward(self, z, padding_mask):
        scores = self.score(z).squeeze(-1)  # [B, T]
        scores = scores.masked_fill(~padding_mask, -1e9)
        alpha = torch.softmax(scores, dim=1)
        pooled = (z * alpha.unsqueeze(-1)).sum(dim=1)
        return pooled

# ===== GRUD Transformer =====
class GRUDTransformer(nn.Module):
    def __init__(
        self,
        n_features,
        hidden_size=96,
        d_model=128,
        nhead=4,
        num_layers=2,
        reg_dim=8,
        bin_dim=1
    ):
        super().__init__()

        self.input_size = n_features * 3  # x + mask + delta
        self.gru = nn.GRU(self.input_size, hidden_size, batch_first=True)
        self.to_dmodel = nn.Linear(hidden_size, d_model)

        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.attn_pool = TemporalAttnPool(d_model)

        # ===== 6-head setup =====
        self.reg_heads = nn.ModuleList([nn.Linear(d_model, reg_dim) for _ in range(3)])  # regression per window
        self.bin_heads = nn.ModuleList([nn.Linear(d_model, bin_dim) for _ in range(3)])  # binary per window

        # buat loop mudah untuk freeze/unfreeze
        self.heads = nn.ModuleList(self.reg_heads + self.bin_heads)  # total 6 head

    def forward(self, x, mask, delta, window_id=None):
        inp = torch.cat([x, mask, delta], dim=-1)
        h, _ = self.gru(inp)
        z = self.to_dmodel(h)

        time_mask = mask.sum(dim=-1) > 0
        z = self.transformer(z, src_key_padding_mask=~time_mask)
        pooled = self.attn_pool(z, padding_mask=~time_mask)

        # Ambil output head per sample
        y_reg_out = torch.zeros(x.size(0), self.reg_heads[0].out_features, device=x.device)
        y_bin_out = torch.zeros(x.size(0), self.bin_heads[0].out_features, device=x.device)

        for i, w_id in enumerate(window_id):
            y_reg_out[i] = self.reg_heads[w_id](pooled[i])
            y_bin_out[i] = self.bin_heads[w_id](pooled[i])

        return y_reg_out, y_bin_out

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
print("Supports bf16:", torch.cuda.is_bf16_supported())

Using device: cuda
Supports bf16: True


In [14]:
n_features = X_train_scaled.shape[1]
reg_dim = y_train_scaled.shape[1] - 1
bin_dim = 1

model = GRUDTransformer(
    n_features=n_features,
    hidden_size=64,
    d_model=128,
    nhead=4,
    num_layers=2,
    reg_dim=reg_dim,
    bin_dim=bin_dim
).to(device)

In [None]:
use_amp = True
num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for head in model.heads:
    for p in head.parameters():
        p.requires_grad = False
for p in model.reg_heads[0].parameters():
    p.requires_grad = True
for p in model.bin_heads[0].parameters():
    p.requires_grad = True

for p in model.reg_heads[1].parameters():
    p.requires_grad = True
for p in model.bin_heads[1].parameters():
    p.requires_grad = True

optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=3e-4,
    weight_decay=1e-4
)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.5, patience=3
)

scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

  scaler = torch.cuda.amp.GradScaler(enabled=use_amp)


In [20]:
def load_ckpt(path, model, optimizer, scheduler, scaler, device):
    ckpt = torch.load(path, map_location=device)
    model.load_state_dict(ckpt["model"])
    optimizer.load_state_dict(ckpt["optimizer"])
    scheduler.load_state_dict(ckpt["scheduler"])
    scaler.load_state_dict(ckpt["scaler"])
    return ckpt["epoch"]

In [None]:
last_epoch = load_ckpt(
    "ckpt_epoch_5.pt",
    model,
    optimizer,
    scheduler,
    scaler,
    device
)
start_epoch = last_epoch + 1

In [22]:
def save_ckpt(path, epoch, model, optimizer, scheduler, scaler):
    torch.save({
        "epoch": epoch,
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scheduler": scheduler.state_dict(),
        "scaler": scaler.state_dict(),
    }, path)

In [24]:
def train_phase(
    model,
    train_loader,
    val_loader,
    optimizer,
    scheduler,
    device,
    epochs,
    phase_name="",
    use_amp=True,
    scaler=scaler,
    start_epoch=start_epoch
):
    criterion_reg = torch.nn.MSELoss(reduction='none')  # per output
    criterion_bin = torch.nn.BCEWithLogitsLoss(reduction='none')  # per output

    for epoch in range(start_epoch, epochs):
        model.train()
        train_losses = {0: {"reg": 0.0, "bin": 0.0}, 1: {"reg": 0.0, "bin": 0.0}, 2: {"reg": 0.0, "bin": 0.0}}
        train_counts = {0: 0, 1: 0, 2: 0}
        train_losses_per_output = {0: [], 1: [], 2: []}  # list of tensor (regression 8 dim)

        # ===== TRAIN =====
        for batch in tqdm(train_loader, desc=f"{phase_name} Epoch {epoch+1} [Train]"):
            X = batch["X"].to(device)
            mask = batch["mask"].to(device)
            delta = batch["delta"].to(device)
            y_reg = batch["y_reg"].to(device)
            y_bin = batch["y_bin"].to(device)
            window_id = batch["window_id"].to(device)

            optimizer.zero_grad(set_to_none=True)

            with torch.cuda.amp.autocast(enabled=use_amp, dtype=torch.bfloat16):
                y_reg_out, y_bin_out = model(X, mask, delta, window_id)
                losses = []

                for head_idx in window_id.unique():
                    idx = (window_id == head_idx)
                    if idx.any():
                        y_reg_head = y_reg_out[idx]
                        y_bin_head = y_bin_out[idx]
                        y_reg_target = y_reg[idx]
                        y_bin_target = y_bin[idx]

                        # per output loss
                        reg_loss_per_output = criterion_reg(y_reg_head, y_reg_target).mean(dim=0)  # 8 output
                        bin_loss_per_output = criterion_bin(y_bin_head.squeeze(-1), y_bin_target.squeeze(-1))
                        bin_loss_per_output = bin_loss_per_output.mean(dim=0) if bin_loss_per_output.dim() > 1 else bin_loss_per_output

                        train_losses_per_output[head_idx.item()].append(reg_loss_per_output.cpu())

                        # total head loss (avg)
                        reg_loss = reg_loss_per_output.mean()
                        bin_loss = bin_loss_per_output.mean()

                        train_losses[head_idx.item()]["reg"] += reg_loss.item()
                        train_losses[head_idx.item()]["bin"] += bin_loss.item()
                        train_counts[head_idx.item()] += 1

                        losses.append(reg_loss + bin_loss)

                loss = torch.stack(losses).mean()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

        # avg per head
        for h in train_losses:
            if train_counts[h] > 0:
                train_losses[h]["reg"] /= train_counts[h]
                train_losses[h]["bin"] /= train_counts[h]

        # --- PRINT TRAIN LOSS PER OUTPUT ---
        print(f"{phase_name} | Epoch {epoch+1} Train Loss per Head:")
        for h in train_losses:
            print(f"  Head {h}: Reg Loss = {train_losses[h]['reg']:.4f}, Bin Loss = {train_losses[h]['bin']:.4f}")
            # per output (gabungkan semua loss per output jadi satu baris)
            if train_losses_per_output[h]:
                per_output_mean = torch.stack(train_losses_per_output[h]).mean(dim=0)
                loss_str = ", ".join([f"{l.item():.4f}" for l in per_output_mean])
                print(f"    Reg Loss per Output: [{loss_str}]")

        # ===== VALIDATION =====
        model.eval()
        val_losses = {0: {"reg": 0.0, "bin": 0.0}, 1: {"reg": 0.0, "bin": 0.0}, 2: {"reg": 0.0, "bin": 0.0}}
        val_counts = {0: 0, 1: 0, 2: 0}
        val_losses_per_output = {0: [], 1: [], 2: []}
        sepsis_logits, sepsis_true = [], []
        reg_preds, reg_true = {0: [], 1: [], 2: []}, {0: [], 1: [], 2: []}

        with torch.no_grad():
            for batch in val_loader:
                X = batch["X"].to(device)
                mask = batch["mask"].to(device)
                delta = batch["delta"].to(device)
                y_reg = batch["y_reg"].to(device)
                y_bin = batch["y_bin"].to(device)
                window_id = batch["window_id"].to(device)

                y_reg_out, y_bin_out = model(X, mask, delta, window_id)

                for head_idx in window_id.unique():
                    idx = (window_id == head_idx)
                    if idx.any():
                        y_reg_head = y_reg_out[idx]
                        y_bin_head = y_bin_out[idx]
                        y_reg_target = y_reg[idx]
                        y_bin_target = y_bin[idx]

                        reg_loss_per_output = criterion_reg(y_reg_head, y_reg_target).mean(dim=0)
                        bin_loss_per_output = criterion_bin(y_bin_head.squeeze(-1), y_bin_target.squeeze(-1))
                        bin_loss_per_output = bin_loss_per_output.mean(dim=0) if bin_loss_per_output.dim() > 1 else bin_loss_per_output

                        val_losses_per_output[head_idx.item()].append(reg_loss_per_output.cpu())

                        reg_loss = reg_loss_per_output.mean()
                        bin_loss = bin_loss_per_output.mean()

                        val_losses[head_idx.item()]["reg"] += reg_loss.item()
                        val_losses[head_idx.item()]["bin"] += bin_loss.item()
                        val_counts[head_idx.item()] += 1

                        sepsis_logits.append(y_bin_head.squeeze(-1).cpu())
                        sepsis_true.append(y_bin_target.cpu())

                        reg_preds[head_idx.item()].append(y_reg_head.cpu())
                        reg_true[head_idx.item()].append(y_reg_target.cpu())

        # avg per head
        for h in val_losses:
            if val_counts[h] > 0:
                val_losses[h]["reg"] /= val_counts[h]
                val_losses[h]["bin"] /= val_counts[h]

        # --- compute regression metrics per head ---
        from sklearn.metrics import mean_squared_error
        reg_metrics = {}
        for h in reg_preds:
            if len(reg_preds[h]) > 0:
                preds = torch.cat(reg_preds[h], dim=0).numpy()
                targets = torch.cat(reg_true[h], dim=0).numpy()
                mse = mean_squared_error(targets, preds)
                reg_metrics[h] = mse

        # --- metrics sepsis ---
        sepsis_logits = torch.cat(sepsis_logits).view(-1)
        sepsis_true = torch.cat(sepsis_true).view(-1)
        sepsis_probs = torch.sigmoid(sepsis_logits)

        from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix
        sepsis_true_np = sepsis_true.cpu().numpy().astype(int)
        sepsis_probs_np = sepsis_probs.cpu().numpy().astype(float)
        sepsis_preds = (sepsis_probs_np > 0.5).astype(int)
        acc = accuracy_score(sepsis_true_np, sepsis_preds)
        auc = roc_auc_score(sepsis_true_np, sepsis_probs_np)
        cm = confusion_matrix(sepsis_true_np, sepsis_preds)

        print(f"{phase_name} | Epoch {epoch+1} Val Loss per Head:")
        for h in val_losses:
            print(f"  Head {h}: Reg Loss = {val_losses[h]['reg']:.4f}, Bin Loss = {val_losses[h]['bin']:.4f}, MSE = {reg_metrics.get(h, 0):.4f}")
            if val_losses_per_output[h]:
                # gabungkan semua loss per output jadi satu baris
                per_output_mean = torch.stack(val_losses_per_output[h]).mean(dim=0)
                loss_str = ", ".join([f"{l.item():.4f}" for l in per_output_mean])
                print(f"    Reg Loss per Output: [{loss_str}]")

        save_ckpt(
            path=f"ckpt_epoch_{epoch}.pt",
            epoch=epoch,
            model=model,
            optimizer=optimizer,
            scheduler=scheduler,
            scaler=scaler
        )

In [25]:
train_phase(
    model=model,
    train_loader=train_dataloader,
    val_loader=val_dataloader,
    optimizer=optimizer,
    scheduler=scheduler,
    device=device,
    epochs=10,
    phase_name="PHASE 1 (6h)",
    use_amp=True,
    scaler=scaler,
    start_epoch=start_epoch
)

  with torch.cuda.amp.autocast(enabled=use_amp, dtype=torch.bfloat16):
PHASE 1 (6h) Epoch 6 [Train]: 100%|██████████| 31497/31497 [1:10:53<00:00,  7.41it/s]


PHASE 1 (6h) | Epoch 6 Train Loss per Head:
  Head 0: Reg Loss = 0.2887, Bin Loss = 0.0562
    Reg Loss per Output: [0.3334, 0.3400, 0.2651, 0.2900, 0.5159, 0.3596, 0.1071, 0.0989]
  Head 1: Reg Loss = 0.2071, Bin Loss = 0.0290
    Reg Loss per Output: [0.2562, 0.2304, 0.1825, 0.2047, 0.3804, 0.2655, 0.0696, 0.0673]
  Head 2: Reg Loss = 0.1523, Bin Loss = 0.0278
    Reg Loss per Output: [0.1922, 0.1561, 0.1337, 0.1379, 0.2697, 0.2133, 0.0560, 0.0596]


  output = torch._nested_tensor_from_mask(


PHASE 1 (6h) | Epoch 6 Val Loss per Head:
  Head 0: Reg Loss = 0.9729, Bin Loss = 0.9394, MSE = 1.1631
    Reg Loss per Output: [0.8163, 0.9773, 0.7414, 0.7995, 0.9163, 0.8696, 1.1869, 1.4756]
  Head 1: Reg Loss = 0.8745, Bin Loss = 1.0620, MSE = 1.0698
    Reg Loss per Output: [0.6665, 0.8230, 0.6294, 0.5763, 0.7898, 0.7495, 1.2113, 1.5506]
  Head 2: Reg Loss = 0.6160, Bin Loss = 1.0362, MSE = 0.8814
    Reg Loss per Output: [0.3062, 0.3845, 0.3582, 0.1946, 0.3553, 0.4365, 1.2183, 1.6742]


  with torch.cuda.amp.autocast(enabled=use_amp, dtype=torch.bfloat16):
PHASE 1 (6h) Epoch 7 [Train]:   3%|▎         | 895/31497 [02:48<1:35:46,  5.33it/s]


KeyboardInterrupt: 