In [31]:
import time
import torch
import json
import numpy as np
from torch.utils.data import DataLoader

In [32]:
from src.loss import *
from src.dataset import KGData, Dataset
from src.utils import *
from src.model import GAEA

In [33]:
current_task = "en_fr_15k"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with open(f"args/{current_task}.json", 'r') as f:
    args_dict = json.load(f)
    f.close()
class ARGs:
    def __init__(self, dic):
        for k, v in dic.items():
            setattr(self, k, v)
args = ARGs(args_dict)

In [34]:
# load knowledge graph data
kgdata = KGData(
    model="gaea", 
    task=args.task, 
    device=device, 
    neg_samples_size=args.neg_samples_size, 
    fold=args.fold, 
    train_ratio=0.3, 
    val=True, 
    direct=True
)
kgdata.data_summary()
batchsize = kgdata.train_pair_size
train_dataset = Dataset(np.array(kgdata.mapped_train_pair))
train_loader = DataLoader(dataset=train_dataset, batch_size=batchsize, shuffle=False)

[loading KG data...]

--------------dataset summary--------------

current task: en_fr_15k, file direction: data/OpenEA_dataset_v2.0/EN_FR_15K_V1

current fold: 1

entity num: 30000, duplicated entity num: 0

entity num of kg1: 15000, min index: 0, max index: 14999

entity num of kg2: 15000, min index: 15000, max index: 29999

relation num: 342

triple num: 88198

training samples: 3000, test samples: 10500

validation samples: 1500

-------------------------------------------



In [35]:
# load model/optimizer/loss function
set_random_seed(args.seed)
model = GAEA(
    num_sr=kgdata.kg1_ent_num, 
    num_tg=kgdata.kg2_ent_num, 
    adj_sr=kgdata.tensor_adj1, 
    adj_tg=kgdata.tensor_adj2, 
    rel_num=kgdata.rel_num, 
    rel_adj_sr=kgdata.tensor_rel_adj1, 
    rel_adj_tg=kgdata.tensor_rel_adj2,
    args=args
).to(device)
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.weight_decay)
loss_fn = margin_based_loss

- current seed is 42



In [36]:
def evaluate(test_pair, k, eval_metric, phase="test", model_path=None):
    if phase == "test" and args.val and args.save:
        with open(model_path, "rb") as f:
            model.load_state_dict(torch.load(f))
    
    model.eval()

    sr_embedding, tg_embedding = model(phase="eval")
    sr_embedding = sr_embedding.detach().cpu().numpy() # Returns a new Tensor, detached from the current graph. The result will never require gradient. Before transform to numpy need transfer to cpu firstly.
    tg_embedding = tg_embedding.detach().cpu().numpy()

    Lvec = np.array([sr_embedding[e1] for e1, _ in test_pair])
    Rvec = np.array([tg_embedding[e2] for _, e2 in test_pair])
    del sr_embedding, tg_embedding

    if eval_metric == "euclidean":
        hit_1_score, hit_k_score, mrr, mean = cal_metrics_faiss(Lvec, Rvec, test_pair, k=k)
        del Lvec, Rvec;
    else:
        '''step 1: generate sim mat'''
        similarity_matrix = cal_distance(Lvec, Rvec, eval_metric, csls_k=args.csls_k) # Note that the ground truth alignment relation is on the diagonal of similarity matrix
        '''step 2: calculate the hit@1, hit@k, and MRR'''
        hit_1_score, hit_k_score, mrr, mean = cal_metrics(similarity_matrix, Lvec, Rvec, test_pair, k=k)
        del similarity_matrix, Lvec, Rvec;
    
    return hit_1_score, hit_k_score, mrr, mean

In [37]:
print("[start training...]\n")
t_start = time.time()
best_val = 0.0
bad_count = 0
sr_embedding, tg_embedding = None, None
neg1_left, neg1_right, neg2_left, neg2_right = kgdata.generate_neg_sample(neg_samples_size=args.neg_samples_size)
pr1, pr2 = random.uniform(0, args.pr), random.uniform(0, args.pr)
aug_adj1, aug_rel_adj1 = kgdata.generate_aug_graph(kgdata.triples1, kgdata.kg1_ent_num, kgdata.rel_num, kgdata.kg1_ent_ids, kgdata.ent2node1, kgdata.d_v1, pr=pr1)
aug_adj2, aug_rel_adj2 = kgdata.generate_aug_graph(kgdata.triples2, kgdata.kg2_ent_num, kgdata.rel_num, kgdata.kg2_ent_ids, kgdata.ent2node2, kgdata.d_v2, pr=pr2)

