In [45]:
import csv
import pandas as pd
import torch
import pickle
import numpy as np
from torch.utils.data import Dataset, DataLoader
from os.path import abspath, dirname, join, exists
from sentence_transformers import SentenceTransformer, util
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import *
import argparse
import logging
from datetime import datetime
import torch.optim as optim
import random
import os

data_dir = '/home/mrcactus/Thesis/ACEA/data'
TOKEN_LEN = 50
VOCAB_SIZE = 100000
LaBSE_DIM = 768
EMBED_DIM = 300
BATCH_SIZE = 96
FASTTEXT_DIM = 300
NEIGHBOR_SIZE = 20 
ATTENTION_DIM = 300
MULTI_HEAD_DIM = 1

In [39]:
def parse_options(parser):
    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument('--time', type=str, default=datetime.now().strftime("%Y%m%d%H%M%S"))
    parser.add_argument('--language', type=str, default='zh_en')
    parser.add_argument('--model_language', type=str, default='zh_en')
    parser.add_argument('--model', type=str, default='LaBSE')

    parser.add_argument('--epoch', type=int, default=300)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--queue_length', type=int, default=64)

    parser.add_argument('--center_norm', type=bool, default=False)
    parser.add_argument('--neighbor_norm', type=bool, default=True)
    parser.add_argument('--emb_norm', type=bool, default=True)
    parser.add_argument('--combine', type=bool, default=True)

    parser.add_argument('--gat_num', type=int, default=1)

    parser.add_argument('--t', type=float, default=0.08)
    parser.add_argument('--momentum', type=float, default=0.9999)
    parser.add_argument('--lr', type=float, default=1e-6)
    parser.add_argument('--dropout', type=float, default=0.3)

    return parser.parse_args()
# parser = argparse.ArgumentParser()
args = {
    'device':'cuda:0',
    'time':datetime.now().strftime("%Y%m%d%H%M%S"),
    'language':'zh_en',
    'model_language':'zh_en',
    'epoch':300,
    'batch_size':64,
    'queue_length':64,
    'center_norm':False,
    'neighbor_norm':True,
    'emb_norm':True,
    'combine':True,
    'gat_num':1,
    't':0.08,
    'momentum':0.9999,
    'lr':1e-6,
    'dropout':0.3
}

In [41]:
class NCESoftmaxLoss(nn.Module):

    def __init__(self, device):
        super(NCESoftmaxLoss, self).__init__()
        self.criterion = nn.CrossEntropyLoss()
        self.device = device

    def forward(self, x):
        batch_size = x.shape[0]
        x = x.squeeze()
        label = torch.zeros([batch_size]).to(self.device).long()
        loss = self.criterion(x, label)
        return loss
class MyEmbedder(nn.Module):
    def __init__(self, args, vocab_size, padding=ord(' ')):
        super(MyEmbedder, self).__init__()

        self.args = args

        self.device = torch.device(self.args['device'])

        self.attn = BatchMultiHeadGraphAttention(self.device, self.args)
        
        self.attn_mlp = nn.Sequential(
            nn.Linear(LaBSE_DIM * 2, LaBSE_DIM),
        )

        # loss
        self.criterion = NCESoftmaxLoss(self.device)

        # batch queue
        self.batch_queue = []

    def contrastive_loss(self, pos_1, pos_2, neg_value):
        bsz = pos_1.shape[0]
        l_pos = torch.bmm(pos_1.view(bsz, 1, -1), pos_2.view(bsz, -1, 1))
        l_pos = l_pos.view(bsz, 1)
        l_neg = torch.mm(pos_1.view(bsz, -1), neg_value.t())
        logits = torch.cat((l_pos, l_neg), dim=1)
        logits = logits.squeeze().contiguous()
        return self.criterion(logits / self.args['t'])

    def update(self, network: nn.Module):
        for key_param, query_param in zip(self.parameters(), network.parameters()):
            key_param.data *= self.args['momentum']
            key_param.data += (1 - self.args['momentum']) * query_param.data
        self.eval()

    def forward(self, batch):
        batch = batch.to(self.device)
        batch_in = batch[:, :, :LaBSE_DIM]
        adj = batch[:, :, LaBSE_DIM:]

        center = batch_in[:, 0].to(self.device)
        center_neigh = batch_in.to(self.device)

        for i in range(0, self.args['gat_num']):
            center_neigh = self.attn(center_neigh, adj.bool()).squeeze(1)
        
        center_neigh = center_neigh[:, 0]

        if self.args['center_norm']:
            center = F.normalize(center, p=2, dim=1)
        if self.args['neighbor_norm']:
            center_neigh = F.normalize(center_neigh, p=2, dim=1)
        if self.args['combine']:
            out_hat = torch.cat((center, center_neigh), dim=1)
            out_hat = self.attn_mlp(out_hat)
            if self.args['emb_norm']:
                out_hat = F.normalize(out_hat, p=2, dim=1)
        else:
            out_hat = center_neigh

        return out_hat


