In [1]:
!pip install datasets



In [2]:
import math
import time
import random

import torch
import torch.nn as nn
import torch.optim as optim

import torchtext

import datasets

In [3]:
torch.manual_seed(0)
random.seed(0)

In [4]:
dataset = datasets.load_dataset('wikitext', 'wikitext-2-raw-v1')

Reusing dataset wikitext (/root/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/aa5e094000ec7afeb74c3be92c88313cd6f132d564c7effd961c10fd47c76f20)


In [5]:
tokenizer = torchtext.data.utils.get_tokenizer('basic_english')

In [6]:
def tokenize_data(example, tokenizer):
    tokens = {'tokens': tokenizer(example['text'])}
    return tokens

In [7]:
tokenized_dataset = dataset.map(tokenize_data, remove_columns=['text'], fn_kwargs={'tokenizer': tokenizer})

Loading cached processed dataset at /root/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/aa5e094000ec7afeb74c3be92c88313cd6f132d564c7effd961c10fd47c76f20/cache-f2efa5c633011b8b.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/aa5e094000ec7afeb74c3be92c88313cd6f132d564c7effd961c10fd47c76f20/cache-7f46f3973245ea2f.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/aa5e094000ec7afeb74c3be92c88313cd6f132d564c7effd961c10fd47c76f20/cache-3451e0738dc204d3.arrow


In [8]:
vocab = torchtext.vocab.build_vocab_from_iterator(tokenized_dataset['train']['tokens'],
                                                  min_freq=3)

In [9]:
vocab.insert_token('<unk>', 0)
vocab.insert_token('<eos>', 1)

In [None]:
vocab.get_itos()[:10]

In [10]:
unk_index = vocab['<unk>']
vocab.set_default_index(unk_index)

In [11]:
def get_data(dataset, vocab, batch_size):
    data = []
    for example in dataset:
        if example['tokens']:
            tokens = example['tokens'].append('<eos>')
            tokens = [vocab[token] for token in example['tokens']]
            data.extend(tokens)
    data = torch.LongTensor(data)
    n_batches = data.shape[0] // batch_size
    data = data.narrow(0, 0, n_batches * batch_size)
    data = data.view(batch_size, -1)
    return data

In [12]:
batch_size = 80

train_data = get_data(tokenized_dataset['train'], vocab, batch_size)
valid_data = get_data(tokenized_dataset['validation'], vocab, batch_size)
test_data = get_data(tokenized_dataset['test'], vocab, batch_size)

In [13]:
class LockedDropout(nn.Module):
    def __init__(self, p=0.5):
        super().__init__()
        self.p = p

    def forward(self, x):
        # x = [batch size, seq len, hidden dim]
        if not self.training or not self.p:
            return x
        x = x.clone()
        mask = x.new_empty(x.shape[0], 1, x.shape[2], requires_grad=False).bernoulli_(1 - self.p)
        mask = mask.div_(1 - self.p)
        mask = mask.expand_as(x)
        return x * mask

In [14]:
def _setup_weight_drop(module, weights, dropout):
    for name_w in weights:
        w = getattr(module, name_w)
        del module._parameters[name_w]
        module.register_parameter(name_w + '_raw', nn.Parameter(w))

    original_module_forward = module.forward

    def forward(*args, **kwargs):
        for name_w in weights:
            raw_w = getattr(module, name_w + '_raw')
            w = nn.Parameter(torch.nn.functional.dropout(raw_w, p=dropout, training=module.training))
            setattr(module, name_w, w)

        return original_module_forward(*args, **kwargs)

    setattr(module, 'forward', forward)

In [15]:
class WeightDropLSTM(torch.nn.LSTM):
    """
    Wrapper around :class:`torch.nn.LSTM` that adds ``weight_dropout`` named argument.

    Args:
        weight_dropout (float): The probability a weight will be dropped.
    """

    def __init__(self, *args, weight_dropout=0.0, **kwargs):
        super().__init__(*args, **kwargs)
        weights = ['weight_hh_l' + str(i) for i in range(self.num_layers)]
        _setup_weight_drop(self, weights, weight_dropout)

