In [1]:
# coding=utf-8
import torch
from torch_geometric import transforms as T
torch.manual_seed(3407)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


### Step 1 定义模型和图数据

In [2]:
from src.model import GCN, DeepGCN, GAT
from src.tools import homo_data, Study, split_homo_graph

transform = T.Compose([T.NormalizeFeatures(), T.ToDevice(device)])
data, id_mapping = homo_data(
    "./data/", transform=transform, fill_mode="stats", return_id_mapping=True
)
test_graph, train_graph = split_homo_graph(
    data, [id_mapping["patient"][id] for id in range(1, 791)]
)
split = T.RandomLinkSplit(num_test=0.1, num_val=0.1)

In [10]:
# model = DeepGCN(data.num_features, 128, 64, 14).to(device)
model = GCN(data.num_features, 128, 64).to(device)
# model = GAT(data.num_features, 128, 64, 5).to(device)
study = Study(model, train_graph, split)

### Step 2 训练

In [14]:
CKPT_DIR = f"./ckpt/{model.name}/"
NUM_EPOCH = 100
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = torch.nn.BCEWithLogitsLoss()
study.train(NUM_EPOCH, optimizer, criterion, save_dir=CKPT_DIR)

Epoch: 1/100, Loss: 0.6847, Val: 0.8977, Test: 0.8916
Epoch: 2/100, Loss: 0.6846, Val: 0.8977, Test: 0.8916
Epoch: 3/100, Loss: 0.6844, Val: 0.8977, Test: 0.8916
Epoch: 4/100, Loss: 0.6846, Val: 0.8977, Test: 0.8916
Epoch: 5/100, Loss: 0.6844, Val: 0.8977, Test: 0.8916
Epoch: 6/100, Loss: 0.6844, Val: 0.8977, Test: 0.8916
Epoch: 7/100, Loss: 0.6844, Val: 0.8977, Test: 0.8916
Epoch: 8/100, Loss: 0.6842, Val: 0.8977, Test: 0.8916
Epoch: 9/100, Loss: 0.6844, Val: 0.8977, Test: 0.8916
Epoch: 10/100, Loss: 0.6844, Val: 0.8976, Test: 0.8916
Epoch: 11/100, Loss: 0.6842, Val: 0.8976, Test: 0.8916
Epoch: 12/100, Loss: 0.6844, Val: 0.8976, Test: 0.8916
Epoch: 13/100, Loss: 0.6844, Val: 0.8976, Test: 0.8915
Epoch: 14/100, Loss: 0.6842, Val: 0.8976, Test: 0.8915
Epoch: 15/100, Loss: 0.6842, Val: 0.8975, Test: 0.8915
Epoch: 16/100, Loss: 0.6841, Val: 0.8975, Test: 0.8914
Epoch: 17/100, Loss: 0.6839, Val: 0.8974, Test: 0.8914
Epoch: 18/100, Loss: 0.6840, Val: 0.8974, Test: 0.8913
Epoch: 19/100, Loss

### Step 3 测试

In [12]:
# test on the model of the last epoch
precision, recall, f1_score = study.test(test_graph, 7)
print(f"precision: {precision:.4f}\nrecall: {recall:.4f}\nf1_score: {f1_score}")

precision: 0.3058
recall: 0.4287
f1_score: 0.33144882321357727


### Step 4 寻找最佳模型和最佳的topK
这里的`topK`是指病人和药品之间概率最`topK`大的边

In [13]:
study.find_best(CKPT_DIR, test_graph, NUM_EPOCH * 2, range(1, NUM_EPOCH + 1), range(1, 20))

[I 2024-04-25 12:06:51,537] A new study created in memory with name: no-name-a4c98bb7-ef2c-44f1-be21-266dc20588f5
[I 2024-04-25 12:06:51,553] Trial 0 finished with value: 0.32252299785614014 and parameters: {'model_id': 38, 'topk': 10}. Best is trial 0 with value: 0.32252299785614014.
[I 2024-04-25 12:06:51,581] Trial 1 finished with value: 0.2961112856864929 and parameters: {'model_id': 20, 'topk': 14}. Best is trial 0 with value: 0.32252299785614014.
[I 2024-04-25 12:06:51,603] Trial 2 finished with value: 0.27908918261528015 and parameters: {'model_id': 34, 'topk': 17}. Best is trial 0 with value: 0.32252299785614014.
[I 2024-04-25 12:06:51,616] Trial 3 finished with value: 0.2736777067184448 and parameters: {'model_id': 61, 'topk': 18}. Best is trial 0 with value: 0.32252299785614014.
[I 2024-04-25 12:06:51,629] Trial 4 finished with value: 0.29640209674835205 and parameters: {'model_id': 35, 'topk': 14}. Best is trial 0 with value: 0.32252299785614014.
[I 2024-04-25 12:06:51,662] 

Best F1 score: 0.33144882321357727
Optimal parameters: {'model_id': 39, 'topk': 7}
