<a href="https://colab.research.google.com/github/ShadowMonarch9871/Legal-Summarizer/blob/main/Legal%20Summarizer%20using%20LLM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install gradio langchain transformers torch pypdf
!pip install -U langchain-community
!pip install rouge

Collecting gradio
  Downloading gradio-5.29.0-py3-none-any.whl.metadata (16 kB)
Collecting pypdf
  Downloading pypdf-5.4.0-py3-none-any.whl.metadata (7.3 kB)
Collecting aiofiles<25.0,>=22.0 (from gradio)
  Downloading aiofiles-24.1.0-py3-none-any.whl.metadata (10 kB)
Collecting fastapi<1.0,>=0.115.2 (from gradio)
  Downloading fastapi-0.115.12-py3-none-any.whl.metadata (27 kB)
Collecting ffmpy (from gradio)
  Downloading ffmpy-0.5.0-py3-none-any.whl.metadata (3.0 kB)
Collecting gradio-client==1.10.0 (from gradio)
  Downloading gradio_client-1.10.0-py3-none-any.whl.metadata (7.1 kB)
Collecting groovy~=0.1 (from gradio)
  Downloading groovy-0.1.2-py3-none-any.whl.metadata (6.1 kB)
Collecting pydub (from gradio)
  Downloading pydub-0.25.1-py2.py3-none-any.whl.metadata (1.4 kB)
Collecting python-multipart>=0.0.18 (from gradio)
  Downloading python_multipart-0.0.20-py3-none-any.whl.metadata (1.8 kB)
Collecting ruff>=0.9.3 (from gradio)
  Downloading ruff-0.11.8-py3-none-manylinux_2_17_x86_6

In [3]:
import gradio as gr
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import PyPDFLoader
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import re
import tempfile
import nltk
try:
    from nltk.tokenize import sent_tokenize
except ImportError:
    print("NLTK import failed, will use custom sentence tokenizer")


try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    try:
        nltk.download('punkt', quiet=True)
        print("Successfully downloaded NLTK punkt resource")
    except Exception as e:
        print(f"Warning: Could not download NLTK resources: {e}")
        print("Will use custom sentence splitting instead")


MAX_CHUNK_LENGTH = 1024
MODEL_NAME = "manjunathainti/fine_tuned_t5_summarizer"
SHORT_SUMMARY_LENGTH = 150
LONG_SUMMARY_LENGTH = 300

def load_model():
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")
    model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device)
    return tokenizer, model

def preprocess_text(text):
    text = re.sub(r"Page \d+|[0-9]{1,2}/[0-9]{1,2}/[0-9]{2,4}|Footnote.*", "", text)
    text = re.sub(r"^\s*\([a-z]\)\s*", "\n\\0", text, flags=re.MULTILINE)
    text = re.sub(r"\n{3,}", "\n\n", text)
    text = re.sub(r"\s{2,}", " ", text)
    text = re.sub(r"[*\-_]{3,}", "\n\n", text)
    text = re.sub(r"(?<=[a-z])\.\s*(?=[A-Z])", ".\n", text)
    return text.strip()

def chunk_text(text, max_tokens=MAX_CHUNK_LENGTH, overlap=150):
    paragraphs = text.split("\n\n")
    chunks = []

    splitter = RecursiveCharacterTextSplitter(
        chunk_size=max_tokens,
        chunk_overlap=overlap,
        separators=["\n\n", "\n", ". ", ", ", " ", ""]
    )

    current_chunk = ""
    for para in paragraphs:
        if len(current_chunk) + len(para) > max_tokens * 4:
            if len(current_chunk) > max_tokens * 4:
                sub_chunks = splitter.split_text(current_chunk)
                chunks.extend(sub_chunks)
            else:
                chunks.append(current_chunk)
            current_chunk = para
        else:
            if current_chunk:
                current_chunk += "\n\n" + para
            else:
                current_chunk = para

    if current_chunk:
        if len(current_chunk) > max_tokens * 4:
            sub_chunks = splitter.split_text(current_chunk)
            chunks.extend(sub_chunks)
        else:
            chunks.append(current_chunk)

    filtered_chunks = [c for c in chunks if len(c.split()) > 20]

    return filtered_chunks

