In [1]:
import json 
import os
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import csv
import pandas as pd
import re
import unicodedata
import itertools
import random

USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda:1" if USE_CUDA else "cpu")

In [2]:
# Default word tokens
PAD_token = 0  # Used for padding short sentences
SOS_token = 1  # Start-of-sentence token
EOS_token = 2  # End-of-sentence token
UNK_token = 3  # Unkonw token

class Voc:
    def __init__(self, name):
        self.name = name
        self.trimmed = False
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS", UNK_token:"UNK"}
        self.num_words = 4  # Count SOS, EOS, PAD

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.num_words
            self.word2count[word] = 1
            self.index2word[self.num_words] = word
            self.num_words += 1
        else:
            self.word2count[word] += 1

    # Remove words below a certain count threshold
    def trim(self, min_count):
        if self.trimmed:
            return
        self.trimmed = True

        keep_words = []

        for k, v in self.word2count.items():
            if v >= min_count:
                keep_words.append(k)

        print('keep_words {} / {} = {:.4f}'.format(
            len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)
        ))

        # Reinitialize dictionaries
        self.word2index = {"PAD": PAD_token, "SOS": SOS_token, "EOS": EOS_token, "UNK":UNK_token}
        self.word2count = {"UNK": 0}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS", UNK_token:"UNK"}
        self.num_words = 4 # Count default tokens

        for word in keep_words:
            self.addWord(word)

In [3]:
# Turn a Unicode string to plain ASCII, thanks to
# http://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Lowercase, trim, and remove non-letter characters
def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"can't", r"can not", s)
    s = re.sub(r"n't", r" not", s)
    s = re.sub(r"'ve'", r" have", s)
    s = re.sub(r"cannot", r"can not", s)
    s = re.sub(r"what's", r"what is", s)
    s = re.sub(r"'re", r" are", s)
    s = re.sub(r"'d", r" would", s)
    s = re.sub(r"'ll'", r" will", s)
    s = re.sub(r" im ", r" i am ", s)
    s = re.sub(r"'m", r" am", s)
    s = re.sub(r"([.!?])", r" \1 ", s)
    s = re.sub(r"[^a-zA-Z.!?0-9]+", r" ", s)
    s = re.sub(r"\s+", r" ", s).strip()
    return s

wiki_docs = [0]*30
path = '../WikiData'
for file in os.listdir(path):
    wiki_file = open(os.path.join(path,file))
    wiki_data = json.load(wiki_file)
    wiki_docs[wiki_data['wikiDocumentIdx']] = wiki_data

wiki_strings = []
for i in range(30):
    doc = []
    for j in range(4):
        doc.append(normalizeString(str(wiki_docs[i][str(j)])))
    wiki_strings.append(doc)

# print(wikiStrings[20][0])

In [4]:
def readConvsFile(path,file):
    conv_file = open(os.path.join(conv_path,file))
    conv_data = json.load(conv_file)
    wikiIndex = conv_data['wikiDocumentIdx']
    if len(conv_data['whoSawDoc']) == 2:
        saw = 2
    elif conv_data['whoSawDoc'] == ['user1']:
        saw = 0
    else:
        saw = 1
    convs = []
    for idx, utter in enumerate(conv_data['history']):
        utter['text'] = normalizeString(utter['text'])
        line = {}
        line['wikiIdx'] = wikiIndex
        line['docIdx'] = utter['docIdx']
        line['uid'] = utter['uid']
        line['text'] = utter['text']
        line['saw'] = saw
        convs.append(line)
    
    return convs

def saveNewConvs(path):
    index = 0
    for file in os.listdir(path):
        if file.split('.')[1] != 'json':
            continue
        convs = readConvsFile(path,file)
        new_file = 'train'+str(index)+'.csv'
        data_file = os.path.join(path,new_file)
        print('Writing to new formatted line...',index)
        df = pd.DataFrame(convs)
        df.to_csv(data_file,encoding='utf-8',sep='\t')
        index += 1

# conv_path = '../Conversations/train'
# saveNewConvs(conv_path)

In [5]:
MAX_LENGTH = 25  # Maximum sentence length to consider

def buildPairs(df):
    pairs = []
    for i in df.index:

        pair = []
        pair.append(df.iloc[i].wikiIdx)
        pair.append(df.iloc[i].docIdx)
        pair.append(df.iloc[i].text)
        if i+1 < len(df.index) and df.iloc[i].uid != df.iloc[i+1].uid:
            pair.append(df.iloc[i+1].text)
            pairs.append(pair)
    return pairs

# Returns True iff both sentences in a pair 'p' are under the MAX_LENGTH threshold
def filterPair(p):
    try:
        # Input sequences need to preserve the last word for EOS token
        return type(p[2]) == str and type(p[3]) == str and len(p[2].split(' ')) < MAX_LENGTH and len(p[3].split(' ')) < MAX_LENGTH
    except Exception as e:
        print(p)
        raise e
        

# Filter pairs using filterPair condition
def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

