In [42]:
!pip install datasets
!pip install bpemb



In [43]:
import nltk
from bpemb import BPEmb
import string
from datasets import load_dataset
from tqdm import tqdm
import numpy as np
from torch import nn
import torch

nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

<a href="https://colab.research.google.com/github/Axel0087/NLP2023/blob/main/project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [44]:
dataset = load_dataset("copenlu/answerable_tydiqa")

In [45]:
train_set = dataset["train"]
validation_set = dataset["validation"]

In [46]:
def get_answer_start(row):
  return row["annotations"]["answer_start"][0]

def get_answer(row):
  return row["annotations"]["answer_text"][0]

def get_document(row):
  return row["document_plaintext"]

def get_question(row):
  return row["question_text"]

def oracle(answer, document):
  return answer != "" and answer in document

def get_language(dataset, lang):
  return [row for row in dataset if row['language'] == lang]

In [47]:
train_arabic = get_language(train_set, "arabic")
val_arabic = get_language(validation_set, "arabic")

train_bengali = get_language(train_set, "bengali")
val_bengali = get_language(validation_set, "bengali")

train_indonesian = get_language(train_set, "indonesian")
val_indonesian = get_language(validation_set, "indonesian")

In [48]:
def ratio_string(train, val):
  val_ratio = round(len(val)/len(train)*100)
  train_ratio = 100-val_ratio
  return f"{train_ratio} / {val_ratio}"

def answerable_ratio(ds):
  answerable = round(sum([1 for row in ds if get_answer_start(row) == -1])/len(ds)*100)
  nonansw = 100-answerable
  return f"{answerable} / {nonansw}"

print(f"""
Dataset features:

{train_set.column_names}

Dataset sizes:

(Arabic) Training set:                                          {len(train_arabic)}
(Arabic) Validation set:                                        {len(val_arabic)}
(Arabic) Ratio (Training/Val):                                  {ratio_string(train_arabic, val_arabic)}
(Arabic) Training balance (Answerable / Not answerable):        {answerable_ratio(train_arabic)}
(Arabic) Validation balance (Answerable / Not answerable):      {answerable_ratio(val_arabic)}

(Bengali) Training set:                                         {len(train_bengali)}
(Bengali) Validation set:                                       {len(val_bengali)}
(Bengali) Ratio (Training/Val):                                 {ratio_string(train_bengali, val_bengali)}
(Bengali) Training balance (Answerable / Not answerable):       {answerable_ratio(train_bengali)}
(Bengali) Validation balance (Answerable / Not answerable):     {answerable_ratio(val_bengali)}

(Indonesian) Training set:                                      {len(train_indonesian)}
(Indonesian) Validation set:                                    {len(val_indonesian)}
(Indonesian) Ratio (Training/Val):                              {ratio_string(train_indonesian, val_indonesian)}
(Indonesian) Training balance (Answerable / Not answerable):    {answerable_ratio(train_indonesian)}
(Indonesian) Validation balance (Answerable / Not answerable):  {answerable_ratio(val_indonesian)}
""")


Dataset features:

['question_text', 'document_title', 'language', 'annotations', 'document_plaintext', 'document_url']

Dataset sizes:

(Arabic) Training set:                                          29598
(Arabic) Validation set:                                        1902
(Arabic) Ratio (Training/Val):                                  94 / 6
(Arabic) Training balance (Answerable / Not answerable):        50 / 50
(Arabic) Validation balance (Answerable / Not answerable):      50 / 50

(Bengali) Training set:                                         4779
(Bengali) Validation set:                                       224
(Bengali) Ratio (Training/Val):                                 95 / 5
(Bengali) Training balance (Answerable / Not answerable):       50 / 50
(Bengali) Validation balance (Answerable / Not answerable):     50 / 50

(Indonesian) Training set:                                      11394
(Indonesian) Validation set:                                    1191
(Indonesian) Ra

In [8]:
def bag_of_words(dataset, column):
  bag = {}
  for row in dataset:
    tokens = nltk.word_tokenize(row[column])

    for token in tokens:

      if not token in bag:
        bag[token] = 0

      bag[token] += 1
      #print(bag)
  return sorted(bag.items(), key=lambda item: item[1], reverse=True)

#def sort_bag(bag):
#  return sorted(bag.items(), key=lambda item: item[1], reverse=True)

In [9]:
arabic_doc_bow = bag_of_words(train_arabic, "document_plaintext")
arabic_question_bow = bag_of_words(train_arabic, "question_text")

bengali_doc_bow = bag_of_words(train_bengali, "document_plaintext")
bengali_question_bow = bag_of_words(train_bengali, "question_text")

indonesian_doc_bow = bag_of_words(train_indonesian, "document_plaintext")
indonesian_question_bow = bag_of_words(train_indonesian, "question_text")

In [10]:
print(f"""

Most common words:

(Arabic) Documents: {arabic_doc_bow[0:5]}
(Arabic) Questions: {arabic_question_bow[0:5]}

(Bengali) Documents: {bengali_doc_bow[0:5]}
(Bengali) Questions: {bengali_question_bow[0:5]}

(Indonesian) Documents: {indonesian_doc_bow[0:5]}
(Indonesian) Questions: {indonesian_question_bow[0:5]}
""")



Most common words:

