In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!apt update && apt install cuda-11-8

In [None]:
#coding:utf-8
import numpy as np
import json
# from model import ConceptFlow, use_cuda
# from preprocession import prepare_data, build_vocab, gen_batched_data
import torch
import warnings
import yaml
import os
from torch.autograd import Variable
import torch.nn as nn
from torch.nn import utils as nn_utils
warnings.filterwarnings('ignore')

In [None]:
print(torch.cuda.get_device_name(0))

NVIDIA A100-SXM4-40GB


In [None]:
use_cuda = True
if use_cuda and torch.cuda.is_available():
    device = torch.device('cuda:0')

In [None]:
print(device)

cuda:0


In [None]:
!ls /usr/local/

bin    cuda	cuda-12.2  games	       include	lib64	   man	 share
colab  cuda-12	etc	   _gcs_config_ops.so  lib	licensing  sbin  src


In [None]:
!nvidia-smi

Fri Apr 12 09:32:10 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off | 00000000:00:04.0 Off |                    0 |
| N/A   31C    P0              43W / 400W |      5MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [None]:
print(torch.cuda.device_count())

1


In [None]:
print(torch.cuda.get_device_name(0))


NVIDIA A100-SXM4-40GB


In [None]:
torch.cuda.is_available()

True

#Embedding

In [None]:
# #coding:utf-8
# import torch
# import numpy as np
# from torch.autograd import Variable
# import torch.nn as nn
# from torch.nn import utils as nn_utils

VERY_SMALL_NUMBER = 1e-10
VERY_NEG_NUMBER = -100000000000

def use_cuda(var):
    if torch.cuda.is_available():
        var = var.to(device=device)
        return var
    else:
        return var
'''
def use_cuda(var):
    if torch.cuda.is_available():
        return var.cuda()
    else:
        return var
'''


class EntityEmbedding(nn.Module):
  def __init__(self, entity_embed, trans_units):
    super(EntityEmbedding, self).__init__()
    self.trans_units = trans_units
    self.entity_embedding = nn.Embedding(num_embeddings = entity_embed.shape[0] + 7, embedding_dim = self.trans_units, padding_idx = 0)
    entity_embed = torch.tensor(entity_embed, device=device)

    entity_embed = torch.cat((torch.zeros(7, self.trans_units, device=device), entity_embed), 0)
    self.entity_embedding.weight = nn.Parameter(torch.tensor(entity_embed, device=device))
    self.entity_embedding.weight.requires_grad = True
    self.entity_linear = nn.Linear(in_features = self.trans_units, out_features = self.trans_units)

  def forward(self, entity):
    entity_emb = self.entity_embedding(entity)
    entity_emb = self.entity_linear(entity_emb)
    return entity_emb



class WordEmbedding(nn.Module):
    def __init__(self, word_embed, embed_units):
      super(WordEmbedding, self).__init__()

      self.embed_units = embed_units
      self.word_embedding = nn.Embedding(num_embeddings = word_embed.shape[0], embedding_dim = self.embed_units, padding_idx = 0)
      self.word_embedding.weight = nn.Parameter(torch.tensor(word_embed, device=device))
      self.word_embedding.weight.requires_grad = True

    def forward(self, query_text):
      return self.word_embedding(query_text)

#Utils

In [None]:
def padding(sent, l):
    return sent + ['_EOS'] + ['_PAD'] * (l-len(sent)-1)

def padding_triple_id(entity2id, triple, num, l):
    newtriple = []
    for i in range(len(triple)):
        for j in range(len(triple[i])):
            for k in range(len(triple[i][j])):
                if triple[i][j][k] in entity2id:
                    triple[i][j][k] = entity2id[triple[i][j][k]]
                else:
                    triple[i][j][k] = entity2id['_NONE']

    for tri in triple:
        newtriple.append(tri + [[entity2id['_PAD_H'], entity2id['_PAD_R'], entity2id['_PAD_T']]] * (l - len(tri)))
    pad_triple = [[entity2id['_PAD_H'], entity2id['_PAD_R'], entity2id['_PAD_T']]] * l
    return newtriple + [pad_triple] * (num - len(newtriple))

def build_kb_adj_mat(kb_adj_mats, fact_dropout):
    """Create sparse matrix representation for batched data"""
    mats0_batch = np.array([], dtype=int)
    mats0_0 = np.array([], dtype=int)
    mats0_1 = np.array([], dtype=int)
    vals0 = np.array([], dtype=float)

    mats1_batch = np.array([], dtype=int)
    mats1_0 = np.array([], dtype=int)
    mats1_1 = np.array([], dtype=int)
    vals1 = np.array([], dtype=float)

    for i in range(kb_adj_mats.shape[0]):
        (mat0_0, mat0_1, val0), (mat1_0, mat1_1, val1) = kb_adj_mats[i]
        assert len(val0) == len(val1)
        num_fact = len(val0)
        num_keep_fact = int(np.floor(num_fact * (1 - fact_dropout)))
        mask_index = np.random.permutation(num_fact)[ : num_keep_fact]
        # mat0
        mats0_batch = np.append(mats0_batch, np.full(len(mask_index), i, dtype=int))
        mats0_0 = np.append(mats0_0, mat0_0[mask_index])
        mats0_1 = np.append(mats0_1, mat0_1[mask_index])
        vals0 = np.append(vals0, val0[mask_index])
        # mat1
        mats1_batch = np.append(mats1_batch, np.full(len(mask_index), i, dtype=int))
        mats1_0 = np.append(mats1_0, mat1_0[mask_index])
        mats1_1 = np.append(mats1_1, mat1_1[mask_index])
        vals1 = np.append(vals1, val1[mask_index])

    return (mats0_batch, mats0_0, mats0_1, vals0), (mats1_batch, mats1_0, mats1_1, vals1)

# central

