Installation

In [1]:
!pip install torch==2.0.1 torchtext==0.15.2

Collecting torch==2.0.1
  Downloading torch-2.0.1-cp310-cp310-manylinux1_x86_64.whl.metadata (24 kB)
Collecting torchtext==0.15.2
  Downloading torchtext-0.15.2-cp310-cp310-manylinux1_x86_64.whl.metadata (7.4 kB)
Collecting nvidia-cuda-nvrtc-cu11==11.7.99 (from torch==2.0.1)
  Downloading nvidia_cuda_nvrtc_cu11-11.7.99-2-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu11==11.7.99 (from torch==2.0.1)
  Downloading nvidia_cuda_runtime_cu11-11.7.99-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cuda-cupti-cu11==11.7.101 (from torch==2.0.1)
  Downloading nvidia_cuda_cupti_cu11-11.7.101-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu11==8.5.0.96 (from torch==2.0.1)
  Downloading nvidia_cudnn_cu11-8.5.0.96-2-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu11==11.10.3.66 (from torch==2.0.1)
  Downloading nvidia_cublas_cu11-11.10.3.66-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Co

In [2]:
!pip install datasets



Import statements

In [3]:
import torchtext
import string
import nltk
import re
import html
import random
import subprocess
from nltk.tokenize import word_tokenize, sent_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from datasets import load_dataset
from tqdm import tqdm
from collections import defaultdict
import zipfile
import os
import math
from random import shuffle

In [4]:
def split_dataset(file_path, train_ratio=0.7, val_ratio=0.2, test_ratio=0.1):
    sentences = []
    
    with open(file_path, 'r') as f:
        para = ""
        for line in tqdm(f, desc="Splitting dataset"):
            if line.strip():
                para += line.strip() + " "
            else:
                if para:
                    sentences.extend(sent_tokenize(para))
                    para = ""
        if para:
            sentences.extend(sent_tokenize(para))

    shuffle(sentences)

    total_sentences = len(sentences)
    train_size = int(total_sentences * train_ratio)
    val_size = int(total_sentences * val_ratio)
    
    train_sentences = sentences[:train_size]
    val_sentences = sentences[train_size:train_size + val_size]
    test_sentences = sentences[train_size + val_size:]

    return train_sentences, val_sentences, test_sentences

In [5]:
def save_datasets(train_sentences, val_sentences, test_sentences):
    with open('train.txt', 'w') as f:
        f.writelines([s + '\n' for s in train_sentences])
    with open('dev.txt', 'w') as f:
        f.writelines([s + '\n' for s in val_sentences])
    with open('test.txt', 'w') as f:
        f.writelines([s + '\n' for s in test_sentences])

In [6]:
train_sentences, val_sentences, test_sentences = split_dataset('/kaggle/input/auguste-maquet/Auguste_Maquet.txt')
save_datasets(train_sentences, val_sentences, test_sentences)

Splitting dataset: 128612it [00:02, 46253.79it/s]


In [7]:
def get_embeddings(emb_file='glove.6B.300d.txt'):
    unk_emb = torch.zeros(300)
    embeddings = defaultdict(lambda: unk_emb)

    with open(emb_file, 'r', encoding='ISO-8859-1') as f:
        for line in tqdm(f, desc="Reading embeddings"):
            try:
                split = line.strip().split()
                word = split[0]
                vector = torch.tensor([float(x) for x in split[1:]])
                embeddings[word] = vector
            except ValueError as e:
                continue

    return embeddings

In [11]:
embeddings = get_embeddings('/kaggle/input/glove/pytorch/default/1/glove.6B.300d.txt')

Reading embeddings: 400000it [01:15, 5326.25it/s]


