In [1]:
!pip install datasets

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchtext, datasets, math
from tqdm import tqdm

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# making this program comparable if restarted the kernel
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

# torch.cuda.get_device_name(0)

cuda


# 1. Loading Data - Wiki Text

In [4]:
# import os
# os.environ['http_proxy']  = 'http://192.41.170.23:3128'
# os.environ['https_proxy'] = 'http://192.41.170.23:3128'

# there are raw and preprocessed version; we used the raw one and preprocessed ourselves for fun
dataset = datasets.load_dataset('wikitext', 'wikitext-2-raw-v1')
print(dataset)



  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    test: Dataset({
        features: ['text'],
        num_rows: 4358
    })
    train: Dataset({
        features: ['text'],
        num_rows: 36718
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 3760
    })
})


In [5]:
print(dataset['train'][333]['text'])

'''
If you try to change the index you might notice that sometimes there is no paragraph 
and rather an empty string so we will have to care of that later.
'''

 During the same time frame as the Hitchcock rumors , goaltender Curtis Sanford returned from his groin injury on November 13 . He made his first start of the season against the Boston Bruins , losing 2 – 1 in a shootout . Sanford continued his strong play , posting a 3 – 1 – 2 record , 1 @.@ 38 goals against average and .947 save percentage over his next six games . Sanford started 12 consecutive games before Steve Mason made his next start . The number of starts might not have been as numerous , but prior to the November 23 game , Mason was hit in the head by a shot from Rick Nash during pre @-@ game warm @-@ ups and suffered a concussion . Mason returned from his concussion after two games , making a start against the Vancouver Canucks . Mason allowed only one goal in the game despite suffering from cramping in the third period , temporarily being replaced by Sanford for just over three minutes . Columbus won the game 2 – 1 in a shootout , breaking a nine @-@ game losing streak to t

'\nIf you try to change the index you might notice that sometimes there is no paragraph \nand rather an empty string so we will have to care of that later.\n'

# 2. Preprocessing

### Tokenizing

In [8]:
# tokenization
tokenizer = torchtext.data.utils.get_tokenizer('basic_english')

# function to tokenize
tokenize_data = lambda example, tokenizer: {'tokens': tokenizer(example['text'])}  

# mapping the function to each example
tokenized_dataset = dataset.map(tokenize_data, remove_columns=['text'], fn_kwargs={'tokenizer': tokenizer})
print(tokenized_dataset['train'][333]['tokens'])



