In [1]:
!pip install -q langchain nltk chromadb langchain-experimental langchain-community \
    sentence-transformers google-generativeai python-dotenv rich chromadb
    

In [2]:
import json
import chromadb
from chromadb.utils import embedding_functions

In [None]:
import os
from pathlib import Path

# Set path relative to notebook location
file_path = Path("../data/processed/processed_nj_statutes.json")

# Optional: confirm it exists
if not file_path.exists():
    raise FileNotFoundError(f"Couldn't find file at: {file_path.resolve()}")

# Load the JSON
with open(file_path, "r") as f:
    statutes = json.load(f)

print("✅ File loaded successfully.")

In [None]:
from google import generativeai as genai
from dotenv import load_dotenv
import os
from typing import List, Sequence

# Set your API key
load_dotenv()
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))

class GeminiEmbeddingFunction:
    def __call__(self, input: Sequence[str]) -> List[List[float]]:
        # Defaulting to retrieval_document task type for indexing
        return [self.embed_text(text, task_type="retrieval_document") for text in input]

    def embed_text(self, text: str, task_type="retrieval_document") -> list[float]:
        response = genai.embed_content(
            model="models/embedding-001",
            content=text,
            task_type=task_type
        )
        return response["embedding"]

In [5]:
import re

def clean_statute_text(text: str) -> str:
    # Remove isolated digits/headers like "12." at the start of a line
    text = re.sub(r'^\s*\d+\.\s*$', '', text, flags=re.MULTILINE)
    
    # Normalize extra newlines and tabs
    text = re.sub(r'\t', ' ', text)
    text = re.sub(r'\n{2,}', '\n', text)
    return text.strip()


In [6]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.documents import Document

text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=2000,        # adjust as needed (~750-1000 chars = 250-350 tokens)
    chunk_overlap=100       # overlap to preserve context
)

documents = []

for title in statutes:
    for section in title["sections"]:
        cleaned_text = clean_statute_text(section["text"])  # optional cleaner from earlier
        heading = section["heading"]
        section_id = section["section"]

        chunks = text_splitter.create_documents([cleaned_text])
        for idx, chunk in enumerate(chunks):
            documents.append(
                Document(
                    page_content=chunk.page_content.strip(),
                    metadata={
                        "title": title["title"],
                        "section": section_id,
                        "heading": heading,
                        "chunk_id": idx
                    }
                )
            )



In [None]:
for i, doc in enumerate(documents):
    meta = doc.metadata
    print(f"--- Chunk #{i+1} ---")
    print(f"Title:   {meta.get('title')}")
    print(f"Section: {meta.get('section')}")
    print(f"Heading: {meta.get('heading')}")
    print(f"Chunk ID: {meta.get('chunk_id')}\n")
    print(doc.page_content.strip())
    print("\n" + "="*80 + "\n")


In [8]:

import chromadb
from chromadb.utils import embedding_functions

# Initialize Chroma
chroma_client = chromadb.PersistentClient(path="./chroma_db")
collection = chroma_client.get_or_create_collection(
    name="nj_statutes_test_chunks",
    embedding_function=GeminiEmbeddingFunction()
)

In [None]:
import time

# -------------------------------
# Configuration
# -------------------------------
PROGRESS_CHECK = 10         # Print progress every N documents
BATCH_SIZE = 15             # Documents per indexing batch
RETRY_LIMIT = 3             # Max retries for failed batch
RETRY_DELAY = 5             # Delay between retries in seconds

# -------------------------------
# Data Preparation
# -------------------------------
document_ids = []
document_texts = []
document_metadatas = []

print("🔄 Preparing documents for indexing...")

for i, doc in enumerate(documents):
    doc_id = f"statute_{doc.metadata['section']}_{doc.metadata['chunk_id']}"
    document_ids.append(doc_id)
    document_texts.append(doc.page_content)
    document_metadatas.append(doc.metadata)

    if (i + 1) % PROGRESS_CHECK == 0:
        percent = ((i + 1) / len(documents)) * 100
        print(f"📄 Processed {i + 1}/{len(documents)} documents ({percent:.1f}%)")

