In [37]:
from __future__ import unicode_literals, print_function, division
import io
import unicodedata
import string
import re
import random
import codecs
import math

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer

from torchtext.utils import download_from_url, extract_archive
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Training device:", device)

Training device: cuda


In [4]:
url = 'https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip'
test_dir, valid_dir, train_dir = extract_archive(download_from_url(url))
tokenizer = get_tokenizer('basic_english')
vocab = build_vocab_from_iterator(map(tokenizer, iter(io.open(train_dir, encoding='utf-8'))))

36718lines [00:01, 35358.41lines/s]


In [31]:
def preprocess_data(raw_text_iterator):
    
    data = [torch.tensor([vocab[token] for token in tokenizer(item)], dtype=torch.long) for item in raw_text_iterator]
    
    return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))

In [32]:
train_data = preprocess_data(iter(io.open(train_dir, encoding='utf-8')))
val_data = preprocess_data(iter(io.open(valid_dir, encoding='utf-8')))
test_data = preprocess_data(iter(io.open(test_dir, encoding='utf-8')))

In [34]:
def split_into_batch(data, batch_size):
    
    n_batch = data.size(0) // batch_size
    data = data.narrow(0, 0, n_batch * batch_size)
    data = data.view(batch_size, -1).t().contiguous()
    
    return data.to(device)

In [35]:
batch_size = 32
eval_batch_size = 16

train_data = split_into_batch(train_data, batch_size)
val_data = split_into_batch(val_data, eval_batch_size)
test_data = split_into_batch(test_data, eval_batch_size)

In [38]:
class PositionalEncoding(nn.Module):
    
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        
        self.dropout = nn.Dropout(p=dropout)
        
        pos_encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pos_encoding[:, 0::2] = torch.sin(position * div_term)
        pos_encoding[:, 1::2] = torch.cos(position * div_term)
        pos_encoding = pos_encoding.unsqueeze(0).transpose(0, 1)
            
        self.register_buffer('pos_encoding', pos_encoding)
        
    def forward(self, x):
        x = x + self.pos_encoding[:x.size(0), :]
        x = self.dropout(x)
        
        return x

In [39]:
class TransformerModel(nn.Module):
    
    def __init__(self, n_token, n_input, n_head, n_hidden, n_layers, dropout=0.5):
        super(TransformerModel, self).__init__()
        
        self.n_input = n_input
        self.model_type = 'Transformer'
        
        self.pos_encoder = PositionalEncoding(n_input, dropout)
        encoder_layers = TransformerEncoderLayer(n_input, n_head, n_hidden, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, n_layers)
        self.encoder = nn.Embedding(n_token, n_input)
        self.decoder = nn.Linear(n_input, n_token)
        
        self.init_weights()
        
    def generate_square_subsequent_mask(self, size):
        mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        
        return mask
    
    def init_weights(self):
        init_range = 0.1
        
        self.encoder.weight.data.uniform_(-init_range, init_range)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-init_range, init_range)
        
    def forward(self, src, src_mask):
        src = self.encoder(src) * math.sqrt(self.n_input)
        src = self.pos_encoder(src)
        
        outptut = self.transformer_encoder(src, src_mask)
        output = self.decoder(output)
        
        return output