In [None]:
%%time
import warnings
warnings.filterwarnings('ignore')

import os
from tensorflow import keras
import numpy as np
import numba as nb 
from tqdm import *
from utils import *
import tensorflow as tf
from tensorflow.keras.layers import *
from my_layer import *
import multiprocessing
from tensorflow.keras.mixed_precision import experimental as mixed_precision


os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["TF_CPP_MIN_LOG_LEVEL"]="2"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
#tf.config.experimental.set_virtual_device_configuration(gpus[0],[tf.config.experimental.VirtualDeviceConfiguration(memory_limit=10000)])

In [None]:
activation = "elu" 
node_hidden = 300
rel_hidden = 300
batch_size = 512
dropout_rate = 0.3
core_num = 16
lr = 0.005
depth = 2
info = 1

In [None]:
%%time
np.random.seed(12306)
file_path = "data/fr_en/"
all_triples,node_size,rel_size = load_triples(file_path,True)
train_pair,dev_pair = load_aligned_pair(file_path)
all_triples,all_ent_ent,all_ent_rel = generate_map(all_triples,node_size)

In [None]:
def update_and_obtain_feature():
    feature_model.set_weights(train_model.get_weights())
    inputs = [[i for i in range(node_size)],all_triples,all_ent_ent,all_ent_rel]
    inputs = [np.expand_dims(item,axis=0) for item in inputs]
    with tf.device("/cpu:0"):
        vec = feature_model.predict_on_batch(inputs)
    return vec

def gather_embeddings(indexs,vec):
    def gather(index,vec):
        result_vec = np.array([vec[e] for e in index])
        result_vec = result_vec / (np.linalg.norm(result_vec,axis=-1,keepdims=True)+1e-6)
        return result_vec
    return [gather(index,vec) for index in indexs]

In [None]:
def get_main_model(node_hidden,rel_hidden,n_attn_heads = 2,dropout_rate = 0,lr = 0.005,depth = 2):
    used_node = Input(shape=(None,))
    input_triples = Input(shape=(None,3))
    ent_ent = Input(shape=(None,2))
    ent_rel = Input(shape=(None,2))
    
    ent_emb = TokenEmbedding(node_size,node_hidden)(input_triples)
    ent_emb = Lambda(lambda x:K.squeeze(K.gather(indices=K.cast(x[0],"int32"),reference=x[1]),axis=0))([used_node,ent_emb])
    rel_emb = TokenEmbedding(rel_size,node_hidden)(input_triples)
    
    def avg(tensor):
        adj,emb = tensor; adj = K.squeeze(K.cast(adj,"int32"),axis=0)
        embeds = K.gather(indices=adj[:,1],reference=emb)
        sums = tf.math.segment_sum(segment_ids=adj[:,0],data=tf.ones_like(embeds))
        embeds = tf.math.segment_sum(segment_ids=adj[:,0],data=embeds)
        return embeds/sums
    
    ent_feature = Activation(activation)(Lambda(avg)([ent_ent,ent_emb]))
    rel_feature = Activation(activation)(Lambda(avg)([ent_rel,rel_emb]))
    results = [ent_feature,rel_feature]
    for i in range(depth):
        encoder = NR_GraphAttention(node_size,
                                    rel_size,
                                    activation = activation,
                                    attn_heads=n_attn_heads,
                                    attn_heads_reduction='average')
        
        ent_feature = encoder([input_triples,ent_feature,rel_emb])
        rel_feature = encoder([input_triples,rel_feature,rel_emb])
        results.extend([ent_feature,rel_feature])
        
    
    out_feature = Concatenate()(results)
    out_feature = Dropout(dropout_rate)(out_feature)
    
    alignment_input = Input(shape=(None,2))
    def gather_pair_emb(tensor):
        emb = tensor[1]
        l,r = K.cast(tensor[0][0,:,0],'int32'),K.cast(tensor[0][0,:,1],'int32')
        l_emb,r_emb = K.gather(reference=emb,indices=l),K.gather(reference=emb,indices=r)
        return [l_emb,r_emb]
    lemb,remb = Lambda(gather_pair_emb)([alignment_input,out_feature])
    
    fixed_emb = Input((None,2,None))
    fixed_features = Lambda(lambda x:K.squeeze(x,axis=0))(fixed_emb)
    fixed_features = Dropout(dropout_rate)(fixed_features)
    fixed_lemb,fixed_remb = Lambda(lambda x:[x[:,0,:],x[:,1,:]])(fixed_features)
    
    def align_loss(tensor): 
        
        def normalize(x):
            x = (x-K.mean(x,0))/K.std(x,0)
            x = K.l2_normalize(x,axis=-1)
            return x
        
        l_emb,r_emb,fixed_lemb,fixed_remb = [normalize(x) for x in tensor]

        lpos_dis = - K.sum(l_emb*fixed_remb,axis=-1)
        lpos_dis = K.pow(lpos_dis,3)
        
        rpos_dis = - K.sum(r_emb*fixed_lemb,axis=-1)
        rpos_dis = K.pow(rpos_dis,3)
        return K.mean(lpos_dis+rpos_dis,keepdims=True)
    
    loss = Lambda(align_loss)([lemb,remb,fixed_lemb,fixed_remb])
    
    inputs = [used_node,input_triples,ent_ent,ent_rel]
    train_model = tf.keras.Model(inputs = inputs + [alignment_input,fixed_emb],outputs = loss)
    train_model.compile(loss=lambda y_true,y_pred: y_pred,optimizer=tf.keras.optimizers.RMSprop(lr=lr,rho=0.95,centered=True))
    
    
    feature_model = tf.keras.Model(inputs = inputs,outputs = out_feature)
    
    return train_model,feature_model