In [12]:
class TextData(Dataset):
    def __init__(self, file_path='train.txt', pretrained_emb_dict=embeddings,
                 frequency_cutoff=1, context_size=5, vocab=None):
        self.file_path = file_path
        self.frequency_cutoff = frequency_cutoff
        self.context_size = context_size

        self.contexts = []
        self.words = []

        self.frequency_dictionary = defaultdict(lambda: 0)
        self.vocab = vocab if vocab else []

        self.words2indices = {}
        self.embeddings = pretrained_emb_dict

        with open(self.file_path, 'r') as f:
            for line in tqdm(f, desc="Obtaining vocabulary and freq counts"):
                words = [word.lower() for word in word_tokenize(line)]
                if not vocab:
                    self.vocab += words
                for word in words:
                    self.frequency_dictionary[word] += 1

            if not vocab:
                self.vocab = list(set(self.vocab))
                self.vocab = [word for word in self.vocab if self.frequency_dictionary[word] > self.frequency_cutoff]
                self.vocab.append('<unk>')
            self.words2indices = {w: i for i, w in enumerate(self.vocab)}

        embeddings_list = []
        for word in self.vocab:
            embeddings_list.append(self.embeddings[word])
        embeddings_list.append(self.embeddings['<unk>'])
        self.embeddings = torch.stack(embeddings_list)

        with open(self.file_path, 'r') as f:
            for line in tqdm(f, desc="Creating dataset"):
                words = [word.lower() for word in word_tokenize(line)]
                indices = [self.words2indices[word] if word in self.vocab else (len(self.vocab) - 1)
                           for word in words]
                embeddings = [self.embeddings[i] for i in indices]

                for i in range(len(embeddings) - self.context_size):
                    self.contexts.append(torch.stack(embeddings[i:i + self.context_size]))
                    self.words.append(indices[i + self.context_size])

        self.contexts = torch.stack(self.contexts)
        self.words = torch.tensor(self.words)

    def __getitem__(self, idx):
        return (self.contexts[idx], self.words[idx])

    def __len__(self):
        return len(self.contexts)

In [13]:
train_ds = TextData()
with open('vocab.txt', 'w') as f:
    for word in train_ds.vocab:
        f.write(word + '\n')

test_ds = TextData('test.txt', vocab=train_ds.vocab)
dev_ds = TextData('dev.txt', vocab=train_ds.vocab)

Obtaining vocabulary and freq counts: 39555it [00:09, 3985.07it/s]
Creating dataset: 39555it [02:14, 294.45it/s]
Obtaining vocabulary and freq counts: 5652it [00:01, 4217.45it/s]
Creating dataset: 5652it [00:18, 298.13it/s]
Obtaining vocabulary and freq counts: 11301it [00:02, 4161.19it/s]
Creating dataset: 11301it [00:37, 297.72it/s]


In [24]:
class TransformerLanguageModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim=300, num_heads=4, num_layers=2, ff_dim=512, dropout=0.1, padding_idx=0):
        super(TransformerLanguageModel, self).__init__()

        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
        self.positional_encoding = self._generate_positional_encoding(embedding_dim, max_len=5000)

        self.decoder_layers = nn.TransformerDecoderLayer(
            d_model=embedding_dim,
            nhead=num_heads,
            dim_feedforward=ff_dim,
            dropout=dropout
        )
        self.transformer_decoder = nn.TransformerDecoder(self.decoder_layers, num_layers=num_layers)

        self.fc = nn.Linear(embedding_dim, vocab_size)

    def _generate_positional_encoding(self, embedding_dim, max_len):
        pos_enc = torch.zeros(max_len, embedding_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim))
        pos_enc[:, 0::2] = torch.sin(position * div_term)
        pos_enc[:, 1::2] = torch.cos(position * div_term)
        return pos_enc.unsqueeze(0)

    def forward(self, batch_ctx):
        batch_size, seq_len, embedding_dim = batch_ctx.size()

        pos_enc = self.positional_encoding[:, :seq_len, :].to(batch_ctx.device)

        embedded = batch_ctx + pos_enc
        embedded = embedded.permute(1, 0, 2)

        tgt_mask = self.generate_square_subsequent_mask(seq_len).to(embedded.device)

        transformer_output = self.transformer_decoder(embedded, embedded, tgt_mask=tgt_mask)

        last_hidden_state = transformer_output[-1, :, :]
        logits = self.fc(last_hidden_state)
        return logits

    def generate_square_subsequent_mask(self, sz):
        mask = torch.triu(torch.ones(sz, sz)) == 1
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def train_epoch(self, dl, optimiser, loss_fn):
        self.train()
        for batch in tqdm(dl):
            optimiser.zero_grad()
            contexts, words = batch
            logits = self.forward(contexts)
            loss = loss_fn(logits, words)
            loss.backward()
            optimiser.step()

    def train_model(self, num_epochs, lr=0.1):
        optimiser = torch.optim.SGD(self.parameters(), lr=lr)
        loss_fn = nn.CrossEntropyLoss()
        train_dl = DataLoader(train_ds, batch_size=128)
        dev_dl = DataLoader(dev_ds, batch_size=128)

        for epoch in range(num_epochs):
            print(f"Epoch: {epoch + 1}")
            self.train_epoch(train_dl, optimiser, loss_fn)
            train_loss = self.get_loss(train_dl, loss_fn)
            print(f"Loss on train set: {train_loss}")
            val_loss = self.get_loss(dev_dl, loss_fn)
            print(f"Loss on validation set: {val_loss}")
            train_perp = self.get_perp(train_dl, filename='train_perplexity.txt')
            print(f"Perplexity on train set: {train_perp}")
            val_perp = self.get_perp(dev_dl, filename='test_perplexity.txt')
            print(f"Perplexity on validation set: {val_perp}")
            print("==========================")

    def get_loss(self, dl, loss_fn):
        total_loss = 0
        total_samples = 0
        self.eval()

        with torch.no_grad():
            for batch in tqdm(dl):
                contexts, words = batch
                pred = self.forward(contexts)
                loss = loss_fn(pred, words)
                total_loss += loss.item() * len(words)
                total_samples += len(words)

        avg_loss = total_loss / total_samples
        return avg_loss

    def get_perp(self, dl, filename='perplexity_output.txt'):
        total_loss = 0
        total_samples = 0
        loss_fn = nn.CrossEntropyLoss(reduction='sum')
        self.eval()

        sentence_perplexities = []
        with open(filename, 'w') as f, torch.no_grad():
            for batch in tqdm(dl):
                contexts, words = batch

                pred = self.forward(contexts)
                loss = loss_fn(pred, words)
                perplexity = torch.exp(loss / len(words))

                sentence = ' '.join([train_ds.vocab[idx] for idx in words.tolist()])
                f.write(f"{sentence}\t{perplexity.item()}\n")
                
                total_loss += loss.item()
                total_samples += len(words)

                sentence_perplexities.append(perplexity.item())

            avg_loss = total_loss / total_samples
            avg_perplexity = math.exp(avg_loss)

            f.write(f"Average perplexity: {avg_perplexity}\n")
            print(f"Average perplexity: {avg_perplexity}")

        return avg_perplexity

