In [85]:
import pandas as pd

In [86]:
from transformers import AutoTokenizer, AutoModel
import torch

In [87]:
import os

In [88]:
from tqdm import tqdm

In [89]:
tokenizer = AutoTokenizer.from_pretrained('antoinelouis/biencoder-mMiniLMv2-L12-mmarcoFR')
model = AutoModel.from_pretrained('antoinelouis/biencoder-mMiniLMv2-L12-mmarcoFR')

In [90]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]  #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

In [91]:
def get_embeddings(chunck, model, tokenizer):
    """
    Get the embedding of a passage
    :param chunck: the passage
    :param model: the model
    :param tokenizer: the tokenizer
    :return:
    """
    # Tokenize sentences
    encoded_input = tokenizer(chunck, padding=True, truncation=True, return_tensors='pt')

    # Compute token embeddings
    with torch.no_grad():
        model_output = model(**encoded_input)

    chunk_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
    return chunk_embeddings.view(-1).numpy().tolist()

In [92]:
def read_file(filename):
    """
    To read the file
    :param filename: string the name of the file
    :return: string  the content of the file
    """
    with open(f'judilibre_json/data/{filename}') as file:
        content = file.read()
    return content

In [93]:
import nltk
nltk.download('punkt')

[nltk_data] Error loading punkt: <urlopen error [SSL:
[nltk_data]     CERTIFICATE_VERIFY_FAILED] certificate verify failed:
[nltk_data]     unable to get local issuer certificate (_ssl.c:1002)>


False

In [94]:
from nltk.tokenize import word_tokenize

In [95]:
def split_text_into_passages(text,tokenizer, prev_include=2):
    """
    Split text into passage with maximum number of char per passage and number of char of the previous
    passage to include in the following
    :param text:  string the text to split
    :param tokenizer: tokenizer utilise
    :param prev_include: int number of char of the previous to include in the following
    :return:
    """
    passages = []
    current_passage = []
    passage_ends = []
    lines = text.splitlines()

    for i, line in enumerate(lines):
        words = line.split()
        for j, word in enumerate(words):
            word_tokens = tokenizer.tokenize(word)
            current_passage_tokens = tokenizer.tokenize(" ".join(current_passage))
            if len(current_passage_tokens) + len(word_tokens) <= tokenizer.model_max_length:
                current_passage.append(word)
            else:
                passages.append(" ".join(current_passage))
                passage_ends.append(i + 1)
                if prev_include > 0:
                    current_passage = current_passage[-prev_include:]
                else:
                    current_passage = []

                current_passage.append(word)
    if current_passage:
        passages.append(" ".join(current_passage))
        passage_ends.append(len(lines))

    chunks = list(range(1, len(passages) + 1))

    df = pd.DataFrame({'chunck': chunks, 'line': passage_ends, 'passage': passages})
    return df


In [96]:
def get_all_tsv_file():
    """
    Get the name of all data file
    :return:
    """
    path = 'judilibre_json/data'
    tsv_files = []
    if os.path.exists(path) and os.path.isdir(path):
        tsv_files = [f for f in os.listdir(path) if f.endswith(".tsv")]
    return tsv_files

In [97]:
def generate_passages(tokenizer, prev_include=2):
    all_tsv_files = get_all_tsv_file()
    all_dfs = []

    for filename in tqdm(all_tsv_files, total=len(all_tsv_files), desc="Generating Passages", unit=" file"):
        id_dec, _ = os.path.splitext(filename)
        text = read_file(filename)
        df = split_text_into_passages(text, tokenizer, prev_include)
        n_rows, _ = df.shape
        df.insert(0, 'id_dec', [id_dec] * n_rows)
        all_dfs.append(df)

    df_judilibre_v = pd.concat(all_dfs, ignore_index=True)
    df_judilibre_v.to_csv(f'judilibre_v/judilibre_v_{len(all_tsv_files)}_fichiers.tsv', index=False)

    return df_judilibre_v

In [98]:
df_passages = generate_passages(tokenizer)

Generating Passages:  27%|██▋       | 30871/113426 [2:51:31<8:13:38,  2.79 file/s] Token indices sequence length is longer than the specified maximum sequence length for this model (134 > 128). Running this sequence through the model will result in indexing errors
Generating Passages: 100%|██████████| 113426/113426 [10:32:03<00:00,  2.99 file/s]  


In [99]:
def generate_embeddings(df, model, tokenizer):
    n_rows, _ = df.shape
    embeddings = []
    for i in tqdm(range(n_rows), desc="Generating Embeddings", unit=" passage"):
        chunk = df['passage'][i]
        embedding = get_embeddings(chunk, model, tokenizer)
        embeddings.append(embedding)
    df['embedding'] = embeddings
    df.to_csv(f'judilibre_v/judilibre_v.tsv', index=False)
    return df

In [100]:
#df_embeddings = generate_embeddings(df_passages, model, tokenizer)

In [101]:
df_passages

Unnamed: 0,id_dec,chunck,line,passage
0,JURITEXT63c8eef7dc5b777c90992fb9,1,10,COUR D'APPEL D'ORLÉANS Chambre des référés - P...
1,JURITEXT63c8eef7dc5b777c90992fb9,2,16,REPASS'CHIC prise en la personne de Me Delphin...
2,JURITEXT63c8eef7dc5b777c90992fb9,3,25,"deFatima HAJBI, greffier, Statuant en référé d..."
3,JURITEXT63c8eef7dc5b777c90992fb9,4,37,[E] 120 allée du Séquoia 45770 SARAN ni compar...
4,JURITEXT63c8eef7dc5b777c90992fb9,5,40,novembre 2022 L'avis du Ministère public a été...
...,...,...,...,...
4215708,JURITEXT6253ccdcbd3db21cbdd9185b,30,83,d'une prise d'acte de la rupture du contrat de...
4215709,JURITEXT6253ccdcbd3db21cbdd9185b,31,85,être fixée qu'en fonction du préjudice subi pa...
4215710,JURITEXT6253ccdcbd3db21cbdd9185b,32,88,prise d'acte de la rupture du contrat de trava...
4215711,JURITEXT6253ccdcbd3db21cbdd9185b,33,93,"une prise d'acte, et a condamné la Société ALM..."