In [None]:
class CentralEncoder(nn.Module):
  def __init__(self, config, gnn_layers, embed_units, trans_units, word_embedding, entity_embedding):
    super(CentralEncoder, self).__init__()
    self.k = 2 + 1
    self.gnn_layers = gnn_layers # 3
    self.WordEmbedding = word_embedding
    self.EntityEmbedding = entity_embedding
    self.embed_units = embed_units # 300
    self.trans_units = trans_units # 100
    self.pagerank_lambda = config.pagerank_lambda # 0.8
    self.fact_scale = config.fact_scale # 1

    self.node_encoder = nn.LSTM(input_size = self.embed_units, hidden_size = self.trans_units, batch_first=True, bidirectional=False)
    self.lstm_drop = nn.Dropout(p = config.lstm_dropout)
    self.softmax_d1 = nn.Softmax(dim = 1)
    self.linear_drop = nn.Dropout(p = config.linear_dropout)
    self.relu = nn.ReLU()

    for i in range(self.gnn_layers):
        self.add_module('q2e_linear' + str(i), nn.Linear(in_features=self.trans_units, out_features=self.trans_units))
        self.add_module('d2e_linear' + str(i), nn.Linear(in_features=self.trans_units, out_features=self.trans_units))
        self.add_module('e2q_linear' + str(i), nn.Linear(in_features=self.k * self.trans_units, out_features=self.trans_units))
        self.add_module('e2d_linear' + str(i), nn.Linear(in_features=self.k * self.trans_units, out_features=self.trans_units))
        self.add_module('e2e_linear' + str(i), nn.Linear(in_features=self.k * self.trans_units, out_features=self.trans_units))

        #use kb
        self.add_module('kb_head_linear' + str(i), nn.Linear(in_features=self.trans_units, out_features=self.trans_units))
        self.add_module('kb_tail_linear' + str(i), nn.Linear(in_features=self.trans_units, out_features=self.trans_units))
        self.add_module('kb_self_linear' + str(i), nn.Linear(in_features=self.trans_units, out_features=self.trans_units))

  def forward(self, batch_size, max_local_entity, max_fact, query_text, local_entity, q2e_adj_mat, kb_adj_mat, kb_fact_rel, query_mask):
    # normalized adj matrix
    pagerank_f = Variable(torch.from_numpy(q2e_adj_mat).type('torch.FloatTensor'), requires_grad=True).to(device) # use_cuda(Variable(torch.from_numpy(q2e_adj_mat).type('torch.FloatTensor'), requires_grad=True))
    q2e_adj_mat = Variable(torch.from_numpy(q2e_adj_mat).type('torch.FloatTensor'), requires_grad=False).to(device) # use_cuda
    assert pagerank_f.requires_grad == True

    # encode query
    query_word_emb = self.WordEmbedding(query_text)
    query_hidden_emb, (query_node_emb, _) = self.node_encoder(self.lstm_drop(query_word_emb), self.init_hidden(1, batch_size, self.trans_units))
    query_node_emb = query_node_emb.squeeze(dim=0).unsqueeze(dim=1)
    query_rel_emb = query_node_emb

    # build kb_adj_matrix from sparse matrix
    (e2f_batch, e2f_f, e2f_e, e2f_val), (f2e_batch, f2e_e, f2e_f, f2e_val) = kb_adj_mat
    entity2fact_index = torch.LongTensor([e2f_batch, e2f_f, e2f_e]).to(device)
    entity2fact_val = torch.FloatTensor(e2f_val).to(device)
    entity2fact_mat = use_cuda(torch.sparse.FloatTensor(entity2fact_index, entity2fact_val, torch.Size([batch_size, max_fact, max_local_entity])))#       entity2fact_mat = use_cuda(torch.sparse.FloatTensor(entity2fact_index, entity2fact_val, torch.Size([batch_size, max_fact, max_local_entity])))

    fact2entity_index = torch.LongTensor([f2e_batch, f2e_e, f2e_f])
    fact2entity_val = torch.FloatTensor(f2e_val)
    fact2entity_mat = use_cuda(torch.sparse.FloatTensor(fact2entity_index, fact2entity_val, torch.Size([batch_size, max_local_entity, max_fact])))

    local_fact_emb = self.EntityEmbedding(kb_fact_rel)
    local_fact_emb= local_fact_emb.to(device=device)
    # attention fact2question
    div = float(np.sqrt(self.trans_units))
    fact2query_sim = torch.bmm(query_hidden_emb, local_fact_emb.transpose(1, 2)) / div
    fact2query_sim = self.softmax_d1(fact2query_sim + (1 - query_mask.unsqueeze(dim=2)) * VERY_NEG_NUMBER)

    fact2query_att = torch.sum(fact2query_sim.unsqueeze(dim=3) * query_hidden_emb.unsqueeze(dim=2), dim=1)

    W = torch.sum(fact2query_att * local_fact_emb, dim=2) / div
    W_max = torch.max(W, dim=1, keepdim=True)[0]
    W_tilde = torch.exp(W - W_max)
    e2f_softmax = self.sparse_bmm(entity2fact_mat.transpose(1, 2), W_tilde.unsqueeze(dim=2)).squeeze(dim=2)
    e2f_softmax = torch.clamp(e2f_softmax, min=VERY_SMALL_NUMBER)
    e2f_out_dim = use_cuda(Variable(torch.sum(entity2fact_mat.to_dense(), dim=1), requires_grad=False))

    # load entity embedding
    local_entity_emb = self.EntityEmbedding(local_entity)

    # label propagation on entities
    for i in range(self.gnn_layers):
        # get linear transformation functions for each layer
        q2e_linear = getattr(self, 'q2e_linear' + str(i))
        d2e_linear = getattr(self, 'd2e_linear' + str(i))
        e2q_linear = getattr(self, 'e2q_linear' + str(i))
        e2d_linear = getattr(self, 'e2d_linear' + str(i))
        e2e_linear = getattr(self, 'e2e_linear' + str(i))

        kb_self_linear = getattr(self, 'kb_self_linear' + str(i))
        kb_head_linear = getattr(self, 'kb_head_linear' + str(i))
        kb_tail_linear = getattr(self, 'kb_tail_linear' + str(i))

        # start propagation
        next_local_entity_emb = local_entity_emb

        # STEP 1: propagate from question, documents, and facts to entities
        # question -> entity
        q2e_emb = q2e_linear(self.linear_drop(query_node_emb)).expand(batch_size, max_local_entity, self.trans_units)
        next_local_entity_emb = torch.cat((next_local_entity_emb, q2e_emb), dim=2)

        # fact -> entity
        e2f_emb = self.relu(kb_self_linear(local_fact_emb) + self.sparse_bmm(entity2fact_mat, kb_head_linear(self.linear_drop(local_entity_emb))))
        e2f_softmax_normalized = W_tilde.unsqueeze(dim=2) * self.sparse_bmm(entity2fact_mat, (pagerank_f / e2f_softmax).unsqueeze(dim=2))
        e2f_emb = e2f_emb * e2f_softmax_normalized
        f2e_emb = self.relu(kb_self_linear(local_entity_emb) + self.sparse_bmm(fact2entity_mat, kb_tail_linear(self.linear_drop(e2f_emb))))

        pagerank_f = self.pagerank_lambda * self.sparse_bmm(fact2entity_mat, e2f_softmax_normalized).squeeze(dim=2) + (1 - self.pagerank_lambda) * pagerank_f

        # STEP 2: combine embeddings from fact
        next_local_entity_emb = torch.cat((next_local_entity_emb, self.fact_scale * f2e_emb), dim=2)

        # STEP 3: propagate from entities to update question, documents, and facts
        # entity -> query
        query_node_emb = torch.bmm(pagerank_f.unsqueeze(dim=1), e2q_linear(self.linear_drop(next_local_entity_emb)))
        # update entity
        local_entity_emb = self.relu(e2e_linear(self.linear_drop(next_local_entity_emb)))

    return local_entity_emb

  def init_hidden(self, num_layer, batch_size, hidden_size):
    return (use_cuda(Variable(torch.zeros(num_layer, batch_size, hidden_size))),
              use_cuda(Variable(torch.zeros(num_layer, batch_size, hidden_size))))

  def sparse_bmm(self, X, Y):
    """Batch multiply X and Y where X is sparse, Y is dense.
    Args:
        X: Sparse tensor of size BxMxN. Consists of two tensors,
            I:3xZ indices, and V:1xZ values.
        Y: Dense tensor of size BxNxK.
    Returns:
        batched-matmul(X, Y): BxMxK
    """

    class LeftMMFixed(torch.autograd.Function):
        """
        Implementation of matrix multiplication of a Sparse Variable with a Dense Variable, returning a Dense one.
        This is added because there's no autograd for sparse yet. No gradient computed on the sparse weights.
        """

        @staticmethod
        def forward(ctx, sparse_weights, x):
            ctx.save_for_backward(sparse_weights)
            return torch.mm(sparse_weights, x)

        @staticmethod
        def backward(ctx, grad_output):
            sparse_weights, = ctx.saved_tensors
            return None, torch.mm(sparse_weights.t(), grad_output)

    def sparse_mm_fixed(sparse_weights, x):
        return LeftMMFixed.apply(sparse_weights, x)

    I = X._indices()
    V = X._values()
    B, M, N = X.size()
    _, _, K = Y.size()
    Z = I.size()[1]
    lookup = Y[I[0, :], I[2, :], :]
    X_I = torch.stack((I[0, :] * M + I[1, :], torch.arange(Z, device=X.device, dtype=torch.long)), 0)
    S = torch.sparse.FloatTensor(X_I, V, torch.Size([B * M, Z])).to(device=X.device)
    prod = sparse_mm_fixed(S, lookup)
    return prod.view(B, M, K)

