## LegalBERT model:

In [1]:
# For retrieving the clauses and labels.
import os
import json
# For the duplicates
from collections import defaultdict, Counter
# For BERT
import torch
import numpy as np
from datasets import load_dataset, Dataset
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments, TrainerCallback
from sklearn.metrics import accuracy_score, classification_report
# For deduplication
import nltk
from nltk.util import ngrams
from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS
nltk.download('punkt')
# For metrics
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
# For plots 
import matplotlib.pyplot as plt
# For hyperparameter tuning
from itertools import product

  from pandas.core.computation.check import NUMEXPR_INSTALLED
  from pandas.core import (
[nltk_data] Downloading package punkt to
[nltk_data]     /Users/benjaminward/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


### 1. Extracting the individual clauses and labels

#### a. Data extraction

In [2]:
data_dir = "data_all_202503120623106"  # Data directory containing company folders
clause_pairs = []

# Step 1: Check if data directory exists
if not os.path.exists(data_dir):
    print(f"❌ ERROR: Data directory '{data_dir}' does not exist.")
    exit()

# Step 2: Initialize a counter for the companies
company_counter = 0

# Step 3: Loop through all company folders inside the data directory
for company in os.listdir(data_dir):
    company_path = os.path.join(data_dir, company)

    # Check if it's a directory (company folder)
    if os.path.isdir(company_path):
        clause_file = os.path.join(company_path, "clauses.json")

        # Step 4: Check if clauses.json exists
        if not os.path.isfile(clause_file):
            print(f"❌ ERROR: 'clauses.json' not found in '{company}' folder")
            continue

        try:
            # Step 5: Check if clauses.json is valid JSON
            with open(clause_file, "r", encoding="utf-8") as f:
                data = json.load(f)

            # Step 6: Check if 'clauses' key exists
            if "clauses" not in data:
                print(f"⚠️ WARNING: No 'clauses' key found in '{company}/clauses.json'")
                continue

            clauses = data["clauses"]
            if not clauses:
                print(f"⚠️ WARNING: 'clauses' list is empty in '{company}/clauses.json'")
                continue

            # Step 7: Extract (clause_text, rating) pairs
            for clause in clauses:
                clause_text = clause.get("clause_text", "").strip() if clause.get("clause_text") else ""
                rating = clause.get("rating", "").strip() if clause.get("rating") else ""

                if clause_text and rating:
                    clause_pairs.append((clause_text, rating))
                else:
                    print(f"⚠️ WARNING: Skipping a clause in '{company}' due to missing clause_text or rating.")
        
        except json.JSONDecodeError:
            print(f"❌ ERROR: Invalid JSON in '{company}/clauses.json'")

        # Step 8: Stop after the first 2 companies
        company_counter += 1
        if company_counter >= 900:
            break  # Exit the loop after processing the first two companies

# Final results
print(f"\n✅ Extracted {len(clause_pairs)} clause-rating pairs from the first two companies.\n")
for pair in clause_pairs[:5]:  # Print first 5 for checking
    print(pair)


✅ Extracted 14407 clause-rating pairs from the first two companies.

('Instead of asking directly, this Service will assume your consent merely from your usage.', 'bad')
('This service tracks which web page referred you to it', 'bad')
('The service can sell or otherwise transfer your personal data as part of a bankruptcy proceeding or other type of financial transaction.', 'bad')
('You must provide your legal name, pseudonyms are not allowed', 'bad')
('This service employs third-party cookies, but with opt-out instructions', 'bad')


#### b. Duplicate analysis

In [3]:
clause_pairs = []  # Reset to include folder info

# Step 3: Loop through all company folders inside the data directory
for company in os.listdir(data_dir):
    company_path = os.path.join(data_dir, company)

    if os.path.isdir(company_path):
        clause_file = os.path.join(company_path, "clauses.json")

        if not os.path.isfile(clause_file):
            print(f"❌ ERROR: 'clauses.json' not found in '{company}' folder")
            continue

        try:
            with open(clause_file, "r", encoding="utf-8") as f:
                data = json.load(f)

            if "clauses" not in data:
                print(f"⚠️ WARNING: No 'clauses' key found in '{company}/clauses.json'")
                continue

            clauses = data["clauses"]
            if not clauses:
                print(f"⚠️ WARNING: 'clauses' list is empty in '{company}/clauses.json'")
                continue

            # Store (clause_text, rating, folder)
            for clause in clauses:
                clause_text = clause.get("clause_text", "").strip() if clause.get("clause_text") else ""
                rating = clause.get("rating", "").strip() if clause.get("rating") else ""

                if clause_text and rating:
                    clause_pairs.append((clause_text, rating, company))  # Store the folder name
                else:
                    print(f"⚠️ WARNING: Skipping a clause in '{company}' due to missing clause_text or rating.")

        except json.JSONDecodeError:
            print(f"❌ ERROR: Invalid JSON in '{company}/clauses.json'")




In [4]:
# Count occurrences of each (clause, rating) pair
pair_counts = Counter((clause_text, rating) for clause_text, rating, _ in clause_pairs)

# Get the top 5 most duplicated pairs
most_common_pairs = pair_counts.most_common(10)

# Map each (clause, rating) to its folder locations
pair_locations = defaultdict(set)
for clause_text, rating, folder in clause_pairs:
    pair_locations[(clause_text, rating)].add(folder)

# Print the results
print("\n🔍 Top 10 most duplicated clause-rating pairs:")
for (clause_text, rating), count in most_common_pairs:
    folders = list(pair_locations[(clause_text, rating)])  # Convert set to list
    print(f"Clause: {clause_text}\nRating: {rating}\nOccurrences: {count}\nFound in folders: {', '.join(folders)}")
    print("-" * 80)


🔍 Top 10 most duplicated clause-rating pairs:
Clause: There is a date of the last update of the agreements
Rating: neutral
Occurrences: 198
Found in folders: Notion, Translate, Apple Services, Audacity, Guilded, Linguee, RethinkDNS, How-To Geek, Yello, HideMyAss!, IVPN, OpenStreetMap, Dark Reader, Privacy.com, Однокла́ссники (Ok.ru), OsmAnd, LBRY, Free, EthanMcBloxxer, Getty Images, Credit Karma, Yahoo!, DrugBank, Symbaloo, Condé Nast, xda-developers, JojoYou (PriEco), Wise, Gfycat, Replika, Fedora Email, Represent, Orange, F-List, Toggl Track, BFM TV, Simitless, Vivaldi, FanFiction, Weblate, Medium, Mozilla Thunderbird, Douban, Encyclopedia Britannica, TubeBuddy, NBC News, FairTec, Nslookup, Avast, kik-messenger, ePlus Technology, Bilibili, VPN.AC, Instagram, Unity, Speedtest by Ookla, Booking.com, WikiTree, Brilliant, YNAB. (You Need a Budget), Free Code Camp, Nextcloud, Consumer News & Business Channel, The Filipino Channel, MuseScore, Wallapop, The Walt Disney Company, Leetify, ST

#### c. Processing code:

In [5]:
data_dir = "data_all_202503120623106"  # Data directory containing company folders
clause_pairs = []

# Step 1: Check if data directory exists
if not os.path.exists(data_dir):
    print(f"❌ ERROR: Data directory '{data_dir}' does not exist.")
    exit()

# Step 2: Initialize a counter for the companies
company_counter = 0

# Step 3: Loop through all company folders inside the data directory
for company in os.listdir(data_dir):
    company_path = os.path.join(data_dir, company)

    # Check if it's a directory (company folder)
    if os.path.isdir(company_path):
        clause_file = os.path.join(company_path, "clauses.json")

        # Step 4: Check if clauses.json exists
        if not os.path.isfile(clause_file):
            print(f"❌ ERROR: 'clauses.json' not found in '{company}' folder")
            continue

        try:
            # Step 5: Check if clauses.json is valid JSON
            with open(clause_file, "r", encoding="utf-8") as f:
                data = json.load(f)

            # Step 6: Check if 'clauses' key exists
            if "clauses" not in data:
                print(f"⚠️ WARNING: No 'clauses' key found in '{company}/clauses.json'")
                continue

            clauses = data["clauses"]
            if not clauses:
                print(f"⚠️ WARNING: 'clauses' list is empty in '{company}/clauses.json'")
                continue

            # Step 7: Extract (clause_text, rating) pairs
            for clause in clauses:
                clause_text = clause.get("clause_text", "").strip() if clause.get("clause_text") else ""
                rating = clause.get("rating", "").strip() if clause.get("rating") else ""

                if clause_text and rating:
                    clause_pairs.append((clause_text, rating))
                else:
                    print(f"⚠️ WARNING: Skipping a clause in '{company}' due to missing clause_text or rating.")
        
        except json.JSONDecodeError:
            print(f"❌ ERROR: Invalid JSON in '{company}/clauses.json'")

        # Step 8: Stop after the first 2 companies
        company_counter += 1
        if company_counter >= 900:
            break  # Exit the loop after processing the first two companies

# Final results
print(f"\n✅ Extracted {len(clause_pairs)} clause-rating pairs from the first two companies.\n")
for pair in clause_pairs[:5]:  # Print first 5 for checking
    print(pair)



✅ Extracted 14407 clause-rating pairs from the first two companies.

('Instead of asking directly, this Service will assume your consent merely from your usage.', 'bad')
('This service tracks which web page referred you to it', 'bad')
('The service can sell or otherwise transfer your personal data as part of a bankruptcy proceeding or other type of financial transaction.', 'bad')
('You must provide your legal name, pseudonyms are not allowed', 'bad')
('This service employs third-party cookies, but with opt-out instructions', 'bad')


In [6]:
# Step 1: Initialize a set to store unique clause-rating pairs, and non-unique clause-rating pais.
unique_clause_pairs = set()
non_unique_clause_pairs = set()
occurrences_of_non_unique_clause_pairs = 0

# Step 2: Filter out duplicates by checking if the pair already exists in the set
filtered_clause_pairs = []

for clause_text, rating in clause_pairs:
    # Use a tuple of (description, rating) as the set key
    pair = (clause_text, rating)
    
    # If the pair is not in the set, add it to the filtered list and the set
    if pair not in unique_clause_pairs:
        filtered_clause_pairs.append(pair)
        unique_clause_pairs.add(pair)
    else:
        non_unique_clause_pairs.add(pair)
        occurrences_of_non_unique_clause_pairs +=1

unique_clause_pairs = list(unique_clause_pairs)
non_unique_clause_pairs = list(non_unique_clause_pairs)

# Step 3: Check how many unique pairs there are
print(f"Before removal of duplicates: {len(clause_pairs)} clause-rating pairs.")
print(f"✅ Removed exact duplicates. {len(filtered_clause_pairs)} unique clause-rating pairs.")
print(f"Number of clauses which appear more than once in our dataset: {len(non_unique_clause_pairs)}.")
print(f"Number of clauses which we removed because there were already present once: {occurrences_of_non_unique_clause_pairs}")


Before removal of duplicates: 14407 clause-rating pairs.
✅ Removed exact duplicates. 1123 unique clause-rating pairs.
Number of clauses which appear more than once in our dataset: 447.
Number of clauses which we removed because there were already present once: 13284


In [7]:
# Example of accessing the descriptions and ratings
for pair in unique_clause_pairs[:5]:  # Print first 5 pairs for checking
    clause_text = pair[0]  # Clause description (x)
    rating = pair[1]       # Clause rating (y)
    print(f"clause_text: {clause_text}\nRating: {rating}\n")

Title: If you are the target of a copyright claim, your content may be removed
Rating: neutral

Title: When the service wants to change its terms, you are notified at least 30 days in advance
Rating: good

Title: The service will not allow third parties to access your personal information without a legal basis
Rating: good

Title: The cookies used by this service do not contain information that would personally identify you
Rating: good

Title: [EU] Information is provided about how they collect personal data
Rating: good



### 2. BERT model

#### a. Filtering: removing pairs with "unknown" as label.

In [8]:
# Step 3.1: Filter out clauses with 'unknown' ratings
filtered_clause_pairs = [(clause_text, rating) for clause_text, rating in unique_clause_pairs if rating != "unknown"]

# Split the filtered data
clauses, ratings = zip(*filtered_clause_pairs)  # Extract clauses and their ratings

# Map ratings to integers
rating_dict = {"blocker": 0, "bad": 1, "neutral": 2, "good": 3}  # Modify if you have different ratings
ratings_int = [rating_dict[r] for r in ratings]

# Step 3.2: Split data into train, dev, and test sets (80% train, 10% dev, 10% test)
X_temp, X_test, y_temp, y_test = train_test_split(clauses, ratings_int, test_size=0.1, random_state=42)

# Print size
print(f"Data set size: {len(filtered_clause_pairs)}")
print(f"Train+dev set size: {len(y_temp)}")
print(f"Test set size: {len(y_test)}")

Data set size: 1074
Train+dev set size: 859
Test set size: 215


#### b. Deduplication based on n-gram similarity

In [9]:
def get_ngrams(text, n=3):
    """Convert text into a set of n-grams."""
    tokens = nltk.word_tokenize(text.lower())  # Tokenize and lowercase
    tokens = [t for t in tokens if t.isalnum() and t not in ENGLISH_STOP_WORDS]  # Remove stopwords and non-alphanumeric
    return set(ngrams(tokens, n))

def jaccard_similarity(set1, set2):
    """Compute Jaccard similarity between two sets."""
    intersection = len(set1 & set2)
    union = len(set1 | set2)
    return intersection / union if union != 0 else 0

def deduplicate_test_set(train_clauses, test_clauses, test_labels, threshold=0.7, n=3):
    """Remove test clauses that are too similar to any train clause based on n-gram Jaccard similarity."""
    train_ngrams = [get_ngrams(clause, n) for clause in train_clauses]
    
    filtered_test_clauses = []
    filtered_test_labels = []
    
    for test_clause, test_label in zip(test_clauses, test_labels):
        test_ngram_set = get_ngrams(test_clause, n)
        
        # Check similarity with each train clause
        max_similarity = max(jaccard_similarity(test_ngram_set, train_set) for train_set in train_ngrams)
        
        if max_similarity < threshold:
            filtered_test_clauses.append(test_clause)
            filtered_test_labels.append(test_label)
    
    return filtered_test_clauses, filtered_test_labels

In [10]:
# Apply deduplication
print(f"Before deduplication: test+dev set size is {len(y_test)}")
X_test, y_test = deduplicate_test_set(X_temp, X_test, y_test, threshold=0.7, n=3)
print(f"After deduplication: test+dev set size is {len(y_test)}")

Before deduplication: test+dev set size is 215
After deduplication: test+dev set size is 151


#### c. Preparing for training LegalBERT

In [11]:
X_train, X_dev, y_train, y_dev = train_test_split(X_temp, y_temp, test_size=0.15, random_state=42)
print(f"Train set size: {len(y_train)}")
print(f"Dev set size: {len(y_dev)}")

# Step 3.3: Convert into a format that Hugging Face can use
train_data = Dataset.from_dict({"text": X_train, "label": y_train})
dev_data = Dataset.from_dict({"text": X_dev, "label": y_dev})
test_data = Dataset.from_dict({"text": X_test, "label": y_test})

Train set size: 730
Dev set size: 129


In [12]:
# Step 4.1: Load the BERT tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# Step 4.2: Define a function to tokenize the input texts
def tokenize_function(examples):
    return tokenizer(examples.get('text', ""), padding='max_length', truncation=True, max_length=512)

# Step 4.3: Apply the tokenizer to the train, dev, and test datasets
train_data = train_data.map(tokenize_function, batched=True)
dev_data = dev_data.map(tokenize_function, batched=True)
test_data = test_data.map(tokenize_function, batched=True)

# Step 4.4: Set the format for PyTorch
train_data.set_format(type='torch', columns=['input_ids', 'attention_mask', 'token_type_ids', 'label'])
dev_data.set_format(type='torch', columns=['input_ids', 'attention_mask', 'token_type_ids', 'label'])
test_data.set_format(type='torch', columns=['input_ids', 'attention_mask', 'token_type_ids', 'label'])

# Step 4.5: Remove the original text filed
train_data = train_data.map(tokenize_function, batched=True, remove_columns=["text"])
dev_data = dev_data.map(tokenize_function, batched=True, remove_columns=["text"])
test_data = test_data.map(tokenize_function, batched=True, remove_columns=["text"])

# Step 4.6: Make sure we are working with longs
train_data = train_data.map(lambda x: {"label": torch.tensor(x["label"]).long()})
dev_data = dev_data.map(lambda x: {"label": torch.tensor(x["label"]).long()})
test_data = test_data.map(lambda x: {"label": torch.tensor(x["label"]).long()})

Map:   0%|          | 0/730 [00:00<?, ? examples/s]

Map:   0%|          | 0/129 [00:00<?, ? examples/s]

Map:   0%|          | 0/151 [00:00<?, ? examples/s]

Map:   0%|          | 0/730 [00:00<?, ? examples/s]

Map:   0%|          | 0/129 [00:00<?, ? examples/s]

Map:   0%|          | 0/151 [00:00<?, ? examples/s]

Map:   0%|          | 0/730 [00:00<?, ? examples/s]

  train_data = train_data.map(lambda x: {"label": torch.tensor(x["label"]).long()})


Map:   0%|          | 0/129 [00:00<?, ? examples/s]

  dev_data = dev_data.map(lambda x: {"label": torch.tensor(x["label"]).long()})


Map:   0%|          | 0/151 [00:00<?, ? examples/s]

  test_data = test_data.map(lambda x: {"label": torch.tensor(x["label"]).long()})


In [13]:
# Step 5.1: Take a smaller sample (e.g., 5%) of the training data
train_sample = train_data.shuffle(seed=42).select(range(int(0.01 * len(train_data))))
dev_sample = dev_data.shuffle(seed=42).select(range(int(0.05*len(dev_data))))

In [14]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=1)  # Convert logits to predicted labels

    accuracy = accuracy_score(labels, preds)
    precision = precision_score(labels, preds, average='weighted')
    recall = recall_score(labels, preds, average='weighted')
    f1 = f1_score(labels, preds, average='weighted')

    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1
    }

