In [80]:
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
import pickle
import pandas as pd

In [81]:
cases_df = pd.read_csv('../cases.csv')
cases_sample = cases_df[:10].dropna(subset=['body']).fillna('')
index_to_case = {i: row for i, row in enumerate(cases_sample.index)}

In [83]:
cases_sample.to_csv('cases_sample.csv', index=False)

In [84]:
with open('index_to_case.pkl', 'rb') as f:
    index_to_case = pickle.load(f)

FileNotFoundError: [Errno 2] No such file or directory: 'index_to_case.pkl'

In [9]:
sbert_model = SentenceTransformer("all-MiniLM-L6-v2")

In [29]:
def generate_query_embedding(user_query):
    """
    Generates query embeddings using SBERT.
    
    :param user_query: str, the user input query
    :return: np.ndarray, the query embeddings as vector representations
    """
    query_embedding = sbert_model.encode(user_query).astype("float32")

    return query_embedding

In [69]:
# Define example user query and create a query embedding
user_query = "I am looking for cases about election in Canada"
query_embedding = generate_query_embedding(user_query).reshape(1, -1)


In [34]:
# Load case embeddings from the np file
case_embeddings = np.load('sample_case_embeddings.npy').astype('float32')

In [76]:
# Create and add FAISS index to case embeddings
dimension = case_embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(case_embeddings)
faiss.write_index(index, "case_index.faiss")

In [71]:
def faiss_search(query_embedding, top_n):
    """
    Searches FAISS index for the top N most similar cases to a query.

    :param query: str, the user query
    :param top_n: int, number of top matches to return
    :return: list of top matching cases with similarity scores

    """

    # Perform search in FAISS index
    faiss_result = index.search(query_embedding, top_n)

    return faiss_result

In [72]:
faiss_result_tuple = faiss_search(query_embedding, 1)
distances, faiss_indices = faiss_result_tuple

In [73]:
def faiss_results(distances, faiss_indices, index_to_case, cases_df): 
    results = {}

    for i in range(len(faiss_indices[0])):
        faiss_idx = faiss_indices[0][i] 
        dist = distances[0][i]
        
        original_row = index_to_case.get(faiss_idx)

        results[i] = {
            'faiss_index': faiss_idx,
            'distance': dist,
            'original_row': original_row,
            'case_id': cases_df.at[original_row, 'id'],
            'title': cases_df.at[original_row, 'title'],
            'url': cases_df.at[original_row, 'url']
        }
        print(results[i]['url'])
    return results

In [74]:
results = faiss_results(distances, faiss_indices, index_to_case, cases_sample)


https://participedia.net/case/1
