In [6]:
import os
import pandas as pd
import torch
import sys; sys.path.append("..")

from typing import List
from tqdm import tqdm
from sklearn.metrics import roc_auc_score

from dataset.unified import SourceDataFrames, OneAdmOneHG
from model.backbone import BackBoneV2
from utils.metrics import convert2df, convert2df_v2, calc_gauc
from utils.misc import get_latest_model_ckpt, init_seed
from utils.config import HeteroGraphConfig, GNNConfig

In [2]:
# Hyperparameters
init_seed(3407, False)
device = torch.device('cuda') # torch.device('cpu')

# GNN configuration
node_types, edge_types = HeteroGraphConfig.use_all_edge_type()
gnn_conf = GNNConfig("GINEConv", 2, node_types, edge_types)

In [3]:
# dataset
sources_dfs = SourceDataFrames(r"..\data\mimic-iii-clinical-database-1.4")
test_dataset = OneAdmOneHG(sources_dfs, "test")

In [4]:
# model init & load trained weight
model = BackBoneV2(
    sources_dfs, goal="drug", h_dim=256, gnn_conf=gnn_conf, device=device, 
    num_enc_layers=3, embedding_size=10, is_gnn_only=False, init_method="xavier_normal"
).to(device)
sd_path = os.path.join(r"..\model\hub", "loss_0.0866_backbonev2_goal_drug.pt")
sd = torch.load(sd_path, map_location=device)
model.load_state_dict(sd)

<All keys matched successfully>

In [5]:
model.eval()
with torch.no_grad():
    collector: List[pd.DataFrame] = []
    for i, hg in tqdm(enumerate(test_dataset), leave=False, ncols=100, total=len(test_dataset), ascii=True):
        hg = hg.to(device)
        logits, labels, iids = model(hg)
        collector.append(convert2df_v2(logits, labels, i, iids))
results: pd.DataFrame = pd.concat(collector, axis=0)

                                                                                                    

In [7]:
results.head()

Unnamed: 0,user_id,item_id,score,label,day
0,0,2730,5.1e-05,0.0,1
1,0,1826,0.00317,0.0,1
2,0,328,0.00426,0.0,1
3,0,1956,0.945632,1.0,1
4,0,920,0.065332,0.0,1


In [None]:
filtered_df = results[results.groupby('user_id')['label'].transform('sum') > 0]
group_counts = filtered_df.groupby('user_id').size().reset_index(name='cnt')

auc_scores = filtered_df.groupby('user_id').apply(lambda x: roc_auc_score(x['label'], x['score'])).reset_index()
auc_scores.columns = ['user_id', 'auc_score']

res = pd.merge(group_counts, auc_scores, on='user_id')
res.sort_values(by='cnt', ascending=False, inplace=True)
res

Unnamed: 0,user_id,cnt,auc_score
4038,4040,8340,0.998513
23,23,7311,0.998027
92,92,7119,0.996751
4321,4323,7044,0.998691
2923,2924,7011,0.994444
...,...,...,...
1448,1448,6,1.000000
4163,4165,6,1.000000
2153,2154,6,1.000000
4327,4329,3,1.000000


In [14]:
res = results.groupby('day').apply(lambda x: calc_gauc(x, 'user_id')).reset_index()
res
# res.columns = ['day', 'auc_score']

Unnamed: 0,day,0
0,1,0.9852
1,2,0.9933
2,3,0.9944
3,4,0.9948
4,5,0.9954
5,6,0.9955
6,7,0.9953
7,8,0.9957
8,9,0.9957
9,10,0.996


In [None]:
gauc = calc_gauc(results, 'day')
print(gauc)