<a href="https://colab.research.google.com/github/ShubhamT2720/Natural-Language-Processing-for-Legal-Documents/blob/main/Text_Sum_Final.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install transformers
!pip install accelerate
!pip install sentencepiece
!pip install pandas numpy
!pip install nltk





In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [15]:
dataset = "/content/drive/MyDrive/IN-Abs"
output_path = "/content/drive/MyDrive/Output"

In [16]:
import glob
from nltk import tokenize
import nltk
import transformers
import pandas as pd
import numpy as np
import sys
sys.path.insert(0, '../')
import os


In [17]:
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [18]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [19]:
if not os.path.exists(output_path):
    os.makedirs(output_path)

In [20]:
import os
import glob
import nltk

def get_root_path():
    '''
    Function to get the root path of the dataset.
    '''
    # Ensure this path points to the root directory of your dataset.
    path = "/content/drive/MyDrive/IN-Abs"
    return path

def get_summary_data(dataset, data_type):
    '''
    Function to get names, documents, and summaries.

    Args:
    - dataset: The name of the dataset (currently not used as there is only one dataset).
    - data_type: Specifies whether to use 'train-data' or 'test-data'.

    Returns:
    - names: List of document names.
    - data_source: List of document contents from judgment files.
    - data_summary: List of summaries from summary files.
    '''
    # Set the path for the judgment and summary folders based on data_type.
    path_judgment = os.path.join(get_root_path(), f'{data_type}/judgement')
    path_summary = os.path.join(get_root_path(), f'{data_type}/summary')

    # Load all judgment files.
    all_judgment_files = glob.glob(path_judgment + "/*.txt")
    data_source = []
    names = []

    for filename in all_judgment_files:
        with open(filename, 'r') as f:
            p = filename.rfind("/")
            names.append(filename[p+1:])  # Get the file name without the path.
            a = f.read()
            data_source.append(a)

    # Load all summary files.
    all_summary_files = glob.glob(path_summary + "/*.txt")
    data_summary = []

    for filename in all_summary_files:
        with open(filename, 'r') as f:
            a = f.read()
            data_summary.append(a)

    return names, data_source, data_summary

def get_req_len_dict(dataset, data_type):
    '''
    Function to retrieve required length data for each summary.

    Args:
    - dataset: The name of the dataset (currently not used).
    - data_type: Specifies whether to use 'train-data' or 'test-data'.

    Returns:
    - dict_names: A dictionary mapping document names to their required summary lengths.
    '''
    length_file_path = os.path.join(get_root_path(), f"{data_type}/stats-IN-test.txt")

    with open(length_file_path, "r") as f:
        a = f.read().split("\n")

    dict_names = {}
    for i in a:
        b = i.split("\t")
        try:
            dict_names[b[0]] = int(b[1])
        except IndexError:
            print(f"Error parsing line: {i}")

    return dict_names

def split_to_sentences(para):
    '''
    Function to split a paragraph into sentences.

    Args:
    - para: A string containing a paragraph of text.

    Returns:
    - A list of sentences.
    '''
    sents = nltk.sent_tokenize(para)
    return sents

def nest_sentences(document, chunk_length):
    '''
    Function to chunk a document into nested sentences.

    Args:
    - document: The input document as a string.
    - chunk_length: The maximum length of each chunk in words.

    Returns:
    - nested: A list of chunks, where each chunk is a list of sentences.
    '''
    nested = []
    sent = []
    length = 0

    for sentence in nltk.sent_tokenize(document):
        length += len(sentence.split(" "))
        if length < chunk_length:
            sent.append(sentence)
        else:
            nested.append(sent)
            sent = [sentence]
            length = len(sentence.split(" "))  # Reset length for the new chunk

    if len(sent) > 0:
        nested.append(sent)

    return nested

# Reading the test documents
dataset = "IN-Abs"
names, data_source, data_summary = get_summary_data(dataset, "test-data")
print(f"Number of names: {len(names)}")
print(f"Number of documents: {len(data_source)}")
print(f"Number of summaries: {len(data_summary)}")

# Getting the required summary lengths
len_dic = dict_names = get_req_len_dict(dataset, "test-data")


