# SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking

In [2]:
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
from splade.models.transformer_rep import Splade

import pandas as pd
import os
import ast

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# set the dir for trained weights

##### v2
# model_type_or_dir = "naver/splade_v2_max"
# model_type_or_dir = "naver/splade_v2_distil"

### v2bis, directly download from Hugging Face
# model_type_or_dir = "naver/splade-cocondenser-selfdistil"
model_type_or_dir = "naver/splade-cocondenser-ensembledistil"

In [4]:
# loading model and tokenizer

model = Splade(model_type_or_dir, agg="max")
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_type_or_dir)
reverse_voc = {v: k for k, v in tokenizer.vocab.items()}

In [5]:
def get_splade_bow(doc, top=None):
    # now compute the document representation
    with torch.no_grad():
        doc_rep = model(d_kwargs=tokenizer(doc, return_tensors="pt"))["d_rep"].squeeze()  # (sparse) doc rep in voc space, shape (30522,)

    # get the number of non-zero dimensions in the rep:
    col = torch.nonzero(doc_rep).squeeze().cpu().tolist()
    #print("number of actual dimensions: ", len(col))

    # now let's inspect the bow representation:
    weights = doc_rep[col].cpu().tolist()
    d = {k: v for k, v in zip(col, weights)}
    sorted_d = {k: v for k, v in sorted(d.items(), key=lambda item: item[1], reverse=True)}
    
    bow_rep = []
    
    sorted_d_items = list(sorted_d.items())[:top] if top is not None else sorted_d.items()
    
    for k, v in sorted_d_items:
        bow_rep.append((reverse_voc[k], round(v, 2)))
    return bow_rep

In [6]:
if not os.getcwd().endswith("\\_data_mounir\\DEFT2021"):
    os.chdir("_data_mounir\\DEFT2021")

In [None]:
def get_next_phrase(text:str, idx_start=None):
    idx_point = text.find(".", idx_start)
    
    if idx_start is None:
        idx_start = 0
    return text[idx_start:idx_point+1], idx_point+1 

def diviser_passages(text_fr) -> list[str]:
    passages = []
    indice = 0

    while indice < len(text_fr):
        
        passage = str()
        while(len(passage) < 512) and indice < len(text_fr):
        
            indice_precedent = indice
        
            phrase, indice = get_next_phrase(text_fr, indice)
            if len(passage) + len(phrase) < 512: 
                passage += " " + phrase
            else:
                indice = indice_precedent
                break
        passages.append(passage)

    return passages

# Obsolètes

Avec blocs de 2 phrases

In [14]:
for filename in os.listdir("txt_trad"):
    with open (os.path.join("txt_trad", filename), "r") as f:
        text_eng = f.read()
    
    doc_split = text_eng.split(sep='.')
    doc_split2 = [doc_split[i]+doc_split[i+1] for i in range(0, len(doc_split)-1, 2)]
    
    bows = []
    for passage in doc_split2:
        bows.append(get_splade_bow(passage, top=20))
        
    df = pd.DataFrame(zip(doc_split2, bows), columns=["text", "bow_rep"])
    df.text = df.text.apply(lambda x: x.replace("\n",""))
    
    df.to_csv(filename+"-vectorized.csv")
    



Avec blocs de 512 caracteres max

In [16]:
for filename in os.listdir("txt_trad"):
    with open (os.path.join("txt_trad", filename), "r") as f:
        text_eng = f.read()
    
    
    doc_split2 = diviser_passages(text_eng)
    
    bows = []
    for passage in doc_split2:
        bows.append(get_splade_bow(passage, top=20))
        
    df = pd.DataFrame(zip(doc_split2, bows), columns=["text", "bow_rep"])
    df.text = df.text.apply(lambda x: x.replace("\n",""))
    
    df.to_csv(filename+"-vectorized2.csv")
    



# Extraction

Extraire les 50 termes les plus représentatifs pour chaque passage