for e in range(args.epoch):
    model.train()
    '''model training'''
    for _, data in enumerate(train_loader):
        optimizer.zero_grad()
        a1_align, a2_align = data
        sr_embedding, tg_embedding = model(phase="norm")
        aug_sr_embedding, aug_tg_embedding = model(aug_adj1, aug_rel_adj1, aug_adj2, aug_rel_adj2, phase="augment")
        '''alignment loss'''
        loss = loss_fn(aug_sr_embedding, aug_tg_embedding, a1_align, a2_align, neg1_left, neg1_right, neg2_left, neg2_right, neg_samples_size=args.neg_samples_size, loss_norm=args.loss_norm, pos_margin=args.pos_margin, neg_margin=args.neg_margin, neg_param=args.neg_param)
        '''contrastive loss'''
        aug_loss1 = model.contrastive_loss(sr_embedding, aug_sr_embedding, kgdata.kg1_ent_num)
        aug_loss2 = model.contrastive_loss(tg_embedding, aug_tg_embedding, kgdata.kg2_ent_num)
        loss = loss + args.aug_balance * (aug_loss1 + aug_loss2)
        '''multi-loss learning'''
        loss.backward()
        optimizer.step()
    if (e+1) % 10 == 0:
        print(f"epoch: {e+1}, loss: {round(loss.item(), 3)}, time: {round((time.time()-t_start), 2)}\n")
    
    '''validation phase'''
    if args.val and (e+1) % args.val_iter == 0 and e > 0 and (e+1) >= args.val_start:
        hit_1_score, hit_k_score, mrr, mean = evaluate(test_pair=kgdata.mapped_val_pair, k=args.k, eval_metric=args.eval_metric, phase="val")
        if best_val < mrr:
            bad_count = 0
            best_val = mrr
            if args.save:
                with open(model_path, "wb") as f:
                    torch.save(model.state_dict(), f)
            print(f"[validation: epoch: {e+1}, hit@1: {round(hit_1_score, 3)}, hit@{args.k}: {round(hit_k_score, 3)}, mrr is {round(mrr, 3)}]\n")
        else:
            bad_count += 1
        if bad_count == args.patience:
            print("[training end!]")
            break;
    
    '''update negative sampling'''
    if (e+1) % args.neg_iter == 0 and e > 0 and e+1 != args.epoch:
        neg1_left, neg1_right, neg2_left, neg2_right = kgdata.update_neg_sample(sr_embedding.detach().cpu().numpy(), tg_embedding.detach().cpu().numpy(), neg_samples_size=args.neg_samples_size, eval_metric=args.neg_metric, csls_k=args.csls_k, e=args.truncated_epsilon)
    
    '''update augmented knowledge graph'''
    if (e+1) % args.aug_iter == 0 and e > 0 and args.pr != 0 and e+1 != args.epoch:
        pr1, pr2 = random.uniform(0, args.pr), random.uniform(0, args.pr)
        aug_adj1, aug_rel_adj1 = kgdata.generate_aug_graph(kgdata.triples1, kgdata.kg1_ent_num, kgdata.rel_num, kgdata.kg1_ent_ids, kgdata.ent2node1, kgdata.d_v1, pr=pr1)
        aug_adj2, aug_rel_adj2 = kgdata.generate_aug_graph(kgdata.triples2, kgdata.kg2_ent_num, kgdata.rel_num, kgdata.kg2_ent_ids, kgdata.ent2node2, kgdata.d_v2, pr=pr2)
    
    t_start = time.time()
    del sr_embedding, tg_embedding

[start training...]

[generate augmented knowledge graph with pr=0.064...]

[generate augmented knowledge graph with pr=0.003...]

epoch: 10, loss: 7026.826, time: 0.26

[updating negative samples...]

[generate augmented knowledge graph with pr=0.041...]

[generate augmented knowledge graph with pr=0.025...]

epoch: 20, loss: 4118.933, time: 0.25

[updating negative samples...]

[generate augmented knowledge graph with pr=0.02...]

[generate augmented knowledge graph with pr=0.092...]

epoch: 30, loss: 3404.04, time: 0.25

[updating negative samples...]

[generate augmented knowledge graph with pr=0.007...]

[generate augmented knowledge graph with pr=0.065...]

epoch: 40, loss: 2914.255, time: 0.25

[updating negative samples...]

[generate augmented knowledge graph with pr=0.019...]

[generate augmented knowledge graph with pr=0.023...]

epoch: 50, loss: 2540.283, time: 0.27

[validation: epoch: 50, hit@1: 0.69, hit@5: 0.877, mrr is 0.773]

[updating negative samples...]

[generate 

In [38]:
hit_1_score, hit_k_score, mrr, mean = evaluate(test_pair=kgdata.mapped_test_pair, k=args.k, eval_metric=args.eval_metric, phase="test")
print("----------------final score----------------\n")
print(f"+ task: {current_task}\n")
print(f"+ Hit@1: {round(hit_1_score, 3)}\n")
print(f"+ Hit@{args.k}: {round(hit_k_score, 3)}\n")
print(f"+ MRR: {round(mrr, 3)}\n")
print(f"+ mean rank: {round(mean, 3)}\n")
print("-------------------------------------------\n")

----------------final score----------------

+ task: en_fr_15k

+ Hit@1: 0.488

+ Hit@5: 0.739

+ MRR: 0.601

+ mean rank: 65.704

-------------------------------------------

