# Extract Keywords (Query)

In [2]:
import os
import re
import json
from tqdm.auto import tqdm
from langchain.text_splitter import RecursiveCharacterTextSplitter
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForTokenClassification

from torch.utils.data import DataLoader, Dataset
import torch
from torch.nn.functional import softmax

In [2]:
file_path = './data/context_qa_en.json'

with open(file_path, 'r') as f:
    data = json.load(f)

print(f"Total num: {len(data)}")
data[:5]

Total num: 14656


[{'question': 'In what year did WikiLeaks first display information on the Internet?',
  'answer': '2006',
  'context': 'WikiLeaks () is an international non-profit organisation that publishes secret information, news leaks, and classified media provided by anonymous sources. Its website, initiated in 2006 in Iceland by the organisation Sunshine Press, claims a database of 10 million documents in 10 years since its launch. Julian Assange, an Australian Internet activist, is generally described as its founder and director. Kristinn Hrafnsson is its editor-in-chief.',
  'context_id': 0},
 {'question': 'Which country was defeated in the Second World War?',
  'answer': 'Germany',
  'context': 'The war in Europe concluded with an invasion of Germany by the Western Allies and the Soviet Union, culminating in the capture of Berlin by Soviet troops, the suicide of Adolf Hitler and the German unconditional surrender on 8 May 1945. Following the Potsdam Declaration by the Allies on 26 July 1945 

# BART method

**Set up BART**

In [None]:
class TextDataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)

tokenizer = AutoTokenizer.from_pretrained("Andyrasika/bart_tech_keywords")
model = AutoModelForSeq2SeqLM.from_pretrained("Andyrasika/bart_tech_keywords")

'''
    * For "Question", set `texts` with `[item['question'] for item in data]`
    * For "Context", set `texts` with `[item['context'] for item in data]`
'''
texts = [item['question'] for item in data]

encodings = tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="pt")

# Use dataloader
dataset = TextDataset(encodings)
dataloader = DataLoader(dataset, batch_size=16)

tokenizer_config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/279 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.71k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/292 [00:00<?, ?B/s]

In [3]:
model.to("cuda")
text_outputs = []

for batch in tqdm(dataloader, desc="Generating text"):
    batch = {k: v.to("cuda") for k, v in batch.items()} 
    with torch.no_grad():
        outputs = model.generate(**batch)
        decoded_texts = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
        text_outputs.extend(decoded_texts)

**Save results**

In [None]:
assert len(texts) == len(text_outputs)
len(texts), len(text_outputs)

In [74]:
strip_key_words = [line.strip() for line in text_outputs]

'''
    * For "Question", set `ids` with `item['id']`
    * For "Context", set `ids` with `item['context_id']`
'''
ids = [item['id'] for item in data]

assert len(ids) == len(strip_key_words)

print(f"Total num: {len(strip_key_words)}")

In [None]:
documents = [
    {   
        "id": id, 
        "keywords": [k.strip() for k in kw.split(',')]
    } 
    for id, kw in zip(ids, strip_key_words)
]

json_data = json.dumps(documents, indent=4)

In [None]:
with open('/kaggle/working/Query_BART_Keywords.json', 'w') as f:
    f.write(json_data)

# with open('./data/BART_Keywords.json', 'w') as f:
#     f.write(json_data)

## Artificial add prob

In [4]:
file_path_ = '/kaggle/working/Query_BART_Keywords.json'

with open(file_path_, 'r') as f:
    Query_BART_Keywords = json.load(f)

print(f"Total num: {len(Query_BART_Keywords)}")
Query_BART_Keywords[:5]

Total num: 14656


[{'context_id': 0, 'keywords': ['WikiLeaks', 'Information Displaying']},
 {'context_id': 1, 'keywords': ['Second World War', 'Defeat']},
 {'context_id': 2, 'keywords': ['Arab-Israeli War', 'Deaths']},
 {'context_id': 3, 'keywords': ['Capitalist Society']},
 {'context_id': 4, 'keywords': ['First World War']}]

In [5]:
for item in Query_BART_Keywords:
    num_keywords = len(item["keywords"])
    probs = [max(0.9 - 0.1 * i, 0.1) for i in range(num_keywords)]
    item["prob"] = probs

Query_BART_Keywords[:5]

[{'context_id': 0,
  'keywords': ['WikiLeaks', 'Information Displaying'],
  'prob': [0.9, 0.8]},
 {'context_id': 1,
  'keywords': ['Second World War', 'Defeat'],
  'prob': [0.9, 0.8]},
 {'context_id': 2,
  'keywords': ['Arab-Israeli War', 'Deaths'],
  'prob': [0.9, 0.8]},
 {'context_id': 3, 'keywords': ['Capitalist Society'], 'prob': [0.9]},
 {'context_id': 4, 'keywords': ['First World War'], 'prob': [0.9]}]

In [9]:
json_data = json.dumps(Query_BART_Keywords, indent=4)

with open('/kaggle/working/Query_BART_Keywords_artificially_prob.json', 'w') as f:
    f.write(json_data)

## Bart generate prob

