# IRN model impleted by Pytorch

In [692]:
import torch
import torch.nn as nn
import os
from IPython import embed
import torch
import torch.optim as optim
import torch.nn as nn
import argparse
import shutil
import numpy as np
import time
from sklearn import metrics
from sklearn.model_selection import train_test_split
import re
import json

In [693]:
# Preprocess the data

def read_KB(KB_file):
    # example in KB_file: KBs.txt h \t r \t t
    entities = set()
    relations = set()
    if os.path.isfile(KB_file):
        with open(KB_file) as f:
            lines = f.readlines()
    else:
        raise Exception("!! %s is not found!!" % KB_file)

    for line in lines:
        line = line.strip().split('\t')
        entities.add(line[0])
        entities.add(line[2])
        relations.add(line[1])
    return entities, relations


def get_KB(KB_file, ent2id, rel2id):
    nwords = len(ent2id)
    nrels = len(rel2id)
    tails = np.zeros([nwords*nrels, 1], 'int32')
    KBmatrix = np.zeros([nwords * nrels, nwords], 'int32')
    Triples = []

    f = open(KB_file)
    for line in f.readlines():
        line = line.strip().split('\t')
        h = ent2id[line[0]]
        r = rel2id[line[1]]
        t = ent2id[line[2]]
        Triples.append([h, r, t])
        lenlist = tails[h*nrels+r]
        KBmatrix[h*nrels+r, lenlist] = t
        tails[h*nrels+r] += 1
    
    return np.array(Triples), KBmatrix[:, :np.max(tails)], np.max(tails)


def read_data(data_file):
    if os.path.isfile(data_file):
        with open(data_file) as f:
            lines = f.readlines()
    else:
        raise Exception("!! %s is not found!!" % data_file)

    words = set()
    data = []
    questions = []
    doc = []

    for line in lines:
        line = line.strip().split('\t')
        qlist = line[0].strip().split()
        k = line[1].find('(')
        if not k == -1:
            if line[1][k-1] == '_':
                k += (line[1][k+1:-1].find('(') + 1)
            asset = line[1][k+1:-1]
            line[1] = line[1][:k]
        else:
            asset = line[3]
        data.append([line[0], line[1], line[2], asset])

        for w in qlist:
            words.add(w)
        questions.append(qlist)

    sentence_size = max(len(i) for i in questions)

    return words, data, sentence_size


def tokenize(sent):
    '''Return the tokens of a sentence including punctuation.
    >>> tokenize('Bob dropped the apple. Where is the apple?')
    ['Bob', 'dropped', 'the', 'apple', '.', 'Where', 'is', 'the', 'apple', '?']
    '''
    return [x.strip() for x in re.split('(\W+)?', sent) if x.strip()]


def MultiAcc(labels,preds,length):
    #length = path = 2 * hop + 1   (hop == path_l + cons_l + final == path_l * 2 + 1 )
    #compare path and final answer accuracy
    Acc = []

    for i in range(length):
        Acc.append(round(metrics.accuracy_score(labels[:,i],preds[:,i]),3))

    batch_size = preds.shape[0]
    correct = 0.0
    for j in range(batch_size):
        k = length - 1
        while(labels[j,k]==0):
            k -= 2
        if(labels[j,k]==preds[j,k]):
            correct += 1.0   #final answer accuracy 
    Acc.append(round( correct/batch_size ,3))
    return Acc

def InSet(labels,anset,preds):
    #get accuracy(whether in answer set or not)
    #labels does not matter
    #preds is path-list
    #labels is path-labels
    right = 0.0
    for i in range(len(anset)):
        if type(preds[i]) is np.int64:
            ans_pred = preds[i]
        else:
            ans_pred = preds[i,-1]
            '''
            k = len(labels[0]) - 1
            while(labels[i,k]==0):
                k -= 2
            ans_pred = preds[i,k]
            '''
        if ans_pred in anset[i]:
            right += 1
    return round(right/len(anset), 3)

