Installation

In [1]:
!pip install torch==2.0.1 torchtext==0.15.2

Collecting torch==2.0.1
  Downloading torch-2.0.1-cp310-cp310-manylinux1_x86_64.whl.metadata (24 kB)
Collecting torchtext==0.15.2
  Downloading torchtext-0.15.2-cp310-cp310-manylinux1_x86_64.whl.metadata (7.4 kB)
Collecting nvidia-cuda-nvrtc-cu11==11.7.99 (from torch==2.0.1)
  Downloading nvidia_cuda_nvrtc_cu11-11.7.99-2-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu11==11.7.99 (from torch==2.0.1)
  Downloading nvidia_cuda_runtime_cu11-11.7.99-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cuda-cupti-cu11==11.7.101 (from torch==2.0.1)
  Downloading nvidia_cuda_cupti_cu11-11.7.101-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu11==8.5.0.96 (from torch==2.0.1)
  Downloading nvidia_cudnn_cu11-8.5.0.96-2-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu11==11.10.3.66 (from torch==2.0.1)
  Downloading nvidia_cublas_cu11-11.10.3.66-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Co

In [2]:
!pip install datasets



Import statements

In [3]:
import torchtext
import string
import nltk
import re
import html
import random
import subprocess
from nltk.tokenize import word_tokenize, sent_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from datasets import load_dataset
from tqdm import tqdm
from collections import defaultdict
import zipfile
import os
import math
from random import shuffle

In [4]:
def split_dataset(file_path, train_ratio=0.7, val_ratio=0.2, test_ratio=0.1):
    sentences = []
    
    with open(file_path, 'r') as f:
        para = ""
        for line in tqdm(f, desc="Splitting dataset"):
            if line.strip():
                para += line.strip() + " "
            else:
                if para:
                    sentences.extend(sent_tokenize(para))
                    para = ""
        if para:
            sentences.extend(sent_tokenize(para))

    shuffle(sentences)

    total_sentences = len(sentences)
    train_size = int(total_sentences * train_ratio)
    val_size = int(total_sentences * val_ratio)
    
    train_sentences = sentences[:train_size]
    val_sentences = sentences[train_size:train_size + val_size]
    test_sentences = sentences[train_size + val_size:]

    return train_sentences, val_sentences, test_sentences

In [5]:
def save_datasets(train_sentences, val_sentences, test_sentences):
    with open('train.txt', 'w') as f:
        f.writelines([s + '\n' for s in train_sentences])
    with open('dev.txt', 'w') as f:
        f.writelines([s + '\n' for s in val_sentences])
    with open('test.txt', 'w') as f:
        f.writelines([s + '\n' for s in test_sentences])

In [6]:
train_sentences, val_sentences, test_sentences = split_dataset('/kaggle/input/dataset/Auguste_Maquet.txt')
save_datasets(train_sentences, val_sentences, test_sentences)

Splitting dataset: 128612it [00:02, 46237.92it/s]


In [7]:
def get_embeddings(emb_file='glove.6B.300d.txt'):
    unk_emb = torch.zeros(300)  # Placeholder embedding for unknown words
    embeddings = defaultdict(lambda: unk_emb)

    with open(emb_file, 'r', encoding='ISO-8859-1') as f:
        for line in tqdm(f, desc="Reading embeddings"):
            try:
                split = line.strip().split()
                word = split[0]
                vector = torch.tensor([float(x) for x in split[1:]])  # Corrected line
                embeddings[word] = vector
            except ValueError as e:
                continue

    return embeddings

In [21]:
embeddings = get_embeddings('/kaggle/input/glove/pytorch/default/1/glove.6B.300d.txt')

Reading embeddings: 400000it [01:13, 5433.58it/s]