(Arabic) Documents: [('في', 89705), ('.', 88299), ('من', 61719), ('[', 38120), (']', 38119)]
(Arabic) Questions: [('؟', 10061), ('ما', 7451), ('متى', 7130), ('هو', 6760), ('من', 6309)]

(Bengali) Documents: [(',', 12184), (']', 7123), ('[', 7120), ('ও', 5195), ('এবং', 5102)]
(Bengali) Questions: [('?', 4777), ('কী', 940), ('নাম', 837), ('কত', 802), ('হয়', 800)]

(Indonesian) Documents: [(',', 54165), ('.', 43063), ('yang', 24077), ('dan', 23741), ('di', 16604)]
(Indonesian) Questions: [('?', 11368), ('yang', 1814), ('Kapan', 1811), ('Apa', 1633), ('Apakah', 1227)]



In [11]:
def get_ratio(question, document, stop_words):
  tokens = nltk.word_tokenize(question)
  count = 0
  stripped_tokens = set(tokens) - stop_words
  for token in stripped_tokens:
    if token in document:
      count += 1
  return count/len(stripped_tokens)


def avg(lst):
  return sum(lst)/len(lst)

def get_average_ratios(training, stop_words):
  answerable_ratios = []
  nonanswerable_ratios = []
  for row in training:
    ratio = get_ratio(get_question(row), get_document(row), stop_words)
    lst = answerable_ratios if oracle(get_answer(row), get_document(row)) else nonanswerable_ratios
    lst.append(ratio)
  return avg(answerable_ratios), avg(nonanswerable_ratios)

class NaiveModel:
  def __init__(self, stop_words):
    self.stop_words = stop_words
    self.ratio = -1

  def train(self, training):
    answerable_ratio, nonanswerable_ratio = get_average_ratios(training, self.stop_words)
    self.ratio = (answerable_ratio + nonanswerable_ratio)/2

  def classify(self, question, document):
    return get_ratio(question, document, self.stop_words) > self.ratio

def evaluate(validation, model):
  res = [int(oracle(get_answer(row), get_document(row)) == model.classify(get_question(row), get_document(row))) for row in validation]
  acc = avg(res)

  ### Manual generation of confusion matrix for scores like Balanced Accuray and F-score
  #tp, fp, tn, fn = 0, 0, 0, 0
  #for row in validation:
  #  gt = oracle(get_answer(row), get_document(row))
  #  cl = model.classify(get_question(row), get_document(row))
  #  if (cl):
  #    if (gt):
  #      tp += 1
  #    else:
  #      fp += 1
  #  else:
  #    if (gt):
  #      fn += 1
  #    else:
  #      tn += 1
  #tpr = tp / (tp + fn)
  #tnr = tn / (tn + fp)
  #ba = (tpr + tnr) / 2

  print(f"Accuracy: {round(acc*100, 4)}%\n")

In [12]:
from nltk.corpus import stopwords

nltk.download('stopwords')

arabic_stop_words = set(stopwords.words('indonesian')) | set(string.punctuation) | set("؟")
bengali_stop_words = set(stopwords.words('bengali')) | set(string.punctuation)
indonesian_stop_words = set(stopwords.words('indonesian')) | set(string.punctuation)

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.


In [13]:
print("Evaluating arabic:")

arabic_model = NaiveModel(arabic_stop_words)
arabic_model.train(train_arabic)
evaluate(val_arabic, arabic_model)

print("Evaluating bengali:")

bengali_model = NaiveModel(bengali_stop_words)
bengali_model.train(train_bengali)
evaluate(val_bengali, bengali_model)

print("Evaluating indonesian:")

indonesian_model = NaiveModel(indonesian_stop_words)
indonesian_model.train(train_indonesian)
evaluate(val_indonesian, indonesian_model)

Evaluating arabic:
Accuracy: 71.6614%

Evaluating bengali:
Accuracy: 72.3214%

Evaluating indonesian:
Accuracy: 71.2007%



In [49]:
vocab_size = 25000
encoding_dim = 100

bpemb_ar = BPEmb(lang='ar', dim=encoding_dim, vs=vocab_size)
bpemb_bn = BPEmb(lang='bn', dim=encoding_dim, vs=vocab_size)
bpemb_in = BPEmb(lang='id', dim=encoding_dim, vs=vocab_size)


In [15]:
train_arabic_doc = [get_document(row) for row in train_arabic]
train_arabic_question = [get_question(row) for row in train_arabic]

train_bengali_doc = [get_document(row) for row in train_bengali]
train_bengali_question = [get_question(row) for row in train_bengali]

train_indonesian_doc = [get_document(row) for row in train_indonesian]
train_indonesian_question = [get_question(row) for row in train_indonesian]

In [126]:
def get_bpemb_features(dataset, bpemb):
  return [bpemb.embed(x) for x in tqdm(dataset)]

def text_to_ids(text, tokenizer):
    input_ids = tokenizer.encode_ids_with_eos(text)
    return input_ids, len(input_ids)

def pad_input(input):
    input_ids = [i[0] for i in input]
    seq_lens = [i[1] for i in input]

    max_length = max(seq_lens)

    input_ids = [(i + [25000] * (max_length - len(i))) for i in input_ids]

    # Make sure each sample is max_length long
    assert (all(len(i) == max_length for i in input_ids))
    return torch.tensor(input_ids), torch.tensor(seq_lens)

