# Variant 1: TFIDF-SRT-LegalBERT
#### Overview:
• The input document is first tokenized using the LegalBERT tokenizer.
• Duplicate tokens are removed (while preserving only the first occurrence).
• The remaining tokens are sorted in descending order by their TF-IDF score (precomputed on a training corpus).
• The resulting ordered token string is re-tokenized (if needed) and fed into LegalBERT for classification.


## Explanation Variant 1:

• The TF-IDF vectorizer builds a dictionary of sub-word tokens mapped to their inverse document frequency (IDF) values.
• The preprocess_document_bow function deduplicates tokens from each document and sorts them by their corresponding TF-IDF score.
• The resulting ordered token string is then tokenized (again) to produce input IDs suitable for LegalBERT.

Finally, these inputs are fed into the model for classification.

• This variant does not modify the internal architecture of LegalBERT; it only changes the input text.

In [3]:
from datasets import load_dataset

dataset = load_dataset("victorambrose11/normalized_scotus")
dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 5000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 1400
    })
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 1400
    })
})

In [4]:
highest=0
total_length=0
for idx in range(len(dataset['train'])):
    total_length+=len(dataset['train'][idx]['text'])
    if len(dataset['train'][idx]['text']) > highest:
        highest=len(dataset['train'][idx]['text'])
print (f'The average length of documents in training dataset is {round(total_length/len(dataset['train']))}\nThe lengthy document in the dataset contains {highest} number of tokens')        

The average length of documents in training dataset is 37956
The lengthy document in the dataset contains 584365 number of tokens


In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from sklearn.feature_extraction.text import TfidfVectorizer

# Choose the LegalBERT model name (example)
model_name = "nlpaueb/legal-bert-base-uncased"

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)

# Example training documents (in practice, you would use your full training corpus)
documents = [
    "Legal document text example number one.",
    "This is another long legal document for classification."
]

# ----- Preprocessing Step: Compute TF-IDF Scores -----
# Using scikit-learn's TfidfVectorizer to compute TF-IDF on the raw documents.
tfidf_vectorizer = TfidfVectorizer(tokenizer=tokenizer.tokenize, lowercase=True)
tfidf_matrix = tfidf_vectorizer.fit_transform(documents)
feature_names = tfidf_vectorizer.get_feature_names_out()

# Create a dictionary mapping each token (sub-word) to its IDF score.
idf_dict = {token: idf for token, idf in zip(feature_names, tfidf_vectorizer.idf_)}

# ----- Preprocessing Function for TFIDF-SRT-LegalBERT -----
def preprocess_document_bow(doc, tokenizer, idf_dict, max_length=512):
    # Tokenize document using the LegalBERT tokenizer.
    tokens = tokenizer.tokenize(doc)
    
    # Deduplicate tokens: keeping only the first occurrence.
    unique_tokens = list(dict.fromkeys(tokens))
    
    # For each token, fetch its TF-IDF score; use 0 if not found.
    token_scores = [(token, idf_dict.get(token, 0)) for token in unique_tokens]
    
    # Sort tokens by TF-IDF score in descending order.
    token_scores_sorted = sorted(token_scores, key=lambda x: x[1], reverse=True)
    
    # Extract tokens (and then truncate if necessary).
    ordered_tokens = [token for token, score in token_scores_sorted][:max_length]
    
    # Join tokens into a single string for re-tokenization by LegalBERT (alternatively, convert directly to input IDs).
    processed_text = " ".join(ordered_tokens)
    return processed_text

# Example: Preprocess documents using Variant 1 strategy.
processed_docs_variant1 = [preprocess_document_bow(doc, tokenizer, idf_dict) for doc in documents]

# Tokenize the preprocessed documents to obtain model inputs.
inputs_variant1 = tokenizer(processed_docs_variant1, padding=True, truncation=True, 
                            max_length=512, return_tensors="pt")
outputs_variant1 = model(**inputs_variant1)
loss_variant1 = outputs_variant1.loss
logits_variant1 = outputs_variant1.logits

print("Variant 1 logits:", logits_variant1)
