#    Preprocesses data for Hierarchical LegalBERT:
1. Splits documents into paragraphs
2. Tokenizes each paragraph separately
3. Creates a structure suitable for hierarchical processing

In [1]:
from datasets import load_dataset
original_dataset=load_dataset("coastalcph/lex_glue", "scotus")

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import numpy as np
from transformers import AutoTokenizer
import pandas as pd
from concurrent.futures import ThreadPoolExecutor, as_completed

def process_document(text, label, tokenizer, max_paragraphs, max_paragraph_length):
    # Split text into paragraphs
    paragraphs = text.split('\n\n')
    paragraphs = [p for p in paragraphs if p.strip()]
    
    # Limit the number of paragraphs if necessary
    if len(paragraphs) > max_paragraphs:
        paragraphs = paragraphs[:max_paragraphs]
    
    actual_paragraph_count = len(paragraphs)
    
    # Tokenize each paragraph
    paragraph_encodings = [
        tokenizer.encode_plus(
            paragraph,
            add_special_tokens=True,
            max_length=max_paragraph_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True
        ) for paragraph in paragraphs
    ]
    
    paragraph_input_ids = [encoding['input_ids'] for encoding in paragraph_encodings]
    paragraph_attention_masks = [encoding['attention_mask'] for encoding in paragraph_encodings]
    
    # Pad to max_paragraphs if necessary
    if len(paragraph_input_ids) < max_paragraphs:
        pad_input_ids = [tokenizer.cls_token_id] + [tokenizer.pad_token_id] * (max_paragraph_length - 2) + [tokenizer.sep_token_id]
        pad_attention_mask = [1, 0, 1] + [0] * (max_paragraph_length - 3)
        
        paragraph_input_ids.extend([pad_input_ids] * (max_paragraphs - len(paragraph_input_ids)))
        paragraph_attention_masks.extend([pad_attention_mask] * (max_paragraphs - len(paragraph_attention_masks)))
    
    return paragraph_input_ids, paragraph_attention_masks, actual_paragraph_count, label

def preprocess_hierarchical_legalbert(dataset, tokenizer_name="nlpaueb/legal-bert-base-uncased", 
                                     max_paragraphs=20, max_paragraph_length=512, max_workers=None):
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    
    processed_dataset = {split: {} for split in dataset.keys()}
    
    for split in dataset.keys():
        print(f"Processing {split} set for hierarchical encoding...")
        
        texts = dataset[split]['text']
        labels = dataset[split]['labels'] if 'labels' in dataset[split] else [None] * len(texts)
        
        doc_paragraph_input_ids = []
        doc_paragraph_attention_masks = []
        doc_paragraph_counts = []
        doc_labels = []
        
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            future_to_doc = {executor.submit(process_document, text, label, tokenizer, max_paragraphs, max_paragraph_length): (text, label) 
                             for text, label in zip(texts, labels)}
            
            for future in as_completed(future_to_doc):
                paragraph_input_ids, paragraph_attention_masks, actual_paragraph_count, label = future.result()
                doc_paragraph_input_ids.append(paragraph_input_ids)
                doc_paragraph_attention_masks.append(paragraph_attention_masks)
                doc_paragraph_counts.append(actual_paragraph_count)
                doc_labels.append(label)
        
        processed_dataset[split]['paragraph_input_ids'] = doc_paragraph_input_ids
        processed_dataset[split]['paragraph_attention_masks'] = doc_paragraph_attention_masks
        processed_dataset[split]['paragraph_counts'] = doc_paragraph_counts
        processed_dataset[split]['labels'] = doc_labels
    
    result_dataset = {split: pd.DataFrame(data) for split, data in processed_dataset.items()}
    
    return result_dataset

# Example usage:
# preprocessed_data = preprocess_hierarchical_legalbert(dataset, max_workers=4)

In [4]:
preprocessed_data = preprocess_hierarchical_legalbert(original_dataset, max_workers=4)

Processing train set for hierarchical encoding...
Processing test set for hierarchical encoding...
Processing validation set for hierarchical encoding...


In [5]:
from datasets import DatasetDict, Dataset

# Example: if tfidf_srt_processed is a dict of lists or pandas DataFrames
# Convert it to DatasetDict
if isinstance(preprocessed_data, dict):
    preprocessed_data = DatasetDict({
        split: Dataset.from_pandas(data) if not isinstance(data, Dataset) else data
        for split, data in preprocessed_data.items()
    })


In [6]:
preprocessed_data.push_to_hub("victorambrose11/TF_IDF_EMB_Hirearchy")


Creating parquet from Arrow format: 100%|██████████| 3/3 [00:00<00:00,  3.30ba/s]
Creating parquet from Arrow format: 100%|██████████| 3/3 [00:00<00:00,  3.35ba/s]
Uploading the dataset shards: 100%|██████████| 2/2 [00:09<00:00,  4.56s/it]
Creating parquet from Arrow format: 100%|██████████| 2/2 [00:00<00:00,  3.56ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:03<00:00,  3.25s/it]
Creating parquet from Arrow format: 100%|██████████| 2/2 [00:00<00:00,  4.16ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:05<00:00,  5.17s/it]


CommitInfo(commit_url='https://huggingface.co/datasets/victorambrose11/TF_IDF_EMB_Hirearchy/commit/b2afbee1b8f5d9dc4d3d72f0e148c095904e84a4', commit_message='Upload dataset', commit_description='', oid='b2afbee1b8f5d9dc4d3d72f0e148c095904e84a4', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/victorambrose11/TF_IDF_EMB_Hirearchy', endpoint='https://huggingface.co', repo_type='dataset', repo_id='victorambrose11/TF_IDF_EMB_Hirearchy'), pr_revision=None, pr_num=None)