In [3]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

In [4]:
import torch
from torch.jit import script, trace
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import csv
import random
import re
import os
import unicodedata
import codecs
from io import open
import itertools
import math

In [5]:
USE_CUDA = torch.cuda.is_available()
device = torch.device('cuda' if USE_CUDA else 'cpu')

In [6]:
device

device(type='cpu')

In [7]:
corpus_name = 'cornell movie-dialogs corpus'
corpus = os.path.join('data', corpus_name)

In [8]:
def printLines(file, n=10):
    with open(file, 'rb') as datafile:
        lines = datafile.readlines()
    for line in lines[:n]:
        print(line)

In [10]:
printLines(os.path.join(corpus, 'movie_lines.txt'))

b'L1045 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ They do not!\n'
b'L1044 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ They do to!\n'
b'L985 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I hope so.\n'
b'L984 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ She okay?\n'
b"L925 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Let's go.\n"
b'L924 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ Wow\n'
b"L872 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Okay -- you're gonna need to learn how to lie.\n"
b'L871 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ No\n'
b'L870 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I\'m kidding.  You know how sometimes you just become this "persona"?  And you don\'t know how to quit?\n'
b'L869 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Like my fear of wearing pastels?\n'


In [11]:
def loadLines(fileName, fields):
    lines = {}
    with open(fileName, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(' +++$+++ ')
            lineObj = {}
            for i, field in enumerate(fields):
                lineObj[field] = values[i]
            lines[lineObj['lineID']] = lineObj
    return lines

In [12]:
MOVIE_LINES_FIELDS = ["lineID", "characterID", "movieID", "character", "text"]
lines = loadLines(os.path.join(corpus, 'movie_lines.txt'), MOVIE_LINES_FIELDS)

In [13]:
lines['L1045']

{'lineID': 'L1045',
 'characterID': 'u0',
 'movieID': 'm0',
 'character': 'BIANCA',
 'text': 'They do not!\n'}

In [14]:
def loadConversations(fileName, lines, fields):
    conversations = []
    with open(fileName, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(' +++$+++ ')
            convObj = {}
            for i, field in enumerate(fields):
                convObj[field] = values[i]
            lineIds = eval(convObj['utteranceIDs'])
            convObj['lines'] = []
            for lineId in lineIds:
                convObj['lines'].append(lines[lineId])
            conversations.append(convObj)
    return conversations

In [15]:
MOVIE_CONVERSATIONS_FIELDS = ["character1ID", "character2ID", "movieID", "utteranceIDs"]

In [16]:
conversations = loadConversations(os.path.join(corpus, 'movie_conversations.txt'),
                                  lines, MOVIE_CONVERSATIONS_FIELDS)

In [17]:
conversations[0]

{'character1ID': 'u0',
 'character2ID': 'u2',
 'movieID': 'm0',
 'utteranceIDs': "['L194', 'L195', 'L196', 'L197']\n",
 'lines': [{'lineID': 'L194',
   'characterID': 'u0',
   'movieID': 'm0',
   'character': 'BIANCA',
   'text': 'Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\n'},
  {'lineID': 'L195',
   'characterID': 'u2',
   'movieID': 'm0',
   'character': 'CAMERON',
   'text': "Well, I thought we'd start with pronunciation, if that's okay with you.\n"},
  {'lineID': 'L196',
   'characterID': 'u0',
   'movieID': 'm0',
   'character': 'BIANCA',
   'text': 'Not the hacking and gagging and spitting part.  Please.\n'},
  {'lineID': 'L197',
   'characterID': 'u2',
   'movieID': 'm0',
   'character': 'CAMERON',
   'text': "Okay... then how 'bout we try out some French cuisine.  Saturday?  Night?\n"}]}

In [18]:
def extractSentencePairs(conversations):
    qa_pairs = []
    for conversation in conversations:
        for i in range(len(conversation['lines']) - 1):
            inputLine = conversation['lines'][i]['text'].strip()
            targetLine = conversation['lines'][i + 1]['text'].strip()
            if inputLine and targetLine:
                qa_pairs.append([inputLine, targetLine])
    return qa_pairs

In [19]:
datafile = os.path.join(corpus, 'formatted_movie_lines.txt')

In [20]:
delimiter = '\t'
delimiter = str(codecs.decode(delimiter, 'unicode_escape'))

In [21]:
with open(datafile, 'w', encoding='utf-8') as outputfile:
    writer = csv.writer(outputfile, delimiter=delimiter)
    for pair in extractSentencePairs(conversations):
        writer.writerow(pair)

In [22]:
pairs = extractSentencePairs(conversations)

In [23]:
printLines(datafile)

b"Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\tWell, I thought we'd start with pronunciation, if that's okay with you.\r\n"
b"Well, I thought we'd start with pronunciation, if that's okay with you.\tNot the hacking and gagging and spitting part.  Please.\r\n"
b"Not the hacking and gagging and spitting part.  Please.\tOkay... then how 'bout we try out some French cuisine.  Saturday?  Night?\r\n"
b"You're asking me out.  That's so cute. What's your name again?\tForget it.\r\n"
b"No, no, it's my fault -- we didn't have a proper introduction ---\tCameron.\r\n"
b"Cameron.\tThe thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\r\n"
b"The thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\tSeems like she could get a date easy enough...\r\n"
b'Why?\tUnsolved myster

## Load and trim data

In [24]:
PAD_token = 0
SOS_token = 1
EOS_token = 2

In [25]:
class Voc:
    
    def __init__(self, name):
        self.name = name
        self.trimmed = False
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: 'PAD', SOS_token: 'SOS', EOS_token: 'EOS'}
        self.num_words = 3
    
    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)
    
    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.num_words
            self.word2count[word] = 1
            self.index2word[self.num_words] = word
            self.num_words += 1
        else:
            self.word2count[word] += 1
    
    def trim(self, min_count):
        if self.trimmed:
            return
        self.trimmed = True
        
        keep_words = []
        
        for k, v in self.word2count.items():
            if v >= min_count:
                keep_words.append(k)
        
        print('keep_words {} / {} = {:.4f}'.format(
            len(keep_words), len(self.word2index),
            len(keep_words) / len(self.word2index)))
        
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3
        
        for word in keep_words:
            self.addWord(word)