def loadPrepareData(dataPath,corpus_name,wiki_strings):
    voc = Voc(corpus_name)
    pairs = []
    print("Starting preparing training data...")
    for file in os.listdir(dataPath):
        if file.split('.')[1] == 'json':
            continue
        df = pd.read_csv(os.path.join(dataPath,file),sep='\t',encoding='utf-8')
        pairs += buildPairs(df)

    print("Read {!s} sentence pairs".format(len(pairs)))
    pairs = filterPairs(pairs)
    print("Trimmed to {!s} sentence pairs".format(len(pairs)))
    print("Counting words...")
    for pair in pairs:
        voc.addSentence(pair[2])
        voc.addSentence(pair[3])
    
    for wiki_doc in wiki_strings:
        for s in wiki_doc:
            voc.addSentence(s)
            
    print("Counted words:", voc.num_words)
    return voc, pairs

corpus_name = 'seq+att'
save_dir = os.path.join('../Conversations',corpus_name)
voc, pairs = loadPrepareData('../Conversations/train',corpus_name,wiki_strings)

# print("\npairs:")
# for pair in pairs[:30]:
#     print(pair)

Starting preparing training data...
Read 72922 sentence pairs
Trimmed to 58307 sentence pairs
Counting words...
Counted words: 17425


In [6]:
MIN_COUNT = 2    # Minimum word count threshold for trimming

def trimRareWords(voc, pairs, MIN_COUNT,wiki_docs):
    try:
        # Trim words used under the MIN_COUNT from the voc
        voc.trim(MIN_COUNT)
        # Filter out pairs with trimmed words
        keep_pairs = []
        for index,pair in enumerate(pairs):
            input_sentence = pair[2]
            output_sentence = pair[3]
            # Check input sentence
            for word in input_sentence.split(' '):
                if word not in voc.word2index:
                    input_sentence = re.sub(" "+word+" "," UNK ",input_sentence)
                    input_sentence = re.sub("^"+word+" ","UNK ",input_sentence)
                    input_sentence = re.sub(" "+word+"$"," UNK",input_sentence)
                    input_sentence = re.sub("^"+word+"$","UNK",input_sentence)
            # Check output sentence
            for word in output_sentence.split(' '):
                if word not in voc.word2index:
                    output_sentence = re.sub(" "+word+" "," UNK ",output_sentence)
                    output_sentence = re.sub("^"+word+" ","UNK ",output_sentence)
                    output_sentence = re.sub(" "+word+"$"," UNK",output_sentence)
                    output_sentence = re.sub("^"+word+"$","UNK",output_sentence)

            pairs[index][2] = input_sentence
            pairs[index][3] = output_sentence
        
        for index, wiki_doc in enumerate(wiki_docs):
            for si,section in enumerate(wiki_doc):
                for word in section.split(' '):
                    if word not in voc.word2index:
                        section = re.sub(" "+word+" "," UNK ", section)
                        section = re.sub("^"+word+" ","UNK ", section)
                        section = re.sub(" "+word+"$"," UNK", section)
                        section = re.sub("^"+word+"$","UNK", section)
                        wiki_docs[index][si] = section

        return pairs, wiki_docs
    except Exception as e:
        print(pair)
        raise e


# Trim voc and pairs
pairs,wiki_strings = trimRareWords(voc, pairs, MIN_COUNT,wiki_strings)
print("\npairs:")
for pair in pairs[:30]:
    print(pair)
print(wiki_strings[0])

keep_words 12746 / 17421 = 0.7316

pairs:
[5, 0, 'hey have you seen the inception ?', 'no i have not but have heard of it . what is it about']
[5, 0, 'no i have not but have heard of it . what is it about', 'it s about extractors that perform experiments using military technology o n people to retrieve info about their targets .']
[5, 0, 'it s about extractors that perform experiments using military technology o n people to retrieve info about their targets .', 'sounds interesting do you know which actors are in it ?']
[5, 0, 'he plays as don cobb', 'oh okay yeah i am not a big scifi fan but there are a few movies i still enjoy in that genre .']
[5, 0, 'is it a long movie ?', 'does not say how long it is .']
[5, 2, 'ellen page', 'oh cool . i am familiar with her . she s in a number of good movies and is great .']
[5, 2, 'oh cool . i am familiar with her . she s in a number of good movies and is great .', 'she plays ariadne she is a graduate student that constructs the dreamscapes they 

In [7]:
def indexesFromSentence(voc, sentence):
    return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token]


def zeroPadding(l, fillvalue=PAD_token):
    return list(itertools.zip_longest(*l, fillvalue=fillvalue))

def binaryMatrix(l, value=PAD_token):
    m = []
    for i, seq in enumerate(l):
        m.append([])
        for token in seq:
            if token == PAD_token:
                m[i].append(0)
            else:
                m[i].append(1)
    return m

def docPre(docs,voc):
    try:
        indexes_docs = []
        for doc in docs:
    #         print(doc)
            indexes_docs.append([indexesFromSentence(voc,sentence) for sentence in doc])
        return indexes_docs
    except Exception as e:
        print(doc)
        raise e

