In [1]:
import torch
import torch.nn.functional as F
from collections import defaultdict
import random
import numpy as np

In [121]:
# format of files: each line is "word1 word2 ..." aligned line-by-line
train_src_file = "data/parallel/train.ja"
train_trg_file = "data/parallel/train.en"
dev_src_file = "data/parallel/dev.ja"
dev_trg_file = "data/parallel/dev.en"

w2i_src = defaultdict(lambda: len(w2i_src))
w2i_src["<s>"]
w2i_trg = defaultdict(lambda: len(w2i_trg))
w2i_trg["<s>"]


MAX_SRC_LENGTH = -1
MAX_TRG_LENGTH = -1
def read(fname_src, fname_trg, tot = 1024):
    """
    Read parallel files where each line lines up
    """
    global MAX_SRC_LENGTH
    global MAX_TRG_LENGTH
    with open(fname_src, "r") as f_src, open(fname_trg, "r") as f_trg:
        i = 0
        for line_src, line_trg in zip(f_src, f_trg):
            if i > tot - 1:
                break
            i += 1
            sent_src = [w2i_src[x] for x in line_src.strip().split()]
            MAX_SRC_LENGTH = max(len(sent_src), MAX_SRC_LENGTH)
            sent_trg = [w2i_trg[x] for x in line_trg.strip().split()]
            MAX_TRG_LENGTH = max(len(sent_trg), MAX_TRG_LENGTH)
            yield (sent_src, sent_trg)

# Read the data
train = list(read(train_src_file, train_trg_file))
unk_src = w2i_src["<unk>"]
w2i_src = defaultdict(lambda: unk_src, w2i_src)
unk_trg = w2i_trg["<unk>"]
w2i_trg = defaultdict(lambda: unk_trg, w2i_trg)
nwords_src = len(w2i_src)
nwords_trg = len(w2i_trg)
#dev = list(read(dev_src_file, dev_trg_file, 512))
print MAX_SRC_LENGTH, MAX_TRG_LENGTH
print nwords_src, nwords_trg

51 31
2153 2056


In [122]:
train[1], len(train)

(([10, 11, 12, 13, 14, 15, 16, 7, 2, 17, 18, 19, 20, 21, 22, 6, 9],
  [7, 8, 9, 10, 11, 12, 13, 14, 6]),
 1024)

In [123]:
def read_print(fname_src, fname_trg):
    """
    Read parallel files where each line lines up
    """
    with open(fname_src, "r") as f_src, open(fname_trg, "r") as f_trg:
        for line_src, line_trg in zip(f_src, f_trg):
            print line_src, line_trg
            break
read_print(train_src_file, train_trg_file)

ステーキ は 中位 で 焼 い て くださ い 。
i like my steak medium .



In [124]:
torch.manual_seed(7)

class TranslationRetrieval(torch.nn.Module):
    def __init__(self, src_vocab_len, trg_vocab_len, embedding_dim, hidden_dim, batch_size):
        super(TranslationRetrieval, self).__init__()
        
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size
        
        self.src_embed = torch.nn.Embedding(src_vocab_len, embedding_dim, padding_idx = 0)
        self.trg_embed = torch.nn.Embedding(trg_vocab_len, embedding_dim, padding_idx = 0)
        self.src_lstm = torch.nn.LSTM(embedding_dim, hidden_dim)
        self.trg_lstm = torch.nn.LSTM(embedding_dim, hidden_dim)
        
        self.src_hidden = None
        self.trg_hidden = None
        
        self.initialize()
        
        
    def forward(self, src_inputs, trg_inputs):
        
        src_embed = self.src_embed(src_inputs)
        trg_embed = self.trg_embed(trg_inputs)
        
        src_output, src_hidden = self.src_lstm(src_embed, self.src_hidden)
        src_out, _ = src_hidden
        trg_output, trg_hidden = self.trg_lstm(trg_embed, self.trg_hidden)
        trg_out, _ = trg_hidden
        return src_out, trg_out
        
    def initialize(self):
        self.src_hidden = (torch.zeros(1, self.batch_size, self.hidden_dim),
                           torch.zeros(1, self.batch_size, self.hidden_dim))
        self.trg_hidden = (torch.zeros(1, self.batch_size, self.hidden_dim),
                           torch.zeros(1, self.batch_size, self.hidden_dim))
        
model = TranslationRetrieval(src_vocab_len = 3, trg_vocab_len = 3, embedding_dim = 2, hidden_dim = 4, batch_size = 2)
print torch.tensor([[0,1,2],[0,1,2]]).t()
model.forward(torch.tensor([[0,1,0,0],[0,1,0,0]]).t(), torch.tensor([[1,0,1], [1,1,2]]).t())

