In [1]:
import datasets

dataset = datasets.load_dataset("coastalcph/lex_glue", "scotus")

#### Explore the dataset document length 

In [2]:
highest=0
total_length=0
for idx in range(len(dataset['train'])):
    total_length+=len(dataset['train'][idx]['text'])
    if len(dataset['train'][idx]['text']) > highest:
        highest=len(dataset['train'][idx]['text'])
print (f'The average length of documents in training dataset is {round(total_length/len(dataset['train']))}\nThe lengthy document in the dataset contains {highest} number of tokens')        

The average length of documents in training dataset is 35723
The lengthy document in the dataset contains 562772 number of tokens


In [5]:
import math
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from datasets import DatasetDict

# ----- Step 1: Load the Summarization Model -----
# We use a long-document summarization model that can process lengthy texts.
sum_model_name = "nsi319/legal-led-base-16384"
sum_tokenizer = AutoTokenizer.from_pretrained(sum_model_name)
sum_model = AutoModelForSeq2SeqLM.from_pretrained(sum_model_name)

# Create a summarization pipeline (you can adjust parameters later)
summarizer = pipeline("summarization", model=sum_model, tokenizer=sum_tokenizer, framework="pt")

# ----- Step 2: Define a Function to Summarize a Document -----
def summarize_document(document, summarizer, tokenizer, chunk_size=15000, final_max_length=512):
    """
    Summarize a long document in two stages:
      - If document is longer than chunk_size tokens, split into chunks,
        summarize each chunk, then combine and summarize again.
      - Otherwise, summarize directly.
    
    Args:
        document (str): The full text of the document.
        summarizer: The Hugging Face summarization pipeline.
        tokenizer: The tokenizer corresponding to the summarization model.
        chunk_size (int): Maximum tokens per chunk.
        final_max_length (int): Maximum token length for final summary.
    
    Returns:
        str: Final summary text.
    """
    # Tokenize document without truncation
    tokens = tokenizer.encode(document, truncation=False)
    num_tokens = len(tokens)
    
    # If the document is within chunk_size, summarize directly.
    if num_tokens <= chunk_size:
        summary = summarizer(document, max_length=final_max_length, truncation=True)[0]['summary_text']
        return summary
    else:
        num_chunks = math.ceil(num_tokens / chunk_size)
        chunk_summaries = []
        
        # Process each chunk
        for i in range(num_chunks):
            start = i * chunk_size
            end = min((i + 1) * chunk_size, num_tokens)
            # Decode the token IDs for the current chunk back into text.
            chunk_token_ids = tokens[start:end]
            chunk_text = tokenizer.decode(chunk_token_ids, skip_special_tokens=True)
            # Summarize the chunk (you can adjust max_length for intermediate summaries)
            chunk_summary = summarizer(chunk_text, max_length=final_max_length, truncation=True)[0]['summary_text']
            chunk_summaries.append(chunk_summary)
        
        # Combine all chunk summaries into one text.
        combined_summary_text = " ".join(chunk_summaries)
        # Optionally, if the combined summary is still long, summarize it one more time.
        final_summary = summarizer(combined_summary_text, max_length=final_max_length, truncation=True)[0]['summary_text']
        return final_summary

# ----- Step 3: Apply Summarization to the Training Split Only -----
def summarize_example(example):
    # Assume each example has a "text" field with the long legal document.
    example["text"] = summarize_document(example["text"], summarizer, sum_tokenizer,
                                          chunk_size=15000, final_max_length=400)
    return example

# If your dataset is a Hugging Face DatasetDict (with splits 'train', 'test', 'validation'):
# We only summarize the training set and leave the others as is.
summarized_dataset = DatasetDict({
    "train": dataset["train"].map(summarize_example),
    "test": dataset["test"],
    "validation": dataset["validation"]
})

Device set to use cuda:0


Map:   0%|          | 0/5000 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (19422 > 16384). Running this sequence through the model will result in indexing errors
Input ids are automatically padded from 762 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 3322 to 4096 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 11296 to 12288 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 9269 to 10240 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 8616 to 9216 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 15002 to 15360 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 2799 to 3072 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 1196 to 2048 to be a multiple of `c

In [None]:
summarized_dataset.push_to_hub("victorambrose11/summarized_scotus_latest")

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/5 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/2 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/2 [00:00<?, ?ba/s]

CommitInfo(commit_url='https://huggingface.co/datasets/victorambrose11/summarized_scotus_latest/commit/4b3c0124fb3efce22e9a6f6c5c39bac595fe458d', commit_message='Upload dataset', commit_description='', oid='4b3c0124fb3efce22e9a6f6c5c39bac595fe458d', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/victorambrose11/summarized_scotus_latest', endpoint='https://huggingface.co', repo_type='dataset', repo_id='victorambrose11/summarized_scotus_latest'), pr_revision=None, pr_num=None)