In [1]:
import torch
import torch.nn as nn
import math
import numpy as np
import copy
from torch.autograd import Variable
import torch.nn.functional as F
import time
from attention import clones, future_mask, MultiHeadedAttention, FeedForwardNetwork, Embeddings, PositionalEncoding, Generator, LayerNorm, SublayerConnection, Encoder, Decoder, EncoderDecoder, make_model, Batch

In [2]:
vocab_size = 10
batch_size = 16
nbatches = 20
d_model = 512

In [3]:
class zrange:
    def __init__(self, d_model, batch_size, n_batches, vocab_size):
        self.data = torch.from_numpy(np.random.randint(1, vocab_size, size=(batch_size*nbatches, d_model)))
        self.d_model = d_model
        self.batch_size = batch_size
        self.n_batches = n_batches

    def __iter__(self):
        return zrange_iter(self.n_batches, self.batch_size, self.data)

class zrange_iter:
    def __init__(self, n_batches, batch_size, data):
        self.i = 0
        self.n_batches = n_batches
        self.batch_size = batch_size
        self.data = data

    def __iter__(self):
        # Iterators are iterables too.
        # Adding this functions to make them so.
        return self

    def __next__(self):
        if self.i < self.n_batches:
            idx = self.i*self.batch_size
            src = Variable(self.data[idx:idx+batch_size], requires_grad=False)
            tgt = Variable(self.data[idx:idx+batch_size], requires_grad=False)
            self.i += 1
            return Batch(src, tgt, 0)
        else:
            raise StopIteration()

In [4]:
z = zrange(d_model, batch_size, nbatches, vocab_size)

In [5]:
list(z)

[<attention.Batch at 0x13da6a990>,
 <attention.Batch at 0x13d251dc0>,
 <attention.Batch at 0x13d252210>,
 <attention.Batch at 0x13ce3a330>,
 <attention.Batch at 0x13da6aa50>,
 <attention.Batch at 0x13da6a4b0>,
 <attention.Batch at 0x13da6aa20>,
 <attention.Batch at 0x13da6a1b0>,
 <attention.Batch at 0x13da6a240>,
 <attention.Batch at 0x13da6a180>,
 <attention.Batch at 0x13da6a540>,
 <attention.Batch at 0x13da6aab0>,
 <attention.Batch at 0x13da6ab40>,
 <attention.Batch at 0x13da6ab10>,
 <attention.Batch at 0x13da6aa80>,
 <attention.Batch at 0x13da6b380>,
 <attention.Batch at 0x13da6b2f0>,
 <attention.Batch at 0x13da6b230>,
 <attention.Batch at 0x13da6ab70>,
 <attention.Batch at 0x13da6aba0>]

In [6]:
data = torch.from_numpy(np.random.randint(1, vocab_size, size=(batch_size*nbatches, d_model)))
data[:, 0] = 1

In [7]:
def data_gen(V, total_data, batch_size, nbatches, d_model=10):
    "Generate random data for a src-tgt copy task."
    i = 0
    while i < nbatches:
        idx = i*batch_size 
        src = Variable(data[idx:idx+batch_size], requires_grad=False)
        tgt = Variable(data[idx:idx+batch_size], requires_grad=False)
        yield Batch(src, tgt, 0)
        i += 1

In [8]:
## testing make_model
model = make_model(vocab_size, vocab_size)
data_generator = data_gen(data,vocab_size,batch_size,nbatches,d_model)

In [9]:
class NoamOpt:
    "Optim wrapper that implements rate."
    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0
        
    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()
        
    def rate(self, step = None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        return self.factor * \
            (self.model_size ** (-0.5) *
            min(step ** (-0.5), step * self.warmup ** (-1.5)))
        
def get_std_opt(model):
    return NoamOpt(model.src_embed[0].d_model, 2, 4000,
            torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))

In [10]:
def run_epoch(data_iter, model, loss_compute):
    "Standard Training and Logging Function"
    start = time.time()
    total_tokens = 0
    total_loss = 0
    tokens = 0
    for i, batch in enumerate(data_iter):
        out = model.forward(batch.src, batch.trg, 
                            batch.src_mask, batch.trg_mask)
        #print("out shape: ", out.size())
        #print("tgt_y shape: ", batch.trg_y.size())
        loss = loss_compute(out, batch.trg_y, batch.ntokens)
        total_loss += loss
        total_tokens += batch.ntokens.item()
        tokens += batch.ntokens.item()
        if i % 1 == 0:
            elapsed = time.time() - start
            print("Epoch Step: %d Loss: %f Tokens per Sec: %f" %
                    (i, loss / batch.ntokens, tokens / elapsed))
            start = time.time()
            tokens = 0
    #print("total_loss: ", total_loss)
    #print("total_tokens: ", total_tokens)
    #return total_loss / total_tokens