tensor([[0, 0],
        [1, 1],
        [2, 2]])


(tensor([[[-0.0807, -0.1007, -0.1942, -0.0205],
          [-0.0807, -0.1007, -0.1942, -0.0205]]], grad_fn=<ViewBackward>),
 tensor([[[ 0.0066,  0.5216, -0.2422, -0.1972],
          [ 0.0460,  0.4311, -0.2423, -0.2371]]], grad_fn=<ViewBackward>))

In [125]:
batch_size = 16

model = TranslationRetrieval(src_vocab_len = nwords_src, trg_vocab_len = nwords_trg, 
                             embedding_dim = 32, hidden_dim = 32, batch_size = batch_size)


loss_fn = torch.nn.MultiMarginLoss(reduction = "sum")
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for t in range(10000):
    random.shuffle(train)
    tot_loss = 0
    for sid in range(0, len(train), batch_size):
        optimizer.zero_grad()
        model.initialize()

        src_list = [[0 for _ in range(MAX_SRC_LENGTH)] for _ in range(batch_size)]
        trg_list = [[0 for _ in range(MAX_TRG_LENGTH)] for _ in range(batch_size)]
        for i in range(batch_size):
            src_sent, trg_sent = train[sid+i]
            for j in range(len(src_sent)):
                src_list[i][j] = src_sent[j]
            for j in range(len(trg_sent)):
                trg_list[i][j] = trg_sent[j] 
        src_list = torch.tensor(src_list).t()
        trg_list = torch.tensor(trg_list).t()
        src_pred, trg_pred = model(src_list, trg_list)
        
        matrix = torch.mm(src_pred.view(batch_size, -1), trg_pred.view(batch_size, -1).t())
        loss = loss_fn(matrix, torch.tensor(range(batch_size)))
        loss.backward()
        optimizer.step()
        tot_loss += loss.item()
    if t%100 == 99:
        print t, tot_loss/len(train)

99 0.119174408144
199 0.0784459616407
299 0.069283990073
399 0.0646136462456
499 0.0516234857496
599 0.0524801849679
699 0.0482873771107
799 0.0439737692359
899 0.0418169904151
999 0.0392838096304
1099 0.0380574452865
1199 0.0377949607791
1299 0.0333538646228
1399 0.03543443323
1499 0.035748228227
1599 0.0338721896987
1699 0.0327862428967
1799 0.0337030412047
1899 0.030920356483
1999 0.0319915010186
2099 0.0296992468066
2199 0.0301666703017
2299 0.0260528535
2399 0.0262694187404
2499 0.0259705003991
2599 0.0326465129911
2699 0.0255683002761
2799 0.0231270299701
2899 0.0238952251093
2999 0.024560724647
3099 0.0251003823942
3199 0.022744486283
3299 0.0220748823776
3399 0.020469052688
3499 0.0191370627726
3599 0.0235361736704
3699 0.0249580150266
3799 0.0233411132649
3899 0.0222041972738
3999 0.0237440802011
4099 0.0185055304464
4199 0.0185835115699
4299 0.0184142313519
4399 0.0245127692615
4499 0.0194680496352
4599 0.0187962324417
4699 0.0201208339713
4799 0.0178982174693
4899 0.01751417

In [36]:
def retrieve(src, db_mtx):
    scores = np.dot(db_mtx,src)
    ranks = np.argsort(-scores)
    return ranks, scores

retrieve([0,0,1], [[1,1,0.5],[0.1,0,1],[0,0,4]])

(array([2, 1, 0]), array([0.5, 1. , 4. ]))

In [126]:
src_matrix = []
trg_matrix = []
with torch.no_grad():
    for sid in range(0, len(train), batch_size):
        model.initialize()

        src_list = [[0 for _ in range(MAX_SRC_LENGTH)] for _ in range(batch_size)]
        trg_list = [[0 for _ in range(MAX_TRG_LENGTH)] for _ in range(batch_size)]
        for i in range(batch_size):
            src_sent, trg_sent = train[sid+i]
            for j in range(len(src_sent)):
                src_list[i][j] = src_sent[j]
            for j in range(len(trg_sent)):
                trg_list[i][j] = trg_sent[j] 
        src_list = torch.tensor(src_list).t()
        trg_list = torch.tensor(trg_list).t()
        src_pred, trg_pred = model(src_list, trg_list)
        src_matrix += (src_pred.numpy()[0]).tolist()
        trg_matrix += (trg_pred.numpy()[0]).tolist()


In [130]:
tot = 0.0
for i in range(len(src_matrix)):
    ranks, _ = retrieve(src_matrix[i], trg_matrix)
    if i in ranks[:5]:
        tot += 1
print tot/len(src_matrix)

0.642578125
