In [1]:
!pip install datasets
!pip install bpemb

Collecting datasets
  Downloading datasets-2.14.5-py3-none-any.whl (519 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.6/519.6 kB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.8,>=0.3.0 (from datasets)
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash (from datasets)
  Downloading xxhash-3.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting multiprocess (from datasets)
  Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0.0,>=0.14.0 (from datasets)
  Downloading huggingface_hub-0.17.3-py3-none-a

In [2]:
from torch.utils.data import DataLoader, Dataset
import nltk
from bpemb import BPEmb
import string
from datasets import load_dataset
from tqdm import tqdm
import numpy as np
from torch import nn
import torch

In [3]:
dataset = load_dataset("copenlu/answerable_tydiqa")
nltk.download('punkt')

train_set = dataset["train"]
validation_set = dataset["validation"]

def get_answer_start(row):
  return row["annotations"]["answer_start"][0]

def get_answer(row):
  return row["annotations"]["answer_text"][0]

def get_document(row):
  return row["document_plaintext"]

def get_question(row):
  return row["question_text"]

def oracle(answer, document):
  return answer != "" and answer in document

def get_language(dataset, lang):
  return [row for row in dataset if row['language'] == lang]

train_arabic = get_language(train_set, "arabic")
val_arabic = get_language(validation_set, "arabic")

train_bengali = get_language(train_set, "bengali")
val_bengali = get_language(validation_set, "bengali")

train_indonesian = get_language(train_set, "indonesian")
val_indonesian = get_language(validation_set, "indonesian")

Downloading readme:   0%|          | 0.00/4.94k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.47k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/71.6M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.49M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/116067 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/13325 [00:00<?, ? examples/s]

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


In [4]:
vocab_size = 25000
encoding_dim = 100

bpemb_ar = BPEmb(lang='ar', dim=encoding_dim, vs=vocab_size)
bpemb_bn = BPEmb(lang='bn', dim=encoding_dim, vs=vocab_size)
bpemb_in = BPEmb(lang='id', dim=encoding_dim, vs=vocab_size)

downloading https://nlp.h-its.org/bpemb/ar/ar.wiki.bpe.vs25000.model


100%|██████████| 742254/742254 [00:00<00:00, 990256.26B/s]


downloading https://nlp.h-its.org/bpemb/ar/ar.wiki.bpe.vs25000.d100.w2v.bin.tar.gz


100%|██████████| 9491724/9491724 [00:01<00:00, 6259641.50B/s] 


downloading https://nlp.h-its.org/bpemb/bn/bn.wiki.bpe.vs25000.model


100%|██████████| 863227/863227 [00:01<00:00, 824026.12B/s] 


downloading https://nlp.h-its.org/bpemb/bn/bn.wiki.bpe.vs25000.d100.w2v.bin.tar.gz


100%|██████████| 9517491/9517491 [00:01<00:00, 6296496.18B/s] 


downloading https://nlp.h-its.org/bpemb/id/id.wiki.bpe.vs25000.model


100%|██████████| 650018/650018 [00:01<00:00, 624019.23B/s] 


downloading https://nlp.h-its.org/bpemb/id/id.wiki.bpe.vs25000.d100.w2v.bin.tar.gz


100%|██████████| 9465922/9465922 [00:01<00:00, 6243395.11B/s]


In [5]:
train_arabic_doc = [get_document(row) for row in train_arabic]
train_arabic_question = [get_question(row) for row in train_arabic]

train_bengali_doc = [get_document(row) for row in train_bengali]
train_bengali_question = [get_question(row) for row in train_bengali]

train_indonesian_doc = [get_document(row) for row in train_indonesian]
train_indonesian_question = [get_question(row) for row in train_indonesian]

val_arabic_doc = [get_document(row) for row in val_arabic]
val_arabic_question = [get_question(row) for row in val_arabic]

val_bengali_doc = [get_document(row) for row in val_bengali]
val_bengali_question = [get_question(row) for row in val_bengali]

val_indonesian_doc = [get_document(row) for row in val_indonesian]
val_indonesian_question = [get_question(row) for row in val_indonesian]

In [6]:
def text_to_ids(text, tokenizer):
    input_ids = tokenizer.encode_ids_with_eos(text)
    return input_ids, len(input_ids)

def pad_input(input):
    input_ids = [i[0] for i in input]
    seq_lens = [i[1] for i in input]

    max_length = max(seq_lens)

    input_ids = [(i + [25000] * (max_length - len(i))) for i in input_ids]

    # Make sure each sample is max_length long
    assert (all(len(i) == max_length for i in input_ids))
    return torch.tensor(input_ids), torch.tensor(seq_lens)

In [7]:
class DatasetReader(Dataset):
  def __init__(self, data, tokenizer):
    self.data = data
    self.tokenizer = tokenizer

  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):
    row = self.data[idx]
    input_ids, seq_lens = text_to_ids(row, self.tokenizer)
    return input_ids, seq_lens