In [26]:
MAX_LENGTH = 10

In [27]:
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

In [28]:
def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    s = re.sub(r"\s+", r" ", s).strip()
    return s

In [29]:
def readVocs(datafile, corpus_name):
    print('Reading lines...')
    lines = open(datafile, encoding='utf-8').read().strip().split('\n')
    pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]
    voc = Voc(corpus_name)
    return voc, pairs

In [30]:
def filterPair(p):
    return len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH

In [31]:
def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

In [32]:
def loadPrepareData(corpus, corpus_name, datafile, save_dir):
    print('Start preparing training data ...')
    voc, pairs = readVocs(datafile, corpus_name)
    print('Read {!s} sentence pairs'.format(len(pairs)))
    pairs = filterPairs(pairs)
    print('Trimmed to {!s} sentence pairs'.format(len(pairs)))
    print('Counting words ...')
    for pair in pairs:
        voc.addSentence(pair[0])
        voc.addSentence(pair[1])
    print('Counted words:', voc.num_words)
    return voc, pairs

In [33]:
save_dir = os.path.join('data', 'save')

In [34]:
voc, pairs = loadPrepareData(corpus, corpus_name, datafile, save_dir)

Start preparing training data ...
Reading lines...
Read 221282 sentence pairs
Trimmed to 64271 sentence pairs
Counting words ...
Counted words: 18008


In [35]:
for pair in pairs[:10]:
    print(pair)

['there .', 'where ?']
['you have my word . as a gentleman', 'you re sweet .']
['hi .', 'looks like things worked out tonight huh ?']
['you know chastity ?', 'i believe we share an art instructor']
['have fun tonight ?', 'tons']
['well no . . .', 'then that s all you had to say .']
['then that s all you had to say .', 'but']
['but', 'you always been this selfish ?']
['do you listen to this crap ?', 'what crap ?']
['what good stuff ?', 'the real you .']


