# Section 1: Preprocessing
This section involves all the steps related to data cleaning, formatting, and preparing the dataset.
The goal is to ensure that the input data is in the right format for the model to process.

In [None]:
#!pip install faiss-gpu
#!pip install pdfplumber
#!pip install transformers==4.44.0
#!pip install accelerate==0.33.0

import json
from typing import List, Dict, Tuple
from tqdm import tqdm
import os
import re
from transformers import AutoTokenizer, AutoModel
import torch
from transformers import pipeline
from datasets import Dataset
import faiss
import pdfplumber
from datetime import datetime


# Function to extract and clean text from a PDF document
def extract_and_clean_text_from_pdf(pdf_path):
    with pdfplumber.open(pdf_path) as pdf:
        raw_text = []
        for page in pdf.pages:
            text = page.extract_text()
            if text:
                raw_text.append(text)
    return raw_text

# Function to further clean the extracted text
def clean_extracted_text(raw_text):
    cleaned_text = []
    for text in raw_text:
        text = re.sub(r'\s+', ' ', text)  # Replace multiple spaces with a single space
        text = re.sub(r'[^A-Za-z0-9.,;:()\\/\-\s]', '', text)  # Remove non-alphanumeric characters
        text = re.sub(r'\d+\^\d+', '', text)  # Remove superscript numbers
        text = re.sub(r'(\w)-\s+(\w)', r'\1\2', text)  # Fix broken words
        cleaned_text.append(text)
    return cleaned_text

# Function to split cleaned text into sections based on patterns
def split_into_sections(cleaned_text):
    sections = []
    current_section = []
    section_pattern = re.compile(r'\b(?:CHAPTER|Section|SECTION|SCHEDULE)\s+\d+\b', re.IGNORECASE)
    title_pattern = re.compile(r'^\s*(\d+)\s*[-–—]\s*(.*)$')
    paragraph_pattern = re.compile(r'^\(\d+\)')  # Pattern for numbered list points

    for text in cleaned_text:
        lines = text.split('. ')
        for line in lines:
            line = line.strip()
            if section_pattern.match(line):
                if current_section:
                    sections.append(' '.join(current_section))
                    current_section = []
            elif paragraph_pattern.match(line):
                if current_section:
                    sections.append(' '.join(current_section))
                    current_section = []
            title_match = title_pattern.match(line)
            if title_match:
                current_section.append(f"Section {title_match.group(1)}: {title_match.group(2)}")
            else:
                current_section.append(line)
        if current_section:
            sections.append(' '.join(current_section))
            current_section = []
    return sections

# Extract text from the PDF document of the Negotiable Instruments Act
pdf_path = './data/The Negotiable Instruments Act, 1881.pdf'
raw_text = extract_and_clean_text_from_pdf(pdf_path)

# Clean the extracted text
cleaned_text = clean_extracted_text(raw_text)

# Split cleaned text into sections
sections = split_into_sections(cleaned_text)

# File path for the JSON data to be processed
nia_cases_to_process_file = './data/nia_cases_to_process.json'

# Check if the file already exists
file_exists = os.path.exists(nia_cases_to_process_file)

# If the file doesn't exist, filter and save the cases
if not file_exists:
    # Load JSON data from file
    with open('./data/cases.nia_cases.json', 'r') as file:
        data = json.load(file)

    # Filter the cases based on the presence of required fields
    filtered_cases = [case for case in data if "RLC" in case and "RPC" in case and "Facts" in case and "Arg_Pet" in case and "Arg Resp" in case and "Formatted_JudgementText" in case and "JudgmentDate" in case and "Analysis" in case]

    # Save the filtered cases to a new JSON file
    with open('./data/cleaned_nia_cases.json', 'w') as outfile:
        json.dump(filtered_cases, outfile, indent=2)

    print(f"Saved {len(filtered_cases)} cases to 'cleaned_nia_cases.json'")

In [None]:
# Load the filtered cases JSON file
cases_data = load_json('./data/cleaned_nia_cases.json')

# Function to clean the text data
def clean_text(text):
    text = re.sub(r'\r\n', '\n', text)  # Convert \r\n to \n
    text = re.sub(r'(?<![\.\!\?])\n', ' ', text)  # Remove newlines not at end of sentences
    text = re.sub(r'[ \t]+', ' ', text).strip()  # Remove extra whitespaces
    return text

# Function to combine a list of text fields into a single cleaned string
def combine_and_clean_text(text_list):
    combined_text = ' '.join(text_list)  # Combine list into a single string
    return clean_text(combined_text)  # Clean the combined text
    
# Clean specific text fields in the case data
if 'JudgmentText' in cases_data:
    cases_data['JudgmentText'] = [clean_text(para) for para in cases_data['JudgmentText']]