In [8]:
class LSTMNetwork(nn.Module):
    def __init__(
            self,
            pretrained_embeddings,
            vocab_size: int,
            num_layers,
            hidden_dim: int,
            dropout_rate: float = 0.1,
            bidirectional: bool = False
    ):
        super(LSTMNetwork, self).__init__()

        self.vocab_size = vocab_size

        self.embeddings = nn.Embedding.from_pretrained(pretrained_embeddings, padding_idx=pretrained_embeddings.shape[0] - 1)

        self.lstm = nn.LSTM(
                pretrained_embeddings.shape[1],
                hidden_dim,
                num_layers,
                batch_first=True,
                dropout=dropout_rate,
                bidirectional=bidirectional)

        self.dropout = nn.Dropout(dropout_rate)

        self.output_layer = nn.Linear(2*hidden_dim if bidirectional else hidden_dim, vocab_size)

        # Initialize the weights of the model
        self._init_weights()

    def _init_weights(self):
        all_params = list(self.lstm.named_parameters()) + list(self.output_layer.named_parameters())
        for n,p in all_params:
            if 'weight' in n:
                nn.init.xavier_normal_(p)
            elif 'bias' in n:
                nn.init.zeros_(p)

    def forward(self, inputs, input_lens):

        embeds = self.embeddings(inputs)

        # Pack padded: This is necessary for padded batches input to an RNN
        lstm_in = nn.utils.rnn.pack_padded_sequence(
            embeds,
            input_lens.cpu(),
            batch_first=True,
            enforce_sorted=False
        )

        lstm_out, _ = self.lstm(lstm_in)

        lstm_out, _ = nn.utils.rnn.pad_packed_sequence(lstm_out, batch_first=True)

        lin_in = self.dropout(lstm_out)
        output = self.output_layer(lin_in)

        targets = torch.flatten(inputs.clone())
        logits = output.view(-1, self.vocab_size)

        loss_fn = nn.CrossEntropyLoss()
        loss = loss_fn(logits, targets)

        return (logits, loss)


In [13]:
from math import exp

def train(
    model,
    dl,
    optimizer: torch.optim.Optimizer,
    n_epochs: int,
    device,
    patience: int = 10
):
  # Keep track of the loss and best accuracy
  best_ppl = float("inf")
  pcounter = 0

  # Iterate through epochs
  for ep in range(n_epochs):

    losses_epoch = []

    for batch in tqdm(dl):
      model.train()
      optimizer.zero_grad()

      batch = tuple(t.to(device) for t in batch)
      inputs = batch[0]
      seq_lens = batch[1]

      (_, loss) = model(inputs, seq_lens)

      losses_epoch.append(loss.item())

      loss.backward()

      #torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

      optimizer.step()

    avg_loss = sum(losses_epoch) / len(losses_epoch)
    ppl = exp(avg_loss)

    print(f'Avg Loss / Perplexity at epoch {ep}: {avg_loss}/{ppl}')

    # Keep track of the best model based on the accuracy
    if ppl < best_ppl:
      torch.save(model.state_dict(), 'best_model')
      best_ppl = ppl
      pcounter = 0
    else:
      pcounter += 1
      if pcounter == patience:
        break

    if best_ppl < 1.02:
      break

  model.load_state_dict(torch.load('best_model'))
  return model, best_ppl

In [14]:
from torch.optim import Adam

num_layers = 2
hidden_dim = 100
dropout_rate = 0.1
lr = 0.01
n_epochs = 100
#batch_size = 8

bidirectional = True

patience = 2

