In [1]:
import time, random
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("use: " + str(device))

use: cpu


In [2]:
batch_size = 4
learning_rate = 1e-4
n_epoch = 100
teacher_forcing_rate = 0.0

## Toy Problem (Addition Dataset)
3桁の自然数同士の足し算を学習させる。
情報源アルファベットは
- 数字（0~9）
- 加算記号（+）
- 特殊記号
    - PAD( )：パディング用の記号
    - BOS(_)：出力開始記号
の13種類

### ミニバッチ学習とパディング
ミニバッチ学習ではエンコーダーに`(seq_length, batch_size)`サイズのテンソルを入力して、デコーダーではエンコーダーから得られたコンテキストベクトルと、用意したBOS記号を入力として学習する。

このエンコーダーへの入力において、各データごとに系列長(`seq_length`)が異なる（例えば`3+3=6`,  `111+111=222`だと前者は`seq_length=3`後者は`seq_length=7`）。しかし、`torch.Tensor`は行ごとに違う長さの配列を扱えないため、各ミニバッチで系列長を揃える必要がある（バッチ全体で揃えても良いが、それでは余計な計算コストが増える。そのためにパディングを行う。

パディングをするインデックスは`PAD`変数に記録して誤差関数の算出段階などでそのインデックスを無視する(`ignore_index=PAD`)

In [3]:
# 情報源アルファベットにidを割り当てる
PAD_TOKEN = " "
BOS_TOKEN = "_"
EOS_TOKEN = "."
ADD_TOKEN = "+"
PAD = 0
BOS = 1
EOS = 2
ADD = 3

alphabet_to_id = {
    PAD_TOKEN: PAD,
    BOS_TOKEN: BOS,
    EOS_TOKEN: EOS,
    ADD_TOKEN: ADD,
}

id_to_alphabet = {
    PAD: PAD_TOKEN,
    BOS: BOS_TOKEN,
    EOS: EOS_TOKEN,
    ADD: ADD_TOKEN
}

for i in range(10):
    alphabet_to_id[str(i)] = i + 4
    id_to_alphabet[i + 4] = str(i)

def seq_to_ids(seq):
    return list(filter(lambda e: e != PAD, [alphabet_to_id[a] for a in seq]))

def ids_to_str(ids):
    s = ""
    for i in ids:
        s += id_to_alphabet[i.item()]
    return s


# テキストデータを読み込んで文字情報をidにして、二次元リスト型のデータセットを用意
with open("./data/addition.txt", 'r') as f:
    lines = f.readlines()

X, Y = [], []
for i, line in enumerate(lines):
    x, y = line[:-1].split("_")
    X.append(seq_to_ids(x))
    Y.append(seq_to_ids(y+EOS_TOKEN))

train_X, test_X, train_Y, test_Y = train_test_split(X, Y, test_size=0.1)

### パディングを適用したデータローダー

In [4]:
# max_length = 5
class Dataloader(object):
    def __init__(self, X, Y, batch_size, shuffle=False):
        self.data = list(zip(X, Y))
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.pointer = 0
    
    def __iter__(self):
        return self
    
    def __next__(self):
        if self.pointer >= len(self.data):
            if self.shuffle:
                self.data = shuffle(self.data)
            self.pointer = 0
            raise StopIteration()
        
        seqs_X, seqs_Y = zip(*self.data[self.pointer:self.pointer+self.batch_size])
        
        # padding
        seqs_X, seqs_Y = zip(*sorted(zip(seqs_X, seqs_Y), key=lambda p: len(p[0]), reverse=True))
        lengths_X = [len(s) for s in seqs_X]
        max_length_X, max_length_Y = max(lengths_X), max([len(s) for s in seqs_Y])
        seqs_X, seqs_Y = [s + [PAD] * (max_length_X - len(s)) for s in seqs_X], [s + [PAD] * (max_length_Y - len(s)) for s in seqs_Y]

        # (seq_length, batch_size)に転置
        seqs_X = torch.tensor(seqs_X, dtype=torch.long, device=device).transpose(0, 1)
        seqs_Y = torch.tensor(seqs_Y, dtype=torch.long, device=device).transpose(0, 1)
        
        """pad_sequenceを使う方法(for文を使うので遅い)
        x, y = [], []
        for i in range(self.batch_size):
            x.append(torch.tensor(seqs_X[i], dtype=torch.long))
            y.append(torch.tensor(seqs_Y[i], dtype=torch.long))
        seqs_X, seqs_Y = pad_sequence(x, padding_value=PAD), pad_sequence(y, padding_value=PAD)
        """
        
        self.pointer += self.batch_size
        return seqs_X, seqs_Y, lengths_X # lengths_Xはpack_padded_sequenceに使用
    
trainloader = Dataloader(train_X, train_Y, batch_size=batch_size, shuffle=False)
testloader = Dataloader(test_X, test_Y, batch_size=100, shuffle=False)

In [5]:
n_alphabet = len(alphabet_to_id.keys()) # 文字などの種類の数
embed_dim = 16   # 埋め込み次元
hid_dim = 128     # 隠れ状態の次元

In [6]:
class Encoder(nn.Module):
    def __init__(self, n_alphabet, embed_dim, hid_dim):
        super(Encoder, self).__init__()
        
        self.embedding = nn.Embedding(n_alphabet, embed_dim, padding_idx=PAD)
        self.lstm = nn.LSTM(embed_dim, hid_dim)
    
    def forward(self, xs, lengths_xs, hidden=None):
        emb = self.embedding(xs)
        emb = pack_padded_sequence(emb, lengths_xs)
        hiddens, context = self.lstm(emb, hidden)
        return context

