In [1]:
import time
import torch
import json
import numpy as np
from torch.utils.data import DataLoader

In [2]:
from src.dataset import *
from src.utils import *
from src.model import KGEncoder

In [3]:
current_task = "en_de_15k"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with open(f"args/{current_task}.json", 'r') as f:
    args_dict = json.load(f)
    f.close()
class ARGs:
    def __init__(self, dic):
        for k, v in dic.items():
            setattr(self, k, v)
args = ARGs(args_dict)

In [4]:
# load knowledge graph data
kgdata = KGData(args.task, args.fold, setting=args.setting)
train_pair, valid_pair, test_pair = kgdata.load_pair_data()
adj_matrix, r_index, r_val, adj_features, rel_features, rel_features_top, rdict, rel_in, rel_out = kgdata.load_matrix_data()
unlabeled_pair = np.concatenate((valid_pair, test_pair), axis=0)
unlabeled_pair = copy.deepcopy(unlabeled_pair)
unlabeled_s = [e1 for e1, e2 in unlabeled_pair]
unlabeled_t = [e2 for e1, e2 in unlabeled_pair]
np.random.shuffle(unlabeled_s)
np.random.shuffle(unlabeled_t)

ent_num = kgdata.ent_num
rel_num = kgdata.rel_num
triple_num = kgdata.triple_num
adj_matrix = sparse_mx_to_torch_sparse_tensor(normalize_adj(adj_matrix)).to(device)
rel_in = sparse_mx_to_torch_sparse_tensor(rel_in).to(device)
rel_out = sparse_mx_to_torch_sparse_tensor(rel_out).to(device)

KG 1 info: #ent. 15000, #rel. 215, #tri. 47676
KG 2 info: #ent. 15000, #rel. 131, #tri. 50419



In [5]:
# load model/optimizer/loss function
set_random_seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\ncurrent device is \033[92m{device}\033[0m \n")

model = KGEncoder(args, ent_num=ent_num, adj_matrix=adj_matrix, rel_features=(rel_in, rel_out), device=device, name="student")
_model = KGEncoder(args, ent_num=ent_num, adj_matrix=adj_matrix, rel_features=(rel_in, rel_out), device=device, name="teacher")
model = model.to(device=device)
_model = _model.to(device=device)

optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.weight_decay)
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
train_dataset = Dataset(np.array(train_pair))
train_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=False)
unlabeled_dataset = Dataset(np.array(unlabeled_pair))
unlabeled_loader = DataLoader(dataset=unlabeled_dataset, batch_size=len(unlabeled_pair), shuffle=False)