Number of names: 100
Number of documents: 100
Number of summaries: 100
Error parsing line: 


In [34]:
device = "cuda:1"

In [35]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Automatically select device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("nsi319/legal-pegasus")
model = AutoModelForSeq2SeqLM.from_pretrained("nsi319/legal-pegasus").to(device)

print(f"Model loaded on device: {device}")

Model loaded on device: cuda


In [36]:
def summarize(text, max_len, min_len):
    '''
    Function to generate summary using Pegasus.

    Args:
    - text: The input text to be summarized.
    - max_len: Maximum length of the summary.
    - min_len: Minimum length of the summary.

    Returns:
    - summary: The generated summary.
    '''
    try:
        input_tokenized = tokenizer.encode(text, return_tensors='pt', max_length=512, truncation=True).to(device)
        summary_ids = model.generate(
            input_tokenized,
            num_beams=9,
            length_penalty=0.1,
            min_length=min_len,
            max_length=max_len,
        )
        summary = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in summary_ids][0]
        return summary
    except Exception as e:
        print(f"Error during summarization: {e}")
        return ""

In [37]:
def summarize_doc(nested_sentences, p):
  result = []
  for nested in nested_sentences:
    # Join sentences directly
    chunk_text = " ".join(nested)
    l = int(p * len(chunk_text.split(" ")))
    max_len = l + 10  # Allow some flexibility for truncation
    min_len = l - 5
    summary = summarize(chunk_text, max_len, min_len)
    # Truncate while respecting sentence boundaries
    sentences = sent_tokenize(summary)
    truncated_summary = sentences[:min(len(sentences), l)]
    result.append(" ".join(truncated_summary))
  return " ".join(result)

In [40]:
# List done files once before the loop
done_files = glob.glob(os.path.join(output_path, "*.txt"))
done_files = [os.path.basename(f) for f in done_files]

In [41]:
# Main loop to generate and save summaries of each document in the test dataset
for i in range(min(5, len(data_source))):  # Limit to the first 5 documents
    name = names[i]

    # Skip if file has already been processed
    if name in done_files:
        continue

    doc = data_source[i]
    input_len = len(doc.split(" "))

    # Check if name is in dict_names
    if name not in dict_names:
        print(f"Warning: Required length for '{name}' not found in dict_names.")
        continue

    req_len = dict_names[name]

    # Print information for debugging
    print(f"{i}: {name} - {input_len} : {req_len}", end=", ")

    # Avoid division by zero
    if input_len == 0:
        print("Error: Input length is zero, skipping this document.")
        continue

    nested = nest_sentences(doc, 512)
    p = float(req_len) / input_len
    print(f"p: {p}")

    abs_summ = summarize_doc(nested, p)
    abs_summ = " ".join(abs_summ.split())

    # Print the length of the summary for debugging
    print(f"Summary length before truncation: {len(abs_summ.split(' '))}")

    # Ensure summary is not truncated mid-sentence
    sentences = split_to_sentences(abs_summ)
    if len(abs_summ.split(" ")) > req_len:
        abs_summ = " ".join(sentences[:len(sentences)])

    # Print the final length of the summary
    print(f"Final summary length: {len(abs_summ.split(' '))}")

    # Write the summary to a file
    path = os.path.join(output_path, name)
    try:
        with open(path, 'w') as file:
            file.write(abs_summ)
    except IOError as e:
        print(f"Error writing to file '{path}': {e}")

0: 1181.txt - 3387 : 3510, p: 1.0363153232949514
Summary length before truncation: 3137
Final summary length: 3137
1: 1195.txt - 4234 : 4389, p: 1.0366084081247047
Summary length before truncation: 3912
Final summary length: 3912
2: 1329.txt - 2990 : 3083, p: 1.0311036789297658
Summary length before truncation: 2818
Final summary length: 2818
3: 1378.txt - 2202 : 2281, p: 1.0358764759309718
Summary length before truncation: 1992
Final summary length: 1992
4: 1406.txt - 2089 : 2165, p: 1.0363810435615126
Summary length before truncation: 1973
Final summary length: 1973