In [36]:
MIN_COUNT = 3

def trimRareWords(voc, pairs, MIN_COUNT):
    voc.trim(MIN_COUNT)
    
    keep_pairs = []
    for pair in pairs:
        input_sentence = pair[0]
        output_sentence = pair[1]
        keep_input = True
        keep_output = True
        for word in input_sentence.split(' '):
            if word not in voc.word2index:
                keep_input = False
                break
        for word in output_sentence.split(' '):
            if word not in voc.word2index:
                keep_output = False
                break
        
        if keep_input and keep_output:
            keep_pairs.append(pair)
    
    print('Trimmed from {} pairs to {}, {:.4f} of total'.format(
        len(pairs), len(keep_pairs), len(keep_pairs) / len(pairs)
    ))
    
    return keep_pairs

In [37]:
pairs = trimRareWords(voc, pairs, MIN_COUNT)

keep_words 7823 / 18005 = 0.4345
Trimmed from 64271 pairs to 53165, 0.8272 of total


## Prepare Data for Models

In [38]:
def indexesFromSentence(voc, sentence):
    return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token]

In [39]:
def zeroPadding(l, fillvalue=PAD_token):
    return list(itertools.zip_longest(*l, fillvalue=fillvalue))

In [40]:
def binaryMatrix(l, value=PAD_token):
    m = []
    for i, seq in enumerate(l):
        m.append([])
        for token in seq:
            if token == PAD_token:
                m[i].append(0)
            else:
                m[i].append(1)
    return m

In [41]:
l = ["hi", "good morning sir", "thank you"]

In [42]:
indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
indexes_batch

[[16, 2], [51, 665, 572, 2], [383, 7, 2]]

In [43]:
max_target_len = max([len(indexes) for indexes in indexes_batch])
max_target_len

4

In [44]:
padList = zeroPadding(indexes_batch)
padList

[(16, 51, 383), (2, 665, 7), (0, 572, 2), (0, 2, 0)]

In [45]:
list(itertools.zip_longest(*indexes_batch))

[(16, 51, 383), (2, 665, 7), (None, 572, 2), (None, 2, None)]

In [46]:
def inputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    padVar = torch.LongTensor(padList)
    return padVar, lengths

In [47]:
mask = binaryMatrix(padList)
mask

[[1, 1, 1], [1, 1, 1], [0, 1, 1], [0, 1, 0]]

In [48]:
mask = torch.ByteTensor(mask)
mask

tensor([[1, 1, 1],
        [1, 1, 1],
        [0, 1, 1],
        [0, 1, 0]], dtype=torch.uint8)