In [22]:
class TextData(Dataset):
    def __init__(self, file_path='train.txt', pretrained_emb_dict=embeddings,
                 frequency_cutoff=1, context_size=5, vocab=None):
        self.file_path = file_path
        self.frequency_cutoff = frequency_cutoff
        self.context_size = context_size

        self.contexts = []
        self.words = []

        self.frequency_dictionary = defaultdict(lambda: 0)
        self.vocab = vocab if vocab else []

        self.words2indices = {}
        self.embeddings = pretrained_emb_dict

        with open(self.file_path, 'r') as f:
            for line in tqdm(f, desc="Obtaining vocabulary and freq counts"):
                words = [word.lower() for word in word_tokenize(line)]
                if not vocab:
                    self.vocab += words
                for word in words:
                    self.frequency_dictionary[word] += 1

            if not vocab:
                self.vocab = list(set(self.vocab))
                self.vocab = [word for word in self.vocab if self.frequency_dictionary[word] > self.frequency_cutoff]
                self.vocab.append('<unk>')
            self.words2indices = {w: i for i, w in enumerate(self.vocab)}

        embeddings_list = []
        for word in self.vocab:
            embeddings_list.append(self.embeddings[word])
        embeddings_list.append(self.embeddings['<unk>'])
        self.embeddings = torch.stack(embeddings_list)

        with open(self.file_path, 'r') as f:
            for line in tqdm(f, desc="Creating dataset"):
                words = [word.lower() for word in word_tokenize(line)]
                indices = [self.words2indices[word] if word in self.vocab else (len(self.vocab) - 1)
                           for word in words]
                embeddings = [self.embeddings[i] for i in indices]

                for i in range(len(embeddings) - self.context_size):
                    self.contexts.append(torch.stack(embeddings[i:i + self.context_size]))
                    self.words.append(indices[i + self.context_size])

        self.contexts = torch.stack(self.contexts)
        self.words = torch.tensor(self.words)

    def __getitem__(self, idx):
        return (self.contexts[idx], self.words[idx])

    def __len__(self):
        return len(self.contexts)

In [23]:
train_ds = TextData()
with open('vocab.txt', 'w') as f:
    for word in train_ds.vocab:
        f.write(word + '\n')

test_ds = TextData('test.txt', vocab=train_ds.vocab)
dev_ds = TextData('dev.txt', vocab=train_ds.vocab)

Obtaining vocabulary and freq counts: 39555it [00:09, 4059.13it/s]
Creating dataset: 39555it [02:02, 322.10it/s]
Obtaining vocabulary and freq counts: 5652it [00:01, 4002.83it/s]
Creating dataset: 5652it [00:17, 321.76it/s]
Obtaining vocabulary and freq counts: 11301it [00:02, 3972.54it/s]
Creating dataset: 11301it [00:36, 312.03it/s]


In [35]:
class LSTM_LanguageModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim=300, hidden_dim=300, num_layers=2, padding_idx=0):
        super(LSTM_LanguageModel, self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.padding_idx = padding_idx
        
        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, batch_ctx):
        batch_size = batch_ctx.size(0)
        h0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(batch_ctx.device)  # Hidden state
        c0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(batch_ctx.device)  # Cell state

        lstm_out, (hn, cn) = self.lstm(batch_ctx, (h0, c0))  # lstm_out shape: (batch_size, context_size, hidden_dim)

        last_hidden_state = lstm_out[:, -1, :]  # Last time step
        logits = self.fc(last_hidden_state)

        return logits

    def train_epoch(self, dl, optimiser, loss_fn):
        super().train()  # Ensure model is in training mode
        for batch in tqdm(dl):
            optimiser.zero_grad()
            contexts, words = batch

            # Forward pass
            logits = self.forward(contexts)

            # Compute loss, ignoring padding tokens
            loss = loss_fn(logits, words)
            loss.backward()

            optimiser.step()

    def train(self, num_epochs, lr=0.1):
        optimiser = torch.optim.SGD(self.parameters(), lr=lr)
        # Use ignore_index to skip padding tokens in the loss computation
        loss_fn = nn.CrossEntropyLoss(ignore_index=self.padding_idx)
        train_dl = DataLoader(train_ds, batch_size=128)
        dev_dl = DataLoader(dev_ds, batch_size=128)

        for epoch in range(num_epochs):
            print("Epoch:", epoch + 1)
            self.train_epoch(train_dl, optimiser, loss_fn)
            train_loss = self.get_loss(train_dl, loss_fn)
            print("Loss on train set:", train_loss)
            val_loss = self.get_loss(dev_dl, loss_fn)
            print("Loss on validation set:", val_loss)
            train_perp = self.get_perp(train_dl)
            print("Perplexity on train set:", train_perp)
            val_perp = self.get_perp(dev_dl)
            print("Perplexity on validation set:", val_perp)

        return train_perp, val_perp

    def get_loss(self, dl, loss_fn):
        total_loss = 0
        total_samples = 0

        for batch in tqdm(dl):
            contexts, words = batch
            pred = self.forward(contexts)
            loss = loss_fn(pred, words)
            total_loss += loss.item() * len(words)
            total_samples += (words != self.padding_idx).sum().item()  # Exclude padding tokens

        avg_loss = total_loss / total_samples
        return avg_loss

    def get_perp(self, dl, filename='perplexity_output.txt'):
        total_loss = 0
        total_samples = 0
        loss_fn = nn.CrossEntropyLoss(reduction='sum')
        self.eval()

        sentence_perplexities = []
        with open(filename, 'w') as f, torch.no_grad():
            for batch in tqdm(dl):
                contexts, words = batch

                # Forward pass
                pred = self.forward(contexts)

                # Calculate loss for the batch
                loss = loss_fn(pred, words)
                
                # Calculate perplexity for the batch
                perplexity = torch.exp(loss / len(words))

                # Format the contexts (as a sentence) and their corresponding perplexity
                sentence = ' '.join([train_ds.vocab[idx] for idx in words.tolist()])
                f.write(f"{sentence}\t{perplexity.item()}\n")
                
                total_loss += loss.item()
                total_samples += len(words)

                sentence_perplexities.append(perplexity.item())

            avg_loss = total_loss / total_samples
            avg_perplexity = math.exp(avg_loss)

            f.write(f"Average perplexity: {avg_perplexity}\n")
            print(f"Average perplexity: {avg_perplexity}")

        return avg_perplexity