In [12]:
transformer_lm = TransformerLanguageModel(len(train_ds.vocab))
transformer_lm.train_model(num_epochs=10)

torch.save(transformer_lm, '10epochs_transformer.pth')

test_dl = DataLoader(test_ds, batch_size=128)
perp = transformer_lm.get_perp(test_dl)
print(perp)

Epoch: 1


100%|██████████| 5013/5013 [07:14<00:00, 11.55it/s]
100%|██████████| 5013/5013 [02:41<00:00, 31.01it/s]


Loss on train set: 4.906577545409246


100%|██████████| 1432/1432 [00:45<00:00, 31.16it/s]


Loss on validation set: 4.913811779087805


100%|██████████| 5013/5013 [02:47<00:00, 29.99it/s]


Perplexity on train set: 135.17598815194546


100%|██████████| 1432/1432 [00:48<00:00, 29.82it/s]


Perplexity on validation set: 136.1574285422968
Epoch: 2


100%|██████████| 5013/5013 [07:49<00:00, 10.67it/s]
100%|██████████| 5013/5013 [02:42<00:00, 30.88it/s]


Loss on train set: 4.671896824147495


100%|██████████| 1432/1432 [00:45<00:00, 31.13it/s]


Loss on validation set: 4.722732277158658


100%|██████████| 5013/5013 [02:42<00:00, 30.91it/s]


Perplexity on train set: 106.90032135433643


100%|██████████| 1432/1432 [00:46<00:00, 30.91it/s]


Perplexity on validation set: 112.47514649106964
Epoch: 3


100%|██████████| 5013/5013 [07:42<00:00, 10.85it/s]
100%|██████████| 5013/5013 [02:35<00:00, 32.18it/s]


Loss on train set: 4.51844811830664


100%|██████████| 1432/1432 [00:44<00:00, 32.16it/s]


Loss on validation set: 4.606743531652158


100%|██████████| 5013/5013 [02:34<00:00, 32.40it/s]


Perplexity on train set: 91.69319052319733


100%|██████████| 1432/1432 [00:43<00:00, 32.57it/s]


Perplexity on validation set: 100.15745840217254
Epoch: 4


100%|██████████| 5013/5013 [07:32<00:00, 11.07it/s]
100%|██████████| 5013/5013 [02:38<00:00, 31.67it/s]


Loss on train set: 4.4310084936483145


100%|██████████| 1432/1432 [00:40<00:00, 35.20it/s]


Loss on validation set: 4.555966797985258


100%|██████████| 5013/5013 [02:25<00:00, 34.37it/s]