In [None]:
%%time
train_model,_ = get_main_model(dropout_rate=dropout_rate,
                          n_attn_heads = 1,
                          depth=depth,
                          node_hidden=node_hidden,
                          rel_hidden=rel_hidden,
                          lr=lr)


with tf.device("/cpu:0"):
    _,feature_model = get_main_model(dropout_rate=dropout_rate,
                          n_attn_heads = 1,
                          depth=depth,
                          node_hidden=node_hidden,
                          rel_hidden=rel_hidden,
                          lr=lr)

init_features = update_and_obtain_feature()
train_model.summary()

In [None]:
%%time
@nb.njit()
def find_start(triples):
    adj_dic = {}
    for i,(h,r,t) in enumerate(triples):
        if h not in adj_dic:
            temp = nb.typed.List()
            temp.append((r,t))
            adj_dic[h] = temp
        else:
            adj_dic[h].append((r,t))
    return adj_dic

adj_dic = find_start(all_triples)


def bfs(pairs,rel_weights,max_depth,info):
    triples,ent_ent,ent_rel,node_dict,used_node = select_path(pairs.flatten(),adj_dic,rel_weights,max_depth,info)
    triples = np.unique(triples,axis=0)
    mapped_pairs = np.array([i for i in range(pairs.shape[0]*2)]).reshape((-1,2))
    fix_emb = np.array([[init_features[i] for i in pair]for pair in pairs])
    inputs = [used_node,triples,ent_ent,ent_rel,mapped_pairs,fix_emb]
    inputs = [np.array([item]) for item in inputs]
    return inputs

def generate_batch(train_pair,rel_weights,max_depth,info):
    np.random.shuffle(train_pair)
    pool = multiprocessing.Pool(processes=core_num)
    inputs = []
    for i in range(len(train_pair)//batch_size+1):
        inputs.append(pool.apply_async(bfs,(train_pair[i*batch_size:(i+1)*batch_size],rel_weights,max_depth,info)))
    pool.close()
    return inputs,pool

rel_weights = np.ones(rel_size,dtype="float32")
next_inputs,pool = generate_batch(train_pair,rel_weights,depth+2,info)

In [None]:
dis = {}
rest_set_1 = [e1 for e1, e2 in dev_pair]
rest_set_2 = [e2 for e1, e2 in dev_pair]
np.random.shuffle(rest_set_1)
np.random.shuffle(rest_set_2)

In [None]:
%%time
epoch = 7
for turn in range(10):
    for i in trange(epoch):
        pool.join()
        inputs = [res.get() for res in next_inputs]  
        weights = train_model.get_weights()
        rel_weights = np.squeeze(np.exp(np.dot(weights[1],weights[2])),axis=-1)
        next_inputs,pool = generate_batch(train_pair,rel_weights,depth+2,info)
        
        for _ in range(2):
            shuffle(inputs)
            for input_batch in inputs:
                train_model.train_on_batch(input_batch,np.zeros((1,1)))
                
        
    now_features = update_and_obtain_feature()
    Lvec,Rvec = gather_embeddings([dev_pair[:,0],dev_pair[:,1]],now_features)
    ILvec,IRvec = gather_embeddings([dev_pair[:,0],dev_pair[:,1]],init_features)
    GPU_test(Lvec,IRvec,512)
    GPU_test(Rvec,ILvec,512)
    
    now_features = update_and_obtain_feature()
    Lvec,Rvec = gather_embeddings([rest_set_1,rest_set_2],now_features)
    ILvec,IRvec = gather_embeddings([rest_set_1,rest_set_2],init_features)
    
    A = CSLS_cal(Lvec,IRvec,False,512)
    B = CSLS_cal(Rvec,ILvec,False,512)
    A = np.argmax(A,1)
    B = np.argmax(B,1)
    
    new_pair = []
    for i,j in enumerate(A):
        if  B[j] == i:
            new_pair.append([rest_set_1[j],rest_set_2[i]])

    for e1,e2 in new_pair:
        if e1 in rest_set_1:
            rest_set_1.remove(e1)   
        if e2 in rest_set_2:
            rest_set_2.remove(e2)
            
    print(len(new_pair))
    
    for i,j in dis:
        emb_i,emb_j = now_features[i],now_features[j]
        now_dis = emb_i.dot(emb_j.T)/(np.linalg.norm(emb_i)*np.linalg.norm(emb_j))
        if dis[(i,j)] - now_dis > 0.05:
            new_pair.append([i,j])
    print(len(new_pair))
    
    for i,j in train_pair:
        emb_i,emb_j = now_features[i],now_features[j]
        now_dis = emb_i.dot(emb_j.T)/(np.linalg.norm(emb_i)*np.linalg.norm(emb_j))
        dis[(i,j)] = now_dis  
    train_pair = np.array(new_pair)
    epoch = 5