In [12]:
lstm_lm = LSTM_LanguageModel(len(train_ds.vocab))
lstm_lm.train(10)

torch.save(lstm_lm, '10epochs_lstm.pth')

test_dl = DataLoader(test_ds, batch_size=128)
perp = lstm_lm.get_perp(test_dl)
print(perp)

Epoch: 1


100%|██████████| 5012/5012 [04:46<00:00, 17.48it/s]
100%|██████████| 5012/5012 [01:37<00:00, 51.48it/s]


Loss on train set: 5.747939432516416


100%|██████████| 1452/1452 [00:28<00:00, 51.01it/s]


Loss on validation set: 5.69088221584073


100%|██████████| 5012/5012 [01:38<00:00, 51.02it/s]


Perplexity on train set: 313.54391578664263


100%|██████████| 1452/1452 [00:28<00:00, 51.23it/s]


Perplexity on validation set: 296.1547778613643
Epoch: 2


100%|██████████| 5012/5012 [04:51<00:00, 17.19it/s]
100%|██████████| 5012/5012 [01:37<00:00, 51.22it/s]


Loss on train set: 5.328265357257279


100%|██████████| 1452/1452 [00:28<00:00, 50.98it/s]


Loss on validation set: 5.276204497284666


100%|██████████| 5012/5012 [01:38<00:00, 50.84it/s]


Perplexity on train set: 206.0801884275392


100%|██████████| 1452/1452 [00:28<00:00, 50.80it/s]


Perplexity on validation set: 195.6259656054754
Epoch: 3


100%|██████████| 5012/5012 [04:46<00:00, 17.49it/s]
100%|██████████| 5012/5012 [01:38<00:00, 50.72it/s]


Loss on train set: 5.113164464993


100%|██████████| 1452/1452 [00:28<00:00, 50.30it/s]


Loss on validation set: 5.068552115595233


100%|██████████| 5012/5012 [01:39<00:00, 50.22it/s]


Perplexity on train set: 166.19544328587557


100%|██████████| 1452/1452 [00:29<00:00, 49.80it/s]


Perplexity on validation set: 158.94402808068557
Epoch: 4


100%|██████████| 5012/5012 [04:47<00:00, 17.42it/s]
100%|██████████| 5012/5012 [01:38<00:00, 50.69it/s]


Loss on train set: 4.982197419356609


100%|██████████| 1452/1452 [00:28<00:00, 50.58it/s]


Loss on validation set: 4.945078656192627


100%|██████████| 5012/5012 [01:38<00:00, 50.67it/s]


Perplexity on train set: 145.79440135571838


100%|██████████| 1452/1452 [00:28<00:00, 50.74it/s]


Perplexity on validation set: 140.48190018829288
Epoch: 5


100%|██████████| 5012/5012 [04:47<00:00, 17.42it/s]
100%|██████████| 5012/5012 [01:39<00:00, 50.26it/s]


Loss on train set: 4.885199653671509


100%|██████████| 1452/1452 [00:28<00:00, 50.11it/s]


Loss on validation set: 4.8569092500773685


100%|██████████| 5012/5012 [01:40<00:00, 49.63it/s]


Perplexity on train set: 132.3168802482339


100%|██████████| 1452/1452 [00:28<00:00, 50.39it/s]


Perplexity on validation set: 128.62603619728955
Epoch: 6


100%|██████████| 5012/5012 [04:50<00:00, 17.28it/s]
100%|██████████| 5012/5012 [01:38<00:00, 50.79it/s]


Loss on train set: 4.80463574523882


100%|██████████| 1452/1452 [00:28<00:00, 50.21it/s]


Loss on validation set: 4.785755237756645


100%|██████████| 5012/5012 [01:40<00:00, 49.98it/s]


Perplexity on train set: 122.07501651566169


100%|██████████| 1452/1452 [00:29<00:00, 49.72it/s]


Perplexity on validation set: 119.79180021925261
Epoch: 7