#Outer

In [None]:
class OuterEncoder(nn.Module):
  def __init__(self, trans_units, entity_embedding):
    super(OuterEncoder, self).__init__()
    self.EntityEmbedding = entity_embedding
    self.trans_units = trans_units

    self.head_tail_linear = nn.Linear(in_features = self.trans_units * 2, out_features = self.trans_units)
    self.one_two_entity_linear = nn.Linear(in_features = self.trans_units, out_features = self.trans_units)
    self.softmax_d2 = nn.Softmax(dim = 2)

  def forward(self, batch_size, one_two_triples_id, one_two_triple_num):
    one_two_triples_embedding = self.EntityEmbedding(one_two_triples_id).reshape([batch_size, one_two_triple_num, -1, 3 * self.trans_units])

    head, relation, tail = torch.split(one_two_triples_embedding, [self.trans_units] * 3, 3)
    head_tail = torch.cat((head, tail), 3)
    head_tail_transformed = torch.tanh(self.head_tail_linear(head_tail))

    relation_transformed = self.one_two_entity_linear(relation)

    e_weight = torch.sum(relation_transformed * head_tail_transformed, 3)
    alpha_weight = self.softmax_d2(e_weight)

    one_two_embed = torch.sum(alpha_weight.unsqueeze(3) * head_tail, 2)

    return one_two_embed

# Conceptflow