In [49]:
def outputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    max_target_len = max([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    mask = binaryMatrix(padList)
    mask = torch.ByteTensor(mask)
    padVar = torch.LongTensor(padList)
    return padVar, mask, max_target_len

In [50]:
def batch2TrainData(voc, pair_batch):
    pair_batch.sort(key=lambda x: len(x[0].split(' ')), reverse=True)
    input_batch, output_batch = [], []
    for pair in pair_batch:
        input_batch.append(pair[0])
        output_batch.append(pair[1])
    inp, lengths = inputVar(input_batch, voc)
    output, mask, max_target_len = outputVar(output_batch, voc)
    return inp, lengths, output, mask, max_target_len

In [51]:
small_batch_size = 5

In [52]:
bs = [random.choice(pairs) for _ in range(small_batch_size)]
bs

[['oh boy ! a jacket !', 'your mom made that all by herself .'],
 ['yes .', 'nobody . accounting .'],
 ['what s the matter with you ?', 'toothache .'],
 ['what are you working on ?', 'just a little experiment .'],
 ['feel okay ?', 'yeah .']]

In [53]:
bs.sort(key=lambda x: len(x[0].split(' ')), reverse=True)
bs

[['what s the matter with you ?', 'toothache .'],
 ['oh boy ! a jacket !', 'your mom made that all by herself .'],
 ['what are you working on ?', 'just a little experiment .'],
 ['feel okay ?', 'yeah .'],
 ['yes .', 'nobody . accounting .']]

In [54]:
batches = batch2TrainData(voc, bs)

In [55]:
batches

(tensor([[  50,  124,   50, 1026,  318],
         [  37,  519,   92,   62,    4],
         [  53,   66,    7,    6,    2],
         [ 341,   12,  937,    2,    0],
         [ 169, 2468,  177,    0,    0],
         [   7,   66,    6,    0,    0],
         [   6,    2,    2,    0,    0],
         [   2,    0,    0,    0,    0]]),
 tensor([8, 7, 7, 4, 3]),
 tensor([[6772,   70,  112,  167,  898],
         [   4, 1119,   12,    4,    4],
         [   2,  782,  201,    2, 5288],
         [   0,   36, 2395,    0,    4],
         [   0,   38,    4,    0,    2],
         [   0,  234,    2,    0,    0],
         [   0,  173,    0,    0,    0],
         [   0,    4,    0,    0,    0],
         [   0,    2,    0,    0,    0]]),
 tensor([[1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1],
         [0, 1, 1, 0, 1],
         [0, 1, 1, 0, 1],
         [0, 1, 1, 0, 0],
         [0, 1, 0, 0, 0],
         [0, 1, 0, 0, 0],
         [0, 1, 0, 0, 0]], dtype=torch.uint8),
 9)

In [56]:
input_variable, length, target_variable, mask, max_length_len = batches

In [57]:
input_variable

tensor([[  50,  124,   50, 1026,  318],
        [  37,  519,   92,   62,    4],
        [  53,   66,    7,    6,    2],
        [ 341,   12,  937,    2,    0],
        [ 169, 2468,  177,    0,    0],
        [   7,   66,    6,    0,    0],
        [   6,    2,    2,    0,    0],
        [   2,    0,    0,    0,    0]])

In [58]:
length

tensor([8, 7, 7, 4, 3])

In [59]:
target_variable

tensor([[6772,   70,  112,  167,  898],
        [   4, 1119,   12,    4,    4],
        [   2,  782,  201,    2, 5288],
        [   0,   36, 2395,    0,    4],
        [   0,   38,    4,    0,    2],
        [   0,  234,    2,    0,    0],
        [   0,  173,    0,    0,    0],
        [   0,    4,    0,    0,    0],
        [   0,    2,    0,    0,    0]])

In [60]:
mask

tensor([[1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [0, 1, 1, 0, 1],
        [0, 1, 1, 0, 1],
        [0, 1, 1, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0]], dtype=torch.uint8)

In [61]:
max_length_len

9

## Define Models

- input_seqはembeddingする前なので2D tensor (max_length, batch_size)

In [118]:
class EncoderRNN(nn.Module):
    
    def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):
        super(EncoderRNN, self).__init__()
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.embedding = embedding
        
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers,
                          dropout=(0 if n_layers == 1 else dropout),
                          bidirectional=True)

    def forward(self, input_seq, input_lengths, hidden=None):
        embedded = self.embedding(input_seq)
        packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
        outputs, hidden = self.gru(packed, hidden)
        outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs)
        outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:]
        return outputs, hidden

