# MLQA Extract `English` Keywords (Query/Context)

In [None]:
import os
import re
import json
from tqdm.auto import tqdm
from langchain.text_splitter import RecursiveCharacterTextSplitter
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForTokenClassification

from torch.utils.data import DataLoader, Dataset
import torch
from torch.nn.functional import softmax

2024-02-21 23:06:13.986858: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-21 23:06:13.986980: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-21 23:06:14.143813: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [None]:
file_path = '/kaggle/input/m-l-q-a/Context_EN.json'

with open(file_path, 'r') as f:
    data = json.load(f)

print(f"Total num: {len(data)}")
data[:1]

Total num: 9916


[{'id': 0,
  'title': 'Area 51',
  'context': 'In 1994, five unnamed civilian contractors and the widows of contractors Walter Kasza and Robert Frost sued the USAF and the United States Environmental Protection Agency. Their suit, in which they were represented by George Washington University law professor Jonathan Turley, alleged they had been present when large quantities of unknown chemicals had been burned in open pits and trenches at Groom. Biopsies taken from the complainants were analyzed by Rutgers University biochemists, who found high levels of dioxin, dibenzofuran, and trichloroethylene in their body fat. The complainants alleged they had sustained skin, liver, and respiratory injuries due to their work at Groom, and that this had contributed to the deaths of Frost and Kasza. The suit sought compensation for the injuries they had sustained, claiming the USAF had illegally handled toxic materials, and that the EPA had failed in its duty to enforce the Resource Conservation an

# BART method (Query)

**Set up BART**

In [None]:
class TextDataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)

tokenizer = AutoTokenizer.from_pretrained("Andyrasika/bart_tech_keywords")
model = AutoModelForSeq2SeqLM.from_pretrained("Andyrasika/bart_tech_keywords")

'''
    * For "Question", set `texts` with `[item['question'] for item in data]`
    * For "Context", set `texts` with `[item['context'] for item in data]`
'''
texts = [item['context'] for item in data]

encodings = tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="pt")

# Use dataloader
dataset = TextDataset(encodings)
dataloader = DataLoader(dataset, batch_size=16)

tokenizer_config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/279 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.71k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/292 [00:00<?, ?B/s]

In [None]:
model.to("cuda")
text_outputs = []

for batch in tqdm(dataloader, desc="Generating text"):
    batch = {k: v.to("cuda") for k, v in batch.items()} 
    with torch.no_grad():
        outputs = model.generate(**batch)
        decoded_texts = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
        text_outputs.extend(decoded_texts)

Generating text:   0%|          | 0/620 [00:00<?, ?it/s]

  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


**Save results**

In [None]:
assert len(texts) == len(text_outputs)
len(texts), len(text_outputs)

(9916, 9916)

In [None]:
strip_key_words = [line.strip() for line in text_outputs]

'''
    * For "Question", set `ids` with `item['id']`
    * For "Context", set `ids` with `item['context_id']`
'''
ids = [item['id'] for item in data]
titles = [item['title'] for item in data]

assert len(ids) == len(strip_key_words)
assert len(ids) == len(titles)

print(f"Total num: {len(strip_key_words)}")

Total num: 9916


In [None]:
documents = [
    {
        "id": id, 
        "title": title,
        "keywords": [k.strip() for k in kw.split(',') if k.strip()]  # Strip and remove empty values
    } 
    for id, title, kw in zip(ids, titles, strip_key_words)
]

json_data = json.dumps(documents, indent=4)


In [None]:
with open('/kaggle/working/Context_BART_Keywords_MLQA.json', 'w') as f:
    f.write(json_data)

# with open('./data/BART_Keywords.json', 'w') as f:
#     f.write(json_data)

## Artificial add prob

In [None]:
file_path_ = '/kaggle/working/Query_BART_Keywords_MLQA.json'

with open(file_path_, 'r') as f:
    Query_BART_Keywords = json.load(f)

print(f"Total num: {len(Query_BART_Keywords)}")
Query_BART_Keywords[:5]

Total num: 11590


[{'id': 'a4968ca8a18de16aa3859be760e43dbd3af3fce9',
  'keywords': ['Biopsies', 'Analysis']},
 {'id': 'f251ea56c4f1aa1df270137f7e6d89c0cc1b6ef4',
  'keywords': ['Robert Frost', 'Walter Kasza', 'Lawsuit']},
 {'id': '04ecd5555635bc05fd2f379d1b9027edd663cebf',
  'keywords': ['Lawsuit', 'Groom']},
 {'id': 'd066a75dbe8cd3e2b57c415a8eb54a08dc7e72a7',
  'keywords': ['Complaints', 'Allegations']},
 {'id': 'c5f545baccd8ea8adb83f75756f4832340600bd9',
  'keywords': ['Aerospace Magazine']}]

