Necessary modules

In [6]:
import encoder
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import heapq
import time
import torch.optim as optim
from nltk.translate.bleu_score import sentence_bleu, corpus_bleu, SmoothingFunction
import pickle

In [7]:
# get vocab
model_dir = 'gpt_vocab'
enc = encoder.get_encoder(model_dir)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# define parameters
class HParams():
    n_vocab = 50257
    n_embed = 512
    start_token = 50256
    batch_size = 25
    device = device
 
hparams = HParams()

In [8]:
class Attn(nn.Module):
    def __init__(self, input_dim = 512, hid_dim = 50, k = 5):
        super(Attn,self).__init__()
        self.input_dim = input_dim
        self.hid_dim = hid_dim
        # Number of triples selected
        self.k = k
        self.attn = nn.Sequential(nn.Linear(input_dim*2,hid_dim), nn.Linear(hid_dim,1))

    def forward(self, h_k, h_c):
        #h_k = [batch_size, k, hid_dim]
        #h_c = [batch_size, hid_dim]

        h_c =  h_c.unsqueeze(1) #[batch_size, 1, hid_dim]
        h_c = torch.cat([h_c]*self.k,dim=1)    #[batch_size, k, hid_dim]
        h_comb = torch.cat((h_k,h_c),dim=2)   #[batch_size, k, hid_dim*2]

        attn_logits = self.attn(h_comb).squeeze(2)  #[batch_size,k]
        attn_weight = F.softmax(attn_logits).unsqueeze(1)
        h_k_comb = torch.bmm(attn_weight,h_k).squeeze(1)    #[batch_size,hid_dim]
        return h_k_comb

In [9]:
class DecoderLSTM(nn.Module):
    def __init__(self,embedding_size = 256, num_units = 512, vocab_size = 50257, dropout_p = 0.1, num_layers = 2):
        super(DecoderLSTM, self).__init__()
        self.embedding_size = embedding_size
        self.num_units = num_units
        self.vocab_size = vocab_size
        self.dropout_p = dropout_p

        self.embedding = nn.Embedding(vocab_size, embedding_size) # !

        self.lstm = nn.LSTM(embedding_size,hidden_size=num_units,num_layers=num_layers, batch_first = True)

        self.Linear = nn.Linear(num_units,vocab_size)

        self.dropout = nn.Dropout(dropout_p)

    def forward(self, input, hidden, cell):
        #input = [batch_size]
        #hidden = [batch_size, 2, hid_dim]
        #cell = [batch_size, 2, hid_dim]

        input = input.unsqueeze(1)  #[batch_size, 1]
        # print("Input shape:",input.shape)

        embedding = self.dropout(self.embedding(input)) #[batch_size, 1, emb_dim]
        # print("Embedding shape:",embedding.shape)
        # print("hidden shape",hidden.shape)

        output, (hidden, cell) = self.lstm(embedding, (hidden, cell)) 

        #output = [batch_size, 1, hid_dim]
        #hidden = [2, batch_size hid_dim]
        #cell  = [2, batch_size, hid_dim]

        logits = self.Linear(output)  #[batch_size, vocab_size]

        return logits, hidden, cell



