In [11]:
import pandas as pd
import numpy as np
import re
import json
from tqdm.autonotebook import tqdm
from typing import List, Dict, Any


from pymilvus import CollectionSchema, FieldSchema, DataType, MilvusClient

In [5]:
# Initialize Milvus client
client = MilvusClient(uri="http://localhost:19530")

In [6]:
client.list_collections()

['articles_collection_L2',
 'articles_collection_IP',
 'articles_collection_COSINE']

In [7]:
def benchmark_metric_types(entities, metric_types, embeddings_dim, batch_size):
    # Dictionary to store the created collections
    collections = {}

    # Iterate over different metric types (L2, IP, COSINE)
    for metric_type in metric_types:
        collection_name = f"articles_collection_{metric_type}"
        
        # Check if the collection already exists and drop it if necessary
        if client.has_collection(collection_name):
            print(f"Collection {collection_name} already exists. Dropping the collection...\n")
            client.drop_collection(collection_name)
            
        # Define the fields of the collection
        id_field = FieldSchema(name="id", dtype=DataType.INT64, is_primary=True)
        text_field = FieldSchema(name="article", dtype=DataType.VARCHAR, max_length=65535)
        reference_field = FieldSchema(name="reference", dtype=DataType.VARCHAR, max_length=1000)
        embedding_field = FieldSchema(name="embedding_articles", dtype=DataType.FLOAT16_VECTOR, dim=embeddings_dim)
        
        # Define the schema
        schema = CollectionSchema(fields=[id_field, text_field, reference_field, embedding_field], description=f"Collection for {metric_type} benchmark")
        
        # Create the collection
        client.create_collection(collection_name=collection_name, schema=schema)
        
        # Insert all entities into the collection
        print(f"\nInserting entities into collection: {collection_name}")
        # client.insert(data=entities, collection_name=collection_name)
        for i in range(0, len(entities), batch_size):
            batch = entities[i:i + batch_size]
            try:
                client.insert(data=batch, collection_name=collection_name)
                print(f"Inserted batch {i // batch_size + 1} successfully.")
            except Exception as e:
                print(f"Error during insertion for batch {i // batch_size + 1}: {e}")

        
        # Create the index with the metric type (only using FLAT index)
        index_params = MilvusClient.prepare_index_params()
        index_params.add_index(
            field_name="embedding_articles",
            metric_type=metric_type,  # Change only the metric type (L2, IP, COSINE)
            index_type="FLAT",  # Always use FLAT index
            index_name="vector_index",
            params={}  # No additional parameters needed for FLAT index
        )
        
        # Create the index on the collection after inserting the entities
        try:
            client.create_index(
                collection_name=collection_name,
                index_params=index_params,
                sync=True  # Wait for index creation to complete
            )
        except Exception as e:
            print(f"Failed to create an index on collection: {collection_name}")
            print(e)
            continue  # Skip this index and continue with the next one

        # Store the created collection in the dictionary
        collections[metric_type] = collection_name

    print("\n\nBenchmark completed for all metric types.")
    # Return the dictionary containing the collections
    return collections


In [7]:
# Load the embeddings from the JSON file
with open('embeddings.json', 'r', encoding='utf-8') as f:
    loaded_embeddings = json.load(f)

In [8]:
def load_embeddings_from_file(filepath: str) -> Dict[int, np.ndarray]:
    try:
        # Load the npz file
        data = np.load(filepath)
        
        # Convert arrays back to dictionary
        embeddings_dict = {
            int(id_): emb for id_, emb in zip(data['ids'], data['embeddings'])
        }
        
        print(f"Successfully loaded {len(embeddings_dict)} embeddings")
        
        return embeddings_dict
        
    except Exception as e:
        print(f"Error loading embeddings: {str(e)}")
        raise

In [9]:
embeddings = load_embeddings_from_file('embeddings_bel.npz')


Successfully loaded 22633 embeddings


In [10]:
df_articles = pd.read_csv('articles.csv')

