In [1]:
import wikipedia
import pandas as pd
import spacy

from FlagEmbedding import BGEM3FlagModel
model = BGEM3FlagModel('BAAI/bge-m3', use_fp16 = True)

Fetching 30 files:   0%|          | 0/30 [00:00<?, ?it/s]

In [2]:
df_en = pd.read_json(r'C:\Users\FLopezP\Documents\GitHub\Mu-SHROOM-GIL\Datasets\train_ds\mushroom.en-train_nolabel.v1.jsonl', lines=True) #Cambiar para directorio local
df_en.head(5)

Unnamed: 0,lang,model_id,model_input,model_output_text,model_output_logits,model_output_tokens
0,EN,togethercomputer/Pythia-Chat-Base-7B,Do all arthropods have antennae?,"Yes, all insects and arachnids (including spi...","[-2.57427001, 5.1865358353, 5.4173498154, 2.32...","[ĠYes, ,, Ġall, Ġinsects, Ġand, Ġar, ach, n, i..."
1,EN,togethercomputer/Pythia-Chat-Base-7B,Do all arthropods have antennae?,"Yes, all insects and arachnids have at least ...","[-2.57427001, 5.1865358353, 5.4173498154, 2.32...","[ĠYes, ,, Ġall, Ġinsects, Ġand, Ġar, ach, n, i..."
2,EN,togethercomputer/Pythia-Chat-Base-7B,Do all arthropods have antennae?,"Yes, all insects and arachnids (including spi...","[-2.57427001, 5.1865358353, 5.4173498154, 2.32...","[ĠYes, ,, Ġall, Ġinsects, Ġand, Ġar, ach, n, i..."
3,EN,togethercomputer/Pythia-Chat-Base-7B,Do all arthropods have antennae?,"Yes, all insects and arachnids (including spi...","[-2.57427001, 5.1865358353, 5.4173498154, 2.32...","[ĠYes, ,, Ġall, Ġinsects, Ġand, Ġar, ach, n, i..."
4,EN,togethercomputer/Pythia-Chat-Base-7B,Do all arthropods have antennae?,"Yes, all insects and arachnids (including spi...","[-2.57427001, 5.1865358353, 5.4173498154, 2.32...","[ĠYes, ,, Ġall, Ġinsects, Ġand, Ġar, ach, n, i..."


In [3]:
# Estas tres partes están bastante bien.
# El problema siento que radica en la siguiente celda donde recuperamos el resumen de cada página en vez de las secciones relevantes.
def noun_list(a, lang):
    """
    Filtra la pregunta y obtiene las PoST relevantes.
    
    a = list; Lista de preguntas del dataset
    lang = 'es' or 'en'; Idioma a trabajar
    """
    if lang == 'es':
        post_spacy = spacy.load("es_core_news_sm")
    else:
        post_spacy = spacy.load("en_core_web_sm")
    noun_list = []
    nums = ['0','1','2','3','4','5','6','7','8','9']

    for _ in a:
        doc = post_spacy(_)
        sub_noun = []
        for token in doc:
            if token.pos_ == "NOUN" or token.pos_ == "PROPN" or token.pos_ == "NUM":
                sub_noun.append(token.text)
            if token.pos_ == "ADJ" and token.text[0] in nums:
                sub_noun.append(token.text)
        noun_list.append(sub_noun)
    return noun_list

def keyword_por_preg(n_list):
    """
    Junta lista de PoST previo a pasarlo por el API de Wikipedia.
    
    n_list = list; Obtenida de la función noun_list().
    """
    keyword_list = []
    for i in n_list:
        keyword = ''
        for j in i:
            keyword = keyword + j + ' '
        keyword_list.append(keyword)
    return keyword_list

def get_wikipage(text, lang, page_total):
    """
    Regresa las n páginas de Wikipedia más relevantes al query

    text = str; Texto proveniente de la función keyword_por_preg()
    lang = 'es' or 'en'; Lenguaje necesario para wikipedia
    page_total = int; Cantidad de páginas a regresar
    """
    if lang == 'es':
        wikipedia.set_lang('es')
    if lang == 'en':
        wikipedia.set_lang('en')
    page_title = wikipedia.search(text, results = page_total)
    return page_title

In [4]:
def get_content(wiki_names):
    """
    Regresa la lista content con las páginas filtradas y segmentadas.

    wiki_names = list ; Lista de los nombres de páginas de Wikipedia. Extraído del dataset. 
    """
    p_content = []
    content = []
    
    for _ in wiki_names:
        try:
            page = wikipedia.WikipediaPage(_)
        except wikipedia.exceptions.DisambiguationError:
            print("Error")
        text = page.content.replace('\n', '')
        text = text.replace('\t', '')
        p_content.append(text)

    for _ in p_content:
        _ = _.split("===")
        texto = [i.replace('=', '') for i in _]
        aux = []
        for j in texto:
            if len(j) > 60:
                aux.append(j)
        content.append(aux)

    return content

def embeddings_t1(seg_txt):
    """
    Obtenemos los embeddings de cada elemento de seg_txt

    seg_txt = list ; Instancia única de content_list (content_list[i] for i in N)
    """
    len_list = [len(_) for _ in seg_txt]
    max_len = max(len_list)
    embedding_list = []
    for i in seg_txt:
        embs = model.encode(
            i,
            batch_size = 12,
            max_length = max_len,
        )["dense_vecs"]

        embeddings = [_ for _ in embs]
        embedding_list.append(embeddings)
    return embedding_list