def summarize_chunks(chunks, tokenizer, model, summary_type="both", num_beams=4):
    device = next(model.parameters()).device
    short_summaries = []
    long_summaries = []

    for chunk in chunks:
        inputs = tokenizer(chunk, return_tensors="pt", truncation=True, max_length=MAX_CHUNK_LENGTH).to(device)

        if summary_type in ["short", "both"]:
            short_summary_ids = model.generate(
                inputs.input_ids,
                max_length=SHORT_SUMMARY_LENGTH,
                min_length=min(50, len(chunk.split()) // 10),
                num_beams=num_beams,
                early_stopping=True,
                length_penalty=1.0,
                no_repeat_ngram_size=3,
            )
            short_summary = tokenizer.decode(short_summary_ids[0], skip_special_tokens=True)
            short_summaries.append(short_summary.strip())

        if summary_type in ["long", "both"]:
            long_summary_ids = model.generate(
                inputs.input_ids,
                max_length=LONG_SUMMARY_LENGTH,
                min_length=min(100, len(chunk.split()) // 5),
                num_beams=num_beams,
                early_stopping=True,
                length_penalty=1.0,
                no_repeat_ngram_size=2,
            )
            long_summary = tokenizer.decode(long_summary_ids[0], skip_special_tokens=True)
            long_summaries.append(long_summary.strip())

    return short_summaries, long_summaries

def custom_sent_tokenize(text):
    sentences = re.split(r'(?<=[.!?])\s+(?=[A-Z])', text)

    result = []
    for sent in sentences:
        if len(sent) > 150:
            subsents = re.split(r'(?<=[;:])\s+(?=[A-Z])', sent)
            result.extend(subsents)
        else:
            result.append(sent)
    return result

def combine_summaries(summaries, max_length=None):
    if not summaries:
        return ""

    all_sentences = []
    sentence_set = set()

    for summary in summaries:
        try:
            sentences = sent_tokenize(summary)
        except (NameError, LookupError):
            sentences = custom_sent_tokenize(summary)

        for sentence in sentences:
            normalized = re.sub(r'\s+', ' ', sentence.lower()).strip()
            if normalized not in sentence_set and len(normalized) > 10:
                sentence_set.add(normalized)
                all_sentences.append(sentence)

    if max_length and len(' '.join(all_sentences)) > max_length:
        first_sentences = []
        for summary in summaries:
            try:
                sentences = sent_tokenize(summary)
            except (NameError, LookupError):
                sentences = custom_sent_tokenize(summary)

            if sentences:
                first_sentence = sentences[0]
                normalized = re.sub(r'\s+', ' ', first_sentence.lower()).strip()
                if normalized not in sentence_set and len(normalized) > 10:
                    sentence_set.add(normalized)
                    first_sentences.append(first_sentence)

        if first_sentences and len(' '.join(first_sentences)) >= max_length // 2:
            all_sentences = first_sentences

    combined_text = ' '.join(all_sentences)

    combined_text = re.sub(r'\s+', ' ', combined_text)
    combined_text = re.sub(r'\s+\.', '.', combined_text)
    combined_text = re.sub(r'\s+,', ',', combined_text)

    return combined_text

def extract_legal_terms(text):
    base_legal_terms = [
        "Section", "Article", "Clause", "Amendment", "Schedule", "Act", "Law",
        "Constitution", "Provision", "Regulation", "Statute", "Directive",
        "Legislative", "Assembly", "Parliament", "Election", "Commission",
        "President", "Governor", "Cabinet", "Council", "Minister", "Bill"
    ]

    additional_terms = re.findall(r'\b([A-Z][a-z]+(?:\s+[A-Z][a-z]+)+)\b', text)

    all_terms = base_legal_terms + additional_terms
    return list(set(all_terms))

def highlight_keywords(summary_text, keywords=None):
    if keywords is None:
        keywords = [
            "Section", "Article", "Clause", "Amendment", "Schedule", "Act", "Law",
            "Constitution", "Provision", "Regulation"
        ]

    for keyword in sorted(keywords, key=len, reverse=True):
        summary_text = re.sub(
            fr'\b({re.escape(keyword)})\b',
            r'**\1**',
            summary_text,
            flags=re.IGNORECASE
        )

    return summary_text

def summarize_target_sections(text, tokenizer, model, keywords, summary_type="both", num_beams=4):
    try:
        keyword_regex = "|".join(re.escape(keyword) for keyword in keywords if keyword.strip())
    except Exception:
        print("Warning: Issue with keyword regex, using simple matching")
        keyword_regex = "|".join(keywords)

    paragraphs = text.split("\n\n")
    target_sections = []

    try:
        target_sections = [
            para for para in paragraphs
            if para.strip() and any(re.search(re.escape(kw), para, re.IGNORECASE) for kw in keywords)
        ]
    except Exception as e:
        print(f"Warning: Error in keyword matching: {e}")
        target_sections = [
            para for para in paragraphs
            if para.strip() and any(kw.lower() in para.lower() for kw in keywords)
        ]

    if not target_sections:
        return "No sections containing the specified keywords were found.", ""

    combined_sections = []
    current_section = ""

    for section in target_sections:
        if not current_section:
            current_section = section
        elif len(current_section) + len(section) < MAX_CHUNK_LENGTH * 4:
            current_section += "\n\n" + section
        else:
            combined_sections.append(current_section)
            current_section = section

    if current_section:
        combined_sections.append(current_section)

    if not combined_sections:
        return "Error: No valid sections to summarize.", ""

    short_summaries, long_summaries = summarize_chunks(
        combined_sections, tokenizer, model, summary_type, num_beams
    )

    try:
        short_combined = combine_summaries(short_summaries, SHORT_SUMMARY_LENGTH * 2)
        long_combined = combine_summaries(long_summaries, LONG_SUMMARY_LENGTH * 2)
    except Exception as e:
        print(f"Warning: Error combining summaries: {e}")
        short_combined = " ".join(short_summaries)
        long_combined = " ".join(long_summaries)

    return short_combined, long_combined

def create_structured_summary(short_summary, long_summary, metrics, keywords=None):
    legal_terms = keywords if keywords else []

    highlighted_short = highlight_keywords(short_summary, legal_terms)
    highlighted_long = highlight_keywords(long_summary, legal_terms)

    structured_summary = f"""# Legal Document Summary

## Metrics
{metrics}

## Executive Summary
{highlighted_short}

## Detailed Summary
{highlighted_long}
"""
    return structured_summary

def process_pdf(file_obj, keywords=None, summary_type="both", num_beams=4):
    if file_obj is None:
        return "Please upload a PDF file.", None

    try:
        loader = PyPDFLoader(file_obj.name)
        pages = loader.load_and_split()
        text = " ".join([page.page_content for page in pages])
        text = preprocess_text(text)

        legal_terms = extract_legal_terms(text)

        tokenizer, model = load_model()

        if keywords and keywords.strip():
            keyword_list = [k.strip() for k in keywords.split(',')]
            short_summary, long_summary = summarize_target_sections(
                text, tokenizer, model, keyword_list, summary_type, num_beams
            )
        else:
            chunks = chunk_text(text)
            short_summaries, long_summaries = summarize_chunks(
                chunks, tokenizer, model, summary_type, num_beams
            )

            short_summary = combine_summaries(short_summaries, SHORT_SUMMARY_LENGTH * 2)
            long_summary = combine_summaries(long_summaries, LONG_SUMMARY_LENGTH * 2)

        total_words = len(text.split())
        short_word_count = len(short_summary.split())
        long_word_count = len(long_summary.split())
        short_compression = round((short_word_count / total_words) * 100, 2)
        long_compression = round((long_word_count / total_words) * 100, 2)

        metrics = (
            f"Original Word Count: {total_words}\n"
            f"Short Summary Word Count: {short_word_count} (Compression: {short_compression}%)\n"
            f"Long Summary Word Count: {long_word_count} (Compression: {long_compression}%)"
        )

        final_summary = create_structured_summary(
            short_summary, long_summary, metrics, legal_terms
        )

        with tempfile.NamedTemporaryFile(delete=False, suffix=".md") as tmp_file:
            tmp_file.write(final_summary.encode("utf-8"))
            download_path = tmp_file.name

        return final_summary, download_path

    except Exception as e:
        import traceback
        print(traceback.format_exc())
        return f"An error occurred: {str(e)}", None

def main():
    with gr.Blocks(theme=gr.themes.Soft()) as iface:
        gr.Markdown("# Legal Document Summarizer (Fine-Tuned T5)")
        gr.Markdown(
            "Upload a legal PDF to generate summaries with optional keyword-based selection. "
            "The model highlights critical legal terms and provides both short and detailed summaries."
        )

        with gr.Row():
            with gr.Column(scale=1):
                file_input = gr.File(
                    label="Upload Legal Document (PDF)",
                    file_types=[".pdf"]
                )
                keywords_input = gr.Textbox(
                    label="Target Keywords (Optional, comma-separated)",
                    placeholder="e.g., Section, Clause, Election"
                )

                with gr.Row():
                    summary_type = gr.Radio(
                        ["short", "long", "both"],
                        label="Summary Type",
                        value="both"
                    )
                    beam_count = gr.Slider(
                        1, 10, step=1, value=4,
                        label="Number of Beams (Higher = more diverse)"
                    )

                submit_btn = gr.Button("Generate Summary", variant="primary")

            with gr.Column(scale=2):
                output_text = gr.Markdown(label="Summary")
                download_output = gr.File(label="Download Summary")

        submit_btn.click(
            fn=process_pdf,
            inputs=[file_input, keywords_input, summary_type, beam_count],
            outputs=[output_text, download_output]
        )

    iface.launch(debug=True, share=True)

if __name__ == "__main__":
    main()

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://8b93281d8921dfc6f4.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Using device: cuda




Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://8b93281d8921dfc6f4.gradio.live