['during', 'the', 'same', 'time', 'frame', 'as', 'the', 'hitchcock', 'rumors', ',', 'goaltender', 'curtis', 'sanford', 'returned', 'from', 'his', 'groin', 'injury', 'on', 'november', '13', '.', 'he', 'made', 'his', 'first', 'start', 'of', 'the', 'season', 'against', 'the', 'boston', 'bruins', ',', 'losing', '2', '–', '1', 'in', 'a', 'shootout', '.', 'sanford', 'continued', 'his', 'strong', 'play', ',', 'posting', 'a', '3', '–', '1', '–', '2', 'record', ',', '1', '@', '.', '@', '38', 'goals', 'against', 'average', 'and', '.', '947', 'save', 'percentage', 'over', 'his', 'next', 'six', 'games', '.', 'sanford', 'started', '12', 'consecutive', 'games', 'before', 'steve', 'mason', 'made', 'his', 'next', 'start', '.', 'the', 'number', 'of', 'starts', 'might', 'not', 'have', 'been', 'as', 'numerous', ',', 'but', 'prior', 'to', 'the', 'november', '23', 'game', ',', 'mason', 'was', 'hit', 'in', 'the', 'head', 'by', 'a', 'shot', 'from', 'rick', 'nash', 'during', 'pre', '@-@', 'game', 'warm', '@-@

### Numericalizing

In [9]:
# numericalization
vocab = torchtext.vocab.build_vocab_from_iterator(tokenized_dataset['train']['tokens'], 
min_freq=3) 
vocab.insert_token('<unk>', 0)           
vocab.insert_token('<eos>', 1)            
vocab.set_default_index(vocab['<unk>'])   
print(len(vocab))                         
print(vocab.get_itos()[:10])  

29473
['<unk>', '<eos>', 'the', ',', '.', 'of', 'and', 'in', 'to', 'a']


# 3. Preparing the BatchLoader

### Prepare Data

In [11]:
def get_data(dataset, vocab, batch_size):
    data = []                                                   
    for example in dataset:
        if example['tokens']:         
            # appending eos so we know it ends....so model learn how to end...                             
            tokens = example['tokens'].append('<eos>')   
            # numericalization          
            tokens = [vocab[token] for token in example['tokens']] 
            data.extend(tokens)                                    
    data = torch.LongTensor(data)                                 
    num_batches = data.shape[0] // batch_size # getting the int number of batches...
    data = data[:num_batches * batch_size] # making the batch evenly, and cut out any remaining                      
    data = data.view(batch_size, num_batches)          
    return data # [batch size, bunch of tokens]

In [12]:
batch_size = 128
train_data = get_data(tokenized_dataset['train'], vocab, batch_size)
valid_data = get_data(tokenized_dataset['validation'], vocab, batch_size)
test_data  = get_data(tokenized_dataset['test'], vocab, batch_size)

# 4. Modeling

In [13]:
class LSTMLanguageModel(nn.Module):
    def __init__(self, vocab_size, emb_dim, hid_dim, num_layers, dropout_rate):
                
        super().__init__()
        self.num_layers = num_layers
        self.hid_dim = hid_dim
        self.emb_dim = emb_dim

        self.embedding = nn.Embedding(vocab_size, emb_dim)
        self.lstm = nn.LSTM(emb_dim, hid_dim, num_layers=num_layers, 
                    dropout=dropout_rate, batch_first=True)
        self.dropout = nn.Dropout(dropout_rate)
        self.fc = nn.Linear(hid_dim, vocab_size)
        
        self.init_weights()
        
    def init_weights(self):
        init_range_emb = 0.1
        init_range_other = 1/math.sqrt(self.hid_dim)
        self.embedding.weight.data.uniform_(-init_range_emb, init_range_emb)
        self.fc.weight.data.uniform_(-init_range_other, init_range_other)
        self.fc.bias.data.zero_()
        for i in range(self.num_layers):
            self.lstm.all_weights[i][0] = torch.FloatTensor(self.emb_dim,
                    self.hid_dim).uniform_(-init_range_other, init_range_other) 
            self.lstm.all_weights[i][1] = torch.FloatTensor(self.hid_dim, 
                    self.hid_dim).uniform_(-init_range_other, init_range_other) 

    def init_hidden(self, batch_size, device):
        hidden = torch.zeros(self.num_layers, batch_size, self.hid_dim).to(device)
        cell   = torch.zeros(self.num_layers, batch_size, self.hid_dim).to(device)
        return hidden, cell
    
    def detach_hidden(self, hidden):
        hidden, cell = hidden
        hidden = hidden.detach()
        cell = cell.detach()
        return hidden, cell

        
    def forward(self, src, hidden):
        # src: [batch size, seq len]
        embedding = self.dropout(self.embedding(src))
        # embedding: [batch size, seq len, emb_dim]
        output, hidden = self.lstm(embedding, hidden)      
        # output: [batch size, seq len, hid_dim]
        # hidden = h, c = [num_layers * direction, seq len, hid_dim)
        output = self.dropout(output) 
        prediction = self.fc(output)
        # prediction: [batch size, seq_len, vocab size]
        return prediction, hidden

# 5. Training

In [14]:
vocab_size = len(vocab)
emb_dim = 1024                # 400 in the paper
hid_dim = 1024                # 1150 in the paper
num_layers = 2                # 3 in the paper
dropout_rate = 0.65              
lr = 1e-3             

In [15]:
model = LSTMLanguageModel(vocab_size, emb_dim, hid_dim, num_layers, dropout_rate).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'The model has {num_params:,} trainable parameters')

The model has 77,183,777 trainable parameters


In [17]:
def get_batch(data, seq_len, idx):
    # data # [batch size, bunch of tokens]
    src    = data[:, idx:idx+seq_len]                   
    target = data[:, idx+1:idx+seq_len+1]  # target simply is ahead of src by 1            
    return src, target