for key in ['Facts', 'RLC', 'Analysis', 'RPC','Arg Resp','Arg_Pet','Result']:
    if key in cases_data:
        cases_data[key] = [combine_and_clean_text(para) for para in cases_data[key]]

# Function to extract citations from a given text
def extract_citations(text: str) -> List[str]:
    pattern = r'Section \d+'  # Match sections like "Section 138" or "section 138"
    citations = re.findall(pattern, text, re.IGNORECASE)  # Case-insensitive search
    citations = [citation.lower() for citation in citations]  # Convert all to lowercase for uniformity
    return citations

# Function to extract citations from specific sections of a case
def extract_from_sections(case: dict) -> Dict[str, List[str]]:
    sections = ['Facts', 'Arg_Pet', 'Arg Resp', 'RLC']
    citations_dict = {}
    for section in sections:
        if section in case:
            texts = case[section]
            all_citations = []
            for text in texts:
                all_citations.extend(extract_citations(text))
            citations_dict[section] = all_citations
    return citations_dict

# Function to extract unique citations from the sections
def extract_unique_citations(citations):
    unique_citations = set()
    for sections in citations.values():
        unique_citations.update(sections)
    return list(unique_citations)

# Summarize and process the case data if the file doesn't already exist
if not file_exists:
    # Check if GPU is available and select the appropriate device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Initialize a summarization pipeline using a pretrained BART model
    summarization_pipeline = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6", device=device)

    # Function to summarize a list of texts using the pipeline
    def summarize_texts(texts, max_length=150, min_length=50, batch_size=10):
        summaries = []
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i + batch_size]
            summaries.extend(summarization_pipeline(batch_texts, max_length=max_length, min_length=min_length, do_sample=False))
        return [summary['summary_text'] for summary in summaries]

    # Convert the sections into a Hugging Face dataset for processing
    dataset = Dataset.from_dict({"section_text": sections})

    # Apply the summarization function to the dataset in batches
    def summarize_batch(batch):
        batch['summarized_text'] = summarize_texts(batch['section_text'], max_length=150, min_length=50, batch_size=10)
        return batch

    # Process and summarize the sections
    summarized_dataset = dataset.map(summarize_batch, batched=True, batch_size=10)

    # Convert summarized sections to a list
    summarized_sections = summarized_dataset['summarized_text']

    # Function to retrieve sections based on query text
    def retrieve_exact_sections(sections, query_text):
        filtered_paragraphs = [paragraph for paragraph in sections if any(name.lower() in paragraph.lower() for name in query_text)]
        return filtered_paragraphs
    