class Decoder(nn.Module):
    def __init__(self, hparams, embedding_size = 256, num_units=512, vocab_size = 50257, dropout_p = 0.1, seq_len = 100, batch_size = 32, teacher_forcing_ratio = 0.5):
        super(Decoder,self).__init__()

        self.LSTM = DecoderLSTM(embedding_size, num_units,vocab_size,dropout_p)
        self.start_token = hparams.start_token
        self.batch_size = hparams.batch_size
        self.teacher_forcing_ratio = teacher_forcing_ratio
        self.hparams = hparams


    def forward(self, trg, h_c, h_k):
        #h_c = [batch_size, hid_dim]
        #h_k = [batch_size, hid_dim]
        #trg = [batch_size, seq_len]

        input = trg[:,0]
        hidden = torch.stack((h_c,h_k),dim=0)
        cell = torch.zeros(2, self.batch_size,self.LSTM.num_units).to(self.hparams.device)
        outputs = torch.zeros(self.batch_size, 1, self.LSTM.vocab_size).to(self.hparams.device)

        for t in range(1, trg.shape[1]):
            output, hidden, cell = self.LSTM.forward(input,hidden,cell)
            # print(output.shape)

            outputs = torch.cat([outputs,output],dim=1)
            # outputs[:,t,:] = output

            top1 = output.squeeze(1).argmax(1)
            # print("top1_shape",top1.shape)

            replace = np.random.random() < self.teacher_forcing_ratio

            input = trg[:,t] if replace else top1
            # print(input.shape)

        
        return outputs[:,1:]


    def decode(self, h_c, h_k, seq_len):
      with torch.no_grad():
        input = torch.LongTensor([self.hparams.start_token]*self.batch_size).to(self.hparams.device)
        hidden = torch.stack((h_c,h_k),dim=0)
        cell = torch.zeros(2,self.batch_size,self.LSTM.num_units).to(self.hparams.device)
        tokens = None
        # tokens = torch.LongTensor(np.zeros([self.batch_size,1]))
        for t in range(0, seq_len-1):
            output, hidden, cell = self.LSTM.forward(input,hidden,cell)
            input = output.squeeze(1).argmax(1)
            tokens = input.unsqueeze(1) if tokens == None else torch.cat([tokens,input.unsqueeze(1)],dim=1)
        return tokens

In [10]:
class Hashmodel(nn.Module):
    def __init__(self,hparams,k):
        super(Hashmodel,self).__init__()
        self.embedding_model = nn.Sequential(
                                            nn.Linear(768, 1024),
                                            nn.Tanh(),
                                            nn.Linear(1024, 512),
                                            nn.Tanh(),
                                            nn.Sigmoid(),
                                            )
        self.attn = Attn()
        self.decoder = Decoder(hparams)
        self.k = k
        self.hparams = hparams

    def find_k_tuples(self,tuples_emb, conv_emb, k):
        """ 
        tuples is output tuple embedding from model [tuple_size, 512]
        conv is output conversation embedding from model [batch_size,512]
        k is the number of selected tuples 
        """
        # hash = [batch_size, 512]
        # probability = [batch_size, 512]
        topk_indices_lst = torch.LongTensor(np.zeros([conv_emb.shape[0],k])).to(self.hparams.device)
        with torch.no_grad():
            tuples_hash = torch.bernoulli(tuples_emb)
            conv_hash = torch.bernoulli(conv_emb)
            for i in range(conv_hash.shape[0]):
                conv_hash_1 = conv_hash[i].detach()
                hamming_table = torch.logical_xor(conv_hash_1, tuples_hash)
                hamming_dist = torch.sum(hamming_table, 1)
                _,topk_indices =  torch.topk(hamming_dist,k, largest = True)    ##[5]
                topk_indices_lst[i] = topk_indices
        
            return topk_indices_lst, tuples_hash



    def forward(self,conv,tuples,trg):
        #conv = [batch_size,768]
        #tuples = [tuples_size,768]
        batch_size = self.hparams.batch_size
        seq_len = trg.shape[1]

        conv_emb = self.embedding_model(conv)   #[batch_size, 512]
        tuple_emb = self.embedding_model(tuples) #[tuples_size, 512]
        outputs = torch.zeros(1, batch_size , seq_len-1, self.hparams.n_vocab).to(self.hparams.device)
        tokens = torch.zeros(1, batch_size, seq_len-1).to(self.hparams.device)
        batch_pos = torch.zeros(1, batch_size).to(self.hparams.device)

        ##Do five samples
        for i in range(1):
            topk_indices, tuple_hash = self.find_k_tuples(tuple_emb,conv_emb,self.k)    #[batch_size, 5]
            topk_emb = tuple_emb[topk_indices]        #[batch_size, 5, 512]
            topk_hash = tuple_hash[topk_indices]       #[batch_size, 5, 512]
            temp_ones = torch.ones([batch_size, 5, 512]).to(self.hparams.device)
            sum1 = torch.log(topk_emb) * topk_hash + torch.log(1 - topk_emb) * (1-topk_hash) #[batch_size, 5, 512]
            sum2 = torch.sum(sum1,dim=2)
            batch_pos[i] = torch.sum(sum2,dim=1) #[batch_size]
            
            # now go next level
            after_attn = self.attn(topk_emb, conv_emb)
            after_decode = self.decoder(trg,conv_emb,after_attn)
            outputs[i] = after_decode
            tokens[i] = self.decoder.decode(conv_emb, after_attn, seq_len) # [batch_size, seq_len]


        return batch_pos, outputs, tokens
        

