In [1]:
import torch
import torch.nn as nn
import math
import numpy as np
import copy
from torch.autograd import Variable
import torch.nn.functional as F
import time
from attention import clones, future_mask, MultiHeadedAttention, FeedForwardNetwork, Embeddings, PositionalEncoding, Generator, LayerNorm, SublayerConnection, Encoder, Decoder, EncoderDecoder, make_model, Batch

In [2]:
vocab_size = 10
batch_size = 16
nbatches = 20
d_model = 512

In [3]:
class zrange:
    def __init__(self,batch_size, n_batches, vocab_size, seq_len):
        self.data = torch.from_numpy(np.random.randint(1, vocab_size, size=(batch_size*nbatches, seq_len)))
        self.batch_size = batch_size
        self.n_batches = n_batches
        self.seq_len = seq_len

    def __iter__(self):
        return zrange_iter(self.n_batches, self.batch_size, self.data, self.seq_len)

class zrange_iter:
    def __init__(self, n_batches, batch_size, data, seq_len):
        self.i = 0
        self.n_batches = n_batches
        self.batch_size = batch_size
        self.data = data
        self.seq_len = seq_len

    def __iter__(self):
        # Iterators are iterables too.
        # Adding this functions to make them so.
        return self

    def __next__(self):
        if self.i < self.n_batches:
            idx = self.i*self.batch_size
            src = Variable(self.data[idx:idx+batch_size,:], requires_grad=False)
            tgt = Variable(self.data[idx:idx+batch_size,:], requires_grad=False)
            self.i += 1
            return Batch(src, tgt, 0)
        else:
            raise StopIteration()

In [4]:
z = zrange(batch_size, nbatches, vocab_size, 10)

In [5]:
list(z)

[<attention.Batch at 0x13f879cd0>,
 <attention.Batch at 0x13f042690>,
 <attention.Batch at 0x13f878e00>,
 <attention.Batch at 0x13f879280>,
 <attention.Batch at 0x13f879d00>,
 <attention.Batch at 0x13f87a180>,
 <attention.Batch at 0x13f879970>,
 <attention.Batch at 0x13f879940>,
 <attention.Batch at 0x13f879b50>,
 <attention.Batch at 0x13f879520>,
 <attention.Batch at 0x13f879490>,
 <attention.Batch at 0x13f8794f0>,
 <attention.Batch at 0x13f879c10>,
 <attention.Batch at 0x13f879be0>,
 <attention.Batch at 0x13f878ce0>,
 <attention.Batch at 0x13f879040>,
 <attention.Batch at 0x13f879d30>,
 <attention.Batch at 0x13f4536e0>,
 <attention.Batch at 0x13f87a210>,
 <attention.Batch at 0x13f87a270>]

In [6]:
data = torch.from_numpy(np.random.randint(1, vocab_size, size=(1,batch_size*nbatches, d_model)))
data[:, 0] = 1

In [7]:
## testing make_model
# model = make_model(vocab_size, vocab_size)


In [8]:
class NoamOpt:
    "Optim wrapper that implements rate."
    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0
        
    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()
        
    def rate(self, step = None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        return self.factor * \
            (self.model_size ** (-0.5) *
            min(step ** (-0.5), step * self.warmup ** (-1.5)))
        
def get_std_opt(model):
    return NoamOpt(model.src_embed[0].d_model, 2, 4000,
            torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))

In [9]:
def run_epoch(data_iter, model, loss_compute):
    "Standard Training and Logging Function"
    start = time.time()
    total_tokens = 0
    total_loss = 0
    tokens = 0
    for i, batch in enumerate(data_iter):
        out = model.forward(batch.src, batch.trg, 
                            batch.src_mask, batch.trg_mask)
        #print("out shape: ", out.size())
        #print("tgt_y shape: ", batch.trg_y.size())
        loss = loss_compute(out, batch.trg_y, batch.ntokens)
        total_loss += loss
        total_tokens += batch.ntokens.item()
        tokens += batch.ntokens.item()
        if i % 1 == 0:
            elapsed = time.time() - start
            print("Epoch Step: %d Loss: %f Tokens per Sec: %f" %
                    (i, loss / batch.ntokens, tokens / elapsed))
            start = time.time()
            tokens = 0
    #print("total_loss: ", total_loss)
    #print("total_tokens: ", total_tokens)
    #return total_loss / total_tokens

In [10]:
class SimpleLossCompute:
    "A simple loss compute and train function."
    def __init__(self, generator, criterion, opt=None):
        self.generator = generator
        self.criterion = criterion
        self.opt = opt
        
    def __call__(self, x, y, norm):
        # x = self.generator(x)
        loss = self.criterion(x.contiguous().view(-1, x.size(-1)), 
                              y.contiguous().view(-1)) / norm
        loss.backward()
        if self.opt is not None:
            self.opt.step()
            self.opt.optimizer.zero_grad()
        return loss.item() * norm