In [12]:
def process_embeddings_in_chunks(
    df_articles: pd.DataFrame,
    loaded_embeddings: Dict[int, np.ndarray],
    chunk_size: int = 1000
) -> List[Dict[str, Any]]:
    """
    Process embeddings in chunks to optimize memory usage and maintain consistent performance.
    
    Args:
        df_articles: DataFrame containing article information
        loaded_embeddings: Dictionary mapping article IDs to their embeddings
        chunk_size: Number of rows to process in each chunk
        
    Returns:
        List of processed entities
    """
    # Initialize empty list for all entities
    all_entities = []
    
    # Calculate total number of chunks
    total_chunks = (len(df_articles) + chunk_size - 1) // chunk_size
    
    # Process data in chunks with progress bar
    with tqdm(total=len(df_articles), desc="Processing entities") as pbar:
        for chunk_start in range(0, len(df_articles), chunk_size):
            # Get chunk of dataframe
            chunk_end = min(chunk_start + chunk_size, len(df_articles))
            df_chunk = df_articles.iloc[chunk_start:chunk_end]
            
            # Process chunk
            chunk_entities = [
                {
                    "id": row['id'],
                    "article": row['article'],
                    "reference": row['reference'],
                    "embedding_articles": np.array(loaded_embeddings[row['id']], dtype=np.float16)
                }
                for _, row in df_chunk.iterrows()
                if row['id'] in loaded_embeddings
            ]
            
            # Extend all_entities with chunk results
            all_entities.extend(chunk_entities)
            
            # Update progress bar
            pbar.update(len(df_chunk))
            
            # Optional: Clear memory
            del chunk_entities
    
    print(f"Processed {len(all_entities)} entities")
    
    return all_entities


In [13]:

# Process in chunks
processed_entities = process_embeddings_in_chunks(
    df_articles=df_articles,
    loaded_embeddings=embeddings,
    chunk_size=1000
)



Processing entities:   0%|          | 0/22633 [00:00<?, ?it/s]

Processed 22633 entities


In [14]:
metric_types = ["L2", "IP", "COSINE"]
index_types = ["FLAT"]
embeddings_dim = 1024  

collections = benchmark_metric_types(processed_entities, metric_types, embeddings_dim, batch_size=1000)

Collection articles_collection_L2 already exists. Dropping the collection...


Inserting entities into collection: articles_collection_L2
Inserted batch 1 successfully.
Inserted batch 2 successfully.
Inserted batch 3 successfully.
Inserted batch 4 successfully.
Inserted batch 5 successfully.
Inserted batch 6 successfully.
Inserted batch 7 successfully.
Inserted batch 8 successfully.
Inserted batch 9 successfully.
Inserted batch 10 successfully.
Inserted batch 11 successfully.
Inserted batch 12 successfully.
Inserted batch 13 successfully.
Inserted batch 14 successfully.
Inserted batch 15 successfully.
Inserted batch 16 successfully.
Inserted batch 17 successfully.
Inserted batch 18 successfully.
Inserted batch 19 successfully.
Inserted batch 20 successfully.
Inserted batch 21 successfully.
Inserted batch 22 successfully.
Inserted batch 23 successfully.
Collection articles_collection_IP already exists. Dropping the collection...


Inserting entities into collection: articles_collection_

### test search

In [15]:
from sentence_transformers import SentenceTransformer

model_bel = SentenceTransformer('Lajavaness/bilingual-embedding-large', trust_remote_code=True, device='cuda')

In [13]:
from FlagEmbedding import BGEM3FlagModel

bge_m3 = BGEM3FlagModel('BAAI/bge-m3',  
                       use_fp16=True, 
                       device='cuda')

Fetching 30 files:   0%|          | 0/30 [00:00<?, ?it/s]

  colbert_state_dict = torch.load(os.path.join(model_dir, 'colbert_linear.pt'), map_location='cpu')
  sparse_state_dict = torch.load(os.path.join(model_dir, 'sparse_linear.pt'), map_location='cpu')


In [16]:
def generate_embedding(article):
    embedding = bge_m3.encode([article], batch_size=12, max_length=8*1024)["dense_vecs"]
    return embedding[0]

In [27]:
from sentence_transformers import SentenceTransformer
import numpy as np
from typing import Union, List
import torch

