# Extract Keywords - BART X BERT

In [8]:
import os
import re
import json
from tqdm.auto import tqdm
from langchain.text_splitter import RecursiveCharacterTextSplitter
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForTokenClassification

In [10]:
file_path = './data/context_id_train_dev_en.json'

with open(file_path, 'r') as f:
    data = json.load(f)

print(f"Total num: {len(data)}")
data[:5]

Total num: 13932


[{'context': 'WikiLeaks () is an international non-profit organisation that publishes secret information, news leaks, and classified media provided by anonymous sources. Its website, initiated in 2006 in Iceland by the organisation Sunshine Press, claims a database of 10 million documents in 10 years since its launch. Julian Assange, an Australian Internet activist, is generally described as its founder and director. Kristinn Hrafnsson is its editor-in-chief.',
  'id': 0},
 {'context': 'The war in Europe concluded with an invasion of Germany by the Western Allies and the Soviet Union, culminating in the capture of Berlin by Soviet troops, the suicide of Adolf Hitler and the German unconditional surrender on 8 May 1945. Following the Potsdam Declaration by the Allies on 26 July 1945 and the refusal of Japan to surrender under its terms, the United States dropped atomic bombs on the Japanese cities of Hiroshima and Nagasaki on 6 and 9 August respectively. With an invasion of the Japanese

# BART method

**Set up BART**

In [11]:
os.environ["TOKENIZERS_PARALLELISM"] = "True"

bart_tokenizer = AutoTokenizer.from_pretrained("Andyrasika/bart_tech_keywords")
bart_model = AutoModelForSeq2SeqLM.from_pretrained("Andyrasika/bart_tech_keywords")

keyword_extractor = pipeline("text2text-generation", 
                             model=bart_model, 
                             tokenizer=bart_tokenizer,
                             max_new_tokens=50)
                            #  device=0) # GPU

**Handle long context**

In [12]:
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=4000,
    chunk_overlap=20,
    length_function=len,
    is_separator_regex=False,
)

**Avoid IndexError when input exceeds model limits(1024)**

In [None]:
keywords = []
for i in tqdm(range(len(data))):
    text = data[i]['context']
    num_tokens = len(bart_tokenizer.encode(text))
    if num_tokens >= 1024:
        split_texts = text_splitter.create_documents([data[i]['context']])
        sub_keys = []
        for t in split_texts:
            sub_keys.append(keyword_extractor(t.page_content)[0]['generated_text'])
        keywords.append(sub_keys)
    else:
        keywords.append(keyword_extractor(text)[0]['generated_text'])

  0%|          | 0/13932 [00:00<?, ?it/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (1374 > 1024). Running this sequence through the model will result in indexing errors


**Save results**

In [74]:
def flatten_nested_list(nested_list):
    result = []
    for item in nested_list:
        if isinstance(item, list):
            flattened_item = ', '.join(item)
            result.append(flattened_item)
        else:
            result.append(item)
    return result

In [None]:
keywords[6159:6162]

['United States of America, Federal Republic',
 ['Dram, National Currency, Central Bank of Armenia, ISO, UNICODE',
  'Armenia, Banknotes, Coins, Security Printing',
  "Collector Coins, Banknotes, Coins, Central Bank of Armenia, Noah's Ark"],
 'Buenos Aires, Argentina, Río de la Plata']

In [None]:
flattened_list = flatten_nested_list(keywords)
flattened_list[6159:6162]

['United States of America, Federal Republic',
 "Dram, National Currency, Central Bank of Armenia, ISO, UNICODE, Armenia, Banknotes, Coins, Security Printing, Collector Coins, Banknotes, Coins, Central Bank of Armenia, Noah's Ark",
 'Buenos Aires, Argentina, Río de la Plata']

In [None]:
len(flattened_list), len(keywords)

(13932, 13932)

In [None]:
strip_key_words = [line.strip() for line in flattened_list]
ids = [item['id'] for item in data]

assert len(ids) == len(strip_key_words)

print(f"Total num: {len(strip_key_words)}")

Total num: 13932


In [None]:
documents = [
    {"context_id": id, "keywords": [k.strip() for k in kw.split(',')]} 
    for id, kw in zip(ids, strip_key_words)
]

json_data = json.dumps(documents, indent=4)
# print(json_data)

In [None]:
with open('./data/BART_Keywords.json', 'w') as f:
    f.write(json_data)

# BERT method

**Set up BART**

In [None]:
tokenizer = AutoTokenizer.from_pretrained("yanekyuk/bert-uncased-keyword-extractor")
model = AutoModelForTokenClassification.from_pretrained("yanekyuk/bert-uncased-keyword-extractor")
kw_extractor = pipeline("token-classification", 
                        model="yanekyuk/bert-uncased-keyword-extractor",
                        tokenizer = tokenizer)

In [None]:
def segment_text(id, text, max_length=512, step_size=256):
    """
    After exceeding the maximum bert input limit, use the slide window to truncate to size=256
    """
    segments = []
    start = 0
    while start < len(text):
        if start + max_length > len(text):
            segment = text[start:]
            segments.append(segment)
            id.append(id[-1]+1)
            break
        
        segment = text[start:start+max_length]
        segments.append(segment)
        id.append(id[-1])
        start += step_size

    return segments

def preprocess(data):
    preprocessed_data = []
    id_context=[0]
    for item in data:
        context = item['context']
        context = re.sub(r'\s+', ' ', context).strip()
        segmented_contexts = segment_text(id_context, context)
        preprocessed_data.extend(segmented_contexts) 
    return preprocessed_data, id_context

def KeyWords_generator(outputs, text):
    keywords = []
    current_keyword = None
    current_start = None
    current_end = None

    for item in outputs:
        if item['entity'] == 'B-KEY':
            if current_keyword:
                keyword = text[current_start:current_end]
                if keyword.count(' ') < 4:
                    keywords.append(keyword) 
                current_keyword = None
            current_start = item['start']
            current_keyword = item['word']
        elif item['entity'] == 'I-KEY' and current_keyword:
            current_keyword = item['word']
        current_end = item['end']

    if current_keyword:
        keyword = text[current_start:current_end]
        if keyword.count(' ') < 4:
            keywords.append(keyword)

    return keywords

In [None]:
texts, id_context = preprocess(data)
FILE_WRITE = "./data/BERT_Keywords.json"

def main():
    doc_kw = []
    with open(FILE_WRITE, 'w') as file:
        keywords = KeyWords_generator(kw_extractor(texts[0]), texts[0])
        for i in tqdm(range(1, len(texts))):
            if id_context[i] != id_context[i-1]:
                kw_list = list(set(keywords))
                print(kw_list)
                data_dic = {
                    'context': data[id_context[i] - 1],
                    'keywords': kw_list
                }
                # file.write(f"Text {id_context[i]-1} Keywords: {', '.join(list(set(keywords)))}\n\n")
                doc_kw.append(data_dic)
                keywords = KeyWords_generator(kw_extractor(texts[i]), texts[i])
            else:
                keywords += KeyWords_generator(kw_extractor(texts[i]), texts[i])
        json.dump(doc_kw, file, ensure_ascii=False, indent=4)
    
       
main()



  0%|          | 0/30042 [00:00<?, ?it/s]

['anonymous', 'editor-in-chief', 'Julian Assange', 'WikiLeaks', 'Sunshine Press', 'Kristinn Hrafnsson', 'Iceland', 'Internet']


  0%|          | 3/30042 [00:01<3:34:17,  2.34it/s]

['Soviet Union', 'Hiroshima', 'Nagasaki', 'Japan', 'war crimes', 'unconditional surrender', 'Adolf Hitler', 'atomic bombs', 'Potsdam Declaration', 'Western Allies', 'Soviet', 'United States']


  0%|          | 6/30042 [00:02<3:17:44,  2.53it/s]

['marriage equality', 'Loving v. Virginia', 'Due Process Clause', 'gay marriage', 'separate marriage', 'Same-sex marriage']


  0%|          | 6/30042 [00:03<4:17:08,  1.95it/s]


# Load output keywords again

In [2]:
bart_file_path = './data/BART_Keywords.json'
bert_file_path = './data/BERT_Keywords.json'

**BART**

In [3]:
with open(bart_file_path, 'r') as f:
    data_bart = json.load(f)

print(f"Total num: {len(data_bart)}")

Total num: 13932


**BERT**

In [4]:
with open(bert_file_path, 'r') as f:
    data_bert = json.load(f)

print(f"Total num: {len(data_bert)}")

Total num: 13932


# BART X BERT

In [6]:
merged_data = {}
order = []

# Bart -> merged_data
for item in data_bart:
    context_id = item['context_id']
    merged_data[context_id] = item['keywords']
    order.append(context_id)

# Bert -> merged_data
for item in data_bert:
    context_id = item['context_id']
    if context_id not in merged_data:
        merged_data[context_id] = item['keywords']
        order.append(context_id)
    else:
        # merge the keywords and drop duplicates.
        merged_data[context_id].extend([kw for kw in item['keywords'] if kw not in merged_data[context_id]])

merged_list = [{'context_id': key, 'keywords': merged_data[key]} for key in order]

print(f"Total num: {len(merged_list)}")
merged_list


Total num: 13932


[{'context_id': 0,
  'keywords': ['WikiLeaks',
   'News Leaks',
   'Sunshine Press',
   'Kristinn Hrafnsson',
   'Julian Assange',
   'Iceland',
   'editor-in-chief',
   'anonymous',
   'Internet']},
 {'context_id': 1,
  'keywords': ['World War II',
   'Europe',
   'Japan',
   'Soviet Union',
   'War Crimes Trials',
   'atomic bombs',
   'United States',
   'Germans',
   'Western Allies',
   'Hiroshima',
   'Potsdam Declaration']},
 {'context_id': 2,
  'keywords': ['Arab War',
   'Casualties',
   'Henry Laurens',
   'Egypt',
   'Jordan',
   'Aref al-Aref',
   'Palestinians']},
 {'context_id': 3,
  'keywords': ['Capitalism',
   'European Transformation',
   'Adam Smith',
   'Max Weber',
   'Fernand Braudel',
   'Henri Pirenne',
   'Paul Sweezy',
   'Venice',
   'nation',
   'capitalism',
   'Malacca',
   'Werner Sombart',
   'Genoa',
   'United Provinces',
   'Holland',
   'Dutch Republic',
   'Netherlands']},
 {'context_id': 4,
  'keywords': ['World War I',
   'WWI',
   'Great War',
  

In [7]:
with open('./data/BART_X_BERT_Keywords.json', 'w', encoding='utf-8') as file:
    json.dump(merged_list, file, ensure_ascii=False, indent=4)


In [26]:
text = data[0]['context']
text

'WikiLeaks () is an international non-profit organisation that publishes secret information, news leaks, and classified media provided by anonymous sources. Its website, initiated in 2006 in Iceland by the organisation Sunshine Press, claims a database of 10 million documents in 10 years since its launch. Julian Assange, an Australian Internet activist, is generally described as its founder and director. Kristinn Hrafnsson is its editor-in-chief.'

In [15]:
keyword_extractor(text)

[{'generated_text': 'WikiLeaks, News Leaks, Sunshine Press'}]

In [27]:
# text = "Your input text here"
inputs = bart_tokenizer.encode(text, return_tensors="pt")
outputs = bart_model.generate(inputs, output_scores=True, return_dict_in_generate=True)
scores = outputs['scores']  # 每步的分数（概率分布）

In [28]:
import torch
import torch.nn.functional as F

# 假设 scores 是您的模型输出
scores_example = scores[0]  # 以第一个时间步的输出为例

# 应用 Softmax 来获取概率分布
probabilities = F.softmax(scores_example, dim=-1)

# 找出每个时间步概率最高的词的索引
max_indices = torch.argmax(probabilities, dim=-1)

# 解码为词
decoded_words = [bart_tokenizer.decode([idx], skip_special_tokens=True) for idx in max_indices]

print(decoded_words)



['', '', '', '']


In [29]:
import torch
import torch.nn.functional as F

# 假设我们处理第一个时间步的输出，这里采用循环来处理所有时间步
for scores_tensor in scores:
    # 应用 Softmax 来获取概率分布
    probabilities = F.softmax(scores_tensor, dim=-1)

    # 对每个时间步，找出概率最高的词的索引
    max_indices = torch.argmax(probabilities, dim=-1)

    decoded_words = []
    for idx in max_indices:
        # 解码每个索引，跳过特殊标记
        word = bart_tokenizer.decode(idx, skip_special_tokens=True)
        if word:  # 确保解码的词不为空
            decoded_words.append(word)
    
    print(decoded_words)


[]
['Wiki', 'Wiki', 'Wiki', 'Wiki']
['Leaks', 'Wiki', 'ileaks', 'Wiki']
[',', 'Leaks', ',', 'Leaks']
[' News', ',', ' News', ' News']
[' Le', ' Le', ' Information', '-']
['aks', 'ak', ',', 'profit']
[',', ' News', ',', ',']
[' Sunshine', ' Le', ' News', ' News']
['aks', ' Press', ' Le', ' Le']
[',', 'aks', 'aks']
[' Sunshine', ' Julian', ',']
[' Sunshine', ' Assange', ' Press', ' Media']
[' Press', ' Media']


In [None]:
keywords = []
for i in tqdm(range(len(data))):
    text = data[i]['context']
    num_tokens = len(bart_tokenizer.encode(text))
    if num_tokens >= 1024:
        split_texts = text_splitter.create_documents([data[i]['context']])
        sub_keys = []
        for t in split_texts:
            sub_keys.append(keyword_extractor(t.page_content)[0]['generated_text'])
        keywords.append(sub_keys)
    else:
        keywords.append(keyword_extractor(text)[0]['generated_text'])