In [0]:
%pip install -r requirements.txt --force-reinstall
dbutils.library.restartPython()

In [0]:
# Widgets for easy evaluator input
dbutils.widgets.text("input_question", "Summarize the following legal case")
dbutils.widgets.text("embedding_endpoint", "databricks-bge-large-en")
dbutils.widgets.text("llm_endpoint", "databricks-meta-llama-3-1-8b-instruct")
dbutils.widgets.text("k", "5")  # number of retrieved chunks

# Retrieve widget values
input_question = dbutils.widgets.get("input_question")
embedding_endpoint = dbutils.widgets.get("embedding_endpoint")
llm_endpoint = dbutils.widgets.get("llm_endpoint")
k_val = int(dbutils.widgets.get("k"))

print(f"📌 Parameters set:\n- Question: {input_question}\n- Chunks: {k_val}\n- Embedding: {embedding_endpoint}\n- LLM: {llm_endpoint}")


In [0]:
import numpy as np
import faiss
print(np.__version__)
print(faiss.__version__)

In [0]:
from datasets import load_dataset
import os
cache_path = "/Volumes/legal/bronze/casesumm_volume/huggingface_cache"
os.environ["HF_DATASETS_CACHE"] = cache_path

dataset = load_dataset("ChicagoHAI/CaseSumm", cache_dir=cache_path)

# Convert the 'train' split to Pandas
df = dataset['train'].to_pandas()

# Select required columns and rename
df = df[['opinion', 'syllabus']].rename(columns={
    'opinion': 'text',
    'syllabus': 'summary'
})

spark_df = spark.createDataFrame(df)

spark_df.write.format("delta") \
    .mode("overwrite") \
    .saveAsTable("legal.bronze.casesumm")

In [0]:
%sql
select * from legal.bronze.casesumm limit 2

In [0]:
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain.embeddings import DatabricksEmbeddings
import os

# --- Config ---
TABLE_NAME = "legal.bronze.casesumm"
BATCH_SIZE = 5000  # number of rows per batch (tune for serverless memory)
FAISS_DIR = "/Volumes/legal/bronze/casesumm_volume/casesumm_faiss"

# --- Setup ---
splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=150)
embeddings = DatabricksEmbeddings(endpoint=embedding_endpoint)

# Prepare an empty FAISS index for merging
main_vectorstore = None

# --- Process in batches ---
batch_num = 0
for batch_df in (
    spark.table(TABLE_NAME)
    .select("text")
    .where("text IS NOT NULL")
    .limit(1000000)  # optional: safety limit during first run
    .toLocalIterator()  # stream to driver row-by-row
):
    batch_texts = []
    for row in batch_df:
        batch_texts.append(row)
        if len(batch_texts) >= BATCH_SIZE:
            # Process this batch
            docs = splitter.create_documents(batch_texts)
            vs = FAISS.from_documents(docs, embeddings)

            if main_vectorstore is None:
                main_vectorstore = vs
            else:
                main_vectorstore.merge_from(vs)

            batch_texts = []
            batch_num += 1
            print(f"Processed batch {batch_num}")

# Process any leftover texts
if batch_texts:
    docs = splitter.create_documents(batch_texts)
    vs = FAISS.from_documents(docs, embeddings)
    if main_vectorstore is None:
        main_vectorstore = vs
    else:
        main_vectorstore.merge_from(vs)

# --- Save final FAISS index ---
os.makedirs(FAISS_DIR, exist_ok=True)
main_vectorstore.save_local(FAISS_DIR)
print(f"✅ FAISS index saved to {FAISS_DIR}")


In [0]:
def retrieve_top_k(query: str, k: int = k_val):
    # similarity_search returns Documents with page_content and metadata
    docs = vectorstore.similarity_search(query, k=k)
    return docs

# Example
q = input_question
top_docs = retrieve_top_k(q, k=k_val)
for i, d in enumerate(top_docs, 1):
    print(f"--- chunk #{i} ---")
    print(d.page_content[:500].replace("\n", " "))  # print first 500 chars
    print("metadata:", d.metadata)


In [0]:
from langchain_community.vectorstores import FAISS
from langchain.embeddings import DatabricksEmbeddings
from langchain_community.chat_models import ChatDatabricks

# config - change these to your values
FAISS_DIR = "/Volumes/legal/bronze/casesumm_volume/casesumm_faiss"

# create embeddings wrapper (used for loading the FAISS index)
embeddings = DatabricksEmbeddings(endpoint=embedding_endpoint)

# load the saved FAISS index
vectorstore = FAISS.load_local(FAISS_DIR, embeddings, allow_dangerous_deserialization=True)
print("Loaded FAISS with", vectorstore.index.ntotal, "vectors (if available).")
retriever = vectorstore.as_retriever()


llm = ChatDatabricks(
    endpoint=llm_endpoint,
    temperature=0.0,
    max_tokens=1024
)

def rag_pipeline(case_text: str) -> str:
    # Retrieve relevant chunks
    docs = retriever.get_relevant_documents(case_text)

    # Combine into a single context
    context = "\n\n".join([d.page_content for d in docs])

    # Prompt for summarization
    prompt = f"""
    You are a legal assistant. Write in 4–5 sentences covering the main legal arguments and conclusion:

    {context}

    Summary:
    """

    # Call LLM
    response = llm.invoke(prompt)
    return response.content.strip()

In [0]:
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

# Load table from Unity Catalog
df = spark.table("legal.bronze.casesumm").limit(1)
case = df.collect()[0]

case_text = case["text"]
ground_truth_summary = case["summary"]

print("Case text length:", len(case_text))
print("Ground truth summary:", ground_truth_summary[:200], "...")

generated_summary = rag_pipeline(case_text)
print("Generated Summary:", generated_summary[:200], "...")



In [0]:
%pip install rouge-score

In [0]:
from rouge_score import rouge_scorer

scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
scores = scorer.score(ground_truth_summary, generated_summary)

import pandas as pd

rows = []
for metric, score in scores.items():
    rows.append({
        "Metric": metric.upper(),
        "Precision": score.precision,
        "Recall": score.recall,
        "F1": score.fmeasure
    })

df_scores = pd.DataFrame(rows)
df_scores


In [0]:
import matplotlib.pyplot as plt

# Table display (Databricks automatically renders Pandas DataFrames nicely)
display(df_scores)

# Chart
plt.figure(figsize=(8,5))
for col in ["Precision", "Recall", "F1"]:
    plt.plot(df_scores["Metric"], df_scores[col], marker='o', label=col)

plt.title("ROUGE Scores")
plt.ylim(0, 1)
plt.ylabel("Score")
plt.legend()
plt.grid(True)
plt.show()
