In [None]:
from torch.utils.data import DataLoader, Dataset
import nltk
from bpemb import BPEmb
import string
from datasets import load_dataset
from tqdm import tqdm
import numpy as np
from torch import nn
import torch

In [19]:
!pip install datasets
!pip install bpemb

dataset = load_dataset("copenlu/answerable_tydiqa")
nltk.download('punkt')

train_set = dataset["train"]
validation_set = dataset["validation"]

def get_answer_start(row):
  return row["annotations"]["answer_start"][0]

def get_answer(row):
  return row["annotations"]["answer_text"][0]

def get_document(row):
  return row["document_plaintext"]

def get_question(row):
  return row["question_text"]

def oracle(answer, document):
  return answer != "" and answer in document

def get_language(dataset, lang):
  return [row for row in dataset if row['language'] == lang]

train_arabic = get_language(train_set, "arabic")
val_arabic = get_language(validation_set, "arabic")

train_bengali = get_language(train_set, "bengali")
val_bengali = get_language(validation_set, "bengali")

train_indonesian = get_language(train_set, "indonesian")
val_indonesian = get_language(validation_set, "indonesian")



[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [None]:
vocab_size = 25000
encoding_dim = 100

bpemb_ar = BPEmb(lang='ar', dim=encoding_dim, vs=vocab_size)
bpemb_bn = BPEmb(lang='bn', dim=encoding_dim, vs=vocab_size)
bpemb_in = BPEmb(lang='id', dim=encoding_dim, vs=vocab_size)

In [20]:
train_arabic_doc = [get_document(row) for row in train_arabic]
train_arabic_question = [get_question(row) for row in train_arabic]

train_bengali_doc = [get_document(row) for row in train_bengali]
train_bengali_question = [get_question(row) for row in train_bengali]

train_indonesian_doc = [get_document(row) for row in train_indonesian]
train_indonesian_question = [get_question(row) for row in train_indonesian]

val_arabic_doc = [get_document(row) for row in val_arabic]
val_arabic_question = [get_question(row) for row in val_arabic]

val_bengali_doc = [get_document(row) for row in val_bengali]
val_bengali_question = [get_question(row) for row in val_bengali]

val_indonesian_doc = [get_document(row) for row in val_indonesian]
val_indonesian_question = [get_question(row) for row in val_indonesian]

In [None]:
def text_to_ids(text, tokenizer):
    input_ids = tokenizer.encode_ids_with_eos(text)
    return input_ids, len(input_ids)

def pad_input(input):
    input_ids = [i[0] for i in input]
    seq_lens = [i[1] for i in input]

    max_length = max(seq_lens)

    input_ids = [(i + [25000] * (max_length - len(i))) for i in input_ids]

    # Make sure each sample is max_length long
    assert (all(len(i) == max_length for i in input_ids))
    return torch.tensor(input_ids), torch.tensor(seq_lens)

In [None]:
class DatasetReader(Dataset):
  def __init__(self, data, tokenizer):
    self.data = data
    self.tokenizer = tokenizer

  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):
    row = self.data[idx]
    input_ids, seq_lens = text_to_ids(row, self.tokenizer)
    return input_ids, seq_lens


