In [None]:
import os
import json
from sentence_transformers import SentenceTransformer
import numpy as np
import faiss
from rank_bm25 import BM25Okapi
import pickle

# ---------CONFIGURATION---------
cases_folder = "D:/AIP/data/"  # <--- change this to your folder if needed

# ---------LOAD & EMBED CASES---------
model = SentenceTransformer('all-MiniLM-L6-v2')

embeddings = []
meta = []
bm25_corpus = []

def extract_text(case):
    casebody = case.get("casebody", {})
    text_parts = []
    # Extract opinion text(s)
    opinions = casebody.get("opinions", [])
    for item in opinions:
        if isinstance(item, dict):
            content = item.get("text")
            if content and len(content) > 20:
                text_parts.append(content)
    # Optionally add head_matter, parties, etc.
    head_matter = casebody.get("head_matter")
    if isinstance(head_matter, str) and len(head_matter) > 20:
        text_parts.append(head_matter)
    parties = casebody.get("parties", [])
    for party in parties:
        if isinstance(party, str) and len(party) > 20:
            text_parts.append(party)
    # Concatenate all parts
    return "\n".join(text_parts)

for filename in os.listdir(cases_folder):
    if filename.endswith('.json'):
        with open(os.path.join(cases_folder, filename), 'r', encoding='utf-8') as f:
            case = json.load(f)
        full_text = extract_text(case)
        if len(full_text) > 50:
            emb = model.encode(full_text)
            embeddings.append(emb)
            meta.append({'id': case.get("id"), 'name': case.get("name"), 'filename': filename, 'text': full_text})
            bm25_corpus.append(full_text.lower().split())
        print(f"Processing {filename} | Text length: {len(full_text)}")

print(f"Total cases processed: {len(embeddings)}")
if len(embeddings) == 0:
    print("No valid cases found. Check your folder and extraction logic.")
    exit()

embeddings = np.vstack(embeddings)

# ---------FAISS SEMANTIC INDEX---------
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings)

# ---------BM25 KEYWORD INDEX---------
bm25 = BM25Okapi(bm25_corpus)

# ---------SEARCH FUNCTIONS---------

def semantic_search(query, top_k=3):
    query_emb = model.encode(query).reshape(1, -1)
    _, I = index.search(query_emb, top_k)
    results = [meta[i] for i in I[0]]
    return results

def keyword_search(query, top_k=3):
    tokenized_query = query.lower().split()
    scores = bm25.get_scores(tokenized_query)
    top_indices = np.argsort(scores)[-top_k:][::-1]
    return [meta[i] for i in top_indices]

def hybrid_search(query, top_k=3, weight_semantic=0.5, weight_keyword=0.5):
    # Get results & scores
    query_emb = model.encode(query).reshape(1, -1)
    D, I = index.search(query_emb, len(meta))  # dists, indices
    semantic_scores = -D[0]  # negative distance (higher is better)
    tokenized_query = query.lower().split()
    keyword_scores = bm25.get_scores(tokenized_query)
    # Normalize
    s_norm = (semantic_scores - semantic_scores.min()) / (np.ptp(semantic_scores) + 1e-9)
    k_norm = (keyword_scores - np.min(keyword_scores)) / (np.ptp(keyword_scores) + 1e-9)
    hybrid = weight_semantic * s_norm + weight_keyword * k_norm
    top_indices = np.argsort(hybrid)[-top_k:][::-1]
    return [meta[i] for i in top_indices]

# ---------EXAMPLE USAGE---------
search_query = "insurance payout after accidental death"
print("\n--- Semantic Search ---")
for r in semantic_search(search_query):
    print(f"Case ID: {r['id']}, Name: {r['name']}, File: {r['filename']}")

print("\n--- Keyword (BM25) Search ---")
for r in keyword_search(search_query):
    print(f"Case ID: {r['id']}, Name: {r['name']}, File: {r['filename']}")

print("\n--- Hybrid Search ---")
for r in hybrid_search(search_query):
    print(f"Case ID: {r['id']}, Name: {r['name']}, File: {r['filename']}")