In [3]:
# ! pip install lifelines
# ! pip install scikit-survival

In [4]:
import torch
import torch.nn.functional as F
import numpy as np
from lifelines.utils import concordance_index
from sksurv.metrics import integrated_brier_score
from sksurv.util import Surv
from torch.utils.data import DataLoader
import pandas as pd

import modules.DataAnalysis as DataAnalysis
import modules.ModelAnalysis as ModelAnalysis
import modules.DataModify as DataModify
from modules.DataSelect import DataPreprocessing

import modules.Models as Models

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [5]:
def evaluate_deephit(model, test_loader, y_train, y_test, device='cuda'):
    """
    DeepHit 모델 평가 함수
    - C-index
    - Integrated Brier Score (IBS)
    
    마지막 시간 bin은 dummy이므로 제거 후 계산
    """
    model.eval()
    all_risk = []
    all_surv = []
    all_times = []
    all_events = []

    with torch.no_grad():
        for x, times, events in test_loader:
            x = x.to(device)
            _, pmf, cif = model(x)  # pmf, cif 반환 (B, num_events, time_bins)

            # -----------------------------
            # 마지막 더미 시간 bin 제거
            # -----------------------------
            pmf = pmf[:, :, :-1]       # (B, num_events, time_bins-1)
            cif = cif[:, :, :-1]
            survival = 1 - cif.sum(dim=1)  # (B, time_bins-1)

            # -----------------------------
            # Risk score 계산
            # -----------------------------
            risk_score = pmf.sum(dim=(1, 2))  # (B,)

            all_risk.append(risk_score.cpu())
            all_surv.append(survival.cpu())
            all_times.append(times.cpu())
            all_events.append(events.cpu())

    # -----------------------------
    # Tensor → NumPy 변환
    # -----------------------------
    risk_score = torch.cat(all_risk).numpy()
    survival = torch.cat(all_surv).numpy()
    times = torch.cat(all_times).numpy()
    events = torch.cat(all_events).numpy()

    # -----------------------------
    # Concordance Index 계산
    # -----------------------------
    c_index = concordance_index(
        event_times=times,
        predicted_scores=-risk_score,  # 위험도 높을수록 사건 빨리 발생
        event_observed=events
    )

    # -----------------------------
    # Integrated Brier Score 계산
    # -----------------------------
    # 최신 scikit-survival 대응
    y_test_surv = Surv.from_arrays(
        event=events.astype(bool),
        time=times.astype(float)
    )

    # 마지막 더미 시간 제거 후 최대 follow-up 시간
    max_time = int(y_test_surv["time"].max())
    survival = survival[:, :max_time]  # 마지막 dummy 시간 제거
    eval_times = np.arange(max_time)

    # IBS 계산
    ibs = integrated_brier_score(y_train, y_test_surv, survival, eval_times)

    # -----------------------------
    # 결과 출력
    # -----------------------------
    print(f"Concordance Index (C-index): {c_index:.4f}")
    print(f"Integrated Brier Score (IBS): {ibs:.4f}")

    return c_index, ibs

In [6]:
# 경로 지정
# CSV 읽기 + 첫 열 제거
df = pd.read_csv('./data/test dataset.csv')
df = df.drop(df.columns[0], axis=1)  # 첫 열 제거
df.to_csv('./data/test dataset_fixed.csv', index=False)

# Dataset 로드
test_file = ['./data/test dataset_fixed.csv']
test_dataset = DataModify.CancerDataset(
    target_column='event',
    time_column='time',
    file_paths=test_file
)

test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# IBS 계산용 Surv 형식 생성
test_times = test_dataset.time.numpy()
test_events = test_dataset.target.numpy()

# y_test만 있으면 IBS 계산 시 train은 동일 형식 dummy로 전달 가능
y_test = Surv.from_arrays(event=test_events.astype(bool),
                          time=test_times.astype(float))


In [7]:
input_params_path = './parameters/deephit_model_2D_CNN.pth'

input_dim = 17   # input dimension : data의 feature의 개수
hidden_size = (128, 64)             # 1번째, 2번째 hidden layer의 size
time_bins = 91                     # 3개월 단위로 time을 split하여 각 구간으로 삼음 -> 270개월+ 는 하나로 취급
num_events = 4                      # 사건의 개수

# 모델 선언
model = Models.DeepHitSurvWithSEBlockAnd2DCNN(input_dim, hidden_size, time_bins, num_events, dropout=.2).to(device)
model.load_state_dict(torch.load(input_params_path, map_location=device))
model.to(device)
model.eval()  # 평가 모드

DeepHitSurvWithSEBlockAnd2DCNN(
  (se_block): Sequential(
    (0): Linear(in_features=17, out_features=4, bias=True)
    (1): ReLU()
    (2): Linear(in_features=4, out_features=17, bias=True)
    (3): Sigmoid()
  )
  (se_block_event): ModuleList(
    (0-3): 4 x Sequential(
      (0): Linear(in_features=64, out_features=16, bias=True)
      (1): ReLU()
      (2): Linear(in_features=16, out_features=64, bias=True)
      (3): Sigmoid()
    )
  )
  (shared): Sequential(
    (0): Linear(in_features=17, out_features=128, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=128, out_features=64, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.2, inplace=False)
  )
  (heads): ModuleList(
    (0-3): 4 x Linear(in_features=64, out_features=91, bias=True)
  )
  (conv2d_block): Sequential(
    (0): Conv2d(1, 8, kernel_size=(2, 5), stride=(1, 1), padding=(1, 2))
    (1): ReLU()
    (2): Conv2d(8, 16, kernel_size=(2, 3), stride=(1, 1), padding=(0, 1))
    (3)

In [8]:
y_train_dummy = y_test.copy()

# 평가 실행
c_index, ibs = evaluate_deephit(model, test_loader, y_train_dummy, y_test, device=device)


Concordance Index (C-index): 0.8263
Integrated Brier Score (IBS): 0.2005