def docVar(l,voc,docs):
    indexes_docs = docPre(docs,voc)
    indexes_batch = [indexes_docs[docIdx][secIdx] for docIdx,secIdx in l]
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    padVar = torch.LongTensor(padList)
    return padVar, lengths
    

# Returns padded input sequence tensor and lengths
def inputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    padVar = torch.LongTensor(padList)
    return padVar, lengths

# Returns padded target sequence tensor, padding mask, and max target length
def outputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    max_target_len = max([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    mask = binaryMatrix(padList)
    mask = torch.ByteTensor(mask)
    padVar = torch.LongTensor(padList)
    return padVar, mask, max_target_len

# Returns all items for a given batch of pairs
def batch2TrainData(voc, pair_batch,wiki_docs):
    try:
        pair_batch.sort(key=lambda x: len(x[2].split(" ")), reverse=True)
        doc_batch, input_batch, output_batch = [], [], []
        for pair in pair_batch:
            doc_batch.append([pair[0],pair[1]])
            input_batch.append(pair[2])
            output_batch.append(pair[3])
        doc_inp, doc_lengths = docVar(doc_batch, voc,wiki_docs)
        inp, lengths = inputVar(input_batch, voc)
        output, mask, max_target_len = outputVar(output_batch, voc)
        return doc_inp,doc_lengths,inp, lengths, output, mask, max_target_len
    except Exception as e:
        print(pair_batch)
        raise e


# Example for validation
small_batch_size = 64
batches = batch2TrainData(voc, [random.choice(pairs) for _ in range(small_batch_size)],wiki_strings)
doc_input, doc_lengths, input_variable, lengths, target_variable, mask, max_target_len = batches

print("doc_input:", doc_input.shape)
print("doc_lengths:", doc_lengths)
print("input_variable:", input_variable)
print("lengths:", lengths)
print("target_variable:", target_variable)
print("mask:", mask)
print("max_target_len:", max_target_len)
# indexes_batch = [indexesFromSentence(voc, sentence) for sentence in wiki_strings[0]]


doc_input: torch.Size([472, 64])
doc_lengths: tensor([122, 384, 143, 367, 139, 126, 235, 399, 399, 399, 361, 206, 234, 139,
        133, 172, 168, 192, 138, 220, 108, 204, 103, 331, 133, 137, 315, 220,
        191, 122, 241, 203, 350,  97, 472, 114, 234, 419, 167, 291, 350, 131,
        220, 150, 291, 384, 234, 188, 141, 211, 204, 331, 140, 172, 361, 419,
        235, 188, 228,  95, 126, 121, 209, 235])
input_variable: tensor([[  118,   128,   177,  ...,   123,  3755,   112],
        [   17,    18,     8,  ...,    72,  1390,     2],
        [ 2944,    12, 12300,  ...,     2,     2,     0],
        ...,
        [ 2368,    18,  2258,  ...,     0,     0,     0],
        [  362,     2,     2,  ...,     0,     0,     0],
        [    2,     0,     0,  ...,     0,     0,     0]])
lengths: tensor([24, 23, 23, 21, 21, 21, 20, 20, 20, 19, 17, 16, 16, 16, 16, 15, 15, 15,
        14, 14, 13, 13, 13, 12, 12, 11, 11, 10, 10, 10, 10, 10,  9,  9,  9,  8,
         8,  8,  8,  7,  7,  7,  7,  7,  7,  7

In [8]:
class EncoderRNN(nn.Module):
    def __init__(self, embedding_size,hidden_size, embedding, n_layers=1, dropout=0):
        super(EncoderRNN, self).__init__()
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.embedding = embedding
        self.embedding_size = embedding_size

        self.gru = nn.GRU(embedding_size, hidden_size, n_layers,
                          dropout=(0 if n_layers == 1 else dropout), bidirectional=True)

    def forward(self, input_seq, input_lengths, hidden=None):
        # Convert word indexes to embeddings
        embedded = self.embedding(input_seq) # (L,B,E)
        # Pack padded batch of sequences for RNN module
        packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
        # Forward pass through GRU
        outputs, hidden = self.gru(packed, hidden)# (L,B,direc*H)  (layer*direc,B,H)
        # Unpack padding
        outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs) 
        # Sum bidirectional GRU outputs
        outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:] #(L,B,H)
        # Return output and final hidden state
        return outputs, hidden

In [9]:
# Luong attention layer
class Attn(torch.nn.Module):
    def __init__(self, method, hidden_size):
        super(Attn, self).__init__()
        self.method = method
        if self.method not in ['dot', 'general', 'concat']:
            raise ValueError(self.method, "is not an appropriate attention method.")
        self.hidden_size = hidden_size
        if self.method == 'general':
            self.attn = torch.nn.Linear(self.hidden_size, hidden_size)
        elif self.method == 'concat':
            self.attn = torch.nn.Linear(self.hidden_size * 2, hidden_size)
            self.v = torch.nn.Parameter(torch.FloatTensor(hidden_size))

    def dot_score(self, hidden, encoder_output): #(1,B,H) (L,B,H)
        return torch.sum(hidden * encoder_output, dim=2) #(L,B)

    def general_score(self, hidden, encoder_output):
        energy = self.attn(encoder_output)
        return torch.sum(hidden * energy, dim=2)

    def concat_score(self, hidden, encoder_output):
        energy = self.attn(torch.cat((hidden.expand(encoder_output.size(0), -1, -1), encoder_output), 2)).tanh()
        return torch.sum(self.v * energy, dim=2)

    def forward(self, hidden, encoder_outputs):
        # Calculate the attention weights (energies) based on the given method
        if self.method == 'general':
            attn_energies = self.general_score(hidden, encoder_outputs)
        elif self.method == 'concat':
            attn_energies = self.concat_score(hidden, encoder_outputs)
        elif self.method == 'dot':
            attn_energies = self.dot_score(hidden, encoder_outputs)

        # Transpose max_length and batch_size dimensions
        attn_energies = attn_energies.t() #(B,L)

        # Return the softmax normalized probability scores (with added dimension)
        return F.softmax(attn_energies, dim=1).unsqueeze(1) #(B,1,L)

In [10]:
class LuongAttnDecoderRNN(nn.Module):
    def __init__(self, attn_model, embedding, embedding_size, encoder_n_layers, hidden_size, output_size, n_layers=1, dropout=0.1):
        super(LuongAttnDecoderRNN, self).__init__()

        # Keep for reference
        self.attn_model = attn_model
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout = dropout
        self.encoder_n_layers = encoder_n_layers

        # Define layers
        self.embedding = embedding
        self.embedding_dropout = nn.Dropout(dropout)
        self.gru = nn.GRU(embedding_size+hidden_size*encoder_n_layers, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout))
        self.concat = nn.Linear(hidden_size * 2, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)

        self.attn = Attn(attn_model, hidden_size)

    def forward(self, input_step, sec_hidden, last_hidden, encoder_outputs):
        # Note: we run this one step (word) at a time
        # Get embedding of current input word
        embedded = self.embedding(input_step) #(1,B,E) 
        embedded = self.embedding_dropout(embedded)
        
        sec_hidden = sec_hidden[:self.encoder_n_layers] + sec_hidden[self.encoder_n_layers:] #(encoderlayer,B,H)
        sec_hidden = sec_hidden.transpose(0,1).contiguous().view(1,embedded.shape[1],-1) #(1,B,encoderlayer*H)
        # Forward through unidirectional GRU
        rnn_input = torch.cat((embedded,sec_hidden),2) #融合section与word (1,B,E+H*encoderlayer)
        rnn_output, hidden = self.gru(rnn_input, last_hidden) #(1,B,H) (layer,B,H)
        # Calculate attention weights from the current GRU output
        attn_weights = self.attn(rnn_output, encoder_outputs) #(B,1,L)
        # Multiply attention weights to encoder outputs to get new "weighted sum" context vector
        context = attn_weights.bmm(encoder_outputs.transpose(0, 1))# (B,1,L) * (B,L,H) = (B,1,H)
        # Concatenate weighted context vector and GRU output using Luong eq. 5
        rnn_output = rnn_output.squeeze(0) #(B,H)
        context = context.squeeze(1) #(B,H)
        concat_input = torch.cat((rnn_output, context), 1) #(B,2*H)
        concat_output = torch.tanh(self.concat(concat_input)) #(B,H)
        # Predict next word using Luong eq. 6
        output = self.out(concat_output)#(B,Out)
        output = F.softmax(output, dim=1)#(B,Out)
        # Return output and final hidden state
        return output, hidden #(B,Out) (layer,B,H)

In [11]:
def maskNLLLoss(inp, target, mask):
    nTotal = mask.sum()
    crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
    loss = crossEntropy.masked_select(mask).mean()
    loss = loss.to(device)
    return loss, nTotal.item()

In [12]:
def train(input_variable, lengths, section_variable, sec_lengths, sec_idx, target_variable, mask, max_target_len, encoder, sec_encoder, decoder, embedding,
          encoder_optimizer, sec_encoder_optimizer, decoder_optimizer, batch_size, clip, max_length=MAX_LENGTH):

    # Zero gradients
    encoder_optimizer.zero_grad()
    sec_encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    # Set device options
    input_variable = input_variable.to(device)
    lengths = lengths.to(device)
    section_variable = section_variable.to(device)
    sec_lengths = sec_lengths.to(device)
    sec_idx = sec_idx.to(device)
    target_variable = target_variable.to(device)
    mask = mask.to(device)

    # Initialize variables
    loss = 0
    print_losses = []
    n_totals = 0
    
    encoder_outputs, encoder_hidden = encoder(input_variable, lengths) # (L,B,H)  (layer*direc,B,H)
    try:
        # Forward pass through encoder
        sec_outputs, sec_hidden = sec_encoder(section_variable, sec_lengths) # (secL,B,H)  (layer*direc,B,H)
        sec_hidden = sec_hidden.index_select(1,sec_idx) #调整回按utter长度排序的batch内顺序
    except Exception as e:
        print(section_variable,section_variable.shape)
        print(sec_lengths,sec_lengths.shape)
        print(input_variable,input_variable.shape)
        print(lengths,lengths.shape)
        raise e

    # Create initial decoder input (start with SOS tokens for each sentence)
    decoder_input = torch.LongTensor([[SOS_token for _ in range(batch_size)]]) # (1,B)
    decoder_input = decoder_input.to(device)

    # Set initial decoder hidden state to the encoder's final hidden state
    decoder_hidden = encoder_hidden[:decoder.n_layers] #（layer,B,H)

    # Determine if we are using teacher forcing this iteration
    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False

    # Forward batch of sequences through decoder one time step at a time
    if use_teacher_forcing:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(
                decoder_input, sec_hidden, decoder_hidden, encoder_outputs
            )
            # Teacher forcing: next input is current target
            decoder_input = target_variable[t].view(1, -1)
            # Calculate and accumulate loss
            mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            print_losses.append(mask_loss.item() * nTotal)
            n_totals += nTotal
    else:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(
                decoder_input, sec_hidden, decoder_hidden, encoder_outputs
            )
            # No teacher forcing: next input is decoder's own current output
            _, topi = decoder_output.topk(1)
            decoder_input = torch.LongTensor([[topi[i][0] for i in range(batch_size)]])
            decoder_input = decoder_input.to(device)
            # Calculate and accumulate loss
            mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            print_losses.append(mask_loss.item() * nTotal)
            n_totals += nTotal

    # Perform backpropatation
    loss.backward()

    # Clip gradients: gradients are modified in place
    _ = torch.nn.utils.clip_grad_norm_(encoder.parameters(), clip)
    _ = torch.nn.utils.clip_grad_norm_(sec_encoder.parameters(), clip)
    _ = torch.nn.utils.clip_grad_norm_(decoder.parameters(), clip)

    # Adjust model weights
    encoder_optimizer.step()
    sec_encoder_optimizer.step()
    decoder_optimizer.step()

    return sum(print_losses) / n_totals

In [13]:
def trainIters(model_name, voc, pairs, wiki_strings, encoder, sec_encoder, decoder, encoder_optimizer, sec_encoder_optimizer, decoder_optimizer, embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size, print_every, save_every, clip, corpus_name, loadFilename):

    # Load batches for each iteration
    training_batches = [batch2TrainData(voc, [random.choice(pairs) for _ in range(batch_size)],wiki_strings)
                                       for _ in range(n_iteration)]

    # Initializations
    print('Initializing ...')
    start_iteration = 1
    print_loss = 0
    if loadFilename:
        start_iteration = checkpoint['iteration'] + 1

    # Training loop
    print("Training...")
    for iteration in range(start_iteration, n_iteration + 1):
        training_batch = training_batches[iteration - 1]
        # Extract fields from batch
        doc_input, doc_lengths, input_variable, lengths, target_variable, mask, max_target_len = training_batch
        
        #将doc按长度降序排列，并保存让其恢复原样的idx2
        doc_lengths,idx1 = torch.sort(doc_lengths,descending=True)
        doc_input = doc_input.index_select(1,idx1)
        _,idx2 = torch.sort(idx1)
        # Run a training iteration with batch
        loss = train(input_variable, lengths, doc_input, doc_lengths, idx2, target_variable, mask, max_target_len, encoder,sec_encoder,
                     decoder, embedding, encoder_optimizer,sec_encoder_optimizer, decoder_optimizer, batch_size, clip)
        print_loss += loss

        # Print progress
        if iteration % print_every == 0:
            print_loss_avg = print_loss / print_every
            print("Iteration: {}; Percent complete: {:.1f}%; Average loss: {:.4f}".format(iteration, iteration / n_iteration * 100, print_loss_avg))
            print_loss = 0

        # Save checkpoint
        if (iteration % save_every == 0):
            directory = os.path.join(save_dir, model_name, '{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size))
            if not os.path.exists(directory):
                os.makedirs(directory)
            torch.save({
                'iteration': iteration,
                'en': encoder.state_dict(),
                'sec_en':sec_encoder.state_dict(),
                'de': decoder.state_dict(),
                'en_opt': encoder_optimizer.state_dict(),
                'sec_en_opt': sec_encoder_optimizer.state_dict(),
                'de_opt': decoder_optimizer.state_dict(),
                'loss': loss,
                'voc_dict': voc.__dict__,
                'embedding': embedding.state_dict()
            }, os.path.join(directory, '{}_{}.tar'.format(iteration, 'checkpoint')))

In [14]:
class GreedySearchDecoder(nn.Module):
    def __init__(self, encoder, sec_encoder, decoder):
        super(GreedySearchDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.sec_encoder = sec_encoder

    def forward(self, input_seq, input_length, sec_seq, sec_length, max_length):
        # Forward input and section through encoder model
        encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length)#(L,1,H) (layer*direc,1,H)
        sec_outputs, sec_hidden = self.sec_encoder(sec_seq, sec_length)
        
        # Prepare encoder's final hidden layer to be first hidden input to the decoder
        decoder_hidden = encoder_hidden[:decoder.n_layers]
        # Initialize decoder input with SOS_token
        decoder_input = torch.ones(1, 1, device=device, dtype=torch.long) * SOS_token
        # Initialize tensors to append decoded words to
        all_tokens = torch.zeros([0], device=device, dtype=torch.long)
        all_scores = torch.zeros([0], device=device)
        # Iteratively decode one word token at a time
        for _ in range(max_length):
            # Forward pass through decoder 
            decoder_output, decoder_hidden = self.decoder(decoder_input, sec_hidden, decoder_hidden, encoder_outputs)
            # Obtain most likely word token and its softmax score
            decoder_scores, decoder_input = torch.max(decoder_output, dim=1)
            # Record token and score
            all_tokens = torch.cat((all_tokens, decoder_input), dim=0)
            all_scores = torch.cat((all_scores, decoder_scores), dim=0)
            # Prepare current token to be next decoder input (add a dimension)
            decoder_input = torch.unsqueeze(decoder_input, 0)
        # Return collections of word tokens and scores
        return all_tokens, all_scores

In [None]:
class Beam(object):
    def __init__(self, tokens, log_probs, sec_hidden, decoder_hidden, encoder_outputs):
        self.tokens = tokens
        self.log_probs = log_probs
        self.sec_hidden = sec_hidden
        self.decoder_hidden = decoder_hidden
        self.encoder_outputs = encoder_outputs
        
    def extend(self, token, log_prob, sec_hidden, decoder_hidden, encoder_outputs):
        return Beam(tokens = self.tokens+[token], 
                   log_probs = self.log_probs+[log_prob],
                   sec_hidden = sec_hidden,
                   decoder_hidden = decoder_hidden,
                   encoder_outputs = encoder_outputs)
    
    @property
    def latest_token(self):
        return self.tokens[-1]
    
    @property
    def avg_log_prob(self):
        return sum(self.log_probs)/len(self.tokens)

class BeamSearchDecoder(object):
    def __init__(self, encoder, sec_encoder, decoder):
        self.encoder = encoder
        self.decoder = decoder
        self.sec_encoder = sec_encoder
    
    def sort_beams(self,beams):
        return sorted(beams, key=lambda h: h.avg_log_prob, reverse=True)

    def beam_search(self, input_seq, input_length, sec_seq, sec_length, max_length, beam_size):
        # Forward input and section through encoder model
        encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length)#(L,1,H) (layer*direc,1,H)
        sec_outputs, sec_hidden = self.sec_encoder(sec_seq, sec_length) #(secL,1,H) (layer*direc,1,H)
        
        # Prepare encoder's final hidden layer to be first hidden input to the decoder
        decoder_hidden = encoder_hidden[:decoder.n_layersSOS_tokentokentokentoken      
        beams = [Beam(tokens=[SOS_token],
                     log_probs=[0.0],
                     sec_hidden=sec_hidden,
                     decodeencoder_outputsecoder_hidden,
                     encoder_ouputs=encoder_outputs) for _ in range(beam_size)]
        
        results = []
        steps = 0
        while steps < max_length and len(results) < beam_size:
            latest_tokens = [h.latest_token for h in beams]
            latest_tokens = torch.LongTensor(latest_tokens,device=device, dtype=torch.long)
            
            all_decoder_hidden = []
            
            for h in beams:
                all_decoder_hidden.append(h.decoder_hidden.transpose(0,1).contiguous())
            
            sec_hidden_stack = beams[0].sec_hidden # (layer*direc,1,H)
            decoder_hidden_stack = torch.stack(all_decoder_hidden).transpose(0,1).contiguous() # (layer*direc,Beam,H)
            encoder_outputs_stack =  beams[0].encoder_outputs # (secL,1,H)

In [15]:
def evaluate(encoder, sec_encoder, decoder, searcher, voc, sentence, wikiSec, max_length=MAX_LENGTH):
    ### Format input sentence as a batch
    # words -> indexes
    indexes_batch = [indexesFromSentence(voc, sentence)] #(1,L)
    sec_indexes = [indexesFromSentence(voc,wikiSec)]
    # Create lengths tensor
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch]) #(1,)
    sec_lengths = torch.tensor([len(indexes) for indexes in sec_indexes])
    # Transpose dimensions of batch to match models' expectations
    input_batch = torch.LongTensor(indexes_batch).transpose(0, 1) #(L,1)
    sec_batch = torch.LongTensor(sec_indexes).transpose(0, 1) 
    # Use appropriate device
    input_batch = input_batch.to(device)
    sec_batch = sec_batch.to(device)
    lengths = lengths.to(device)
    sec_lengths = sec_lengths.to(device)
    # Decode sentence with searcher
    tokens, scores = searcher(input_batch, lengths, sec_batch, sec_lengths, max_length)
    # indexes -> words
    decoded_words = [voc.index2word[token.item()] for token in tokens]
    return decoded_words