runs = [("Arabic document", train_arabic_doc, bpemb_ar, 4, val_arabic_doc),
        ("Arabic question", train_arabic_question, bpemb_ar, 32, val_arabic_question),
        ("Bengali document", train_bengali_doc, bpemb_bn, 4, val_bengali_doc),
        ("Bengali question", train_bengali_question, bpemb_bn, 32, val_bengali_question),
        ("Indonesian document", train_indonesian_doc, bpemb_in, 4, val_indonesian_doc),
        ("Indonesian question", train_indonesian_question, bpemb_in, 32, val_indonesian_question),
]

device = torch.device("cpu")
if torch.cuda.is_available():
  device = torch.device("cuda")

results = []

for (tag, ds, embs, batch_size, val_set) in runs:

  pretrained_embeddings = torch.Tensor(np.concatenate([embs.emb.vectors, np.zeros(shape=(1,100))], axis=0))
  vocabulary = embs.emb.index_to_key + ['[PAD]']

  model = LSTMNetwork(pretrained_embeddings, len(vocabulary), num_layers, hidden_dim, dropout_rate, bidirectional).to(device)

  reader = DatasetReader(ds, embs)

  dl = DataLoader(reader, batch_size=batch_size, collate_fn=pad_input, shuffle=True, num_workers=2)

  optimizer = Adam(model.parameters(), lr=lr)

  best_model, train_ppl = train(model, dl, optimizer, n_epochs, device, patience)

  results.append((tag, train_ppl, best_model, val_set))

100%|██████████| 7400/7400 [04:12<00:00, 29.33it/s]


Avg Loss / Perplexity at epoch 0: 0.6593144763545511/1.9334664430246353


100%|██████████| 7400/7400 [04:11<00:00, 29.38it/s]


Avg Loss / Perplexity at epoch 1: 0.031223017147087768/1.0317155684940846


100%|██████████| 7400/7400 [04:10<00:00, 29.50it/s]


Avg Loss / Perplexity at epoch 2: 0.022950443326805504/1.0232158311182071


100%|██████████| 7400/7400 [04:09<00:00, 29.65it/s]


Avg Loss / Perplexity at epoch 3: 0.02135931419702231/1.0215890571501913


100%|██████████| 7400/7400 [04:10<00:00, 29.54it/s]


Avg Loss / Perplexity at epoch 4: 0.018862567251164513/1.0190415893064644


100%|██████████| 925/925 [00:11<00:00, 81.19it/s]


Avg Loss / Perplexity at epoch 0: 3.0467719804918443/21.047293572650155


100%|██████████| 925/925 [00:12<00:00, 76.80it/s]


Avg Loss / Perplexity at epoch 1: 0.22473098478204495/1.2519858676341435


100%|██████████| 925/925 [00:11<00:00, 79.35it/s]


Avg Loss / Perplexity at epoch 2: 0.043221522678394575/1.0441691763852643


100%|██████████| 925/925 [00:11<00:00, 80.55it/s]


Avg Loss / Perplexity at epoch 3: 0.029633323295494996/1.0300767595456346


100%|██████████| 925/925 [00:11<00:00, 80.26it/s]


Avg Loss / Perplexity at epoch 4: 0.029663597824102318/1.030107945106021


100%|██████████| 925/925 [00:11<00:00, 80.33it/s]


Avg Loss / Perplexity at epoch 5: 0.025268548682734773/1.025590504524784


100%|██████████| 925/925 [00:11<00:00, 80.49it/s]


Avg Loss / Perplexity at epoch 6: 0.024196985656445895/1.0244921082655347


100%|██████████| 925/925 [00:11<00:00, 80.39it/s]


Avg Loss / Perplexity at epoch 7: 0.020502395139302354/1.0207140129925147


100%|██████████| 925/925 [00:11<00:00, 79.99it/s]


Avg Loss / Perplexity at epoch 8: 0.02262828030950717/1.022886241912284


100%|██████████| 925/925 [00:11<00:00, 80.46it/s]


Avg Loss / Perplexity at epoch 9: 0.01921955629302239/1.019405440928701


100%|██████████| 1195/1195 [00:40<00:00, 29.82it/s]


Avg Loss / Perplexity at epoch 0: 3.0086901445269087/20.260843759580446


100%|██████████| 1195/1195 [00:39<00:00, 30.50it/s]


Avg Loss / Perplexity at epoch 1: 0.2354856172958428/1.2655231794808202


100%|██████████| 1195/1195 [00:39<00:00, 29.95it/s]


Avg Loss / Perplexity at epoch 2: 0.07401893644059802/1.0768271965773943


100%|██████████| 1195/1195 [00:39<00:00, 30.22it/s]