def get_content_embs(content_list):
    """
    Genera la lista de content_embeddings a partir de la lista "content" correspondiente.

    content_list  = list ; lista obtenida 
    """
    content_embs = [embeddings_t1(_) for _ in content_list]
    return content_embs

In [6]:
def ds_procesamiento(dataset, lang):
    """
    Regresa un dataset bien formateado, acá chido para todo el procesamiento.
    
    dataset = pd.DataFrame ; El dataset mismo, así encuerado.
    lang = str ; 'es' or 'en' Para determinar el idioma.
    """
    a = keyword_por_preg(noun_list(dataset["model_input"], lang))
    dataset["output_filtrado"] = a

    a_set = list(set(a))
    dic = {}
    for _ in a_set:
        aux = get_wikipage(_, lang, 3)
        if not aux:
            print(f"Problema con: {_}")
        dic[f"{_}"] = aux

    aux1 = []
    for i in range(len(dataset["output_filtrado"])):
        if dataset["output_filtrado"][i] in dic:
            aux1.append(dic[dataset["output_filtrado"][i]])
    dataset["Wiki Asociado"] = aux1
    return dataset

In [7]:
%%time
aux_n = ds_procesamiento(df_en, 'en')
aux_n.head(3)

Problema con: Bischofsheim constitutent community Mainz 
CPU times: total: 2.28 s
Wall time: 20.2 s


Unnamed: 0,lang,model_id,model_input,model_output_text,model_output_logits,model_output_tokens,output_filtrado,Wiki Asociado
0,EN,togethercomputer/Pythia-Chat-Base-7B,Do all arthropods have antennae?,"Yes, all insects and arachnids (including spi...","[-2.57427001, 5.1865358353, 5.4173498154, 2.32...","[ĠYes, ,, Ġall, Ġinsects, Ġand, Ġar, ach, n, i...",arthropods,"[Arthropod, 2018 in arthropod paleontology, Ar..."
1,EN,togethercomputer/Pythia-Chat-Base-7B,Do all arthropods have antennae?,"Yes, all insects and arachnids have at least ...","[-2.57427001, 5.1865358353, 5.4173498154, 2.32...","[ĠYes, ,, Ġall, Ġinsects, Ġand, Ġar, ach, n, i...",arthropods,"[Arthropod, 2018 in arthropod paleontology, Ar..."
2,EN,togethercomputer/Pythia-Chat-Base-7B,Do all arthropods have antennae?,"Yes, all insects and arachnids (including spi...","[-2.57427001, 5.1865358353, 5.4173498154, 2.32...","[ĠYes, ,, Ġall, Ġinsects, Ġand, Ġar, ach, n, i...",arthropods,"[Arthropod, 2018 in arthropod paleontology, Ar..."


In [8]:
temp_wiki = aux_n["Wiki Asociado"][0]
temp_wiki

['Arthropod', '2018 in arthropod paleontology', 'Arthropod exoskeleton']

In [9]:
%%time
content = get_content(temp_wiki)
content_embs = get_content_embs(content)
print(len(content), len(content_embs))

You're using a XLMRobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


3 3
CPU times: total: 5.64 s
Wall time: 4.58 s


In [8]:
# Hasta este punto tenemos un dataset con los outputs extraídos por pregunta, las páginas de wikipedia correspondientes
# Y los embeddings segmentados de dichas páginas.

# Tenemos que encontrar similitudes entre pregunta y top 5-10 párrafos segmentados.

# Tal vez tener todos los embeddings separados por página no es lo mejor

In [10]:
# Preg_Emb
def similarity(t1, t2):
    return t1 @ t2

preg = df_en["model_input"][0]

preg_emb = model.encode(
    preg,
    batch_size = 12,
    max_length = 1024,
)['dense_vecs']

print(similarity(preg_emb, content_embs[0][0]))

0.6157


In [19]:
full = []
simi_list = []
for i in range(len(content)):
    aux = list(zip(content[i], content_embs[i]))
    for _ in aux:
        full.append(_)

simi_list = [(similarity(preg_emb, _[1]), _[0]) for _ in full]
simi_list.sort(reverse=True)
# Con esto tenemos los n valores segmentados más relevantes dada la pregunta.

[(0.6157, 'Arthropods ( ARTH-rə-pod) are invertebrates in the phylum Arthropoda. They possess an exoskeleton with a cuticle made of chitin, often mineralised with calcium carbonate, a body with differentiated (metameric) segments, and paired jointed appendages. In order to keep growing, they must go through stages of moulting, a process by which they shed their exoskeleton to reveal a new one. They form an extremely diverse group of up to ten million species.Haemolymph is the analogue of blood for most arthropods.  An arthropod has an open circulatory system, with a body cavity called a haemocoel through which haemolymph circulates to the interior organs. Like their exteriors, the internal organs of arthropods are generally built of repeated segments. They have ladder-like nervous systems, with paired ventral nerve cords running through all segments and forming paired ganglia in each segment. Their heads are formed by fusion of varying numbers of segments, and their brains are formed b