In [51]:
!pip install torcheval



In [60]:
from torch.utils.data import DataLoader, Dataset

class DatasetReader(Dataset):
  def __init__(self, data, tokenizer):
    self.data = data
    self.tokenizer = tokenizer

  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):
    row = self.data[idx]
    # Calls the text_to_batch function
    input_ids, seq_lens = text_to_ids(row, self.tokenizer)
    return input_ids, seq_lens


In [63]:

reader[0]

([913, 20783, 1062, 425, 9712, 1614, 326, 20301, 2], 9)

In [127]:
from torcheval.metrics.text import Perplexity

class LSTMNetwork(nn.Module):
    def __init__(
            self,
            pretrained_embeddings,
            vocab_size: int,
            num_layers,
            hidden_dim: int,
            dropout_rate: float = 0.1
    ):
        super(LSTMNetwork, self).__init__()

        self.vocab_size = vocab_size

        self.embeddings = nn.Embedding.from_pretrained(pretrained_embeddings, padding_idx=pretrained_embeddings.shape[0] - 1)

        self.lstm = nn.LSTM(
                pretrained_embeddings.shape[1],
                hidden_dim,
                num_layers,
                batch_first=True,
                dropout=dropout_rate)

        self.dropout = nn.Dropout(dropout_rate)

        self.output_layer = nn.Linear(hidden_dim, vocab_size)

        # Initialize the weights of the model
        self._init_weights()

    def _init_weights(self):
        all_params = list(self.lstm.named_parameters()) + list(self.output_layer.named_parameters())
        for n,p in all_params:
            if 'weight' in n:
                nn.init.xavier_normal_(p)
            elif 'bias' in n:
                nn.init.zeros_(p)

    def forward(self, inputs, input_lens):

        embeds = self.embeddings(inputs)

        # Pack padded: This is necessary for padded batches input to an RNN
        lstm_in = nn.utils.rnn.pack_padded_sequence(
            embeds,
            input_lens.cpu(),
            batch_first=True,
            enforce_sorted=False
        )

        lstm_out, _ = self.lstm(lstm_in)

        lstm_out, _ = nn.utils.rnn.pad_packed_sequence(lstm_out, batch_first=True)

        targets = torch.flatten(inputs.clone())

        output = self.output_layer(lstm_out)
        logits = output.view(-1, self.vocab_size)

        loss_fn = nn.NLLLoss()
        loss = loss_fn(logits, targets)

        return (output, loss)


In [166]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [186]:
def train(
    model,
    dl,
    optimizer: torch.optim.Optimizer,
    n_epochs: int,
    device,
    patience: int = 10
):
  # Keep track of the loss and best accuracy
  best_ppl = float("inf")
  pcounter = 0

  # Iterate through epochs
  for ep in range(n_epochs):

    logits_epoch = torch.tensor([]).to(device)
    inputs_epoch = torch.tensor([]).int().to(device)

    for batch in tqdm(dl):
      model.train()
      optimizer.zero_grad()

      batch = tuple(t.to(device) for t in batch)
      inputs = batch[0]

      seq_lens = batch[1]

      (logits, loss) = model(inputs, seq_lens)

      logits_epoch = torch.cat((logits_epoch, logits), 0)
      inputs_epoch = torch.cat((inputs_epoch, inputs), 0)

      loss.backward()
      optimizer.step()

    targets = inputs_epoch.clone().to(device)

    metric=Perplexity()

    metric.update(logits_epoch.cpu(), targets.cpu())
    ppl = metric.compute()

    print(f'Perplexity at epoch {ep}: {ppl}')

    # Keep track of the best model based on the accuracy
    if ppl < best_ppl:
      torch.save(model.state_dict(), 'best_model')
      best_ppl = ppl
      pcounter = 0
    else:
      pcounter += 1
      if pcounter == patience:
        break

  model.load_state_dict(torch.load('best_model'))
  return model

In [187]:
from torch.optim import Adam

num_layers = 2
hidden_dim = 100
dropout_rate = 0.1
lr = 3e-4
n_epochs = 200
batch_size = 32

device = torch.device("cpu")
if torch.cuda.is_available():
  device = torch.device("cuda")

pretrained_embeddings = torch.Tensor(np.concatenate([bpemb_bn.emb.vectors, np.zeros(shape=(1,100))], axis=0))
vocabulary = bpemb_bn.emb.index_to_key + ['[PAD]']

model = LSTMNetwork(pretrained_embeddings, len(vocabulary), num_layers, hidden_dim, dropout_rate).to(device)

reader = DatasetReader(train_bengali_question[0:5], bpemb_bn)

dl = DataLoader(reader, batch_size=batch_size, collate_fn=pad_input, shuffle=True, num_workers=2)

optimizer = Adam(model.parameters(), lr=lr)

train(model, dl, optimizer, n_epochs, device)

100%|██████████| 1/1 [00:00<00:00,  1.48it/s]


Perplexity at epoch 0: 25004.836478282516


100%|██████████| 1/1 [00:00<00:00,  1.78it/s]


Perplexity at epoch 1: 24981.7307099388


100%|██████████| 1/1 [00:00<00:00,  3.62it/s]


Perplexity at epoch 2: 24959.580634338283