In [None]:
class ConceptFlow(nn.Module):

  def __init__(self, config, word_embed, entity_embed, is_select=False):
    super(ConceptFlow, self).__init__()
    self.is_select = is_select
    self.is_inference = False

    self.trans_units = config.trans_units
    self.embed_units = config.embed_units
    self.units = config.units
    self.layers = config.layers
    self.gnn_layers = config.gnn_layers
    self.symbols = config.symbols

    self.WordEmbedding = WordEmbedding(word_embed, self.embed_units)
    self.EntityEmbedding = EntityEmbedding(entity_embed, self.trans_units)
    self.CentralEncoder = CentralEncoder(config, self.gnn_layers, self.embed_units, self.trans_units, self.WordEmbedding, self.EntityEmbedding)
    self.OuterEncoder = OuterEncoder(self.trans_units, self.EntityEmbedding)

    self.softmax_d1 = nn.Softmax(dim = 1)
    self.softmax_d2 = nn.Softmax(dim = 2)

    self.text_encoder = nn.GRU(input_size = self.embed_units, hidden_size = self.units, num_layers = self.layers, batch_first = True)
    self.decoder = nn.GRU(input_size = self.units + self.embed_units, hidden_size = self.units, num_layers = self.layers, batch_first = True)

    self.attn_c_linear = nn.Linear(in_features = self.units, out_features = self.units, bias = False)
    self.attn_ce_linear = nn.Linear(in_features = self.trans_units, out_features = 2 * self.units, bias = False)
    self.attn_co_linear = nn.Linear(in_features = 2 * self.trans_units, out_features = 2 * self.units, bias = False)
    self.attn_ct_linear = nn.Linear(in_features = self.trans_units, out_features = 2 * self.units, bias = False)

    self.context_linear = nn.Linear(in_features = 4 * self.units, out_features = self.units, bias = False)


    self.logits_linear = nn.Linear(in_features = self.units, out_features = self.symbols)
    self.selector_linear = nn.Linear(in_features = self.units, out_features = 3)

  def forward(self, batch_data):
    query_text = batch_data['query_text']
    answer_text = batch_data['answer_text']
    local_entity = batch_data['local_entity']
    responses_length = batch_data['responses_length']
    q2e_adj_mat = batch_data['q2e_adj_mat']
    kb_adj_mat = batch_data['kb_adj_mat']
    kb_fact_rel = batch_data['kb_fact_rel']
    match_entity_one_hop = batch_data['match_entity_one_hop']
    only_two_entity = batch_data['only_two_entity']
    match_entity_only_two = batch_data['match_entity_only_two']
    one_two_triples_id = batch_data['one_two_triples_id']
    local_entity_length = batch_data['local_entity_length']
    only_two_entity_length = batch_data['only_two_entity_length']

    if self.is_inference == True:
        word2id = batch_data['word2id']
        entity2id = batch_data['entity2id']
        id2entity = dict()
        for key in entity2id.keys():
            id2entity[entity2id[key]] = key
    else:
        id2entity = None

    batch_size, max_local_entity = local_entity.shape
    _, max_only_two_entity = only_two_entity.shape
    _, one_two_triple_num, one_two_triple_len, _ = one_two_triples_id.shape
    _, max_fact = kb_fact_rel.shape

    # numpy to tensor
    local_entity = use_cuda(Variable(torch.from_numpy(local_entity).type('torch.LongTensor'), requires_grad=False))
    local_entity_mask = use_cuda((local_entity != 0).type('torch.FloatTensor'))
    kb_fact_rel = use_cuda(Variable(torch.from_numpy(kb_fact_rel).type('torch.LongTensor'), requires_grad=False))
    query_text = use_cuda(Variable(torch.from_numpy(query_text).type('torch.LongTensor'), requires_grad=False))
    answer_text = use_cuda(Variable(torch.from_numpy(answer_text).type('torch.LongTensor'), requires_grad=False))
    responses_length = use_cuda(Variable(torch.Tensor(responses_length).type('torch.LongTensor'), requires_grad=False))
    query_mask = use_cuda((query_text != 0).type('torch.FloatTensor'))
    match_entity_one_hop = use_cuda(Variable(torch.from_numpy(match_entity_one_hop).type('torch.LongTensor'), requires_grad=False))
    only_two_entity = use_cuda(Variable(torch.from_numpy(only_two_entity).type('torch.LongTensor'), requires_grad=False))
    match_entity_only_two = use_cuda(Variable(torch.from_numpy(match_entity_only_two).type('torch.LongTensor'), requires_grad=False))
    one_two_triples_id = use_cuda(Variable(torch.from_numpy(one_two_triples_id).type('torch.LongTensor'), requires_grad=False))


    decoder_len = answer_text.shape[1]
    encoder_len = query_text.shape[1]
    responses_target = answer_text
    responses_id = torch.cat((use_cuda(torch.ones([batch_size, 1], device=device).type('torch.LongTensor')),torch.split(answer_text, [decoder_len - 1, 1], 1)[0]), 1)

    # ★☆★☆★☆★☆★☆★☆★☆★☆★☆★☆★☆★☆★☆★☆
    # encode central graph
    # print('batch_size', batch_size)
    # print('max_local_entity', max_local_entity)
    # print('max_fact',max_fact)
    # print('query_text',query_text)
    # print('local_entity',local_entity)
    # print('q2e_adj_mat',q2e_adj_mat)
    # print('kb_adj_mat', kb_adj_mat)
    # print('kb_fact_rel',kb_fact_rel)
    # print('query_mask', query_mask)

    local_entity_emb = self.CentralEncoder(batch_size, max_local_entity, max_fact, query_text, local_entity, q2e_adj_mat, kb_adj_mat, kb_fact_rel, query_mask)
    local_entity_emb = local_entity_emb.to(device=device)
    # print('local_entity_emb',local_entity_emb)

    # encode text
    text_encoder_input = self.WordEmbedding(query_text)
    text_encoder_output, text_encoder_state = self.text_encoder(text_encoder_input, use_cuda(Variable(torch.zeros(self.layers, batch_size, self.units))))

    # encode outer graph
    one_two_embed = self.OuterEncoder(batch_size, one_two_triples_id, one_two_triple_num)
    one_two_embed = one_two_embed.to(device=device)
    # prepare decoder input for training
    decoder_input = self.WordEmbedding(responses_id)

    # attention key and values
    c_attention_keys = self.attn_c_linear(text_encoder_output)
    c_attention_values = text_encoder_output
    ce_attention_keys, ce_attention_values = torch.split(self.attn_ce_linear(local_entity_emb), [self.units, self.units], 2)
    co_attention_keys, co_attention_values = torch.split(self.attn_co_linear(one_two_embed), [self.units, self.units], 2)
    only_two_entity_embed = self.EntityEmbedding(only_two_entity)
    ct_attention_keys, ct_attention_values = torch.split(self.attn_ct_linear(only_two_entity_embed), [self.units, self.units], 2)


    decoder_state = text_encoder_state
    decoder_output = use_cuda(torch.empty(0))
    ce_alignments = use_cuda(torch.empty(0))
    co_alignments = use_cuda(torch.empty(0))
    ct_alignments = use_cuda(torch.empty(0))

    # central entity mask
    local_entity_mask = np.zeros([batch_size, local_entity.shape[1]])
    for i in range(batch_size):
        local_entity_mask[i][0:local_entity_length[i]] = 1
    local_entity_mask = use_cuda(torch.from_numpy(local_entity_mask).type('torch.LongTensor'))

    # two-hop entity mask
    only_two_entity_mask = np.zeros([batch_size, only_two_entity.shape[1]])
    for i in range(batch_size):
        only_two_entity_mask[i][0:only_two_entity_length[i]] = 1
    only_two_entity_mask = use_cuda(torch.from_numpy(only_two_entity_mask).type('torch.LongTensor'))

    context = use_cuda(torch.zeros([batch_size, self.units]))

    if not self.is_inference:
        for t in range(decoder_len):
            decoder_input_t = torch.cat((decoder_input[:,t,:], context), 1).unsqueeze(1)

            decoder_output_t, decoder_state = self.decoder(decoder_input_t, decoder_state)
            context, ce_alignments_t, co_alignments_t, ct_alignments_t = self.attention(c_attention_keys, c_attention_values, \
                ce_attention_keys, ce_attention_values, co_attention_keys, co_attention_values, ct_attention_keys, \
                decoder_output_t.squeeze(1), local_entity_mask, only_two_entity_mask)
            decoder_output_t = context.unsqueeze(1)
            ce_alignments = torch.cat((ce_alignments, ce_alignments_t.unsqueeze(1)), 1)

            co_alignments = torch.cat((co_alignments, co_alignments_t.unsqueeze(1)), 1)
            decoder_output = torch.cat((decoder_output, decoder_output_t), 1)
            ct_alignments = torch.cat((ct_alignments, ct_alignments_t.unsqueeze(1)), 1)

    else:
        word_index = use_cuda(torch.empty(0).type('torch.LongTensor'))
        decoder_input_t = self.WordEmbedding(use_cuda(torch.ones([batch_size]).type('torch.LongTensor')))
        context = use_cuda(torch.zeros([batch_size, self.units]))
        decoder_state = text_encoder_state
        selector = use_cuda(torch.empty(0).type('torch.LongTensor'))

        for t in range(decoder_len):
            decoder_input_t = torch.cat((decoder_input_t, context), 1).unsqueeze(1)
            decoder_output_t, decoder_state = self.decoder(decoder_input_t, decoder_state)
            context, ce_alignments_t, co_alignments_t, ct_alignments_t = self.attention(c_attention_keys, c_attention_values, \
                ce_attention_keys, ce_attention_values, co_attention_keys, co_attention_values, ct_attention_keys, \
                decoder_output_t.squeeze(1), local_entity_mask, only_two_entity_mask)
            ct_alignments = torch.cat((ct_alignments, ct_alignments_t.unsqueeze(1)), 1)
            decoder_output_t = context.unsqueeze(1)

            decoder_input_t, word_index_t, selector_t = self.inference(decoder_output_t, ce_alignments_t, ct_alignments_t, word2id, \
                local_entity, only_two_entity, id2entity)
            word_index = torch.cat((word_index, word_index_t.unsqueeze(1)), 1)
            selector = torch.cat((selector, selector_t.unsqueeze(1)), 1)

    decoder_mask = np.zeros([batch_size, decoder_len])
    for i in range(batch_size):
        decoder_mask[i][0:responses_length[i]] = 1
    decoder_mask = use_cuda(torch.from_numpy(decoder_mask).type('torch.LongTensor'))

    one_hot_entities_local = use_cuda(torch.zeros(batch_size, decoder_len, max_local_entity))
    for b in range(batch_size):
        for d in range(decoder_len):
            if match_entity_one_hop[b][d] == -1:
                continue
            else:
                one_hot_entities_local[b][d][match_entity_one_hop[b][d]] = 1

    use_entities_local = torch.sum(one_hot_entities_local, [2])

    one_hot_entities_only_two = use_cuda(torch.zeros(batch_size, decoder_len, max_only_two_entity))
    for b in range(batch_size):
        for d in range(decoder_len):
            if match_entity_only_two[b][d] == -1:
                continue
            else:
                one_hot_entities_only_two[b][d][match_entity_only_two[b][d]] = 1

    use_entities_only_two = torch.sum(one_hot_entities_only_two, [2])

    if not self.is_inference:
        decoder_loss, ppx_loss, sentence_ppx, sentence_ppx_word, sentence_ppx_local, sentence_ppx_only_two, \
            word_neg_num, local_neg_num, only_two_neg_num = self.total_loss(decoder_output, responses_target, decoder_mask, \
            ce_alignments, ct_alignments, use_entities_local, one_hot_entities_local, use_entities_only_two, one_hot_entities_only_two)

    if self.is_select:
        self.sort(id2entity, ct_alignments, only_two_entity)

    if self.is_inference == True:
        return word_index.cpu().numpy().tolist(), selector.cpu().numpy().tolist()
    return decoder_loss, sentence_ppx, sentence_ppx_word, sentence_ppx_local, sentence_ppx_only_two, word_neg_num, local_neg_num, only_two_neg_num

  def sort(self, id2entity, ct_alignments, only_two_entity):
    only_two_score = torch.sum(ct_alignments, 1)
    _, sort_local_index = only_two_score.sort(1)
    sort_global_index = torch.gather(only_two_entity, 1, sort_local_index)
    sort_global_index = sort_global_index.cpu().numpy().tolist()

    sort_str = []
    for i in range(len(sort_global_index)):
      tmp = []
      for j in range(len(sort_global_index[i])):
        if sort_global_index[i][j] == 1:
          continue
        tmp.append(id2entity[sort_global_index[i][j]])
      sort_str.append(tmp)

    sort_f = open('selected_concept.txt','a')
    for line in sort_str:
      sort_f.write(str(line) + '\n')
    sort_f.close()


  def inference(self, decoder_output_t, ce_alignments_t, ct_alignments_t, word2id, local_entity, only_two_entity, id2entity):
    batch_size = decoder_output_t.shape[0]

    logits = self.logits_linear(decoder_output_t.squeeze(1)) # batch * num_symbols

    selector = self.softmax_d1(self.selector_linear(decoder_output_t.squeeze(1)))

    (word_prob, word_t) = torch.max(selector[:,0].unsqueeze(1) * self.softmax_d1(logits), dim = 1)
    (local_entity_prob, local_entity_l_index_t) = torch.max(selector[:,1].unsqueeze(1) * ce_alignments_t, dim = 1)
    (only_two_entity_prob, only_two_entity_l_index_t) = torch.max(selector[:,2].unsqueeze(1) * ct_alignments_t, dim = 1)

    selector[:,0] = selector[:,0] * word_prob
    selector[:,1] = selector[:,1] * local_entity_prob
    selector[:,2] = selector[:,2] * only_two_entity_prob
    selector = torch.argmax(selector, dim = 1)

    local_entity_l_index_t = local_entity_l_index_t.cpu().numpy().tolist()
    only_two_entity_l_index_t = only_two_entity_l_index_t.cpu().numpy().tolist()
    word_t = word_t.cpu().numpy().tolist()

    word_local_entity_t = []
    word_only_two_entity_t = []
    word_index_final_t = []
    for i in range(batch_size):
        if selector[i] == 0:
            word_index_final_t.append(word_t[i])
            continue
        if selector[i] == 1:
            local_entity_index_t = int(local_entity[i][local_entity_l_index_t[i]])
            local_entity_text = id2entity[local_entity_index_t]
            if local_entity_text not in word2id:
                local_entity_text = '_UNK'
            word_index_final_t.append(word2id[local_entity_text])
            continue
        if selector[i] == 2:
            only_two_entity_index_t = int(only_two_entity[i][only_two_entity_l_index_t[i]])
            only_two_entity_text = id2entity[only_two_entity_index_t]
            if only_two_entity_text not in word2id:
                only_two_entity_text = '_UNK'
            word_index_final_t.append(word2id[only_two_entity_text])
            continue

    word_index_final_t = use_cuda(torch.LongTensor(word_index_final_t))
    decoder_input_t = self.WordEmbedding(word_index_final_t)

    return decoder_input_t, word_index_final_t, selector

  def total_loss(self, decoder_output, responses_target, decoder_mask, ce_alignments, ct_alignments, use_entities_local, \
        entity_targets_local, use_entities_only_two, entity_targets_only_two):
    batch_size = decoder_output.shape[0]
    decoder_len = responses_target.shape[1]

    local_masks = use_cuda(decoder_mask.reshape([-1]).type("torch.FloatTensor"))
    local_masks_word = use_cuda((1 - use_entities_local - use_entities_only_two).reshape([-1]).type("torch.FloatTensor")) * local_masks
    local_masks_local = use_cuda(use_entities_local.reshape([-1]).type("torch.FloatTensor"))
    local_masks_only_two = use_cuda(use_entities_only_two.reshape([-1]).type("torch.FloatTensor"))
    logits = self.logits_linear(decoder_output) #batch * decoder_len * num_symbols

    word_prob = torch.gather(self.softmax_d2(logits), 2, responses_target.unsqueeze(2)).squeeze(2)

    selector_word, selector_local, selector_only_two = torch.split(self.softmax_d2(self.selector_linear(decoder_output)), [1, 1, 1], 2) #batch_size * decoder_len * 1
    selector_word = selector_word.squeeze(2)
    selector_local = selector_local.squeeze(2)
    selector_only_two = selector_only_two.squeeze(2)

    entity_prob_local = torch.sum(ce_alignments * entity_targets_local, [2])
    entity_prob_only_two = torch.sum(ct_alignments * entity_targets_only_two, [2])

    ppx_prob = word_prob * (1 - use_entities_local - use_entities_only_two) + entity_prob_local * use_entities_local + entity_prob_only_two * use_entities_only_two
    ppx_word = word_prob * (1 - use_entities_local - use_entities_only_two)
    ppx_local = entity_prob_local * use_entities_local
    ppx_only_two = entity_prob_only_two * use_entities_only_two

    final_prob = word_prob * selector_word * (1 - use_entities_local - use_entities_only_two) + entity_prob_local * selector_local * \
        use_entities_local + entity_prob_only_two * selector_only_two * use_entities_only_two

    final_loss = torch.sum(- torch.log(1e-12 + final_prob).reshape([-1]) * local_masks)

    sentence_ppx = torch.sum((- torch.log(1e-12 + ppx_prob).reshape([-1]) * local_masks).reshape([batch_size, -1]), 1)
    sentence_ppx_word = torch.sum((- torch.log(1e-12 + ppx_word).reshape([-1]) * local_masks_word).reshape([batch_size, -1]), 1)
    sentence_ppx_local = torch.sum((- torch.log(1e-12 + ppx_local).reshape([-1]) * local_masks_local).reshape([batch_size, -1]), 1)
    sentence_ppx_only_two = torch.sum((- torch.log(1e-12 + ppx_only_two).reshape([-1]) * local_masks_only_two).reshape([batch_size, -1]), 1)

    selector_loss = torch.sum(- torch.log(1e-12 + selector_local * use_entities_local + selector_only_two * use_entities_only_two + \
        selector_word * (1 - use_entities_local - use_entities_only_two)).reshape([-1]) * local_masks)

    loss = final_loss + selector_loss
    total_size = torch.sum(local_masks)
    total_size += 1e-12

    sum_word = torch.sum(use_cuda(((1 - use_entities_local - use_entities_only_two) * use_cuda(decoder_mask.type("torch.FloatTensor"))).type("torch.FloatTensor")), 1)
    sum_local = torch.sum(use_cuda(use_entities_local.type("torch.FloatTensor")), 1)
    sum_only_two= torch.sum(use_cuda(use_entities_only_two.type("torch.FloatTensor")), 1)

    word_neg_mask = use_cuda((sum_word == 0).type("torch.FloatTensor"))
    local_neg_mask = use_cuda((sum_local == 0).type("torch.FloatTensor"))
    only_two_neg_mask = use_cuda((sum_only_two == 0).type("torch.FloatTensor"))

    word_neg_num = torch.sum(word_neg_mask)
    local_neg_num = torch.sum(local_neg_mask)
    only_two_neg_num = torch.sum(only_two_neg_mask)

    sum_word = sum_word + word_neg_mask
    sum_local = sum_local + local_neg_mask
    sum_only_two = sum_only_two + only_two_neg_mask

    return loss / total_size, 0, sentence_ppx / torch.sum(use_cuda(decoder_mask.type("torch.FloatTensor")), 1), \
        sentence_ppx_word / sum_word, sentence_ppx_local / sum_local, sentence_ppx_only_two / sum_only_two, word_neg_num, \
        local_neg_num, only_two_neg_num



  def attention(self, c_attention_keys, c_attention_values, ce_attention_keys, ce_attention_values, co_attention_keys, \
        co_attention_values, ct_attention_keys, decoder_state, local_entity_mask, only_two_entity_mask):
    batch_size = ct_attention_keys.shape[0]
    only_two_len = ct_attention_keys.shape[1]

    c_query = decoder_state.reshape([-1, 1, self.units])
    ce_query = decoder_state.reshape([-1, 1, self.units])
    co_query = decoder_state.reshape([-1, 1, self.units])
    ct_query = decoder_state.reshape([-1, 1, self.units])

    c_scores = torch.sum(c_attention_keys * c_query, 2)
    ce_scores = torch.sum(ce_attention_keys * ce_query, 2)
    co_scores = torch.sum(co_attention_keys * co_query, 2)
    ct_scores = torch.sum(ct_attention_keys * ct_query, 2)

    c_alignments = self.softmax_d1(c_scores)
    ce_alignments = self.softmax_d1(ce_scores)
    co_alignments = self.softmax_d1(co_scores)
    ct_alignments = self.softmax_d1(ct_scores)

    ce_alignments = ce_alignments * use_cuda(local_entity_mask.type("torch.FloatTensor"))
    ct_alignments = ct_alignments * use_cuda(only_two_entity_mask.type("torch.FloatTensor"))

    c_context = torch.sum(c_alignments.unsqueeze(2) * c_attention_values, 1)
    ce_context = torch.sum(ce_alignments.unsqueeze(2) * ce_attention_values, 1)
    co_context = torch.sum(co_alignments.unsqueeze(2) * co_attention_values, 1)

    context = self.context_linear(torch.cat((decoder_state, c_context, ce_context, co_context), 1))

    return context, ce_alignments, co_alignments, ct_alignments