In [119]:
# Luong attention layer
class Attn(torch.nn.Module):
    
    def __init__(self, method, hidden_size):
        super(Attn, self).__init__()
        self.method = method
        if self.method not in ['dot', 'general', 'concat']:
            raise ValueError(self.method, 'is not an appropriate attention method.')
        self.hidden_size = hidden_size
        if self.method == 'general':
            self.attn = torch.nn.Linear(self.hidden_size, hidden_size)
        elif self.method == 'concat':
            self.attn = torch.nn.Linear(self.hidden_size * 2, hidden_size)
            self.v = torch.nn.Parameter(torch.FloatTensor(hidden_size))
        
    # encoder_outupt = (seq_len, batch, hidden_size)
    def dot_score(self, hidden, encoder_output):
        return torch.sum(hidden * encoder_output, dim=2)

    def general_score(self, hidden, encoder_output):
        energy = self.attn(encoder_output)
        return torch.sum(hidden * energy, dim=2)

    def concat_score(self, hidden, encoder_output):
        energy = self.attn(torch.cat((hidden.expand(encoder_output.size(0), -1, -1), encoder_output), 2)).tanh()
        return torch.sum(self.v * energy, dim=2)

    def forward(self, hidden, encoder_outputs):
        if self.method == 'general':
            attn_energies = self.general_score(hidden, encoder_outputs)
        elif self.method == 'concat':
            attn_energies = self.concat_score(hidden, encoder_outputs)
        elif self.method == 'dot':
            attn_energies = self.dot_score(hidden, encoder_outputs)

        attn_energies = attn_energies.t()

        return F.softmax(attn_energies, dim=1).unsqueeze(1)

In [120]:
class LuongAttnDecoderRNN(nn.Module):
    
    def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers=1, dropout=0.1):
        super(LuongAttnDecoderRNN, self).__init__()
        
        self.attn_model = attn_model
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout = dropout
        
        self.embedding = embedding
        self.embedding_dropout = nn.Dropout(dropout)
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout))
        self.concat = nn.Linear(hidden_size * 2, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        self.attn = Attn(attn_model, hidden_size)
    
    def forward(self, input_step, last_hidden, encoder_outputs):
        # 1単語ずつ入力する
        embedded = self.embedding(input_step)
        embedded = self.embedding_dropout(embedded)
        
        rnn_output, hidden = self.gru(embedded, last_hidden)
        
        # 現在のRNNの出力とエンコーダーの出力を使う
        attn_weights = self.attn(rnn_output, encoder_outputs)
        
        # エンコーダーの出力をattention重みで重みづけしてコンテキストベクトルとする
        context = attn_weights.bmm(encoder_outputs.transpose(0, 1))
        
        rnn_output = rnn_output.squeeze(0)
        context = context.squeeze(1)
        concat_input = torch.cat((rnn_output, context), 1)
        concat_output = torch.tanh(self.concat(concat_input))
        
        output = self.out(concat_output)
        output = F.softmax(output, dim=1)
        
        return output, hidden

## Masked loss

In [121]:
def maskNLLLoss(inp, target, mask):
    nTotal = mask.sum()
    crossEntropy = - torch.log(torch.gather(inp, 1, target.view(-1, 1)))
    loss = crossEntropy.masked_select(mask).mean()
    loss = loss.to(device)
    return loss, nTotal.item()

## Single training iteration

In [135]:
def train(input_variable, lengths, target_variable,
          mask, max_target_len, encoder, decoder, embedding,
          encoder_optimizer, decoder_optimizer, batch_size, clip,
          max_length=MAX_LENGTH):
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    
    input_variable = input_variable.to(device)
    lengths = lengths.to(device)
    target_variable = target_variable.to(device)
    mask = mask.to(device)
    
    loss = 0
    print_losses = []
    n_totals = 0

    # forward pass through encoder
    encoder_outputs, encoder_hidden = encoder(input_variable, lengths)

    decoder_input = torch.LongTensor([[SOS_token for _ in range(batch_size)]])
    decoder_input = decoder_input.to(device)
    
    decoder_hidden = encoder_hidden[:decoder.n_layers]
    
    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
    
    if use_teacher_forcing:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs)
            decoder_input = target_variable[t].view(1, -1)
            mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            print_losses.append(mask_loss.item() * nTotal)
            n_totals += nTotal
    else:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs)
            _, topi = decoder_output.topk(1)
            decoder_input = torch.LongTensor([[topi[i][0] for i in range(batch_size)]])
            decoder_input = decoder_input.to(device)
            mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            print_losses.append(mask_loss.item() * nTotal)
            n_totals += nTotal
    
    loss.backward()
    
    # clip gradients
    _ = torch.nn.utils.clip_grad_norm_(encoder.parameters(), clip)
    _ = torch.nn.utils.clip_grad_norm_(decoder.parameters(), clip)
    
    encoder_optimizer.step()
    decoder_optimizer.step()
    
    return sum(print_losses) / n_totals          