# -------------------------------
# Batch Indexing with Retry Logic
# -------------------------------
print("\n🚀 Starting ChromaDB indexing...")

for i in range(0, len(document_ids), BATCH_SIZE):
    end_idx = min(i + BATCH_SIZE, len(document_ids))
    batch_ids = document_ids[i:end_idx]
    batch_texts = document_texts[i:end_idx]
    batch_metadatas = document_metadatas[i:end_idx]

    batch_num = (i // BATCH_SIZE) + 1
    success = False

    for attempt in range(1, RETRY_LIMIT + 1):
        try:
            collection.add(
                ids=batch_ids,
                documents=batch_texts,
                metadatas=batch_metadatas
            )
            print(f"✅ Indexed batch {batch_num}, documents {i + 1} to {end_idx}")
            success = True
            break
        except Exception as e:
            print(f"⚠️ Error on batch {batch_num} (attempt {attempt}): {e}")
            time.sleep(RETRY_DELAY)

    if not success:
        print(f"❌ Failed to index batch {batch_num} after {RETRY_LIMIT} attempts.")

# -------------------------------
# Completion
# -------------------------------
print(f"\n🎉 Successfully processed {len(document_ids)} documents in total.")

In [None]:
# Embed the user's question
def embed_query(text):
    response = genai.embed_content(
        model="models/embedding-001",
        content=text,
        task_type="retrieval_query"
    )
    return response["embedding"]

# Retrieve from Chroma
def retrieve_context(question, k=5):
    query_embedding = embed_query(question)
    results = collection.query(
        query_embeddings=[query_embedding],
        n_results=k
    )
    return results

# Format the context from Chroma results
def format_context(results):
    chunks = results["documents"][0]
    return "\n\n---\n\n".join(chunks)

# Ask Gemini to answer using context
def ask_gemini(question, context):
    prompt = f"""
You are a legal assistant that answers based on official New Jersey law.

Answer based on context provided but if the context does not contain information about the question, use your best judgement to answer the question

Be detailed, this is a legal chatbot but for the general public. Keep the chat natural and conversational.  

Context:
{context}

Question:
{question}
"""

    try:
        # Create Gemini model instance
        model = genai.GenerativeModel("gemini-2.5-pro-exp-03-25")

        # Generate content with safety
        response = model.generate_content(
            prompt,
            generation_config=genai.GenerationConfig(
                temperature=0.2,
            )
        )

        # Check for missing response
        if not response.candidates:
            return "[ERROR: No response candidates returned.]"

        return response.text.strip()

    except Exception as e:
        return f"[ERROR: Gemini call failed]\nDetails: {str(e)}"

In [None]:
question = "Who conducts the investigation in a private adoption case in NJ?"
results = retrieve_context(question)
context = format_context(results)
answer = ask_gemini(question, context)

print("🔍 Context Used:\n", context)
print("\n🧠 Gemini's Answer:\n", answer)


In [None]:
import datetime

questions = [
    "What happens if a child is not adopted through an approved agency in New Jersey?",
    "Who conducts the investigation in a private adoption case in NJ?",
    "Can a stepparent skip the agency investigation in an adoption case?",
    "What does the preliminary hearing determine in a New Jersey adoption case?",
    "When can the court immediately issue a judgment of adoption?",
    "What responsibilities does the agency have after the preliminary hearing?"
]

timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
log_file = f"../logs/gemini_adoption_test_{timestamp}.log"

file_path = Path(log_file)

with open(file_path, "w", encoding="utf-8") as f:
    for i, q in enumerate(questions, start=1):
        results = retrieve_context(q)
        context = format_context(results)
        response = ask_gemini(q, context)

        f.write(f"--- Test #{i} ---\n")
        f.write(f"Q: {q}\n\n")
        f.write(f"Context Used:\n{context}\n\n")
        f.write(f"Response:\n{response}\n")
        f.write("="*80 + "\n\n")

print(f"✅ All responses saved to'{log_file}'")