In [2]:
import datasets

dataset = datasets.load_dataset("coastalcph/lex_glue", "scotus")

  from .autonotebook import tqdm as notebook_tqdm


#### Explore the dataset document length 

In [3]:
highest=0
total_length=0
for idx in range(len(dataset['train'])):
    total_length+=len(dataset['train'][idx]['text'])
    if len(dataset['train'][idx]['text']) > highest:
        highest=len(dataset['train'][idx]['text'])
print (f'The average length of documents in training dataset is {round(total_length/len(dataset['train']))}\nThe lengthy document in the dataset contains {highest} number of tokens')        

The average length of documents in training dataset is 35723
The lengthy document in the dataset contains 562772 number of tokens


In [8]:
import math
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from datasets import DatasetDict

# ----- Step 1: Load the Summarization Model -----
# We use a long-document summarization model that can process lengthy texts.
sum_model_name = "nsi319/legal-led-base-16384"
sum_tokenizer = AutoTokenizer.from_pretrained(sum_model_name)
sum_model = AutoModelForSeq2SeqLM.from_pretrained(sum_model_name)

# Create a summarization pipeline (you can adjust parameters later)
summarizer = pipeline("summarization", model=sum_model, tokenizer=sum_tokenizer, framework="pt")

# ----- Step 2: Define a Function to Summarize a Document -----
def summarize_document(document, summarizer, tokenizer, chunk_size=15000, final_max_length=400):
    """
    Summarize a long document in two stages:
      - If document is longer than chunk_size tokens, split into chunks,
        summarize each chunk, then combine and summarize again.
      - Otherwise, summarize directly.
    
    Args:
        document (str): The full text of the document.
        summarizer: The Hugging Face summarization pipeline.
        tokenizer: The tokenizer corresponding to the summarization model.
        chunk_size (int): Maximum tokens per chunk.
        final_max_length (int): Maximum token length for final summary.
    
    Returns:
        str: Final summary text.
    """
    # Tokenize document without truncation
    tokens = tokenizer.encode(document, truncation=False)
    num_tokens = len(tokens)
    
    # If the document is within chunk_size, summarize directly.
    if num_tokens <= chunk_size:
        summary = summarizer(document, max_length=final_max_length, truncation=True)[0]['summary_text']
        return summary
    else:
        num_chunks = math.ceil(num_tokens / chunk_size)
        chunk_summaries = []
        
        # Process each chunk
        for i in range(num_chunks):
            start = i * chunk_size
            end = min((i + 1) * chunk_size, num_tokens)
            # Decode the token IDs for the current chunk back into text.
            chunk_token_ids = tokens[start:end]
            chunk_text = tokenizer.decode(chunk_token_ids, skip_special_tokens=True)
            # Summarize the chunk (you can adjust max_length for intermediate summaries)
            chunk_summary = summarizer(chunk_text, max_length=final_max_length, truncation=True)[0]['summary_text']
            chunk_summaries.append(chunk_summary)
        
        # Combine all chunk summaries into one text.
        combined_summary_text = " ".join(chunk_summaries)
        # Optionally, if the combined summary is still long, summarize it one more time.
        final_summary = summarizer(combined_summary_text, max_length=final_max_length, truncation=True)[0]['summary_text']
        return final_summary

# ----- Step 3: Apply Summarization to the Training Split Only -----
def summarize_example(example):
    # Assume each example has a "text" field with the long legal document.
    example["text"] = summarize_document(example["text"], summarizer, sum_tokenizer,
                                          chunk_size=15000, final_max_length=400)
    return example

# If your dataset is a Hugging Face DatasetDict (with splits 'train', 'test', 'validation'):
# We only summarize the training set and leave the others as is.
summarized_dataset = DatasetDict({
    "train": dataset["train"].map(summarize_example),
    "test": dataset["test"],
    "validation": dataset["validation"]
})



Device set to use mps:0
Map:   0%|          | 1/5000 [00:14<20:04:54, 14.46s/ examples]Input ids are automatically padded from 5490 to 6144 to be a multiple of `config.attention_window`: 1024
Map:   0%|          | 2/5000 [00:28<19:19:39, 13.92s/ examples]Input ids are automatically padded from 6651 to 7168 to be a multiple of `config.attention_window`: 1024
Map:   0%|          | 3/5000 [00:42<19:26:52, 14.01s/ examples]Input ids are automatically padded from 11472 to 12288 to be a multiple of `config.attention_window`: 1024
Map:   0%|          | 4/5000 [01:15<29:54:43, 21.55s/ examples]Input ids are automatically padded from 14719 to 15360 to be a multiple of `config.attention_window`: 1024
Map:   0%|          | 5/5000 [05:51<157:28:28, 113.50s/ examples]Input ids are automatically padded from 1638 to 2048 to be a multiple of `config.attention_window`: 1024
Map:   0%|          | 6/5000 [06:41<127:14:40, 91.73s/ examples] Input ids are automatically padded from 9527 to 10240 to be a mul

KeyboardInterrupt: 

In [None]:
summarized_dataset.push_to_hub("victorambrose11/summarized_scotus")