In [1]:
import torch.nn as nn
import torch
import math
#parser.add_argument('--nhid', type=int, default=200,
#                    help='number of hidden units per layer')

#parser.add_argument('--ninp', type=int, default=200,
#                    help='size of word embeddings')#

class NoamOpt:
    "Optim wrapper that implements rate."
    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0
        
    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()
        
    def rate(self, step = None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        return self.factor * \
            (self.model_size ** (-0.5) *
            min(step ** (-0.5), step * self.warmup ** (-1.5)))
        
def get_std_opt(model):
    return NoamOpt(model.src_embed[0].d_model, 2, 4000,
            torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))

class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)
    
class PositionalEncoding(nn.Module):
    "Implement the PE function."
    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0.0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0.0, d_model, 2) *
                             -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + Variable(self.pe[:, :x.size(1)], 
                         requires_grad=False)
        return self.dropout(x)

    
class LayerNorm(nn.Module):
    "Construct a layernorm module (See citation for details)."
    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2


class PositionwiseFeedForward(nn.Module):
    "Implements FFN equation."
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(F.relu(self.w_1(x))))

    
class TransformerCell(nn.Module):
        """Transformer Cell with Attention and PositionwiseFeedForward"""
        def __init__(self, ntoken, d_model, nff, nheads=8,dropout=0.2, tie_weights=False):
            super(TransformerCell, self).__init__()
            
            self.att = nn.MultiheadAttention(d_model,num_heads=nheads,dropout=dropout)
            self.ff = PositionwiseFeedForward(d_model,nff)
        
            self.norm = LayerNorm(d_model)
            self.dropnorm1 = nn.Dropout(dropout)
            
            self.norm2 = LayerNorm(d_model)
            self.dropnorm2 = nn.Dropout(dropout)
            
        def forward(self,x):
            
            x,_= self.att(x,x,x)
            
            x = x + self.dropnorm1(self.norm(x))
            
            x = self.ff(x)
            x = x + self.dropnorm2(self.norm2(x))
            
            return x
            



class TransformerModel(nn.Module):
    """Container module with an encoder, a recurrent module, and a decoder."""
    
    def __init__(self, ntoken, d_model,nheads=8,nff=512, nlayers=1, dropout=0.0, tie_weights=False):
        super(TransformerModel, self).__init__()
        
        self.emb = Embeddings(d_model,ntoken)
        self.pos = PositionalEncoding(d_model,dropout)
        
        assert(d_model%nheads==0)
        
        self.trans = TransformerCell(ntoken, d_model, nff=nff, nheads=nheads,dropout=0.2, tie_weights=True)
        
        
        self.decoder = nn.Sequential(nn.Linear(d_model, ntoken),nn.LogSoftmax())
        
        # Optionally tie weights as in:
        # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
        # https://arxiv.org/abs/1608.05859
        # and
        # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
        # https://arxiv.org/abs/1611.01462
        if tie_weights:
            self.decoder.weight = self.emb.lut.weight

        self.nlayers = nlayers



    def forward(self, input):
        
        emb = self.pos(self.emb(input))
        
        output =  self.trans(emb)
        
        decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
        
        #return decoded.view(output.size(0), output.size(1), decoded.size(1))
        return decoded


In [13]:
from torch.autograd import Variable
import torch.nn.functional as F

###############################################################################
# Training code
###############################################################################



# get_batch subdivides the source data into chunks of length args.bptt.
# If source is equal to the example output of the batchify function, with
# a bptt-limit of 2, we'd get the following two Variables for i = 0:
# ┌ a g m s ┐ ┌ b h n t ┐
# └ b h n t ┘ └ c i o u ┘
# Note that despite the name of the function, the subdivison of data is not
# done along the batch dimension (i.e. dimension 1), since that was handled
# by the batchify function. The chunks are along dimension 0, corresponding
# to the seq_len dimension in the LSTM.

def get_batch(source, i):
    seq_len2 = min(seq_len, len(source) - 1 - i)
    data = source[i:i+seq_len2]
    target = source[i+1:i+1+seq_len2].view(-1)
    return data, target


def evaluate(data_source):
    # Turn on evaluation mode which disables dropout.
    model.eval()
    total_loss = 0.
    ntokens = len(corpus.dictionary)
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, seq_len):
            input_data, targets = get_batch(data_source, i)
            output = model(input_data)
            output_flat = output.view(-1, ntokens)
            total_loss += len(input_data) * criterion(output_flat, targets).item()
    return total_loss / (len(data_source) - 1)


