In [2]:
import toons
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
import sys


# --- Configuration ---
# Use the model you used for embedding.
# I'll use 'BAAI/bge-small-en-v1.5' as recommended,
MODEL_NAME = 'BAAI/bge-small-en-v1.5' 
CASES_FILE = '9700cases_classified.toon'
EMBEDDING_FIELDS = [
    "Issue_embedding", "Precedent_Analysis_embedding", "Analysis_of_the_law_embedding",
    "Fact_embedding", "Respondents_Argument_embedding", "Petitioners_Argument_embedding",
    "Courts_Reasoning_embedding", "Conclusion_embedding"
]


def load_model():
    """Loads the SentenceTransformer model to the GPU."""
    print(f"Loading model '{MODEL_NAME}' to GPU...")
    try:
        model = SentenceTransformer(MODEL_NAME, device='cuda')
        print("Model loaded successfully.")
        return model
    except Exception as e:
        print(f"Error loading model: {e}")
        print("Please ensure 'sentence-transformers' and a CUDA-enabled PyTorch are installed.")
        sys.exit()


def load_cases(filename):
    """Loads the .toon file containing all classified cases."""
    print(f"Loading classified cases from {filename}...")
    try:
        with open(filename, 'r', encoding='utf-8') as f:
            cases_data = toons.loads(f.read())
        print(f"Loaded {len(cases_data)} cases.")
        return cases_data
    except FileNotFoundError:
        print(f"FATAL ERROR: Cases file not found at {filename}")
        sys.exit()
    except Exception as e:
        print(f"Error loading cases file: {e}")
        sys.exit()


def build_search_index(cases_data):
    """
    Pre-processes case embeddings into a fast, searchable list.
    This is a one-time setup to make querying much faster.
    """
    print("Building in-memory search index...")
    search_index = []
    for i, case in enumerate(tqdm(cases_data, desc="Indexing cases")):
        case_chunk_embeddings = []
        for field in EMBEDDING_FIELDS:
            if field in case and isinstance(case[field], list) and len(case[field]) > 0:
                case_chunk_embeddings.append(np.array(case[field]))
        
        if case_chunk_embeddings:
            # Vstack all chunks for this case into one numpy matrix
            search_index.append({
                'case_index': i, # Link back to the original data
                'chunk_stack': np.vstack(case_chunk_embeddings)
            })
            
    print(f"Index built with {len(search_index)} searchable cases.")
    return search_index


def find_best_case(query, model, cases_data, search_index):
    """
    Performs semantic search to find the best matching case.
    """
    if not query:
        return None, None, 0.0

    # 1. Embed the user's query
    query_embedding = model.encode(query, show_progress_bar=False).reshape(1, -1)
    
    best_score = -1.0
    best_case_index = -1
    
    # 2. Compare query against every case in the index
    for item in search_index:
        sim_scores = cosine_similarity(query_embedding, item['chunk_stack'])
        
        max_sim_in_case = np.max(sim_scores)
        
        if max_sim_in_case > best_score:
            best_score = max_sim_in_case
            best_case_index = item['case_index']
            
    if best_case_index == -1:
        return "No match found", "N/A", 0.0

    # 3. Retrieve the best case and its category
    best_case = cases_data[best_case_index]
    
    return best_case.get('Title'), best_case.get('case_category'), best_score


def main():
    model = load_model()
    cases_data = load_cases(CASES_FILE)
    search_index = build_search_index(cases_data)
    
    print("\n--- Legal Case Search Engine ---")
    print("Enter a query (case title, fact, etc.) to find its category.")
    
    while True:
        try:
            query = input("\nQuery (or 'q' to quit): ")
            if query.lower() == 'q':
                print("Exiting...")
                break
                
            title, category, score = find_best_case(query, model, cases_data, search_index)
            
            if title:
                print("\n--- Best Match Found ---")
                print(f" Title: {title}")
                print(f" Category: {category}")
                print(f" Similarity Score: {score:.4f}")
                
            else:
                print("Could not find a match.")
                
        except KeyboardInterrupt:
            print("\nExiting...")
            break
        except Exception as e:
            print(f"An error occurred: {e}")

if __name__ == "__main__":
    main()

Loading model 'BAAI/bge-small-en-v1.5' to GPU...
Model loaded successfully.
Loading classified cases from 9700cases_classified.toon...
Loaded 9760 cases.
Building in-memory search index...


Indexing cases: 100%|██████████| 9760/9760 [00:01<00:00, 6068.47it/s]


Index built with 9760 searchable cases.

--- Legal Case Search Engine ---
Enter a query (case title, fact, etc.) to find its category.
Exiting...