In [16]:
class AWDLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, n_layers, 
                 embedding_dropout_rate, weight_dropout_rate, lstm_dropout_rate, output_dropout_rate, 
                 tie_weights):
        super().__init__()
        self.n_layers = n_layers
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.tie_weights = tie_weights
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        lstms = []
        for n in range(n_layers):
            input_dim = embedding_dim if n == 0 else hidden_dim
            output_dim = hidden_dim if n != n_layers - 1 else (embedding_dim if tie_weights else hidden_dim)
            lstm = WeightDropLSTM(input_dim, output_dim, batch_first=True, weight_dropout=weight_dropout_rate)
            lstms.append(lstm)
        self.lstms = nn.ModuleList(lstms)
        self.fc = nn.Linear(embedding_dim if tie_weights else hidden_dim, vocab_size)

        self.embedding_dropout = LockedDropout(embedding_dropout_rate)
        self.lstm_dropout = LockedDropout(lstm_dropout_rate)
        self.output_dropout = LockedDropout(output_dropout_rate)

        if tie_weights:
            self.embedding.weight = self.fc.weight

        self.init_weights()
    
    def init_weights(self):
        init_range = 0.1
        self.embedding.weight.data.uniform_(-init_range, init_range)
        self.fc.weight.data.uniform_(-init_range, init_range)
        self.fc.bias.data.zero_()

    def init_hidden(self, batch_size, device):
        hiddens = []
        for n in range(self.n_layers):
            dim = self.hidden_dim if n != n_layers - 1 else (self.embedding_dim if self.tie_weights else self.hidden_dim)
            hidden = torch.zeros(1, batch_size, dim).to(device)
            cell = torch.zeros(1, batch_size, dim).to(device)
            hiddens.append((hidden, cell))
        return hiddens

    def detach_hidden(self, hidden):
        if isinstance(hidden, torch.Tensor):
            return hidden.detach()
        else:
            return tuple(self.detach_hidden(h) for h in hidden)

    def forward(self, input, hidden):
        # input = [batch size, seq len]
        # hidden = list([1, batch size, hidden dim])
        embedding = self.embedding_dropout(self.embedding(input))
        # embedding = [batch size, seq len, embedding dim]
        lstm_input = embedding
        new_hiddens = []
        for n, lstm in enumerate(self.lstms):
            lstm_output, new_hidden = lstm(lstm_input, hidden[n])
            # lstm_output = [batch size, seq len, hidden dim]
            # new_hidden = [1, batch size, hidden dim]
            if n != self.n_layers - 1:
                lstm_output = self.lstm_dropout(lstm_output)
            lstm_input = lstm_output
            new_hiddens.append(new_hidden)
        output = self.output_dropout(lstm_output)
        prediction = self.fc(output)
        # prediction = [batch size, seq len, vocab size]
        # output = [batch size, seq len, hidden dim]
        # lstm_output = [bath size, seq len, hidden dim]
        # new_hiddens = list([1, batch size, hidden dim])
        return prediction, output, lstm_output, new_hiddens

In [17]:
vocab_size = len(vocab)
embedding_dim = 400
hidden_dim = 1150
n_layers = 3
embedding_dropout_rate = 0.65
weight_dropout_rate = 0.5
lstm_dropout_rate = 0.2
output_dropout_rate = 0.4
tie_weights = True

model = AWDLSTM(vocab_size, embedding_dim, hidden_dim, n_layers,
                embedding_dropout_rate, weight_dropout_rate, lstm_dropout_rate, output_dropout_rate, 
                tie_weights)

In [18]:
criterion = nn.CrossEntropyLoss()

In [20]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 32,030,273 trainable parameters


In [21]:
lr = 1e-3
weight_decay = 1e-6

optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

In [22]:
criterion = nn.CrossEntropyLoss()

In [23]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(device)

cuda


In [24]:
model = model.to(device)
criterion = criterion.to(device)

In [25]:
def train(model, data, optimizer, criterion, batch_size, base_seq_len, alpha, beta, clip, device):
    
    epoch_loss = 0
    model.train()
    n_tokens = data.shape[-1]
    base_lr = optimizer.param_groups[0]["lr"]

    hidden = model.init_hidden(batch_size, device)
    
    for input, target in get_batches(data, base_seq_len):
        optimizer.zero_grad()
        input = input.to(device)
        target = target.to(device)
        # input = [batch size, seq len]
        # target = [batch size, seq len]
        batch_size, seq_len = input.shape
        scaled_lr = base_lr * seq_len / base_seq_len
        optimizer.param_groups[0]["lr"] = scaled_lr
        hidden = model.detach_hidden(hidden)
        # hidden = list([1, batch size, hidden dim])
        prediction, output, raw_output, hidden = model(input, hidden)
        # prediction = [batch size, seq len, vocab size]
        # output = [batch size, seq len, hidden dim]
        # hidden = list([1, batch size, hidden dim])
        prediction = prediction.reshape(batch_size * seq_len, -1)
        target = target.reshape(-1)
        # output = [batch size * seq len, vocab size]
        # target = [batch size * seq len]
        loss = criterion(prediction, target)
        alpha_loss = (alpha * output.pow(2).mean()).sum()
        beta_loss = (beta * (output[:,1:] - output[:,:-1]).pow(2).mean()).sum()
        loss = loss + alpha_loss + beta_loss 
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item() * seq_len
    return epoch_loss / n_tokens