def process_data(KB_file, data_file):
    entities, relations = read_KB(KB_file)
    words, data, sentence_size = read_data(data_file)

    word2id = {}
    ent2id = {}
    rel2id = {}

    word2id['<unk>'] = 0
    rel2id['<end>'] = 0
    ent2id['<unk>'] = 0

    for r in relations:
        # same r_id in rel2id and word2id
        if r not in rel2id.keys():
            rel2id[r] = len(rel2id)
        if r not in word2id.keys():
            word2id[r] = len(word2id)
    for e in entities:
        if e not in ent2id.keys():
            ent2id[e] = len(ent2id)
    for word in words:
        if word not in word2id.keys():
            word2id[word] = len(word2id)

    print('here are %d words in word2id(vocab)' % len(word2id))
    print('here are %d relations in rel2id(rel_vocab)' % len(rel2id))
    print('here are %d entities in ent2id(ent_vocab)' % len(ent2id))

    Triples, KBs, tails_size = get_KB(KB_file, ent2id, rel2id)

    print("The number of records or triples", len(np.nonzero(KBs)[0]))

    Q = []
    QQ = []
    A = []
    AA = []
    P = []
    PP = []
    S = []
    SS = []

    for query, answer, path, answerset in data:
        path = path.strip().split('#')[0:5]  # path = [s,r1,m,r2,t]

        query = query.strip().split()
        ls = max(0, sentence_size-len(query))
        q = [word2id[w] for w in query] + [0] * ls
        Q.append(q)
        QQ.append(query)


        a = np.zeros(len(ent2id))  # if use new ans-vocab, add 0 for 'end'
        a[ent2id[answer]] = 1
        A.append(a)
        AA.append(ent2id[answer])

        p = []
        for i in range(len(path)):
            if i % 2 == 0:
                e = ent2id[path[i]]
                p.append(e)
            else:
                r = rel2id[path[i]]
                p.append(r)

        P.append(p)
        PP.append(path)

        anset = answerset.split('/')
        anset = anset[:-1]
        ass = []
        for a in anset:
            ass.append(ent2id[a])
        S.append(ass)
        SS.append(anset)

    return np.array(Q), np.array(A), np.array(P), np.array(S), Triples, sentence_size, word2id, ent2id, rel2id