# prepare_data

In [None]:
def prepare_data(config):
    global csk_entities, csk_triples, kb_dict, dict_csk_entities, dict_csk_triples

    with open('%s/resource.txt' % config.data_dir) as f:
        d = json.loads(f.readline())

    csk_triples = d['csk_triples']
    csk_entities = d['csk_entities']
    raw_vocab = d['vocab_dict']
    kb_dict = d['dict_csk']
    dict_csk_entities = d['dict_csk_entities']
    dict_csk_triples = d['dict_csk_triples']

    data_train, data_test = [], []

    if config.is_train:
        with open('%s/trainset3.txt' % config.data_dir) as f:
            for idx, line in enumerate(f):
                if idx % 100000 == 0: print('read train file line %d' % idx)
                data_train.append(json.loads(line))


    with open('%s/testset.txt' % config.data_dir) as f:
        for line in f:
            data_test.append(json.loads(line))

    return raw_vocab, data_train, data_test

def build_vocab(path, raw_vocab, config, trans='transE'):

    print("Creating word vocabulary...")
    vocab_list = ['_PAD','_GO', '_EOS', '_UNK', ] + sorted(raw_vocab, key=raw_vocab.get, reverse=True)
    if len(vocab_list) > config.symbols:
        vocab_list = vocab_list[:config.symbols]

    print("Creating entity vocabulary...")
    entity_list = ['_NONE', '_PAD_H', '_PAD_R', '_PAD_T', '_NAF_H', '_NAF_R', '_NAF_T']
    with open('%s/entity.txt' % path) as f:
        for i, line in enumerate(f):
            e = line.strip()
            entity_list.append(e)

    print("Creating relation vocabulary...")
    relation_list = []
    with open('%s/relation.txt' % path) as f:
        for i, line in enumerate(f):
            r = line.strip()
            relation_list.append(r)

    print("Loading word vectors...")
    vectors = {}
    with open('%s/glove.840B.300d.txt' % path, encoding='utf-8') as f:
        for i, line in enumerate(f):
            if i % 100000 == 0:
                print("    processing line %d" % i)
            s = line.strip()
            word = s[:s.find(' ')]
            vector = s[s.find(' ')+1:]
            vectors[word] = vector

    embed = []
    for word in vocab_list:
        if word in vectors:
            #vector = map(float, vectors[word].split())
            vector = vectors[word].split()
        else:
            vector = np.zeros((config.embed_units), dtype=np.float32)
        embed.append(vector)
    embed = np.array(embed, dtype=np.float32)

    print("Loading entity vectors...")
    entity_embed = []
    with open('%s/entity_%s.txt' % (path, trans)) as f:
        for i, line in enumerate(f):
            s = line.strip().split('\t')
            #entity_embed.append(map(float, s))
            entity_embed.append(s)

    print("Loading relation vectors...")
    relation_embed = []
    with open('%s/relation_%s.txt' % (path, trans)) as f:
        for i, line in enumerate(f):
            s = line.strip().split('\t')
            relation_embed.append(s)

    entity_relation_embed = np.array(entity_embed+relation_embed, dtype=np.float32)
    entity_embed = np.array(entity_embed, dtype=np.float32)
    relation_embed = np.array(relation_embed, dtype=np.float32)

    word2id = dict()
    entity2id = dict()
    for word in vocab_list:
        word2id[word] = len(word2id)
    for entity in entity_list + relation_list:
        entity2id[entity] = len(entity2id)

    return word2id, entity2id, vocab_list, embed, entity_list, entity_embed, relation_list, relation_embed, entity_relation_embed