In [7]:
class Decoder(nn.Module):
    def __init__(self, n_alphabet, embed_dim, hid_dim):
        super(Decoder, self).__init__()
        
        self.embedding = nn.Embedding(n_alphabet, embed_dim, padding_idx=PAD)
        self.lstm = nn.LSTM(embed_dim, hid_dim)
        self.fc = nn.Linear(hid_dim, n_alphabet)
    
    def forward(self, x, hidden):
        emb = self.embedding(x)
        out, hidden = self.lstm(emb, hidden)
        score = self.fc(out)
        return score, hidden

In [8]:
class Seq2Seq(nn.Module):
    def __init__(self, n_alphabet, embed_dim, hid_dim):
        super(Seq2Seq, self).__init__()
        self.encoder = Encoder(n_alphabet, embed_dim, hid_dim)
        self.decoder = Decoder(n_alphabet, embed_dim, hid_dim)
        self.n_alphabet = n_alphabet
        
    def forward(self, seqs_X, lengths_X, seqs_Y=None, max_length=None, teacher_forcing_rate=0.0, device=device):
        context = self.encoder(seqs_X, lengths_X)
        
        if max_length is None:
            max_length = seqs_Y.size(0)
        batch_size = seqs_X.size(1)
        
        # (max_length, batch_size, n_alphabet)
        scores = torch.zeros(max_length, batch_size, self.n_alphabet, device=device)
        
        # (1, batch_size)
        decoder_input = torch.ones(1, batch_size, dtype=torch.long, device=device) * BOS
        
        hidden = context
        for t in range(max_length):
            score, hidden = self.decoder(decoder_input, hidden)
            scores[t] += score.squeeze(0)
            if seqs_Y is not None and random.random() < teacher_forcing_rate:
                # teacher forcing
                decoder_input = seqs_Y[t].unsqueeze(0)
            else:
                decoder_input = torch.argmax(score, dim=-1)
        
        return scores

model = Seq2Seq(n_alphabet, embed_dim, hid_dim).to(device)

In [9]:
def criterion(scores, target):
    return F.cross_entropy(scores.view(-1, scores.size(-1)), (target.contiguous()).view(-1), reduction='sum', ignore_index=PAD)

In [10]:
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [11]:
def print_result(x, y, pred):
    x = x.transpose(1, 0)
    y = y.transpose(1, 0)
    pred = pred.transpose(1, 0)
    for i in range(x.size(0)):
        print(ids_to_str(x[i])+" = "+ids_to_str(y[i])+" pred:" + ids_to_str(pred[i]))

In [12]:
for epoch in range(n_epoch):
    print("epoch: {} / {} start".format(epoch + 1, n_epoch))
    start_at = time.time()
    
    model.train()
    for data in trainloader:
        x, y, lengths_x = data[0].to(device), data[1].to(device), data[2]
        scores = model(x, lengths_x, y, teacher_forcing_rate=teacher_forcing_rate)
        loss = criterion(scores, y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    accuracy = 0
    model.eval()
    for data in testloader:
        x, y, lengths_x = data[0].to(device), data[1].to(device), data[2]
        scores = model(x, lengths_x, seqs_Y=y)
        pred = torch.argmax(scores, dim=-1)
        
        # replace EOS to PAD
        ones = torch.ones_like(pred, dtype=torch.long, device=device)
        y = torch.where(y != ones * EOS, y, ones * PAD)
        pred = torch.where(pred != ones * EOS, pred, ones * PAD)
        
        accuracy += torch.sum(torch.sum(y == pred, dim=0) == y.size(0)).item()
    accuracy /= len(test_Y)
    
    print("accuracy: {:.1f}%  process time:{:.1f}".format(accuracy * 100, time.time() - start_at))
    print_result(x, y, pred)

epoch: 1 / 100 start
accuracy: 0.2%  process time:100.2
203+289 = 492   pred:311  
911+760 = 1671  pred:1111 
175+172 = 347   pred:611  
957+320 = 1277  pred:1111 
559+489 = 1048  pred:101  
335+877 = 1212  pred:1011 
715+746 = 1461  pred:1011 
129+811 = 940   pred:101  
879+909 = 1788  pred:1111 
422+440 = 862   pred:555  
818+977 = 1795  pred:1111 
871+373 = 1244  pred:1011 
566+697 = 1263  pred:1011 
235+828 = 1063  pred:101  
538+885 = 1423  pred:1011 
124+916 = 1040  pred:101  
294+150 = 444   pred:411  
294+900 = 1194  pred:1011 
589+647 = 1236  pred:101  
410+162 = 572   pred:511  
818+249 = 1067  pred:101  
773+327 = 1100  pred:101  
756+886 = 1642  pred:1611 
415+672 = 1087  pred:101  
472+820 = 1292  pred:1011 
258+589 = 847   pred:101  
30+643  = 673   pred:611  
305+76  = 381   pred:611  
54+746  = 800   pred:555  
56+172  = 228   pred:311  
146+53  = 199   pred:511  
25+998  = 1023  pred:101  
526+83  = 609   pred:101  
731+88  = 819   pred:101  
26+101  = 127   pred:211  

KeyboardInterrupt: 

In [13]:
torch.save(model.state_dict(), "./parameters.pth")