**Author: Sandhya Nayar**

This notebook develops a data pipeline designed to process Terms of Service (ToS) documents and generate user-friendly explanations for clauses identified as potentially problematic or "shady."

The pipeline begins by reading and processing the document, followed by tokenization. The resulting tokens are then passed through a pre-trained LegalBERT model tailored for sequence classification. Using retrieval-augmented generation (RAG), LegalBERT analyzes each clause of the document along with its contextual information, classifying each clause as either "shady" or "not shady."

Attention weights extracted from LegalBERT’s output are used to identify the key terms that contributed to the model’s classification decision.

In cases where a clause is classified as "shady," the clause, its classification, and the key terms are passed to a 1-billion-parameter LLaMA model. This model then generates a concise, user-friendly explanation detailing the reasons for the shady classification.

In [110]:
# Open the example file
with open("/Discord.txt", "r") as file:
    tos_text = file.read()

print(tos_text)

Login

Discord Terms of Service
Last modified: May 7, 2020.
IMPORTANT NOTICE: THESE TERMS OF SERVICE CONTAIN A BINDING ARBITRATION PROVISION AND CLASS ACTION WAIVER. IT
AFFECTS YOUR LEGAL RIGHTS AS DETAILED IN THE DISPUTE RESOLUTION AND CLASS ACTION WAIVER SECTION BELOW. PLEASE
READ CAREFULLY.

INTRODUCTION AND ACCEPTING THE TERMS
Welcome to Discord! These Terms of Service (“Terms”), which include and hereby incorporate the Privacy Policy at https://discord.com/privacy
(“Privacy Policy”), are a legal agreement between Discord Inc. and its related companies (the “Company,” “us,” “our,” or "we") and you ("you" or
“your”). By using or accessing the Discord application (the “App”) or the website located at https://discord.com (the "Site"), which are
collectively referred to as the “Service,” you agree (i) that you are 13 years of age and the minimum age of digital consent in your country, (ii) if
you are the age of majority in your jurisdiction or over, that you have read, understood, and 

In [80]:
# Pre Processing

import nltk

# Download punkt
nltk.download('punkt')
nltk.download('punkt_tab')

from nltk.tokenize import sent_tokenize

def segment_document(text):
    return sent_tokenize(text)

segments = segment_document(tos_text)

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


In [81]:
# Clean text

import re

def clean_text(text):
    text = text.lower()  # Converting to lowercase
    text = re.sub(r'\s+', ' ', text)  # Remove extra whitespace
    text = re.sub(r'[^\w\s.,;\'\"()-]', '', text)  # Keep legal characters
    return text

cleaned_segments = [clean_text(segment) for segment in segments]
print("Example Segments: ", cleaned_segments[:5])

Example Segments:  ['login discord terms of service last modified may 7, 2020.', 'important notice these terms of service contain a binding arbitration provision and class action waiver.', 'it affects your legal rights as detailed in the dispute resolution and class action waiver section below.', 'please read carefully.', 'introduction and accepting the terms welcome to discord']


In [111]:
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
!pip install faiss-gpu
import faiss
import torch

# Initialize tokenizer and models
bert_tokenizer = AutoTokenizer.from_pretrained("nlpaueb/legal-bert-base-uncased")
classification_model = AutoModelForSequenceClassification.from_pretrained("nlpaueb/legal-bert-base-uncased")
embedding_model = AutoModel.from_pretrained("nlpaueb/legal-bert-base-uncased")

# Create FAISS index for Retrieval Augmented Generation
# Create and embed document segments
dimension = embedding_model.config.hidden_size
index = faiss.IndexFlatL2(dimension)

# Function to generate embeddings using the embedding model
def embed_texts(texts):
    embeddings = []
    for text in texts:
        inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
        with torch.no_grad():
            outputs = embedding_model(**inputs).last_hidden_state.mean(dim=1)
            embeddings.append(outputs)
    return torch.cat(embeddings, dim=0)

document_embeddings = embed_texts(cleaned_segments)
index.add(document_embeddings.numpy())