Perplexity on train set: 84.01610390708842


100%|██████████| 1432/1432 [00:43<00:00, 32.96it/s]


Perplexity on validation set: 95.19874871731507
Epoch: 5


100%|██████████| 5013/5013 [07:24<00:00, 11.27it/s]
100%|██████████| 5013/5013 [02:34<00:00, 32.41it/s]


Loss on train set: 4.347354815055359


100%|██████████| 1432/1432 [00:43<00:00, 32.69it/s]


Loss on validation set: 4.510279294564244


100%|██████████| 5013/5013 [02:30<00:00, 33.27it/s]


Perplexity on train set: 77.27378888128897


100%|██████████| 1432/1432 [00:42<00:00, 33.86it/s]


Perplexity on validation set: 90.94721602673972
Epoch: 6


100%|██████████| 5013/5013 [07:19<00:00, 11.40it/s]
100%|██████████| 5013/5013 [02:29<00:00, 33.43it/s]


Loss on train set: 4.287467683179539


100%|██████████| 1432/1432 [00:42<00:00, 33.52it/s]


Loss on validation set: 4.4882939310811825


100%|██████████| 5013/5013 [02:29<00:00, 33.44it/s]


Perplexity on train set: 72.78192804019207


100%|██████████| 1432/1432 [00:42<00:00, 33.44it/s]


Perplexity on validation set: 88.96952817772743
Epoch: 7


100%|██████████| 5013/5013 [07:24<00:00, 11.29it/s]
100%|██████████| 5013/5013 [02:27<00:00, 34.05it/s]


Loss on train set: 4.209724129290699


100%|██████████| 1432/1432 [00:41<00:00, 34.72it/s]


Loss on validation set: 4.450416825831654


100%|██████████| 5013/5013 [02:31<00:00, 33.02it/s]


Perplexity on train set: 67.33796067654029


100%|██████████| 1432/1432 [00:41<00:00, 34.28it/s]


Perplexity on validation set: 85.66264296395414
Epoch: 8


100%|██████████| 5013/5013 [07:17<00:00, 11.45it/s]
100%|██████████| 5013/5013 [02:34<00:00, 32.43it/s]


Loss on train set: 4.164285963914807


100%|██████████| 1432/1432 [00:43<00:00, 32.90it/s]


Loss on validation set: 4.4423787102413685


100%|██████████| 5013/5013 [02:28<00:00, 33.66it/s]


Perplexity on train set: 64.34672015589415


100%|██████████| 1432/1432 [00:42<00:00, 33.91it/s]


Perplexity on validation set: 84.9768367255146
Epoch: 9


100%|██████████| 5013/5013 [07:18<00:00, 11.44it/s]
100%|██████████| 5013/5013 [02:27<00:00, 34.07it/s]


Loss on train set: 4.119091849666926


100%|██████████| 1432/1432 [00:43<00:00, 33.08it/s]


Loss on validation set: 4.439525863883796


100%|██████████| 5013/5013 [02:31<00:00, 33.02it/s]


Perplexity on train set: 61.503362595472794


100%|██████████| 1432/1432 [00:44<00:00, 32.29it/s]


Perplexity on validation set: 84.73475633964333
Epoch: 10


100%|██████████| 5013/5013 [07:42<00:00, 10.85it/s]
100%|██████████| 5013/5013 [02:42<00:00, 30.83it/s]


Loss on train set: 4.087679990099404


100%|██████████| 1432/1432 [00:46<00:00, 30.49it/s]


Loss on validation set: 4.44571960182354


100%|██████████| 5013/5013 [02:40<00:00, 31.15it/s]


Perplexity on train set: 59.601455213039934


100%|██████████| 1432/1432 [00:45<00:00, 31.55it/s]


Perplexity on validation set: 85.26120988946136


100%|██████████| 719/719 [00:23<00:00, 30.87it/s]

83.03964072816288





In [None]:
model = TransformerLanguageModel(vocab_size=len(train_ds.vocab))

train_dl = DataLoader(train_ds, batch_size=128)
dev_dl = DataLoader(dev_ds, batch_size=128)
test_dl = DataLoader(test_ds, batch_size=128)

train_perplexities, val_perplexities = model.train_model(
    num_epochs=10, 
    lr=0.1, 
    train_dl=train_dl, 
    dev_dl=dev_dl
)

model.get_perp(train_dl, filename='train_perplexity.txt')
model.get_perp(test_dl, filename='test_perplexity.txt')
print("Perplexity files generated for both train and test sets.")