In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import os
import math

from torch import nn, optim
from torch.utils.data import TensorDataset, DataLoader, ConcatDataset
from matplotlib.lines import Line2D
from collections import Counter
from datetime import datetime
from utils import train, compute_accuracy

seed = 265
torch.manual_seed(seed)

device = (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
print(f"Training on device {device}.")

Training on device cpu.


In [2]:
# ----------------------- Load Vocab, Embedding and Texts ----------------------------
PATH_GENERATED = './generated/'

vocab     = torch.load(PATH_GENERATED + 'vocabulary.pt', map_location=torch.device(device))
embedding = torch.load(PATH_GENERATED + 'embedding.pt', map_location=torch.device(device))
(VOCAB_SIZE, embedding_dim) = embedding.weight.shape  

words_train = torch.load(PATH_GENERATED + "words_train.pt", map_location=torch.device(device))
words_val   = torch.load(PATH_GENERATED + "words_val.pt", map_location=torch.device(device))
words_test  = torch.load(PATH_GENERATED + "words_test.pt", map_location=torch.device(device))

In [3]:
# ----------------------- Datasets ----------------------------
CONTEXT_SIZE = 10
not_words = [',', '.', '(', ')', '?', '!', '<unk>']
not_words_idx = vocab.lookup_indices(not_words)

def create_dataset(text, vocab, context_size=CONTEXT_SIZE, dataset_limit=50):
    """
    Create targets-contexts pairs where the targets are valid words from the given text up to dataset_limit.
    """
    txt = [vocab[w] for w in text]
    n_text = len(text)
    word_bank = {}

    contexts = []
    targets = []
    for i in range(n_text - context_size):
        
        t = txt[i + context_size]
        w = vocab.lookup_token(t)

        if w in not_words: continue
        if w not in word_bank: word_bank[w] = 0
        if word_bank[w] > dataset_limit: continue
        word_bank[w] += 1

        c = txt[i:i + context_size]
        
        targets.append(t) 
        contexts.append(torch.tensor(c).to(device=device))
            
    contexts = torch.stack(contexts)
    targets = torch.tensor(targets).to(device=device)
    return TensorDataset(contexts, targets)

def load_dataset(words, vocab, fname):
    """
    Load dataset if its already generated, otherwise, create it and save it
    """
    
    if os.path.isfile(PATH_GENERATED + fname):
        dataset = torch.load(PATH_GENERATED + fname, map_location=torch.device(device))
    else:
        dataset = create_dataset(words, vocab)
        torch.save(dataset, PATH_GENERATED + fname)
    return dataset

data_train_gen = load_dataset(words_train, vocab, "gen_data_train.pt")
data_val_gen   = load_dataset(words_val, vocab, "gen_data_val.pt")
data_test_gen  = load_dataset(words_test, vocab, "gen_data_test.pt")

In [4]:
# ----------------------- Balance the training set ----------------------------
def count_freqs(words, vocab):
    freqs = torch.zeros(len(vocab), dtype=torch.int)
    for w in words:
        freqs[vocab[w]] += 1
    return freqs

def calculate_word_weights(freqs):
    """
    Calculate the weight of each word so that the loss function can weigh 
    frequent words less and unfrequent words more.
    """
    total_words = sum(freqs)
    word_weights = [total_words / (len(freqs)* freq) for freq in freqs]
    word_weights = torch.tensor(word_weights, dtype=torch.float).to(device=device)
    return word_weights

target_words_idx = data_train_gen[:][1].tolist()
target_words = [vocab.lookup_token(i) for i in target_words_idx]
freqs = count_freqs(target_words, vocab)
word_weigts = calculate_word_weights(freqs)

In [5]:
# ---------------- RNN hyper-parameters ----------------------- 
lrs = [0.01, 0.001]
rnn_hparams = [{
    'lr': lr,
} for lr in lrs]

# ---------------- Dataloaders -----------------------
train_loader = DataLoader(data_train_gen, batch_size=128, shuffle=True)
val_loader = DataLoader(data_val_gen, batch_size=128, shuffle=True)
test_loader = DataLoader(data_test_gen, batch_size=128, shuffle=True)

In [6]:
# ----------------------- RNN model ----------------------------
class RNN(nn.Module):
    def __init__(self, embedding):
        super().__init__()

        self.embedding = nn.Embedding(VOCAB_SIZE, embedding_dim)
        self.embedding.load_state_dict(embedding.state_dict())
        for p in self.embedding.parameters():
            p.requires_grad = False

        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=VOCAB_SIZE, batch_first=True)

    def forward(self, x):
        x = self.embedding(x)
        out, (h_n, c_n) = self.lstm(x)
        return h_n[-1]

