In [0]:
%pip install datasets langchain-community langchain-text-splitters faiss-cpu

In [0]:
dbutils.library.restartPython()

In [0]:
%pip install numpy
from datasets import load_dataset

# Load the dataset from Hugging Face
dataset = load_dataset("ChicagoHAI/CaseSumm")

# Convert the 'train' split to Pandas
df = dataset['train'].to_pandas()

# Select required columns and rename
df = df[['opinion', 'syllabus']].rename(columns={
    'opinion': 'text',
    'syllabus': 'summary'
})

spark_df = spark.createDataFrame(df)

spark_df.write.format("delta") \
    .mode("overwrite") \
    .saveAsTable("legal.bronze.casesumm")

In [0]:
%sql
select * from legal.bronze.casesumm limit 2

In [0]:
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain.embeddings import DatabricksEmbeddings
import os

# --- Config ---
TABLE_NAME = "legal.bronze.casesumm"
BATCH_SIZE = 5000  # number of rows per batch (tune for serverless memory)
FAISS_DIR = "/Volumes/legal/bronze/casesumm_volume/casesumm_faiss"

# --- Setup ---
splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=150)
embeddings = DatabricksEmbeddings(endpoint="databricks-bge-large-en")

# Prepare an empty FAISS index for merging
main_vectorstore = None

# --- Process in batches ---
batch_num = 0
for batch_df in (
    spark.table(TABLE_NAME)
    .select("text")
    .where("text IS NOT NULL")
    .limit(1000000)  # optional: safety limit during first run
    .toLocalIterator()  # stream to driver row-by-row
):
    batch_texts = []
    for row in batch_df:
        batch_texts.append(row)
        if len(batch_texts) >= BATCH_SIZE:
            # Process this batch
            docs = splitter.create_documents(batch_texts)
            vs = FAISS.from_documents(docs, embeddings)

            if main_vectorstore is None:
                main_vectorstore = vs
            else:
                main_vectorstore.merge_from(vs)

            batch_texts = []
            batch_num += 1
            print(f"Processed batch {batch_num}")

# Process any leftover texts
if batch_texts:
    docs = splitter.create_documents(batch_texts)
    vs = FAISS.from_documents(docs, embeddings)
    if main_vectorstore is None:
        main_vectorstore = vs
    else:
        main_vectorstore.merge_from(vs)

# --- Save final FAISS index ---
os.makedirs(FAISS_DIR, exist_ok=True)
main_vectorstore.save_local(FAISS_DIR)
print(f"✅ FAISS index saved to {FAISS_DIR}")


In [0]:
from langchain_community.vectorstores import FAISS
from langchain.embeddings import DatabricksEmbeddings

# config - change these to your values
FAISS_DIR = "/Volumes/legal/bronze/casesumm_volume/casesumm_faiss"

# create embeddings wrapper (used for loading the FAISS index)
embeddings = DatabricksEmbeddings(endpoint="databricks-bge-large-en")

# load the saved FAISS index
vectorstore = FAISS.load_local(FAISS_DIR, embeddings)
print("Loaded FAISS with", vectorstore.index.ntotal, "vectors (if available).")