In [None]:
# !pip install -U langchain-community pypdf transformers nltk sentence_transformers faiss-cpu numpy langchain_groq gradio neo4j

In [None]:
# !pip install chromadb

In [None]:
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter

pdf_paths = ["/BNS.pdf"]

class PDFFunctions():
  def __init__(self):
    pass

  def pdf_to_chunks(self, pdf_paths):
    # pdf to page
    documents = []
    for x in pdf_paths:
      loader = PyPDFLoader(x)
      documents.extend(loader.load())

    # tokenize
    splitter = RecursiveCharacterTextSplitter(
    chunk_size=2000,
    chunk_overlap=300
    )
    chunks = splitter.split_documents(documents)

    return chunks

In [None]:
# from sentence_transformers import SentenceTransformer
# from chromadb import PersistentClient

# class Vectorize():
#     def __init__(self, persist_directory="./chroma_db", collection_name="my_collection"):
#         # Load embedding model once
#         self.model = SentenceTransformer('all-MiniLM-L6-v2')

#         # New ChromaDB persistent client
#         self.client = PersistentClient(path=persist_directory)

#         # Create or get the collection
#         self.collection = self.client.get_or_create_collection(name=collection_name)

#     def documents_to_vector(self, documents):
#         texts = [doc.page_content for doc in documents]
#         ids = [f"doc_{i}" for i in range(len(texts))]

#         embeddings = self.model.encode(texts, convert_to_numpy=True).tolist()

#         self.collection.add(
#             documents=texts,
#             embeddings=embeddings,
#             ids=ids
#         )

#         return self.collection

#     def top_k(self, query, k=5):
#         results = self.collection.query(
#             query_texts=[query],
#             n_results=k
#         )
#         top_texts = results.get('documents', [[]])[0]
#         return "\n".join(top_texts)


In [None]:
import ast
import re
import time
from langchain_groq import ChatGroq
from langchain.prompts import ChatPromptTemplate