In [694]:
class IRN(nn.Module):
    def __init__(self, args):
        super(IRN, self).__init__()
        self._margin = 4  # 外边距
        self._batch_size = args["batch_size"]
        self._vocab_size = args["nwords"]
        self._rel_size = args["nrels"]
        self._ent_size = args["nents"]
        self._sentence_size = args["query_size"]
        self._embedding_size = args["edim"]
        self._path_size = args["path_size"]
        self._hops = int(args["nhop"])
        self._max_grad_norm = args["max_grad_norm"]
        self._name = "IRN"
        self._inner_epochs = args["inner_nepoch"]
        self._checkpoint_dir = args["checkpoint_dir"] + '/' + self._name
        self.build_vars()


    def forward(self, KBs, queries, answers, answers_id, paths):
        nexample = queries.shape[0]
        keys = np.repeat(np.reshape(np.arange(self._rel_size), [1, -1]),
                         nexample,
                         axis=0)
        pad = np.arange(nexample)
        ones = np.ones(nexample)
        zeros = np.zeros(nexample)
        
        loss = torch.Tensor(zeros).unsqueeze(1)
        s_index = torch.Tensor(paths[:, 0]).unsqueeze(1)
        q_emb = self.Q[torch.LongTensor(queries)]
        q = torch.sum(q_emb, dim=1)
        state = self.E[s_index.long()].squeeze(1)
        p = s_index
        for hop in range(self._hops):
            step = 2 * hop
            gate = torch.matmul(q, torch.matmul(self.R, self.Mrq).t())+torch.matmul(state, torch.matmul(self.R, self.Mrs).t())
            rel_logits = gate
            r_index = torch.argmax(rel_logits, dim=1)
            gate = torch.softmax(gate,dim=1)
            real_rel_onehot = torch.Tensor(paths[:, step + 1])
            
            predict_rel_onehot = torch.nn.functional.one_hot(r_index, num_classes=self._rel_size)
            state = state + torch.matmul(gate, torch.matmul(self.R, self.Mrs))
            critrion = nn.CrossEntropyLoss(reduce=False)
            loss += critrion(rel_logits, real_rel_onehot.long()).unsqueeze(1)
            
            q = q - torch.matmul(gate, torch.matmul(self.R, self.Mrq))
            value = torch.matmul(state, self.Mse)
            ans = torch.matmul(value, self.E.t())
            t_index = torch.argmax(ans, dim=1).float()
            r_index = r_index.float()
            t_index = r_index/(r_index+1e-15)*t_index + \
                (1-r_index/(r_index+1e-15)) * p[:, -1].float()

            p = torch.cat((p, r_index.float().view(-1, 1)), dim=1)
            p = torch.cat((p, t_index.float().view(-1, 1)), dim=1)

            real_ans_onehot = torch.Tensor(paths[:, step + 2])
            loss += critrion(ans, real_ans_onehot.long()).unsqueeze(1)
        
        loss = torch.sum(loss)
        self.E.data = self.E.data / (torch.pow(self.E.data, 2).sum(dim=1, keepdim=True))
        self.R.data = self.R.data / (torch.pow(self.R.data, 2).sum(dim=1, keepdim=True))
        self.Q.data = self.Q.data / (torch.pow(self.Q.data, 2).sum(dim=1, keepdim=True))
        return loss

    def build_vars(self):
        nil_word_slot = torch.zeros(1, self._embedding_size)
        nil_rel_slot = torch.zeros(1, self._embedding_size)
        '''
        self.E = nn.Parameter(
            torch.cat(
                (nil_word_slot,
                 nn.init.xavier_normal_(
                     torch.Tensor(self._ent_size - 1, self._embedding_size))),
                dim=0))
        self.Q = nn.Parameter(
            torch.cat((nil_word_slot,
                       nn.init.xavier_normal_(
                           torch.Tensor(self._vocab_size - 1,
                                        self._embedding_size))),
                      dim=0))
        self.R = nn.Parameter(
            torch.cat(
                (nil_rel_slot,
                 nn.init.xavier_normal_(
                     torch.Tensor(self._rel_size - 1, self._embedding_size))),
                dim=0))
        '''
        self.E = nn.Parameter(nn.init.xavier_normal_(torch.Tensor(self._ent_size, self._embedding_size)))
        self.Q = nn.Parameter(nn.init.xavier_normal_(torch.Tensor(self._vocab_size, self._embedding_size)))
        self.R = nn.Parameter(nn.init.xavier_normal_(torch.Tensor(self._rel_size, self._embedding_size)))

        self.Mrq = nn.Parameter(
            nn.init.xavier_normal_(
                torch.Tensor(self._embedding_size, self._embedding_size)))
        self.Mrs = nn.Parameter(
            nn.init.xavier_normal_(
                torch.Tensor(self._embedding_size, self._embedding_size)))
        self.Mse = nn.Parameter(
            nn.init.xavier_normal_(
                torch.Tensor(self._embedding_size, self._embedding_size)))

        self._zeros = torch.zeros(1)

    def match(self):
        Similar = torch.matmul(torch.matmul(self.R, self.Mrq), self.Q.t())
        _, idx = torch.topk(Similar, 5)
        return idx

    def batch_pretrain(self, KBs, queries, answers, answers_id, paths):
        nexample = KBs.shape[0]
        keys = np.repeat(np.reshape(np.arange(self._rel_size), [1, -1]),
                         nexample,
                         axis=0)
        pad = np.random.randint(low=0, high=self._ent_size, size=nexample)
        ones = np.ones(nexample)
        zeros = np.zeros(nexample)
    
        h = torch.Tensor(KBs[:, 0])
        r = torch.Tensor(KBs[:, 1])
        t = torch.Tensor(KBs[:, 2])
        tt = torch.Tensor(pad)
        h_emb = self.E[h.long()]
        r_emb = self.R[r.long()]
        t_emb = self.E[t.long()]
        tt_emb = self.E[tt.long()]
        l_emb = torch.matmul((h_emb + r_emb), self.Mse)
        s = (l_emb - t_emb) * (l_emb - t_emb)
        ss = (l_emb - tt_emb) * (l_emb - tt_emb)
        
        loss = self._margin + torch.sum(s, dim=1) - torch.sum(ss, dim=1)
        loss = torch.clamp(loss, min=0)
        
        loss = torch.sum(loss)
        return loss

    def batch_fit(self, KBs, queries, answers, answers_id, paths):
        nexample = queries.shape[0]
        keys = np.repeat(np.reshape(np.arange(self._rel_size), [1, -1]),
                         nexample,
                         axis=0)
        pad = np.arange(nexample)
        ones = np.ones(nexample)
        zeros = np.zeros(nexample)
        print(self.Q)
        loss = torch.Tensor(zeros).unsqueeze(1)
        s_index = torch.Tensor(paths[:, 0]).unsqueeze(1)
        q_emb = self.Q[torch.LongTensor(queries)]
        q = torch.sum(q_emb, dim=1)
        state = self.E[s_index.long()].squeeze(1)
        p = s_index
        for hop in range(self._hops):
            step = 2 * hop
            gate = torch.matmul(q, torch.matmul(self.R, self.Mrq).t())+torch.matmul(state, torch.matmul(self.R, self.Mrs).t())
            rel_logits = gate
            r_index = torch.argmax(rel_logits, dim=1)
            gate = torch.softmax(gate,dim=1)
            real_rel_onehot = torch.Tensor(paths[:, step + 1])
            
            predict_rel_onehot = torch.nn.functional.one_hot(r_index, num_classes=self._rel_size)
            state = state + torch.matmul(gate, torch.matmul(self.R, self.Mrs))
            critrion = nn.CrossEntropyLoss(reduce=False)
            loss += critrion(rel_logits, real_rel_onehot.long()).unsqueeze(1)
            
            q = q - torch.matmul(gate, torch.matmul(self.R, self.Mrq))
            value = torch.matmul(state, self.Mse)
            ans = torch.matmul(value, self.E.t())
            t_index = torch.argmax(ans, dim=1).float()
            r_index = r_index.float()
            t_index = r_index/(r_index+1e-15)*t_index + \
                (1-r_index/(r_index+1e-15)) * p[:, -1].float()

            p = torch.cat((p, r_index.float().view(-1, 1)), dim=1)
            p = torch.cat((p, t_index.float().view(-1, 1)), dim=1)

            real_ans_onehot = torch.Tensor(paths[:, step + 2])
            loss += critrion(ans, real_ans_onehot.long()).unsqueeze(1)
        
        loss = torch.sum(loss)
        
        self.E.data = self.E.data / (torch.pow(self.E.data, 2).sum(dim=1, keepdim=True))
        self.R.data = self.R.data / (torch.pow(self.R.data, 2).sum(dim=1, keepdim=True))
        self.Q.data = self.Q.data / (torch.pow(self.Q.data, 2).sum(dim=1, keepdim=True))
        
        return loss

    def predict(self, KBs, queries, paths):
        nexample = queries.shape[0]
        keys = np.repeat(np.reshape(np.arange(self._rel_size), [1, -1]),
                         nexample,
                         axis=0)
        pad = np.arange(nexample)
        ones = np.ones(nexample)
        zeros = np.zeros(nexample)

        loss = torch.Tensor(zeros).unsqueeze(1)
        s_index = torch.Tensor(paths[:, 0]).unsqueeze(1)
        q_emb = self.Q[torch.LongTensor(queries)]
        q = torch.sum(q_emb, dim=1)
        state = self.E[s_index.long()].squeeze(1)
        p = s_index
        for hop in range(self._hops):
            step = 2 * hop
            gate = torch.matmul(q, torch.matmul(self.R, self.Mrq).t())+torch.matmul(state, torch.matmul(self.R, self.Mrs).t())
            
            rel_logits = gate
            
            r_index = torch.argmax(rel_logits, dim=1)
            gate = torch.softmax(gate,dim=1)
            real_rel_onehot = torch.Tensor(paths[:, step + 1])
            
            predict_rel_onehot = torch.nn.functional.one_hot(r_index, num_classes=self._rel_size)
            state = state + torch.matmul(gate, torch.matmul(self.R, self.Mrs))
            critrion = nn.CrossEntropyLoss(reduce=False)
            
            loss += critrion(rel_logits, real_rel_onehot.long()).unsqueeze(1)
            
            q = q - torch.matmul(gate, torch.matmul(self.R, self.Mrq))
            value = torch.matmul(state, self.Mse)
            ans = torch.matmul(value, self.E.t())
            t_index = torch.argmax(ans, dim=1).float()
            r_index = r_index.float()
            t_index = r_index/(r_index+1e-15)*t_index + \
                (1-r_index/(r_index+1e-15)) * p[:, -1].float()

            p = torch.cat((p, r_index.float().view(-1, 1)), dim=1)
            p = torch.cat((p, t_index.float().view(-1, 1)), dim=1)

            real_ans_onehot = torch.Tensor(paths[:, step + 2])
            loss += critrion(ans, real_ans_onehot.long()).unsqueeze(1)
        return p

