<a href="https://colab.research.google.com/github/ShadowMonarch9871/Legal-Summarizer/blob/main/Legal%20Summarizer%20using%20LLM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install gradio langchain transformers torch pypdf
!pip install -U langchain-community

Collecting langchain-community
  Downloading langchain_community-0.3.23-py3-none-any.whl.metadata (2.5 kB)
Collecting dataclasses-json<0.7,>=0.5.7 (from langchain-community)
  Downloading dataclasses_json-0.6.7-py3-none-any.whl.metadata (25 kB)
Collecting pydantic-settings<3.0.0,>=2.4.0 (from langchain-community)
  Downloading pydantic_settings-2.9.1-py3-none-any.whl.metadata (3.8 kB)
Collecting httpx-sse<1.0.0,>=0.4.0 (from langchain-community)
  Downloading httpx_sse-0.4.0-py3-none-any.whl.metadata (9.0 kB)
Collecting marshmallow<4.0.0,>=3.18.0 (from dataclasses-json<0.7,>=0.5.7->langchain-community)
  Downloading marshmallow-3.26.1-py3-none-any.whl.metadata (7.3 kB)
Collecting typing-inspect<1,>=0.4.0 (from dataclasses-json<0.7,>=0.5.7->langchain-community)
  Downloading typing_inspect-0.9.0-py3-none-any.whl.metadata (1.5 kB)
Collecting python-dotenv>=0.21.0 (from pydantic-settings<3.0.0,>=2.4.0->langchain-community)
  Downloading python_dotenv-1.1.0-py3-none-any.whl.metadata (24 kB

In [15]:
import gradio as gr
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import PyPDFLoader
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import re
import tempfile
import nltk
try:
    from nltk.tokenize import sent_tokenize
except ImportError:
    print("NLTK import failed, will use custom sentence tokenizer")

# Download NLTK resources (only needed once)
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    try:
        nltk.download('punkt', quiet=True)
        print("Successfully downloaded NLTK punkt resource")
    except Exception as e:
        print(f"Warning: Could not download NLTK resources: {e}")
        print("Will use custom sentence splitting instead")

# Constants
MAX_CHUNK_LENGTH = 1024  # Token limit for T5-based models
MODEL_NAME = "manjunathainti/fine_tuned_t5_summarizer"  # Pretrained Legal T5 summarizer
SHORT_SUMMARY_LENGTH = 150  # For concise summaries
LONG_SUMMARY_LENGTH = 300  # For detailed summaries

# Load the T5 model and tokenizer
def load_model():
    """Load a T5 model fine-tuned for legal summarization."""
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")
    model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device)
    return tokenizer, model

# Clean and preprocess the input text
def preprocess_text(text):
    """Clean and standardize extracted legal text."""
    # Remove page numbers, dates, and footnotes
    text = re.sub(r"Page \d+|[0-9]{1,2}/[0-9]{1,2}/[0-9]{2,4}|Footnote.*", "", text)

    # Handle section markers more carefully
    text = re.sub(r"^\s*\([a-z]\)\s*", "\n\\0", text, flags=re.MULTILINE)

    # Normalize whitespace
    text = re.sub(r"\n{3,}", "\n\n", text)  # Standardize paragraph breaks
    text = re.sub(r"\s{2,}", " ", text)     # Remove multiple spaces

    # Replace dashed separators with paragraph breaks
    text = re.sub(r"[*\-_]{3,}", "\n\n", text)

    # Handle common legal document formatting issues
    text = re.sub(r"(?<=[a-z])\.\s*(?=[A-Z])", ".\n", text)  # Add line break between sentences if missing

    return text.strip()

# Intelligently split text into semantic chunks
def chunk_text(text, max_tokens=MAX_CHUNK_LENGTH, overlap=150):
    """Split text into smaller chunks that respect sentence boundaries when possible."""
    # First split by paragraphs
    paragraphs = text.split("\n\n")
    chunks = []

    # Create a splitter with higher overlap to maintain context between chunks
    splitter = RecursiveCharacterTextSplitter(
        chunk_size=max_tokens,
        chunk_overlap=overlap,
        separators=["\n\n", "\n", ". ", ", ", " ", ""]
    )

    current_chunk = ""
    for para in paragraphs:
        # If adding this paragraph would exceed the limit, process the current chunk
        if len(current_chunk) + len(para) > max_tokens * 4:  # Approximate char count
            # Further split if needed
            if len(current_chunk) > max_tokens * 4:
                sub_chunks = splitter.split_text(current_chunk)
                chunks.extend(sub_chunks)
            else:
                chunks.append(current_chunk)
            current_chunk = para
        else:
            if current_chunk:
                current_chunk += "\n\n" + para
            else:
                current_chunk = para

    # Add the last chunk
    if current_chunk:
        if len(current_chunk) > max_tokens * 4:
            sub_chunks = splitter.split_text(current_chunk)
            chunks.extend(sub_chunks)
        else:
            chunks.append(current_chunk)

    # Ensure chunks aren't too small
    filtered_chunks = [c for c in chunks if len(c.split()) > 20]

    return filtered_chunks