Avg Loss / Perplexity at epoch 3: 0.045263920660485284/1.0463039647056676


100%|██████████| 1195/1195 [00:39<00:00, 30.10it/s]


Avg Loss / Perplexity at epoch 4: 0.03554546828712941/1.036184760603209


100%|██████████| 1195/1195 [00:39<00:00, 30.21it/s]


Avg Loss / Perplexity at epoch 5: 0.028679167841868983/1.0290943749362085


100%|██████████| 1195/1195 [00:39<00:00, 30.49it/s]


Avg Loss / Perplexity at epoch 6: 0.024850891132850713/1.0251622483459184


100%|██████████| 1195/1195 [00:39<00:00, 29.89it/s]


Avg Loss / Perplexity at epoch 7: 0.023342295554680866/1.0236168590879484


100%|██████████| 1195/1195 [00:39<00:00, 30.21it/s]


Avg Loss / Perplexity at epoch 8: 0.023187149561908503/1.0234580613528614


100%|██████████| 1195/1195 [00:39<00:00, 30.38it/s]


Avg Loss / Perplexity at epoch 9: 0.02089466873198849/1.0211144906887735


100%|██████████| 1195/1195 [00:39<00:00, 30.03it/s]


Avg Loss / Perplexity at epoch 10: 0.020651221916052245/1.020865933873719


100%|██████████| 1195/1195 [00:39<00:00, 30.15it/s]


Avg Loss / Perplexity at epoch 11: 0.019685054981080564/1.0198800832881798


100%|██████████| 150/150 [00:02<00:00, 59.61it/s]


Avg Loss / Perplexity at epoch 0: 6.643865003585815/768.0578098721537


100%|██████████| 150/150 [00:02<00:00, 66.79it/s]


Avg Loss / Perplexity at epoch 1: 3.9484508721033733/51.85497459425442


100%|██████████| 150/150 [00:02<00:00, 66.76it/s]


Avg Loss / Perplexity at epoch 2: 2.2776615460713705/9.753844803229041


100%|██████████| 150/150 [00:02<00:00, 66.79it/s]


Avg Loss / Perplexity at epoch 3: 1.1785211598873138/3.2495650606294406


100%|██████████| 150/150 [00:02<00:00, 62.71it/s]


Avg Loss / Perplexity at epoch 4: 0.52049398680528/1.682858754424437


100%|██████████| 150/150 [00:02<00:00, 58.05it/s]


Avg Loss / Perplexity at epoch 5: 0.21822793612877528/1.2438705583487124


100%|██████████| 150/150 [00:02<00:00, 67.54it/s]


Avg Loss / Perplexity at epoch 6: 0.11412502740820249/1.120892258377646


100%|██████████| 150/150 [00:02<00:00, 67.29it/s]


Avg Loss / Perplexity at epoch 7: 0.07130158871412277/1.0739050546788762


100%|██████████| 150/150 [00:02<00:00, 67.95it/s]


Avg Loss / Perplexity at epoch 8: 0.04930769401292006/1.0505435469744209


100%|██████████| 150/150 [00:02<00:00, 66.01it/s]


Avg Loss / Perplexity at epoch 9: 0.037343331364293894/1.0380493545654919


100%|██████████| 150/150 [00:02<00:00, 56.99it/s]


Avg Loss / Perplexity at epoch 10: 0.028620222633083663/1.0290337165411965


100%|██████████| 150/150 [00:02<00:00, 66.81it/s]


Avg Loss / Perplexity at epoch 11: 0.02327784028525154/1.0235508837137566


100%|██████████| 150/150 [00:02<00:00, 67.78it/s]


Avg Loss / Perplexity at epoch 12: 0.018680062703788282/1.0188556265524906


100%|██████████| 2849/2849 [01:21<00:00, 35.08it/s]


Avg Loss / Perplexity at epoch 0: 1.6224353987241094/5.065411603807285


100%|██████████| 2849/2849 [01:20<00:00, 35.35it/s]


Avg Loss / Perplexity at epoch 1: 0.1387501439790965/1.1488370202942442


100%|██████████| 2849/2849 [01:20<00:00, 35.29it/s]


Avg Loss / Perplexity at epoch 2: 0.06388797077015905/1.065972971940905


100%|██████████| 2849/2849 [01:22<00:00, 34.70it/s]


Avg Loss / Perplexity at epoch 3: 0.03717300641020203/1.0378725639131876