In [29]:
for filename in os.listdir("txt_trad"):
    with open (os.path.join("txt_trad", filename), "r") as f:
        text_eng = f.read()
    
    
    doc_split2 = diviser_passages(text_eng)
    
    bows = []
    for passage in doc_split2:
        bows.append(get_splade_bow(passage, top=50))
        
    df = pd.DataFrame(zip(doc_split2, bows), columns=["text", "bow_rep"])
    df.text = df.text.apply(lambda x: x.replace("\n",""))
    
    #df.to_csv(filename+"-vectorized3.csv")
    



Regrouper les termes des passages d'un même document dans un seul vecteur (on fait l'union de tout) (Quand les termes se répètent dans des passages différents, on prend celui avec le meilleur score)

In [96]:
if not os.path.exists("txt_trad_vectorized3all"):
    os.mkdir("txt_trad_vectorized3all")


for file in os.listdir("txt_trad_vectorized3"):
    bow_rep_all = set()
    df = pd.read_csv(os.path.join("txt_trad_vectorized3", file))
    df.bow_rep = df.bow_rep.apply(ast.literal_eval)
    
    for bow in df.bow_rep:
        bow_rep_all = bow_rep_all.union(bow)
    
    # Regrouper les termes identiques et considérer uniquement l'instance du terme avec le plus grand score
    df = pd.DataFrame(bow_rep_all, columns=["term", "weight"])
    idx = df.groupby('term')['weight'].idxmax()
    # Sélectionner les lignes correspondantes dans le DataFrame original
    df_clean = df.loc[idx].reset_index(drop=True)
    df.to_csv(os.path.join("txt_trad_vectorized3all", file+"-vectorized-all.csv"))
        
        
    

In [66]:
bow_rep_all