In [123]:
def trainIters(model_name, voc, pairs, encoder, decoder,
               encoder_optimizer, decoder_optimizer,
               embedding, encoder_n_layers, decoder_n_layers,
               save_dir, n_iteration, batch_size,
               print_every, save_every, clip, corpus_name, loadFilename):
    # 全iterationのバッチをまとめて生成（メモリ大丈夫？）
    training_batches = [batch2TrainData(voc, [random.choice(pairs) for _ in range(batch_size)]) for _ in range(n_iteration)]
    
    print('Initializing ...')
    start_iteration = 1
    print_loss = 0
    if loadFilename:
        start_iteration = checkpoint['iteration'] + 1
    
    print('Training ...')
    for iteration in range(start_iteration, n_iteration + 1):
        training_batch = training_batches[iteration - 1]
        
        input_variable, lengths, target_variable, mask, max_target_len = training_batch
        
        loss = train(input_variable, lengths, target_variable, mask,
                     max_target_len, encoder, decoder, embedding,
                     encoder_optimizer, decoder_optimizer, batch_size, clip)
        print_loss += loss
        
        # print progress
        if iteration % print_every == 0:
            print_loss_avg = print_loss / print_every
            print('Iteration: {}; Percent complete: {:.1f}%; Average loss: {:.4f}'.format(iteration, iteration / n_iteration * 100, print_loss_avg))
            print_loss = 0
        
        # save checkpoint
        if iteration % save_every == 0:
            directory = os.path.join(save_dir, model_name, corpus_name,
                                     '{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size))
            os.makedirs(directory, exist_ok=True)
            torch.save({
                'iteration': iteration,
                'en': encoder.state_dict(),
                'de': decoder.state_dict(),
                'en_opt': encoder_optimizer.state_dict(),
                'de_opt': decoder_optimizer.state_dict(),
                'loss': loss,
                'voc_dict': voc.__dict__,
                'embedding': embedding.state_dict()
            }, os.path.join(directory, '{}_{}.tar'.format(iteration, 'checkpoint')))

## Run Model

In [124]:
# configure models
model_name = 'cb_model'
attn_model = 'dot'
hidden_size = 500
encoder_n_layers = 2
decoder_n_layers = 2
dropout = 0.1
batch_size = 64

loadFilename = None
checkpoint_iter = 4000

# loadFilename = os.path.join(save_dir, model_name, corpus_name,
#                             '{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size),
#                             '{}_checkpoint.tar'.format(checkpoint_iter))

In [125]:
loadFilename

In [126]:
if loadFilename:
    checkpoint = torch.load(loadFilename)
    #checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))
    
    encoder_sd = checkpoint['en']
    decoder_sd = checkpoint['de']
    encoder_optimizer_sd = checkpoint['en_opt']
    decoder_optimizer_sd = checkpoint['de_opt']
    embedding_sd = checkpoint['embedding']
    voc.__dict__ = checkpoint['voc_dict']

- 7826個の単語を500次元の蜜ベクトルにEmbeddingする

In [127]:
print('Building encoder and decoder ...')

# initialize word embeddings
embedding = nn.Embedding(voc.num_words, hidden_size)
if loadFilename:
    embedding.load_state_dict(embedding_sd)

Building encoder and decoder ...


In [128]:
embedding

Embedding(7826, 500)

In [129]:
# initialize encoder & decoder models
encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
encoder

EncoderRNN(
  (embedding): Embedding(7826, 500)
  (gru): GRU(500, 500, num_layers=2, dropout=0.1, bidirectional=True)
)

In [130]:
decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size,
                              voc.num_words, decoder_n_layers, dropout)
