### Importing Libraries

In [None]:
!pip install datasets

In [3]:
import torch
import torch.nn as nn
import math
import torch.optim as optim
import torchtext
from tqdm import tqdm
import datasets # Loading publicaly dataset(by hugging face)

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)

<torch._C.Generator at 0x7f2d1c778990>

In [5]:
device

device(type='cuda')

### loading Dataset

In [None]:
dataset = datasets.load_dataset('wikitext', 'wikitext-2-raw-v1')

In [7]:
dataset['train'][88]['text']

' This ammunition , and that which I brought with me , was rapidly prepared for use at the Laboratory established at the Little Rock Arsenal for that purpose . As illustrating as the pitiful scarcity of material in the country , the fact may be stated that it was found necessary to use public documents of the State Library for cartridge paper . Gunsmiths were employed or conscripted , tools purchased or impressed , and the repair of the damaged guns I brought with me and about an equal number found at Little Rock commenced at once . But , after inspecting the work and observing the spirit of the men I decided that a garrison 500 strong could hold out against Fitch and that I would lead the remainder - about 1500 - to Gen \'l Rust as soon as shotguns and rifles could be obtained from Little Rock instead of pikes and lances , with which most of them were armed . Two days elapsed before the change could be effected . " \n'

### Tokenizing the Dataset

In [12]:
tokenizer = torchtext.data.utils.get_tokenizer('basic_english')

In [13]:
tokenize_data = lambda example, tokenizer: {'tokens': tokenizer(example['text'])}  

In [15]:
tokenized_dataset = dataset.map(tokenize_data, remove_columns=['text'], fn_kwargs={'tokenizer': tokenizer})

Map:   0%|          | 0/4358 [00:00<?, ? examples/s]

Map:   0%|          | 0/36718 [00:00<?, ? examples/s]

Map:   0%|          | 0/3760 [00:00<?, ? examples/s]

In [17]:
print(tokenized_dataset['train'][88]['tokens'])

['this', 'ammunition', ',', 'and', 'that', 'which', 'i', 'brought', 'with', 'me', ',', 'was', 'rapidly', 'prepared', 'for', 'use', 'at', 'the', 'laboratory', 'established', 'at', 'the', 'little', 'rock', 'arsenal', 'for', 'that', 'purpose', '.', 'as', 'illustrating', 'as', 'the', 'pitiful', 'scarcity', 'of', 'material', 'in', 'the', 'country', ',', 'the', 'fact', 'may', 'be', 'stated', 'that', 'it', 'was', 'found', 'necessary', 'to', 'use', 'public', 'documents', 'of', 'the', 'state', 'library', 'for', 'cartridge', 'paper', '.', 'gunsmiths', 'were', 'employed', 'or', 'conscripted', ',', 'tools', 'purchased', 'or', 'impressed', ',', 'and', 'the', 'repair', 'of', 'the', 'damaged', 'guns', 'i', 'brought', 'with', 'me', 'and', 'about', 'an', 'equal', 'number', 'found', 'at', 'little', 'rock', 'commenced', 'at', 'once', '.', 'but', ',', 'after', 'inspecting', 'the', 'work', 'and', 'observing', 'the', 'spirit', 'of', 'the', 'men', 'i', 'decided', 'that', 'a', 'garrison', '500', 'strong', 'co

### Constructing The Vocabulary

In [23]:
vocab = torchtext.vocab.build_vocab_from_iterator(tokenized_dataset['train']['tokens'], min_freq=3) 

In [24]:
vocab.insert_token('<unk>', 0)  

In [25]:
vocab.insert_token('<eos>', 1)            

In [26]:
vocab.set_default_index(vocab['<unk>'])   

In [27]:
print(len(vocab))                         

29473


In [28]:
print(vocab.get_itos()[:10])              

['<unk>', '<eos>', 'the', ',', '.', 'of', 'and', 'in', 'to', 'a']


### Implementing DataLoader

In [29]:
def get_data(dataset, vocab, batch_size):
    data = []                                                   
    for example in dataset:
        if example['tokens']:                                      
            tokens = example['tokens'].append('<eos>')             
            tokens = [vocab[token] for token in example['tokens']] 
            data.extend(tokens)                                    
    data = torch.LongTensor(data)                                 
    num_batches = data.shape[0] // batch_size 
    data = data[:num_batches * batch_size]                       
    data = data.view(batch_size, num_batches)          
    return data

#### Splitting the Data

In [30]:
batch_size = 128

In [31]:
train_data = get_data(tokenized_dataset['train'], vocab, batch_size)
valid_data = get_data(tokenized_dataset['validation'], vocab, batch_size)
test_data = get_data(tokenized_dataset['test'], vocab, batch_size)

### Defining LSTM Model Architecture

In [42]:
class LSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, dropout_rate, 
                tie_weights):
                
        super().__init__()
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.embedding_dim = embedding_dim

        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=num_layers, 
                    dropout=dropout_rate, batch_first=True)
        self.dropout = nn.Dropout(dropout_rate)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        
        if tie_weights:
            assert embedding_dim == hidden_dim, 'cannot tie, check dims'
            self.embedding.weight = self.fc.weight
        self.init_weights()

    def forward(self, src, hidden):
        embedding = self.dropout(self.embedding(src))
        output, hidden = self.lstm(embedding, hidden)          
        output = self.dropout(output) 
        prediction = self.fc(output)
        return prediction, hidden

    def init_weights(self):
        init_range_emb = 0.1
        init_range_other = 1/math.sqrt(self.hidden_dim)
        self.embedding.weight.data.uniform_(-init_range_emb, init_range_emb)
        self.fc.weight.data.uniform_(-init_range_other, init_range_other)
        self.fc.bias.data.zero_()
        for i in range(self.num_layers):
            self.lstm.all_weights[i][0] = torch.FloatTensor(self.embedding_dim,
                    self.hidden_dim).uniform_(-init_range_other, init_range_other) 
            self.lstm.all_weights[i][1] = torch.FloatTensor(self.hidden_dim, 
                    self.hidden_dim).uniform_(-init_range_other, init_range_other) 

    def init_hidden(self, batch_size, device):
        hidden = torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(device)
        cell = torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(device)
        return hidden, cell
      

    def detach_hidden(self, hidden):
        hidden, cell = hidden
        hidden = hidden.detach()
        cell = cell.detach()
        return hidden, cell
        