100%|██████████| 5012/5012 [04:54<00:00, 17.00it/s]
100%|██████████| 5012/5012 [01:40<00:00, 50.08it/s]


Loss on train set: 4.737497588246888


100%|██████████| 1452/1452 [00:28<00:00, 50.42it/s]


Loss on validation set: 4.728671559450408


100%|██████████| 5012/5012 [01:40<00:00, 49.83it/s]


Perplexity on train set: 114.14819818741476


100%|██████████| 1452/1452 [00:30<00:00, 48.13it/s]


Perplexity on validation set: 113.14515585470903
Epoch: 8


100%|██████████| 5012/5012 [04:51<00:00, 17.21it/s]
100%|██████████| 5012/5012 [01:39<00:00, 50.35it/s]


Loss on train set: 4.680494013938674


100%|██████████| 1452/1452 [00:28<00:00, 50.41it/s]


Loss on validation set: 4.682366941078469


100%|██████████| 5012/5012 [01:38<00:00, 50.70it/s]


Perplexity on train set: 107.82332564221923


100%|██████████| 1452/1452 [00:28<00:00, 50.09it/s]


Perplexity on validation set: 108.02546010759241
Epoch: 9


100%|██████████| 5012/5012 [04:46<00:00, 17.49it/s]
100%|██████████| 5012/5012 [01:38<00:00, 50.87it/s]


Loss on train set: 4.631396518729677


100%|██████████| 1452/1452 [00:28<00:00, 50.26it/s]


Loss on validation set: 4.644315670896575


100%|██████████| 5012/5012 [01:39<00:00, 50.56it/s]


Perplexity on train set: 102.65732693240271


100%|██████████| 1452/1452 [00:28<00:00, 50.51it/s]


Perplexity on validation set: 103.99217655681672
Epoch: 10


100%|██████████| 5012/5012 [04:45<00:00, 17.54it/s]
100%|██████████| 5012/5012 [01:37<00:00, 51.33it/s]


Loss on train set: 4.588082728900541


100%|██████████| 1452/1452 [00:28<00:00, 51.47it/s]


Loss on validation set: 4.612161095689533


100%|██████████| 5012/5012 [01:37<00:00, 51.42it/s]


Perplexity on train set: 98.30577055114044


100%|██████████| 1452/1452 [00:27<00:00, 52.31it/s]


Perplexity on validation set: 100.70154031545744


100%|██████████| 699/699 [00:13<00:00, 50.76it/s]

99.41911908367919





In [27]:
test_dl = DataLoader(test_ds, batch_size=128)

In [38]:
def experiment(train_ds, dev_ds, test_ds, vocab_size):
    train_dl = DataLoader(train_ds, batch_size=128)
    dev_dl = DataLoader(dev_ds, batch_size=128)
    test_dl = DataLoader(test_ds, batch_size=128)

    hyperparams = [
        {'lr': 0.01, 'hidden_dim': 300, 'num_layers': 2, 'optim': 'SGD'},
        {'lr': 0.001, 'hidden_dim': 400, 'num_layers': 2, 'optim': 'Adam'},
        {'lr': 0.01, 'hidden_dim': 500, 'num_layers': 3, 'optim': 'SGD'}
    ]

    train_perplexities, dev_perplexities, test_perplexities = [], [], []

    for params in hyperparams:
        model = LSTM_LanguageModel(vocab_size, hidden_dim=params['hidden_dim'],
                                   num_layers=params['num_layers'])
        
        print(f"Training with hyperparams: {params}")
        train_perp, val_perp = model.train(num_epochs=10, lr=params['lr'])
        
        train_perplexities.append(train_perp)
        dev_perplexities.append(val_perp)
        
        test_perp = model.get_perp(test_dl)
        test_perplexities.append(test_perp)
        print(f"Test perplexity: {test_perp:.4f}")
        print("==========================")

    return hyperparams, train_perplexities, dev_perplexities, test_perplexities


In [39]:
vocab_size = len(train_ds.vocab)
hyperparams, train_perplexities, dev_perplexities, test_perplexities = experiment(train_ds, dev_ds, test_ds, vocab_size)

Training with hyperparams: {'lr': 0.01, 'hidden_dim': 300, 'num_layers': 2, 'optim': 'SGD'}
Epoch: 1


100%|██████████| 5009/5009 [04:37<00:00, 18.02it/s]
100%|██████████| 5009/5009 [01:32<00:00, 54.36it/s]


Loss on train set: 6.3540164599362505


100%|██████████| 1437/1437 [00:26<00:00, 53.47it/s]


Loss on validation set: 6.306828181843887


100%|██████████| 5009/5009 [01:36<00:00, 51.97it/s]