current seed is set to [92m42[0m 


current device is [92mcuda[0m 



In [6]:
print("--------------------INFO--------------------\n")
print(f'- current task: \033[93m{args.task}\033[0m\n')
print(f'- #entity: \033[93m{ent_num}\033[0m\n')
print(f'- #relation: \033[93m{rel_num}\033[0m\n')
print(f'- #triple: \033[93m{triple_num}\033[0m\n')
print(f'- #labeled number: \033[93m{len(train_pair)+len(valid_pair)+len(test_pair)}\033[0m\n')
print(f'- #batch size: \033[93m{args.batch_size}\033[0m\n')
print(f'- #total params: \033[93m{pytorch_total_params}\033[0m\n')
print("--------------------------------------------\n")

--------------------INFO--------------------

- current task: [93men_de_15k[0m

- #entity: [93m30000[0m

- #relation: [93m239[0m

- #triple: [93m172388[0m

- #labeled number: [93m15000[0m

- #batch size: [93m5000[0m

- #total params: [93m7712128[0m

--------------------------------------------



In [7]:
def eval_entity_alignment_faiss(Lvec, Rvec, test_num, k, eval_metric, eval_normalize=True):
    '''
        calculate evaluation metrics: hit@1, hit@k, MRR, Mean etc.
        using faiss accelerate alignment inference: https://github.com/facebookresearch/faiss
    '''
    if eval_normalize:
        Lvec = preprocessing.normalize(Lvec)
        Rvec = preprocessing.normalize(Rvec)
    assert test_num == Lvec.shape[0]
    mrr = 0
    mean = 0
    hit_1_score = 0
    hit_k_score = 0
    if eval_metric == "l2":
        index = faiss.IndexFlatL2(Rvec.shape[1]) # create index base with fixed dimension
    elif eval_metric == "inner":
        index = faiss.IndexFlatIP(Rvec.shape[1])
    else:
        assert ValueError
    index.add(np.ascontiguousarray(Rvec)) # add key to index base
    del Rvec;
    _, I = index.search(np.ascontiguousarray(Lvec), test_num) # search query in index base
    for idx in range(Lvec.shape[0]):
        rank_index = np.where(I[idx,:]==idx)[0][0]
        rank_index += 1
        mean += (rank_index)
        mrr += 1.0 / (rank_index)
        if rank_index <= 1: # hit@1
            hit_1_score += 1
        if rank_index <= k: # hit@k
            hit_k_score += 1
    mrr = mrr / test_num
    hit_1_score = hit_1_score / test_num
    hit_k_score = hit_k_score / test_num
    mean = mean / test_num
    return hit_1_score, hit_k_score, mrr, mean

In [8]:
# begin training
t_total_start = time.time()
best_score = 0.0
hit1_st = 1.0
hit1_ts = 0.0
model_path = f"save/{args.task}"

for e in range(1, args.epoch):
    model.train()
    _model.train()
    global bad_count
    align_total_loss = 0.0
    pseudo_total_loss = 0.0
    t_start = time.time()
    adjust_learning_rate(optimizer, e, args.lr)

    if e % args.neg_iter == 1:
        neg_sample_list = list()

    # supervised alignment learning with labeled data
    for idx, data in enumerate(train_loader):
        model.train()
        kg1_align, kg2_align = data
        kg1_align, kg2_align = np.array(kg1_align), np.array(kg2_align)
        vec = model()
        # negtive sampling
        if e % args.neg_iter == 1:
            neg_left, neg_right = kgdata.negative_sampling(
                bsize=len(kg1_align), 
                kg1_align=kg1_align, 
                kg2_align=kg2_align, 
                neg_samples_size=args.neg_samples_size, 
                target_left=unlabeled_s, 
                target_right=unlabeled_t, 
                vec=vec.detach().cpu().numpy(), 
                e=args.truncated_epsilon
            )
            neg_sample_list.append([neg_left, neg_right])
        else:
            neg_left, neg_right = neg_sample_list[idx]
        
        align_loss = model.alignment_loss(
            vec, 
            kg1_align, 
            kg2_align, 
            neg_left, 
            neg_right, 
            neg_samples_size=args.neg_samples_size, 
            neg_margin=args.neg_margin, 
            dist=args.dist
        )
        align_total_loss += align_loss
    
    # pseudo mapping learning with unlabeled data
    for _, data in enumerate(unlabeled_loader):
        torch.cuda.empty_cache()
        kg1_ids, kg2_ids = data
        kg1_ids, kg2_ids = np.array(kg1_ids), np.array(kg2_ids)
        vec = model()
        with torch.no_grad():
            _vec = _model()
        pseudo_loss = model.pseudo_ce_loss(ent_embedding1=vec[kg1_ids], ent_embedding2=vec[kg2_ids], ent_embedding3=_vec[kg1_ids], ent_embedding4=_vec[kg2_ids], hit1_st=hit1_st, hit1_ts=hit1_ts)
        r = args.consistency * sigmoid_rampup(e, args.consistency_rampup)
        pseudo_loss = pseudo_loss * r
        pseudo_total_loss += pseudo_loss
    
    # the final objective
    if not args.il:
        loss = align_total_loss + pseudo_total_loss
    else:
        loss = align_total_loss
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    # update teacher model
    _model.update(model, epoch=e)

    if e >= args.val_start and e % args.val_iter == 0 and args.val:
        with torch.no_grad():
            model.eval()
            vec = model()
            vec = vec.detach().cpu().numpy()
            if args.record:
                Lvec = np.array([vec[e] for e in test_pair[:,0]])
                Rvec = np.array([vec[e] for e in test_pair[:,1]])
                del vec;
                hit1, hitk, mrr, mr = eval_entity_alignment_faiss(Lvec, Rvec, test_num=len(test_pair), k=args.k, eval_metric=args.eval_metric)
                print(f"[test: epoch: {e}, hit@1: {round(hit1, 3)}, hit@{args.k}: {round(hitk, 3)}, mrr is {round(mrr, 3)}]\n")
                with open(f'result/{args.task}_test.csv', 'a', encoding='utf-8') as file:
                    file.write('\n')
                    file.write(f"{e}, {round(hit1, 3)}, {round(hitk, 3)}, {round(mrr, 3)}, {round(mr, 3)}")
            else:
                Lvec = np.array([vec[e] for e in valid_pair[:,0]])
                Rvec = np.array([vec[e] for e in valid_pair[:,1]])
                del vec;
                hit1_st, hitk, mrr, mr = eval_entity_alignment_faiss(Lvec, Rvec, test_num=len(valid_pair), k=args.k, eval_metric=args.eval_metric)
                print(f"[validation: epoch: {e}, hit@1: {round(hit1_st, 3)}, hit@{args.k}: {round(hitk, 3)}, mrr is {round(mrr, 3)}]\n")
                hit1_ts, _, _, _ = eval_entity_alignment_faiss(Rvec, Lvec, test_num=len(valid_pair), k=args.k, eval_metric=args.eval_metric)
                with open(model_path, "wb") as f:
                    torch.save(model.state_dict(), f)
                if hit1_st > best_score:
                    bad_count = 0
                    best_score = hit1_st
                else:
                    bad_count = bad_count + 1
                if bad_count == args.patience:
                    break;

[generate negative samples...]

[validation: epoch: 10, hit@1: 0.85, hit@5: 0.953, mrr is 0.894]

[generate negative samples...]

[validation: epoch: 20, hit@1: 0.86, hit@5: 0.958, mrr is 0.903]

[generate negative samples...]

[validation: epoch: 30, hit@1: 0.864, hit@5: 0.959, mrr is 0.905]

[generate negative samples...]

[validation: epoch: 40, hit@1: 0.863, hit@5: 0.961, mrr is 0.905]

[generate negative samples...]

[validation: epoch: 50, hit@1: 0.865, hit@5: 0.96, mrr is 0.906]

[generate negative samples...]

[validation: epoch: 60, hit@1: 0.871, hit@5: 0.959, mrr is 0.91]

[generate negative samples...]

[validation: epoch: 70, hit@1: 0.864, hit@5: 0.96, mrr is 0.906]

[generate negative samples...]

[validation: epoch: 80, hit@1: 0.867, hit@5: 0.959, mrr is 0.907]

[generate negative samples...]

[validation: epoch: 90, hit@1: 0.867, hit@5: 0.959, mrr is 0.907]

[generate negative samples...]

[validation: epoch: 100, hit@1: 0.865, hit@5: 0.959, mrr is 0.906]

[generate nega

In [9]:
# test phase
total_time = int(time.time()-t_total_start)
with torch.no_grad():
    with open(model_path, "rb") as f:
        model.load_state_dict(torch.load(f))
model.eval()
vec = model()
vec = vec.detach().cpu().numpy()
Lvec = np.array([vec[e] for e in test_pair[:,0]])
Rvec = np.array([vec[e] for e in test_pair[:,1]])
del vec;
hit1, hitk, mrr, mr = eval_entity_alignment_faiss(Lvec, Rvec, test_num=len(test_pair), k=args.k, eval_metric=args.eval_metric)
hit1, hitk, mrr, mr = round(hit1, 3), round(hitk, 3), round(mrr, 3), round(mr, 3)
print(f'+ task: {args.task}')
print(f'+ Hit@1: \033[94m{hit1}\033[0m')
print(f'+ Hit@k: \033[94m{hitk}\033[0m')
print(f'+ MRR: \033[94m{mrr}\033[0m')
print(f'+ MR: \033[94m{mr}\033[0m')

+ task: en_de_15k
+ Hit@1: [94m0.725[0m
+ Hit@k: [94m0.878[0m
+ MRR: [94m0.792[0m
+ MR: [94m31.602[0m