### Hyperparameter Tuning & Model Initialization

In [43]:
vocab_size = len(vocab)
embedding_dim = 1024             # 400 in the paper
hidden_dim = 1024                # 1150 in the paper
num_layers = 2                   # 3 in the paper
dropout_rate = 0.65              
tie_weights = True                  
lr = 1e-3                        # They used 30 and a different optimizer

In [44]:
model = LSTM(vocab_size, embedding_dim, hidden_dim, num_layers, dropout_rate, tie_weights).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'The model has {num_params:,} trainable parameters')

The model has 47,003,425 trainable parameters


In [45]:
def get_batch(data, seq_len, num_batches, idx):
    src = data[:, idx:idx+seq_len]                   
    target = data[:, idx+1:idx+seq_len+1]             
    return src, target

### Model Training & Evaluating

In [46]:
def train(model, data, optimizer, criterion, batch_size, seq_len, clip, device):
    
    epoch_loss = 0
    model.train()
    
    # drop all batches that are not a multiple of seq_len
    num_batches = data.shape[-1]
    data = data[:, :num_batches - (num_batches -1) % seq_len]
    num_batches = data.shape[-1]

    hidden = model.init_hidden(batch_size, device)
    
    for idx in tqdm(range(0, num_batches - 1, seq_len), desc='Training: ',leave=False):  # The last batch can't be a src
        optimizer.zero_grad()
        hidden = model.detach_hidden(hidden)

        src, target = get_batch(data, seq_len, num_batches, idx)
        src, target = src.to(device), target.to(device)
        batch_size = src.shape[0]
        prediction, hidden = model(src, hidden)               

        prediction = prediction.reshape(batch_size * seq_len, -1)   
        target = target.reshape(-1)
        loss = criterion(prediction, target)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item() * seq_len
    return epoch_loss / num_batches

In [47]:
def evaluate(model, data, criterion, batch_size, seq_len, device):

    epoch_loss = 0
    model.eval()
    num_batches = data.shape[-1]
    data = data[:, :num_batches - (num_batches -1) % seq_len]
    num_batches = data.shape[-1]

    hidden = model.init_hidden(batch_size, device)

    with torch.no_grad():
        for idx in range(0, num_batches - 1, seq_len):
            hidden = model.detach_hidden(hidden)
            src, target = get_batch(data, seq_len, num_batches, idx)
            src, target = src.to(device), target.to(device)
            batch_size= src.shape[0]

            prediction, hidden = model(src, hidden)
            prediction = prediction.reshape(batch_size * seq_len, -1)
            target = target.reshape(-1)

            loss = criterion(prediction, target)
            epoch_loss += loss.item() * seq_len
    return epoch_loss / num_batches

In [48]:
n_epochs = 50
seq_len = 50
clip = 0.25
saved = False

lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=0)

if saved:
    model.load_state_dict(torch.load('best-val-lstm_lm.pt',  map_location=device))
    test_loss = evaluate(model, test_data, criterion, batch_size, seq_len, device)
    print(f'Test Perplexity: {math.exp(test_loss):.3f}')