Perplexity on train set: 574.7967269437164


100%|██████████| 1437/1437 [00:26<00:00, 54.69it/s]


Perplexity on validation set: 548.3030701781943
Epoch: 2


100%|██████████| 5009/5009 [04:35<00:00, 18.21it/s]
100%|██████████| 5009/5009 [01:33<00:00, 53.73it/s]


Loss on train set: 6.201055275868876


100%|██████████| 1437/1437 [00:26<00:00, 53.82it/s]


Loss on validation set: 6.152244640936316


100%|██████████| 5009/5009 [01:34<00:00, 53.10it/s]


Perplexity on train set: 493.26930172669944


100%|██████████| 1437/1437 [00:27<00:00, 53.14it/s]


Perplexity on validation set: 469.77067069620705
Epoch: 3


100%|██████████| 5009/5009 [04:35<00:00, 18.17it/s]
100%|██████████| 5009/5009 [01:33<00:00, 53.38it/s]


Loss on train set: 6.157831512032196


100%|██████████| 1437/1437 [00:26<00:00, 53.74it/s]


Loss on validation set: 6.108448728524875


100%|██████████| 5009/5009 [01:30<00:00, 55.50it/s]


Perplexity on train set: 472.40256405668475


100%|██████████| 1437/1437 [00:26<00:00, 54.91it/s]


Perplexity on validation set: 449.6406591418234
Epoch: 4


100%|██████████| 5009/5009 [04:32<00:00, 18.40it/s]
100%|██████████| 5009/5009 [01:29<00:00, 55.87it/s]


Loss on train set: 6.12933931142993


100%|██████████| 1437/1437 [00:25<00:00, 55.51it/s]


Loss on validation set: 6.0792917560087405


100%|██████████| 5009/5009 [01:29<00:00, 55.70it/s]


Perplexity on train set: 459.1327167119147


100%|██████████| 1437/1437 [00:25<00:00, 55.57it/s]


Perplexity on validation set: 436.7197810003642
Epoch: 5


100%|██████████| 5009/5009 [04:24<00:00, 18.92it/s]
100%|██████████| 5009/5009 [01:32<00:00, 54.20it/s]


Loss on train set: 6.087814528303627


100%|██████████| 1437/1437 [00:27<00:00, 52.74it/s]


Loss on validation set: 6.037240589737055


100%|██████████| 5009/5009 [01:35<00:00, 52.30it/s]


Perplexity on train set: 440.4577505557706


100%|██████████| 1437/1437 [00:28<00:00, 51.23it/s]


Perplexity on validation set: 418.735974872696
Epoch: 6


100%|██████████| 5009/5009 [04:36<00:00, 18.12it/s]
100%|██████████| 5009/5009 [01:35<00:00, 52.36it/s]


Loss on train set: 6.014528247838417


100%|██████████| 1437/1437 [00:27<00:00, 52.13it/s]


Loss on validation set: 5.963981943245675


100%|██████████| 5009/5009 [01:35<00:00, 52.30it/s]


Perplexity on train set: 409.3326897797428


100%|██████████| 1437/1437 [00:27<00:00, 52.20it/s]


Perplexity on validation set: 389.15664278503954
Epoch: 7


100%|██████████| 5009/5009 [04:43<00:00, 17.64it/s]
100%|██████████| 5009/5009 [01:37<00:00, 51.46it/s]


Loss on train set: 5.93778224093169


100%|██████████| 1437/1437 [00:28<00:00, 51.32it/s]


Loss on validation set: 5.886528737381191


100%|██████████| 5009/5009 [01:38<00:00, 50.96it/s]


Perplexity on train set: 379.0932590589364


100%|██████████| 1437/1437 [00:27<00:00, 52.02it/s]


Perplexity on validation set: 360.1529266072926
Epoch: 8


100%|██████████| 5009/5009 [04:46<00:00, 17.49it/s]
100%|██████████| 5009/5009 [01:36<00:00, 51.70it/s]


Loss on train set: 5.854466985510365


100%|██████████| 1437/1437 [00:28<00:00, 51.09it/s]


Loss on validation set: 5.802483651614813


100%|██████████| 5009/5009 [01:38<00:00, 50.68it/s]


Perplexity on train set: 348.7889409395917


100%|██████████| 1437/1437 [00:28<00:00, 49.94it/s]


Perplexity on validation set: 331.1209285195081
Epoch: 9


100%|██████████| 5009/5009 [04:47<00:00, 17.44it/s]
100%|██████████| 5009/5009 [01:36<00:00, 52.09it/s]


Loss on train set: 5.7694255164288375


100%|██████████| 1437/1437 [00:27<00:00, 51.36it/s]


Loss on validation set: 5.717125480677383