#### d. Hyperparameter tuning for Legal-BERT

In [15]:
# Step 1: Define the hyperparameter grid
param_grid = {
    "batch_size": [2],  # Try different batch sizes - use [4,8,16,32]
    "learning_rate": [3e-5, 5e-5],  # Common BERT learning rates - use [2e-5,3e-5,5e-5]
    "num_epochs": [3],  # Vary number of epochs - use [4,8,12,16]
    "dropout_rate": [0.1, 0.2]  # Try different dropout rates - use [0.1, 0.2]
}

# Step 2: Track best model
best_f1 = 0
best_params = None
best_model = None

# Step 3: Iterate over all hyperparameter combinations
for batch_size, lr, epochs, dropout in product(param_grid["batch_size"], 
                                               param_grid["learning_rate"], 
                                               param_grid["num_epochs"], 
                                               param_grid["dropout_rate"]):
    
    print(f"\nTraining with batch_size={batch_size}, lr={lr}, epochs={epochs}, dropout={dropout}\n")

    # Step 4: Modify model with dropout
    model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4)
    model.config.hidden_dropout_prob = dropout
    model.config.attention_probs_dropout_prob = dropout
    
    # Step 5: Define training arguments
    training_args = TrainingArguments(
        output_dir='./results',
        eval_strategy="epoch",
        save_strategy="no",  # Don't save all models to save space
        learning_rate=lr,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        num_train_epochs=epochs,
        weight_decay=0.01,
        logging_dir='./logs',
        logging_steps=10,
        report_to="none"
    )

    # Step 6: Create Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_sample,
        eval_dataset=dev_sample,
        compute_metrics=compute_metrics,
    )

    # Step 7: Train and evaluate
    trainer.train()
    metrics = trainer.evaluate()

    # Step 8: Get weighted F1-score
    f1 = metrics.get("eval_f1", 0)

    # Step 9: Track best model
    if f1 > best_f1:
        best_f1 = f1
        best_params = (batch_size, lr, epochs, dropout)
        best_model = model  # Save the best model in memory