100%|██████████| 1/1 [00:00<00:00,  3.18it/s]


Perplexity at epoch 3: 24936.618151898765


100%|██████████| 1/1 [00:00<00:00,  2.93it/s]


Perplexity at epoch 4: 24915.217727247986


100%|██████████| 1/1 [00:00<00:00,  2.83it/s]


Perplexity at epoch 5: 24893.410240159577


100%|██████████| 1/1 [00:00<00:00,  3.47it/s]


Perplexity at epoch 6: 24870.22527981305


100%|██████████| 1/1 [00:00<00:00,  3.38it/s]


Perplexity at epoch 7: 24846.29354266121


100%|██████████| 1/1 [00:00<00:00,  3.19it/s]


Perplexity at epoch 8: 24818.97118298706


100%|██████████| 1/1 [00:00<00:00,  3.53it/s]


Perplexity at epoch 9: 24796.319672378333


100%|██████████| 1/1 [00:00<00:00,  3.57it/s]


Perplexity at epoch 10: 24767.056793595915


100%|██████████| 1/1 [00:00<00:00,  3.41it/s]


Perplexity at epoch 11: 24735.392633485895


100%|██████████| 1/1 [00:00<00:00,  3.29it/s]


Perplexity at epoch 12: 24703.226153877353


100%|██████████| 1/1 [00:00<00:00,  3.44it/s]


Perplexity at epoch 13: 24669.435136899552


100%|██████████| 1/1 [00:00<00:00,  3.89it/s]


Perplexity at epoch 14: 24637.11383196329


100%|██████████| 1/1 [00:00<00:00,  3.59it/s]


Perplexity at epoch 15: 24596.006109533744


100%|██████████| 1/1 [00:00<00:00,  5.50it/s]


Perplexity at epoch 16: 24559.463536901963


100%|██████████| 1/1 [00:00<00:00,  6.06it/s]


Perplexity at epoch 17: 24506.1773454834


100%|██████████| 1/1 [00:00<00:00,  5.44it/s]


Perplexity at epoch 18: 24459.99263342578


100%|██████████| 1/1 [00:00<00:00,  5.59it/s]


Perplexity at epoch 19: 24408.23321597825


100%|██████████| 1/1 [00:00<00:00,  5.72it/s]


Perplexity at epoch 20: 24340.374784884676


100%|██████████| 1/1 [00:00<00:00,  6.16it/s]


Perplexity at epoch 21: 24270.828530715236


100%|██████████| 1/1 [00:00<00:00,  5.56it/s]


Perplexity at epoch 22: 24201.618853677883


100%|██████████| 1/1 [00:00<00:00,  5.69it/s]


Perplexity at epoch 23: 24126.656598519316


100%|██████████| 1/1 [00:00<00:00,  5.43it/s]


Perplexity at epoch 24: 24020.160151166223


100%|██████████| 1/1 [00:00<00:00,  5.69it/s]


Perplexity at epoch 25: 23924.742585373828


100%|██████████| 1/1 [00:00<00:00,  6.21it/s]


Perplexity at epoch 26: 23794.803452550164


100%|██████████| 1/1 [00:00<00:00,  5.25it/s]


Perplexity at epoch 27: 23685.29956796015


100%|██████████| 1/1 [00:00<00:00,  5.31it/s]


Perplexity at epoch 28: 23533.935692135576


100%|██████████| 1/1 [00:00<00:00,  6.24it/s]


Perplexity at epoch 29: 23373.626820186746


100%|██████████| 1/1 [00:00<00:00,  6.01it/s]


Perplexity at epoch 30: 23177.9205401924


100%|██████████| 1/1 [00:00<00:00,  5.74it/s]


Perplexity at epoch 31: 22983.609744982616


100%|██████████| 1/1 [00:00<00:00,  5.05it/s]


Perplexity at epoch 32: 22705.14159235097


100%|██████████| 1/1 [00:00<00:00,  5.63it/s]


Perplexity at epoch 33: 22450.73823259747


100%|██████████| 1/1 [00:00<00:00,  5.66it/s]


Perplexity at epoch 34: 22147.28905310956


100%|██████████| 1/1 [00:00<00:00,  5.58it/s]


Perplexity at epoch 35: 21848.741470492627


100%|██████████| 1/1 [00:00<00:00,  5.58it/s]


Perplexity at epoch 36: 21480.88222682417


100%|██████████| 1/1 [00:00<00:00,  5.79it/s]


Perplexity at epoch 37: 21068.061697966477


100%|██████████| 1/1 [00:00<00:00,  5.94it/s]


Perplexity at epoch 38: 20680.805179034804


100%|██████████| 1/1 [00:00<00:00,  5.64it/s]


Perplexity at epoch 39: 20241.97118281577


100%|██████████| 1/1 [00:00<00:00,  5.74it/s]


Perplexity at epoch 40: 19767.32264334118


100%|██████████| 1/1 [00:00<00:00,  6.05it/s]


Perplexity at epoch 41: 19249.51001097536


100%|██████████| 1/1 [00:00<00:00,  5.48it/s]


Perplexity at epoch 42: 18775.567193161063


100%|██████████| 1/1 [00:00<00:00,  5.57it/s]


Perplexity at epoch 43: 18268.5782135634


100%|██████████| 1/1 [00:00<00:00,  5.40it/s]