In [695]:
args = {}
args["edim"] = 50
args["nhop"] = 3
args["batch_size"] = 50
args["nepoch"] = 5000
args["inner_nepoch"] = 3
args["init_lr"] = 0.001
args["epsilon"] = 1e-8
args["max_grad_norm"] = 20
args["checkpoint_dir"] = "checkpoint"

In [696]:
KB_file = 'data/2H-kb.txt'
data_file = 'data/2H.txt'
start = time.time()
Q,A,P,S,Triples,args["query_size"], word2id, ent2id, rel2id = process_data(KB_file, data_file)
args["path_size"] = len(P[0])
args["nhop"] = args["path_size"] / 2


print ("read data cost %f seconds" %(time.time()-start))
args["nwords"] = len(word2id) 
args["nrels"] = len(rel2id) 
args["nents"] = len(ent2id)

trainQ, testQ, trainA, testA, trainP, testP, trainS, testS = train_test_split(Q, A, P, S, test_size=.1, random_state=123)
trainQ, validQ, trainA, validA, trainP, validP, trainS, validS = train_test_split(trainQ, trainA, trainP, trainS, test_size=.11, random_state=0)

n_train = trainQ.shape[0]     
n_test = testQ.shape[0]
n_val = validQ.shape[0]
print(trainQ.shape, trainA.shape,trainP.shape,trainS.shape)

