Skip to content

Commit

Permalink
[Feature] Update KG methods (#381)
Browse files Browse the repository at this point in the history
  • Loading branch information
QingFei1 committed Aug 23, 2022
1 parent c7f35aa commit 4c8737c
Showing 1 changed file with 4 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
import torch
import torch.nn as nn
import os
import json
import numpy as np
from .. import ModelWrapper

from cogdl.utils.link_prediction_utils import cal_mrr, DistMultLayer, ConvELayer
from cogdl.datasets.kg_data import BidirectionalOneShotIterator, TestDataset, TrainDataset
from tqdm import tqdm
import torch.nn.functional as F


Expand Down Expand Up @@ -89,12 +84,11 @@ def val_step(self, subgraph):
def eval_step(self, subgraph):
test_dataloader_head, test_dataloader_tail = subgraph
logs = []
step = 0
test_dataset_list = [test_dataloader_head, test_dataloader_tail]
total_steps = sum([len(dataset) for dataset in test_dataset_list])

for test_dataset in test_dataset_list:
for positive_sample, negative_sample, filter_bias, mode in test_dataset:
pbar = tqdm(test_dataset)
for positive_sample, negative_sample, filter_bias, mode in pbar:
pbar.set_description("Evaluating the model: Use mode({})".format(mode))
positive_sample = positive_sample.to(self.device)
negative_sample = negative_sample.to(self.device)
filter_bias = filter_bias.to(self.device)
Expand Down Expand Up @@ -131,11 +125,6 @@ def eval_step(self, subgraph):
}
)

if step % 1000 == 0:
print("Evaluating the model... (%d/%d)" % (step, total_steps))

step += 1

metrics = {}
for metric in logs[0].keys():
metrics[metric] = sum([log[metric] for log in logs]) / len(logs)
Expand Down

0 comments on commit 4c8737c

Please sign in to comment.