Some weights of BertForSequenceClassification were not initialized from the model checkpoint at nlpaueb/legal-bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [83]:
def top_attention_weights(attentions, clause_token_indices):
    print(clause_token_indices)
    layer_idx = len(attentions) - 1  # Last layer index
    attention_weights_last_layer = attentions[layer_idx]

    # Average across all attention heads (dim=1 corresponds to the num_heads axis)
    average_attention = attention_weights_last_layer[0].mean(dim=0)

    # Extract the attention for clause tokens
    clause_attention = average_attention[clause_token_indices, :]

    clause_attention_sum = clause_attention.sum(dim=1)  # Sum across the columns for each clause token to create an average

    # Get the indices of the top 5 attention scores
    top_5_indices = torch.argsort(clause_attention_sum, descending=True)[:5]

    # Retrieve the tokens corresponding to the top 5 indices
    top_5_tokens = [clause_token_indices[i] for i in top_5_indices]

    print(top_5_tokens)

    return top_5_tokens

In [84]:
import numpy as np

# Classify clauses
def classify_clause(segment, similar_clauses):
    input_text = f"Context: {' '.join(similar_clauses)}\nClause: {segment}\nClassify if this clause is shady."
    inputs = bert_tokenizer(input_text, return_tensors="pt", truncation=True, padding=True, return_offsets_mapping=True)

    # Get token indexes
    offsets = inputs.pop("offset_mapping")[0]

    # Identify the segments based on offsets
    clause_start = input_text.find(f"Clause: {segment}")  # Adjust start to the actual clause
    clause_end_index = input_text.find("Classify if this clause is shady.")  # Exclude prompt

    # clause_start_index is just after the "Clause: "
    clause_only_start_index = clause_start + len(f"Clause: ")
    clause_length = len(f"Clause: {segment}")
    print(clause_length)

    # Identify indices of Clause tokens that fall within the "Clause: {segment}" part
    clause_token_indices = [
        i for i, (start, end) in enumerate(offsets)
        if start >= clause_only_start_index and end <= clause_end_index
    ]


    # Calculate output
    outputs = classification_model(**inputs, output_attentions=True)
    prediction = torch.argmax(outputs.logits, dim=1).item()
    classification_result = "Shady" if prediction == 1 else "Not Shady"

    # Attention Mechanisms
    attentions = outputs.attentions

    top_5_tokens = top_attention_weights(attentions, clause_token_indices)

    # Retrieve the corresponding token labels (words) from the token IDs
    token_labels = bert_tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
    top_5_words = [token_labels[token_idx] for token_idx in top_5_tokens]

    return top_5_words, classification_result

In [85]:
# Smaller test for only n clauses
n = 2
stop = 0
results = []
for i, segment in enumerate(cleaned_segments):
    print(f"Analyzing Clause {i+1}: {segment}")
    # Retrieve top-k similar clauses for the current clause
    segment_embedding = embed_texts([segment]).numpy()
    _, indices = index.search(segment_embedding, k=3)  # Retrieve top 3 similar clauses
    similar_clauses = [cleaned_segments[idx] for idx in indices[0] if idx != i]

    # Classify the clause and get the key tokens
    top_5_words, classification = classify_clause(segment, similar_clauses)

    # Add the results to the output
    results.append({
        "clause": segment,
        "classification": classification,
        "top_terms": top_5_words  # Include the top 5 terms with their attention scores
    })
    stop = stop + 1
    if stop == n:
        break


Analyzing Clause 1: login discord terms of service last modified may 7, 2020.
65
Clause start: 307
Clause-only start index: 315
Clause end index: 373
[66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80]
[77, 78, 80, 66, 67]
Analyzing Clause 2: important notice these terms of service contain a binding arbitration provision and class action waiver.
112
Clause start: 294
Clause-only start index: 302
Clause end index: 407
[54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69]
[54, 55, 57, 59, 60]
Analyzing Clause 3: it affects your legal rights as detailed in the dispute resolution and class action waiver section below.
113
Clause start: 243
Clause-only start index: 251
Clause end index: 357
[47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65]
[56, 65, 64, 62, 61]
Analyzing Clause 4: please read carefully.
30
Clause start: 158
Clause-only start index: 166
Clause end index: 189
[29, 30, 31, 32]
[32, 29, 30, 31]
Analyzing Clause 5: introduction and a