def train():
    # Turn on training mode which enables dropout.
    model.train()
    total_loss = 0.
    start_time = time.time()
    ntokens = len(corpus.dictionary)
    
    for batch, i in enumerate(range(0, train_data.size(0) - 1, seq_len)):
        input_data, targets = get_batch(train_data, i)
        
        optimizer.zero_grad()
        output = model(input_data)
        loss = criterion(output.view(-1, ntokens), targets)
        #print(output.view(-1, ntokens).shape,targets.shape)
        loss.backward()
        optimizer.step()


        total_loss += loss.item()

        if batch % 200 == 0 and batch > 0:
            cur_loss = total_loss / 200
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches |  ms/batch {:5.2f} | '
                    'loss {:5.2f} | ppl {:8.2f}'.format(
                epoch, batch, len(train_data) // seq_len,
                elapsed * 1000 / 200, cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()


In [10]:

# coding: utf-8
import argparse
import time
import math
import os
import torch
import torch.nn as nn
import torch.onnx
import data
import model


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def batchify(data, bsz):
    # Work out how cleanly we can divide the dataset into bsz parts.
    nbatch = data.size(0) // bsz
    # Trim off any extra elements that wouldn't cleanly fit (remainders).
    data = data.narrow(0, 0, nbatch * bsz)
    # Evenly divide the data across the bsz batches.
    data = data.view(bsz, -1).t().contiguous()
    return data.to(device)
eval_batch_size = 40

batch_size = 100
torch.cuda.empty_cache()

In [11]:
corpus = data.Corpus('./data/pokeCorpusBulba')

train_data = batchify(corpus.train, batch_size)
val_data = batchify(corpus.valid, eval_batch_size)

seq_len = 50



In [None]:
#overfit one batch



ntokens = len(corpus.dictionary)


model = TransformerModel(ntokens, 1024).to(device)


criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=0.001,betas=(0.9, 0.98), eps=1e-9)

input_data, targets = get_batch(train_data, 1)

for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

output = None
for i in range(1000):
    optimizer.zero_grad()

    output = model(input_data)

    loss = criterion(output.view(-1, ntokens), targets)

    loss.backward()
    optimizer.step()
    
    if (i%10 == 0):
        print(loss.detach())



In [56]:

###############################################################################
# Build the model
###############################################################################

ntokens = len(corpus.dictionary)


model = TransformerModel(ntokens, 512).to(device)

for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

        
optimizer = torch.optim.Adam(model.parameters(), lr=0.001,betas=(0.9, 0.98), eps=1e-9)

criterion = nn.CrossEntropyLoss()





# At any point you can hit Ctrl + C to break out of training early.
try:
    for epoch in range(1, 100):
        epoch_start_time = time.time()
        train()
        val_loss = evaluate(val_data)
        print('-' * 89)
        print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
                'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                           val_loss, math.exp(val_loss)))
        print('-' * 89)

except KeyboardInterrupt:
    print('-' * 89)
    print('Exiting from training early')





  input = module(input)


| epoch   1 |   200/ 1064 batches |  ms/batch 109.58 | loss  5.79 | ppl   326.92
| epoch   1 |   400/ 1064 batches |  ms/batch 106.37 | loss  4.44 | ppl    84.89
| epoch   1 |   600/ 1064 batches |  ms/batch 104.61 | loss  3.89 | ppl    48.69
| epoch   1 |   800/ 1064 batches |  ms/batch 104.66 | loss  3.57 | ppl    35.36
| epoch   1 |  1000/ 1064 batches |  ms/batch 104.70 | loss  3.35 | ppl    28.56
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 112.67s | valid loss  2.85 | valid ppl    17.36
-----------------------------------------------------------------------------------------
| epoch   2 |   200/ 1064 batches |  ms/batch 105.20 | loss  3.12 | ppl    22.61
| epoch   2 |   400/ 1064 batches |  ms/batch 104.74 | loss  2.97 | ppl    19.41
| epoch   2 |   600/ 1064 batches |  ms/batch 104.81 | loss  2.85 | ppl    17.29
| epoch   2 |   800/ 1064 batches |  ms/batch 104.77 | loss  2.76 | ppl    15.73
| epoch   2 |  1

In [57]:
# generate

model.eval()


n_words = 100
corpus = data.Corpus('./data/pokeCorpusBulba')

ntokens = len(corpus.dictionary)





In [105]:
input_sentence = "ash and his friends"

indices = [corpus.dictionary.word2idx[w] for w in input_sentence.split(' ')]

input = torch.LongTensor(indices).to(device).view(1,-1)
output= model(input)
#torch.randint(ntokens, (1, 1), dtype=torch.long).to(device)




  input = module(input)


In [107]:
with torch.no_grad():  # no tracking history
        for i in range(n_words+50):
            
            output= model(input)
            
            word_weights = output.squeeze().exp().cpu()
            
            word_idx = torch.multinomial(word_weights, 1)[-1].to(device)
            
            input = torch.cat((input.squeeze(),word_idx)).view(1,-1)
            
            word = corpus.dictionary.idx2word[word_idx]

            print(word,end =" ") 


  input = module(input)


pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon pokémon 

In [108]:
input

tensor([[  5,   6,   7,   8, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153,
         153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153,
         153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153,
         153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153,
         153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153,
         153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153,
         153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153,
         153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153,
         153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153,
         153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153,
         153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153,
         153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153, 153,
         153, 153, 153, 153, 153, 153, 153, 153, 153