In [None]:
class LSTMNetwork(nn.Module):
    def __init__(
            self,
            pretrained_embeddings,
            vocab_size: int,
            num_layers,
            hidden_dim: int,
            dropout_rate: float = 0.1,
            bidirectional: bool = False
    ):
        super(LSTMNetwork, self).__init__()

        self.vocab_size = vocab_size

        self.embeddings = nn.Embedding.from_pretrained(pretrained_embeddings, padding_idx=pretrained_embeddings.shape[0] - 1)

        self.lstm = nn.LSTM(
                pretrained_embeddings.shape[1],
                hidden_dim,
                num_layers,
                batch_first=True,
                dropout=dropout_rate,
                bidirectional=bidirectional)

        self.dropout = nn.Dropout(dropout_rate)

        self.output_layer = nn.Linear(2*hidden_dim if bidirectional else hidden_dim, vocab_size)

        # Initialize the weights of the model
        self._init_weights()

    def _init_weights(self):
        all_params = list(self.lstm.named_parameters()) + list(self.output_layer.named_parameters())
        for n,p in all_params:
            if 'weight' in n:
                nn.init.xavier_normal_(p)
            elif 'bias' in n:
                nn.init.zeros_(p)

    def forward(self, inputs, input_lens):

        embeds = self.embeddings(inputs)

        # Pack padded: This is necessary for padded batches input to an RNN
        lstm_in = nn.utils.rnn.pack_padded_sequence(
            embeds,
            input_lens.cpu(),
            batch_first=True,
            enforce_sorted=False
        )

        lstm_out, _ = self.lstm(lstm_in)

        lstm_out, _ = nn.utils.rnn.pad_packed_sequence(lstm_out, batch_first=True)

        lin_in = self.dropout(lstm_out)

        targets = torch.flatten(inputs.clone())

        output = self.output_layer(lin_in)
        logits = output.view(-1, self.vocab_size)

        loss_fn = nn.CrossEntropyLoss()
        loss = loss_fn(logits, targets)

        return (output, loss)


In [None]:
from math import exp

def train(
    model,
    dl,
    optimizer: torch.optim.Optimizer,
    n_epochs: int,
    device,
    patience: int = 10
):
  # Keep track of the loss and best accuracy
  best_ppl = float("inf")
  pcounter = 0

  # Iterate through epochs
  for ep in range(n_epochs):

    losses_epoch = []

    for batch in tqdm(dl):
      model.train()
      optimizer.zero_grad()

      batch = tuple(t.to(device) for t in batch)
      inputs = batch[0]
      seq_lens = batch[1]

      (output, loss) = model(inputs, seq_lens)

      losses_epoch.append(loss.item())

      loss.backward()

      #torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

      optimizer.step()

    avg_loss = sum(losses_epoch) / len(losses_epoch)
    ppl = exp(avg_loss)

    print(f'Avg Loss / Perplexity at epoch {ep}: {avg_loss}/{ppl}')

    # Keep track of the best model based on the accuracy
    if ppl < best_ppl:
      torch.save(model.state_dict(), 'best_model')
      best_ppl = ppl
      pcounter = 0
    else:
      pcounter += 1
      if pcounter == patience:
        break

  model.load_state_dict(torch.load('best_model'))
  return model, best_ppl

In [16]:
from torch.optim import Adam

num_layers = 2
hidden_dim = 100
dropout_rate = 0.1
lr = 0.01
n_epochs = 100
#batch_size = 8

bidirectional = True

patience = 5

runs = [#("Arabic document", train_arabic_doc, bpemb_ar, 4),
        #("Arabic question", train_arabic_question, bpemb_ar, 32),
        #("Bengali document", train_bengali_doc, bpemb_bn, 4),
        ("Bengali question", train_bengali_question, bpemb_bn, 32),
        #("Indonesian document", train_indonesian_doc, bpemb_in, 4),
        ("Indonesian question", train_indonesian_question, bpemb_in, 32),
]

device = torch.device("cpu")
if torch.cuda.is_available():
  device = torch.device("cuda")

results = []

for (tag, ds, embs, batch_size) in runs:

  pretrained_embeddings = torch.Tensor(np.concatenate([embs.emb.vectors, np.zeros(shape=(1,100))], axis=0))
  vocabulary = embs.emb.index_to_key + ['[PAD]']

  model = LSTMNetwork(pretrained_embeddings, len(vocabulary), num_layers, hidden_dim, dropout_rate, bidirectional).to(device)

  reader = DatasetReader(ds, embs)

  dl = DataLoader(reader, batch_size=batch_size, collate_fn=pad_input, shuffle=True, num_workers=2)

  optimizer = Adam(model.parameters(), lr=lr)

  best_model, ppl = train(model, dl, optimizer, n_epochs, device, patience)

  results.append((tag, ppl, best_model))

100%|██████████| 150/150 [00:02<00:00, 58.53it/s]


Avg Loss / Perplexity at epoch 0: 6.747271849314372/851.7319369361895


