In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import trange

class SimpleSTRAFE(nn.Module):
    def __init__(self, vocab_size=1000, embedding_dim=64, max_visits=20, max_time=12):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.max_visits = max_visits
        self.max_time = max_time

        # 每个诊断 code 的嵌入
        self.concept_embedder = nn.Embedding(vocab_size, embedding_dim)

        # 每次 visit 的时间位置嵌入（可选的 index-based 时间编码）
        self.time_embedder = nn.Embedding(max_visits, embedding_dim)

        # visit 序列 → Transformer 表示
        encoder_layer = nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=4, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=2)

        # Conv 将 visit 序列对齐到时间轴（如：48个月）
        self.conv = nn.Conv1d(in_channels=max_visits, out_channels=max_time, kernel_size=1)

        # MLP → 每月的生存风险概率
        self.mlp = nn.Sequential(
            nn.Linear(embedding_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, visit_codes, visit_times):
        """
        visit_codes: [B, V, C] -> 每次 visit 的多个诊断 code id
        visit_times: [B, V]    -> 每次 visit 的时间（整数）
        """
        B, V, C = visit_codes.shape

        # 将每个 code 嵌入后，对每次 visit 内 code 求和
        code_embed = self.concept_embedder(visit_codes)  # [B, V, C, D]
        visit_embed = code_embed.sum(dim=2)              # [B, V, D]

        # 加上 visit 时间嵌入
        time_embed = self.time_embedder(visit_times)     # [B, V, D]
        x = visit_embed + time_embed                     # [B, V, D]

        # visit 序列经过 self-attention 建模
        x = self.transformer(x)                          # [B, V, D]

        # Conv 将 V 维映射到 T 维（月）
        x = self.conv(x)                                 # [B, T, D]

        # 对每月的表示进行 MLP → 得到每月生存概率
        x = self.mlp(x).squeeze(-1)                      # [B, T]
        return x


def strafe_loss(pred_q, event_time, event_indicator):
    B, T = pred_q.shape
    eps = 1e-7
    S_hat = torch.cumprod(pred_q, dim=1)  # [B, T]

    loss = 0.0
    for i in range(B):
        T_i = event_time[i]
        S_i = S_hat[i]
        if event_indicator[i] == 1:
            pre_event = torch.log(S_i[:T_i] + eps).sum()
            post_event = torch.log(1 - S_i[T_i:] + eps).sum()
            loss -= pre_event + post_event
        else:
            censored = torch.log(S_i[:T_i] + eps).sum()
            loss -= censored
    return loss / B

def generate_toy_data(B=32, V=20, C=5, T=12, vocab_size=1000):
    visit_codes = torch.randint(0, vocab_size, (B, V, C))
    visit_times = torch.arange(V).unsqueeze(0).repeat(B, 1)
    event_time = torch.randint(4, T, (B,))
    event_indicator = torch.randint(0, 2, (B,))
    return visit_codes, visit_times, event_time, event_indicator


In [2]:
B, V, C, T = 4, 20, 5, 12  # 4个病人, 每人20次visit, 每次5个code, 输出12个月预测
dummy_codes = torch.randint(0, 1000, (B, V, C))
dummy_times = torch.arange(V).unsqueeze(0).repeat(B, 1)

model = SimpleSTRAFE()
output = model(dummy_codes, dummy_times)

print(output.shape)  # 应该是 [4, 12]


torch.Size([4, 12])


In [3]:
output

tensor([[0.4505, 0.4602, 0.4558, 0.4589, 0.4776, 0.4017, 0.4875, 0.5181, 0.4545,
         0.4272, 0.4611, 0.4728],
        [0.4362, 0.4495, 0.4750, 0.4641, 0.4978, 0.4602, 0.4743, 0.4571, 0.4080,
         0.4644, 0.4563, 0.4638],
        [0.4627, 0.4417, 0.4815, 0.4510, 0.4749, 0.4213, 0.4380, 0.5073, 0.4812,
         0.5123, 0.4559, 0.4555],
        [0.4153, 0.4191, 0.4433, 0.4569, 0.4865, 0.4641, 0.4861, 0.4699, 0.4525,
         0.4798, 0.4738, 0.4324]], grad_fn=<SqueezeBackward1>)

In [6]:

model = SimpleSTRAFE()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

for epoch in trange(10, desc="Training"):
    model.train()
    visit_codes, visit_times, event_time, event_indicator = generate_toy_data()
    pred_q = model(visit_codes, visit_times)
    loss = strafe_loss(pred_q, event_time, event_indicator)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")


Training:  70%|███████   | 7/10 [00:00<00:00, 32.03it/s]

Epoch 1, Loss: 25.1079
Epoch 2, Loss: 17.0588
Epoch 3, Loss: 16.3907
Epoch 4, Loss: 15.0230
Epoch 5, Loss: 12.8963
Epoch 6, Loss: 10.8966
Epoch 7, Loss: 10.5710
Epoch 8, Loss: 9.9704


Training: 100%|██████████| 10/10 [00:00<00:00, 33.54it/s]

Epoch 9, Loss: 9.1118
Epoch 10, Loss: 7.8756





In [7]:
model.eval()
with torch.no_grad():
    visit_codes, visit_times, event_time, event_indicator = generate_toy_data(B=4)
    pred_q = model(visit_codes, visit_times)
    S_hat = torch.cumprod(pred_q, dim=1)
    print("Predicted survival probabilities:\n", S_hat)

Predicted survival probabilities:
 tensor([[0.8595, 0.6830, 0.6702, 0.5311, 0.3914, 0.2687, 0.1859, 0.1058, 0.0726,
         0.0467, 0.0347, 0.0206],
        [0.8587, 0.6879, 0.6751, 0.5367, 0.3915, 0.2686, 0.1863, 0.1052, 0.0717,
         0.0476, 0.0353, 0.0211],
        [0.8549, 0.6867, 0.6738, 0.5355, 0.3922, 0.2689, 0.1875, 0.1062, 0.0728,
         0.0484, 0.0358, 0.0213],
        [0.8652, 0.6903, 0.6776, 0.5379, 0.3903, 0.2663, 0.1855, 0.1046, 0.0714,
         0.0467, 0.0343, 0.0202]])