def evaluateInput(encoder, sec_encoder, decoder, searcher, voc, wiki_strings):
    input_sentence = ''
    while(1):
        try:
            doc_idx = int(input('document index:'))
            sec_idx = int(input('section index:'))
            sec_sentence = wiki_strings[doc_idx][sec_idx]
            # Get input sentence
            input_sentence = input('> ')
            # Check if it is quit case
            if input_sentence == 'q' or input_sentence == 'quit': break
            # Normalize sentence
            input_sentence = normalizeString(input_sentence)
            # Evaluate sentence
            output_words = evaluate(encoder, sec_encoder, decoder, searcher, voc, input_sentence, sec_sentence)
            # Format and print response sentence
            output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]
            print('Bot:', ' '.join(output_words))

        except KeyError:
            print("Error: Encountered unknown word.")

In [18]:
# Configure models
model_name = 'cb_model'
attn_model = 'dot'
#attn_model = 'general'
#attn_model = 'concat'
hidden_size = 300
embedding_size = 100
encoder_n_layers = 2
decoder_n_layers = 1
dropout = 0.3
batch_size = 64

# Set checkpoint to load from; set to None if starting from scratch
loadFilename = None
checkpoint_iter = 4000
#loadFilename = os.path.join(save_dir, model_name, corpus_name,
#                            '{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size),
#                            '{}_checkpoint.tar'.format(checkpoint_iter))