100%|██████████| 150/150 [00:02<00:00, 51.29it/s]


Avg Loss / Perplexity at epoch 1: 4.029033211072286/56.206545105597435


100%|██████████| 150/150 [00:02<00:00, 66.31it/s]


Avg Loss / Perplexity at epoch 2: 2.3011905566851296/9.98606435605


100%|██████████| 150/150 [00:02<00:00, 64.81it/s]


Avg Loss / Perplexity at epoch 3: 1.1971918721993764/3.310806688386495


100%|██████████| 150/150 [00:02<00:00, 64.58it/s]


Avg Loss / Perplexity at epoch 4: 0.5199468057354292/1.6819381778547884


100%|██████████| 150/150 [00:02<00:00, 65.40it/s]


Avg Loss / Perplexity at epoch 5: 0.22499335542321205/1.2523143950650586


100%|██████████| 150/150 [00:03<00:00, 41.61it/s]


Avg Loss / Perplexity at epoch 6: 0.11646402743955453/1.1235170939536918


100%|██████████| 150/150 [00:02<00:00, 58.91it/s]


Avg Loss / Perplexity at epoch 7: 0.07172693945467472/1.0743619381501393


100%|██████████| 150/150 [00:02<00:00, 63.95it/s]


Avg Loss / Perplexity at epoch 8: 0.05135816500832637/1.0526998660262317


100%|██████████| 150/150 [00:02<00:00, 64.46it/s]


Avg Loss / Perplexity at epoch 9: 0.03772750331709782/1.038448220624813


100%|██████████| 150/150 [00:02<00:00, 64.46it/s]


Avg Loss / Perplexity at epoch 10: 0.028972336016595365/1.0293961168841366


100%|██████████| 150/150 [00:02<00:00, 53.54it/s]


Avg Loss / Perplexity at epoch 11: 0.023064561436573663/1.0233326052376335


100%|██████████| 150/150 [00:02<00:00, 55.73it/s]


Avg Loss / Perplexity at epoch 12: 0.019114929965386787/1.0192987898603971


100%|██████████| 150/150 [00:02<00:00, 64.84it/s]


Avg Loss / Perplexity at epoch 13: 0.015841628002623717/1.0159677718190352


100%|██████████| 150/150 [00:02<00:00, 65.63it/s]


Avg Loss / Perplexity at epoch 14: 0.013691009401033322/1.0137851604538897


100%|██████████| 150/150 [00:02<00:00, 65.11it/s]


Avg Loss / Perplexity at epoch 15: 0.013021395479639371/1.0131065430283415


100%|██████████| 150/150 [00:02<00:00, 57.51it/s]


Avg Loss / Perplexity at epoch 16: 0.023377641427020233/1.0236530403581996


100%|██████████| 150/150 [00:02<00:00, 53.53it/s]


Avg Loss / Perplexity at epoch 17: 0.12793596578141053/1.1364802267437477


100%|██████████| 150/150 [00:02<00:00, 65.78it/s]


Avg Loss / Perplexity at epoch 18: 0.03482951524977883/1.0354431664815584


100%|██████████| 150/150 [00:02<00:00, 64.90it/s]


Avg Loss / Perplexity at epoch 19: 0.013382690418511629/1.0134726394252813


100%|██████████| 150/150 [00:02<00:00, 65.53it/s]


Avg Loss / Perplexity at epoch 20: 0.00830096528865397/1.008335513830266


100%|██████████| 150/150 [00:02<00:00, 65.70it/s]


Avg Loss / Perplexity at epoch 21: 0.006221350265356401/1.006240743060517


100%|██████████| 150/150 [00:02<00:00, 50.29it/s]


Avg Loss / Perplexity at epoch 22: 0.005530728097073734/1.0055460508092762


100%|██████████| 150/150 [00:02<00:00, 61.07it/s]


Avg Loss / Perplexity at epoch 23: 0.005134079226603111/1.0051472811950033


100%|██████████| 150/150 [00:02<00:00, 65.17it/s]


Avg Loss / Perplexity at epoch 24: 0.004514882074048121/1.004525089510128