100%|██████████| 5009/5009 [01:38<00:00, 50.71it/s]


Perplexity on train set: 320.35364186963955


100%|██████████| 1437/1437 [00:28<00:00, 51.01it/s]


Perplexity on validation set: 304.0297263517197
Epoch: 10


100%|██████████| 5009/5009 [04:49<00:00, 17.28it/s]
100%|██████████| 5009/5009 [01:39<00:00, 50.14it/s]


Loss on train set: 5.7050656540903555


100%|██████████| 1437/1437 [00:28<00:00, 49.80it/s]


Loss on validation set: 5.652790435707794


100%|██████████| 5009/5009 [01:39<00:00, 50.15it/s]


Perplexity on train set: 300.38520091834016


100%|██████████| 1437/1437 [00:28<00:00, 50.31it/s]


Perplexity on validation set: 285.08587073401924


100%|██████████| 717/717 [00:13<00:00, 53.02it/s]


Test perplexity: 289.0636
Training with hyperparams: {'lr': 0.001, 'hidden_dim': 400, 'num_layers': 2, 'optim': 'Adam'}
Epoch: 1


100%|██████████| 5009/5009 [06:37<00:00, 12.59it/s]
100%|██████████| 5009/5009 [02:16<00:00, 36.77it/s]


Loss on train set: 9.319851136009891


100%|██████████| 1437/1437 [00:37<00:00, 38.44it/s]


Loss on validation set: 9.31948883237611


100%|██████████| 5009/5009 [02:13<00:00, 37.42it/s]


Perplexity on train set: 11157.32080648237


100%|██████████| 1437/1437 [00:38<00:00, 36.97it/s]


Perplexity on validation set: 11153.27920079933
Epoch: 2


100%|██████████| 5009/5009 [06:21<00:00, 13.14it/s]
100%|██████████| 5009/5009 [02:14<00:00, 37.18it/s]


Loss on train set: 9.039482193311608


100%|██████████| 1437/1437 [00:37<00:00, 37.83it/s]


Loss on validation set: 9.03722899414485


100%|██████████| 5009/5009 [02:12<00:00, 37.83it/s]


Perplexity on train set: 8429.411120290535


100%|██████████| 1437/1437 [00:36<00:00, 39.60it/s]


Perplexity on validation set: 8410.439359782013
Epoch: 3


100%|██████████| 5009/5009 [06:19<00:00, 13.20it/s]
100%|██████████| 5009/5009 [02:15<00:00, 37.10it/s]


Loss on train set: 7.761687332494042


100%|██████████| 1437/1437 [00:39<00:00, 36.54it/s]


Loss on validation set: 7.745292163062528


100%|██████████| 5009/5009 [02:16<00:00, 36.76it/s]


Perplexity on train set: 2348.8645789684147


100%|██████████| 1437/1437 [00:38<00:00, 37.24it/s]


Perplexity on validation set: 2310.6685172691136
Epoch: 4


100%|██████████| 5009/5009 [06:35<00:00, 12.66it/s]
100%|██████████| 5009/5009 [02:16<00:00, 36.61it/s]


Loss on train set: 7.039148729672666


100%|██████████| 1437/1437 [00:39<00:00, 36.26it/s]


Loss on validation set: 6.996763251996599


100%|██████████| 5009/5009 [02:18<00:00, 36.06it/s]


Perplexity on train set: 1140.4163906697722


100%|██████████| 1437/1437 [00:40<00:00, 35.74it/s]


Perplexity on validation set: 1093.0893715089544
Epoch: 5


100%|██████████| 5009/5009 [06:27<00:00, 12.92it/s]
100%|██████████| 5009/5009 [02:15<00:00, 37.10it/s]


Loss on train set: 6.712701361051769


100%|██████████| 1437/1437 [00:38<00:00, 37.29it/s]


Loss on validation set: 6.665415054755244


100%|██████████| 5009/5009 [02:14<00:00, 37.16it/s]


Perplexity on train set: 822.7902937102484


100%|██████████| 1437/1437 [00:37<00:00, 37.90it/s]


Perplexity on validation set: 784.7891278517193
Epoch: 6


100%|██████████| 5009/5009 [06:24<00:00, 13.04it/s]
100%|██████████| 5009/5009 [02:12<00:00, 37.76it/s]


Loss on train set: 6.56958583691224


100%|██████████| 1437/1437 [00:38<00:00, 37.60it/s]


Loss on validation set: 6.522708934409646


100%|██████████| 5009/5009 [02:14<00:00, 37.35it/s]


Perplexity on train set: 713.0744528500194


100%|██████████| 1437/1437 [00:38<00:00, 37.64it/s]