In [26]:
def get_batches(data, seq_len):
    data_len = data.shape[-1]
    total_seq_len = 0
    sampled_seq_lens = []
    min_seq_len, max_seq_len = int(seq_len * 0.9), int(seq_len * 1.1)
    while total_seq_len < data_len:
        sampled_seq_len = random.randint(min_seq_len, max_seq_len)
        sampled_seq_lens.append(sampled_seq_len)
        total_seq_len += sampled_seq_len
    sampled_seq_lens = sampled_seq_lens[:-1]
    remainder = data_len - sum(sampled_seq_lens)
    if remainder > min_seq_len:
        sampled_seq_lens.append(remainder - 1)
    pos = 0
    for sampled_seq_len in sampled_seq_lens:
        input = data[:,pos:pos+sampled_seq_len]
        target = data[:,pos+1:pos+sampled_seq_len+1]
        pos += sampled_seq_len
        yield input, target

In [28]:
def evaluate(model, data, criterion, batch_size, base_seq_len, device):

    epoch_loss = 0
    model.eval()
    n_tokens = data.shape[-1]

    hidden = model.init_hidden(batch_size, device)

    with torch.no_grad():
        for input, target in get_batches(data, base_seq_len):
            input = input.to(device)
            target = target.to(device)
            # input = [batch size, seq len]
            # target = [batch size, seq len]
            batch_size, seq_len = input.shape
            hidden = model.detach_hidden(hidden)
            # hidden = list([1, batch size, hidden dim])
            prediction, _, _, hidden = model(input, hidden)
            # prediction = [batch size, seq len, vocab size]
            # hidden = list([1, batch size, hidden dim])
            prediction = prediction.reshape(batch_size * seq_len, -1)
            target = target.reshape(-1)
            # prediction = [batch size * seq len, vocab size]
            # target = [batch size * seq len]
            loss = criterion(prediction, target)
            epoch_loss += loss.item() * seq_len
    return epoch_loss / n_tokens

In [29]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [30]:
n_epochs = 50
seq_len = 70
clip = 0.25
alpha = 2
beta = 1

best_valid_loss = float('inf')

for epoch in range(n_epochs):

    start_time = time.monotonic()

    train_loss = train(model, train_data, optimizer, criterion, batch_size, seq_len, alpha, beta, clip, device)
    valid_loss = evaluate(model, valid_data, criterion, batch_size, seq_len, device)
    
    end_time = time.monotonic()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'awd-lstm_lm.pt')

    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Perplexity: {math.exp(train_loss):.3f}')
    print(f'\tValid Perplexity: {math.exp(valid_loss):.3f}')

  self.dropout, self.training, self.bidirectional, self.batch_first)


Epoch: 01 | Epoch Time: 1m 27s
	Train Perplexity: 956.350
	Valid Perplexity: 372.991
Epoch: 02 | Epoch Time: 1m 27s
	Train Perplexity: 459.650
	Valid Perplexity: 269.607
Epoch: 03 | Epoch Time: 1m 27s
	Train Perplexity: 342.260
	Valid Perplexity: 217.015
Epoch: 04 | Epoch Time: 1m 27s
	Train Perplexity: 272.380
	Valid Perplexity: 174.195
Epoch: 05 | Epoch Time: 1m 27s
	Train Perplexity: 228.770
	Valid Perplexity: 152.187
Epoch: 06 | Epoch Time: 1m 27s
	Train Perplexity: 202.017
	Valid Perplexity: 143.913
Epoch: 07 | Epoch Time: 1m 27s
	Train Perplexity: 178.614
	Valid Perplexity: 127.991
Epoch: 08 | Epoch Time: 1m 27s
	Train Perplexity: 165.137
	Valid Perplexity: 130.775
Epoch: 09 | Epoch Time: 1m 27s
	Train Perplexity: 151.638
	Valid Perplexity: 126.324
Epoch: 10 | Epoch Time: 1m 27s
	Train Perplexity: 142.756
	Valid Perplexity: 117.159
Epoch: 11 | Epoch Time: 1m 27s
	Train Perplexity: 133.374
	Valid Perplexity: 116.873
Epoch: 12 | Epoch Time: 1m 27s
	Train Perplexity: 128.099
	Valid 

In [31]:
model.load_state_dict(torch.load('awd-lstm_lm.pt'))

test_loss = evaluate(model, test_data, criterion, batch_size, seq_len, device)