In [11]:
class SimpleLossCompute:
    "A simple loss compute and train function."
    def __init__(self, generator, criterion, opt=None):
        self.generator = generator
        self.criterion = criterion
        self.opt = opt
        
    def __call__(self, x, y, norm):
        # x = self.generator(x)
        loss = self.criterion(x.contiguous().view(-1, x.size(-1)), 
                              y.contiguous().view(-1)) / norm
        loss.backward()
        if self.opt is not None:
            self.opt.step()
            self.opt.optimizer.zero_grad()
        return loss.item() * norm


In [12]:
class LabelSmoothing(nn.Module):
    "Implement label smoothing."
    def __init__(self, size, padding_idx, smoothing=0.0):
        super(LabelSmoothing, self).__init__()
        self.criterion = nn.KLDivLoss(size_average=False)
        self.padding_idx = padding_idx
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.size = size
        self.true_dist = None
        
    def forward(self, x, target):
        assert x.size(1) == self.size
        true_dist = x.data.clone()
        true_dist.fill_(self.smoothing / (self.size - 2))
        true_dist.scatter_(1, target.data.unsqueeze(1).long(), self.confidence)
        true_dist[:, self.padding_idx] = 0
        mask = torch.nonzero(target.data == self.padding_idx)
        if mask.dim() > 0:
            true_dist.index_fill_(0, mask.squeeze(), 0.0)
        self.true_dist = true_dist
        return self.criterion(x, Variable(true_dist, requires_grad=False))

In [13]:
vocab_size = 10
criterion = LabelSmoothing(size=vocab_size, padding_idx=0, smoothing=0.0)
#criterion = nn.CrossEntropyLoss()
model = make_model(vocab_size, vocab_size, N=2)
#data_generator = data_gen(data,vocab_size,batch_size,nbatches,d_model)
data_generator = zrange(d_model, batch_size, nbatches, vocab_size)
model_opt = NoamOpt(model.src_embed[0].d_model, 1, 400,
        torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))

for epoch in range(10):
    print("epoch: ", epoch)
    model.train()
    run_epoch(data_generator, model, 
              SimpleLossCompute(model.generator, criterion, model_opt))
    model.eval()
    print(run_epoch(data_generator, model, 
                    SimpleLossCompute(model.generator, criterion, None)))



epoch:  0
Epoch Step: 0 Loss: 3.446250 Tokens per Sec: 1267.290997
Epoch Step: 1 Loss: 3.236492 Tokens per Sec: 1207.103062
Epoch Step: 2 Loss: 2.902675 Tokens per Sec: 1275.497088
Epoch Step: 3 Loss: 2.498465 Tokens per Sec: 1502.379176
Epoch Step: 4 Loss: 2.138298 Tokens per Sec: 1397.080464
Epoch Step: 5 Loss: 1.847563 Tokens per Sec: 1060.637148
Epoch Step: 6 Loss: 1.560208 Tokens per Sec: 969.015239
Epoch Step: 7 Loss: 1.327338 Tokens per Sec: 1277.639960
Epoch Step: 8 Loss: 1.134370 Tokens per Sec: 1265.149763
Epoch Step: 9 Loss: 0.895822 Tokens per Sec: 1063.316940
Epoch Step: 10 Loss: 0.702552 Tokens per Sec: 1184.823287
Epoch Step: 11 Loss: 0.526847 Tokens per Sec: 1355.335171
Epoch Step: 12 Loss: 0.398553 Tokens per Sec: 1187.262975
Epoch Step: 13 Loss: 0.303925 Tokens per Sec: 1531.441171
Epoch Step: 14 Loss: 0.244751 Tokens per Sec: 1370.089625
Epoch Step: 15 Loss: 0.187407 Tokens per Sec: 1362.523353
Epoch Step: 16 Loss: 0.137313 Tokens per Sec: 1392.820789
Epoch Step: 17 

KeyboardInterrupt: 