def gen_batched_data(data, config, word2id, entity2id):
    global csk_entities, csk_triples, kb_dict, dict_csk_entities, dict_csk_triples

    encoder_len = max([len(item['post']) for item in data])+1

    decoder_len = max([len(item['response']) for item in data])+1
    triple_num = max([len(item['all_triples_one_hop']) for item in data])
    entity_len = max([len(item['all_entities_one_hop']) + max(item['post_triples']) for item in data])
    only_two_entity_len = max([len(item['only_two']) for item in data])
    triple_num_one_two = max([len(item['one_two_triple']) for item in data])
    triple_len_one_two = max([len(tri) for item in data for tri in item['one_two_triple']])
    posts_id = np.full((len(data), encoder_len), 0, dtype=int)
    responses_id = np.full((len(data), decoder_len), 0, dtype=int)
    responses_length = []
    # posts_length = []
    local_entity_length = []
    only_two_entity_length = []
    local_entity = []
    only_two_entity = []
    kb_fact_rels = np.full((len(data), triple_num), 2, dtype=int)
    kb_adj_mats = np.empty(len(data), dtype=object)
    q2e_adj_mats = np.full((len(data), entity_len), 0, dtype=int)
    match_entity_one_hop = np.full((len(data), decoder_len), -1, dtype=int)
    match_entity_only_two = np.full((len(data), decoder_len), -1, dtype=int)
    one_two_triples_id = []
    g2l_only_two_list = []
    # o2t_entity_index_list = []

    next_id = 0
    for item in data:
        # posts
        for i, post_word in enumerate(padding(item['post'], encoder_len)):
            if post_word in word2id:
                posts_id[next_id, i] = word2id[post_word]

            else:
                posts_id[next_id, i] = word2id['_UNK']

        # responses
        for i, response_word in enumerate(padding(item['response'], decoder_len)):
            if response_word in word2id:
                responses_id[next_id, i] = word2id[response_word]

            else:
                responses_id[next_id, i] = word2id['_UNK']

        # responses_length
        responses_length.append(len(item['response']) + 1)

        # local_entity
        local_entity_tmp = []
        for i in range(len(item['post_triples'])):
            if item['post_triples'][i] == 0:
                continue
            elif item['post'][i] not in entity2id:
                continue
            elif entity2id[item['post'][i]] in local_entity_tmp:
                continue
            else:
                local_entity_tmp.append(entity2id[item['post'][i]])

        for entity_index in item['all_entities_one_hop']:
            if csk_entities[entity_index] not in entity2id:
                continue
            if entity2id[csk_entities[entity_index]] in local_entity_tmp:
                continue
            else:
                local_entity_tmp.append(entity2id[csk_entities[entity_index]])
        local_entity_len_tmp = len(local_entity_tmp)
        local_entity_tmp += [1] * (entity_len - len(local_entity_tmp))
        local_entity.append(local_entity_tmp)

        # kb_adj_mat and kb_fact_rel
        g2l = dict()
        for i in range(len(local_entity_tmp)):
            g2l[local_entity_tmp[i]] = i

        entity2fact_e, entity2fact_f = [], []
        fact2entity_f, fact2entity_e = [], []

        tmp_count = 0
        for i in range(len(item['all_triples_one_hop'])):
            sbj = csk_triples[item['all_triples_one_hop'][i]].split()[0][:-1]
            rel = csk_triples[item['all_triples_one_hop'][i]].split()[1][:-1]
            obj = csk_triples[item['all_triples_one_hop'][i]].split()[2]

            if (sbj not in entity2id) or (obj not in entity2id):
                continue
            if (entity2id[sbj] not in g2l) or (entity2id[obj] not in g2l):
                continue

            entity2fact_e += [g2l[entity2id[sbj]]]
            entity2fact_f += [tmp_count]
            fact2entity_f += [tmp_count]
            fact2entity_e += [g2l[entity2id[obj]]]
            kb_fact_rels[next_id, tmp_count] = entity2id[rel]
            tmp_count += 1

        kb_adj_mats[next_id] = (np.array(entity2fact_f, dtype=int), np.array(entity2fact_e, dtype=int), np.array([1.0] * len(entity2fact_f))), (np.array(fact2entity_e, dtype=int), np.array(fact2entity_f, dtype=int), np.array([1.0] * len(fact2entity_e)))

        # q2e_adj_mat
        for i in range(len(item['post_triples'])):
            if item['post_triples'][i] == 0:
                continue
            elif item['post'][i] not in entity2id:
                continue
            else:
                q2e_adj_mats[next_id, g2l[entity2id[item['post'][i]]]] = 1

        # match_entity_one_hop
        for i in range(len(item['match_response_index_one_hop'])):
            if item['match_response_index_one_hop'][i] == -1:
                continue
            if csk_entities[item['match_response_index_one_hop'][i]] not in entity2id:
                continue
            if entity2id[csk_entities[item['match_response_index_one_hop'][i]]] not in g2l:
                continue
            else:
                match_entity_one_hop[next_id, i] = g2l[entity2id[csk_entities[item['match_response_index_one_hop'][i]]]]

        # only_two_entity
        only_two_entity_tmp = []
        for entity_index in item['only_two']:
            if csk_entities[entity_index] not in entity2id:
                continue
            if entity2id[csk_entities[entity_index]] in only_two_entity_tmp:
                continue
            else:
                only_two_entity_tmp.append(entity2id[csk_entities[entity_index]])
        only_two_entity_len_tmp = len(only_two_entity_tmp)
        only_two_entity_tmp += [1] * (only_two_entity_len - len(only_two_entity_tmp))
        only_two_entity.append(only_two_entity_tmp)

        # match_entity_two_hop
        g2l_only_two = dict()
        for i in range(len(only_two_entity_tmp)):
            g2l_only_two[only_two_entity_tmp[i]] = i

        for i in range(len(item['match_response_index_only_two'])):
            if item['match_response_index_only_two'][i] == -1:
                continue
            if csk_entities[item['match_response_index_only_two'][i]] not in entity2id:
                continue
            else:
                match_entity_only_two[next_id, i] = g2l_only_two[entity2id[csk_entities[item['match_response_index_only_two'][i]]]]

        # one_two_triple
        one_two_triples_id.append(padding_triple_id(entity2id, [[csk_triples[x].split(', ') for x in triple] for triple in item['one_two_triple']], triple_num_one_two, triple_len_one_two))

        ############################ g2l_only_two
        g2l_only_two_list.append(g2l_only_two)

        # local_entity_length
        local_entity_length.append(local_entity_len_tmp)

        # only_two_entity_length
        only_two_entity_length.append(only_two_entity_len_tmp)

        next_id += 1

    batched_data = {'query_text': np.array(posts_id),
            'answer_text': np.array(responses_id),
            'local_entity': np.array(local_entity),
            'responses_length': responses_length,
            'q2e_adj_mat': np.array(q2e_adj_mats),
            'kb_adj_mat': build_kb_adj_mat(kb_adj_mats, config.fact_dropout),
            'kb_fact_rel': np.array(kb_fact_rels),
            'match_entity_one_hop': np.array(match_entity_one_hop),
            'only_two_entity': np.array(only_two_entity),
            'match_entity_only_two': np.array(match_entity_only_two),
            'one_two_triples_id': np.array(one_two_triples_id),
            'word2id': word2id,
            'entity2id': entity2id,
            'local_entity_length': local_entity_length,
            'only_two_entity_length': only_two_entity_length}

    return batched_data