In [None]:
# Embed cases and search for similar ones if the file doesn't exist
if not file_exists:

    # Initialize the tokenizer and model for embedding
    embedding_tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
    embedding_model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
    
    # Function to embed text using the tokenizer and model
    def embed_texts(texts, batch_size=16):
        embeddings = []
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i + batch_size]
            inputs = embedding_tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True, max_length=512).to('cuda')
            embedding_model.to('cuda')  
            with torch.no_grad():
                outputs = embedding_model(**inputs)
                batch_embeddings = outputs.last_hidden_state.mean(dim=1)
            embeddings.append(batch_embeddings.cpu().numpy())
        return np.vstack(embeddings)

    # Function to convert lists to strings
    def convert_to_string(value):
        if isinstance(value, list):
            return ' '.join(value)
        return value

    # Summarize similar cases
    def summarize_similar_cases(text, max_length=200, min_length=50):
        if len(text) > 200:
            summary = summarization_pipeline(text, max_length=max_length, min_length=min_length, do_sample=False)[0]['summary_text']
            return summary
        return text

    # Function to create a FAISS index for searching similar cases
    def create_faiss_index(cases_data, batch_size=16):
        case_embeddings = []
        for case in tqdm(cases_data, desc="Processing Cases"):
            text_to_embed = (convert_to_string(case.get('Facts', '')) + ' ' +
                         convert_to_string(case.get('RLC', '')) + ' ' +
                         convert_to_string(case.get('Arg_Pet', '')) + ' ' +
                         convert_to_string(case.get('Arg Resp', '')))
            case_embeddings.append(text_to_embed)
    
        # Embed the texts in batches
        embeddings = embed_texts(case_embeddings, batch_size=batch_size)
    
        # Create FAISS index
        faiss_index = faiss.IndexFlatL2(embeddings.shape[1])
        faiss_index.add(embeddings)
    
        return faiss_index, embeddings

    # Create the FAISS index and precompute embeddings
    faiss_index, all_case_embeddings = create_faiss_index(cases_data, batch_size=16)

    # Function to fetch similar cases using the FAISS index
    def fetch_similar_cases(case, k=4):
        text_to_embed = (convert_to_string(case.get('Facts', '')) + ' ' +
                         convert_to_string(case.get('RLC', '')) + ' ' +
                         convert_to_string(case.get('Arg_Pet', '')) + ' ' +
                         convert_to_string(case.get('Arg Resp', '')))
        query_embedding = embed_texts([text_to_embed])

        D, I = faiss_index.search(query_embedding, k)
        similar_cases = [{'_id': cases_data[i]['_id'],
                          'Facts': summarize_similar_cases(cases_data[i].get('Facts', '')),
                          'RLC': summarize_similar_cases(cases_data[i].get('RLC', '')),
                          'Arg_Pet': summarize_similar_cases(cases_data[i].get('Arg_Pet', '')),
                          'Arg_Resp': summarize_similar_cases(cases_data[i].get('Arg Resp', '')),
                          'Citation_context': summarize_similar_cases(cases_data[i].get('Citation_context', '')),
                          'Analysis': cases_data[i].get('Analysis', ''),
                          'JudgmentDate': cases_data[i].get('JudgmentDate', '')}
                         for i in I[0]]
    
        # Filter out cases with judgment dates less than the current case
        current_judgement_date = case['JudgmentDate']
        similar_cases = [c for c in similar_cases if c['JudgmentDate'] < current_judgement_date]

        return similar_cases[:1]
    
    # Extract citations, summarize them, and create the 'Similar_Case_Analysis' & 'Citation_context' fields
    for case in tqdm(cases_data, desc="Extracting and Summarizing Citations"):
        citations = extract_from_sections(case)
        unique_citations = extract_unique_citations(citations)
        exact_sections = retrieve_exact_sections(summarized_sections, unique_citations)
        case['Citation_context'] = exact_sections
        
    # Add similar case analysis to each case
    for case in tqdm(cases_data, desc="Adding Similar Cases Analysis"):
        similar_cases = fetch_similar_cases(case)
        similar_cases_str = ""
        for i, sim_case in enumerate(similar_cases):
            similar_cases_str += f"\n### **Similar Case {i+1}:**\n"
            similar_cases_str += f"**Case ID**: {sim_case.get('_id', '')}\n"
            similar_cases_str += f"**Facts**: {convert_to_string(sim_case.get('Facts', ''))}\n"
            similar_cases_str += f"**Ruling by Lower Court**: {convert_to_string(sim_case.get('RLC', ''))}\n"
            similar_cases_str += f"**Argument by Petitioner**: {convert_to_string(sim_case.get('Arg_Pet', ''))}\n"
            similar_cases_str += f"**Argument by Respondent**: {convert_to_string(sim_case.get('Arg_Resp', ''))}\n"
            similar_cases_str += f"**Citation Context**: {convert_to_string(sim_case.get('Citation_context', ''))}\n"
            similar_cases_str += f"**Analysis**: {convert_to_string(sim_case.get('Analysis', ''))}\n"
        case['Similar_Cases_Analysis'] = similar_cases_str

# Concatenate the relevant sections into a single field for each case
if not file_exists:

    def concat_case_inputs(doc):
        key_to_name = {
            'Facts': "Facts",
            'Arg_Pet': "Argument by the Petitioner",
            'Arg Resp': "Argument by the Respondent",
            'RLC': "Ruling by the Lower Court"
        }
        record = ""
        for key in ['Facts', 'Arg_Pet', 'Arg Resp', 'RLC']:
            if key in doc:
                record += f"{key_to_name[key]}:\n"
                record += " ".join(doc[key]) + "\n"
        return record

    # Create Case_Result field that is the concatenation of the Analysis and Result fields
    def concat_outputs(doc):
        analysis = ""
        for key in ['Analysis', 'Result']:
            if key in doc:
                content = "".join(doc[key]).replace("\\n", "\n").strip()
                # Concatenate the cleaned content
                analysis += content + " "
        return analysis.strip()  # Remove any leading/trailing whitespace
    
    # Process the documents, keeping only necessary fields
    filtered_cases = []
    for doc in cases_data:
        processed_doc = {
            '_id': str(doc['_id']),
            'JudgmentDate': doc.get('JudgmentDate', ''),
            'Case_Inputs': concat_case_inputs(doc),
            'Case_Result': concat_outputs(doc),
            'Citation_context': doc.get('Citation_context', ''),
            'Similar_Cases_Analysis': doc.get('Similar_Cases_Analysis', '')
        }
        filtered_cases.append(processed_doc)

    # Save the filtered cases to a file
    with open('./data/nia_cases_to_process.json', 'w') as f:
        for doc in filtered_cases:
            f.write(json.dumps(doc) + '\n')

    print("Filtered documents saved to nia_cases_to_process.json")
