In [None]:
import argparse

import itertools
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import softmax, degree
from torch_sparse import spmm
from RAGAutils import composeS, add_inverse_rels, get_hits, get_hits_stable, get_hits_from_S

from data import DBP15K
from model import *

from toolbox.RandomSeeds import set_seeds
from toolbox.DataSchema import cache_data, read_cache
from pathlib import Path
from utils import load_alignment_pair



set_seeds()

In [None]:
class CFG:
    def __init__(self):
        self.cuda = True
        self.data = 'data'
        self.lang = 'zh_en'
        self.rate = 0.3
        self.r_hidden = 100
        self.k = 5
        self.gamma = 3
        self.epoch = 80
        self.neg_epoch = 10
        self.test_epoch = 10
        self.reset_epoch = 10
        self.stable_test = True
        
        self.keep_seeds = 2000
        self.new_train_seeds = 7000
        self.neg_seeds = 1000

In [None]:
# CUDA_VISIBLE_DEVICES=0 python RAGAtrainV10.py --reset_epoch 20 --epoch 160 --r_hidden 300 --test_epoch 40
args = CFG()
args.lang = 'zh_en'
args.neg_epoch = 10
args.reset_epoch = 20
args.epoch = 1000
args.r_hidden = 300
args.test_epoch = 100

# 模型并行
device = 'cuda:0'
second_device = 'cuda:1'

In [None]:
#  前compose
root=Path("data/%s/cache"%args.lang)
pairs = load_alignment_pair("data/%s/ref_ent_ids"%args.lang)
ratio=0.3
test_pair = pairs[int(ratio * len(pairs)):]

test_seeds = torch.LongTensor(test_pair).to(second_device)

S1 = read_cache(root / "attr_similarity").to(second_device)  # 属性相似度（集合相似度）
S2 = read_cache(root / "value_similarity").to(second_device)  # 属性值相似度（集合相似度）

In [None]:
def get_emb(model, data):
    model.eval()
    with torch.no_grad():
        x1 = model(data.x1, data.edge_index1, data.rel1, data.edge_index_all1, data.rel_all1)
        x2 = model(data.x2, data.edge_index2, data.rel2, data.edge_index_all2, data.rel_all2)
    return x1, x2


def train(model, criterion, optimizer, data, train_batch, false_pair=None):
    model.train()
    x1 = model(data.x1, data.edge_index1, data.rel1, data.edge_index_all1, data.rel_all1)
    x2 = model(data.x2, data.edge_index2, data.rel2, data.edge_index_all2, data.rel_all2)
    loss = criterion(x1, x2, data.new_train_set, train_batch, false_pair)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss


def init_data(args, device):
    # args.data数据集根目录,args.lang采用的数据集
    data = DBP15K(args.data, args.lang, rate=args.rate)[0]
    data.x1 = F.normalize(data.x1, dim=1, p=2).to(device).requires_grad_()
    data.x2 = F.normalize(data.x2, dim=1, p=2).to(device).requires_grad_()
    data.edge_index_all1, data.rel_all1 = add_inverse_rels(data.edge_index1, data.rel1)
    data.edge_index_all2, data.rel_all2 = add_inverse_rels(data.edge_index2, data.rel2)
    return data

def test(x1, x2, data, stable=False):
    with torch.no_grad():
        print('-' * 16 + 'Train_set' + '-' * 16)
        get_hits(x1, x2, data.train_set)
        print('-' * 16 + 'Train_set' + '-' * 16)
        get_hits(x1, x2, data.new_train_set)
        print('-' * 16 + 'Test_set' + '-' * 17)
        S, hits1 = get_hits(x1, x2, data.test_set)
        if stable:
            get_hits_stable(x1, x2, data.test_set)
        print()
    return S.detach().cpu(), hits1

In [None]:
data = init_data(args, second_device).to(second_device)
# 消融名称
ablation = "_precompose_globalpair"

model = RAGA5(data.x1.size(1), args.r_hidden)
optimizer = torch.optim.Adam(itertools.chain(model.parameters(), iter([data.x1, data.x2])))
# model, optimizer = apex.amp.initialize(model, optimizer)
criterion = L1_Loss(args.gamma)
data.new_train_set = data.train_set
false_pair = None

# 使用测试集hits1寻找最优
max_hits1 = 0
save_epoch = 200


for epoch in range(args.epoch):
    if epoch % args.neg_epoch == 0:
        x1, x2 = get_emb(model, data)
        train_batch = get_train_batch(x1, x2, data.new_train_set, args.k)
    loss = train(model, criterion, optimizer, data, train_batch, false_pair)
    print('Epoch:', epoch + 1, '/', args.epoch, '\tLoss: %.3f' % loss, '\r', end='')
    if (epoch + 1) % args.test_epoch == 0:
        print()
        S, hits1 = test(x1, x2, data, args.stable_test)
        if(hits1 >= max_hits1 and epoch + 2 > save_epoch):
            max_hits1 = hits1     
            # 保存相似度矩阵
            cache_dir = Path("data/%s/cache/" % args.lang)
            cache_dir.mkdir(exist_ok=True)
            cache_data(S, cache_dir / ("rel_graph_similarity_e" + str(epoch+1) + "_n" + str(args.neg_epoch) + "_r" + str(args.reset_epoch)+"_hit"+str(int(hits1//1e-4))+ablation))

            del S

            # 保存实体表示和模型参数
            torch.save([data.x1.clone().to('cpu'),data.x2.clone().to('cpu')], ("saved_models/"+ args.lang + "_emb_e"+  str(epoch+1) + "_n" + str(args.neg_epoch) + "_r" + str(args.reset_epoch)+"_hit"+str(int(hits1//1e-4))+ablation+".pt"))
            torch.save(model.state_dict(),("saved_models/"+ args.lang + "_model_e"+  str(epoch+1) + "_n" + str(args.neg_epoch) + "_r" + str(args.reset_epoch)+"_hit"+str(int(hits1//1e-4))+ablation+".pt"))

    if (epoch + 1) % args.reset_epoch == 0:
        false_pair = reset(x1, x2, data,keep_seeds=args.keep_seeds,new_train_seeds=args.new_train_seeds,neg_seeds=args.neg_seeds)