{('##ac', 0.92),
 ('##ain', 1.46),
 ('##ani', 1.22),
 ('##bag', 1.29),
 ('##bar', 1.26),
 ('##bet', 1.74),
 ('##bin', 0.91),
 ('##bit', 1.87),
 ('##ble', 0.92),
 ('##cation', 1.05),
 ('##cial', 1.15),
 ('##cr', 1.18),
 ('##cture', 1.43),
 ('##dic', 1.92),
 ('##dine', 2.27),
 ('##ea', 0.89),
 ('##ech', 1.41),
 ('##ection', 1.66),
 ('##elial', 1.3),
 ('##elo', 1.57),
 ('##ema', 1.44),
 ('##ema', 1.68),
 ('##ema', 1.79),
 ('##emia', 1.39),
 ('##eno', 1.72),
 ('##ergy', 0.89),
 ('##esis', 1.53),
 ('##eter', 1.56),
 ('##eth', 1.63),
 ('##eu', 1.72),
 ('##fa', 1.01),
 ('##far', 1.68),
 ('##fusion', 1.57),
 ('##gia', 1.41),
 ('##ginal', 1.4),
 ('##gonal', 1.57),
 ('##gram', 1.37),
 ('##gui', 2.22),
 ('##he', 1.24),
 ('##hg', 1.54),
 ('##hoe', 1.2),
 ('##hra', 1.83),
 ('##hyl', 1.08),
 ('##ia', 1.33),
 ('##ical', 1.47),
 ('##id', 1.21),
 ('##ina', 1.27),
 ('##ina', 1.31),
 ('##ination', 1.54),
 ('##ino', 1.52),
 ('##it', 1.41),
 ('##it', 1.71),
 ('##ital', 1.59),
 ('##ium', 1.64),
 ('##ivated'

In [86]:
df = pd.DataFrame(bow_rep_all, columns=["term", "weight"])
idx = df.groupby('term')['weight'].idxmax()
# Sélectionner les lignes correspondantes dans le DataFrame original
df_clean = df.loc[idx].reset_index(drop=True)
df_clean.head()

Unnamed: 0,term,weight
0,##ac,0.92
1,##ain,1.46
2,##ani,1.22
3,##bag,1.29
4,##bar,1.26


## Vectoriser une requête et calculer les similarités

In [101]:
query = "urine kidney blood pressure"

#filepdf-190-cas.txt-vectorized.csv

query_vectorized = get_splade_bow(passage, top=50)
df_query = pd.DataFrame(query_vectorized, columns=["term", "weight"])
df_query



Unnamed: 0,term,weight
0,##gm,2.26
1,cd,2.25
2,##fusion,1.82
3,distress,1.8
4,i,1.68
5,19,1.58
6,normal,1.58
7,respiratory,1.48
8,##mun,1.46
9,##bu,1.28


In [107]:
path = "txt_trad_vectorized3all"
for file in os.listdir(path):
    df = pd.read_csv(os.path.join(path, file))
    df_termes_communs = df[df['term'].isin(df_query.term.tolist())]
    


In [127]:
for term, weight in df_query.itertuples(index=False):
    print(term)

##gm
cd
##fusion
distress
i
19
normal
respiratory
##mun
##bu
%
im
max
l
breathing
dose
severe
grams
percent
g
cds
##lin
die
death
patient
17
mg
lung
##og
ml
injection
given
total
##lins
/
equation
died
administered
patients
level
calculated
weight
calculate
distressed
##lo
test
ratio
liter
##ulin
concentration


In [130]:
path = "txt_trad_vectorized3all"
scores_files = []
for file in os.listdir(path):
    score_doc = 0
    df = pd.read_csv(os.path.join(path, file))
    # Transformer le DataFrame en dictionnaire avec les colonnes
    dict = df.set_index('term')['weight'].to_dict()
    
    for term, weight in df_query.itertuples(index=False):
      score_mot_doc = dict.get(term, 0)
      score_doc += score_mot_doc*weight
    scores_files.append((file, score_doc))


In [131]:
scores_files

[('filepdf-119-cas.txt-vectorized3.csv-vectorized-all.csv', 9.4788),
 ('filepdf-144-2-cas.txt-vectorized3.csv-vectorized-all.csv',
  9.076799999999999),
 ('filepdf-176-cas.txt-vectorized3.csv-vectorized-all.csv', 16.3914),
 ('filepdf-190-cas.txt-vectorized3.csv-vectorized-all.csv', 4.002400000000001),
 ('filepdf-263-3-cas.txt-vectorized3.csv-vectorized-all.csv',
  3.7336000000000005),
 ('filepdf-32-2-cas.txt-vectorized3.csv-vectorized-all.csv',
  8.533999999999999),
 ('filepdf-472-cas.txt-vectorized3.csv-vectorized-all.csv', 18.536),
 ('filepdf-700-cas.txt-vectorized3.csv-vectorized-all.csv', 4.1617),
 ('filepdf-702-2-cas.txt-vectorized3.csv-vectorized-all.csv',
  4.8629999999999995),
 ('filepdf-705-cas.txt-vectorized3.csv-vectorized-all.csv',
  22.185299999999994),
 ('filepdf-798-5-cas.txt-vectorized-all.csv', 63.2388),
 ('filepdf-798-5-cas.txt-vectorized3.csv-vectorized-all.csv', 63.2388)]

In [133]:
sorted_scores = sorted(scores_files, key=lambda x: x[1], reverse=True)
sorted_scores

[('filepdf-798-5-cas.txt-vectorized-all.csv', 63.2388),
 ('filepdf-798-5-cas.txt-vectorized3.csv-vectorized-all.csv', 63.2388),
 ('filepdf-705-cas.txt-vectorized3.csv-vectorized-all.csv',
  22.185299999999994),
 ('filepdf-472-cas.txt-vectorized3.csv-vectorized-all.csv', 18.536),
 ('filepdf-176-cas.txt-vectorized3.csv-vectorized-all.csv', 16.3914),
 ('filepdf-119-cas.txt-vectorized3.csv-vectorized-all.csv', 9.4788),
 ('filepdf-144-2-cas.txt-vectorized3.csv-vectorized-all.csv',
  9.076799999999999),
 ('filepdf-32-2-cas.txt-vectorized3.csv-vectorized-all.csv',
  8.533999999999999),
 ('filepdf-702-2-cas.txt-vectorized3.csv-vectorized-all.csv',
  4.8629999999999995),
 ('filepdf-700-cas.txt-vectorized3.csv-vectorized-all.csv', 4.1617),
 ('filepdf-190-cas.txt-vectorized3.csv-vectorized-all.csv', 4.002400000000001),
 ('filepdf-263-3-cas.txt-vectorized3.csv-vectorized-all.csv',
  3.7336000000000005)]