Perplexity at epoch 44: 17758.218083663767


100%|██████████| 1/1 [00:00<00:00,  6.00it/s]


Perplexity at epoch 45: 17218.565833505305


100%|██████████| 1/1 [00:00<00:00,  5.85it/s]


Perplexity at epoch 46: 16704.214655520984


100%|██████████| 1/1 [00:00<00:00,  3.36it/s]


Perplexity at epoch 47: 16146.569490281548


100%|██████████| 1/1 [00:00<00:00,  3.44it/s]


Perplexity at epoch 48: 15639.237327718487


100%|██████████| 1/1 [00:00<00:00,  3.08it/s]


Perplexity at epoch 49: 15139.63797046979


100%|██████████| 1/1 [00:00<00:00,  3.22it/s]


Perplexity at epoch 50: 14565.501968346956


100%|██████████| 1/1 [00:00<00:00,  3.31it/s]


Perplexity at epoch 51: 14117.586872375417


100%|██████████| 1/1 [00:00<00:00,  3.35it/s]


Perplexity at epoch 52: 13646.259081998885


100%|██████████| 1/1 [00:00<00:00,  3.35it/s]


Perplexity at epoch 53: 13162.743694501049


100%|██████████| 1/1 [00:00<00:00,  3.16it/s]


Perplexity at epoch 54: 12713.937371831224


100%|██████████| 1/1 [00:00<00:00,  3.70it/s]


Perplexity at epoch 55: 12257.599208464802


100%|██████████| 1/1 [00:00<00:00,  3.53it/s]


Perplexity at epoch 56: 11860.051200105498


100%|██████████| 1/1 [00:00<00:00,  3.35it/s]


Perplexity at epoch 57: 11455.159000774163


100%|██████████| 1/1 [00:00<00:00,  3.70it/s]


Perplexity at epoch 58: 11058.436427303424


100%|██████████| 1/1 [00:00<00:00,  5.44it/s]


Perplexity at epoch 59: 10695.949984835357


100%|██████████| 1/1 [00:00<00:00,  6.25it/s]


Perplexity at epoch 60: 10301.189453137922


100%|██████████| 1/1 [00:00<00:00,  5.86it/s]


Perplexity at epoch 61: 9975.832886612216


100%|██████████| 1/1 [00:00<00:00,  5.31it/s]


Perplexity at epoch 62: 9643.730690248174


100%|██████████| 1/1 [00:00<00:00,  5.59it/s]


Perplexity at epoch 63: 9307.32616692924


100%|██████████| 1/1 [00:00<00:00,  6.12it/s]


Perplexity at epoch 64: 9031.10650800155


100%|██████████| 1/1 [00:00<00:00,  5.59it/s]


Perplexity at epoch 65: 8701.408902739777


100%|██████████| 1/1 [00:00<00:00,  5.30it/s]


Perplexity at epoch 66: 8416.437302965378


100%|██████████| 1/1 [00:00<00:00,  5.77it/s]


Perplexity at epoch 67: 8146.5444388774185


100%|██████████| 1/1 [00:00<00:00,  5.80it/s]


Perplexity at epoch 68: 7879.141922620394


100%|██████████| 1/1 [00:00<00:00,  5.45it/s]


Perplexity at epoch 69: 7614.844292958547


100%|██████████| 1/1 [00:00<00:00,  5.34it/s]


Perplexity at epoch 70: 7384.477320760649


100%|██████████| 1/1 [00:00<00:00,  6.04it/s]


Perplexity at epoch 71: 7126.0071366246575


100%|██████████| 1/1 [00:00<00:00,  5.42it/s]


Perplexity at epoch 72: 6895.143423824254


100%|██████████| 1/1 [00:00<00:00,  5.34it/s]


Perplexity at epoch 73: 6693.942777001021


100%|██████████| 1/1 [00:00<00:00,  5.33it/s]


Perplexity at epoch 74: 6476.238758293238


100%|██████████| 1/1 [00:00<00:00,  5.73it/s]


Perplexity at epoch 75: 6237.227274800519


100%|██████████| 1/1 [00:00<00:00,  5.61it/s]


Perplexity at epoch 76: 6062.729344697117


100%|██████████| 1/1 [00:00<00:00,  5.50it/s]


Perplexity at epoch 77: 5874.630465237158


100%|██████████| 1/1 [00:00<00:00,  5.29it/s]


Perplexity at epoch 78: 5691.672610666716


100%|██████████| 1/1 [00:00<00:00,  5.47it/s]


Perplexity at epoch 79: 5538.753417795491


100%|██████████| 1/1 [00:00<00:00,  6.02it/s]


Perplexity at epoch 80: 5346.231159182846


100%|██████████| 1/1 [00:00<00:00,  6.04it/s]


Perplexity at epoch 81: 5186.039516635037


100%|██████████| 1/1 [00:00<00:00,  5.42it/s]


Perplexity at epoch 82: 5018.438073013707


100%|██████████| 1/1 [00:00<00:00,  5.49it/s]


Perplexity at epoch 83: 4867.279945773427


100%|██████████| 1/1 [00:00<00:00,  5.59it/s]


Perplexity at epoch 84: 4712.092565605869


100%|██████████| 1/1 [00:00<00:00,  5.48it/s]


Perplexity at epoch 85: 4581.213682954755