In [18]:
def train(model, data, optimizer, criterion, batch_size, seq_len, clip, device):
    
    epoch_loss = 0
    model.train()
    # drop all batches that are not a multiple of seq_len
    # data # [batch size, bunch of tokens]
    num_batches = data.shape[-1]
    data = data[:, :num_batches - (num_batches -1) % seq_len]  # we need to -1 because we start at 0
    num_batches = data.shape[-1]
    
    # reseting the hidden every epoch
    hidden = model.init_hidden(batch_size, device)
    
    for idx in tqdm(range(0, num_batches - 1, seq_len), desc='Training: ',leave=False):
        optimizer.zero_grad()
        
        # hidden does not need to be in the computational graph for efficiency
        hidden = model.detach_hidden(hidden)

        src, target = get_batch(data, seq_len, idx) # src, target: [batch size, seq len]
        src, target = src.to(device), target.to(device)
        batch_size = src.shape[0]
        prediction, hidden = model(src, hidden)               

        # need to reshape because criterion expects pred to be 2d and target to be 1d
        prediction = prediction.reshape(batch_size * seq_len, -1)  # prediction: [batch size * seq len, vocab size]  
        target = target.reshape(-1)
        loss = criterion(prediction, target)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item() * seq_len
    return epoch_loss / num_batches

In [19]:
def evaluate(model, data, criterion, batch_size, seq_len, device):

    epoch_loss = 0
    model.eval()
    num_batches = data.shape[-1]
    data = data[:, :num_batches - (num_batches -1) % seq_len]
    num_batches = data.shape[-1]

    hidden = model.init_hidden(batch_size, device)

    with torch.no_grad():
        for idx in range(0, num_batches - 1, seq_len):
            hidden = model.detach_hidden(hidden)
            src, target = get_batch(data, seq_len, idx)
            src, target = src.to(device), target.to(device)
            batch_size= src.shape[0]

            prediction, hidden = model(src, hidden)
            prediction = prediction.reshape(batch_size * seq_len, -1)
            target = target.reshape(-1)

            loss = criterion(prediction, target)
            epoch_loss += loss.item() * seq_len
    return epoch_loss / num_batches

In [20]:
n_epochs = 50
seq_len  = 50 # <----decoding length
clip    = 0.25

lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=0)

best_valid_loss = float('inf')

for epoch in range(n_epochs):
    train_loss = train(model, train_data, optimizer, criterion, 
                batch_size, seq_len, clip, device)
    valid_loss = evaluate(model, valid_data, criterion, batch_size, 
                seq_len, device)

    lr_scheduler.step(valid_loss)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'best-val-lstm_lm.pt')

    print(f'\tTrain Perplexity: {math.exp(train_loss):.3f}')
    print(f'\tValid Perplexity: {math.exp(valid_loss):.3f}')



	Train Perplexity: 855.537
	Valid Perplexity: 454.836




	Train Perplexity: 417.077
	Valid Perplexity: 284.078




	Train Perplexity: 303.390
	Valid Perplexity: 239.372




	Train Perplexity: 249.412
	Valid Perplexity: 214.365




	Train Perplexity: 212.226
	Valid Perplexity: 199.935




	Train Perplexity: 183.642
	Valid Perplexity: 179.601




	Train Perplexity: 162.090
	Valid Perplexity: 168.163




	Train Perplexity: 145.432
	Valid Perplexity: 159.212




	Train Perplexity: 132.125
	Valid Perplexity: 154.703




	Train Perplexity: 120.825
	Valid Perplexity: 150.575




	Train Perplexity: 111.898
	Valid Perplexity: 146.667




	Train Perplexity: 103.964
	Valid Perplexity: 144.243




	Train Perplexity: 97.229
	Valid Perplexity: 143.917




	Train Perplexity: 91.508
	Valid Perplexity: 143.385




	Train Perplexity: 86.274
	Valid Perplexity: 141.730




	Train Perplexity: 81.964
	Valid Perplexity: 139.952




	Train Perplexity: 77.800
	Valid Perplexity: 137.541




	Train Perplexity: 74.186
	Valid Perplexity: 137.217




	Train Perplexity: 71.241
	Valid Perplexity: 138.234




	Train Perplexity: 66.134
	Valid Perplexity: 138.012




	Train Perplexity: 63.357
	Valid Perplexity: 138.002




	Train Perplexity: 62.434
	Valid Perplexity: 136.502




	Train Perplexity: 61.910
	Valid Perplexity: 135.423




	Train Perplexity: 61.361
	Valid Perplexity: 135.076




	Train Perplexity: 60.714
	Valid Perplexity: 134.792




	Train Perplexity: 59.886
	Valid Perplexity: 134.513




	Train Perplexity: 59.129
	Valid Perplexity: 134.640




	Train Perplexity: 59.002
	Valid Perplexity: 135.290




	Train Perplexity: 59.428
	Valid Perplexity: 134.578




	Train Perplexity: 60.350
	Valid Perplexity: 132.425




	Train Perplexity: 60.973
	Valid Perplexity: 133.736




	Train Perplexity: 61.965
	Valid Perplexity: 132.810




	Train Perplexity: 62.798
	Valid Perplexity: 132.254




	Train Perplexity: 62.762
	Valid Perplexity: 132.014




	Train Perplexity: 62.858
	Valid Perplexity: 131.933




	Train Perplexity: 62.647
	Valid Perplexity: 132.110




	Train Perplexity: 63.129
	Valid Perplexity: 132.138




	Train Perplexity: 63.275
	Valid Perplexity: 132.103




	Train Perplexity: 63.382
	Valid Perplexity: 132.078




	Train Perplexity: 63.362
	Valid Perplexity: 132.065




	Train Perplexity: 63.453
	Valid Perplexity: 132.056




	Train Perplexity: 63.364
	Valid Perplexity: 132.052




	Train Perplexity: 63.164
	Valid Perplexity: 132.049




	Train Perplexity: 63.308
	Valid Perplexity: 132.048




	Train Perplexity: 63.310
	Valid Perplexity: 132.047




	Train Perplexity: 63.348
	Valid Perplexity: 132.046




	Train Perplexity: 63.375
	Valid Perplexity: 132.045




	Train Perplexity: 63.391
	Valid Perplexity: 132.044




	Train Perplexity: 63.392
	Valid Perplexity: 132.042




	Train Perplexity: 63.459
	Valid Perplexity: 132.042


