In [1]:
import torch
import torch.nn as nn
import math
import numpy as np
import copy
from torch.autograd import Variable
import torch.nn.functional as F
import time
from attention import clones, future_mask, MultiHeadedAttention, FeedForwardNetwork, Embeddings, PositionalEncoding, Generator, LayerNorm, SublayerConnection, Encoder, Decoder, EncoderDecoder, make_model, Batch

In [2]:
vocab_size = 10
batch_size = 16
nbatches = 20
d_model = 512

In [3]:
class zrange:
    def __init__(self, d_model, batch_size, n_batches, vocab_size):
        self.data = torch.from_numpy(np.random.randint(1, vocab_size, size=(batch_size*nbatches, d_model)))
        self.d_model = d_model
        self.batch_size = batch_size
        self.n_batches = n_batches

    def __iter__(self):
        return zrange_iter(self.n_batches, self.batch_size, self.data)

class zrange_iter:
    def __init__(self, n_batches, batch_size, data):
        self.i = 0
        self.n_batches = n_batches
        self.batch_size = batch_size
        self.data = data

    def __iter__(self):
        # Iterators are iterables too.
        # Adding this functions to make them so.
        return self

    def __next__(self):
        if self.i < self.n_batches:
            idx = self.i*self.batch_size
            src = Variable(self.data[idx:idx+batch_size], requires_grad=False)
            tgt = Variable(self.data[idx:idx+batch_size], requires_grad=False)
            self.i += 1
            return Batch(src, tgt, 0)
        else:
            raise StopIteration()

In [4]:
z = zrange(d_model, batch_size, nbatches, vocab_size)

In [5]:
list(z)

[<attention.Batch at 0x14106ea20>,
 <attention.Batch at 0x119ac37a0>,
 <attention.Batch at 0x119b00710>,
 <attention.Batch at 0x110b14ef0>,
 <attention.Batch at 0x140fed2b0>,
 <attention.Batch at 0x119ca7410>,
 <attention.Batch at 0x14106d9a0>,
 <attention.Batch at 0x14106e870>,
 <attention.Batch at 0x14106e8a0>,
 <attention.Batch at 0x14106e840>,
 <attention.Batch at 0x14106eae0>,
 <attention.Batch at 0x14106eb10>,
 <attention.Batch at 0x14106e4b0>,
 <attention.Batch at 0x140fef1a0>,
 <attention.Batch at 0x14106e270>,
 <attention.Batch at 0x14106e240>,
 <attention.Batch at 0x14106e7b0>,
 <attention.Batch at 0x14106e360>,
 <attention.Batch at 0x14106e7e0>,
 <attention.Batch at 0x14106e540>]

In [6]:
data = torch.from_numpy(np.random.randint(1, vocab_size, size=(batch_size*nbatches, d_model)))
data[:, 0] = 1

In [7]:
def data_gen(V, total_data, batch_size, nbatches, d_model=10):
    "Generate random data for a src-tgt copy task."
    i = 0
    while i < nbatches:
        idx = i*batch_size 
        src = Variable(data[idx:idx+batch_size], requires_grad=False)
        tgt = Variable(data[idx:idx+batch_size], requires_grad=False)
        yield Batch(src, tgt, 0)
        i += 1

In [8]:
## testing make_model
model = make_model(vocab_size, vocab_size)
data_generator = data_gen(data,vocab_size,batch_size,nbatches,d_model)

In [9]:
class NoamOpt:
    "Optim wrapper that implements rate."
    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0
        
    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()
        
    def rate(self, step = None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        return self.factor * \
            (self.model_size ** (-0.5) *
            min(step ** (-0.5), step * self.warmup ** (-1.5)))
        
def get_std_opt(model):
    return NoamOpt(model.src_embed[0].d_model, 2, 4000,
            torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))