# Load model if a loadFilename is provided
if loadFilename:
    # If loading on same machine the model was trained on
    checkpoint = torch.load(loadFilename)
    # If loading a model trained on GPU to CPU
    #checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))
    encoder_sd = checkpoint['en']
    sec_encoder_sd = checkpoint['sec_en']
    decoder_sd = checkpoint['de']
    encoder_optimizer_sd = checkpoint['en_opt']
    sec_encoder_optimizer_sd = checkpoint['sec_en_opt']
    decoder_optimizer_sd = checkpoint['de_opt']
    embedding_sd = checkpoint['embedding']
    voc.__dict__ = checkpoint['voc_dict']


print('Building encoder and decoder ...')
# Initialize word embeddings
embedding = nn.Embedding(voc.num_words, embedding_size)
if loadFilename:
    embedding.load_state_dict(embedding_sd)
# Initialize encoder & decoder models
encoder = EncoderRNN(embedding_size,hidden_size, embedding, encoder_n_layers, dropout)
sec_encoder = EncoderRNN(embedding_size,hidden_size, embedding, encoder_n_layers, dropout)
decoder = LuongAttnDecoderRNN(attn_model, embedding,embedding_size, encoder_n_layers,hidden_size, voc.num_words, decoder_n_layers, dropout)
if loadFilename:
    encoder.load_state_dict(encoder_sd)
    sec_encoder.load_state_dict(sec_encoder_sd)
    decoder.load_state_dict(decoder_sd)