# 6. Testing

In [21]:
model.load_state_dict(torch.load('best-val-lstm_lm.pt',  map_location=device))
test_loss = evaluate(model, test_data, criterion, batch_size, seq_len, device)
print(f'Test Perplexity: {math.exp(test_loss):.3f}')

Test Perplexity: 124.413


# 7. Real-world inference

In [22]:
def generate(prompt, max_seq_len, temperature, model, tokenizer, vocab, device, seed=None):
    if seed is not None:
        torch.manual_seed(seed)
    model.eval()
    tokens = tokenizer(prompt)
    indices = [vocab[t] for t in tokens]
    batch_size = 1
    hidden = model.init_hidden(batch_size, device)
    with torch.no_grad():
        for i in range(max_seq_len):
            src = torch.LongTensor([indices]).to(device)
            prediction, hidden = model(src, hidden)
            
            #prediction: [batch size, seq len, vocab size]
            #prediction[:, -1]: [batch size, vocab size] #probability of last vocab
            
            probs = torch.softmax(prediction[:, -1] / temperature, dim=-1)  
            prediction = torch.multinomial(probs, num_samples=1).item()    
            
            while prediction == vocab['<unk>']: #if it is unk, we sample again
                prediction = torch.multinomial(probs, num_samples=1).item()

            if prediction == vocab['<eos>']:    #if it is eos, we stop
                break

            indices.append(prediction) #autoregressive, thus output becomes input

    itos = vocab.get_itos()
    tokens = [itos[i] for i in indices]
    return tokens

In [23]:
prompt = 'Harry Potter is '
max_seq_len = 30
seed = 0

#smaller the temperature, more diverse tokens but comes 
#with a tradeoff of less-make-sense sentence
temperatures = [0.5, 0.7, 0.75, 0.8, 1.0]
for temperature in temperatures:
    generation = generate(prompt, max_seq_len, temperature, model, tokenizer, 
                          vocab, device, seed)
    print(str(temperature)+'\n'+' '.join(generation)+'\n')

0.5
harry potter is a part of the series ' s final games . the game was released in the united states on october 25 , 2002 , and was released on july 22

0.7
harry potter is a part of his bachelor ' s domestic movement . the first of the first two years were married , the 7th dynasty ( and later the scottish union )

0.75
harry potter is a part of his bachelor ' s domestic movement . the cathedral is a major figure in the commercial , and the site ' s status is added to this

0.8
harry potter is a part of his bachelor ' s domestic movement . a theme of the assassination of the goat is sung by the fact that he is a supporter of the

1.0
harry potter is said to have been the most magnificent 3 – 11 with the friend