class Inference:
    def __init__(self, model="deepseek-r1-distill-llama-70b", api_key="", batch_size=10):
        self.llm = ChatGroq(model=model, api_key=api_key)
        self.batch_size = batch_size
        self.prompt = ChatPromptTemplate.from_template("""
Extract legal information tuples from the following text.
Return ONLY a valid Python list of tuples in this format:

(offence, chapter, section, punishment_clause)

Here are some detailed examples:

Text:
\"\"\"
1. Offence: Theft; Chapter: 5; Section: 378; Punishment: Imprisonment up to 3 years or fine, or both.
2. Offence: Criminal Breach of Trust; Chapter: 7; Section: 405; Punishment: Imprisonment for up to 2 years, or fine, or both.
\"\"\"
Output:
[
    ("Theft", "5", "378", "Imprisonment up to 3 years or fine, or both"),
    ("Criminal Breach of Trust", "7", "405", "Imprisonment for up to 2 years, or fine, or both")
]

Text:
\"\"\"
The offence of Murder falls under Chapter 6, Section 302. The punishment prescribed is death penalty or life imprisonment.
Attempt to murder is dealt with in Chapter 6, Section 307, punishable by imprisonment for up to 10 years and fine.
\"\"\"
Output:
[
    ("Murder", "6", "302", "Death penalty or life imprisonment"),
    ("Attempt to murder", "6", "307", "Imprisonment for up to 10 years and fine")
]

Text:
\"\"\"
Offence: Cheating; Chapter: 8; Section: 420; Punishment: Imprisonment which may extend to seven years and fine.
\"\"\"
Output:
[
    ("Cheating", "8", "420", "Imprisonment which may extend to seven years and fine")
]

Now extract from the following text:

\"\"\"{input_text}\"\"\"
""")

    def _safe_llm_invoke(self, prompt):
        for _ in range(3):
            try:
                return self.llm.invoke(prompt)
            except Exception:
                time.sleep(1)
        return None

    def extract_custom_tuples(self, chunks):
        all_tuples = []
        total_batches = (len(chunks) + self.batch_size - 1) // self.batch_size

        for batch_idx in range(0, len(chunks), self.batch_size):
            batch_chunks = chunks[batch_idx:batch_idx + self.batch_size]

            combined_text = "\n\n".join(
                [f"[CHUNK {i + batch_idx + 1}]\n{chunk.page_content}" for i, chunk in enumerate(batch_chunks)]
            )

            formatted_prompt = self.prompt.format_prompt(input_text=combined_text).to_messages()
            response = self._safe_llm_invoke(formatted_prompt)
            if not response:
                continue

            output_str = response.content

            # Regex to extract Python lists of tuples
            matches = re.findall(r"\[\s*\([^)]+\)\s*(?:,\s*\([^)]+\)\s*)*\]", output_str, re.DOTALL)
            batch_tuples = []

            for match in matches:
                try:
                    tuples = ast.literal_eval(match)
                    if isinstance(tuples, list) and all(len(t) == 4 for t in tuples):
                        batch_tuples.extend(tuples)
                except:
                    pass

            all_tuples.extend(batch_tuples)
            print(f"Batch {(batch_idx // self.batch_size) + 1} finished")

        unique_tuples = list(set(all_tuples))
        return unique_tuples

    def save_tuples_to_file(self, tuples, filename="graphData.txt"):
        with open(filename, 'w', encoding='utf-8') as f:
            f.write("[\n")
            for t in tuples:
                f.write(f"    {repr(t)},\n")
            f.write("]\n")

    def load_tuples_from_file(self, filename="graphData.txt"):
        """
        Load list of tuples from a file containing a Python list literal.
        """

        with open(filename, 'r', encoding='utf-8') as f:
            content = f.read()
            try:
                tuples = ast.literal_eval(content)
                if isinstance(tuples, list) and all(isinstance(t, tuple) for t in tuples):
                    return tuples
                else:
                    print("[Warning] File content is not a list of tuples.")
                    return []
            except Exception as e:
                print(f"[Error] Failed to parse file '{filename}': {e}")
                return []

    def answer_user_question(self, context_for_llm, matched_node):
      """
      Given legal context and a matched node (offense), prompt the LLM to extract
      relevant legal references: chapter, section, and punishment clause from the BNS.
      """
      prompt = f"""
    You are a legal assistant AI trained to interpret and extract structured legal information from the Bharatiya Nyaya Sanhita (BNS) based on a legal knowledge graph.

    Given the retrieved legal graph context and the user's target offense or concept, return the relevant:
    - **Chapter Number and Name**
    - **Section Number and Title**
    - **Punishment Clause(s)**
    - **Any Directly Related Offenses**

    Ensure the answer is structured, cited exactly as in BNS, and only use information from the retrieved context.

    ---

    ### Example 1:

    **Matched Node:** Theft

    **Retrieved Context:**
    Chapter XVII – Of Offenses Against Property
    Section 303 – Theft
    Whoever commits theft shall be punished with imprisonment of either description for a term which may extend to three years, or with fine, or with both.

    **Answer:**
    - Chapter: Chapter XVII – Of Offenses Against Property
    - Section: Section 303 – Theft
    - Punishment: Imprisonment up to 3 years, or fine, or both
    - Related Offenses: Attempt to commit theft, aggravated theft under Section 304

    ---

    ### Example 2:

    **Matched Node:** Dacoity

    **Retrieved Context:**
    Chapter XVII – Of Offenses Against Property
    Section 310 – Dacoity
    When five or more persons conjointly commit or attempt to commit a robbery, it is called “dacoity”. Punishment is imprisonment for life, or rigorous imprisonment for not less than 10 years.

    **Answer:**
    - Chapter: Chapter XVII – Of Offenses Against Property
    - Section: Section 310 – Dacoity
    - Punishment: Life imprisonment or rigorous imprisonment of not less than 10 years
    - Related Offenses: Robbery (Section 309), preparation to commit dacoity (Section 311)

    ---

    ### User Query:

    **Matched Node:** {matched_node}

    **Retrieved Context:**
    {context_for_llm}

    ---

    ### Answer:
    """

      response = self.llm.invoke(prompt)
      return response.content


