In this Script I build the "teacher" observations.

I will make use of an LLM, in this case `stabilityai/StableBeluga-7B` in order to get the labels that I will be using during the training phase of our lightweight BERT model.

In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import json
from time import time
from sentence_transformers import SentenceTransformer
import nltk
import pandas as pd
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from copy import deepcopy
import ast
# local tools
from local_tools import get_most_relevant_sentences, print_progress_bar, options_to_numerated_list

In [2]:
print("Available devices ", torch.cuda.device_count())
for i in range(torch.cuda.device_count()):
    print(f'Device {i}:', torch.cuda.get_device_name(i))

Available devices  1
Device 0: NVIDIA GeForce RTX 4090


## Load Stable Beluge to the GPU

In [3]:
tokenizer = AutoTokenizer.from_pretrained(
    "stabilityai/StableBeluga-7B", 
    use_fast=True
)
model = AutoModelForCausalLM.from_pretrained(
    "stabilityai/StableBeluga-7B", 
    torch_dtype=torch.bfloat16,
    #low_cpu_mem_usage=True, 
    #device_map="auto"
)#.to("cuda")
model = model.to('cuda:0')

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
# We will be using this embedding model in order to shorten the news article when prompt is very long
embedding_model = SentenceTransformer("all-MiniLM-L6-v2").to('cuda:0')

In [5]:
SYS_PROMPT = "### System:\nYou are StableBeluga, an AI that follows instructions extremely well. Help as much as you can. Remember, be safe, and don't do anything illegal.\n\n"

In [6]:
# Query string that will be used to ask questions to Stable Beluga
Q = """
Given this news article:

"{news_text}"

When the article mentions "{entity}", which of the following options is the news article most likely referring to? Provide only one option.
options:

{search_options}
"""

In [56]:
# function to make the complete prompt that will be sent to Stable Beluga
make_prompt = lambda message: "{sys_prompt}### User: {message}\n\n### Assistant:\n The news article is most likely referring to: ".format(
    sys_prompt=SYS_PROMPT,
    message=message
)

def find_answer(sb_response):
    assis_token = "### Assistant:\n "
    end_of_answer = "</s>"
    soa_ix = sb_response.find(assis_token)
    eoa_ix = sb_response.find(end_of_answer)
    sb_answer = sb_response[soa_ix+len(assis_token):eoa_ix]
    return sb_answer

def make_tokenized_prompt(article_text, entity, search_options, top_n_perc=0.6):
    if top_n_perc < 0.1:
        return None
    question = Q.format(
        news_text = article_text,
        entity = entity,
        search_options=search_options
    )
    prompt=make_prompt(question)
    inputs = tokenizer(prompt, return_tensors="pt")["input_ids"]
    tokens_length = len(inputs[0])
    if tokens_length <= 4096:
        return inputs
    else:
        #print("token lengths were greater than 4096, summarizing the most relevant sentences")
        # summarize only the most importante sentences of the article
        most_relevant_pieces = get_most_relevant_sentences(
            article_text, 
            embedding_model, 
            top_n_perc = top_n_perc
        )
        return make_tokenized_prompt(
            article_text=most_relevant_pieces, 
            entity=entity, 
            search_options=search_options,
            top_n_perc=top_n_perc-0.2
        )

def ask_sb(input_tokens):
    if input_tokens is None:
        return ""
    inputs = input_tokens.to("cuda:0")
    with torch.no_grad():
        output = model.generate(
            inputs, 
            do_sample=False, 
            #top_p=0.95, 
            #top_k=0, 
            max_new_tokens=300
        )
    response = tokenizer.decode(output[0], skip_special_tokens=False)
    answer = find_answer(response)
    return answer

## Load DataSets
`ners_extracted_sofar.json` is a 2 GB file which will not be uploaded to github.
Contains 2,953,563 Named Entities for a subset of 112,522 news articles. 
I extracted Entities using Stanford Core NLP Python client. I focused only on PERSON, ORGANIZATION, COUNTRY and LOCATION entities.

For each entity of interesIwe performed a query string search to Wikidata knowledge graph, which sometimes returned several results for each entit.


Columns:

 - `text`: Is the text of the Named Entity. For example, Donald Trump.
 - `ner`: Is the type of the Named Entity, such as PERSON, LOCATION, etc.
 - `nerConfidences`: Confidence of the Named Entity belonging to a NER.
 - `clean_text`: Is the text of the Named Entity cleaned regarding special caracters and upper cased.
 - `h1`: Is the title of the news articles. Is used as key to merge with the news articles information.
 - `wikidata_search_entries`: Is a list of all instances of an entity that were found in Wikidata. Only one of those instances corresponds to the correct option.