100%|██████████| 1/1 [00:00<00:00,  5.45it/s]


Perplexity at epoch 86: 4437.002767007644


100%|██████████| 1/1 [00:00<00:00,  5.69it/s]


Perplexity at epoch 87: 4308.351153354102


100%|██████████| 1/1 [00:00<00:00,  5.58it/s]


Perplexity at epoch 88: 4172.4104201485


100%|██████████| 1/1 [00:00<00:00,  5.06it/s]


Perplexity at epoch 89: 4061.8562195722607


100%|██████████| 1/1 [00:00<00:00,  5.38it/s]


Perplexity at epoch 90: 3929.5176920242334


100%|██████████| 1/1 [00:00<00:00,  5.71it/s]


Perplexity at epoch 91: 3811.5151407326407


100%|██████████| 1/1 [00:00<00:00,  4.60it/s]


Perplexity at epoch 92: 3704.5041180848525


100%|██████████| 1/1 [00:00<00:00,  3.85it/s]


Perplexity at epoch 93: 3585.4011768035125


100%|██████████| 1/1 [00:00<00:00,  3.76it/s]


Perplexity at epoch 94: 3481.028265838454


100%|██████████| 1/1 [00:00<00:00,  3.56it/s]


Perplexity at epoch 95: 3379.0281734249725


100%|██████████| 1/1 [00:00<00:00,  3.47it/s]


Perplexity at epoch 96: 3297.107280119091


100%|██████████| 1/1 [00:00<00:00,  3.34it/s]


Perplexity at epoch 97: 3197.926708648974


100%|██████████| 1/1 [00:00<00:00,  3.92it/s]


Perplexity at epoch 98: 3107.181593218172


100%|██████████| 1/1 [00:00<00:00,  3.50it/s]


Perplexity at epoch 99: 3011.0739355659002


100%|██████████| 1/1 [00:00<00:00,  3.42it/s]


Perplexity at epoch 100: 2922.173617254147


100%|██████████| 1/1 [00:00<00:00,  3.43it/s]


Perplexity at epoch 101: 2843.1539787250185


100%|██████████| 1/1 [00:00<00:00,  3.03it/s]


Perplexity at epoch 102: 2769.9700777729304


100%|██████████| 1/1 [00:00<00:00,  3.77it/s]


Perplexity at epoch 103: 2676.714470570769


100%|██████████| 1/1 [00:00<00:00,  4.16it/s]


Perplexity at epoch 104: 2620.116656364155


100%|██████████| 1/1 [00:00<00:00,  5.50it/s]


Perplexity at epoch 105: 2524.830059959898


100%|██████████| 1/1 [00:00<00:00,  5.45it/s]


Perplexity at epoch 106: 2450.475269153933


100%|██████████| 1/1 [00:00<00:00,  5.40it/s]


Perplexity at epoch 107: 2395.8245908536346


100%|██████████| 1/1 [00:00<00:00,  5.45it/s]


Perplexity at epoch 108: 2321.251255747653


100%|██████████| 1/1 [00:00<00:00,  5.31it/s]


Perplexity at epoch 109: 2262.1658232624864


100%|██████████| 1/1 [00:00<00:00,  5.26it/s]


Perplexity at epoch 110: 2203.2553351979554


100%|██████████| 1/1 [00:00<00:00,  5.67it/s]


Perplexity at epoch 111: 2139.5300291572003


100%|██████████| 1/1 [00:00<00:00,  5.83it/s]


Perplexity at epoch 112: 2079.745509071908


100%|██████████| 1/1 [00:00<00:00,  6.22it/s]


Perplexity at epoch 113: 2020.0889213483


100%|██████████| 1/1 [00:00<00:00,  6.00it/s]


Perplexity at epoch 114: 1988.9548591216997


100%|██████████| 1/1 [00:00<00:00,  5.57it/s]


Perplexity at epoch 115: 1918.909879244879


100%|██████████| 1/1 [00:00<00:00,  6.04it/s]


Perplexity at epoch 116: 1859.8016061465762


100%|██████████| 1/1 [00:00<00:00,  5.58it/s]


Perplexity at epoch 117: 1823.038860333832


100%|██████████| 1/1 [00:00<00:00,  6.21it/s]


Perplexity at epoch 118: 1765.3544501732442


100%|██████████| 1/1 [00:00<00:00,  5.57it/s]


Perplexity at epoch 119: 1721.2283244152393


100%|██████████| 1/1 [00:00<00:00,  5.68it/s]


Perplexity at epoch 120: 1675.5822832243966


100%|██████████| 1/1 [00:00<00:00,  5.51it/s]


Perplexity at epoch 121: 1629.6355073065108


100%|██████████| 1/1 [00:00<00:00,  5.40it/s]


Perplexity at epoch 122: 1596.720622575101


100%|██████████| 1/1 [00:00<00:00,  5.76it/s]


Perplexity at epoch 123: 1545.2287646317998


100%|██████████| 1/1 [00:00<00:00,  5.65it/s]


Perplexity at epoch 124: 1516.8626186821728


100%|██████████| 1/1 [00:00<00:00,  5.51it/s]


Perplexity at epoch 125: 1470.0883786250074


100%|██████████| 1/1 [00:00<00:00,  5.70it/s]


Perplexity at epoch 126: 1428.8087786376527