Perplexity on validation set: 680.4191017414829
Epoch: 7


100%|██████████| 5009/5009 [06:25<00:00, 12.98it/s]
100%|██████████| 5009/5009 [02:14<00:00, 37.22it/s]


Loss on train set: 6.479155372473592


100%|██████████| 1437/1437 [00:38<00:00, 37.72it/s]


Loss on validation set: 6.4321083494501545


100%|██████████| 5009/5009 [02:16<00:00, 36.66it/s]


Perplexity on train set: 651.4205061546937


100%|██████████| 1437/1437 [00:38<00:00, 37.35it/s]


Perplexity on validation set: 621.4828707634249
Epoch: 8


100%|██████████| 5009/5009 [06:36<00:00, 12.65it/s]
100%|██████████| 5009/5009 [02:16<00:00, 36.59it/s]


Loss on train set: 6.4132475092227175


100%|██████████| 1437/1437 [00:37<00:00, 37.91it/s]


Loss on validation set: 6.36604584621064


100%|██████████| 5009/5009 [02:13<00:00, 37.50it/s]


Perplexity on train set: 609.87103038943


100%|██████████| 1437/1437 [00:38<00:00, 37.43it/s]


Perplexity on validation set: 581.7529343708285
Epoch: 9


100%|██████████| 5009/5009 [06:31<00:00, 12.79it/s]
100%|██████████| 5009/5009 [02:15<00:00, 36.93it/s]


Loss on train set: 6.363260813948391


100%|██████████| 1437/1437 [00:37<00:00, 38.00it/s]


Loss on validation set: 6.3158497136498255


100%|██████████| 5009/5009 [02:12<00:00, 37.73it/s]


Perplexity on train set: 580.1349877425141


100%|██████████| 1437/1437 [00:37<00:00, 37.91it/s]


Perplexity on validation set: 553.2719836698568
Epoch: 10


100%|██████████| 5009/5009 [06:23<00:00, 13.07it/s]
100%|██████████| 5009/5009 [02:09<00:00, 38.60it/s]


Loss on train set: 6.325324059687017


100%|██████████| 1437/1437 [00:36<00:00, 39.36it/s]


Loss on validation set: 6.277608202430008


100%|██████████| 5009/5009 [02:08<00:00, 39.13it/s]


Perplexity on train set: 558.5387842443762


100%|██████████| 1437/1437 [00:37<00:00, 38.62it/s]


Perplexity on validation set: 532.5134750012134


100%|██████████| 717/717 [00:18<00:00, 39.41it/s]


Test perplexity: 537.2356
Training with hyperparams: {'lr': 0.01, 'hidden_dim': 500, 'num_layers': 3, 'optim': 'SGD'}
Epoch: 1


100%|██████████| 5009/5009 [12:10<00:00,  6.86it/s]
100%|██████████| 5009/5009 [03:59<00:00, 20.89it/s]


Loss on train set: 6.3593084640194


100%|██████████| 1437/1437 [01:08<00:00, 20.95it/s]


Loss on validation set: 6.3123027113068835


100%|██████████| 5009/5009 [03:58<00:00, 20.98it/s]


Perplexity on train set: 577.8466164658313


100%|██████████| 1437/1437 [01:08<00:00, 21.04it/s]


Perplexity on validation set: 551.3130029559511
Epoch: 2


100%|██████████| 5009/5009 [12:33<00:00,  6.65it/s]
100%|██████████| 5009/5009 [04:07<00:00, 20.27it/s]


Loss on train set: 6.19059168739768


100%|██████████| 1437/1437 [01:09<00:00, 20.57it/s]


Loss on validation set: 6.14129875589718


100%|██████████| 5009/5009 [04:04<00:00, 20.51it/s]


Perplexity on train set: 488.1348440203312


100%|██████████| 1437/1437 [01:11<00:00, 20.21it/s]


Perplexity on validation set: 464.656654715623
Epoch: 3


100%|██████████| 5009/5009 [12:21<00:00,  6.75it/s]
100%|██████████| 5009/5009 [04:01<00:00, 20.72it/s]


Loss on train set: 6.151679601956162


100%|██████████| 1437/1437 [01:09<00:00, 20.67it/s]


Loss on validation set: 6.101709119257059


100%|██████████| 5009/5009 [04:01<00:00, 20.71it/s]


Perplexity on train set: 469.5053069330258


100%|██████████| 1437/1437 [01:08<00:00, 20.99it/s]


Perplexity on validation set: 446.62044574945537
Epoch: 4


100%|██████████| 5009/5009 [12:15<00:00,  6.81it/s]
100%|██████████| 5009/5009 [03:58<00:00, 21.00it/s]


Loss on train set: 6.135954491012101