`news_with_metadata.json` is a 1.5 GB file which will not be uploaded to github.
Contains 117,786 spanish news articles and metadata associated to those news articles, which was also extracted using other NLP techniques, such as translations and summarizations. Data was scrapped from "El Universal" digital news paper.

Relevant Columns:
 - `h1`: Is the title of the news articles. Is used as key to merge with the news articles information.
 - `date`: Date and time when the news articles was published.
 - `author`: Name of the author who published the article.
 - `content`: Article text in spanish.
 - `h1_en`: Article title translated to english.
 - `content_en`Article text translated to English.

In [8]:
path_file = 'datasets/ners_extracted_sofar.json'
with open(path_file, 'r') as jfile:
    ners_data = json.load(jfile)

In [9]:
ners_df = pd.DataFrame.from_dict(ners_data)

In [32]:
# free memory
try:
    # delete variable
    del ners_data
except NameError:
    # variable has already been deleted
    pass

In [11]:
ners_df.shape

(2953563, 6)

In [12]:
# 'h1' if the title of the news article
len(ners_df['h1'].unique())

112522

In [13]:
ners_df.head(20)

Unnamed: 0,text,ner,nerConfidences,clean_text,h1,wikidata_search_entries
0,American,NATIONALITY,{'MISC': 0.94065111011726},AMERICAN,A Caleb Plant no le desagrada ser visto como u...,[]
1,Caleb Plant,PERSON,{'PERSON': 0.697600692564},CALEB PLANT,A Caleb Plant no le desagrada ser visto como u...,"[{'id': 'Q41500289', 'display_label': 'Caleb P..."
2,owner,TITLE,,OWNER,A Caleb Plant no le desagrada ser visto como u...,[]
3,International Boxing Federation,ORGANIZATION,{'ORGANIZATION': 0.89281201154313},INTERNATIONAL BOXING FEDERATION,A Caleb Plant no le desagrada ser visto como u...,"[{'id': 'Q742944', 'display_label': 'Internati..."
4,three,NUMBER,{'NUMBER': -1},THREE,A Caleb Plant no le desagrada ser visto como u...,[]
5,Canelo,PERSON,{'PERSON': 0.90461856292689},CANELO,A Caleb Plant no le desagrada ser visto como u...,"[{'id': 'Q37195125', 'display_label': 'Canelo'..."
6,Mexicans,MISC,{'MISC': 0.99871178240681},MEXICANS,A Caleb Plant no le desagrada ser visto como u...,[]
7,strategist,TITLE,,STRATEGIST,A Caleb Plant no le desagrada ser visto como u...,[]
8,Justin Gamber,PERSON,{'PERSON': 0.9997918079006},JUSTIN GAMBER,A Caleb Plant no le desagrada ser visto como u...,[]
9,Mexican,NATIONALITY,{'MISC': 0.99998711367514},MEXICAN,A Caleb Plant no le desagrada ser visto como u...,[]


In [17]:
path_file = 'datasets/news_with_metadata.json'
with open(path_file, 'r') as jfile:
    processed_news_articles = json.load(jfile)

In [18]:
processed_news_articles_df = pd.DataFrame.from_dict(processed_news_articles)

In [22]:
processed_news_articles_df.shape

(117786, 18)

In [23]:
processed_news_articles_df.head(5)