In [None]:
from neo4j import GraphDatabase
import re
from sentence_transformers import SentenceTransformer
from scipy.spatial.distance import cosine
import numpy as np

class Graphclass:
    def __init__(self, database="neo4j"):
        self.uri = ""
        self.user = ""
        self.password = ""
        self.database = database
        self.driver = GraphDatabase.driver(self.uri, auth=(self.user, self.password))

        # Initialize embedding model once
        self.model = SentenceTransformer('all-MiniLM-L6-v2')

        # Placeholder for cached node names and embeddings
        self.node_names = []
        self.node_embeddings = None

    def create_knowledge_graph(self, tuples):
        """
        tuples: List of (offence, chapter, section, punishment_clause)
        """
        with self.driver.session(database=self.database) as session:
            for i, (offence, chapter, section, punishment) in enumerate(tuples):
                try:
                    session.execute_write(
                        self._create_offense_graph,
                        offence, chapter, section, punishment
                    )
                except Exception as e:
                    print(f"[Warning] Failed to create graph for tuple {i}: {offence, chapter, section, punishment}")
                    print("  Reason:", e)
        print("Graph creation complete.")

    @staticmethod
    def _create_offense_graph(tx, offence, chapter, section, punishment):
        query = """
  MERGE (o:Offense {name: $offence})
  MERGE (c:Chapter {name: 'Chapter No.: ' + $chapter})
  MERGE (s:Section {number: 'Section No.: ' + $section})
  MERGE (p:Punishment {description: $punishment})

  MERGE (o)-[:refersToChapter]->(c)
  MERGE (o)-[:refersToSection]->(s)
  MERGE (o)-[:hasPunishment]->(p)
  """
        tx.run(query, offence=offence, chapter=chapter, section=section, punishment=punishment)

    def _fetch_all_node_names(self):
        """Fetch distinct node names from the Neo4j graph."""
        def fetch_all_node_names(tx):
            query = """
            MATCH (n)
            WHERE n.name IS NOT NULL
            RETURN DISTINCT n.name AS name
            """
            result = tx.run(query)
            return [record["name"] for record in result]

        with self.driver.session(database=self.database) as session:
            node_names = session.execute_read(fetch_all_node_names)

        # Convert all to strings to be safe
        self.node_names = [str(name) for name in node_names]

    def _encode_node_names(self):
        """Encode all node names to embeddings."""
        if not self.node_names:
            self._fetch_all_node_names()
        self.node_embeddings = self.model.encode(self.node_names, convert_to_numpy=True)

    def find_most_similar_node(self, input_text):
        """
        Given input_text, find the most similar node name in the graph.
        Returns: (best_node_name, similarity_score)
        """
        if self.node_embeddings is None:
            self._encode_node_names()

        input_embedding = self.model.encode([input_text], convert_to_numpy=True)[0]

        similarities = [1 - cosine(input_embedding, node_emb) for node_emb in self.node_embeddings]

        best_idx = np.argmax(similarities)

        return self.node_names[best_idx], similarities[best_idx]

    def fetch_related_info(self, node_name):
      """
      Fetches related nodes (1 to 2 hops) connected to the given node_name.
      Returns a list of dicts: [{"info": str, "labels": list}, ...]
      """
      def fetch_related_info_tx(tx, node_name):
          query = """
          MATCH (n)
          WHERE toLower(n.name) CONTAINS toLower($node_name)
            OR toLower(n.number) CONTAINS toLower($node_name)
          WITH n
          MATCH (n)-[*1..2]-(related)
          RETURN DISTINCT coalesce(related.name, related.number, related.description, '') AS info,
                          labels(related) AS labels
          """
          result = tx.run(query, node_name=node_name)
          return [{"info": record["info"], "labels": record["labels"]} for record in result]


      with self.driver.session() as session:
          return session.read_transaction(fetch_related_info_tx, node_name)



    def get_context_text_for_llm(self, node_name):
      """
      Given a node name, fetch related info and combine their 'info' texts into one string,
      suitable as context input for an LLM.
      """
      related_infos = self.fetch_related_info(node_name)

      if not related_infos:
          print(f"[DEBUG] No related info found for node: {node_name}")
          return f"No context found for: {node_name}"

      context_texts = [item["info"] for item in related_infos if item["info"] and item["info"].strip() != ""]

      if not context_texts:
          print(f"[DEBUG] Related nodes found but no valid 'info' properties for: {node_name}")
          return f"No usable info for: {node_name}"

      context =  "\n".join(context_texts)
      # print(">>>>>>>>>>>>>>", context)
      return context