class BatchMultiHeadGraphAttention(nn.Module):
    def __init__(self, device, args, n_head=MULTI_HEAD_DIM, f_in=LaBSE_DIM, f_out=LaBSE_DIM, bias=True):
        super(BatchMultiHeadGraphAttention, self).__init__()
        self.device = device
        self.n_head = n_head
        self.w = Parameter(torch.Tensor(n_head, f_in, f_out))
        self.a_src = Parameter(torch.Tensor(n_head, f_out, 1))
        self.a_dst = Parameter(torch.Tensor(n_head, f_out, 1))

        self.leaky_relu = nn.LeakyReLU(negative_slope=0.2)
        self.softmax = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(args['dropout'])
        if bias:
            self.bias = Parameter(torch.Tensor(f_out))
            nn.init.constant_(self.bias, 0)
        else:
            self.register_parameter('bias', None)

        nn.init.xavier_uniform_(self.w)
        nn.init.xavier_uniform_(self.a_src)
        nn.init.xavier_uniform_(self.a_dst)

    def forward(self, h, adj):
        bs, n = h.size()[:2]  # h is of size bs x n x f_in
        h_prime = torch.matmul(h.unsqueeze(1), self.w)  # bs x n_head x n x f_out
        attn_src = torch.matmul(torch.tanh(h_prime), self.a_src)  # bs x n_head x n x 1
        attn_dst = torch.matmul(torch.tanh(h_prime), self.a_dst)  # bs x n_head x n x 1
        attn = attn_src.expand(-1, -1, -1, n) + attn_dst.expand(-1, -1, -1, n).permute(0, 1, 3, 2)  # bs x n_head x n x n

        attn = self.leaky_relu(attn)
        mask = ~(adj.unsqueeze(1) | torch.eye(adj.shape[-1]).bool().to(self.device))  # bs x 1 x n x n
        attn.data.masked_fill_(mask, float("-inf"))
        attn = self.softmax(attn)  # bs x n_head x n x n
        attn = self.dropout(attn)
        # logging.info("attn: ", attn)
        # logging.info("attn.shape: ", attn.shape)
        output = torch.matmul(attn, h_prime)  # bs x n_head x n x f_out
        if self.bias is not None:
            return output + self.bias
        else:
            return output

In [2]:
path = join(data_dir, 'DBP15K', 'ja_en')
path

'/home/mrcactus/Thesis/ACEA/data/DBP15K/ja_en'

In [3]:
def load_dict(data_dir, file_num=2):
    if file_num == 2:
        file_names = [data_dir + str(i) for i in range(1, 3)]
    else:
        file_names = [data_dir]
    what2id, id2what, ids = {}, {}, []
    for file_name in file_names:
        with open(file_name, "r", encoding="utf-8") as f:
            data = f.read().strip().split("\n")
            data = [i.split("\t") for i in data]
            what2id = {**what2id, **dict([[i[1], int(i[0])] for i in data])}
            id2what = {**id2what, **dict([[int(i[0]), i[1]] for i in data])}
            ids.append(set([int(i[0]) for i in data]))
    return what2id, id2what, ids

In [4]:
path == "/home/mrcactus/Thesis/ACEA/data/DBP15K/ja_en"

True

In [5]:
ent2id_dict, id2ent_dict, [kg1_ent_ids, kg2_ent_ids] = load_dict(path + "/cleaned_ent_ids_", file_num=2)
# /home/mrcactus/Thesis/ACEA/data/DBP15K/ja_en

In [6]:
rel2id_dict, id2rel_dict, [kg1_rel_ids, kg2_rel_ids] = load_dict(path + "/cleaned_rel_ids_", file_num=2)