Unnamed: 0,h1,h2,h3,date,author,content,h1_en,h2_en,content_en,content_len,summary,article_summary_similarity,content_en_embedding,Topic,entities,keypoints,main_entity,sentiment_towards_main_entity
0,Cancelan “Noche de Rábanos” en Oaxaca por ries...,"A través de un video, el gobernador Alejandro ...",Más Información,23/12/2021 09:09,Fernando Miranda / Corresponsal,.– Ante el aumento de casos de Covid-19 y en ...,“Night of Radishes” Canceled in Oaxaca by Risk...,"Through a video, Governor Alejandro Murat repo...",.– In view of the increase in Covid-19 cases a...,,The government of Oaxaca cancelled the traditi...,0.834634,"[-0.0039373683, 0.1377171576, 0.0249814242, 0....",1193,"[Oaxaca government, Governor Alejandro Murat, ...","[Cancellation of ""Night of Rábanos"" due to Cov...",Oaxaca,NEGATIVE
1,"Pepenadores se casan en relleno sanitario, don...",Con apoyo de las autoridades de Ciudad Victori...,Más Información,17/12/2019 00:02,Redacción El Universal,Jesús Gallegos y Juana Martínez se conocieron ...,"Pepenadores get married in sanitary filling, w...",With the support of the authorities of Ciudad ...,Jesús Gallegos and Juana Martínez met five yea...,,Jesús Gallegos and Juana Martínez met five yea...,0.769818,"[-0.0183615051, 0.0783702508, 0.0043704007, 0....",2702,"[Jesús Gallegos, Juana Martínez, Ciudad Victor...","[Landfill couple's wedding in Ciudad Victoria,...",Jesús Gallegos and Juana Martínez,POSITIVE
2,Ladrón se mete a casa en Coyoacán de diputada ...,La legisladora Edna Laura Huerta Ruiz dijo a l...,Más Información,29/03/2021 14:41,Redacción,"La diputada federal por Morena, Edna Laura Hue...",Thief enters home in Coyoacán as federal deput...,Legislator Edna Laura Huerta Ruiz told the aut...,"The federal deputy for Morena, Edna Laura Huer...",,Federal deputy Edna Laura Huerta Ruiz was a vi...,0.718933,"[-0.0195583869, 0.0726484135, -0.0551280603, -...",3528,"[Federal deputy for Morena, Edna Laura Huerta ...","[Federal deputy for Morena, Edna Laura Huerta ...",Edna Laura Huerta Ruiz,NEGATIVE
3,Automovilistas evitan atraco y tunden a golpes...,Los conductores que atestiguaron un robo sobre...,Más Información,15/09/2021 09:31,Redacción El Universal,Tras asaltar a punta de pistola a tres automo...,Motorists avoid robbery and beat up thief in I...,The drivers who testified to a robbery over th...,"After assaulting three motorists at gunpoint, ...",,A repeat thief who assaulted three motorists a...,0.819494,"[-0.0368850492, 0.1118888482, 0.0172117595, 0....",937,"[Repeat thief, Motorists, Ignacio Zaragoza Cau...",[Repeat thief assaults three motorists at gunp...,The main entity mentioned in the news article ...,NEGATIVE
4,"Detienen a ladrón, lo amarran a un poste y le ...",Habitantes del fraccionamiento Villas del Pedr...,Más Información,30/09/2021 16:01,Redacción,Morelia.- Vecinos del fraccionamiento Villas ...,"They arrest a thief, tie him to a pole, and wh...",Residents of the Villas del Pedregal division ...,Morelia.- Neighbors of the Villas del Pedregal...,,Neighbors of the Villas del Pedregal division ...,0.779302,"[-0.0057766289, 0.0313035548, -0.0007194448, -...",3735,"[Villas del Pedregal division of Morelia, Mich...",[Neighbors arrest a man stealing inside a hous...,The main entity mentioned in the news article ...,NEGATIVE


In [21]:
# free memory
try:
    # delete variable
    del processed_news_articles
except NameError:
    # variable has already been deleted
    pass

In [27]:
# temporary file that will hold checkopoints in case that the process breaks.
path_file = 'ask_sb_sofar.json'
try:
    # open the checkpoint
    with open(path_file, 'r') as jfile:
        asked_sb_sofar = json.load(jfile)
    already_asked_sb = pd.DataFrame.from_dict(asked_sb_sofar)
    # index to identify where the process stopped, becomes entity + title
    already_asked_sb['ix'] = already_asked_sb['text']+'-'+already_asked_sb['h1']
except FileNotFoundError:
    # no checkpoint has been saved so far, process will start from the beginning
    asked_sb_sofar = []

In [28]:
ners_df['ix'] = ners_df['text']+'-'+ners_df['h1']

In [29]:
if len(asked_sb_sofar) > 0:
    to_ask_sb = ners_df[
        (ners_df['wikidata_search_entries'].apply(len) >1 )&
        ~(ners_df['ix'].isin(already_asked_sb['ix']))
    ]
else:
    to_ask_sb = ners_df[
        (ners_df['wikidata_search_entries'].apply(len) >1 )
    ]

In [30]:
to_ask_sb.head(5)

Unnamed: 0,text,ner,nerConfidences,clean_text,h1,wikidata_search_entries,ix
1,Caleb Plant,PERSON,{'PERSON': 0.697600692564},CALEB PLANT,A Caleb Plant no le desagrada ser visto como u...,"[{'id': 'Q41500289', 'display_label': 'Caleb P...",Caleb Plant-A Caleb Plant no le desagrada ser ...
3,International Boxing Federation,ORGANIZATION,{'ORGANIZATION': 0.89281201154313},INTERNATIONAL BOXING FEDERATION,A Caleb Plant no le desagrada ser visto como u...,"[{'id': 'Q742944', 'display_label': 'Internati...",International Boxing Federation-A Caleb Plant ...
5,Canelo,PERSON,{'PERSON': 0.90461856292689},CANELO,A Caleb Plant no le desagrada ser visto como u...,"[{'id': 'Q37195125', 'display_label': 'Canelo'...",Canelo-A Caleb Plant no le desagrada ser visto...
11,MGM Grand,LOCATION,{'LOCATION': 0.48660150460337},MGM GRAND,A Caleb Plant no le desagrada ser visto como u...,"[{'id': 'Q713960', 'display_label': 'MGM Grand...",MGM Grand-A Caleb Plant no le desagrada ser vi...
12,Las Vegas,CITY,{'LOCATION': 0.99967540705336},LAS VEGAS,A Caleb Plant no le desagrada ser visto como u...,"[{'id': 'Q23768', 'display_label': 'Las Vegas'...",Las Vegas-A Caleb Plant no le desagrada ser vi...


In [31]:
# free memory
try:
    # delete variable
    del ners_df
except NameError:
    # variable has already been deleted
    pass

In [44]:
# Example for one entry point in the process. We will use "Andres Manuel Lopez Obrador" as en example
example = to_ask_sb[to_ask_sb['text'] == 'Andres Manuel Lopez Obrador'].iloc[0]
print(example)

text                                             Andres Manuel Lopez Obrador
ner                                                                   PERSON
nerConfidences                                       {'PERSON': 0.998525444}
clean_text                                       ANDRES MANUEL LOPEZ OBRADOR
h1                         Anuncia Monreal su participación en "legítima"...
wikidata_search_entries    [{'id': 'Q318508', 'display_label': 'Andrés Ma...
ix                         Andres Manuel Lopez Obrador-Anuncia Monreal su...
Name: 2140677, dtype: object


In [54]:
title = example['h1']
entity = example['text']
text = processed_news_articles_df[
    processed_news_articles_df['h1'] == title
]['content_en'].iloc[0]
options = options_to_numerated_list(
    [
        i.get("label_desc") for i in example['wikidata_search_entries']
    ]
)
sb_question_tokens = make_tokenized_prompt(
    article_text=text, 
    entity=entity, 
    search_options=options, 
    top_n_perc=0.6
)
question = Q.format(
    news_text = text,
    entity = entity,
    search_options=options
)
print("Stable Beluga Query: \n{0}".format(question))
sb_answer = ask_sb(sb_question_tokens)
print("Stable Beluga Disambiguation answer: \n{0}".format(sb_answer))

Stable Beluga Query: 

Given this news article:

"President Andres Manuel Lopez Obrador's call for the next scheduled on 27 November for the Angel of Independence to the Zócalo is legitimate, said Morena's leader in the Senate, Ricardo Monreal Ávila, who announced that he will participate in that mobilization, because he is part of the Movement that won the Presidency of the Republic in 2018. In an interview, the legislator said that the president is in all his right to march in the streets as he did many times as an opposition leader. “He believes that this must be the case and we must respect him, not forgetting that he is head of state, the President is President of all Mexicans and of all Mexicans, including opponents, because the rules of democracy are so and so they must be observed. When López Obrador was elected in 2018 he won democratically,” he said. Monreal Ávila said that he is going to march with President López Obrador, although he acknowledged that he could suffer a disa

## Trigger the Process

Now we run the above process for the whole 2,953,563 Named Entities on a subset of 112,522 news articles on a loop. This way, Stable Belgua will help us create a big dataset of disambiguated observations which we can later use to fine tune a BERT model. The objective of fine tuning BERT is to have a lightweight model that can dismabiguate as good as Stable Belgua 7B.

The whole Stable Belgua disambiguation process took around 1 week on my 4090 GPU.

A lightweight model, as we will later see, uses 3x less GPU memory (which translates to lower GPU prouctive costs) and performs the disambiguation 3x faster (which also translates in faster productive systems).

In [58]:
sb_results = asked_sb_sofar
total_entries = len(to_ask_sb.to_dict(orient='records'))
for i, entry in enumerate(to_ask_sb.to_dict(orient='records')):
    if i%1000 == 0:
        torch.cuda.empty_cache()
        
    new_entry = deepcopy(entry)
    title = entry['h1']
    entity = entry['text']
    text = processed_news_articles_df[
        processed_news_articles_df['h1'] == title
    ]['content_en'].iloc[0]
    options = options_to_numerated_list(
        [
            i.get("label_desc") for i in entry['wikidata_search_entries']
        ]
    )
    sb_question_tokens = make_tokenized_prompt(
        article_text=text, 
        entity=entity, 
        search_options=options, 
        top_n_perc=0.6
    )
    sb_answer = ask_sb(sb_question_tokens)
    new_entry['sb_answer'] = sb_answer
    new_entry['options_given'] = options
    sb_results.append(new_entry)

    if i%1000 == 0:
        print_progress_bar(
            iteration=i, 
            total=total_entries, 
            bar_length=50
        )
        with open("ask_sb_sofar.json", "w") as file:
            json.dump(sb_results, file)

with open("datasets/sb_disambiguation_result.json", "w") as file:
    json.dump(sb_results, file)