# Use appropriate device
encoder = encoder.to(device)
sec_encoder = sec_encoder.to(device)
decoder = decoder.to(device)
print('Models built and ready to go!')

Building encoder and decoder ...
Models built and ready to go!


In [19]:
# Configure training/optimization
clip = 20.0
teacher_forcing_ratio = 1
learning_rate = 0.0001
decoder_learning_ratio = 3.0
n_iteration = 30000
print_every = 100
save_every = 500

# Ensure dropout layers are in train mode
encoder.train()
sec_encoder.train()
decoder.train()

# Initialize optimizers
print('Building optimizers ...')
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
sec_encoder_optimizer = optim.Adam(sec_encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio)
if loadFilename:
    encoder_optimizer.load_state_dict(encoder_optimizer_sd)
    sec_encoder_optimizer.load_state_dict(sec_encoder_optimizer_sd)
    decoder_optimizer.load_state_dict(decoder_optimizer_sd)

# Run training iterations
print("Starting Training!")
trainIters(model_name, voc, pairs, wiki_strings, encoder, sec_encoder,decoder, encoder_optimizer,sec_encoder_optimizer, decoder_optimizer,
           embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size,
           print_every, save_every, clip, corpus_name, loadFilename)

Building optimizers ...
Starting Training!
Initializing ...
Training...
Iteration: 100; Percent complete: 0.3%; Average loss: 6.6619
Iteration: 200; Percent complete: 0.7%; Average loss: 5.7941
Iteration: 300; Percent complete: 1.0%; Average loss: 5.6615
Iteration: 400; Percent complete: 1.3%; Average loss: 5.5169
Iteration: 500; Percent complete: 1.7%; Average loss: 5.4277
Iteration: 600; Percent complete: 2.0%; Average loss: 5.3053
Iteration: 700; Percent complete: 2.3%; Average loss: 5.1602
Iteration: 800; Percent complete: 2.7%; Average loss: 5.1044
Iteration: 900; Percent complete: 3.0%; Average loss: 5.0309
Iteration: 1000; Percent complete: 3.3%; Average loss: 4.9572
Iteration: 1100; Percent complete: 3.7%; Average loss: 4.9141
Iteration: 1200; Percent complete: 4.0%; Average loss: 4.8925
Iteration: 1300; Percent complete: 4.3%; Average loss: 4.8392
Iteration: 1400; Percent complete: 4.7%; Average loss: 4.8028
Iteration: 1500; Percent complete: 5.0%; Average loss: 4.7780
Iterati

