#  <span style="color: #4daafc">Legal Case Similarity Detection - RAG</span>

- [Environment](#environment)
- [Load data](#load-data)
- [Hybrid search](#hybrid-search)
- [RAG QA Chat](#rag-qa-chat)

# Environment

### Prepare python environment

In [1]:
import numpy as np
from utils.embedding import get_embedding_handler
from utils.db import load_vector_db
from utils.str_utils import str_to_arr
from langchain.schema import Document
from sklearn.metrics.pairwise import pairwise_distances
from langchain_core.tools import tool
from typing import Dict, Optional
from dotenv import load_dotenv
import os
import boto3
from botocore.config import Config
from langchain_aws.chat_models import ChatBedrock
from langchain.agents import AgentExecutor, create_tool_calling_agent
from langchain_core.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate, MessagesPlaceholder
from langchain.vectorstores.base import VectorStoreRetriever
from langchain.memory import ConversationBufferMemory

# ignore warnings
import warnings
from sklearn.exceptions import DataConversionWarning

In [2]:
# ignore specific warning
warnings.filterwarnings(action='ignore', category=DataConversionWarning)

### Global variables

In [3]:
base_url = "http://localhost:11434"
emb_model = 'nomic-embed-text:latest'

# Load data

### Load dense vector DB (FAISS)

In [4]:
db_name = "db/vectors/db_legal_cases_summary_100"
embedding_handler = get_embedding_handler(model=emb_model, base_url=base_url)
vector_store = load_vector_db(db_name, embedding_handler, trust_source=True)
#retriever = vector_store.as_retriever(search_type = 'similarity', search_kwargs = {'k': 3})
retriever = vector_store.as_retriever(search_kwargs = {'k': 5})

load_vector_db: successfully loaded vector db from db/vectors/db_legal_cases_summary_100


In [5]:
docstore_id_to_indx = {value: key for key, value in vector_store.index_to_docstore_id.items()}
list(docstore_id_to_indx.items())[:10] # display first n elements

[('e4b649c3-1e33-47e2-a9da-3785d1b98105', 0),
 ('5c409d5e-9449-4af6-84f8-676cf1150094', 1),
 ('884a094b-0e06-49b2-b976-b70692aa5a97', 2),
 ('d9787e86-b812-4256-ab32-e8c1f6735e48', 3),
 ('691ee6a2-c4be-470c-9e17-a0b79cf213d2', 4),
 ('af0dfd09-d4fd-400a-8f59-1677d34c98f3', 5),
 ('3591af72-0202-4f09-bd76-338ab6aa8b81', 6),
 ('b0efc033-aeda-4db2-b5f2-a865869dd280', 7),
 ('0ac4dbe2-508d-47e3-8212-502b514aa060', 8),
 ('27a82aa3-4369-4577-8f3c-08e4c87a4c90', 9)]

### Load sparse vectors (One-hot)

In [6]:
f_path_sparse_vec = "db/vectors/sparse_vectors.npy"
sparse_vectors = np.load(f_path_sparse_vec)

# Hybrid search

### Helper function to find legal case (aka document) by id

In [7]:
def get_doc_by_id(doc_id: str) -> Optional[Document]:
    doc: Document = next((d for d in vector_store.docstore._dict.values() if d.metadata.get("case_number") == doc_id), None)
    return doc

### Define hybrid search tool function

In [8]:
@tool
def hybrid_search(doc_id: str, alpha: float = 0.5, k: int = 5, sparse_exp: float = 0.1) -> Dict[str, float] | str:
    """
    Find similar documents/legal cases in the database. 

    Args:
        doc_id (str): document identifier (legal case number). Example: 3015/09
        alpha (float): Weighting factor for dense vs sparse search (0.0 = only sparse, 1.0 = only dense).
        k (int): Number of top results to return.
        sparse_exp (float): Power exponent to apply to sparse scores to boost their impact.
            Values < 1 (e.g., 0.5) boost small similarity scores; values > 1 reduce them.

    Returns:
        Dict[str, float]: Mapping of case_number to hybrid similarity score.
    """
    if not doc_id:
        return "Error: 'doc_id' is required as input."

    # check if the document ID exists in the database
    doc = get_doc_by_id(doc_id)
    #doc: Document = next((d for d in vector_store.docstore._dict.values() if d.metadata.get("case_number") == doc_id), None)
    if not doc:
        return f"Document ID '{doc_id}' not found in the database."
    
    # get the index of the document we're querying
    query_doc_idx = docstore_id_to_indx.get(doc.id)
    
    # get the embedding for the user's document. Normalized dense vector (L2 norm = 1)
    query_dense_vector = embedding_handler.embed_documents([doc.page_content])[0]

    # sparse vector (one-hot encoded)
    query_sparse_vector = str_to_arr(doc.metadata['legal_refs_sparse_vec'])

    # dense search (FAISS)
    dense_results = vector_store.similarity_search_with_score_by_vector(query_dense_vector, k=k+1)

    # extract FAISS document indexes and similarity scores
    dense_scores = {docstore_id_to_indx.get(d.id): score for d, score in dense_results if docstore_id_to_indx.get(d.id) != query_doc_idx}

    # sparse search using Jaccard similarity
    # .reshape(1, -1) converts query_sparse_vector from shape (n,) (1D) to shape (1, n) (2D)
    jaccard_distances = pairwise_distances(query_sparse_vector.reshape(1, -1), sparse_vectors, metric="jaccard")[0]

    # convert distances to similarity (Jaccard similarity = 1 - Jaccard distance)
    sparse_scores = 1 - jaccard_distances

    # since sparse scores have relatively small values comparing to dense vector cosine similarity scores, 
    # decided to scale-up the values using power exponent
    sparse_scores = np.power(sparse_scores, sparse_exp)

    # merge scores using weighted sum
    combined_scores = {}
    for idx in range(sparse_vectors.shape[0]):  # iterate through document indices
        if idx != query_doc_idx:
            dense_score = dense_scores.get(idx, 0)  # FAISS similarity score
            sparse_score = sparse_scores[idx]  # Jaccard similarity score
            combined_scores[idx] = alpha * dense_score + (1 - alpha) * sparse_score
            # ucomment below line for debug
            #print(f"indx={idx}, dense score={float(dense_score)}, sparse_score={float(sparse_score)}, combined score={combined_scores[idx]}")

    # rank documents by combined score
    ranked_docs = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)

    # return top k documents
    res = {vector_store.docstore._dict[vector_store.index_to_docstore_id[d_idx]].metadata['case_number']: float(d_scr) for d_idx, d_scr in ranked_docs[:k]}
    return res

# RAG QA Chat

### Init langchain AWS Bedrock chat model

In [9]:
# load environment variables from the .env file
load_dotenv()

True

In [10]:
# define params - model id, region, keys, etc.
region_name = 'us-west-2'
model_id = 'anthropic.claude-3-5-sonnet-20241022-v2:0'
endpoint_url = 'https://bedrock-runtime.us-west-2.amazonaws.com'

# get secret keys from environment variables
aws_access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
aws_secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY')

# timeout configuration
timeout_config = Config(connect_timeout=30, read_timeout=120)

# initialize Bedrock client
bedrock_client = boto3.client('bedrock-runtime', region_name=region_name, aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, endpoint_url=endpoint_url, config=timeout_config)

# initialize the LangChain LLM Chat Bedrock
llm = ChatBedrock(
    client=bedrock_client,
    model_id=model_id,
    model_kwargs={
        "max_tokens": 1024,
        "temperature": 0.5
    }
)

Heler function that formats the documents content

In [11]:
def format_docs(docs):
    return '\n\n'.join(doc.page_content for doc in docs)

### Define RAG question answering tool

In [12]:
@tool
def legal_rag_qa(query: str) -> str:
    """Answer general legal questions using retrieved documents.
    
    Args:
        query: The search query.
    """
    #docs = retriever.get_relevant_documents(query)
    docs = retriever.invoke(query)
    context = format_docs(docs)
    
    print(f"Context ===>> {context}")
    prompt = ChatPromptTemplate.from_messages([
        ("system", "You are a legal assistant. Use the provided context to answer the user's query.\nContext:\n{context}"),
        ("human", "Query: {query}")
    ])

    chain = prompt | llm
    return chain.invoke({"context": context, "query": query})

@tool
def case_lookup(case_number: str) -> str:
    """Look up a specific legal case by its exact case number and return its content."""
    doc = get_doc_by_id(case_number)
    if doc:
        return doc.page_content
    return f"No case found for case number {case_number}."

### Create agent

In [13]:
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)

  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)


In [14]:
prompt = ChatPromptTemplate.from_messages([
    ("system", "You are a legal assistant that uses tools to retrieve similar cases or answer legal questions."),
    MessagesPlaceholder(variable_name="chat_history"),
    ("human", "{input}"),
    MessagesPlaceholder(variable_name='agent_scratchpad')
])

In [15]:
tools = [hybrid_search, legal_rag_qa, case_lookup]
agent_runnable = create_tool_calling_agent(llm=llm, tools=tools, prompt=prompt)

In [16]:
agent = AgentExecutor(
    agent=agent_runnable,
    tools=tools,
    memory=memory,
    verbose=False  # put True for debug
)

Print results debug function

In [17]:
def print_result(result):
    output = result.get("output", [])

    if isinstance(output, list):
        for block in output:
            if block.get("type") == "text":
                print("LLM output:", block["text"])
    elif isinstance(output, str):
        print("LLM output:", output)
    else:
        print("Unknown output format:", output)

In [18]:
def format_result(result) -> str:
    output = result.get("output", [])

    print(f"User query: {result.get('input')}")
    # if LLM returned a list of message blocks
    if isinstance(output, list):
        texts = []
        for block in output:
            if block.get("type") == "text":
                texts.append(block["text"])
        return "\n".join(texts)

    # if it's a plain string (some tools may return direct output)
    elif isinstance(output, str):
        return output

    return "Unrecognized response format."

In [19]:
def chat():
    print("Legal Assistant (RAG). Type 'reset' to clear memory, 'exit' or 'quit' to leave the chat.\n")
    while True:
        user_input = input("You: ")
        if user_input.lower() in ("exit", "quit"):
            break
        elif user_input.lower() == "reset":
            memory.clear()
            print("Memory cleared.\n")
            continue
        try:
            result = agent.invoke({"input": user_input})
            print(f"LLM output: {format_result(result)}")
            # ucomment below lines for debug
            #print_result(result)
            #print(f"LLM output: {result['output']}\n")
            #print(f"test: {result['output']}\n")
        except Exception as e:
            print(f"Error: {e}\n")

### Invoke Chat

In [20]:
chat()

Legal Assistant (RAG). Type 'reset' to clear memory, 'exit' or 'quit' to leave the chat.

User query: Please provide the offenses in 3015/09 in one short sentence
LLM output: 

The offenses in case 3015/09 involved terrorism-related charges including conspiracy, attempted murder, possession of firearms, and violations of the Prevention of Terrorism Ordinance.
User query: Find 3 similar cases to case number 3015/09
LLM output: 

I found three similar cases to 3015/09. Here they are listed in order of similarity (with their similarity scores):

1. Case 6068/21 (similarity: 0.81)
2. Case 4182/10 (similarity: 0.80)
3. Case 9387/16 (similarity: 0.44)

Would you like me to look up the details of any of these specific cases to understand why they're similar?
User query: Please print main offenses for the 1st similar case and the verdict. make it short.
LLM output: 

Main offenses: Unlawful possession and firing of a firearm. 
Verdict: Appeal granted, sentence increased from 14 to 25 months im