In [1]:
pip install datasets transformers


Collecting pyarrow>=21.0.0 (from datasets)
  Downloading pyarrow-22.0.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (3.2 kB)
Collecting huggingface-hub>=0.24.0 (from datasets)
  Downloading huggingface_hub-0.36.0-py3-none-any.whl.metadata (14 kB)
Downloading huggingface_hub-0.36.0-py3-none-any.whl (566 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m566.1/566.1 kB[0m [31m11.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pyarrow-22.0.0-cp311-cp311-manylinux_2_28_x86_64.whl (47.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.7/47.7 MB[0m [31m40.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pyarrow, huggingface-hub
  Attempting uninstall: pyarrow
    Found existing installation: pyarrow 19.0.1
    Uninstalling pyarrow-19.0.1:
      Successfully uninstalled pyarrow-19.0.1
  Attempting uninstall: huggingface-hub
    Found existing installation: huggingface-hub 1.0.0rc2
    Uninstalling huggingface-hub-1

In [2]:
from datasets import load_dataset

ds = load_dataset("lucadiliello/newsqa")
print({split: len(ds[split]) for split in ds.keys()})

for split in ds.keys():
    print("Columns:", ds[split].column_names)
    
sample = ds["train"][0]
print("Example data: ")
for k, v in sample.items():
    if isinstance(v, str):
        print(f"{k}: {v[:200]}")
    else:
        print(f"{k}: {v}")

README.md:   0%|          | 0.00/681 [00:00<?, ?B/s]

data/train-00000-of-00001-ec54fbe500fc3b(…):   0%|          | 0.00/29.7M [00:00<?, ?B/s]

data/validation-00000-of-00001-3cf888b12(…):   0%|          | 0.00/1.63M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/74160 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/4212 [00:00<?, ? examples/s]

{'train': 74160, 'validation': 4212}
Columns: ['context', 'question', 'answers', 'key', 'labels']
Columns: ['context', 'question', 'answers', 'key', 'labels']
Example data: 
context: NEW DELHI, India (CNN) -- A high court in northern India on Friday acquitted a wealthy businessman facing the death sentence for the killing of a teen in a case dubbed "the house of horrors."



Monin
question: What was the amount of children murdered?
answers: ['19']
key: da0e6b66e04d439fa1ba23c32de07e50
labels: [{'end': [295], 'start': [294]}]


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class BiLSTMQA(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=128, layers=2, drop=0.2):
        super(BiLSTMQA, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers=layers, bidirectional=True, batch_first=True, dropout=drop)
        self.drop = nn.Dropout(drop)
        self.start_linear = nn.Linear(hidden_dim*2, 1)
        self.end_linear = nn.Linear(hidden_dim*2, 1)

    def forward(self, tokens, mask):
        x = self.embed(tokens)
        x, _ = self.lstm(x)
        x = self.drop(x)
        start = self.start_linear(x).squeeze(-1)
        end = self.end_linear(x).squeeze(-1)
        start = start.masked_fill(mask==0, -1e9)
        end = end.masked_fill(mask==0, -1e9)
        return start, end


In [4]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from datasets import load_dataset

data = load_dataset("lucadiliello/newsqa")
tok = AutoTokenizer.from_pretrained("bert-base-uncased")

def prep_example(item, max_len=128):
    enc = tok(
        item['question'],
        item['context'],
        truncation=True,
        padding='max_length',
        max_length=max_len,
        return_offsets_mapping=True
    )
    offs = enc.pop("offset_mapping")
    schar = item['labels'][0]['start'][0]
    echar = item['labels'][0]['end'][0]
    enc["stpos"] = 0
    enc["enpos"] = 0
    for i, (s, e) in enumerate(offs):
        if s <= schar < e:
            enc["stpos"] = i
        if s < echar <= e:
            enc["enpos"] = i
    return enc

train_subset = data['train'].select(range(2000))
train_tok = [prep_example(x) for x in train_subset]

TrainD = QAData(train_tok)
TL = DataLoader(TrainD, batch_size=8, shuffle=True)

vsize = tok.vocab_size
model = BiLSTMQA(vsize)
CEL = torch.nn.CrossEntropyLoss()
ADAM = optim.Adam(model.parameters(), lr=1e-3)

for ep in range(5):
    total_loss = 0
    for batch in TL:
        ids = batch['ids']
        mask = batch['mask']
        stpos = batch['stpos']
        enpos = batch['enpos']

        ADAM.zero_grad()
        st_logits, en_logits = model(ids, mask)
        loss_st = CEL(st_logits, stpos)
        loss_en = CEL(en_logits, enpos)
        loss = (loss_st + loss_en) / 2
        loss.backward()
        ADAM.step()
        total_loss += loss.item()

    print(f"Epoch {ep+1}, Loss: {total_loss/len(TL):.4f}")


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

NameError: name 'QAData' is not defined

In [None]:
from tqdm import tqdm

def logitsans(ids, st_logits, en_logits):
    st_idx = st_logits.argmax(dim=1)
    en_idx = en_logits.argmax(dim=1)
    answers = []
    for i in range(ids.size(0)):
        if en_idx[i] < st_idx[i]:
            en_idx[i] = st_idx[i]
        ans_tokens = ids[i, st_idx[i]:en_idx[i]+1]
        ans_text = tok.decode(ans_tokens, skip_special_tokens=True)
        answers.append(ans_text)
    return answers

def f1(pred, truth):
    pred_tokens = pred.split()
    truth_tokens = truth.split()
    common = set(pred_tokens) & set(truth_tokens)
    if len(common) == 0:
        return 0.0
    prec = len(common) / len(pred_tokens)
    rec = len(common) / len(truth_tokens)
    return 2 * prec * rec / (prec + rec)

def em(pred, truth):
    return int(pred.strip() == truth.strip())

val_subset = data['validation'].select(range(500))
val_tok = [prep_example(x) for x in val_subset]
ValD = QAData(val_tok)
VL = DataLoader(ValD, batch_size=2)

model.eval()
total_f1 = 0
total_em = 0
count = 0

with torch.no_grad():
    for batch_idx, batch in enumerate(VL):
        ids = batch['ids']
        mask = batch['mask']
        st_logits, en_logits = model(ids, mask)
        preds = logitsans(ids, st_logits, en_logits)
        for i, item in enumerate(val_tok[batch_idx*2 : batch_idx*2 + ids.size(0)]):
            true_ans = item['input_ids'][item['stpos']:item['enpos']+1]
            true_text = tok.decode(true_ans, skip_special_tokens=True)
            total_f1 += f1(preds[i], true_text)
            total_em += em(preds[i], true_text)
            count += 1

print(f"Validation EM: {total_em/count:.4f}, F1: {total_f1/count:.4f}")