Iteration: 13100; Percent complete: 43.7%; Average loss: 3.6621
Iteration: 13200; Percent complete: 44.0%; Average loss: 3.6487
Iteration: 13300; Percent complete: 44.3%; Average loss: 3.6334
Iteration: 13400; Percent complete: 44.7%; Average loss: 3.6191
Iteration: 13500; Percent complete: 45.0%; Average loss: 3.6021
Iteration: 13600; Percent complete: 45.3%; Average loss: 3.6255
Iteration: 13700; Percent complete: 45.7%; Average loss: 3.6198
Iteration: 13800; Percent complete: 46.0%; Average loss: 3.5954
Iteration: 13900; Percent complete: 46.3%; Average loss: 3.5918
Iteration: 14000; Percent complete: 46.7%; Average loss: 3.5905
Iteration: 14100; Percent complete: 47.0%; Average loss: 3.5924
Iteration: 14200; Percent complete: 47.3%; Average loss: 3.5920
Iteration: 14300; Percent complete: 47.7%; Average loss: 3.5719
Iteration: 14400; Percent complete: 48.0%; Average loss: 3.5749
Iteration: 14500; Percent complete: 48.3%; Average loss: 3.5974
Iteration: 14600; Percent complete: 48.7

Iteration: 26000; Percent complete: 86.7%; Average loss: 3.2042
Iteration: 26100; Percent complete: 87.0%; Average loss: 3.2171
Iteration: 26200; Percent complete: 87.3%; Average loss: 3.2186
Iteration: 26300; Percent complete: 87.7%; Average loss: 3.1725
Iteration: 26400; Percent complete: 88.0%; Average loss: 3.1854
Iteration: 26500; Percent complete: 88.3%; Average loss: 3.1973
Iteration: 26600; Percent complete: 88.7%; Average loss: 3.1571
Iteration: 26700; Percent complete: 89.0%; Average loss: 3.1927
Iteration: 26800; Percent complete: 89.3%; Average loss: 3.1708
Iteration: 26900; Percent complete: 89.7%; Average loss: 3.1778
Iteration: 27000; Percent complete: 90.0%; Average loss: 3.1752
Iteration: 27100; Percent complete: 90.3%; Average loss: 3.1813
Iteration: 27200; Percent complete: 90.7%; Average loss: 3.1813
Iteration: 27300; Percent complete: 91.0%; Average loss: 3.1691
Iteration: 27400; Percent complete: 91.3%; Average loss: 3.1627
Iteration: 27500; Percent complete: 91.7

