In [1]:
import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM
from FlagEmbedding import BGEM3FlagModel
from pymilvus import connections, Collection, CollectionSchema, FieldSchema, DataType
import time
from tqdm import tqdm
import os

In [19]:
model_name = "onlplab/alephbert-base"
alephbert_model = AutoModel.from_pretrained(model_name, add_pooling_layer=False)
alephbert_tokenizer = AutoTokenizer.from_pretrained(model_name, clean_up_tokenization_spaces=True)

bgem3_model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)

Fetching 30 files:   0%|          | 0/30 [00:00<?, ?it/s]

In [20]:
def get_alephbert_embedding(text):
    inputs = alephbert_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
    with torch.no_grad():
        outputs = alephbert_model(**inputs)
    return outputs.last_hidden_state[:, 0, :].numpy().flatten()

def get_bgem3_embedding(text):
    return bgem3_model.encode([text], batch_size=1, max_length=8192)['dense_vecs'][0]

def load_sentences(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        return [(line.strip(), os.path.basename(file_path)) for line in file if line.strip()]

In [21]:
# Connect to Milvus
connections.connect("default", host="localhost", port="19530")

alephbert_dim = 768
bgem3_dim = 1024

def create_collection(name, dim):
    fields = [
        FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=False),
        FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
        FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=dim),
        FieldSchema(name="filename", dtype=DataType.VARCHAR, max_length=256),
    ]
    schema = CollectionSchema(fields, f"{name} embeddings")
    collection = Collection(name=name, schema=schema)
    index_params = {
        "metric_type": "L2",
        "index_type": "IVF_FLAT",
        "params": {"nlist": 1024}
    }
    collection.create_index(field_name="embedding", index_params=index_params)
    return collection

alephbert_collection = create_collection("alephbert_test", alephbert_dim)
bgem3_collection = create_collection("bgem3_test", bgem3_dim)

In [22]:


def insert_embeddings(collection, texts_with_metadata, get_embedding_func):
    data = []
    for i, (text, filename) in enumerate(tqdm(texts_with_metadata, desc=f"Creating embeddings for {collection.name}")):
        embedding = get_embedding_func(text)
        data.append({
            "id": i,
            "text": text,
            "embedding": embedding.tolist(),
            "filename": filename,
        })
    
    collection.insert(data)
    
hebrew_texts_1 = load_sentences('ישראל_wikipedia.txt')
hebrew_texts_2 = load_sentences('תכנות מונחה עצמים_wikipedia.txt')
texts_3 = load_sentences('mathematics_wikipedia.txt')
text_src = [hebrew_texts_1, hebrew_texts_2, texts_3]
for i in text_src:   
    insert_embeddings(alephbert_collection, i, get_alephbert_embedding)
    insert_embeddings(bgem3_collection, i, get_bgem3_embedding)

Creating embeddings for alephbert_test: 100%|██████████| 358/358 [00:16<00:00, 22.05it/s]
Creating embeddings for bgem3_test: 100%|██████████| 358/358 [00:05<00:00, 63.98it/s]
Creating embeddings for alephbert_test: 100%|██████████| 98/98 [00:03<00:00, 25.10it/s]
Creating embeddings for bgem3_test: 100%|██████████| 98/98 [00:01<00:00, 57.36it/s]
Creating embeddings for alephbert_test: 100%|██████████| 211/211 [00:10<00:00, 19.22it/s]
Creating embeddings for bgem3_test: 100%|██████████| 211/211 [00:03<00:00, 62.60it/s]


In [23]:
def similarity_search(collection, query_text, get_embedding_func, top_k=5):
    start_time = time.time()
    query_embedding = get_embedding_func(query_text)
    extraction_time = time.time() - start_time
    
    collection.load()
    
    start_time = time.time()
    results = collection.search(
        data=[query_embedding.tolist()],
        anns_field="embedding",
        param={"metric_type": "L2", "params": {"nprobe": 10}},
        limit=top_k,
        output_fields=["text", "filename"]
    )
    search_time = time.time() - start_time
    
    print(f"Query embedding extraction time: {extraction_time:.4f} seconds")
    print(f"Milvus search time: {search_time:.4f} seconds")
    
    return results

In [24]:
query = "מתמטיקה דיסקרטית"