# 找到答案所在的坐标
train_labels = np.argmax(trainA, axis=1)
test_labels = np.argmax(testA, axis=1)
valid_labels = np.argmax(validA, axis=1)
batches = list(zip(range(0, n_train-args["batch_size"], args["batch_size"]), range(args["batch_size"], n_train, args["batch_size"])))
pre_batches = list(zip(range(0, Triples.shape[0]-args["batch_size"], args["batch_size"]), range(args["batch_size"], Triples.shape[0], args["batch_size"])))


model = IRN(args)
optimizer = optim.Adam(model.parameters(), args["init_lr"],weight_decay=1e-5)
pre_val_preds = model.predict(Triples, validQ, validP)
pre_test_preds = model.predict(Triples, testQ, testP)

here are 542 words in word2id(vocab)
here are 14 relations in rel2id(rel_vocab)
here are 1057 entities in ent2id(ent_vocab)
The number of records or triples 1211
read data cost 0.095364 seconds
(1528, 13) (1528, 1057) (1528, 5) (1528,)


In [697]:
for t in range(args["nepoch"]):
    np.random.shuffle(batches)
    for i in range(args["inner_nepoch"]):
        np.random.shuffle(pre_batches)
        pre_total_cost = 0.0
        for s, e in pre_batches:
            pretrain_loss = model.batch_pretrain(
                Triples[s:e], trainQ[0:args["batch_size"]],
                trainA[0:args["batch_size"]],
                np.argmax(trainA[0:args["batch_size"]],
                          axis=1), trainP[0:args["batch_size"]])
            optimizer.zero_grad()
            pretrain_loss.backward()
            optimizer.step()
    total_cost = 0.0
    
    for s, e in batches:
        total_cost = model(Triples[s:e], trainQ[s:e], trainA[s:e],
                                      np.argmax(trainA[s:e], axis=1),
                                      trainP[s:e])
        optimizer.zero_grad()
        total_cost.backward()
        optimizer.step()
    if t % 1 == 0:
        
        train_preds = model.predict(Triples,trainQ,trainP)
        train_acc = MultiAcc(trainP,train_preds, model._path_size)
        train_true_acc = InSet(trainP,trainS,train_preds)

        val_preds = model.predict(Triples,validQ, validP)
        val_acc = MultiAcc(validP,val_preds,model._path_size)
        val_true_acc = InSet(validP,validS,val_preds)


        print('-----------------------')
        print('Epoch', t)
        print('Train Accuracy:', train_true_acc)
        print('Validation Accuracy:', val_true_acc)               
        print('-----------------------')


