In [37]:
GOOGLE_API_KEY = "AIzaSyAnOLAWbG1Pipm_fh8fRTVIxoOSOFtWV0I"

In [38]:
from langchain_community.document_loaders import PyPDFLoader, TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Field
from typing import Optional, List

In [39]:
class DataLoader:
    def __init__(self, file_path: str):
        self.path = file_path

    def _check_file_type(self):
        return self.path.split(".")[-1].lower()

    def load(self):
        ext = self._check_file_type()

        if ext == "pdf":
            loader = PyPDFLoader(self.path)
        elif ext == "txt":
            loader = TextLoader(self.path, encoding="utf-8")
        else:
            raise ValueError(f"Unsupported file type: {ext}")

        return loader.load()


In [51]:
import sqlite3
import json
import uuid
from datetime import datetime

class DocumentDatabase:
    def __init__(self, db_name="documents.db"):
        self.conn = sqlite3.connect(db_name)
        self.create_table()

    def create_table(self):
        query = """
        CREATE TABLE IF NOT EXISTS documents (
            document_id TEXT PRIMARY KEY,
            created_at TEXT,
            updated_at TEXT,
            status TEXT,
            raw_text TEXT,
            final_summary TEXT,
            risk_flag TEXT,
            chunk_summaries TEXT,
            num_chunks INTEGER
        )
        """
        self.conn.execute(query)
        self.conn.commit()

    def insert_document(self, raw_text):
        document_id = str(uuid.uuid4())
        now = datetime.now().isoformat()

        query = """
        INSERT INTO documents (
            document_id, created_at, updated_at, status,
            raw_text, final_summary, risk_flag,
            chunk_summaries, num_chunks
        ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
        """

        self.conn.execute(query, (
            document_id,
            now,
            now,
            "processing",
            raw_text,
            None,
            None,
            None,
            0
        ))

        self.conn.commit()
        return document_id

    def update_document(self, document_id, final_summary,
                        risk_flag, chunk_summaries, num_chunks):
        now = datetime.now().isoformat()

        query = """
        UPDATE documents
        SET updated_at = ?,
            status = ?,
            final_summary = ?,
            risk_flag = ?,
            chunk_summaries = ?,
            num_chunks = ?
        WHERE document_id = ?
        """

        self.conn.execute(query, (
            now,
            "completed",
            final_summary,
            risk_flag,
            json.dumps(chunk_summaries),
            num_chunks,
            document_id
        ))

        self.conn.commit()


In [41]:
class TextSplitter:
    def __init__(self, text: str):
        self.text = text

    def summary_split(self):
        splitter = RecursiveCharacterTextSplitter(
            chunk_size=2500,
            chunk_overlap=150
        )
        return splitter.split_text(self.text)

    def embed_split(self):
        splitter = RecursiveCharacterTextSplitter(
            chunk_size= 1200,
            chunk_overlap=150
        )
        return splitter.split_documents(self.text)


In [42]:
from pydantic import BaseModel, Field
from typing import Optional

class ChunkLegalResponse(BaseModel):
    Summary: str = Field(
        description="Detailed summary of this section"
    )
    Flag: Optional[str] = Field(
        default=None,
        description="Risky clause if present"
    )


class FinalLegalResponse(BaseModel):
    Summary: str = Field(
        min_length=300,
        description="Full multi-paragraph legal summary"
    )
    Flag: Optional[str] = Field(
        description="Most significant risky clause for bearer"
    )


In [43]:
class ChatModel:
    def __init__(self, api_key: str):
        self.api = api_key

    def get_chat_model(self):
        return ChatGoogleGenerativeAI(
            model="gemini-2.5-pro", 
            google_api_key=self.api,
            temperature=0.4,
            max_tokens=2000
        )

    def get_embed_model(self):
        return GoogleGenerativeAIEmbeddings(
            model = "gemini-embedding-001",
            api_key=self.api,
            output_dimensionality=512
        )


In [44]:
from langchain_core.prompts import ChatPromptTemplate