print("\nBest Hyperparameters:", best_params)
print("Best Weighted F1 Score:", best_f1)


Training with batch_size=2, lr=3e-05, epochs=3, dropout=0.1



Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.3767,1.325261,0.25,0.416667,0.25,0.3125
2,1.0957,1.114964,0.583333,0.791667,0.583333,0.537037



Training with batch_size=2, lr=3e-05, epochs=3, dropout=0.2



Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.3198,1.123485,0.5,0.318182,0.5,0.388889
2,1.1035,0.990176,0.75,0.763889,0.75,0.751748



Training with batch_size=2, lr=5e-05, epochs=3, dropout=0.1



Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.3126,1.020145,0.583333,0.791667,0.583333,0.537037
2,1.1243,0.919543,0.666667,0.814815,0.666667,0.647619



Training with batch_size=2, lr=5e-05, epochs=3, dropout=0.2



Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss



Best Hyperparameters: (2, 3e-05, 3, 0.2)
Best Weighted F1 Score: 0.7447552447552447


In [None]:
# Step 10: Evaluate the best model on the test set
if best_model is not None:
    print("\nEvaluating the best model on the test set...")

    trainer = Trainer(
        model=best_model,
        args=training_args,
        eval_dataset=test_data,  # Use the test dataset
        compute_metrics=compute_metrics,
    )

    test_metrics = trainer.evaluate()
    print("\nTest Set Metrics:", test_metrics)
else:
    print("No best model found. Check hyperparameter tuning process.")

In [None]:
# Evaluate on the train set
train_results = trainer.evaluate(train_data)
# Print loss and accuracy
print(f"Train Loss: {train_results['eval_loss']:.4f}")
print(f"Train Accuracy: {train_results['eval_accuracy']:.4f}")  # Accuracy from compute_metrics
# Evaluate on the dev set
dev_results = trainer.evaluate(dev_data)
# Print loss and accuracy
print(f"Dev Loss: {dev_results['eval_loss']:.4f}")
print(f"Dev Accuracy: {dev_results['eval_accuracy']:.4f}")  # Accuracy from compute_metrics

In [None]:
# Evaluate on the test set
test_results = trainer.evaluate(test_data)

# Print loss and accuracy
print(f"Test Loss: {test_results['eval_loss']:.4f}")
print(f"Test Accuracy: {test_results['eval_accuracy']:.4f}")  # Accuracy from compute_metrics