-----------------------
Epoch 0
Train Accuracy: 0.099
Validation Accuracy: 0.063
-----------------------
-----------------------
Epoch 1
Train Accuracy: 0.098
Validation Accuracy: 0.058
-----------------------
-----------------------
Epoch 2
Train Accuracy: 0.205
Validation Accuracy: 0.148
-----------------------
-----------------------
Epoch 3
Train Accuracy: 0.241
Validation Accuracy: 0.18
-----------------------
-----------------------
Epoch 4
Train Accuracy: 0.266
Validation Accuracy: 0.243
-----------------------
-----------------------
Epoch 5
Train Accuracy: 0.259
Validation Accuracy: 0.212
-----------------------
-----------------------
Epoch 6
Train Accuracy: 0.269
Validation Accuracy: 0.222
-----------------------
-----------------------
Epoch 7
Train Accuracy: 0.307
Validation Accuracy: 0.265
-----------------------
-----------------------
Epoch 8
Train Accuracy: 0.329
Validation Accuracy: 0.275
-----------------------
-----------------------
Epoch 9
Train Accuracy: 0.338
Va

-----------------------
Epoch 78
Train Accuracy: 0.931
Validation Accuracy: 0.788
-----------------------
-----------------------
Epoch 79
Train Accuracy: 0.954
Validation Accuracy: 0.82
-----------------------
-----------------------
Epoch 80
Train Accuracy: 0.963
Validation Accuracy: 0.836
-----------------------
-----------------------
Epoch 81
Train Accuracy: 0.952
Validation Accuracy: 0.847
-----------------------
-----------------------
Epoch 82
Train Accuracy: 0.948
Validation Accuracy: 0.82
-----------------------
-----------------------
Epoch 83
Train Accuracy: 0.949
Validation Accuracy: 0.857
-----------------------
-----------------------
Epoch 84
Train Accuracy: 0.953
Validation Accuracy: 0.857
-----------------------
-----------------------
Epoch 85
Train Accuracy: 0.967
Validation Accuracy: 0.862
-----------------------
-----------------------
Epoch 86
Train Accuracy: 0.965
Validation Accuracy: 0.852
-----------------------
-----------------------
Epoch 87
Train Accuracy:

KeyboardInterrupt: 