# Train / Test

In [None]:
#coding:utf-8
csk_triples, csk_entities, kb_dict = [], [], []
dict_csk_entities, dict_csk_triples = {}, {}

class Config():
    def __init__(self, path):
        self.config_path = path
        self._get_config()

    def _get_config(self):
        with open("/content/drive/MyDrive/Colab Notebooks/ConceptFlow-master/config.yml", "r") as setting:
            config = yaml.load(setting, Loader=yaml.FullLoader)
        self.is_train = config['is_train']
        self.test_model_path = config['test_model_path']
        self.embed_units = config['embed_units']
        self.symbols = config['symbols']
        self.units = config['units']
        self.layers = config['layers']
        self.batch_size = config['batch_size']
        self.data_dir = config['data_dir']
        self.num_epoch = config['num_epoch']
        self.lr_rate = config['lr_rate']
        self.lstm_dropout = config['lstm_dropout']
        self.linear_dropout = config['linear_dropout']
        self.max_gradient_norm = config['max_gradient_norm']
        self.trans_units = config['trans_units']
        self.gnn_layers = config['gnn_layers']
        self.fact_dropout = config['fact_dropout']
        self.fact_scale = config['fact_scale']
        self.pagerank_lambda = config['pagerank_lambda']
        self.result_dir_name = config['result_dir_name']

    def list_all_member(self):
        for name, value in vars(self).items():
            print('%s = %s' % (name, value))


def run(model, data_train, config, word2id, entity2id):
    batched_data = gen_batched_data(data_train, config, word2id, entity2id)

    if model.is_inference == True:
        word_index, selector = model(batched_data)
        return word_index, selector
    else:
        decoder_loss, sentence_ppx, sentence_ppx_word, sentence_ppx_local, sentence_ppx_only_two, word_neg_num, local_neg_num, only_two_neg_num = model(batched_data)
        return decoder_loss, sentence_ppx, sentence_ppx_word, sentence_ppx_local, sentence_ppx_only_two, word_neg_num, local_neg_num, only_two_neg_num