chunk_prompt = ChatPromptTemplate.from_messages([
    ("system", """
You are a legal document analysis assistant.

Analyze ONLY the provided text segment.

Tasks:
1. Produce a detailed legal summary of this segment.
2. Identify any clause that may be risky or harmful to the bearer.

Rules:
- Do not infer beyond the given text.
- If no risky clause exists, explicitly say so.
- Preserve exact legal language where relevant.
"""),
    ("human", "{text}")
])


In [45]:
merge_prompt = ChatPromptTemplate.from_messages([
    ("system", """
You are a senior legal analyst.

You are given multiple legal summaries and identified risks
from different sections of the same document.

Tasks:
1. Produce a cohesive, multi-paragraph legal summary of the ENTIRE document.
2. Identify the SINGLE most significant risky clause affecting the bearer.
3. If multiple risks exist, choose the most severe.
4. If no risks exist, state this clearly.

Return output strictly in the required JSON format.
"""),
    ("human", "{text}")
])


In [46]:
def cap_chunks(chunks: List[str], max_chunks: int = 9) -> List[str]:
    """
    Limit the number of chunks to at most max_chunks
    by evenly sampling across the document.
    """
    if len(chunks) <= max_chunks:
        return chunks

    step = max(1, len(chunks) // max_chunks)
    return [chunks[i] for i in range(0, len(chunks), step)][:max_chunks]


def truncate_text(text: str, max_chars: int = 8000) -> str:
    """
    Prevent silent LLM failures by bounding input size.
    """
    if len(text) <= max_chars:
        return text
    return text[:max_chars]


In [49]:
class LegalSummarizer:
    def __init__(self, model, text: str):
        self.base_model = model
        self.text = text

    def summarize(self) -> FinalLegalResponse:
        chunks = TextSplitter(self.text).summary_split()

        chunks = cap_chunks(chunks, max_chunks=9)

        chunk_model = self.base_model.with_structured_output(ChunkLegalResponse)

        chunk_results = []
        for chunk in chunks:
            result = (chunk_prompt | chunk_model).invoke(
                {"text": chunk}
            )
            chunk_results.append(result)


        merged_input = "\n\n".join(
        f"SUMMARY:\n{r.Summary}\nRISK:\n{r.Flag or 'None'}"
        for r in chunk_results
        )

        merged_input = truncate_text(merged_input, max_chars=8000)


        final_model = self.base_model.with_structured_output(FinalLegalResponse)

        final_result = (merge_prompt | final_model).invoke(
            {"text": merged_input}
        )

        return {
            "final_summary": final_result.Summary,
            "risk_flag": final_result.Flag,
            "chunk_summaries": [
                        {
                        "summary": r.Summary,
                        "risk": r.Flag
                            }
                        for r in chunk_results
                                    ]
                    }




In [50]:
# Initialize DB
db = DocumentDatabase()

docs = DataLoader("sample-doc/sample_rent_doc.txt").load()
text = "\n".join(doc.page_content for doc in docs)

# Insert initial document (status = processing)
document_id = db.insert_document(text)

model = ChatModel(api_key=GOOGLE_API_KEY).get_chat_model()

summarizer = LegalSummarizer(model, text)
result = summarizer.summarize()

# Prepare chunk summaries safely
chunk_summaries = []
for idx, r in enumerate(result["chunk_summaries"]):
    entry = {
        "chunk_index": idx,
        "summary": r["summary"]
    }
    if r.get("risk"):
        entry["risk"] = r["risk"]
    chunk_summaries.append(entry)

# Update document with final results
db.update_document(
    document_id=document_id,
    final_summary=result["final_summary"],
    risk_flag=result["risk_flag"],
    chunk_summaries=chunk_summaries,
    num_chunks=len(chunk_summaries)
)

print("Stored Document ID:", document_id)


Stored Document ID: 2177bd3b-272c-468b-a72f-c7487fad5d3e


  now = datetime.utcnow().isoformat()