def generate_embedding(
    article: Union[str, List[str]],
    model: SentenceTransformer,
    batch_size: int = 32,
    max_length: int = 512,
    device: str = None
) -> np.ndarray:
    """
    Generate embeddings using SentenceTransformer model
    
    Args:
        article: Single article or list of articles
        model: SentenceTransformer model instance
        batch_size: Batch size for processing
        max_length: Maximum sequence length
        device: Computing device ('cuda' or 'cpu')
        
    Returns:
        Numpy array of embeddings
    """
    # Set device
    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Ensure model is on correct device
    model = model.to(device)
    
    # Convert single article to list if necessary
    if isinstance(article, str):
        articles = [article]
    else:
        articles = article
    
    # Generate embeddings
    with torch.no_grad():
        embeddings = model.encode(
            articles,
            batch_size=batch_size,
            show_progress_bar=False,
            convert_to_numpy=True,
            normalize_embeddings=True,  # Normalize for cosine similarity
            max_length=max_length,
            device=device
        )

    # Convert embeddings to float16
    embeddings = embeddings.astype(np.float16)
    
    # Return single embedding if input was single article
    if isinstance(article, str):
        return embeddings[0]
    
    return embeddings

In [19]:
df_questions = pd.read_csv('questions_train.csv')

#concat df_questions['question'] + df_questions['extra_description']

df_questions['complet_question'] = df_questions['question'] + df_questions['extra_description']

In [28]:

# For single article
# article = "Your article text here"
query_vector = generate_embedding(
    article= df_questions['complet_question'].iloc[0],
    model=model_bel,
    batch_size=32
)

# # For multiple articles
# articles = ["Article 1", "Article 2", "Article 3"]
# embeddings = generate_embedding(
#     article=articles,
#     model=model_bel,
#     batch_size=32
# )

# # Process DataFrame column
# df_embeddings = generate_embedding(
#     article=df_articles['articles'].tolist(),
#     model=model_bel,
#     batch_size=32
# )


In [29]:
query_vector

array([-0.03096 , -0.004353,  0.03195 , ..., -0.0366  , -0.02815 ,
       -0.0423  ], dtype=float16)

In [24]:
df_questions = pd.read_csv('questions_train.csv')

#concat df_questions['question'] + df_questions['extra_description']

df_questions['complet_question'] = df_questions['question'] + df_questions['extra_description']

# Vecteur de la question à tester
query_vector = generate_embedding(df_questions['complet_question'].iloc[0])

In [25]:


# Paramètres de recherche uniquement pour FLAT, en variant les métriques
search_params = {
    'L2': {"metric_type": "L2", "params": {}},      # Distance euclidienne
    'COSINE': {"metric_type": "COSINE", "params": {}},  # Distance cosinus
    'IP': {"metric_type": "IP", "params": {}},      # Produit scalaire
}

# Résultats de performance
performance_results = {}

# Boucle pour tester chaque métrique
for metric_type, collection_name in collections.items():
    print(f"Testing collection: {collection_name} with metric type: {metric_type}")

    # Charger les partitions pour la collection
    client.load_collection(collection_name=collection_name)

    # Récupérer les informations sur l'index
    index_info = client.describe_index(collection_name=collection_name, index_name="vector_index")
    print(f"Index info for {metric_type}: {index_info}")

    # Effectuer la recherche avec la métrique correspondante
    search_results = client.search(
        collection_name=collection_name,
        data=[query_vector],
        limit=3,
        search_params=search_params[metric_type],  # Utiliser les paramètres de la métrique en cours
        output_fields=['id', 'reference']
    )

    # Formater et afficher les résultats
    formatted_result = json.dumps(search_results[0], indent=3, ensure_ascii=False)
    print(f"Results for {metric_type}:\n{formatted_result}")
    
    # Enregistrer les résultats dans le dictionnaire
    performance_results[metric_type] = search_results