100%|██████████| 2849/2849 [01:20<00:00, 35.45it/s]


Avg Loss / Perplexity at epoch 4: 0.029910508805524528/1.0303623214725848


100%|██████████| 2849/2849 [01:22<00:00, 34.33it/s]


Avg Loss / Perplexity at epoch 5: 0.02378885342838449/1.0240740653329565


100%|██████████| 2849/2849 [01:20<00:00, 35.51it/s]


Avg Loss / Perplexity at epoch 6: 0.0219201924866445/1.0221622049913275


100%|██████████| 2849/2849 [01:20<00:00, 35.32it/s]


Avg Loss / Perplexity at epoch 7: 0.020569437789778044/1.020782446659285


100%|██████████| 2849/2849 [01:21<00:00, 35.03it/s]


Avg Loss / Perplexity at epoch 8: 0.01910096193998046/1.0192845523684386


100%|██████████| 357/357 [00:04<00:00, 87.10it/s]


Avg Loss / Perplexity at epoch 0: 5.46522851734936/236.32985441064181


100%|██████████| 357/357 [00:04<00:00, 77.07it/s]


Avg Loss / Perplexity at epoch 1: 1.9747678973093754/7.204947172109769


100%|██████████| 357/357 [00:04<00:00, 87.56it/s]


Avg Loss / Perplexity at epoch 2: 0.5311656764575413/1.7009138687182634


100%|██████████| 357/357 [00:04<00:00, 87.54it/s]


Avg Loss / Perplexity at epoch 3: 0.1511793348844312/1.1632052426557775


100%|██████████| 357/357 [00:04<00:00, 78.02it/s]


Avg Loss / Perplexity at epoch 4: 0.059599857782574595/1.0614117459110834


100%|██████████| 357/357 [00:04<00:00, 87.29it/s]


Avg Loss / Perplexity at epoch 5: 0.031833193178812995/1.0323452887060895


100%|██████████| 357/357 [00:04<00:00, 87.03it/s]


Avg Loss / Perplexity at epoch 6: 0.020014086511267834/1.0202157112056471


100%|██████████| 357/357 [00:04<00:00, 79.04it/s]


Avg Loss / Perplexity at epoch 7: 0.01784770414024508/1.0180079261909851


In [16]:
for (tag, train_ppl, model, val_set) in results:
  emb = bpemb_ar if "Arabic" in tag else bpemb_bn if "Bengali" in tag else bpemb_in
  batch_size = 32 if "question" in tag else 4

  reader = DatasetReader(val_set, emb)
  dl = DataLoader(reader, batch_size=batch_size, collate_fn=pad_input, shuffle=True, num_workers=2)

  losses = []
  for batch in tqdm(dl):
    model.eval()
    optimizer.zero_grad()
    batch = tuple(t.to(device) for t in batch)
    inputs = batch[0]
    seq_lens = batch[1]
    (output, loss) = model(inputs, seq_lens)
    losses.append(loss.item())

  avg_loss = sum(losses) / len(losses)
  ppl = exp(avg_loss)
  print(f"""
  {tag}:

  Training perplexity:    {train_ppl}
  Validation perplexity:  {ppl}

  """)

100%|██████████| 476/476 [00:08<00:00, 53.35it/s]



  Arabic document:

  Training perplexity:    1.0190415893064644
  Validation perplexity:  1.0053563051098684
   
  


100%|██████████| 60/60 [00:00<00:00, 86.35it/s] 



  Arabic question:

  Training perplexity:    1.019405440928701
  Validation perplexity:  1.6223864605254148
   
  


100%|██████████| 56/56 [00:01<00:00, 54.77it/s]



  Bengali document:

  Training perplexity:    1.0198800832881798
  Validation perplexity:  1.124875286964304
   
  


100%|██████████| 7/7 [00:00<00:00, 33.84it/s]



  Bengali question:

  Training perplexity:    1.0188556265524906
  Validation perplexity:  3.911623281653864
   
  


100%|██████████| 298/298 [00:04<00:00, 72.35it/s]



  Indonesian document:

  Training perplexity:    1.0192845523684386
  Validation perplexity:  1.0629749875333891
   
  


100%|██████████| 38/38 [00:00<00:00, 97.30it/s] 


  Indonesian question:

  Training perplexity:    1.0180079261909851
  Validation perplexity:  3.869788321353709
   
  