In [None]:
PDF = PDFFunctions()

# chunking
chunks = PDF.pdf_to_chunks(pdf_paths)

# vec = Vectorize()
# # vectorizing
# # storing
# index = vec.documents_to_vector(chunks)

# get tuples
infer = Inference()
tuples = infer.extract_custom_tuples(chunks)

# store tupeles
infer.save_tuples_to_file(tuples)




In [None]:
# create knowledge graph
graph = Graphclass()
graph.create_knowledge_graph(tuples)


In [None]:
import re

def split_think_sections(llm_output):
    # Find all <think>...</think> blocks
    think_matches = re.findall(r"<think>(.*?)</think>", llm_output, re.DOTALL)

    # Join all think sections (in case there are multiple)
    think_text = "\n\n".join(think_matches).strip()

    # Remove <think>...</think> blocks from the full text
    non_think_text = re.sub(r"<think>.*?</think>", "", llm_output, flags=re.DOTALL).strip()

    return think_text, non_think_text

In [None]:
# graph2 = Graphclass()
# context_for_llm = graph2.get_context_text_for_llm("Attempt to commit culpable homicide")
# print("--->>>", context_for_llm)

In [None]:
import gradio as gr
import traceback

def process_input_stream(news_article):
    logs = ""
    try:
        logs += "Starting process...\n"
        yield "", logs, ""

        # Step 1: Find most similar node
        logs += "Finding most similar node...\n"
        yield "", logs, ""
        similar_node = graph.find_most_similar_node(news_article)
        node_name = similar_node[0]
        similarity_score = similar_node[1]
        logs += f"Similar Node Found: {node_name} (Score: {similarity_score:.2f})\n"
        yield "", logs, ""

        # Step 2: Get context for LLM
        logs += "Fetching context for LLM...\n"
        yield "", logs, ""
        context = graph.get_context_text_for_llm(node_name)
        print(".....................", context)
        logs += "Context fetched successfully.\n"
        logs += "---------- Retrieved Context ----------\n"
        logs += context + "\n"
        logs += "---------------------------------------\n"
        yield "", logs, ""

        # Step 3: Call inference engine
        logs += "Running inference...\n"
        yield "", logs, ""
        answer = infer.answer_user_question(context, similar_node)
        answer = split_think_sections(answer)[1]
        logs += "Inference completed.\n"
        yield "", logs, ""

        # Final output
        output = f"**Answer:**\n{answer}"
        logs += "Done.\n"
        yield "", logs, output

    except Exception as e:
        error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}"
        logs += error_msg + "\n"
        yield "", logs, ""

# Gradio UI with streaming logs
with gr.Blocks(title="Legal Assistant: News Article to Law Interpretation") as demo:
    gr.Markdown("## 📰 News Article Input")
    input_box = gr.Textbox(lines=6, label="Enter News Article")

    gr.Markdown("## ⚙️ Under the Hood (Status for Nerds)")
    logs_box = gr.Textbox(lines=18, interactive=False, show_copy_button=True)

    gr.Markdown("## 🧠 Output (LLM Interpretation)")
    output_box = gr.Markdown()

    submit_btn = gr.Button("🔍 Analyze")

    submit_btn.click(
        fn=process_input_stream,
        inputs=input_box,
        outputs=[input_box, logs_box, output_box],
        concurrency_limit=1
    )

demo.launch()