In [20]:
# Set dropout layers to eval mode
encoder.eval()
sec_encoder.eval()
decoder.eval()

# Initialize search module
searcher = GreedySearchDecoder(encoder, sec_encoder, decoder)

# Begin chatting (uncomment and run the following line to begin)
evaluateInput(encoder, sec_encoder, decoder, searcher, voc,wiki_strings)

document index:5
section index:0
> what do you think of inception ?
Bot: i like the movie . about a bit of a few action . a great director . a great director .
document index:22
section index:2
> i love you
Bot: i am not sure . i like the movie and the story of the movie . the original and the plot .
document index:23
section index:2
> do you love me ?
Bot: yes i do like the movie the music the movie the animation the animation the animation . a
document index:23
section index:0
> you are stupid 
Bot: i am good movies my kids . kids and adults . kids are good
document index:12
section index:2
> aha
Bot: i think you are right the conversation
document index:12
section index:2
> you will not get hurt
Bot: i think i will be interested in watching it . like it . like that . man . . .
document index:21
section index:0
> do you love me ?
Bot: i do not think i ve seen it
document index:15 
section index:0
> do you think i am ok ?
Bot: i think it s a great movie . i think it was a great movie 

In [None]:
import torch

l1 = torch.randint(10,(10,))
l2 = l1
l2, idx1 = torch.sort(l2,descending=True)
a = l1.index_select(0,idx1)
_, idx2 = torch.sort(idx1)
print(l1)
print(a)
print(l2.index_select(0,idx2))

In [None]:
import torch.nn as nn
import torch.nn.functional as F

# rnn = nn.GRU(10, 20, 2,bidirectional=True)# 2 layers, bidirectional, 10 embedding, 20 hidden
# input = torch.randn(5, 3, 10)# sentence length 5, batch 3, embedding 10
# h0 = torch.randn(4, 3, 20) # layers 2 * bidirection, batch 3, hidden 20
# output, hn = rnn(input, h0) # sentence length 5, batch 3, hidden 20 * bidirection
# print(output.shape)
# print(hn.shape) # layers 2 * bidirection, batch 3, hidden 20
# decoder_input = torch.LongTensor([[3 for _ in range(10)]])
# print(decoder_input.shape)

hidden = torch.randn(1,5,10)
outputs = torch.randn(3,5,10)
attn = torch.sum(hidden*outputs,dim=2)
print(attn.shape)
attn = attn.t()
print(attn.shape)
r = F.softmax(attn,dim=1).unsqueeze(1)
print(r.shape)