100%|██████████| 1/1 [00:00<00:00,  6.11it/s]


Perplexity at epoch 127: 1403.3413184864623


100%|██████████| 1/1 [00:00<00:00,  5.62it/s]


Perplexity at epoch 128: 1361.107347727477


100%|██████████| 1/1 [00:00<00:00,  6.12it/s]


Perplexity at epoch 129: 1330.9673605238665


100%|██████████| 1/1 [00:00<00:00,  6.18it/s]


Perplexity at epoch 130: 1300.4148908267296


100%|██████████| 1/1 [00:00<00:00,  5.28it/s]


Perplexity at epoch 131: 1261.1923039918095


100%|██████████| 1/1 [00:00<00:00,  6.44it/s]


Perplexity at epoch 132: 1236.2461864715317


100%|██████████| 1/1 [00:00<00:00,  5.66it/s]


Perplexity at epoch 133: 1209.6800297099467


100%|██████████| 1/1 [00:00<00:00,  4.72it/s]


Perplexity at epoch 134: 1178.4474431641227


100%|██████████| 1/1 [00:00<00:00,  3.36it/s]


Perplexity at epoch 135: 1147.1069687782556


100%|██████████| 1/1 [00:00<00:00,  3.79it/s]


Perplexity at epoch 136: 1118.2004281206068


100%|██████████| 1/1 [00:00<00:00,  3.66it/s]


Perplexity at epoch 137: 1090.2041812714144


100%|██████████| 1/1 [00:00<00:00,  3.43it/s]


Perplexity at epoch 138: 1070.6370397864328


100%|██████████| 1/1 [00:00<00:00,  3.72it/s]


Perplexity at epoch 139: 1046.1396474991916


100%|██████████| 1/1 [00:00<00:00,  3.71it/s]


Perplexity at epoch 140: 1021.725401083097


100%|██████████| 1/1 [00:00<00:00,  3.31it/s]


Perplexity at epoch 141: 999.7886886597497


100%|██████████| 1/1 [00:00<00:00,  3.39it/s]


Perplexity at epoch 142: 978.0029587260494


100%|██████████| 1/1 [00:00<00:00,  3.16it/s]


Perplexity at epoch 143: 955.3366601568897


100%|██████████| 1/1 [00:00<00:00,  3.53it/s]


Perplexity at epoch 144: 930.9390870157976


100%|██████████| 1/1 [00:00<00:00,  3.67it/s]


Perplexity at epoch 145: 908.5764818158714


100%|██████████| 1/1 [00:00<00:00,  4.04it/s]


Perplexity at epoch 146: 891.1067172170129


100%|██████████| 1/1 [00:00<00:00,  5.67it/s]


Perplexity at epoch 147: 871.9620040385742


100%|██████████| 1/1 [00:00<00:00,  5.76it/s]


Perplexity at epoch 148: 850.5660760179427


100%|██████████| 1/1 [00:00<00:00,  5.74it/s]


Perplexity at epoch 149: 835.6491279826735


100%|██████████| 1/1 [00:00<00:00,  5.76it/s]


Perplexity at epoch 150: 817.9065163449154


100%|██████████| 1/1 [00:00<00:00,  5.85it/s]


Perplexity at epoch 151: 800.0873153386246


100%|██████████| 1/1 [00:00<00:00,  6.20it/s]


Perplexity at epoch 152: 783.5494926878956


100%|██████████| 1/1 [00:00<00:00,  5.49it/s]


Perplexity at epoch 153: 770.5947960358668


100%|██████████| 1/1 [00:00<00:00,  6.12it/s]


Perplexity at epoch 154: 758.3006291296897


100%|██████████| 1/1 [00:00<00:00,  5.28it/s]


Perplexity at epoch 155: 738.6223662070493


100%|██████████| 1/1 [00:00<00:00,  5.52it/s]


Perplexity at epoch 156: 732.19148667338


100%|██████████| 1/1 [00:00<00:00,  5.73it/s]


Perplexity at epoch 157: 710.8368024899738


100%|██████████| 1/1 [00:00<00:00,  5.12it/s]


Perplexity at epoch 158: 694.8140152038205


100%|██████████| 1/1 [00:00<00:00,  5.83it/s]


Perplexity at epoch 159: 686.2434878144938


100%|██████████| 1/1 [00:00<00:00,  5.89it/s]


Perplexity at epoch 160: 669.6659268242587


100%|██████████| 1/1 [00:00<00:00,  5.46it/s]


Perplexity at epoch 161: 659.9429027626501


100%|██████████| 1/1 [00:00<00:00,  5.69it/s]


Perplexity at epoch 162: 653.9758834119712


100%|██████████| 1/1 [00:00<00:00,  5.66it/s]


Perplexity at epoch 163: 635.2866628823667


100%|██████████| 1/1 [00:00<00:00,  5.74it/s]


Perplexity at epoch 164: 625.2857977482323


100%|██████████| 1/1 [00:00<00:00,  4.97it/s]


Perplexity at epoch 165: 612.9367311301261


100%|██████████| 1/1 [00:00<00:00,  6.08it/s]


Perplexity at epoch 166: 606.6226301126214


100%|██████████| 1/1 [00:00<00:00,  5.35it/s]


Perplexity at epoch 167: 595.2577753009671