In [None]:
# class TextDataset(Dataset):
#     def __init__(self, encodings):
#         self.encodings = encodings

#     def __getitem__(self, idx):
#         return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

#     def __len__(self):
#         return len(self.encodings.input_ids)

# tokenizer = AutoTokenizer.from_pretrained("Andyrasika/bart_tech_keywords")
# model = AutoModelForSeq2SeqLM.from_pretrained("Andyrasika/bart_tech_keywords")

# texts = [item['question'] for item in data]

# encodings = tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="pt")

# # Use dataloader
# dataset = TextDataset(encodings)
# dataloader = DataLoader(dataset, batch_size=32)

**Set model to evaluate mode**

In [None]:
model.eval()
model.to("cuda")

all_input_texts = []
all_generated_texts = []
all_probabilities = []

for batch in tqdm(dataloader):
    input_ids = batch['input_ids'].to(model.device)
    attention_mask = batch['attention_mask'].to(model.device)

    for idx in range(input_ids.size(0)):  # Iter all in a batch
        decoded_ids = torch.full((1, 1), tokenizer.bos_token_id, dtype=torch.long, device=model.device)
        seq_probabilities = []

        input_text = tokenizer.decode(input_ids[idx].tolist(), skip_special_tokens=True)
        all_input_texts.append(input_text) 
        
        while True:
            with torch.no_grad():
                outputs = model(input_ids[idx:idx+1], attention_mask=attention_mask[idx:idx+1], decoder_input_ids=decoded_ids)
                logits = outputs.logits[:, -1, :]
                probs = softmax(logits, dim=-1)
                next_token_id = torch.argmax(probs, dim=-1).unsqueeze(-1)
                next_token_prob = probs[0, next_token_id.item()].item()

                if next_token_id == tokenizer.eos_token_id or decoded_ids.size(1) >= 30:
                    break

                decoded_ids = torch.cat([decoded_ids, next_token_id], dim=-1)
                seq_probabilities.append(next_token_prob)

        generated_text = tokenizer.decode(decoded_ids[:, 1:].squeeze().tolist(), skip_special_tokens=True)

        # Splitting the generated text and probability
        segments = generated_text.split(',')
        segment_probs = []
        start_idx = 0
        for segment in segments:
            segment_tokens = tokenizer.tokenize(segment.strip())
            segment_length = len(segment_tokens)
            # Calc avg prob of the segment
            if segment_length > 0:
                segment_prob = sum(seq_probabilities[start_idx:start_idx+segment_length]) / segment_length
                segment_probs.append(f"{segment_prob:.6f}")
            start_idx += segment_length

        all_generated_texts.append(generated_text)
        all_probabilities.append(",".join(segment_probs))

  0%|          | 0/916 [00:00<?, ?it/s]

  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


In [None]:
strip_key_words = [line.strip() for line in all_generated_texts]

'''
    * For "Question", set `ids` with `[item['id'] for item in data]`
    * For "Context", set `ids` with `[item['context_id'] for item in data]`
'''
ids = [item['id'] for item in data]

len(ids),len(strip_key_words)

# print(f"Total num: {len(strip_key_words)}")

(14656, 14656)

In [None]:
# documents = [
#     {"context_id": id, "keywords": [k.strip() for k in kw.split(',')], "prob": [k.strip() for k in p.split(',')]} 
#     for id, kw, p in zip(ids, strip_key_words, all_probabilities)
# ]

documents = [
    {
        "id": id, 
        "keywords": [k.strip() for k in kw.split(',')],
        "prob": [float(p) for p in prob.split(',')]
    } 
    for id, kw, prob in zip(ids, strip_key_words, all_probabilities)
    if kw.strip() != '' and prob.strip() != ''
]

json_data = json.dumps(documents, indent=4)
# print(json_data)


In [None]:
with open('/kaggle/working/Query_BART_Keywords_with_prob.json', 'w') as f:
    f.write(json_data)

In [None]:
file_path_ = '/kaggle/working/BART_Keywords_with_prob.json'

with open(file_path_, 'r') as f:
    data_ = json.load(f)

print(f"Total num: {len(data_)}")
data_[:5]

Total num: 14439


[{'id': -2745676532874802308,
  'keywords': ['WikiLeaks', 'Internet'],
  'prob': [0.938878, 0.677631]},
 {'id': -1876566410524927695,
  'keywords': ['SecondWorld War II'],
  'prob': [0.774945]},
 {'id': 4765871946129153693,
  'keywords': ['ArabIsraeli War', 'Arab Soldiers'],
  'prob': [0.779647, 0.733196]},
 {'id': -420328556017322612,
  'keywords': ['CapitalCapitalist Society'],
  'prob': [0.905467]},
 {'id': -4747113494541566005,
  'keywords': ['FirstWorld War I'],
  'prob': [0.709858]}]

**Check empty output**

In [8]:

empty_keywords_context_ids = [entry['context_id'] for entry in data_ if not entry['keywords']]

print(empty_keywords_context_ids)


[]


In [None]:
empty_id = [d['id'] for d in data_ if d['keywords']==['']]
len(empty_id), empty_id[:5]

(0, [])