# Generate both short and long summaries for text chunks
def summarize_chunks(chunks, tokenizer, model, summary_type="both", num_beams=4):
    """Generate summaries for each chunk of text with options for short or long summaries."""
    device = next(model.parameters()).device
    short_summaries = []
    long_summaries = []

    for chunk in chunks:
        inputs = tokenizer(chunk, return_tensors="pt", truncation=True, max_length=MAX_CHUNK_LENGTH).to(device)

        # Generate short summary
        if summary_type in ["short", "both"]:
            short_summary_ids = model.generate(
                inputs.input_ids,
                max_length=SHORT_SUMMARY_LENGTH,
                min_length=min(50, len(chunk.split()) // 10),  # Reasonable minimum length
                num_beams=num_beams,
                early_stopping=True,
                length_penalty=1.0,
                no_repeat_ngram_size=3,
            )
            short_summary = tokenizer.decode(short_summary_ids[0], skip_special_tokens=True)
            short_summaries.append(short_summary.strip())

        # Generate long summary
        if summary_type in ["long", "both"]:
            long_summary_ids = model.generate(
                inputs.input_ids,
                max_length=LONG_SUMMARY_LENGTH,
                min_length=min(100, len(chunk.split()) // 5),  # Reasonable minimum length
                num_beams=num_beams,
                early_stopping=True,
                length_penalty=1.0,
                no_repeat_ngram_size=2,
            )
            long_summary = tokenizer.decode(long_summary_ids[0], skip_special_tokens=True)
            long_summaries.append(long_summary.strip())

    return short_summaries, long_summaries

# Custom sentence tokenizer as fallback if NLTK is not available
def custom_sent_tokenize(text):
    """Split text into sentences using regex patterns."""
    # Split on common sentence endings followed by space and capital letter
    sentences = re.split(r'(?<=[.!?])\s+(?=[A-Z])', text)

    # Further split any remaining long segments
    result = []
    for sent in sentences:
        if len(sent) > 150:  # If sentence is too long
            subsents = re.split(r'(?<=[;:])\s+(?=[A-Z])', sent)  # Split on semicolons and colons
            result.extend(subsents)
        else:
            result.append(sent)
    return result

# Intelligently combine summaries for coherence
def combine_summaries(summaries, max_length=None):
    """Combine summaries into a coherent text, avoiding redundancy."""
    if not summaries:
        return ""

    # Extract unique sentences to avoid redundancy
    all_sentences = []
    sentence_set = set()

    for summary in summaries:
        # Try to use NLTK's tokenizer, fall back to custom if not available
        try:
            sentences = sent_tokenize(summary)
        except (NameError, LookupError):
            sentences = custom_sent_tokenize(summary)

        for sentence in sentences:
            # Normalize sentence for comparison
            normalized = re.sub(r'\s+', ' ', sentence.lower()).strip()
            if normalized not in sentence_set and len(normalized) > 10:
                sentence_set.add(normalized)
                all_sentences.append(sentence)

    # If we have too many sentences, prioritize the first sentence from each summary
    # to maintain coverage of the entire document
    if max_length and len(' '.join(all_sentences)) > max_length:
        first_sentences = []
        for summary in summaries:
            try:
                sentences = sent_tokenize(summary)
            except (NameError, LookupError):
                sentences = custom_sent_tokenize(summary)

            if sentences:
                first_sentence = sentences[0]
                normalized = re.sub(r'\s+', ' ', first_sentence.lower()).strip()
                if normalized not in sentence_set and len(normalized) > 10:
                    sentence_set.add(normalized)
                    first_sentences.append(first_sentence)

        # Use first sentences if they're representative enough
        if first_sentences and len(' '.join(first_sentences)) >= max_length // 2:
            all_sentences = first_sentences

    # Join sentences, ensuring proper spacing
    combined_text = ' '.join(all_sentences)

    # Clean up spacing and formatting
    combined_text = re.sub(r'\s+', ' ', combined_text)
    combined_text = re.sub(r'\s+\.', '.', combined_text)
    combined_text = re.sub(r'\s+,', ',', combined_text)

    return combined_text

# Extract legal terms from the document for improved highlighting
def extract_legal_terms(text):
    """Extract domain-specific legal terms from the document."""
    base_legal_terms = [
        "Section", "Article", "Clause", "Amendment", "Schedule", "Act", "Law",
        "Constitution", "Provision", "Regulation", "Statute", "Directive",
        "Legislative", "Assembly", "Parliament", "Election", "Commission",
        "President", "Governor", "Cabinet", "Council", "Minister", "Bill"
    ]

    # Find potential additional terms (capitalized multi-word phrases)
    additional_terms = re.findall(r'\b([A-Z][a-z]+(?:\s+[A-Z][a-z]+)+)\b', text)

    # Combine and remove duplicates
    all_terms = base_legal_terms + additional_terms
    return list(set(all_terms))

# Highlight keywords in the summary
def highlight_keywords(summary_text, keywords=None):
    """Emphasize critical legal terms in the summary."""
    if keywords is None:
        keywords = [
            "Section", "Article", "Clause", "Amendment", "Schedule", "Act", "Law",
            "Constitution", "Provision", "Regulation"
        ]

    # Highlight terms that are complete words (not parts of other words)
    for keyword in sorted(keywords, key=len, reverse=True):
        summary_text = re.sub(
            fr'\b({re.escape(keyword)})\b',
            r'**\1**',
            summary_text,
            flags=re.IGNORECASE
        )

    return summary_text

# Summarize sections matching specific keywords
def summarize_target_sections(text, tokenizer, model, keywords, summary_type="both", num_beams=4):
    """Summarize specific sections of the document based on target keywords."""
    # Create a safe regex pattern from keywords
    try:
        keyword_regex = "|".join(re.escape(keyword) for keyword in keywords if keyword.strip())
    except Exception:
        # Fallback if there's a problem with the regex
        print("Warning: Issue with keyword regex, using simple matching")
        keyword_regex = "|".join(keywords)

    # Extract paragraphs containing keywords
    paragraphs = text.split("\n\n")
    target_sections = []

    # Use try/except to handle potential regex issues
    try:
        target_sections = [
            para for para in paragraphs
            if para.strip() and any(re.search(re.escape(kw), para, re.IGNORECASE) for kw in keywords)
        ]
    except Exception as e:
        print(f"Warning: Error in keyword matching: {e}")
        # Simple fallback matching
        target_sections = [
            para for para in paragraphs
            if para.strip() and any(kw.lower() in para.lower() for kw in keywords)
        ]

    if not target_sections:
        return "No sections containing the specified keywords were found.", ""

    # Combine related paragraphs to maintain context
    combined_sections = []
    current_section = ""

    for section in target_sections:
        if not current_section:
            current_section = section
        elif len(current_section) + len(section) < MAX_CHUNK_LENGTH * 4:
            current_section += "\n\n" + section
        else:
            combined_sections.append(current_section)
            current_section = section

    if current_section:
        combined_sections.append(current_section)

    # Handle empty combined sections (shouldn't happen, but just in case)
    if not combined_sections:
        return "Error: No valid sections to summarize.", ""

    # Summarize each combined section
    short_summaries, long_summaries = summarize_chunks(
        combined_sections, tokenizer, model, summary_type, num_beams
    )

    # Combine summaries with error handling
    try:
        short_combined = combine_summaries(short_summaries, SHORT_SUMMARY_LENGTH * 2)
        long_combined = combine_summaries(long_summaries, LONG_SUMMARY_LENGTH * 2)
    except Exception as e:
        print(f"Warning: Error combining summaries: {e}")
        # Simple concatenation as fallback
        short_combined = " ".join(short_summaries)
        long_combined = " ".join(long_summaries)

    return short_combined, long_combined

# Create a structured summary with sections
def create_structured_summary(short_summary, long_summary, metrics, keywords=None):
    """Format the summary into a structured document with sections."""
    legal_terms = keywords if keywords else []

    # Highlight keywords in both summaries
    highlighted_short = highlight_keywords(short_summary, legal_terms)
    highlighted_long = highlight_keywords(long_summary, legal_terms)

    # Create a structured summary document
    structured_summary = f"""# Legal Document Summary

## Metrics
{metrics}

## Executive Summary
{highlighted_short}

## Detailed Summary
{highlighted_long}
"""
    return structured_summary

# Process PDF and generate summaries
def process_pdf(file_obj, keywords=None, summary_type="both", num_beams=4):
    """Process a legal PDF document, summarize it, and provide metrics."""
    if file_obj is None:
        return "Please upload a PDF file.", None

    try:
        # Read and preprocess the uploaded PDF
        loader = PyPDFLoader(file_obj.name)
        pages = loader.load_and_split()
        text = " ".join([page.page_content for page in pages])
        text = preprocess_text(text)

        # Extract document-specific legal terms for highlighting
        legal_terms = extract_legal_terms(text)

        # Load the T5 summarizer model
        tokenizer, model = load_model()

        # Perform keyword-based summarization if keywords are provided
        if keywords and keywords.strip():
            keyword_list = [k.strip() for k in keywords.split(',')]
            short_summary, long_summary = summarize_target_sections(
                text, tokenizer, model, keyword_list, summary_type, num_beams
            )
        else:
            # Chunk the text and generate summaries
            chunks = chunk_text(text)
            short_summaries, long_summaries = summarize_chunks(
                chunks, tokenizer, model, summary_type, num_beams
            )

            # Combine summaries for coherence
            short_summary = combine_summaries(short_summaries, SHORT_SUMMARY_LENGTH * 2)
            long_summary = combine_summaries(long_summaries, LONG_SUMMARY_LENGTH * 2)

        # Calculate summarization metrics
        total_words = len(text.split())
        short_word_count = len(short_summary.split())
        long_word_count = len(long_summary.split())
        short_compression = round((short_word_count / total_words) * 100, 2)
        long_compression = round((long_word_count / total_words) * 100, 2)

        metrics = (
            f"Original Word Count: {total_words}\n"
            f"Short Summary Word Count: {short_word_count} (Compression: {short_compression}%)\n"
            f"Long Summary Word Count: {long_word_count} (Compression: {long_compression}%)"
        )

        # Create the final structured summary
        final_summary = create_structured_summary(
            short_summary, long_summary, metrics, legal_terms
        )

        # Save the summarized text in a temporary file for download
        with tempfile.NamedTemporaryFile(delete=False, suffix=".md") as tmp_file:
            tmp_file.write(final_summary.encode("utf-8"))
            download_path = tmp_file.name

        return final_summary, download_path

    except Exception as e:
        import traceback
        print(traceback.format_exc())
        return f"An error occurred: {str(e)}", None

# Gradio Interface
def main():
    with gr.Blocks(theme=gr.themes.Soft()) as iface:
        gr.Markdown("# Legal Document Summarizer (Fine-Tuned T5)")
        gr.Markdown(
            "Upload a legal PDF to generate summaries with optional keyword-based selection. "
            "The model highlights critical legal terms and provides both short and detailed summaries."
        )

        with gr.Row():
            with gr.Column(scale=1):
                file_input = gr.File(
                    label="Upload Legal Document (PDF)",
                    file_types=[".pdf"]
                )
                keywords_input = gr.Textbox(
                    label="Target Keywords (Optional, comma-separated)",
                    placeholder="e.g., Section, Clause, Election"
                )

                with gr.Row():
                    summary_type = gr.Radio(
                        ["short", "long", "both"],
                        label="Summary Type",
                        value="both"
                    )
                    beam_count = gr.Slider(
                        1, 10, step=1, value=4,
                        label="Number of Beams (Higher = more diverse)"
                    )

                submit_btn = gr.Button("Generate Summary", variant="primary")

            with gr.Column(scale=2):
                output_text = gr.Markdown(label="Summary")
                download_output = gr.File(label="Download Summary")

        submit_btn.click(
            fn=process_pdf,
            inputs=[file_input, keywords_input, summary_type, beam_count],
            outputs=[output_text, download_output]
        )

    iface.launch(debug=True, share=True)

if __name__ == "__main__":
    main()

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://dbf2265ca7775bd30e.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Using device: cpu
Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://dbf2265ca7775bd30e.gradio.live