100%|██████████| 1/1 [00:00<00:00,  5.69it/s]


Perplexity at epoch 168: 587.4759845602406


100%|██████████| 1/1 [00:00<00:00,  5.41it/s]


Perplexity at epoch 169: 577.6410394412325


100%|██████████| 1/1 [00:00<00:00,  5.64it/s]


Perplexity at epoch 170: 572.526021529626


100%|██████████| 1/1 [00:00<00:00,  5.35it/s]


Perplexity at epoch 171: 559.8716267654066


100%|██████████| 1/1 [00:00<00:00,  5.43it/s]


Perplexity at epoch 172: 547.8949532777644


100%|██████████| 1/1 [00:00<00:00,  5.55it/s]


Perplexity at epoch 173: 540.0834328019939


100%|██████████| 1/1 [00:00<00:00,  5.26it/s]


Perplexity at epoch 174: 535.2101669668049


100%|██████████| 1/1 [00:00<00:00,  5.23it/s]


Perplexity at epoch 175: 523.5155915491547


100%|██████████| 1/1 [00:00<00:00,  5.41it/s]


Perplexity at epoch 176: 519.5373163932777


100%|██████████| 1/1 [00:00<00:00,  5.82it/s]


Perplexity at epoch 177: 510.9806829839214


100%|██████████| 1/1 [00:00<00:00,  5.98it/s]


Perplexity at epoch 178: 502.3554120565905


100%|██████████| 1/1 [00:00<00:00,  5.49it/s]


Perplexity at epoch 179: 495.29432458565356


100%|██████████| 1/1 [00:00<00:00,  4.71it/s]


Perplexity at epoch 180: 493.28266008517545


100%|██████████| 1/1 [00:00<00:00,  3.63it/s]


Perplexity at epoch 181: 482.92710582990935


100%|██████████| 1/1 [00:00<00:00,  3.35it/s]


Perplexity at epoch 182: 478.2828193571497


100%|██████████| 1/1 [00:00<00:00,  3.54it/s]


Perplexity at epoch 183: 469.9439937598436


100%|██████████| 1/1 [00:00<00:00,  3.25it/s]


Perplexity at epoch 184: 466.34700970248383


100%|██████████| 1/1 [00:00<00:00,  3.52it/s]


Perplexity at epoch 185: 457.8295663259567


100%|██████████| 1/1 [00:00<00:00,  3.25it/s]


Perplexity at epoch 186: 455.49015589398556


100%|██████████| 1/1 [00:00<00:00,  3.29it/s]


Perplexity at epoch 187: 450.53262830781387


100%|██████████| 1/1 [00:00<00:00,  3.15it/s]


Perplexity at epoch 188: 445.31674135216696


100%|██████████| 1/1 [00:00<00:00,  3.02it/s]


Perplexity at epoch 189: 438.44307265143766


100%|██████████| 1/1 [00:00<00:00,  3.39it/s]


Perplexity at epoch 190: 435.5576644411009


100%|██████████| 1/1 [00:00<00:00,  3.25it/s]


Perplexity at epoch 191: 431.17642020953383


100%|██████████| 1/1 [00:00<00:00,  4.17it/s]


Perplexity at epoch 192: 424.75496561979975


100%|██████████| 1/1 [00:00<00:00,  5.53it/s]


Perplexity at epoch 193: 424.9352689444558


100%|██████████| 1/1 [00:00<00:00,  5.27it/s]


Perplexity at epoch 194: 418.65599171831855


100%|██████████| 1/1 [00:00<00:00,  5.45it/s]


Perplexity at epoch 195: 416.34645612864654


100%|██████████| 1/1 [00:00<00:00,  5.37it/s]


Perplexity at epoch 196: 411.10341791558994


100%|██████████| 1/1 [00:00<00:00,  5.44it/s]


Perplexity at epoch 197: 406.0929325236007


100%|██████████| 1/1 [00:00<00:00,  5.38it/s]


Perplexity at epoch 198: 399.90865773349054


100%|██████████| 1/1 [00:00<00:00,  5.55it/s]


Perplexity at epoch 199: 398.2570196480786


LSTMNetwork(
  (embeddings): Embedding(25001, 100, padding_idx=25000)
  (lstm): LSTM(100, 100, num_layers=2, batch_first=True, dropout=0.1)
  (dropout): Dropout(p=0.1, inplace=False)
  (output_layer): Linear(in_features=100, out_features=25001, bias=True)
)

In [134]:
for smt in dl:
  print(smt)
  break


(tensor([[ 8074,  4662,    22, 24797,  8988,  6078,    83, 11372, 11747,   376,
         13826, 12717,    78, 20301,     2],
        [ 6726,  4136, 24798,  1121,  7705,   779,   326, 20301,     2, 25000,
         25000, 25000, 25000, 25000, 25000],
        [  913, 20783,  1062,   425,  9712,  1614,   326, 20301,     2, 25000,
         25000, 25000, 25000, 25000, 25000],
        [ 4647,   145, 22525,  8841,  6520,  2049,    78, 20301,     2, 25000,
         25000, 25000, 25000, 25000, 25000],
        [ 4647,   145,  1998,  3416,   139,    78, 20301,     2, 25000, 25000,
         25000, 25000, 25000, 25000, 25000]]), tensor([15,  9,  9,  9,  8]))