100%|██████████| 1437/1437 [01:10<00:00, 20.47it/s]


Loss on validation set: 6.085518631935898


100%|██████████| 5009/5009 [04:09<00:00, 20.11it/s]


Perplexity on train set: 462.1800302354338


100%|██████████| 1437/1437 [01:09<00:00, 20.61it/s]


Perplexity on validation set: 439.44766517550084
Epoch: 5


100%|██████████| 5009/5009 [12:34<00:00,  6.64it/s]
100%|██████████| 5009/5009 [04:02<00:00, 20.64it/s]


Loss on train set: 6.127915213369135


100%|██████████| 1437/1437 [01:08<00:00, 20.95it/s]


Loss on validation set: 6.077179918638881


100%|██████████| 5009/5009 [03:56<00:00, 21.19it/s]


Perplexity on train set: 458.47933205273574


100%|██████████| 1437/1437 [01:07<00:00, 21.29it/s]


Perplexity on validation set: 435.79847301541497
Epoch: 6


100%|██████████| 5009/5009 [12:11<00:00,  6.85it/s]
100%|██████████| 5009/5009 [04:00<00:00, 20.79it/s]


Loss on train set: 6.123072539876984


100%|██████████| 1437/1437 [01:09<00:00, 20.68it/s]


Loss on validation set: 6.072145927230588


100%|██████████| 5009/5009 [04:05<00:00, 20.42it/s]


Perplexity on train set: 456.26443368805735


100%|██████████| 1437/1437 [01:10<00:00, 20.45it/s]


Perplexity on validation set: 433.6101797922894
Epoch: 7


100%|██████████| 5009/5009 [12:22<00:00,  6.75it/s]
100%|██████████| 5009/5009 [03:59<00:00, 20.89it/s]


Loss on train set: 6.119444999933063


100%|██████████| 1437/1437 [01:08<00:00, 20.86it/s]


Loss on validation set: 6.068382822261607


100%|██████████| 5009/5009 [04:00<00:00, 20.85it/s]


Perplexity on train set: 454.6123146055378


100%|██████████| 1437/1437 [01:09<00:00, 20.64it/s]


Perplexity on validation set: 431.98152549060165
Epoch: 8


100%|██████████| 5009/5009 [12:39<00:00,  6.60it/s]
100%|██████████| 5009/5009 [04:08<00:00, 20.19it/s]


Loss on train set: 6.115317517471806


100%|██████████| 1437/1437 [01:12<00:00, 19.78it/s]


Loss on validation set: 6.064094317892639


100%|██████████| 5009/5009 [04:02<00:00, 20.64it/s]


Perplexity on train set: 452.73977734108587


100%|██████████| 1437/1437 [01:07<00:00, 21.31it/s]


Perplexity on validation set: 430.1329375032078
Epoch: 9


100%|██████████| 5009/5009 [12:08<00:00,  6.88it/s]
100%|██████████| 5009/5009 [04:06<00:00, 20.32it/s]


Loss on train set: 6.109190127126965


100%|██████████| 1437/1437 [01:09<00:00, 20.69it/s]


Loss on validation set: 6.057757713867216


100%|██████████| 5009/5009 [04:04<00:00, 20.45it/s]


Perplexity on train set: 449.97414570591434


100%|██████████| 1437/1437 [01:09<00:00, 20.81it/s]


Perplexity on validation set: 427.4159726562338
Epoch: 10


100%|██████████| 5009/5009 [12:16<00:00,  6.81it/s]
100%|██████████| 5009/5009 [04:01<00:00, 20.72it/s]


Loss on train set: 6.100343896105659


100%|██████████| 1437/1437 [01:08<00:00, 21.01it/s]


Loss on validation set: 6.048758329737336


100%|██████████| 5009/5009 [03:56<00:00, 21.15it/s]


Perplexity on train set: 446.0111252009286


100%|██████████| 1437/1437 [01:08<00:00, 21.01it/s]


Perplexity on validation set: 423.58674830915595


100%|██████████| 717/717 [00:34<00:00, 21.08it/s]

Test perplexity: 427.2778





In [None]:
model = LSTM_LanguageModel(vocab_size=len(train_ds.vocab))

train_dl = DataLoader(train_ds, batch_size=128)
dev_dl = DataLoader(dev_ds, batch_size=128)
test_dl = DataLoader(test_ds, batch_size=128)

train_perplexities, val_perplexities = model.train_model(
    num_epochs=10, 
    lr=0.1, 
    train_dl=train_dl, 
    dev_dl=dev_dl
)

model.get_perp(train_dl, filename='train_perplexity.txt')
model.get_perp(test_dl, filename='test_perplexity.txt')
print("Perplexity files generated for both train and test sets.")