print("AlephBERT Results:")
alephbert_results = similarity_search(alephbert_collection, query, get_alephbert_embedding)
print(f"Top {len(alephbert_results[0])} similar sentences to '{query}':")
for i, result in enumerate(alephbert_results[0]):
    print(f"{i+1}. {result.entity.get('text')} (Distance: {result.distance})")
    print(f"   File: {result.entity.get('filename')}")

print("\nBGEM3 Results:")
bgem3_results = similarity_search(bgem3_collection, query, get_bgem3_embedding)
print(f"Top {len(bgem3_results[0])} similar sentences to '{query}':")
for i, result in enumerate(bgem3_results[0]):
    print(f"{i+1}. {result.entity.get('text')} (Distance: {result.distance})")
    print(f"   File: {result.entity.get('filename')}")

AlephBERT Results:
Query embedding extraction time: 0.0307 seconds
Milvus search time: 0.6455 seconds
Top 3 similar sentences to 'מתמטיקה דיסקרטית':
1. שפה (Distance: 280.8265686035156)
   File: ישראל_wikipedia.txt
2. היסטוריה (Distance: 317.5686340332031)
   File: ישראל_wikipedia.txt
3. היסטוריה (Distance: 317.5686340332031)
   File: תכנות מונחה עצמים_wikipedia.txt

BGEM3 Results:
Query embedding extraction time: 0.0711 seconds
Milvus search time: 0.9041 seconds
Top 3 similar sentences to 'מתמטיקה דיסקרטית':
1. Discrete mathematics (Distance: 0.3976154625415802)
   File: mathematics_wikipedia.txt
2. Discrete geometry (Distance: 0.5477763414382935)
   File: mathematics_wikipedia.txt
3. Discrete probability distributions (Distance: 0.5578590631484985)
   File: mathematics_wikipedia.txt


In [25]:
query = "Israel state"

print("AlephBERT Results:")
alephbert_results = similarity_search(alephbert_collection, query, get_alephbert_embedding)
print(f"Top {len(alephbert_results[0])} similar sentences to '{query}':")
for i, result in enumerate(alephbert_results[0]):
    print(f"{i+1}. {result.entity.get('text')} (Distance: {result.distance})")
    print(f"   File: {result.entity.get('filename')}")

print("\nBGEM3 Results:")
bgem3_results = similarity_search(bgem3_collection, query, get_bgem3_embedding)
print(f"Top {len(bgem3_results[0])} similar sentences to '{query}':")
for i, result in enumerate(bgem3_results[0]):
    print(f"{i+1}. {result.entity.get('text')} (Distance: {result.distance})")
    print(f"   File: {result.entity.get('filename')}")

AlephBERT Results:
Query embedding extraction time: 0.0260 seconds
Milvus search time: 0.0030 seconds
Top 3 similar sentences to 'Israel state':
1. גבולות ישראל (Distance: 225.12490844726562)
   File: ישראל_wikipedia.txt
2. Reality (Distance: 255.7299041748047)
   File: mathematics_wikipedia.txt
3. שם המדינה (Distance: 258.9053955078125)
   File: ישראל_wikipedia.txt

BGEM3 Results:
Query embedding extraction time: 0.1291 seconds
Milvus search time: 0.0030 seconds
Top 3 similar sentences to 'Israel state':
1. ישראל, סרטונים בערוץ היוטיוב (Distance: 0.8721590042114258)
   File: ישראל_wikipedia.txt
2. גבולות ישראל (Distance: 0.8847226500511169)
   File: ישראל_wikipedia.txt
3. שם המדינה (Distance: 0.8852741122245789)
   File: ישראל_wikipedia.txt


In [26]:
alephbert_collection.release()
bgem3_collection.release()
connections.disconnect("default")

In [27]:
#run to delete a collection

# from pymilvus import connections, utility
# 
# # Connect to Milvus
# connections.connect("default", host="localhost", port="19530")
# 
# def delete_collection(collection_name):
#     try:
#         # Check if the collection exists
#         if utility.has_collection(collection_name):
#             # Drop the collection
#             utility.drop_collection(collection_name)
#             print(f"Collection '{collection_name}' has been successfully deleted.")
#         else:
#             print(f"Collection '{collection_name}' does not exist.")
#     except Exception as e:
#         print(f"An error occurred while deleting the collection: {e}")
#     finally:
#         # Disconnect from Milvus
#         connections.disconnect("default")
# 
# # Usage
# collection_name_to_delete = "alephbert_test"  # Replace with your collection name
# delete_collection(collection_name_to_delete)