In [86]:
# Process and classify all segments
results = []
for i, segment in enumerate(cleaned_segments):
    # Retrieve top-k similar clauses for the current clause (RAG)
    segment_embedding = embed_texts([segment]).numpy()
    _, indices = index.search(segment_embedding, k=3)  # Retrieve top 3 similar clauses
    similar_clauses = [cleaned_segments[idx] for idx in indices[0] if idx != i]

    # Classify the clause and get the key tokens
    top_5_tokens, classification = classify_clause(segment, similar_clauses)

    # Add the results to the output
    results.append({
        "clause": segment,
        "classification": classification,
        "top_terms": top_5_tokens
    })


65
Clause start: 307
Clause-only start index: 315
Clause end index: 373
[66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80]
[77, 78, 80, 66, 67]
112
Clause start: 294
Clause-only start index: 302
Clause end index: 407
[54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69]
[54, 55, 57, 59, 60]
113
Clause start: 243
Clause-only start index: 251
Clause end index: 357
[47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65]
[56, 65, 64, 62, 61]
30
Clause start: 158
Clause-only start index: 166
Clause end index: 189
[29, 30, 31, 32]
[32, 29, 30, 31]
63
Clause start: 222
Clause-only start index: 230
Clause end index: 286
[52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62]
[55, 52, 53, 54, 56]
266
Clause start: 901
Clause-only start index: 909
Clause end index: 1168
[215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 25

In [112]:
# Count the number of clauses classified as "Shady"
shady_count = sum(1 for result in results if result["classification"] == "Shady")

# Output the count
print(f"There were {shady_count} shady clauses out of {len(result)} total clauses")

247
Number of shady clauses: 100


In [1]:
# Check if GPU is available
import torch
print("GPU available:", torch.cuda.is_available())
print("GPU name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "NO GPU")

GPU available: True
GPU name: Tesla T4


In [2]:
# Login to access transformers
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [3]:
!pip install bitsandbytes>=0.39.0
!pip install --upgrade accelerate transformers datasets peft trl



In [88]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

In [89]:
model_id = "huihui-ai/Llama-3.2-1B-Instruct-abliterated"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [90]:
# Import tokenizers

lama_tokenizer = AutoTokenizer.from_pretrained(model_id)
llama_tokenizer.pad_token_id = llama_tokenizer.eos_token_id

In [91]:
# Configure BitsAndBytes
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_quant_type= "nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

In [92]:
# Load model

llama_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    low_cpu_mem_usage=True
)
llama_model.to(device)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear4bit(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear4bit(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear4bit(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear4bit(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear4bit(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear4bit(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), 

In [103]:
# Chat template for format of input from LegalBERT classification

def apply_chat_template(system_prompt, clause, top_terms, prompt_, max_new_tokens=150):

    user_message = (
        f"Clause: {clause}\n\n"
        f"Top terms: {', '.join(top_terms)}\n\n"
        f"Prompt: {prompt_}"
    )

    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_message}
    ]

    inputs = llama_tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
        padding=True,
        truncation=True,
    ).to(device)
    outputs = llama_model.generate(
        inputs,
        max_new_tokens=max_new_tokens,
        pad_token_id=llama_tokenizer.pad_token_id
    )
    return llama_tokenizer.decode(outputs[0], skip_special_tokens=True)


In [105]:
# Practice Run

system_prompt = "Your task is to evaluate legal clauses for potentially shady terms and explain why in less than three sentences, using provided top terms as guidance. Follow this format for your output: 'The following clause was considered shady because (insert explanation)'"
clause = "By uploading, distributing, transmitting or otherwise using Your Content with the Service, you grant to us a perpetual, nonexclusive, transferable, royalty-free, sublicensable, and worldwide license to use, host, reproduce, modify, adapt, publish, translate, create derivative works from, distribute, perform, and display Your Content in connection with operating and providing the Service."
top_terms = ["grant", "perpetual", "nonexclusive", "adapt"]
prompt_ = "Generate a user explanation for why this clause might be shady using the top terms to identify why it was shady."
output = apply_chat_template(system_prompt, clause, top_terms, prompt_)
print(output)


system

Cutting Knowledge Date: December 2023
Today Date: 11 Jan 2025

Your task is to evaluate legal clauses for potentially shady terms and explain why in less than three sentences, using provided top terms as guidance. Follow this format for your output: 'The following clause was considered shady because (insert explanation)'user

Clause: By uploading, distributing, transmitting or otherwise using Your Content with the Service, you grant to us a perpetual, nonexclusive, transferable, royalty-free, sublicensable, and worldwide license to use, host, reproduce, modify, adapt, publish, translate, create derivative works from, distribute, perform, and display Your Content in connection with operating and providing the Service.

Top terms: grant, perpetual, nonexclusive, adapt

Prompt: Generate a user explanation for why this clause might be shady using the top terms to identify why it was shady.assistant

The following clause was considered shady because it grants a "royalty-free" licens

In [106]:
# System prompt
system_prompt = "Your task is to evaluate legal clauses for potentially shady terms and explain why in less than four sentences, using provided top terms as guidance. Follow this format for your output: 'The following clause was considered shady because (insert explanation)'"
chat_template = "{% for message in messages %}{{message.role}}: {{message.content}}\n{% endfor %}"


# Loop through results and feed only shady clauses into Llama
final_outputs = []  # To store the responses
for result in results:
    clause = result["clause"]
    classification = result.get("classification", "Unclassified")
    top_terms = result.get("top_terms", None)

    # Format the top terms
    if isinstance(top_terms, (list, np.ndarray)) and len(top_terms) > 0:
        top_terms_str = ", ".join(top_terms)
    else:
        top_terms_str = "None"

    if classification.lower() == "shady":
        # Construct user prompt
        user_prompt = (
            f"Clause: {clause}\n\n"
            f"Classification: {classification}\n\n"
            f"Top Terms: {top_terms_str}\n\n"
            f"Prompt: Generate a user explanation for why this clause might be shady using the top terms to identify why it was shady."
        )

        # Get Llama's response
        llama_output = apply_chat_template(system_prompt, clause, top_terms, user_prompt)
    else:
        llama_output = "N/A - Clause not classified as shady."

    final_outputs.append({
        "clause": clause,
        "classification": classification,
        "top_terms": top_terms,
        "explanation": llama_output
    })

# Print final outputs
for output in final_outputs:
    print(f"Clause: {output['clause']}")
    print(f"Classification: {output['classification']}")
    print(f"Top Terms: {output['top_terms']}")
    print(f"Explanation: {output['explanation']}")
    print("-" * 80)


Clause: login discord terms of service last modified may 7, 2020.
Classification: Shady
Top Terms: ['7', ',', '.', 'log', '##in']
Explanation: system

Cutting Knowledge Date: December 2023
Today Date: 11 Jan 2025

Your task is to evaluate legal clauses for potentially shady terms and explain why in less than four sentences, using provided top terms as guidance. Follow this format for your output: 'The following clause was considered shady because (insert explanation)'user

Clause: login discord terms of service last modified may 7, 2020.

Top terms: 7,,,., log, ##in

Prompt: Clause: login discord terms of service last modified may 7, 2020.

Classification: Shady

Top Terms: 7,,,., log, ##in

Prompt: Generate a user explanation for why this clause might be shady using the top terms to identify why it was shady.assistant

The following clause was considered shady because it:

* Requires users to agree to a specific date (May 7, 2020) that has already passed, which could be seen as an att