In [7]:
def load_triples(data_dir, file_num=2):
    if file_num == 2:
        file_names = [data_dir + str(i) for i in range(1, 3)]
    else:
        file_names = [data_dir]
    triple = []
    for file_name in file_names:
        with open(file_name, "r", encoding="utf-8") as f:
            data = f.read().strip().split("\n")
            data = [tuple(map(int, i.split("\t"))) for i in data]
            triple += data
    np.random.shuffle(triple)
    return triple

In [8]:
triple_idx = load_triples(path + "/triples_", file_num=2)

In [9]:
ill_idx = load_triples(path + "/ref_ent_ids", file_num=1) # ground truth

In [10]:
def load_LaBSE_emb(data_dir, file_num):
    if file_num == 2:
        file_names = [data_dir + str(i) + '.pkl' for i in range(1, 3)]
    else:
        file_names = [data_dir + '.pkl']
    id_entity = []
    for file_name in file_names:
        with open(file_name, 'rb') as f:
            id_entity.append(pickle.load(f))
    return id_entity

In [11]:
[kg1_ids_ent_emb, kg2_ids_ent_emb] = load_LaBSE_emb(path + "/raw_LaBSE_emb_", file_num = 2)
# kg1_ids_ent_emb的格式：
# kg1_ids_ent_emb[i] = [[emb]]，二维数组，取的时候要在后面加[0]

In [12]:
s = util.pytorch_cos_sim(kg1_ids_ent_emb[ill_idx[0][0]],
                      kg2_ids_ent_emb[ill_idx[0][1]])
s

tensor([[0.8857]])

In [13]:
rate, val = 0.2, 0.1
ill_train_idx, ill_val_idx, ill_test_idx = np.array(ill_idx[:int(len(ill_idx) // 1 * rate)], dtype=np.int32), np.array(ill_idx[int(len(ill_idx) // 1 * rate) : int(len(ill_idx) // 1 * (rate+val))], dtype=np.int32), np.array(ill_idx[int(len(ill_idx) // 1 * (rate+val)):], dtype=np.int32)

In [19]:
ill_train_idx = list(zip(*ill_train_idx))

In [31]:
# 通过多个key，批量从dict中获取value
# 这里是获取ground truth中实体id对应的emb
kg1_train_ent_idx = ill_train_idx[0]
kg2_train_ent_idx = ill_train_idx[1]
# kg1_ids_ent_emb
# kg1_train_ent_idx
# from operator import itemgetter
# kg1_train_ent_emb = itemgetter(*kg1_train_ent_idx)(kg1_ids_ent_emb)
# kg2_train_ent_emb = itemgetter(*kg2_train_ent_idx)(kg2_ids_ent_emb)

kg1_train_ent_emb = []
kg2_train_ent_emb = []
for idx in kg1_train_ent_idx:
    kg1_train_ent_emb.append(kg1_ids_ent_emb[idx][0])
for idx in kg2_train_ent_idx:
    kg2_train_ent_emb.append(kg2_ids_ent_emb[idx][0])

In [35]:
ss = []
s = util.pytorch_cos_sim(kg1_train_ent_emb,kg2_train_ent_emb)

s

tensor([[0.8857, 0.6664, 0.5875,  ..., 0.5349, 0.6056, 0.5527],
        [0.6301, 0.7858, 0.5265,  ..., 0.4918, 0.6622, 0.5003],
        [0.5795, 0.6063, 0.9310,  ..., 0.6285, 0.5670, 0.5919],
        ...,
        [0.5152, 0.5741, 0.5957,  ..., 0.8379, 0.6220, 0.5366],
        [0.5934, 0.6929, 0.5839,  ..., 0.6201, 0.8585, 0.5554],
        [0.6172, 0.5440, 0.6283,  ..., 0.5713, 0.5970, 0.9448]])

In [42]:
device = torch.device('cuda')
model = MyEmbedder(args, VOCAB_SIZE).to(device)

In [44]:
iteration = 0
lr = args['lr']
optimizer = optim.Adam(params=model.parameters(), lr=lr)

In [46]:
def fix_seed(seed=37):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.

In [48]:
fix_seed(37)
optimizer.zero_grad()
pos_1 = model(torch.tensor(kg1_train_ent_emb))

IndexError: too many indices for tensor of dimension 2