In [7]:
# ----------------------- Train all models ----------------------------
def train_all_models():
    models = []
    train_losses = []
    accuracies = []
    params = []
    print("Now training a RNN model")
    for param in rnn_hparams:
        print(f"Training using parameters: {param}")
        torch.manual_seed(seed)
        model = RNN(embedding)
        loss_fn = nn.CrossEntropyLoss(weight=word_weigts)
        optimizer = optim.Adam(model.parameters(), **param)
        n_epochs = 5
        
        loss = train(n_epochs, optimizer, model, loss_fn, train_loader, device, yield_tokens=not_words_idx)
        accuracy = compute_accuracy(model, val_loader, device)

        models.append(model)
        train_losses.append(loss)
        accuracies.append(accuracy)
        params.append(param)
        print()
    return models, train_losses, accuracies, params

# ----------------------- Select the best model ----------------------------
def select_best_model(models, accuracies, params):
    best_idx = accuracies.index(max(accuracies))
    best_model = models[best_idx]
    best_param = params[best_idx]

    # ----------------------- Retrain the best model ----------------------------
    print(f"Training using parameters: {best_param}")
    loss_fn = nn.CrossEntropyLoss(weight=word_weigts)
    optimizer = optim.Adam(best_model.parameters(), **best_param)
    n_epochs = 20

    loss = train(n_epochs, optimizer, best_model, loss_fn, train_loader, device, yield_tokens=not_words_idx)
    accuracy = compute_accuracy(best_model, test_loader, device)

    return best_model, loss, accuracy

In [8]:
# ----------------------- Best Model -------------------------------
if os.path.isfile(PATH_GENERATED + 'best_model_generation.pt'):
    best_model = torch.load(PATH_GENERATED + 'best_model_generation.pt', map_location=torch.device(device))
    accuracy = torch.load(PATH_GENERATED + 'best_model_generation_accuracy.pt', map_location=torch.device(device))
else:
    models, train_losses, val_accs, params = train_all_models()
    best_model, best_model_loss, accuracy = select_best_model(models, val_accs, params)
    torch.save(best_model, PATH_GENERATED + 'best_model_generation.pt')
    torch.save(accuracy, PATH_GENERATED + 'best_model_generation_accuracy.pt')

print(f"The accuracy of the best model on the test set is {round(accuracy, 2)}")

The accuracy of the best model on the test set is 0.01


In [9]:
# ----------------------- Predicting sentence with Beam Search -------------------------------
def beam_search(input, n_output, model):
    """
    Return the n_output best predictions from the model given input.
    """
    input = torch.tensor(vocab.lookup_indices(input))
    output = model(input)
    
    out = torch.topk(output, n_output)
    return out

def predict_sentence(input, beam_width, n_predict, model):
    """
    Predicts the n_predict next words of the model given input using beam search with given beam width.
    """
    input = [input.split()]
    for _ in range(n_predict):
        values = torch.empty(0)
        indices = torch.empty(0)
        sentences = []
        for sentence in input:
            prediction = beam_search(sentence, beam_width, model)
            values = torch.cat((values, prediction.values))
            indices = torch.cat((indices, prediction.indices))
            sentences.append(sentence)

        new_input = []
        best_predictions = torch.topk(values, beam_width)
        origin_sen = [int(idx/beam_width) for idx in best_predictions.indices]
        words = [vocab.lookup_token(indices[idx]) for idx in best_predictions.indices]
        for i, word in enumerate(words):
            new_sentence = sentences[origin_sen[i]].copy()
            new_sentence.append(word)
            new_input.append(new_sentence)
        input = new_input
    return " ".join(input[0])

In [10]:
# ----------------------- Predicting a few sentences -------------------------------
print(f"Testing sentence generation by guessing the next word with beam width: 5\n")
print(predict_sentence("humans are", 5, 1, best_model))
print(predict_sentence("cats are", 5, 1, best_model))
print(predict_sentence("dogs are", 5, 1, best_model))
print(predict_sentence("birds are", 5, 1, best_model))
print(predict_sentence("horses are", 5, 1, best_model))
print("\n\n")

print(f"Testing sentence generation by guessing the next two words with beam width: 10\n")
print(predict_sentence("the tall mountain is", 10, 2, best_model))
print(predict_sentence("todays weather is", 10, 2, best_model))
print(predict_sentence("my old house is", 10, 2, best_model))
print(predict_sentence("the language model is", 10, 2, best_model))
print("\n\n")

print(f"Testing sentence generation by guessing the next ten words with beam width: 10\n")
print(predict_sentence("what", 10, 10, best_model))
print(predict_sentence("no", 10, 10, best_model))
print(predict_sentence("yes", 10, 10, best_model))
print(predict_sentence("husband", 10, 10, best_model))
print(predict_sentence("wife", 10, 10, best_model))

Testing sentence generation by guessing the next word with beam width: 5

humans are gentlemen
cats are gentlemen
dogs are edge
birds are torn
horses are enemy



Testing sentence generation by guessing the next two words with beam width: 10

the tall mountain is forty attack
todays weather is sown south
my old house is forty mother
the language model is locked strength



Testing sentence generation by guessing the next ten words with beam width: 10

what scarcely t do . exclaimed exclaimed exclaimed ye t ye
no slight influence occurred thank horror required importance required what o
yes eastward bring farther late sown sown south south south south
husband t re apart . sown sown sown sown ft pieces
wife addressed sorry moved kingdom other sorry sorry sorry sorry seat