else:
    best_valid_loss = float('inf')

    for epoch in range(n_epochs):
        train_loss = train(model, train_data, optimizer, criterion, 
                    batch_size, seq_len, clip, device)
        valid_loss = evaluate(model, valid_data, criterion, batch_size, 
                    seq_len, device)
        
        lr_scheduler.step(valid_loss)

        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            torch.save(model.state_dict(), 'best-val-lstm_lm.pt')

        print(f'\tTrain Perplexity: {math.exp(train_loss):.3f}')
        print(f'\tValid Perplexity: {math.exp(valid_loss):.3f}')



	Train Perplexity: 851.853
	Valid Perplexity: 3196.476




	Train Perplexity: 446.340
	Valid Perplexity: 279.173




	Train Perplexity: 303.284
	Valid Perplexity: 232.582




	Train Perplexity: 242.797
	Valid Perplexity: 196.832




	Train Perplexity: 204.218
	Valid Perplexity: 176.652




	Train Perplexity: 176.613
	Valid Perplexity: 158.888




	Train Perplexity: 155.544
	Valid Perplexity: 152.377




	Train Perplexity: 140.409
	Valid Perplexity: 141.557




	Train Perplexity: 127.846
	Valid Perplexity: 138.213




	Train Perplexity: 117.952
	Valid Perplexity: 133.010




	Train Perplexity: 109.634
	Valid Perplexity: 127.670




	Train Perplexity: 102.673
	Valid Perplexity: 124.481




	Train Perplexity: 96.929
	Valid Perplexity: 121.522




	Train Perplexity: 91.941
	Valid Perplexity: 119.525




	Train Perplexity: 87.743
	Valid Perplexity: 117.351




	Train Perplexity: 83.706
	Valid Perplexity: 117.142




	Train Perplexity: 80.249
	Valid Perplexity: 115.916




	Train Perplexity: 77.013
	Valid Perplexity: 114.979




	Train Perplexity: 74.261
	Valid Perplexity: 114.210




	Train Perplexity: 71.945
	Valid Perplexity: 112.512




	Train Perplexity: 69.695
	Valid Perplexity: 112.143




	Train Perplexity: 67.745
	Valid Perplexity: 110.570




	Train Perplexity: 65.864
	Valid Perplexity: 110.658




	Train Perplexity: 61.928
	Valid Perplexity: 110.520




	Train Perplexity: 59.596
	Valid Perplexity: 110.328




	Train Perplexity: 58.494
	Valid Perplexity: 109.484




	Train Perplexity: 57.683
	Valid Perplexity: 109.155




	Train Perplexity: 57.023
	Valid Perplexity: 109.067




	Train Perplexity: 56.261
	Valid Perplexity: 108.544




	Train Perplexity: 55.614
	Valid Perplexity: 109.064




	Train Perplexity: 54.977
	Valid Perplexity: 108.432




	Train Perplexity: 54.548
	Valid Perplexity: 108.229




	Train Perplexity: 54.337
	Valid Perplexity: 108.248




	Train Perplexity: 54.764
	Valid Perplexity: 108.618




	Train Perplexity: 55.211
	Valid Perplexity: 107.415




	Train Perplexity: 54.989
	Valid Perplexity: 107.086




	Train Perplexity: 54.633
	Valid Perplexity: 107.346




	Train Perplexity: 55.464
	Valid Perplexity: 106.919




	Train Perplexity: 55.246
	Valid Perplexity: 106.804




	Train Perplexity: 55.144
	Valid Perplexity: 106.829




	Train Perplexity: 55.821
	Valid Perplexity: 106.616




	Train Perplexity: 55.526
	Valid Perplexity: 106.514




	Train Perplexity: 55.476
	Valid Perplexity: 106.500




	Train Perplexity: 55.752
	Valid Perplexity: 106.440




	Train Perplexity: 55.821
	Valid Perplexity: 106.377




	Train Perplexity: 55.836
	Valid Perplexity: 106.334




	Train Perplexity: 55.803
	Valid Perplexity: 106.323




	Train Perplexity: 55.885
	Valid Perplexity: 106.293




	Train Perplexity: 56.023
	Valid Perplexity: 106.276




	Train Perplexity: 56.136
	Valid Perplexity: 106.268


### Inference

In [49]:
def generate(prompt, max_seq_len, temperature, model, tokenizer, vocab, device, seed=None):
    if seed is not None:
        torch.manual_seed(seed)
    model.eval()
    tokens = tokenizer(prompt)
    indices = [vocab[t] for t in tokens]
    batch_size = 1
    hidden = model.init_hidden(batch_size, device)
    with torch.no_grad():
        for i in range(max_seq_len):
            src = torch.LongTensor([indices]).to(device)
            prediction, hidden = model(src, hidden)
            probs = torch.softmax(prediction[:, -1] / temperature, dim=-1)  
            prediction = torch.multinomial(probs, num_samples=1).item()    
            
            while prediction == vocab['<unk>']:
                prediction = torch.multinomial(probs, num_samples=1).item()

            if prediction == vocab['<eos>']:
                break

            indices.append(prediction)

    itos = vocab.get_itos()
    tokens = [itos[i] for i in indices]
    return tokens

In [52]:
prompt = 'Who is your'
max_seq_len = 30
seed = 0

temperatures = [0.5, 0.7, 0.75, 0.8, 1.0]
for temperature in temperatures:
    generation = generate(prompt, max_seq_len, temperature, model, tokenizer, 
                          vocab, device, seed)
    print(str(temperature)+'\n'+' '.join(generation)+'\n')

0.5
who is your own .

0.7
who is your teacher .

0.75
who is your teacher .

0.8
who is your teacher .

1.0
who is your teacher . just though i knew i ' m agree that they have to go as the death tom sample , you ' re not going to leave . however