In [11]:
class LabelSmoothing(nn.Module):
    "Implement label smoothing."
    def __init__(self, size, padding_idx, smoothing=0.0):
        super(LabelSmoothing, self).__init__()
        self.criterion = nn.KLDivLoss(size_average=False)
        self.padding_idx = padding_idx
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.size = size
        self.true_dist = None
        
    def forward(self, x, target):
        assert x.size(1) == self.size
        true_dist = x.data.clone()
        true_dist.fill_(self.smoothing / (self.size - 2))
        true_dist.scatter_(1, target.data.unsqueeze(1).long(), self.confidence)
        true_dist[:, self.padding_idx] = 0
        mask = torch.nonzero(target.data == self.padding_idx)
        if mask.dim() > 0:
            true_dist.index_fill_(0, mask.squeeze(), 0.0)
        self.true_dist = true_dist
        return self.criterion(x, Variable(true_dist, requires_grad=False))

In [12]:
vocab_size = 10
criterion = LabelSmoothing(size=vocab_size, padding_idx=0, smoothing=0.0)
#criterion = nn.CrossEntropyLoss()
model = make_model(vocab_size, vocab_size, N=2)
#data_generator = data_gen(data,vocab_size,batch_size,nbatches,d_model)
data_generator = zrange(batch_size, nbatches, vocab_size,20)
model_opt = NoamOpt(model.src_embed[0].d_model, 1, 400,
        torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))


for epoch in range(0):
    print("epoch: ", epoch)
    model.train()
    run_epoch(data_generator, model, 
              SimpleLossCompute(model.generator, criterion, model_opt))
    model.eval()
    print(run_epoch(data_generator, model, 
                    SimpleLossCompute(model.generator, criterion, None)))




d_model:  512
h:  8


In [13]:
def greedy_decode(model, source, source_mask, max_decode_len, start_symbol):
    encoder_outputs = model.encode(source, source_mask) # todo(annhe): check source_masking
    print("encoder outputs shape: ", encoder_outputs.size())
    ys = torch.ones(1,1).fill_(start_symbol).type_as(source.data) # create a 1x1 array with element '<sos>'
    for i in range(max_decode_len-1):
        #print("i: ", i)
        #print("ys shape: ", ys.size())
        output_mask = torch.Tensor(future_mask(ys.size(1)).type_as(src.data))
        output = model.greedy_decode(encoder_outputs, source_mask, torch.Tensor(ys),
                              output_mask,pos=i)
        # output is size N x L x D
        # we need to pass it through the generator
        prob = model.generator(output[:,-1,:])
        _, vocab_symbol = torch.max(prob, dim=1)
        vocab_symbol = vocab_symbol.detach().unsqueeze(0)
        ys = torch.cat([ys, vocab_symbol], dim=1)
    return ys
model.eval()
src = Variable(torch.LongTensor([[1,2,3,4,5,6,7]]) )
src_mask = Variable(torch.ones(1, 1, 7) )
result = greedy_decode(model, src, src_mask, max_decode_len=7, start_symbol=1)
print(result.detach())

encoder outputs shape:  torch.Size([1, 7, 512])
keys shape:  torch.Size([1, 0, 64])
new_key shape:  torch.Size([1, 1, 64])
keys shape:  torch.Size([1, 0, 64])
new_key shape:  torch.Size([1, 1, 64])
keys shape:  torch.Size([1, 0, 64])
new_key shape:  torch.Size([1, 1, 64])
keys shape:  torch.Size([1, 0, 64])
new_key shape:  torch.Size([1, 1, 64])
keys shape:  torch.Size([1, 0, 64])
new_key shape:  torch.Size([1, 1, 64])
keys shape:  torch.Size([1, 0, 64])
new_key shape:  torch.Size([1, 1, 64])
keys shape:  torch.Size([1, 0, 64])
new_key shape:  torch.Size([1, 1, 64])
keys shape:  torch.Size([1, 0, 64])
new_key shape:  torch.Size([1, 1, 64])
keys shape:  torch.Size([1, 0, 64])
new_key shape:  torch.Size([1, 1, 64])
keys shape:  torch.Size([1, 0, 64])
new_key shape:  torch.Size([1, 1, 64])
keys shape:  torch.Size([1, 0, 64])
new_key shape:  torch.Size([1, 1, 64])
keys shape:  torch.Size([1, 0, 64])
new_key shape:  torch.Size([1, 1, 64])
keys shape:  torch.Size([1, 0, 64])
new_key shape:  