In [11]:
model = Hashmodel(hparams, 5).to(device)
# init weights
def init_weights(m):
    for name, param in m.named_parameters():
        nn.init.uniform_(param.data, -0.8, 0.8)

model.apply(init_weights)

# calculate the number of trainable parameters in the model
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# optimizer
optimizer = optim.Adam(model.parameters(),lr=1e-4)

# index of <pad>
PAD_ID = enc.encoder['<|endoftext|>']   
# criterion
# we ignore the loss whenever the target token is a padding token
criterion = nn.CrossEntropyLoss(ignore_index = hparams.start_token)

In [12]:
def train(model, criterion, optimizer, tuples_emb, conv_emb, trg_tokens, hparams):
#     model.to(device)
    '''
    tuples_emb is the embedding of all the tuples
    conv_emb is the embedding of all the conversation 
    trg_tokens is the targer response
    '''
    
    
    model.train()
    epoch_loss = 0
    tuples_input = torch.FloatTensor(tuples_emb).to(hparams.device)
    for iter in range(int(len(trg_tokens)/hparams.batch_size)):
      if iter == 75:
        continue
      batched_data = get_batched_data(trg_tokens,conv_emb, hparams.batch_size, iter)
      conv_input = torch.FloatTensor(batched_data['conv_emb']).to(hparams.device)
      trg_input = torch.LongTensor(batched_data['trg_tokens']).to(hparams.device)

      

      optimizer.zero_grad()

      batch_pos, outputs, tokens = model.forward(conv_input, tuples_input, trg_input)

      #trg = [trg sent len, batch size]
      #output = [trg sent len, batch size, output dim]

      # output = output[1:].view(-1, output.shape[-1])
      # trg = trg[1:].view(-1)

      #output = [(trg sent len - 1) * batch size, output dim]
      #trg = [(trg sent len - 1) * batch size]
      # lm_loss = torch.zeros([5]).to(hparams.device)
      # rl_loss = torch.zeros([5]).to(hparams.device)
      for i in range(1):
          output = outputs[i].view(-1, outputs.shape[-1])
          trg = trg_input[:,1:].reshape(-1)
          # output = outputs[i]
          # trg = trg_input[:,1:]
          lm_loss = criterion(output, trg) 

          token_text = batch_decode(tokens[i])
          target_text = batch_decode(trg_input)
          
          temp = torch.Tensor(cal_bleu(token_text, target_text)).to(device)
          rl_loss = -(torch.sum(batch_pos[i] * temp ).to(device) / len(token_text))
          
      # print(lm_loss.dtype)
      # print(rl_loss)
      # loss = (torch.sum(lm_loss) + torch.sum(rl_loss)) / 5
      loss = lm_loss + rl_loss

      loss.backward()

      torch.nn.utils.clip_grad_norm_(model.parameters(), 1)

      optimizer.step()

      del(conv_input)
      del(trg_input)

      print("iter %d loss: %f"%(iter,loss.item()))
      torch.cuda.empty_cache()

    epoch_loss += loss.item()

    return epoch_loss  

In [13]:
def batch_decode(tokens):
    
    tokens = tokens.data.cpu().numpy() # [batch_size, seq_len]
    result = []
    for token in tokens:
        result.append(enc.decode(token).replace('<|endoftext|>',''))
    return result

In [14]:
def cal_bleu(token_text, target_text):
    smooth = SmoothingFunction()
    score = []
    for i in range(len(token_text)):
        reference = [target_text[i].strip().split()]
        candidate = token_text[i].strip().split()
        bleu = sentence_bleu(reference, candidate,smoothing_function=smooth.method1)
        bleu -= 0.2
        score.append(bleu)
    return score 