decoder

LuongAttnDecoderRNN(
  (embedding): Embedding(7826, 500)
  (embedding_dropout): Dropout(p=0.1)
  (gru): GRU(500, 500, num_layers=2, dropout=0.1)
  (concat): Linear(in_features=1000, out_features=500, bias=True)
  (out): Linear(in_features=500, out_features=7826, bias=True)
  (attn): Attn()
)

In [131]:
if loadFilename:
    encoder.load_state_dict(encoder_sd)
    decoder.load_state_dict(decoder_sd)
encoder = encoder.to(device)
decoder = decoder.to(device)

## Run Training

In [132]:
# configure training/optimization
clip = 50.0
teacher_forcing_ratio = 1.0
learning_rate = 0.0001
decoder_learning_ratio = 5.0
n_iteration = 4000
print_every = 1
save_every = 500

encoder.train()
decoder.train()

print('Building optimizers ...')
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio)

if loadFilename:
    encoder_optimizer.load_state_dict(encoder_optimizer_sd)
    decoder_optimizer.load_state_dict(decoder_optimizer_sd)

Building optimizers ...


In [133]:
print(model_name)
print(voc)
print(pairs[:3])
print(encoder)
print(decoder)
print(embedding)
print(encoder_n_layers, decoder_n_layers)
print(save_dir)
print(n_iteration)
print(print_every, save_every)
print(clip)
print(corpus_name)
print(loadFilename)

cb_model
<__main__.Voc object at 0x12a4776a0>
[['there .', 'where ?'], ['you have my word . as a gentleman', 'you re sweet .'], ['hi .', 'looks like things worked out tonight huh ?']]
EncoderRNN(
  (embedding): Embedding(7826, 500)
  (gru): GRU(500, 500, num_layers=2, dropout=0.1, bidirectional=True)
)
LuongAttnDecoderRNN(
  (embedding): Embedding(7826, 500)
  (embedding_dropout): Dropout(p=0.1)
  (gru): GRU(500, 500, num_layers=2, dropout=0.1)
  (concat): Linear(in_features=1000, out_features=500, bias=True)
  (out): Linear(in_features=500, out_features=7826, bias=True)
  (attn): Attn()
)
Embedding(7826, 500)
2 2
data/save
4000
1 500
50.0
cornell movie-dialogs corpus
None


In [136]:
trainIters(model_name, voc, pairs,
           encoder, decoder,
           encoder_optimizer, decoder_optimizer,
           embedding,
           encoder_n_layers, decoder_n_layers,
           save_dir, n_iteration, batch_size,
           print_every, save_every, clip,
           corpus_name, loadFilename)

Initializing ...
Training ...
Iteration: 1; Percent complete: 0.0%; Average loss: 8.9574
Iteration: 2; Percent complete: 0.1%; Average loss: 8.7525
Iteration: 3; Percent complete: 0.1%; Average loss: 8.4271
Iteration: 4; Percent complete: 0.1%; Average loss: 8.1515
Iteration: 5; Percent complete: 0.1%; Average loss: 7.5345
Iteration: 6; Percent complete: 0.1%; Average loss: 7.0011
Iteration: 7; Percent complete: 0.2%; Average loss: 6.5914
Iteration: 8; Percent complete: 0.2%; Average loss: 6.7082
Iteration: 9; Percent complete: 0.2%; Average loss: 6.3685
Iteration: 10; Percent complete: 0.2%; Average loss: 6.1353
Iteration: 11; Percent complete: 0.3%; Average loss: 6.0705
Iteration: 12; Percent complete: 0.3%; Average loss: 5.8791
Iteration: 13; Percent complete: 0.3%; Average loss: 5.7090
Iteration: 14; Percent complete: 0.4%; Average loss: 4.8764
Iteration: 15; Percent complete: 0.4%; Average loss: 5.0021
Iteration: 16; Percent complete: 0.4%; Average loss: 5.0650
Iteration: 17; Perc

KeyboardInterrupt: 