In [1]:
import pandas as pd
import numpy as np
import torch
from sklearn.preprocessing import StandardScaler

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

df = pd.read_csv("/content/drive/MyDrive/世界モデル_最終課題/timeseries.csv")

ID_COL = "patient_id"
TIME_COL = "time_days"

feature_cols = [
    c for c in df.columns
    if c not in [ID_COL, TIME_COL]
    and not c.startswith("met_")
    and not c.startswith("tx_")
]

F = len(feature_cols)
print("num features:", F)

scaler = StandardScaler()
df[feature_cols] = scaler.fit_transform(df[feature_cols])

def make_triples(df):
    """
    (x_t, x_{t+1}) でグラフを作って x_{t+2} を予測するためのサンプルを作る
    """
    triples = []
    for pid, g in df.groupby(ID_COL):
        g = g.sort_values(TIME_COL)
        x = g[feature_cols].values
        t = g[TIME_COL].values

        # t, t+1, t+2 が必要なので -2 まで
        for i in range(len(g) - 2):
            triples.append({
                "pid": pid,
                "x_t": x[i],
                "x_tp1": x[i+1],
                "x_tp2": x[i+2],
                "dt01": float(t[i+1] - t[i]),
                "dt12": float(t[i+2] - t[i+1]),
            })
    return triples

triples = make_triples(df)
print("num triples:", len(triples))

import torch.nn as nn
import torch.nn.functional as torch_F

class CrossTimeAttention(nn.Module):
    def __init__(self, in_dim, attn_dim):
        super().__init__()
        self.q = nn.Linear(in_dim, attn_dim, bias=False)
        self.k = nn.Linear(in_dim, attn_dim, bias=False)

    def forward(self, x_t, x_tp1):
        """
        x_t, x_tp1: (F, in_dim)
        """
        Q = self.q(x_t)
        K = self.k(x_tp1)
        A = Q @ K.T / np.sqrt(Q.size(-1))
        return torch_F.softmax(A, dim=1)  # row-wise

!pip -q install torch-geometric

from torch_geometric.nn import GCNConv, dense_diff_pool

class DiffPoolItemNet(nn.Module):
    def __init__(self, in_dim, hidden_dim, num_clusters):
        super().__init__()
        self.gnn_embed = GCNConv(in_dim, hidden_dim)
        self.gnn_pool  = GCNConv(in_dim, num_clusters)

    def forward(self, x, adj):
        # dense adj -> edge_index/edge_weight
        edge_index = adj.nonzero().T
        edge_weight = adj[edge_index[0], edge_index[1]]

        z = self.gnn_embed(x, edge_index, edge_weight)
        s = self.gnn_pool(x, edge_index, edge_weight)

        x_pool, adj_pool, _, _ = dense_diff_pool(
            z.unsqueeze(0),
            adj.unsqueeze(0),
            s.unsqueeze(0)
        )
        return x_pool.squeeze(0), adj_pool.squeeze(0), s

device: cuda
num features: 42
num triples: 123742


In [2]:
class ItemGraphEncoder(nn.Module):
    def __init__(self, num_clusters, F_features):
        super().__init__()
        self.F_features = F_features
        self.attn = CrossTimeAttention(1, 16)
        self.diffpool = DiffPoolItemNet(4, 16, num_clusters) # Reduced hidden_dim from 32 to 16
        self.latent_dim = num_clusters * 16   # Adjusted latent_dim accordingly

    def forward(self, triple):
        x_t   = torch.tensor(triple["x_t"],   dtype=torch.float, device=device)
        x_tp1 = torch.tensor(triple["x_tp1"], dtype=torch.float, device=device)
        dt01  = torch.tensor(triple["dt01"],  dtype=torch.float, device=device)

        A = self.attn(x_t.unsqueeze(1), x_tp1.unsqueeze(1))

        H = torch.stack([
            x_t,
            x_tp1,
            x_tp1 - x_t,
            torch.ones(self.F_features, device=device) * dt01,
        ], dim=1)

        x_pool, _, S = self.diffpool(H, A)
        z = x_pool.flatten()   # ← ここが GRU に入る
        return z, S

In [3]:
class PatientGRUModel(nn.Module):
    def __init__(self, latent_dim, F):
        super().__init__()
        self.gru = nn.GRU(
            input_size=latent_dim,
            hidden_size=128,
            batch_first=True
        )
        self.decoder = nn.Linear(128, F)

    def forward(self, z_seq):
        # z_seq: (1, T-2, latent_dim)
        out, _ = self.gru(z_seq)
        h_last = out[:, -1]
        return self.decoder(h_last).squeeze(0)


In [4]:
from collections import defaultdict

triples_by_pid = defaultdict(list)
for tr in triples:
    triples_by_pid[tr["pid"]].append(tr)

In [5]:
encoder = ItemGraphEncoder(num_clusters=3, F_features=F).to(device) # Reduced num_clusters from 5 to 3
gru_model = PatientGRUModel(
    latent_dim=encoder.latent_dim,
    F=F
).to(device)

optimizer = torch.optim.Adam(
    list(encoder.parameters()) + list(gru_model.parameters()),
    lr=1e-3
)


In [None]:
num_epochs = 50

for epoch in range(num_epochs):
    encoder.train()
    gru_model.train()

    total_epoch_loss = 0.0
    n_pred_epoch = 0

    for pid, patient_triples in triples_by_pid.items():
        if len(patient_triples) < 2:
            continue

        # --- 1. 全時点の z を計算 ---
        z_list = []
        for tr in patient_triples:
            z, _ = encoder(tr)
            z_list.append(z)

        z_seq = torch.stack(z_list).unsqueeze(0)  # (1, T, D)

        patient_loss = 0.0  # Accumulate loss per patient
        n_pred_patient = 0

        # --- 2. 各時点で予測 & loss ---
        for t in range(len(patient_triples)):
            if t == 0:
                continue  # GRUは最低1ステップ必要

            pred = gru_model(z_seq[:, :t+1])   # 過去→現在
            target = torch.tensor(
                patient_triples[t]["x_tp2"],
                dtype=torch.float,
                device=device
            )

            loss = torch_F.mse_loss(pred, target)
            patient_loss += loss
            n_pred_patient += 1

        if n_pred_patient > 0:
            patient_loss = patient_loss / n_pred_patient

            # Backpropagate and update weights for this patient
            optimizer.zero_grad()
            patient_loss.backward()  # Backpropagate patient_loss
            optimizer.step()

            total_epoch_loss += patient_loss.item() * n_pred_patient # Accumulate detached loss for epoch display
            n_pred_epoch += n_pred_patient

    if n_pred_epoch > 0:
        avg_epoch_loss = total_epoch_loss / n_pred_epoch
    else:
        avg_epoch_loss = 0.0

    if epoch % 5 == 0:
        print(f"epoch {epoch}: loss={avg_epoch_loss:.4f}")

In [None]:
encoder.eval()
with torch.no_grad():
    _, S = encoder(patient_triples[0])

S = S.softmax(dim=1).cpu().numpy()


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

plt.figure(figsize=(8, max(4, F*0.15)))
sns.heatmap(
    S, cmap="viridis",
    yticklabels=feature_cols,
    xticklabels=[f"C{k}" for k in range(S.shape[1])]
)
plt.title("Item → Cluster Assignment (from x_t, x_{t+1})")
plt.tight_layout()
plt.show()