##Data Loader

In [17]:
def EmbLoader(kg_path, qs_path):
    responds = []
    questions = []
    r_flag = False
    list_file = open(qs_path,'rb')
    embeddings = pickle.load(list_file)
    list_file.close()
    for line in embeddings:
        if r_flag == False:
            questions.append(line)
            r_flag = True
        else:
            responds.append(line)
            r_flag = False

    list_file = open(kg_path,'rb')
    tuples = pickle.load(list_file)
    list_file.close()
    return np.array(questions), np.array(tuples)
    
def TextLoader(trg_path,size):
  with open(trg_path,'r') as fin:
    r_flag = False
    response_lst = []
    count = 0
    for line in fin:
      if not r_flag:
        r_flag = True
        continue
      else:
        response_lst.append(line.strip())
        count += 1
        if count >= size:
          break
        r_flag = False
    response_token = [enc.encode(response) for response in response_lst]
    return response_token

def get_batched_data(tokens,conv_emb,batch_size,iter_num):
  assert(len(tokens) == len(conv_emb))
  batched_data = {}
  st = batch_size*iter_num
  ed = batch_size*(iter_num + 1)
  if ed >= len(tokens):
    ed = len(tokens)
    
  batched_tokens = tokens[st:ed]
  max_len = max([len(text) for text in batched_tokens]) + 2
  batched_pad_tokens = pad_text(batched_tokens,max_len)
  batched_data['trg_tokens'] = batched_pad_tokens
  batched_data['conv_emb'] = conv_emb[st:ed]
  return batched_data

def pad_text(text,max_len):
  pad_texts = [[PAD_ID] + line + [PAD_ID]*(max_len - len(line)) for line in text]
  return np.array(pad_texts)


In [18]:
questions_emb, tuples_emb = EmbLoader('tuples.pickle', 'dialogue_embeddings.pickle')
tokens = TextLoader("dialogue.txt",5000)
train(model, criterion, optimizer, tuples_emb, questions_emb, tokens, hparams)




iter 0 loss: -289.681976
iter 1 loss: -289.304779
iter 2 loss: -288.792542
iter 3 loss: -289.701904
iter 4 loss: -290.020966
iter 5 loss: -288.358795
iter 6 loss: -288.484650
iter 7 loss: -287.538391
iter 8 loss: -287.863190
iter 9 loss: -287.650574
iter 10 loss: -288.951904
iter 11 loss: -286.734344
iter 12 loss: -289.588959
iter 13 loss: -287.166016
iter 14 loss: -288.304291
iter 15 loss: -285.470520
iter 16 loss: -288.210571
iter 17 loss: -286.598755
iter 18 loss: -288.036438
iter 19 loss: -287.838379
iter 20 loss: -287.253845
iter 21 loss: -287.223053
iter 22 loss: -288.381622
iter 23 loss: -288.208893
iter 24 loss: -287.111786
iter 25 loss: -287.137085
iter 26 loss: -288.709259
iter 27 loss: -286.741669
iter 28 loss: -288.962646
iter 29 loss: -287.740601
iter 30 loss: -288.674866
iter 31 loss: -288.544434
iter 32 loss: -288.384216
iter 33 loss: -289.103394
iter 34 loss: -288.286926
iter 35 loss: -289.148132
iter 36 loss: -288.297394
iter 37 loss: -287.799255
iter 38 loss: -288.518

-288.3685607910156

In [None]:
tokens = get_batched_data(tokens,questions_emb, hparams.batch_size, 75)['trg_tokens']
for line in tokens:
  for id in line:
    if id > 50256 or id < 0:
      print("Error")

In [None]:
from torchviz import make_dot
tuples_input = torch.randn(500, 768).to(device)
conv_input = torch.randn(32, 768).to(device)
trg_input = torch.LongTensor(np.random.randint(0,50256, [32,20])).to(device)

vis_graph = make_dot(model(conv_input,tuples_input,trg_input), params=dict(model.named_parameters()))
vis_graph.view()

In [None]:
!nvidia-smi