print(f'Test Perplexity: {math.exp(test_loss):.3f}')

  self.dropout, self.training, self.bidirectional, self.batch_first)


Test Perplexity: 80.583


In [38]:
def generate(prompt, n_gen_tokens, temperature, model, tokenizer, vocab, device, seed=None):
    if seed is not None:
        torch.manual_seed(0)
    model.eval()
    tokens = tokenizer(prompt)
    indices = [vocab[t] for t in tokens]
    batch_size = 1
    hidden = model.init_hidden(batch_size, device)
    with torch.no_grad():
        for i in range(n_gen_tokens):
            input = torch.LongTensor([indices]).to(device)
            prediction, _, _, hidden = model(input, hidden)
            probs = torch.softmax(prediction[:, -1] / temperature, dim=-1) 
            prediction = torch.multinomial(probs, num_samples=1).item()
            indices.append(prediction)

    itos = vocab.get_itos()
    tokens = [itos[i] for i in indices]
    return tokens

In [43]:
prompt = 'the'
n_gen_tokens = 25
temperature = 0.5
seed = 0

generation = generate(prompt, n_gen_tokens, temperature, model, tokenizer, vocab, device, seed)

  self.dropout, self.training, self.bidirectional, self.batch_first)


In [44]:
generation

['the',
 '<unk>',
 '<unk>',
 ',',
 'which',
 'was',
 'the',
 'first',
 'to',
 'be',
 'named',
 'the',
 '<unk>',
 '.',
 '<eos>',
 '=',
 '=',
 '=',
 'death',
 '=',
 '=',
 '=',
 '<eos>',
 'the',
 'first',
 '@-@']

In [45]:
temperature = 0.1

generation = generate(prompt, n_gen_tokens, temperature, model, tokenizer, vocab, device, seed)

  self.dropout, self.training, self.bidirectional, self.batch_first)


In [46]:
generation

['the',
 '<unk>',
 '<unk>',
 '.',
 '<eos>',
 '=',
 '=',
 '=',
 '=',
 '<unk>',
 '=',
 '=',
 '=',
 '=',
 '<eos>',
 'the',
 '<unk>',
 '<unk>',
 '<unk>',
 '(',
 '<unk>',
 ')',
 'is',
 'a',
 '<unk>',
 '<unk>']

In [47]:
temperature = 1.5

generation = generate(prompt, n_gen_tokens, temperature, model, tokenizer, vocab, device, seed)

  self.dropout, self.training, self.bidirectional, self.batch_first)


In [48]:
generation

['the',
 'rigging',
 'swap',
 '228',
 'wansel',
 'and',
 'protestants',
 'arranged',
 'discussions',
 '3',
 'agree',
 'oldman',
 'ctesiphon',
 ',',
 'blown',
 'harvest',
 'manny',
 'friday',
 'the',
 'tom',
 'sample',
 'giger',
 'viewed',
 'accommodated',
 '138',
 'paces']

In [49]:
temperature = 0.75

generation = generate(prompt, n_gen_tokens, temperature, model, tokenizer, vocab, device, seed)

  self.dropout, self.training, self.bidirectional, self.batch_first)


In [50]:
generation

['the',
 'rigging',
 'of',
 'the',
 'hell',
 '.',
 'the',
 'general',
 "'",
 's',
 'movement',
 'was',
 'a',
 'theme',
 'of',
 'the',
 'jin',
 "'",
 's',
 'tom',
 '<unk>',
 ',',
 'who',
 'gathered',
 'the',
 'invading']

In [51]:
temperature = 0.8

generation = generate(prompt, n_gen_tokens, temperature, model, tokenizer, vocab, device, seed)

  self.dropout, self.training, self.bidirectional, self.batch_first)


In [52]:
generation

['the',
 'rigging',
 '<unk>',
 'the',
 'hell',
 '.',
 'the',
 'general',
 "'",
 's',
 'movement',
 'was',
 'later',
 'retired',
 'and',
 'the',
 'jin',
 'captains',
 'had',
 'lost',
 'the',
 'opportunity',
 'to',
 'deliver',
 'the',
 'invading']

In [53]:
temperature = 0.7

generation = generate(prompt, n_gen_tokens, temperature, model, tokenizer, vocab, device, seed)

  self.dropout, self.training, self.bidirectional, self.batch_first)


In [54]:
generation

['the',
 'rigging',
 'of',
 'the',
 'hell',
 '.',
 'the',
 'general',
 "'",
 's',
 'plan',
 'was',
 'to',
 'be',
 'blown',
 'over',
 'as',
 'the',
 'embattled',
 'structure',
 'of',
 'the',
 'rear',
 ',',
 'and',
 'the']