# Extract Keywords - BART X BERT

In [22]:
import os
import json
from tqdm.auto import tqdm
from langchain.text_splitter import RecursiveCharacterTextSplitter
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, DebertaV2Tokenizer, DebertaV2Model

In [17]:
file_path = './data/context_id_train&dev_en.json'

with open(file_path, 'r') as f:
    data = json.load(f)

print(f"Total num: {len(data)}")
data[:5]

Total num: 13932


[{'context': 'WikiLeaks () is an international non-profit organisation that publishes secret information, news leaks, and classified media provided by anonymous sources. Its website, initiated in 2006 in Iceland by the organisation Sunshine Press, claims a database of 10 million documents in 10 years since its launch. Julian Assange, an Australian Internet activist, is generally described as its founder and director. Kristinn Hrafnsson is its editor-in-chief.',
  'id': 0},
 {'context': 'The war in Europe concluded with an invasion of Germany by the Western Allies and the Soviet Union, culminating in the capture of Berlin by Soviet troops, the suicide of Adolf Hitler and the German unconditional surrender on 8 May 1945. Following the Potsdam Declaration by the Allies on 26 July 1945 and the refusal of Japan to surrender under its terms, the United States dropped atomic bombs on the Japanese cities of Hiroshima and Nagasaki on 6 and 9 August respectively. With an invasion of the Japanese

# BART method

**Set up BART**

In [21]:
os.environ["TOKENIZERS_PARALLELISM"] = "True"

bart_tokenizer = AutoTokenizer.from_pretrained("Andyrasika/bart_tech_keywords")
bart_model = AutoModelForSeq2SeqLM.from_pretrained("Andyrasika/bart_tech_keywords")

keyword_extractor = pipeline("text2text-generation", 
                             model=bart_model, 
                             tokenizer=bart_tokenizer,
                             max_new_tokens=50)
                            #  device=0) # GPU

**Handle long context**

In [23]:
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=4000,
    chunk_overlap=20,
    length_function=len,
    is_separator_regex=False,
)

**Avoid IndexError when input exceeds model limits(1024)**

In [24]:
keywords = []
for i in tqdm(range(len(data))):
    text = data[i]['context']
    num_tokens = len(bart_tokenizer.encode(text))
    if num_tokens >= 1024:
        split_texts = text_splitter.create_documents([data[i]['context']])
        sub_keys = []
        for t in split_texts:
            sub_keys.append(keyword_extractor(t.page_content)[0]['generated_text'])
        keywords.append(sub_keys)
    else:
        keywords.append(keyword_extractor(text)[0]['generated_text'])

  0%|          | 0/13932 [00:00<?, ?it/s]

KeyboardInterrupt: 

**Save results**

In [25]:
key_words = [line.strip() for line in keywords]
ids = [item['id'] for item in data]

assert len(ids) == len(key_words)

print(f"Total num: {len(key_words)}")

Total num: 7


['WikiLeaks, News Leaks, Sunshine Press',
 'World War II, Europe, Japan, Soviet Union, War Crimes Trials',
 'Arab War, Casualties, Henry Laurens',
 'Capitalism, European Transformation, Adam Smith, Max Weber, Fernand Braudel, Henri Pirenne, Paul Sweezy',
 'World War I, WWI, Great War, Genocides, Influenza Pandemic']

In [41]:
documents = [
    {"context_id": id, "keywords": [k.strip() for k in kw.split(',')]} 
    for id, kw in zip(ids, key_words)
]

json_data = json.dumps(documents, indent=4)
print(json_data)

[
    {
        "context_id": 0,
        "keywords": [
            "WikiLeaks",
            "News Leaks",
            "Sunshine Press"
        ]
    },
    {
        "context_id": 1,
        "keywords": [
            "World War II",
            "Europe",
            "Japan",
            "Soviet Union",
            "War Crimes Trials"
        ]
    },
    {
        "context_id": 2,
        "keywords": [
            "Arab War",
            "Casualties",
            "Henry Laurens"
        ]
    },
    {
        "context_id": 3,
        "keywords": [
            "Capitalism",
            "European Transformation",
            "Adam Smith",
            "Max Weber",
            "Fernand Braudel",
            "Henri Pirenne",
            "Paul Sweezy"
        ]
    },
    {
        "context_id": 4,
        "keywords": [
            "World War I",
            "WWI",
            "Great War",
            "Genocides",
            "Influenza Pandemic"
        ]
    },
    {
        "context_id": 

In [None]:
with open('./data/context_keyword_pair.json', 'w') as f:
    f.write(json_data)