In [1]:
import torch
from torch.utils.data import Dataset
from settings import *
import pickle
import pandas as pd
import torch.utils.data as Data
from model import MyEmbedder
from datetime import datetime
from load import *

args = {
    'device':'cuda:0',
    'time':datetime.now().strftime("%Y%m%d%H%M%S"),
    'language':'ja_en',
    'model_language':'ja_en',
    'epoch':300,
    'batch_size':64,
    'queue_length':64,
    'center_norm':False,
    'neighbor_norm':True,
    'emb_norm':True,
    'combine':True,
    'gat_num':1,
    't': 0.08,
    'momentum':0.9999,
    'lr':1e-6,
    'dropout':0.3
}
device = torch.device('cuda')
path = "/home/mrcactus/Thesis/ACEA/data/DBP15K/ja_en"

In [2]:
ill_idx = load_triples(path + "/ref_ent_ids", file_num=1) # ground truth
rate, val = 0.3, 0.0
ill_train_idx, ill_val_idx, ill_test_idx = np.array(ill_idx[:int(len(ill_idx) // 1 * rate)], dtype=np.int32), np.array(ill_idx[int(len(ill_idx) // 1 * rate) : int(len(ill_idx) // 1 * (rate+val))], dtype=np.int32), np.array(ill_idx[int(len(ill_idx) // 1 * (rate+val)):], dtype=np.int32)
ill_train_idx = list(zip(*ill_train_idx))

In [3]:
# ill_test_idx[0][0]
link = {}
for [k, v] in ill_test_idx:
    link[k] = v 


In [4]:
seedset = SeedDataset(ill_train_idx)

In [5]:
seedloader = Data.DataLoader(
            dataset=seedset,  # torch TensorDataset format
            batch_size=64,  # all test data
            shuffle=True,
            drop_last=True,
        )

In [6]:
all_data_batches = []
i = 0
for batch_id, (token_data, id_data) in enumerate(seedloader):
    all_data_batches.append([torch.Tensor(list(zip(*token_data)))[0], \
                             torch.Tensor(list(zip(*id_data)))[0]])

In [8]:
loader1 = DBP15KRawNeighbors(path, 'ja_en', "1")
loader2 = DBP15KRawNeighbors(path, 'ja_en', "2")

In [9]:
myset1 = MyRawdataset(loader1.id_neighbors_dict, loader1.id_adj_tensor_dict)
myset2 = MyRawdataset(loader2.id_neighbors_dict, loader2.id_adj_tensor_dict)

In [10]:
random.shuffle(all_data_batches)

In [11]:
eval_loader1 = Data.DataLoader(
            dataset=myset1,  # torch TensorDataset format
            batch_size=64,  # all test data
            shuffle=True,
            drop_last=False,
        )
eval_loader2 = Data.DataLoader(
            dataset=myset2,  # torch TensorDataset format
            batch_size=64,  # all test data
            shuffle=True,
            drop_last=False,
        )

In [12]:
model = MyEmbedder(args, VOCAB_SIZE).to(device)

In [13]:
def fix_seed(seed=37):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
fix_seed(37)

In [14]:
import torch.optim as optim
optimizer = optim.Adam(params=model.parameters(), lr=args['lr'])

In [19]:
import faiss
def cal_sim(v1, v2, link, ids_1, inverse_ids_2):
    source = [_id for _id in ids_1 if _id in link]
    target = np.array(
        [inverse_ids_2[link[_id]] if link[_id] in inverse_ids_2 else 99999 for _id in source])
    src_idx = [idx for idx in range(len(ids_1)) if ids_1[idx] in link]
    v1 = np.concatenate(tuple(v1), axis=0)[src_idx, :]
    v2 = np.concatenate(tuple(v2), axis=0)
    index = faiss.IndexFlatIP(v2.shape[1])
    index.add(np.ascontiguousarray(v2))
    D, I = index.search(np.ascontiguousarray(v1), 10)
    return source, target, D, I # D是相似性矩阵， I是ID矩阵
def evaluate(model, eval_loader1, eval_loader2, link, step):
    print("Evaluate at epoch {}...".format(step))
    ids_1, ids_2, vector_1, vector_2 = list(), list(), list(), list()
    inverse_ids_2 = dict()
    with torch.no_grad():
        model.eval()
        for sample_id_1, (token_data_1, id_data_1) in tqdm(enumerate(eval_loader1)):
            entity_vector_1 = model(token_data_1).squeeze().detach().cpu().numpy()
            ids_1.extend(id_data_1.squeeze().tolist())
            vector_1.append(entity_vector_1)

        for sample_id_2, (token_data_2, id_data_2) in tqdm(enumerate(eval_loader2)):
            entity_vector_2 = model(token_data_2).squeeze().detach().cpu().numpy()
            ids_2.extend(id_data_2.squeeze().tolist())
            vector_2.append(entity_vector_2)

    for idx, _id in enumerate(ids_2):
        inverse_ids_2[_id] = idx
    def cal_hit(v1, v2, link):
        source, target, D, I = cal_sim(v1, v2, link, ids_1, inverse_ids_2)
        print(D)
        hit1 = (I[:, 0] == target).astype(np.int32).sum() / len(source)
        hit10 = (I == target[:, np.newaxis]).astype(np.int32).sum() / len(source)
        print("#Entity: {}".format(len(source)))
        print("Hit@1: {}".format(round(hit1, 3)))
        print("Hit@10:{}".format(round(hit10, 3)))
        return round(hit1, 3), round(hit10, 3)
    print('===========Test===========')
    print("len v1:" + str(len(vector_1)))
    hit1_test, hit10_test = cal_hit(vector_1, vector_2, link)
    return hit1_test, hit10_test


In [20]:
from tqdm import tqdm
start_time = datetime.now()
evaluate(model, eval_loader1, eval_loader2, link, 0)
best_hit1_valid_epoch = 0
best_hit10_valid_epoch = 0
best_hit1_test_epoch = 0
best_hit10_test_epoch = 0
best_hit1_valid = 0
best_hit10_valid = 0
best_hit1_valid_hit10 = 0
best_hit10_valid_hit1 = 0
best_hit1_test = 0
best_hit10_test = 0
best_hit1_test_hit10 = 0
best_hit10_test_hit1 = 0
record_hit1 = 0
record_hit10 = 0
record_epoch = 0
record_batch_id = 0
for epoch in range(1):
    for batch_id, (x_ids, y_ids) in tqdm(enumerate(all_data_batches)):
        kg1_train_ent_idx = list(map(lambda x: int(x), list(x_ids)))
        kg1_train_ent_emb = None 
        kg2_train_ent_idx = list(map(lambda x: int(x), list(y_ids)))
        kg2_train_ent_emb = None 
        with torch.no_grad():
            for idx in kg1_train_ent_idx:
                if kg1_train_ent_emb==None:
                    kg1_train_ent_emb = myset1.id_emb[idx].unsqueeze(0)
                else:
                    kg1_train_ent_emb = torch.cat((kg1_train_ent_emb,\
                                                myset1.id_emb[idx].unsqueeze(0)),0)
            for idx in kg2_train_ent_idx:
                if kg2_train_ent_emb==None:
                    kg2_train_ent_emb = myset2.id_emb[idx].unsqueeze(0)
                else:
                    kg2_train_ent_emb = torch.cat((kg2_train_ent_emb,\
                                                myset2.id_emb[idx].unsqueeze(0)),0)
            # kg1_train_ent_emb.append(myset1.id_emb[idx])
            idx = [i for i in range(kg2_train_ent_emb.size(0)-1,-1,-1)]
            idx = torch.LongTensor(idx)
            neg_queue = kg2_train_ent_emb.index_select(0, idx)
        
        optimizer.zero_grad()
        pos_1 = model(kg1_train_ent_emb)
        pos_2 = model(kg2_train_ent_emb)
        neg = model(neg_queue)
        contrastive_loss = model.contrastive_loss(pos_1, pos_2, neg)

        contrastive_loss.backward(retain_graph=True)
        optimizer.step()

        if batch_id == len(all_data_batches) - 1:
        # if batch_id % 200 == 0 or batch_id == len(all_data_batches) - 1:
            print('epoch: {} batch: {} loss: {}'.format(epoch, batch_id,
                                                        contrastive_loss.detach().cpu().data / 64))
            hit1_test, hit10_test = evaluate(model, eval_loader1, eval_loader2, link, str(epoch)+": batch "+str(batch_id))

            if hit1_test > best_hit1_test:
                best_hit1_test = hit1_test
                best_hit1_test_hit10 = hit10_test
                best_hit1_test_epoch = epoch
            if hit10_test  > best_hit10_test:
                best_hit10_test = hit10_test
                best_hit10_test_hit1 = hit1_test
                best_hit10_test_epoch = epoch
            
            print('Test Hit@1(10)    = {}({}) at epoch {} batch {}'.format(hit1_test, hit10_test, epoch, batch_id))
            print('Best Valid Hit@1  = {}({}) at epoch {}'.format(best_hit1_valid, best_hit1_valid_hit10, best_hit1_valid_epoch))
            print('Best Valid Hit@10 = {}({}) at epoch {}'.format(best_hit10_valid,best_hit10_valid_hit1, best_hit10_valid_epoch))
            print('Test @ Best Valid = {}({}) at epoch {} batch {}'.format(record_hit1, record_hit10, record_epoch, record_batch_id))

            print('Best Test  Hit@1  = {}({}) at epoch {}'.format(best_hit1_test, best_hit1_test_hit10, best_hit1_test_epoch))
            print('Best Test  Hit@10 = {}({}) at epoch {}'.format(best_hit10_test,best_hit10_test_hit1, best_hit10_test_epoch))
            print("====================================")
end_time = datetime.now()
print("start: "+start_time.strftime("%Y-%m-%d %H:%M:%S"))
print("end: "+end_time.strftime("%Y-%m-%d %H:%M:%S"))
print("used_time: "+ str(end_time - start_time))

0it [00:00, ?it/s]

3it [00:00, 20.68it/s]

Evaluate at epoch 0...


310it [00:01, 186.01it/s]
310it [00:01, 205.83it/s]


len v1:310


2it [00:00, 16.76it/s]

[[0.84657776 0.7415437  0.7158799  ... 0.68777066 0.68646765 0.68589044]
 [0.88217133 0.7332629  0.72316325 ... 0.6892617  0.68702865 0.68601644]
 [0.91740835 0.89709264 0.8920222  ... 0.8713179  0.8687122  0.86766076]
 ...
 [0.8061072  0.76103747 0.734789   ... 0.6914282  0.69129246 0.6908588 ]
 [0.8454975  0.79202276 0.78825045 ... 0.7599765  0.750582   0.74111545]
 [0.88704264 0.8526741  0.8517467  ... 0.8420168  0.8414339  0.8401853 ]]
#Entity: 10500
Hit@1: 0.79
Hit@10:0.895


68it [00:04, 21.42it/s]

epoch: 0 batch: 69 loss: 0.01619906537234783
Evaluate at epoch 0: batch 69...


310it [00:01, 216.77it/s]
310it [00:01, 179.93it/s]


len v1:310


70it [00:08,  8.15it/s]

[[0.8784205  0.8738854  0.8728077  ... 0.86947155 0.8633952  0.8631546 ]
 [0.62848747 0.61246824 0.6124333  ... 0.5788074  0.5685188  0.5635835 ]
 [0.8903419  0.87099636 0.8699853  ... 0.8631131  0.86285746 0.8619568 ]
 ...
 [0.5716382  0.5633813  0.5631603  ... 0.5536786  0.55255115 0.551604  ]
 [0.78205824 0.77389795 0.7531215  ... 0.73259354 0.7198845  0.71742606]
 [0.7528471  0.61708367 0.6132751  ... 0.6035112  0.60108244 0.5987027 ]]
#Entity: 10500
Hit@1: 0.793
Hit@10:0.898
Test Hit@1(10)    = 0.793(0.898) at epoch 0 batch 69
Best Valid Hit@1  = 0(0) at epoch 0
Best Valid Hit@10 = 0(0) at epoch 0
Test @ Best Valid = 0(0) at epoch 0 batch 0
Best Test  Hit@1  = 0.793(0.898) at epoch 0
Best Test  Hit@10 = 0.898(0.793) at epoch 0
start: 2024-01-25 15:52:18
end: 2024-01-25 15:52:31
used_time: 0:00:13.195266





: 