In [10]:
def run_epoch(data_iter, model, loss_compute):
    "Standard Training and Logging Function"
    start = time.time()
    total_tokens = 0
    total_loss = 0
    tokens = 0
    for i, batch in enumerate(data_iter):
        print("batch: ", i)
        print("batch total tokens: ", batch.ntokens)
        #print("src shape: ", batch.src.size())
        #print("tgt shape: ", batch.trg.size())
        #print("tgt_y shape: ", batch.trg_y.size())
        out = model.forward(batch.src, batch.trg, 
                            batch.src_mask, batch.trg_mask)
        #print("out shape: ", out.size())
        #print("tgt_y shape: ", batch.trg_y.size())
        loss = loss_compute(out, batch.trg_y, batch.ntokens)
        total_loss += loss
        total_tokens += batch.ntokens.item()
        print("total_tokens: ", total_tokens)
        tokens += batch.ntokens.item()
        if i % 1 == 0:
            elapsed = time.time() - start
            print("Epoch Step: %d Loss: %f Tokens per Sec: %f" %
                    (i, loss / batch.ntokens, tokens / elapsed))
            start = time.time()
            tokens = 0
    print("total_loss: ", total_loss)
    print("total_tokens: ", total_tokens)
    #return total_loss / total_tokens

In [11]:
class SimpleLossCompute:
    "A simple loss compute and train function."
    def __init__(self, generator, criterion, opt=None):
        self.generator = generator
        self.criterion = criterion
        self.opt = opt
        
    def __call__(self, x, y, norm):
        # x = self.generator(x)
        loss = self.criterion(x.contiguous().view(-1, x.size(-1)), 
                              y.contiguous().view(-1)) / norm
        loss.backward()
        if self.opt is not None:
            self.opt.step()
            self.opt.optimizer.zero_grad()
        return loss.item() * norm


In [12]:
class LabelSmoothing(nn.Module):
    "Implement label smoothing."
    def __init__(self, size, padding_idx, smoothing=0.0):
        super(LabelSmoothing, self).__init__()
        self.criterion = nn.KLDivLoss(size_average=False)
        self.padding_idx = padding_idx
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.size = size
        self.true_dist = None
        
    def forward(self, x, target):
        assert x.size(1) == self.size
        true_dist = x.data.clone()
        true_dist.fill_(self.smoothing / (self.size - 2))
        true_dist.scatter_(1, target.data.unsqueeze(1).long(), self.confidence)
        true_dist[:, self.padding_idx] = 0
        mask = torch.nonzero(target.data == self.padding_idx)
        if mask.dim() > 0:
            true_dist.index_fill_(0, mask.squeeze(), 0.0)
        self.true_dist = true_dist
        return self.criterion(x, Variable(true_dist, requires_grad=False))

In [None]:
vocab_size = 10
criterion = LabelSmoothing(size=vocab_size, padding_idx=0, smoothing=0.0)
#criterion = nn.CrossEntropyLoss()
model = make_model(vocab_size, vocab_size, N=2)
#data_generator = data_gen(data,vocab_size,batch_size,nbatches,d_model)
data_generator = zrange(d_model, batch_size, nbatches, vocab_size)
model_opt = NoamOpt(model.src_embed[0].d_model, 1, 400,
        torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))

for epoch in range(10):
    print("epoch: ", epoch)
    model.train()
    run_epoch(data_generator, model, 
              SimpleLossCompute(model.generator, criterion, model_opt))
    model.eval()
    print(run_epoch(data_generator, model, 
                    SimpleLossCompute(model.generator, criterion, None)))



epoch:  0
batch:  0
batch total tokens:  tensor(8176)
total_tokens:  8176
Epoch Step: 0 Loss: 3.041877 Tokens per Sec: 1315.607821
batch:  1
batch total tokens:  tensor(8176)
total_tokens:  16352
Epoch Step: 1 Loss: 2.911628 Tokens per Sec: 1469.657500
batch:  2
batch total tokens:  tensor(8176)
total_tokens:  24528
Epoch Step: 2 Loss: 2.684096 Tokens per Sec: 1413.302226
batch:  3
batch total tokens:  tensor(8176)
total_tokens:  32704
Epoch Step: 3 Loss: 2.368209 Tokens per Sec: 1535.154629
batch:  4
batch total tokens:  tensor(8176)
total_tokens:  40880
Epoch Step: 4 Loss: 2.059685 Tokens per Sec: 1498.584779
batch:  5
batch total tokens:  tensor(8176)
total_tokens:  49056
Epoch Step: 5 Loss: 1.786254 Tokens per Sec: 1438.671663
batch:  6
batch total tokens:  tensor(8176)
total_tokens:  57232
Epoch Step: 6 Loss: 1.589987 Tokens per Sec: 1523.301866
batch:  7
batch total tokens:  tensor(8176)
total_tokens:  65408
Epoch Step: 7 Loss: 1.339727 Tokens per Sec: 1488.751735
batch:  8
batch