def train(config, model, data_train, data_test, word2id, entity2id, model_optimizer):
  for epoch in range(config.num_epoch):
    print ("epoch: ", epoch)
    sentence_ppx_loss = 0
    sentence_ppx_word_loss = 0
    sentence_ppx_local_loss = 0
    sentence_ppx_only_two_loss = 0

    # word_cut = use_cuda(torch.Tensor([0]))
    # local_cut = use_cuda(torch.Tensor([0]))
    # only_two_cut = use_cuda(torch.Tensor([0]))
    word_cut = torch.tensor([0], device=device)
    local_cut = torch.tensor([0], device=device)
    only_two_cut = torch.tensor([0], device=device)

    count = 0
    for iteration in range(len(data_train) // config.batch_size):
        decoder_loss, sentence_ppx, sentence_ppx_word, sentence_ppx_local, sentence_ppx_only_two, word_neg_num, local_neg_num, \
            only_two_neg_num = run(model, data_train[(iteration * config.batch_size):(iteration * \
            config.batch_size + config.batch_size)], config, word2id, entity2id)
        sentence_ppx_loss += torch.sum(sentence_ppx).data
        sentence_ppx_word_loss += torch.sum(sentence_ppx_word).data
        sentence_ppx_local_loss += torch.sum(sentence_ppx_local).data
        sentence_ppx_only_two_loss += torch.sum(sentence_ppx_only_two).data

        # print('word_cut', word_cut.type())
        # print('word_neg_num',word_neg_num.type())
        word_cut = word_cut.float()
        local_cut= local_cut.float()
        only_two_cut = only_two_cut.float()
        # print('word_cut',word_cut.type())

        word_cut += word_neg_num
        local_cut += local_neg_num
        only_two_cut += only_two_neg_num

        model_optimizer.zero_grad()
        decoder_loss.backward()
        torch.nn.utils.clip_grad_norm(model.parameters(), config.max_gradient_norm)
        model_optimizer.step()

        if count % 50 == 0:
            print ("iteration:", iteration, "Loss:", decoder_loss.data)
        count += 1

    print ("perplexity for epoch", epoch + 1, ":", np.exp(sentence_ppx_loss.cpu() / len(data_train)), " ppx_word: ", \
        np.exp(sentence_ppx_word_loss.cpu() / (len(data_train) - int(word_cut))), " ppx_local: ", \
        np.exp(sentence_ppx_local_loss.cpu() / (len(data_train) - int(local_cut))), " ppx_only_two: ", \
        np.exp(sentence_ppx_only_two_loss.cpu() / (len(data_train) - int(only_two_cut))))

    torch.save(model.state_dict(), config.result_dir_name + '/' + '_epoch_' + str(epoch + 1) + '.pkl')
    ppx, ppx_word, ppx_local, ppx_only_two = evaluate(model, data_test, config, word2id, entity2id, epoch + 1)
    ppx_f = open(config.result_dir_name + '/result.txt','a')
    ppx_f.write("epoch " + str(epoch + 1) + " ppx: " + str(ppx) + " ppx_word: " + str(ppx_word) + " ppx_local: " + \
        str(ppx_local) + " ppx_only_two: " + str(ppx_only_two) + '\n')
    ppx_f.close()

def evaluate(model, data_test, config, word2id, entity2id, epoch = 0, model_path = None):
  if model_path != None:
      model.load_state_dict(torch.load(model_path))
  sentence_ppx_loss = 0
  sentence_ppx_word_loss = 0
  sentence_ppx_local_loss = 0
  sentence_ppx_only_two_loss = 0
  word_cut = use_cuda(torch.Tensor([0]))
  local_cut = use_cuda(torch.Tensor([0]))
  only_two_cut = use_cuda(torch.Tensor([0]))
  count = 0
  id2word = dict()
  for key in word2id.keys():
      id2word[word2id[key]] = key


  for iteration in range(len(data_test) // config.batch_size):
    decoder_loss, sentence_ppx, sentence_ppx_word, sentence_ppx_local, sentence_ppx_only_two, word_neg_num, \
        local_neg_num, only_two_neg_num = run(model, data_test[(iteration * config.batch_size):(iteration * \
        config.batch_size + config.batch_size)], config, word2id, entity2id)
    sentence_ppx_loss += torch.sum(sentence_ppx).data
    sentence_ppx_word_loss += torch.sum(sentence_ppx_word).data
    sentence_ppx_local_loss += torch.sum(sentence_ppx_local).data
    sentence_ppx_only_two_loss += torch.sum(sentence_ppx_only_two).data

    word_cut = word_cut.float()
    local_cut= local_cut.float()
    only_two_cut = only_two_cut.float()

    word_cut += word_neg_num
    local_cut += local_neg_num
    only_two_cut += only_two_neg_num

    if count % 50 == 0:
        print ("iteration for evaluate:", iteration, "Loss:", decoder_loss.data)
    count += 1

  model.is_inference = False
  if model_path != None:
    print('    perplexity on test set:', np.exp(sentence_ppx_loss.cpu() / len(data_test)), \
        np.exp(sentence_ppx_word_loss.cpu() / (len(data_test) - int(word_cut))), np.exp(sentence_ppx_local_loss.cpu() / (len(data_test) \
        - int(local_cut))), np.exp(sentence_ppx_only_two_loss.cpu() / (len(data_test) - int(only_two_cut))))
    exit()
  print('    perplexity on test set:', np.exp(sentence_ppx_loss.cpu() / len(data_test)), np.exp(sentence_ppx_word_loss.cpu() / \
    (len(data_test) - int(word_cut))), np.exp(sentence_ppx_local_loss.cpu() / (len(data_test) - int(local_cut))), \
    np.exp(sentence_ppx_only_two_loss.cpu() / (len(data_test) - int(only_two_cut))))
  return np.exp(sentence_ppx_loss.cpu() / len(data_test)), np.exp(sentence_ppx_word_loss.cpu() / (len(data_test) - int(word_cut))), \
    np.exp(sentence_ppx_local_loss.cpu() / (len(data_test) - int(local_cut))), np.exp(sentence_ppx_only_two_loss.cpu() / \
    (len(data_test) - int(only_two_cut)))

def main():
    config = Config('config.yml')
    config.list_all_member()
    raw_vocab, data_train, data_test = prepare_data(config)
    word2id, entity2id, vocab, embed, entity_vocab, entity_embed, relation_vocab, relation_embed, entity_relation_embed = build_vocab(config.data_dir, raw_vocab, config = config)
    # model = use_cuda(ConceptFlow(config, embed, entity_relation_embed))

    model = ConceptFlow(config, embed, entity_relation_embed)

    model = model.to(device=device)
    model_optimizer = torch.optim.Adam(model.parameters(), lr = config.lr_rate)

    if not os.path.exists(config.result_dir_name):
        os.mkdir(config.result_dir_name)
    ppx_f = open(config.result_dir_name + '/result.txt','a')
    for name, value in vars(config).items():
        ppx_f.write('%s = %s' % (name, value) + '\n')

    if config.is_train == False:
        evaluate(model, data_test, config, word2id, entity2id, 0, model_path = config.test_model_path)
        exit()
    train(config, model, data_train, data_test, word2id, entity2id, model_optimizer)

In [None]:
main()

config_path = config.yml
is_train = True
test_model_path = None
embed_units = 300
symbols = 30000
units = 512
layers = 2
batch_size = 10
data_dir = /content/drive/MyDrive/Colab Notebooks/ConceptFlow(ECCF)_data
num_epoch = 20
lr_rate = 0.0001
lstm_dropout = 0.3
linear_dropout = 0.2
max_gradient_norm = 5
trans_units = 100
gnn_layers = 3
fact_dropout = 0.0
fact_scale = 1
pagerank_lambda = 0.8
result_dir_name = training_output
read train file line 0
Creating word vocabulary...
Creating entity vocabulary...
Creating relation vocabulary...
Loading word vectors...
    processing line 0
    processing line 100000
    processing line 200000
    processing line 300000
    processing line 400000
    processing line 500000
    processing line 600000
    processing line 700000
    processing line 800000
    processing line 900000
    processing line 1000000
    processing line 1100000
    processing line 1200000
    processing line 1300000
    processing line 1400000
    processing line 1500000
   

In [None]:
#Assign cuda GPU located at location '0' to a variable
# cuda0 = torch.device('cuda:0')
#Performing the addition on GPU
a = torch.ones(3, 2, device=device) #creating a tensor 'a' on GPU
b = torch.ones(3, 2, device=device) #creating a tensor 'b' on GPU
c = a + print(c)

NameError: name 'c' is not defined

In [None]:
word_cut = use_cuda(torch.Tensor([0]))
local_cut = use_cuda(torch.Tensor([0]))
only_two_cut = use_cuda(torch.Tensor([0]))

In [None]:
#Assign cuda GPU located at location '0' to a variable
cuda0 = torch.device('cuda:0')
#Performing the addition on GPU
d = torch.ones(3, 2, device=cuda0) #creating a tensor 'a' on GPU
e = torch.ones(3, 2, device=cuda0) #creating a tensor 'b' on GPU
f = a + b
print(f)

tensor([[2., 2.],
        [2., 2.],
        [2., 2.]], device='cuda:0')