100%|██████████| 150/150 [00:02<00:00, 64.50it/s]


Avg Loss / Perplexity at epoch 25: 0.0038707724787915747/1.00387827359383


100%|██████████| 150/150 [00:02<00:00, 64.30it/s]


Avg Loss / Perplexity at epoch 26: 0.003488742501164476/1.0034948352465638


100%|██████████| 150/150 [00:02<00:00, 55.32it/s]


Avg Loss / Perplexity at epoch 27: 0.0031826921071236333/1.0031877622491214


100%|██████████| 150/150 [00:02<00:00, 55.87it/s]


Avg Loss / Perplexity at epoch 28: 0.00349169574600334/1.0034977988168827


100%|██████████| 150/150 [00:02<00:00, 64.88it/s]


Avg Loss / Perplexity at epoch 29: 0.002795260138809681/1.0027991705210946


100%|██████████| 150/150 [00:02<00:00, 65.63it/s]


Avg Loss / Perplexity at epoch 30: 0.0029605665934892994/1.0029649533988414


100%|██████████| 150/150 [00:02<00:00, 65.52it/s]


Avg Loss / Perplexity at epoch 31: 0.004531638770519445/1.0045419221731802


100%|██████████| 150/150 [00:02<00:00, 62.52it/s]


Avg Loss / Perplexity at epoch 32: 0.03284717819808672/1.033392602253945


100%|██████████| 150/150 [00:02<00:00, 51.66it/s]


Avg Loss / Perplexity at epoch 33: 0.06812263429164886/1.0704965800192168


100%|██████████| 150/150 [00:02<00:00, 65.38it/s]


Avg Loss / Perplexity at epoch 34: 0.026230774347980816/1.0265778289685517


100%|██████████| 357/357 [00:04<00:00, 84.10it/s]


Avg Loss / Perplexity at epoch 0: 5.268200474292958/194.06642052227778


100%|██████████| 357/357 [00:04<00:00, 79.93it/s]


Avg Loss / Perplexity at epoch 1: 1.8746310766337633/6.518413881491305


100%|██████████| 357/357 [00:04<00:00, 74.34it/s]


Avg Loss / Perplexity at epoch 2: 0.4513021941278495/1.5703557626870857


100%|██████████| 357/357 [00:04<00:00, 84.82it/s]


Avg Loss / Perplexity at epoch 3: 0.12116547602982748/1.1288116881900414


100%|██████████| 357/357 [00:04<00:00, 76.93it/s]


Avg Loss / Perplexity at epoch 4: 0.05204807088843414/1.0534263804382575


100%|██████████| 357/357 [00:04<00:00, 76.69it/s]


Avg Loss / Perplexity at epoch 5: 0.02818694447248089/1.0285879552815782


100%|██████████| 357/357 [00:04<00:00, 84.43it/s]


Avg Loss / Perplexity at epoch 6: 0.020153442387278675/1.0203578941665876


100%|██████████| 357/357 [00:04<00:00, 76.88it/s]


Avg Loss / Perplexity at epoch 7: 0.01469258344875706/1.0148010500199558


100%|██████████| 357/357 [00:04<00:00, 75.36it/s]


Avg Loss / Perplexity at epoch 8: 0.014019147653383118/1.0141178767301824


100%|██████████| 357/357 [00:04<00:00, 84.70it/s]


Avg Loss / Perplexity at epoch 9: 0.028905011410890174/1.0293268155293223


100%|██████████| 357/357 [00:04<00:00, 75.20it/s]


Avg Loss / Perplexity at epoch 10: 0.02461078486872279/1.0249161300168017


100%|██████████| 357/357 [00:04<00:00, 77.11it/s]


Avg Loss / Perplexity at epoch 11: 0.01516036729456163/1.0152758686054344


100%|██████████| 357/357 [00:04<00:00, 84.72it/s]


Avg Loss / Perplexity at epoch 12: 0.011260169272569297/1.0113238035987338


100%|██████████| 357/357 [00:04<00:00, 74.61it/s]


Avg Loss / Perplexity at epoch 13: 0.012462137405257099/1.0125401134188732