In [None]:
for item in Query_BART_Keywords:
    num_keywords = len(item["keywords"])
    probs = [max(0.9 - 0.1 * i, 0.1) for i in range(num_keywords)]
    item["prob"] = probs

Query_BART_Keywords[:5]

[{'id': 'a4968ca8a18de16aa3859be760e43dbd3af3fce9',
  'keywords': ['Biopsies', 'Analysis'],
  'prob': [0.9, 0.8]},
 {'id': 'f251ea56c4f1aa1df270137f7e6d89c0cc1b6ef4',
  'keywords': ['Robert Frost', 'Walter Kasza', 'Lawsuit'],
  'prob': [0.9, 0.8, 0.7]},
 {'id': '04ecd5555635bc05fd2f379d1b9027edd663cebf',
  'keywords': ['Lawsuit', 'Groom'],
  'prob': [0.9, 0.8]},
 {'id': 'd066a75dbe8cd3e2b57c415a8eb54a08dc7e72a7',
  'keywords': ['Complaints', 'Allegations'],
  'prob': [0.9, 0.8]},
 {'id': 'c5f545baccd8ea8adb83f75756f4832340600bd9',
  'keywords': ['Aerospace Magazine'],
  'prob': [0.9]}]

In [None]:
json_data = json.dumps(Query_BART_Keywords, indent=4)

with open('/kaggle/working/Query_BART_Keywords_artificially_prob_MLQA.json', 'w') as f:
    f.write(json_data)

## Bart generate prob

**Set model to evaluate mode**

In [None]:
model.eval()
model.to("cuda")

all_input_texts = []
all_generated_texts = []
all_probabilities = []

for batch in tqdm(dataloader):
    input_ids = batch['input_ids'].to(model.device)
    attention_mask = batch['attention_mask'].to(model.device)

    for idx in range(input_ids.size(0)):  # Iter all in a batch
        decoded_ids = torch.full((1, 1), tokenizer.bos_token_id, dtype=torch.long, device=model.device)
        seq_probabilities = []

        input_text = tokenizer.decode(input_ids[idx].tolist(), skip_special_tokens=True)
        all_input_texts.append(input_text) 
        
        while True:
            with torch.no_grad():
                outputs = model(input_ids[idx:idx+1], attention_mask=attention_mask[idx:idx+1], decoder_input_ids=decoded_ids)
                logits = outputs.logits[:, -1, :]
                probs = softmax(logits, dim=-1)
                next_token_id = torch.argmax(probs, dim=-1).unsqueeze(-1)
                next_token_prob = probs[0, next_token_id.item()].item()

                if next_token_id == tokenizer.eos_token_id or decoded_ids.size(1) >= 30:
                    break

                decoded_ids = torch.cat([decoded_ids, next_token_id], dim=-1)
                seq_probabilities.append(next_token_prob)

        generated_text = tokenizer.decode(decoded_ids[:, 1:].squeeze().tolist(), skip_special_tokens=True)

        # Splitting the generated text and probability
        segments = generated_text.split(',')
        segment_probs = []
        start_idx = 0
        for segment in segments:
            segment_tokens = tokenizer.tokenize(segment.strip())
            segment_length = len(segment_tokens)
            # Calc avg prob of the segment
            if segment_length > 0:
                segment_prob = sum(seq_probabilities[start_idx:start_idx+segment_length]) / segment_length
                segment_probs.append(f"{segment_prob:.6f}")
            start_idx += segment_length

        all_generated_texts.append(generated_text)
        all_probabilities.append(",".join(segment_probs))

  0%|          | 0/725 [00:00<?, ?it/s]

  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


In [None]:
strip_key_words = [line.strip() for line in all_generated_texts]

'''
    * For "Question", set `ids` with `[item['id'] for item in data]`
    * For "Context", set `ids` with `[item['context_id'] for item in data]`
'''
ids = [item['id'] for item in data]

len(ids),len(strip_key_words)

# print(f"Total num: {len(strip_key_words)}")

In [None]:
# documents = [
#     {"context_id": id, "keywords": [k.strip() for k in kw.split(',')], "prob": [k.strip() for k in p.split(',')]} 
#     for id, kw, p in zip(ids, strip_key_words, all_probabilities)
# ]

documents = [
    {
        "id": id, 
        "keywords": [k.strip() for k in kw.split(',')],
        "prob": [float(p) for p in prob.split(',')]
    } 
    for id, kw, prob in zip(ids, strip_key_words, all_probabilities)
    if kw.strip() != '' and prob.strip() != ''
]

json_data = json.dumps(documents, indent=4)
# print(json_data)


In [None]:
with open('/kaggle/working/Query_BART_Keywords_with_prob_MLQA.json', 'w') as f:
    f.write(json_data)

In [None]:
file_path_ = '/kaggle/working/Query_BART_Keywords_with_prob_MLQA.json'

with open(file_path_, 'r') as f:
    data_ = json.load(f)

print(f"Total num: {len(data_)}")
data_[:5]