# Step 0: Preparations and code from Assignment 1

Group 17: Jakob Svensson, Mahdi Afarideh, Maximilian Forsell

In [1]:
!git clone https://github.com/MahdiTheGreat/Intro-to-language-modeling.git
%cd Intro-to-language-modeling

Cloning into 'Intro-to-language-modeling'...
remote: Enumerating objects: 76, done.[K
remote: Counting objects: 100% (76/76), done.[K
remote: Compressing objects: 100% (75/75), done.[K
remote: Total 76 (delta 40), reused 2 (delta 0), pack-reused 0 (from 0)[K
Receiving objects: 100% (76/76), 31.83 MiB | 9.61 MiB/s, done.
Resolving deltas: 100% (40/40), done.
/content/Intro-to-language-modeling


In [2]:
import sklearn
import nltk
import torch
import matplotlib.pyplot as plt
import numpy as np
import random
import pandas as pd
from tqdm import tqdm

In [3]:
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


True

In [4]:
# Set random seed for reproducibility
def set_seed(seed=2024):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

set_seed(1998)

In [5]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu'))
print(f'Using device: {device}')

Using device: cuda


In [6]:
dataset='lmdemo'
zip_file = f"{dataset}.zip"
!unzip -q $zip_file
!rm $zip_file

In [7]:
training_set=open(f'{dataset}/train.txt','r',encoding='utf-8').read()
val_set=open(f'{dataset}/val.txt','r',encoding='utf-8').read()

In [8]:
from collections import Counter
class VocabularyBuilder:
    def __init__(self, max_voc_size):
        self.max_voc_size = max_voc_size
        self.str_to_int = {}
        self.int_to_str = {}
        self.special_tokens = ["BEGINNING", "END", "UNKNOWN", "PADDING"] #Added padding
        self.token_counter = Counter()

    def build_vocabulary(self, text):

        sents=nltk.word_tokenize(text.lower())

        for token in sents:
            self.token_counter[token] += 1

    def create_vocabulary(self):
        for idx, token in enumerate(self.special_tokens):
            self.str_to_int[token] = idx
            self.int_to_str[idx] = token

        max_words = self.max_voc_size - len(self.special_tokens)
        most_common_tokens = self.token_counter.most_common(max_words)

        for idx, (token, _) in enumerate(most_common_tokens, start=len(self.special_tokens)):
            self.str_to_int[token] = idx
            self.int_to_str[idx] = token

    def create_premade_vocabulary(self, c):
        for idx, token in enumerate(self.special_tokens):
            self.str_to_int[token] = idx
            self.int_to_str[idx] = token

        max_words = self.max_voc_size - len(self.special_tokens)
        most_common_tokens = c.most_common(max_words) # Here we can use a premade counter from a previous run

        for idx, (token, _) in enumerate(most_common_tokens, start=len(self.special_tokens)):
            self.str_to_int[token] = idx
            self.int_to_str[idx] = token

    def get_token_id(self, token):
        return self.str_to_int.get(token.lower(), self.str_to_int["UNKNOWN"])

    def get_token_str(self, token_id):
        return self.int_to_str.get(token_id, "UNKNOWN")

    def sanity_check(self): # Here we run the sanity tests recommended in the assignment
        assert len(self.str_to_int) <= self.max_voc_size, "Vocabulary size exceeds max_voc_size."

        for token in self.special_tokens:
            assert token in self.str_to_int, f"Missing special token: {token}"

        common_words = ["the", "and"]
        rare_words = ["cuboidal", "epiglottis"]

        for word in common_words:
            assert word in self.str_to_int, f"Common word '{word}' not in vocabulary."

        for word in rare_words:
            assert word not in self.str_to_int, f"Rare word '{word}' should not be in vocabulary."

        test_word = "the"
        token_id = self.get_token_id(test_word)
        assert self.get_token_str(token_id) == test_word.lower(), "Round-trip token mapping failed."

        print("Sanity check passed!")

vocab_builder = VocabularyBuilder(max_voc_size=16384)


In [9]:
# Run only once!
for paragraph in tqdm(training_set.splitlines()):
  vocab_builder.build_vocabulary(paragraph)
vocab_builder.create_vocabulary()

100%|██████████| 294118/294118 [01:24<00:00, 3466.54it/s]


In [10]:
# Save vocab so we don't have to rerun it
counter= vocab_builder.token_counter
with open("full_vocab", 'w') as f:
    for k,v in  counter.most_common():
        f.write( "{} {}\n".format(k,v) )

In [11]:
# Run this using full_vocab from first run
premade_counter = Counter()

with open("/content/Intro-to-language-modeling/full_vocab", 'r') as file:
    for line in file:
        parts = line.split(" ")
        if len(parts) == 2:
            word, freq = parts[0], int(parts[1])
            premade_counter[word] = freq
vocab_builder.create_premade_vocabulary(premade_counter)


In [12]:
# Perform sanity check
vocab_builder.sanity_check()

Sanity check passed!


In [13]:
# Modified for assignment 2
class TrainingDataPreparerRNN:
    def __init__(self, vocab_builder, max_sequence_length):
        self.vocab_builder = vocab_builder
        self.max = max_sequence_length

    def encode_text(self, text):
        """Tokenizes and encodes a single string with special symbols.

        Parameters:
        - text (str): The input string to encode.

        Returns:
        - List[int]: A list of token IDs including BEGINNING and END tokens.
        """
        # Tokenize the text
        tokens = nltk.word_tokenize(text.lower())

        token_ids = [self.vocab_builder.get_token_id(token) for token in tokens]
        modified_tokens = [0] # Add 1 BEGINNING
        modified_tokens.extend(token_ids)
        modified_tokens.append(1) # Add 1 END

        return modified_tokens

    def create_training_sequences(self, text):
        """
        Creates training sequences from a single string by generating sequences of length N+1.

        Parameters:
        - text (str): The input string to create sequences from.

        Returns:
        - List[Tuple[List[int], int]]: A list of (context, target) pairs.
        """
        encoded_text = self.encode_text(text)

        # Taken from: https://www.geeksforgeeks.org/break-list-chunks-size-n-python/
        training_sequences = [encoded_text[i * self.max:(i + 1) * self.max] for i in range((len(encoded_text) + self.max - 1) // self.max )]

        return training_sequences


# Step 1

In [14]:
# Splitting
preparer = TrainingDataPreparerRNN(vocab_builder, max_sequence_length=50)

training_sequences = []
split_training_set = list(filter(''.__ne__, training_set.splitlines())) # Split and remove empty lines
for paragraph in tqdm(split_training_set):
  training_sequences.append(preparer.create_training_sequences(paragraph))
flattened_training_sequences =  [
    x
    for xs in training_sequences
    for x in xs
]

100%|██████████| 147059/147059 [01:26<00:00, 1694.62it/s]


In [15]:
# Prepare validation data also
val_sequences = []
split_val_set = list(filter(''.__ne__, val_set.splitlines())) # Split and remove empty lines
for paragraph in tqdm(split_val_set):
  val_sequences.append(preparer.create_training_sequences(paragraph))
flattened_val_sequences =  [
    x
    for xs in val_sequences
    for x in xs
]

100%|██████████| 17874/17874 [00:11<00:00, 1620.33it/s]


In [16]:
# Sanity check
for context in flattened_training_sequences[:10]:  # Show the first few sequences
    print([vocab_builder.get_token_str(id) for id in context])

['BEGINNING', 'anatomy', 'END']
['BEGINNING', 'anatomy', '(', 'greek', 'UNKNOWN', ',', '“', 'dissection', '”', ')', 'is', 'the', 'branch', 'of', 'biology', 'concerned', 'with', 'the', 'study', 'of', 'the', 'structure', 'of', 'organisms', 'and', 'their', 'parts', '.', 'anatomy', 'is', 'a', 'branch', 'of', 'natural', 'science', 'dealing', 'with', 'the', 'structural', 'organization', 'of', 'living', 'things', '.', 'it', 'is', 'an', 'old', 'science', ',']
['having', 'its', 'beginnings', 'in', 'prehistoric', 'times', '.', 'anatomy', 'is', 'inherently', 'tied', 'to', 'UNKNOWN', ',', 'comparative', 'anatomy', ',', 'evolutionary', 'biology', ',', 'and', 'phylogeny', ',', 'as', 'these', 'are', 'the', 'processes', 'by', 'which', 'anatomy', 'is', 'generated', 'over', 'immediate', '(', 'UNKNOWN', ')', 'and', 'long', '(', 'evolution', ')', 'UNKNOWN', '.', 'human', 'anatomy', 'is', 'one', 'of']
['the', 'basic', 'essential', 'sciences', 'of', 'medicine', '.', 'END']
['BEGINNING', 'the', 'discipline',

In [17]:
# Sanity check nr. 2
print(len(flattened_training_sequences))
print(len(flattened_val_sequences))

323198
40499


In [18]:
#Adapted batcher
from torch.utils.data import DataLoader, TensorDataset
def TorchDataLoaderRNN(training_sequences, max_sequence_length, batch_size):
  # Padding
  padded_sequences = [sequence +([3] * (max_sequence_length - len(sequence))) for sequence in training_sequences] # PADDING has integer code 3

  # Convert lists to tensors
  context_tensor = torch.tensor(padded_sequences, dtype=torch.long)  # Shape: (num_samples, 3)

  # Create a TensorDataset
  dataset = TensorDataset(context_tensor)

  # Create a DataLoader for batching
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

  return dataloader

In [19]:
trainloader = TorchDataLoaderRNN(flattened_training_sequences, 50, 64)

In [20]:
valloader = TorchDataLoaderRNN(flattened_val_sequences, 50, 64)

In [21]:
# Sanity check
for batch_context in trainloader:
    print(batch_context[0])
    print(batch_context[0].shape)
    break

tensor([[   14,   524,     4,  ...,     1,     3,     3],
        [    2,   265,    10,  ...,   319,   787,     8],
        [   11,   678,     7,  ...,     4,  6738,  1639],
        ...,
        [ 2302,     2,    10,  ...,   121,     6,   262],
        [    0,    61,    29,  ..., 11305,    47,  2404],
        [   38,    50,   488,  ...,     3,     3,     3]])
torch.Size([64, 50])


# Step 2: RNN model

In [22]:
import torch
import torch.nn as nn
import torch.optim as optim


# EarlyStopping class remains the same
class EarlyStopping:
    def __init__(self, patience=5, delta=0, verbose=False, path='checkpoint.pth'):
        self.patience = patience  # Number of epochs to wait for improvement
        self.delta = delta  # Minimum change to qualify as an improvement
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.path = path  # Path to save the best model

    def __call__(self, val_loss, model):
        if self.best_score is None:
            self.best_score = val_loss
            self.save_checkpoint(val_loss, model)
        elif val_loss < self.best_score - self.delta:
            self.best_score = val_loss
            self.save_checkpoint(val_loss, model)
            self.counter = 0
        else:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True

    def save_checkpoint(self, val_loss, model):
        '''Save model when validation loss decreases.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

class RNN(nn.Module):

    def __init__(self, vocab_size, embed_size, layer_sizes,activation=nn.ReLU,last_layer_activation=nn.Softmax,dropout=0):

        super(RNN, self).__init__()

        self.embeddings = nn.Embedding(vocab_size, embed_size)
        self.layers = nn.ModuleList()

        for i in range(len(layer_sizes)-2):
          self.layers.append(nn.LSTM(layer_sizes[i], layer_sizes[i+1]))
          self.layers.append(nn.Dropout(dropout))
          self.layers.append(activation())

        self.layers.append(nn.Linear(layer_sizes[-2], layer_sizes[-1]))
        if last_layer_activation is not None:
         self.layers.append(nn.Dropout(dropout))
         self.layers.append(last_layer_activation())

    def forward(self, x):
        embeddings = self.embeddings(x)  # Get word embeddings for each word in the batch

        # Flatten the input embeddings
        x = embeddings.view(-1, np.prod(embeddings.shape[1:]))

        x = x.float()
        for layer in self.layers:
            if isinstance(layer, nn.LSTM):
                x = layer(x)[0]
            else:
                x = layer(x)
        return x


In [24]:
model = RNN(layer_sizes=[6272, 2048, 16384], vocab_size=16384, embed_size=128)
model.to(device)
criterion = nn.CrossEntropyLoss(ignore_index=3) # Ignore padding
optimizer = optim.Adam(model.parameters(), lr=0.001)

patience = 5
early_stopping = EarlyStopping(patience=patience, verbose=True)

number_of_epochs = 50

for epoch in range(number_of_epochs):
    model.train()  # Set model to training mode
    for batch_context in tqdm(trainloader):
        #FORWARD PASS:
        X = batch_context[0][:,:-1]
        Y = batch_context[0][:,1]
        X, Y = X.to(device), Y.to(device)
        logits = model(X)  # Model output for X
        targets = Y.view(-1)                      # 2-dimensional -> 1-dimensional
        logits = logits.view(-1, logits.shape[-1])  # 3-dimensional -> 2-dimensional
        loss = criterion(logits, targets) # Compute the loss between model output and Y

        #BACKWARD PASS (updating the model parameters):
        optimizer.zero_grad()  # Clear gradients
        loss.backward()        # Compute gradients
        optimizer.step()       # Update model parameters

    print(f"Epoch [{epoch+1}/{number_of_epochs}], Loss: {loss.item():.4f}")

    # Validation loop
    model.eval()  # Set model to evaluation mode
    val_loss = 0.0
    with torch.no_grad():  # No gradient computation for validation
        for batch_context in valloader:
        #FORWARD PASS:
          X = batch_context[0][:,:-1]
          Y = batch_context[0][:,1]
          X, Y = X.to(device), Y.to(device)
          logits = model(X)  # Model output for X
          targets = Y.view(-1)                      # 2-dimensional -> 1-dimensional
          logits = logits.view(-1, logits.shape[-1])  # 3-dimensional -> 2-dimensional
          loss = criterion(logits, targets) # Compute the loss between model output and Y
          val_loss += loss.item()

    avg_val_loss = val_loss / len(valloader)  # Average validation loss
    print(f"Epoch {epoch+1}/{number_of_epochs} - Perplexity: {np.exp(avg_val_loss):.6f}")

    # Call early stopping after each epoch
    early_stopping(avg_val_loss, model)

    if early_stopping.early_stop:
        print("Early stopping triggered!")
        break

# Optionally, load the best model after training
model.load_state_dict(torch.load('checkpoint.pth'))

100%|██████████| 1563/1563 [02:08<00:00, 12.19it/s]


Epoch [1/50], Loss: 9.3220
Epoch 1/50 - Perplexity: 10561.115899
Validation loss decreased (inf --> 9.264934).  Saving model ...


100%|██████████| 1563/1563 [02:08<00:00, 12.19it/s]


Epoch [2/50], Loss: 9.1418
Epoch 2/50 - Perplexity: 9141.887963
Validation loss decreased (9.264934 --> 9.120622).  Saving model ...


100%|██████████| 1563/1563 [02:08<00:00, 12.18it/s]


Epoch [3/50], Loss: 8.9855
Epoch 3/50 - Perplexity: 8797.847437
Validation loss decreased (9.120622 --> 9.082262).  Saving model ...


100%|██████████| 1563/1563 [02:08<00:00, 12.17it/s]


Epoch [4/50], Loss: 9.1425
Epoch 4/50 - Perplexity: 8575.104197
Validation loss decreased (9.082262 --> 9.056618).  Saving model ...


100%|██████████| 1563/1563 [02:08<00:00, 12.17it/s]


Epoch [5/50], Loss: 9.0167
Epoch 5/50 - Perplexity: 8346.672884
Validation loss decreased (9.056618 --> 9.029618).  Saving model ...


100%|██████████| 1563/1563 [02:08<00:00, 12.17it/s]


Epoch [6/50], Loss: 9.0408
Epoch 6/50 - Perplexity: 8197.973643
Validation loss decreased (9.029618 --> 9.011642).  Saving model ...


100%|██████████| 1563/1563 [02:08<00:00, 12.17it/s]


Epoch [7/50], Loss: 8.9946
Epoch 7/50 - Perplexity: 8086.192458
Validation loss decreased (9.011642 --> 8.997913).  Saving model ...


100%|██████████| 1563/1563 [02:08<00:00, 12.17it/s]


Epoch [8/50], Loss: 9.0698
Epoch 8/50 - Perplexity: 7959.843709
Validation loss decreased (8.997913 --> 8.982165).  Saving model ...


100%|██████████| 1563/1563 [02:08<00:00, 12.17it/s]


Epoch [9/50], Loss: 8.9187
Epoch 9/50 - Perplexity: 7882.849640
Validation loss decreased (8.982165 --> 8.972445).  Saving model ...


100%|██████████| 1563/1563 [02:08<00:00, 12.17it/s]


Epoch [10/50], Loss: 9.0476
Epoch 10/50 - Perplexity: 7772.631272
Validation loss decreased (8.972445 --> 8.958364).  Saving model ...


100%|██████████| 1563/1563 [02:08<00:00, 12.17it/s]


Epoch [11/50], Loss: 8.9230
Epoch 11/50 - Perplexity: 7696.322012
Validation loss decreased (8.958364 --> 8.948498).  Saving model ...


100%|██████████| 1563/1563 [02:08<00:00, 12.19it/s]


Epoch [12/50], Loss: 8.8927
Epoch 12/50 - Perplexity: 7624.042373
Validation loss decreased (8.948498 --> 8.939062).  Saving model ...


100%|██████████| 1563/1563 [02:08<00:00, 12.19it/s]


Epoch [13/50], Loss: 8.9311
Epoch 13/50 - Perplexity: 7594.505649
Validation loss decreased (8.939062 --> 8.935180).  Saving model ...


100%|██████████| 1563/1563 [02:08<00:00, 12.20it/s]


Epoch [14/50], Loss: 8.9774
Epoch 14/50 - Perplexity: 7529.155276
Validation loss decreased (8.935180 --> 8.926538).  Saving model ...


100%|██████████| 1563/1563 [02:08<00:00, 12.20it/s]


Epoch [15/50], Loss: 8.7668
Epoch 15/50 - Perplexity: 7489.050019
Validation loss decreased (8.926538 --> 8.921197).  Saving model ...


100%|██████████| 1563/1563 [02:08<00:00, 12.20it/s]


Epoch [16/50], Loss: 8.8468
Epoch 16/50 - Perplexity: 7459.713686
Validation loss decreased (8.921197 --> 8.917272).  Saving model ...


100%|██████████| 1563/1563 [02:08<00:00, 12.20it/s]


Epoch [17/50], Loss: 8.8608
Epoch 17/50 - Perplexity: 7449.566681
Validation loss decreased (8.917272 --> 8.915911).  Saving model ...


100%|██████████| 1563/1563 [02:08<00:00, 12.20it/s]


Epoch [18/50], Loss: 8.8518
Epoch 18/50 - Perplexity: 7410.085304
Validation loss decreased (8.915911 --> 8.910597).  Saving model ...


100%|██████████| 1563/1563 [02:08<00:00, 12.20it/s]


Epoch [19/50], Loss: 8.7979
Epoch 19/50 - Perplexity: 7360.674195
Validation loss decreased (8.910597 --> 8.903907).  Saving model ...


100%|██████████| 1563/1563 [02:08<00:00, 12.20it/s]


Epoch [20/50], Loss: 8.8067
Epoch 20/50 - Perplexity: 7339.012128
Validation loss decreased (8.903907 --> 8.900960).  Saving model ...


100%|██████████| 1563/1563 [02:08<00:00, 12.21it/s]


Epoch [21/50], Loss: 8.8293
Epoch 21/50 - Perplexity: 7341.449532
EarlyStopping counter: 1 out of 5


100%|██████████| 1563/1563 [02:08<00:00, 12.21it/s]


Epoch [22/50], Loss: 8.7980
Epoch 22/50 - Perplexity: 7315.444738
Validation loss decreased (8.900960 --> 8.897743).  Saving model ...


100%|██████████| 1563/1563 [02:08<00:00, 12.20it/s]


Epoch [23/50], Loss: 8.7979
Epoch 23/50 - Perplexity: 7273.188920
Validation loss decreased (8.897743 --> 8.891950).  Saving model ...


100%|██████████| 1563/1563 [02:08<00:00, 12.20it/s]


Epoch [24/50], Loss: 8.8292
Epoch 24/50 - Perplexity: 7268.474944
Validation loss decreased (8.891950 --> 8.891302).  Saving model ...


100%|██████████| 1563/1563 [02:08<00:00, 12.20it/s]


Epoch [25/50], Loss: 8.7991
Epoch 25/50 - Perplexity: 7238.942326
Validation loss decreased (8.891302 --> 8.887230).  Saving model ...


100%|██████████| 1563/1563 [02:08<00:00, 12.20it/s]


Epoch [26/50], Loss: 8.7050
Epoch 26/50 - Perplexity: 7251.909553
EarlyStopping counter: 1 out of 5


100%|██████████| 1563/1563 [02:08<00:00, 12.21it/s]


Epoch [27/50], Loss: 8.7667
Epoch 27/50 - Perplexity: 7232.273846
Validation loss decreased (8.887230 --> 8.886309).  Saving model ...


100%|██████████| 1563/1563 [02:08<00:00, 12.21it/s]


Epoch [28/50], Loss: 8.7364
Epoch 28/50 - Perplexity: 7229.802467
Validation loss decreased (8.886309 --> 8.885967).  Saving model ...


100%|██████████| 1563/1563 [02:08<00:00, 12.20it/s]


Epoch [29/50], Loss: 8.7667
Epoch 29/50 - Perplexity: 7213.790645
Validation loss decreased (8.885967 --> 8.883750).  Saving model ...


100%|██████████| 1563/1563 [02:08<00:00, 12.20it/s]


Epoch [30/50], Loss: 8.7666
Epoch 30/50 - Perplexity: 7205.643353
Validation loss decreased (8.883750 --> 8.882620).  Saving model ...


100%|██████████| 1563/1563 [02:08<00:00, 12.20it/s]


Epoch [31/50], Loss: 8.7979
Epoch 31/50 - Perplexity: 7184.845195
Validation loss decreased (8.882620 --> 8.879729).  Saving model ...


100%|██████████| 1563/1563 [02:08<00:00, 12.20it/s]


Epoch [32/50], Loss: 8.7354
Epoch 32/50 - Perplexity: 7195.504554
EarlyStopping counter: 1 out of 5


100%|██████████| 1563/1563 [02:08<00:00, 12.20it/s]


Epoch [33/50], Loss: 8.7043
Epoch 33/50 - Perplexity: 7205.505086
EarlyStopping counter: 2 out of 5


100%|██████████| 1563/1563 [02:08<00:00, 12.20it/s]


Epoch [34/50], Loss: 8.7042
Epoch 34/50 - Perplexity: 7200.551284
EarlyStopping counter: 3 out of 5


100%|██████████| 1563/1563 [02:08<00:00, 12.20it/s]


Epoch [35/50], Loss: 8.7354
Epoch 35/50 - Perplexity: 7181.211498
Validation loss decreased (8.879729 --> 8.879223).  Saving model ...


100%|██████████| 1563/1563 [02:08<00:00, 12.20it/s]


Epoch [36/50], Loss: 8.7175
Epoch 36/50 - Perplexity: 7147.298655
Validation loss decreased (8.879223 --> 8.874490).  Saving model ...


100%|██████████| 1563/1563 [02:08<00:00, 12.18it/s]


Epoch [37/50], Loss: 8.7354
Epoch 37/50 - Perplexity: 7141.136954
Validation loss decreased (8.874490 --> 8.873627).  Saving model ...


100%|██████████| 1563/1563 [02:08<00:00, 12.19it/s]


Epoch [38/50], Loss: 8.7667
Epoch 38/50 - Perplexity: 7119.064970
Validation loss decreased (8.873627 --> 8.870532).  Saving model ...


100%|██████████| 1563/1563 [02:08<00:00, 12.18it/s]


Epoch [39/50], Loss: 8.7042
Epoch 39/50 - Perplexity: 7133.990782
EarlyStopping counter: 1 out of 5


100%|██████████| 1563/1563 [02:08<00:00, 12.20it/s]


Epoch [40/50], Loss: 8.7354
Epoch 40/50 - Perplexity: 7147.621151
EarlyStopping counter: 2 out of 5


100%|██████████| 1563/1563 [02:08<00:00, 12.19it/s]


Epoch [41/50], Loss: 8.7042
Epoch 41/50 - Perplexity: 7108.413663
Validation loss decreased (8.870532 --> 8.869034).  Saving model ...


100%|██████████| 1563/1563 [02:08<00:00, 12.20it/s]


Epoch [42/50], Loss: 8.7042
Epoch 42/50 - Perplexity: 7134.516708
EarlyStopping counter: 1 out of 5


100%|██████████| 1563/1563 [02:08<00:00, 12.19it/s]


Epoch [43/50], Loss: 8.7042
Epoch 43/50 - Perplexity: 7134.201174
EarlyStopping counter: 2 out of 5


100%|██████████| 1563/1563 [02:08<00:00, 12.20it/s]


Epoch [44/50], Loss: 8.7351
Epoch 44/50 - Perplexity: 7107.376060
Validation loss decreased (8.869034 --> 8.868888).  Saving model ...


100%|██████████| 1563/1563 [02:08<00:00, 12.20it/s]


Epoch [45/50], Loss: 8.7042
Epoch 45/50 - Perplexity: 7111.834872
EarlyStopping counter: 1 out of 5


100%|██████████| 1563/1563 [02:08<00:00, 12.20it/s]


Epoch [46/50], Loss: 8.7042
Epoch 46/50 - Perplexity: 7112.617610
EarlyStopping counter: 2 out of 5


100%|██████████| 1563/1563 [02:08<00:00, 12.21it/s]


Epoch [47/50], Loss: 8.7042
Epoch 47/50 - Perplexity: 7126.282333
EarlyStopping counter: 3 out of 5


100%|██████████| 1563/1563 [02:07<00:00, 12.21it/s]


Epoch [48/50], Loss: 8.7042
Epoch 48/50 - Perplexity: 7095.597942
Validation loss decreased (8.868888 --> 8.867230).  Saving model ...


100%|██████████| 1563/1563 [02:08<00:00, 12.20it/s]


Epoch [49/50], Loss: 8.7042
Epoch 49/50 - Perplexity: 7115.973756
EarlyStopping counter: 1 out of 5


100%|██████████| 1563/1563 [02:07<00:00, 12.21it/s]


Epoch [50/50], Loss: 8.7042
Epoch 50/50 - Perplexity: 7135.911362
EarlyStopping counter: 2 out of 5


  model.load_state_dict(torch.load('checkpoint.pth'))


<All keys matched successfully>

# Step 3 generating text

Regular test sentence using argmax

In [31]:
# We want padding at the start of the sentence for testing:
from torch.utils.data import DataLoader, TensorDataset
def TorchTestLoaderRNN(training_sequences, max_sequence_length, batch_size = 1):
  # Padding
  padded_sequences = [([3] * (max_sequence_length - len(sequence)-1) + sequence) for sequence in training_sequences] # PADDING has integer code 3

  # Convert lists to tensors
  context_tensor = torch.tensor(padded_sequences, dtype=torch.long)  # Shape: (num_samples, 3)

  # Create a TensorDataset
  dataset = TensorDataset(context_tensor)

  # Create a DataLoader for batching
  dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

  return dataloader


In [56]:
test_sentence = "he lives in san"

encoded_sentence = [vocab_builder.get_token_id(word) for word in test_sentence.split(" ")]

test_sentence_loader= TorchTestLoaderRNN([encoded_sentence], 50)
for sentence in test_sentence_loader:
  encoded_sentence = sentence[0][:,:]
  encoded_sentence = encoded_sentence.to(device)
  output = model(encoded_sentence).detach()

# Predict
prediction = torch.argmax(output)

print(vocab_builder.get_token_str(prediction.item()))

32


Random algorithm

In [58]:
from torch.distributions import Categorical
def random_sampling(model, prompt, max_length, temperature, topk):
    # First, encode the input
    encoded_prompt = [vocab_builder.get_token_id(word) for word in prompt.split(" ")]

    test_sentence_loader= TorchTestLoaderRNN([encoded_prompt], 50)
    for sentence in test_sentence_loader:
      encoded_sentence = sentence[0][:,:]
      encoded_sentence = encoded_sentence.to(device)
      logits = model(encoded_sentence).detach()

    # Apply temperature
    softmax = torch.nn.Softmax()
    tempered_logits = softmax(logits / temperature)

    # Apply topk
    topk_logits = torch.topk(tempered_logits, k=topk)

    # Sample from the distribution
    distribution = Categorical(logits=topk_logits[0])
    prediction = distribution.sample()
    encoded_prompt.append(prediction.item())

    end_of_sentence = (prediction == 1)
    words_generated = 1

    # Repeat with its own outputs:
    while words_generated < max_length and not end_of_sentence:
        test_sentence_loader= TorchTestLoaderRNN([encoded_prompt], 50)
        # The logits
        for sentence in test_sentence_loader:
          encoded_sentence = sentence[0][:,:]
          encoded_sentence = encoded_sentence.to(device)
          logits = model(encoded_sentence).detach()

        # Apply temperature
        tempered_logits = softmax(logits / temperature)

        # Apply topk
        topk_logits = torch.topk(tempered_logits, k=topk)

        # Sample from the distribution
        distribution = Categorical(logits=topk_logits[0])
        prediction = distribution.sample()
        encoded_prompt.append(prediction.item())

        # Check if end of sentence and update word counter
        if prediction == 1:
            end_of_sentence = True
        words_generated += 1
    return [vocab_builder.get_token_str(word) for word in encoded_prompt]

# Test it
print(random_sampling(model, "he lives in san", 30, 0.001, 10)) # Sanity check
print(random_sampling(model, "he lives in san", 30, 0.5, 5))
print(random_sampling(model, "which is very", 30, 1, 5))
print(random_sampling(model, "which is very", 30, 2, 10))

['he', 'lives', 'in', 'san', 'PADDING', '.', 'PADDING', 'BEGINNING', 'BEGINNING', 'BEGINNING', 'END']
['he', 'lives', 'in', 'san', 'PADDING', 'END']
['which', 'is', 'very', 'PADDING', 'BEGINNING', 'PADDING', 'UNKNOWN', 'the', 'the', 'UNKNOWN', 'UNKNOWN', 'PADDING', 'the', 'PADDING', 'UNKNOWN', 'END']
['which', 'is', 'very', 'END']


  return self._call_impl(*args, **kwargs)
