In [0]:
%pip install -r requirements.txt --force-reinstall
dbutils.library.restartPython()

In [0]:
# Widgets for easy evaluator input
dbutils.widgets.text("input_question", "Summarize the following legal case")
dbutils.widgets.text("embedding_endpoint", "databricks-bge-large-en")
dbutils.widgets.text("llm_endpoint", "databricks-meta-llama-3-1-8b-instruct")
dbutils.widgets.text("k", "3")  # number of retrieved chunks

# Retrieve widget values
input_question = dbutils.widgets.get("input_question")
embedding_endpoint = dbutils.widgets.get("embedding_endpoint")
llm_endpoint = dbutils.widgets.get("llm_endpoint")
k = int(dbutils.widgets.get("k"))

print(f"📌 Parameters set:\n- Question: {input_question}\n- Chunks: {k}\n- Embedding: {embedding_endpoint}\n- LLM: {llm_endpoint}")


In [0]:
from datasets import load_dataset

dataset = load_dataset("ChicagoHAI/CaseSumm")

# Convert the 'train' split to Pandas
df = dataset['train'].to_pandas()

# Select required columns and rename
df = df[['opinion', 'syllabus']].rename(columns={
    'opinion': 'text',
    'syllabus': 'summary'
})

spark_df = spark.createDataFrame(df)

spark_df.write.format("delta") \
    .mode("overwrite") \
    .saveAsTable("legal.bronze.casesumm")

In [0]:
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain.embeddings import DatabricksEmbeddings
import os

# --- Config ---
TABLE_NAME = "legal.bronze.casesumm"
BATCH_SIZE = 5000  # number of rows per batch (tune for serverless memory)
FAISS_DIR = "/Volumes/legal/bronze/casesumm_volume/casesumm_faiss"

# --- Setup ---
splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=150)
embeddings = DatabricksEmbeddings(endpoint=embedding_endpoint)

# Prepare an empty FAISS index for merging
main_vectorstore = None

# --- Process in batches ---
batch_num = 0
batch_texts = []
print("Starting processing")
for row in (
    spark.table(TABLE_NAME)
    .select("text")
    .where("text IS NOT NULL")
    .limit(5000)
    .toLocalIterator()  # stream to driver row-by-row
):
    
    batch_texts.append(row["text"])  # row is a Row, access the column value

    if len(batch_texts) >= BATCH_SIZE:
        # Process this batch
        docs = splitter.create_documents(batch_texts)
        vs = FAISS.from_documents(docs, embeddings)

        if main_vectorstore is None:
            main_vectorstore = vs
        else:
            main_vectorstore.merge_from(vs)

        batch_texts = []
        batch_num += 1
        print(f"Processed batch {batch_num}")

# Process any leftover texts        
if batch_texts:
    docs = splitter.create_documents(batch_texts)
    print(f"Total chunks created: {len(docs)}")
    vs = FAISS.from_documents(docs, embeddings)
    if main_vectorstore is None:
        main_vectorstore = vs
    else:
        main_vectorstore.merge_from(vs)

# --- Save final FAISS index ---
os.makedirs(FAISS_DIR, exist_ok=True)
main_vectorstore.save_local(FAISS_DIR)
print(f"✅ FAISS index saved to {FAISS_DIR}")


In [0]:
from langchain_community.vectorstores import FAISS
from langchain.embeddings import DatabricksEmbeddings
from langchain_community.chat_models import ChatDatabricks

# config - change these to your values
FAISS_DIR = "/Volumes/legal/bronze/casesumm_volume/casesumm_faiss"

# create embeddings wrapper (used for loading the FAISS index)
embeddings = DatabricksEmbeddings(endpoint=embedding_endpoint)

# load the saved FAISS index
vectorstore = FAISS.load_local(FAISS_DIR, embeddings, allow_dangerous_deserialization=True)
print("Loaded FAISS with", vectorstore.index.ntotal, "vectors (if available).")
retriever = vectorstore.as_retriever()


llm = ChatDatabricks(
    endpoint=llm_endpoint,
    temperature=0.1,
    max_tokens=1024
)

def rag_pipeline(case_text: str) -> str:
    # Retrieve relevant chunks
    docs = retriever.invoke(case_text)

    # Combine into a single context
    context = "\n\n".join([d.page_content for d in docs])

    # Prompt for summarization
    prompt = f"""
    You are a legal assistant. Write in 4–5 sentences covering the main legal arguments and conclusion:

    {context}

    Summary:
    """

    # Call LLM
    response = llm.invoke(prompt)
    return response.content.strip()

In [0]:

DEFAULT_K = 3
def legal_qa(query: str, top_k: int = DEFAULT_K) -> tuple[str, list]:
    # similarity_search returns Documents with page_content and metadata
    docs = vectorstore.similarity_search(query, k=top_k)

    context = "\n\n".join([doc.page_content for doc in docs])

    # Construct prompt for LLM
    prompt = f"""
    You are a knowledgeable legal expert assistant. Use the following context excerpts from legal cases to answer the question below concisely and accurately.

    Context:
    {context}

    Question:
    {query}

    Answer:
    """
    # Generate answer from LLM
    response = llm.invoke(prompt)
    answer = response.content if hasattr(response, "content") else str(response)

    return answer, docs

# Example
q = input_question
answer, source_chunks = legal_qa(q, top_k=k)

print("=== Answer ===")
print(answer)

print("\n=== Source Chunks ===")
for i, chunk in enumerate(source_chunks, 1):
    print(f"[Chunk {i}] {chunk.page_content[:500].replace(chr(10), ' ')}...\n")


In [0]:
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

# Load table from Unity Catalog
df = spark.table("legal.bronze.casesumm").limit(1)
case = df.collect()[0]

case_text = case["text"]
ground_truth_summary = case["summary"]

print("Case text length:", len(case_text))
print("Ground truth summary:", ground_truth_summary[:200], "...")

generated_summary = rag_pipeline(case_text)
print("Generated Summary:", generated_summary[:200], "...")



In [0]:
%pip install rouge-score

In [0]:
from rouge_score import rouge_scorer
import matplotlib.pyplot as plt

scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
scores = scorer.score(ground_truth_summary, generated_summary)

import pandas as pd

rows = []
for metric, score in scores.items():
    rows.append({
        "Metric": metric.upper(),
        "Precision": score.precision,
        "Recall": score.recall,
        "F1": score.fmeasure
    })

df_scores = pd.DataFrame(rows)

display(df_scores)

# Chart
plt.figure(figsize=(8,5))
for col in ["Precision", "Recall", "F1"]:
    plt.plot(df_scores["Metric"], df_scores[col], marker='o', label=col)

plt.title("ROUGE Scores")
plt.ylim(0, 1)
plt.ylabel("Score")
plt.legend()
plt.grid(True)
plt.show()

# df_scores
