TF_IDF_SRT

- Uses the LegalBERT tokenizer
- Removes duplicate subword tokens
- Computes TF-IDF scores from the training split
- Sorts tokens by descending TF-IDF score
- Truncates to 512 tokens (plus [CLS] and [SEP])
- Returns padded input_ids and attention_mask (binary)


In [None]:
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from collections import Counter
from transformers import AutoTokenizer
from datasets import DatasetDict
import pandas as pd

def tfidf_srt_preprocess(dataset: DatasetDict, tokenizer_name="nlpaueb/legal-bert-base-uncased", max_length=512):
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

    # Step 1: Tokenize train texts for TF-IDF vocabulary
    print("Tokenizing and preparing training corpus for TF-IDF...")
    tokenized_train = [" ".join(tokenizer.tokenize(text)) for text in dataset['train']['text']]

    # Step 2: Fit TF-IDF vectorizer on tokenized training data
    print("Fitting TF-IDF vectorizer...")
    tfidf_vectorizer = TfidfVectorizer(analyzer='word', token_pattern=r'\S+')
    tfidf_vectorizer.fit(tokenized_train)
    idf_dict = dict(zip(tfidf_vectorizer.get_feature_names_out(), tfidf_vectorizer.idf_))

    # Step 3: Preprocess each split using TFIDF-SRT
    processed_data = {}
    for split in dataset.keys():
        print(f"Processing split: {split}")
        input_ids_list = []
        attention_mask_list = []
        labels = dataset[split]['label'] if 'label' in dataset[split].features else [None] * len(dataset[split])

        for text in dataset[split]['text']:
            # Tokenize
            tokens = tokenizer.tokenize(text)

            # Deduplicate tokens (keep first occurrence)
            seen = set()
            unique_tokens = []
            for t in tokens:
                if t not in seen:
                    unique_tokens.append(t)
                    seen.add(t)

            # Score tokens by IDF (TF not needed as per paper)
            sorted_tokens = sorted(unique_tokens, key=lambda t: idf_dict.get(t, 0), reverse=True)

            # Truncate tokens to max usable length (leave space for CLS and SEP)
            sorted_tokens = sorted_tokens[:max_length - 2]

            # Add special tokens
            tokens_final = [tokenizer.cls_token] + sorted_tokens + [tokenizer.sep_token]

            # Convert to input_ids
            input_ids = tokenizer.convert_tokens_to_ids(tokens_final)

            # Pad if needed
            padding_length = max_length - len(input_ids)
            input_ids += [tokenizer.pad_token_id] * padding_length
            attention_mask = [1] * len(tokens_final) + [0] * padding_length

            input_ids_list.append(input_ids)
            attention_mask_list.append(attention_mask)

        processed_data[split] = pd.DataFrame({
            "input_ids": input_ids_list,
            "attention_mask": attention_mask_list,
            "label": labels
        })

    return processed_data

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from datasets import load_dataset
dataset = load_dataset("coastalcph/lex_glue", "scotus")
tfidf_srt_processed = tfidf_srt_preprocess(dataset)

Tokenizing and preparing training corpus for TF-IDF...


Token indices sequence length is longer than the specified maximum sequence length for this model (4330 > 512). Running this sequence through the model will result in indexing errors


Fitting TF-IDF vectorizer...
Processing split: train
Processing split: test
Processing split: validation


In [10]:
from datasets import DatasetDict, Dataset

# Example: if tfidf_srt_processed is a dict of lists or pandas DataFrames
# Convert it to DatasetDict
if isinstance(tfidf_srt_processed, dict):
    tfidf_srt_processed = DatasetDict({
        split: Dataset.from_pandas(data) if not isinstance(data, Dataset) else data
        for split, data in tfidf_srt_processed.items()
    })


In [None]:
tfidf_srt_processed.push_to_hub("victorambrose11/tfidf_srt_processed")


Creating parquet from Arrow format: 100%|██████████| 5/5 [00:00<00:00, 52.93ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:03<00:00,  3.66s/it]
Creating parquet from Arrow format: 100%|██████████| 2/2 [00:00<00:00, 74.67ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:02<00:00,  2.50s/it]
Creating parquet from Arrow format: 100%|██████████| 2/2 [00:00<00:00, 79.88ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:02<00:00,  2.61s/it]


CommitInfo(commit_url='https://huggingface.co/datasets/victorambrose11/tfidf_srt_processed/commit/8e823dbb918ac5eddbe421047fc2e68ac6038edc', commit_message='Upload dataset', commit_description='', oid='8e823dbb918ac5eddbe421047fc2e68ac6038edc', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/victorambrose11/tfidf_srt_processed', endpoint='https://huggingface.co', repo_type='dataset', repo_id='victorambrose11/tfidf_srt_processed'), pr_revision=None, pr_num=None)