Testing collection: articles_collection_L2 with metric type: L2
Index info for L2: {'index_type': 'FLAT', 'metric_type': 'L2', 'field_name': 'embedding_articles', 'index_name': 'vector_index', 'total_rows': 22633, 'indexed_rows': 22633, 'pending_index_rows': 0, 'state': 'Finished'}
Results for L2:
[
   {
      "id": 18554,
      "distance": 0.626326322555542,
      "entity": {
         "id": 18554,
         "reference": "Art. 392, Code de la Fonction Publique Wallonne (Livre III, Chapitre IV)"
      }
   },
   {
      "id": 22233,
      "distance": 0.7632614374160767,
      "entity": {
         "id": 22233,
         "reference": "Art. X.5-9, Code du Bien-être au Travail (Livre X, Titre 5)"
      }
   },
   {
      "id": 22225,
      "distance": 0.8184471726417542,
      "entity": {
         "id": 22225,
         "reference": "Art. X.5-1, Code du Bien-être au Travail (Livre X, Titre 5)"
      }
   }
]
Testing collection: articles_collection_IP with metric type: IP
Index info for IP: {'i

In [26]:
labels = df_questions['article_ids'].iloc[0]
print(f"Labels: {labels} \n")


# Affichage des résultats pour chaque index
for index_type, result in performance_results.items():
    list_ids = [result["entity"]["id"] for result in result[0]]
    list_prods = [result["distance"] for result in result[0]]
    print(f"metric Type: {index_type} - results: {list_ids} - distances: {list_prods}")

    print('\n')


Labels: 22225,22226,22227,22228,22229,22230,22231,22232,22233,22234 

metric Type: L2 - results: [18554, 22233, 22225] - distances: [0.626326322555542, 0.7632614374160767, 0.8184471726417542]


metric Type: IP - results: [18554, 22233, 22225] - distances: [0.6871737241744995, 0.6182908415794373, 0.5909006595611572]


metric Type: COSINE - results: [18554, 22233, 22225] - distances: [0.6869423985481262, 0.6183393597602844, 0.5908273458480835]




---------------

In [30]:


# Paramètres de recherche uniquement pour FLAT, en variant les métriques
search_params = {
    'L2': {"metric_type": "L2", "params": {}},      # Distance euclidienne
    'COSINE': {"metric_type": "COSINE", "params": {}},  # Distance cosinus
    'IP': {"metric_type": "IP", "params": {}},      # Produit scalaire
}

# Résultats de performance
performance_results = {}

# Boucle pour tester chaque métrique
for metric_type, collection_name in collections.items():
    print(f"Testing collection: {collection_name} with metric type: {metric_type}")

    # Charger les partitions pour la collection
    client.load_collection(collection_name=collection_name)

    # Récupérer les informations sur l'index
    index_info = client.describe_index(collection_name=collection_name, index_name="vector_index")
    print(f"Index info for {metric_type}: {index_info}")

    # Effectuer la recherche avec la métrique correspondante
    search_results = client.search(
        collection_name=collection_name,
        data=[query_vector],
        limit=3,
        search_params=search_params[metric_type],  # Utiliser les paramètres de la métrique en cours
        output_fields=['id', 'reference']
    )

    # Formater et afficher les résultats
    formatted_result = json.dumps(search_results[0], indent=3, ensure_ascii=False)
    print(f"Results for {metric_type}:\n{formatted_result}")
    
    # Enregistrer les résultats dans le dictionnaire
    performance_results[metric_type] = search_results


Testing collection: articles_collection_L2 with metric type: L2
Index info for L2: {'index_type': 'FLAT', 'metric_type': 'L2', 'field_name': 'embedding_articles', 'index_name': 'vector_index', 'total_rows': 22633, 'indexed_rows': 22633, 'pending_index_rows': 0, 'state': 'Finished'}
Results for L2:
[
   {
      "id": 18554,
      "distance": 0.6798719167709351,
      "entity": {
         "reference": "Art. 392, Code de la Fonction Publique Wallonne (Livre III, Chapitre IV)",
         "id": 18554
      }
   },
   {
      "id": 21092,
      "distance": 0.8355220556259155,
      "entity": {
         "reference": "Art. I.4-50, Code du Bien-être au Travail (Livre Ier, Titre 4, Chapitre V, Section 2)",
         "id": 21092
      }
   },
   {
      "id": 22225,
      "distance": 0.8465765714645386,
      "entity": {
         "reference": "Art. X.5-1, Code du Bien-être au Travail (Livre X, Titre 5)",
         "id": 22225
      }
   }
]
Testing collection: articles_collection_IP with metric type