100%|██████████| 357/357 [00:05<00:00, 68.05it/s]


Avg Loss / Perplexity at epoch 14: 0.011942329621027257/1.0120139239566104


100%|██████████| 357/357 [00:04<00:00, 80.65it/s]


Avg Loss / Perplexity at epoch 15: 0.011851343064743723/1.0119218484836372


100%|██████████| 357/357 [00:04<00:00, 71.79it/s]


Avg Loss / Perplexity at epoch 16: 0.011297803283447311/1.011361864485947


100%|██████████| 357/357 [00:04<00:00, 82.83it/s]


Avg Loss / Perplexity at epoch 17: 0.01023340930103087/1.010285949703531


100%|██████████| 357/357 [00:04<00:00, 84.35it/s]


Avg Loss / Perplexity at epoch 18: 0.011197443700627703/1.0112603697242186


100%|██████████| 357/357 [00:04<00:00, 71.43it/s]


Avg Loss / Perplexity at epoch 19: 0.0134607379051468/1.013551741504381


100%|██████████| 357/357 [00:04<00:00, 83.38it/s]


Avg Loss / Perplexity at epoch 20: 0.011267644878221201/1.0113313638849346


100%|██████████| 357/357 [00:04<00:00, 78.77it/s]


Avg Loss / Perplexity at epoch 21: 0.008636184839247427/1.0086735842688042


100%|██████████| 357/357 [00:07<00:00, 44.92it/s]


Avg Loss / Perplexity at epoch 22: 0.007266973195279914/1.0072934417215014


100%|██████████| 357/357 [00:05<00:00, 60.82it/s]


Avg Loss / Perplexity at epoch 23: 0.007233400076553708/1.0072596243068717


100%|██████████| 357/357 [00:06<00:00, 56.65it/s]


Avg Loss / Perplexity at epoch 24: 0.008953640201396905/1.008993843938174


100%|██████████| 357/357 [00:04<00:00, 81.32it/s]


Avg Loss / Perplexity at epoch 25: 0.009921901007913346/1.009971286264458


100%|██████████| 357/357 [00:04<00:00, 83.19it/s]


Avg Loss / Perplexity at epoch 26: 0.008926175974706626/1.0089661330830453


100%|██████████| 357/357 [00:05<00:00, 62.16it/s]


Avg Loss / Perplexity at epoch 27: 0.007686338283568986/1.0077159540118816


100%|██████████| 357/357 [00:04<00:00, 83.48it/s]

Avg Loss / Perplexity at epoch 28: 0.00898644798425488/1.0090269473321325





In [24]:

reader = DatasetReader(val_bengali_question, bpemb_bn)
dl = DataLoader(reader, batch_size=batch_size, collate_fn=pad_input, shuffle=True, num_workers=2)

_, t_ppl, model = results[0]

losses = []
for batch in tqdm(dl):
  model.eval()
  optimizer.zero_grad()
  batch = tuple(t.to(device) for t in batch)
  inputs = batch[0]
  seq_lens = batch[1]
  (output, loss) = model(inputs, seq_lens)
  losses.append(loss.item())

avg_loss = sum(losses) / len(losses)
ppl = exp(avg_loss)
print()
print(ppl)

100%|██████████| 7/7 [00:00<00:00, 24.29it/s]


4.035701842961149





In [26]:
reader = DatasetReader(val_indonesian_question, bpemb_in)
dl = DataLoader(reader, batch_size=batch_size, collate_fn=pad_input, shuffle=True, num_workers=2)

_, t_ppl, model = results[1]

losses = []
for batch in tqdm(dl):
  model.eval()
  optimizer.zero_grad()
  batch = tuple(t.to(device) for t in batch)
  inputs = batch[0]
  seq_lens = batch[1]
  (output, loss) = model(inputs, seq_lens)
  losses.append(loss.item())

avg_loss = sum(losses) / len(losses)
ppl = exp(avg_loss)
print()
print(ppl)

100%|██████████| 38/38 [00:01<00:00, 32.42it/s]


4.51219946742593



