In [1]:
import torch
import torch.nn as nn

import onmt
import onmt.io
import onmt.modules

We begin by loading in the vocabulary for the model of interest. This will let us check vocab size and to get the special ids for padding.

In [2]:
vocab = dict(torch.load("../../data/data.vocab.pt"))
src_padding = vocab["src"].stoi[onmt.io.PAD_WORD]
tgt_padding = vocab["tgt"].stoi[onmt.io.PAD_WORD]

Next we specify the core model itself. Here we will build a small model with an encoder and an attention based input feeding decoder. Both models will be RNNs and the encoder will be bidirectional

In [3]:
emb_size = 10
rnn_size = 6
# Specify the core model. 
encoder_embeddings = onmt.modules.Embeddings(emb_size, len(vocab["src"]),
                                             word_padding_idx=src_padding)

encoder = onmt.modules.RNNEncoder(hidden_size=rnn_size, num_layers=1, 
                                 rnn_type="LSTM", bidirectional=True,
                                 embeddings=encoder_embeddings)

decoder_embeddings = onmt.modules.Embeddings(emb_size, len(vocab["tgt"]),
                                            word_padding_idx=tgt_padding)
decoder = onmt.modules.InputFeedRNNDecoder(hidden_size=rnn_size, num_layers=1, 
                                           bidirectional_encoder=True,
                                           rnn_type="LSTM", embeddings=decoder_embeddings)

model = onmt.modules.NMTModel(encoder, decoder)

# Specify the tgt word generator and loss computation module
model.generator = nn.Sequential(                                                                                                                        
            nn.Linear(rnn_size, len(vocab["tgt"])),                                                                                           
            nn.LogSoftmax())
loss = onmt.Loss.NMTLossCompute(model.generator, vocab["tgt"]) 

Now we set up the optimizer. This could be a core torch optim class, or our wrapper which handles learning rate updates and gradient normalization automatically.

In [4]:
optim = onmt.Optim(method="sgd", lr=1, max_grad_norm=2)
optim.set_parameters(model.parameters())

Now we load the data from disk. Currently will need to call a function to load the fields into the data as well. 

In [5]:
# Load some data
data = torch.load("../../data/data.train.pt")
valid_data = torch.load("../../data/data.valid.pt")
data.examples = data.examples[:100]
                                        
# TODO: This is a bit hacky, need to clean up                                                                                                                       
fields = onmt.io.load_fields_from_vocab(vocab.items(), "text")
fields = dict([(k, f) for (k, f) in fields.items()                                                                                                         
                  if k in data.examples[0].__dict__])
data.fields = valid_data.fields = fields                                                                           


To iterate through the data itself we use a torchtext iterator class. We specify one for both the training and test data. 

In [6]:
train_iter = onmt.io.OrderedIterator(                                                                                                                            
                dataset=data, batch_size=10, 
                device=-1,                                                                                                                                                                                 
                repeat=False)
valid_iter = onmt.io.OrderedIterator(                                                                                                                            
                dataset=valid_data, batch_size=10,                                                                                                                                                                                       
                device=-1,
                train=False) 

Finally we train.

In [7]:
trainer = onmt.Trainer(model, train_iter, valid_iter,                                                                                                      
                       loss, loss, optim)

def report_func(*args):
    stats = args[-1]
    stats.output(args[0], args[1], 10, 0)
    return stats

for epoch in range(2):
    trainer.train(epoch, report_func)
    val_stats = trainer.validate()
    print("Validation")
    val_stats.output(epoch, 11, 10, 0)
    trainer.epoch_step(val_stats.ppl(), epoch)

Epoch  0,     0/   10; acc:   0.00; ppl: 957.49; 1272 src tok/s; 1212 tgt tok/s; 1514089083 s elapsed
Epoch  0,     1/   10; acc:  19.42; ppl: 650.16; 1369 src tok/s; 1406 tgt tok/s; 1514089083 s elapsed
Epoch  0,     2/   10; acc:  22.40; ppl: 526.48; 1344 src tok/s; 1369 tgt tok/s; 1514089083 s elapsed
Epoch  0,     3/   10; acc:  24.31; ppl: 422.44; 1431 src tok/s; 1409 tgt tok/s; 1514089084 s elapsed
Epoch  0,     4/   10; acc:  25.30; ppl: 379.79; 1430 src tok/s; 1431 tgt tok/s; 1514089084 s elapsed
Epoch  0,     5/   10; acc:  26.38; ppl: 333.62; 1400 src tok/s; 1398 tgt tok/s; 1514089084 s elapsed
Epoch  0,     6/   10; acc:  27.15; ppl: 299.44; 1451 src tok/s; 1447 tgt tok/s; 1514089084 s elapsed
Epoch  0,     7/   10; acc:  27.63; ppl: 262.90; 1564 src tok/s; 1505 tgt tok/s; 1514089084 s elapsed
Epoch  0,     8/   10; acc:  27.67; ppl: 254.80; 1575 src tok/s; 1524 tgt tok/s; 1514089084 s elapsed
Epoch  0,     9/   10; acc:  28.37; ppl: 232.36; 1618 src tok/s; 1564 tgt tok/s; 1

To use the model, we need to load up the translation functions 

In [8]:
import onmt.translate

In [9]:
data.src_vocabs

[]

In [16]:
translator = onmt.translate.Translator(beam_size=10, fields=fields, model=model)
builder = onmt.translate.TranslationBuilder(data=valid_data, fields=fields)

valid_data.src_vocabs
for batch in valid_iter:
    trans_batch = translator.translate_batch(batch=batch, data=valid_data)
    translations = builder.from_batch(trans_batch)
    for trans in translations:
        print(trans.log(0))

  return self.add_(other)


PRED SCORE: -2.4383

SENT 0: ('The', 'competitors', 'have', 'other', 'advantages', ',', 'too', '.')
PRED 0: <unk>

PRED SCORE: -2.4250

SENT 0: ('The', 'company', '&apos;s', 'durability', 'goes', 'back', 'to', 'its', 'first', 'boss', ',', 'a', 'visionary', ',', 'Thomas', 'J.', 'Watson', 'Sr.')
PRED 0: <unk>

PRED SCORE: -2.3067

SENT 0: ('&quot;', 'From', 'what', 'we', 'know', 'today', ',', 'you', 'have', 'to', 'ask', 'how', 'I', 'could', 'be', 'so', 'wrong', '.', '&quot;')
PRED 0: <unk>

PRED SCORE: -2.3162

SENT 0: ('Boeing', 'Co', 'shares', 'rose', '1.5%', 'to', '$', '67.94', '.')
PRED 0: <unk>

PRED SCORE: -2.3477

SENT 0: ('Some', 'did', 'not', 'believe', 'him', ',', 'they', 'said', 'that', 'he', 'got', 'dizzy', 'even', 'in', 'the', 'truck', ',', 'but', 'always', 'wanted', 'to', 'fulfill', 'his', 'dream', ',', 'that', 'of', 'becoming', 'a', 'pilot', '.')
PRED 0: <unk>

PRED SCORE: -2.3305

SENT 0: ('In', 'your', 'opinion', ',', 'the', 'council', 'should', 'ensure', 'that', 'the', 

IndexError: list index out of range