In [13]:
%cd LawGorithmML/citations/citations/sample_bills

/content/LawGorithmML/citations/citations/sample_bills


In [7]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments
import torch
from datasets import Dataset
import os
import glob
import pandas as pd
import json
import re
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np
import concurrent.futures

In [14]:
# Define the path to the full-text bills
bills_path = "/content/LawGorithmML/citations/citations/sample_bills"

# Get all text files in the directory
bill_files = sorted(glob.glob(os.path.join(bills_path, "*.txt")))[:10]

bill_texts = {}
def load_file(file):
    with open(file, "r", encoding="utf-8") as f:
        return os.path.basename(file), f.read()

# Load 10 bills using multi-threading
with concurrent.futures.ThreadPoolExecutor() as executor:
    results = executor.map(load_file, bill_files)

bill_texts = dict(results)

print(f"Loaded {len(bill_texts)} bills.")

Loaded 10 bills.


In [None]:
json_path = "C:/Users/ralph/Downloads/citations/citations/citations/"

# Get all JSON files in the directory
json_files = glob.glob(os.path.join(json_path, "*.json"))

# Load citation data
citations = []
for file in json_files:
    with open(file, "r", encoding="utf-8") as f:
        data = json.load(f)
        citations.extend(data)  # Combine all JSON data
# Convert to DataFrame
df_citations = pd.DataFrame(citations)

# Show structure
df_citations.head()


Unnamed: 0,text,startPosition,endPosition,normCite,citeType,altCite,pinCiteStr,pageRangeStr,nodeId,section,sectionAndSubSection,isShortCite,chunk_id
0,PROVISIONS Sec. 201,2581,2600,,,,,,0,,,False,0.0
1,Sec. 202. Recycling Infrastructure and Accessi...,2643,2713,,,,,,0,,,False,0.0
2,Sec. 204. Reauthorization of Diesel Emissions ...,2766,2893,,,,,,0,,,False,0.0
3,TITLE IV—VETERANS Sec. 401. Protecting Regular...,3293,3434,,,,,,0,,,False,0.0
4,Sec. 111,4039,4047,,,,,,0,,,True,0.0


In [None]:
# Define context window size (adjust as needed)
CONTEXT_WINDOW = 200  # Number of characters before and after citation

# Function to find citation context in any bill
def extract_context_any_bill(citation_text):
    for bill_name, bill_text in bill_texts.items():
        match = re.search(re.escape(citation_text), bill_text)
        if match:
            start, end = match.start(), match.end()
            before = bill_text[max(0, start - CONTEXT_WINDOW): start]
            after = bill_text[end: min(len(bill_text), end + CONTEXT_WINDOW)]
            return before + citation_text + after  # Return first found match
    return None  # No match found

# Apply the function to extract citation contexts
df_citations["context"] = df_citations["text"].apply(extract_context_any_bill)

# Drop rows where no context was found
df_citations = df_citations.dropna(subset=["context"])

# Show some extracted contexts
df_citations.head()

Unnamed: 0,text,startPosition,endPosition,normCite,citeType,altCite,pinCiteStr,pageRangeStr,nodeId,section,sectionAndSubSection,isShortCite,chunk_id,context
4,Sec. 111,4039,4047,,,,,,0,,,True,0.0,State studies and HHS report on costs of provi...
5,Sec. 112,4145,4153,,,,,,0,,,True,0.0,proportionate share hospital allotments.\n ...
7,Sec. 602,4648,4656,,,,,,0,,,True,0.0,gnment within the Community Service Employment...
13,Sec. 604,6331,6339,33 usc 1384,USC,33 usc 1384,,,0,,,True,0.0,ERS\n Sec. 601. Older Americans ...
17,Sec. 1003,6813,6822,,,,,,0,,,True,0.0,"partments.\n <toc-entry align=""l..."


In [None]:
# Keep only necessary columns
df = df_citations[["context", "citeType"]]
df = df.rename(columns={"context": "input_text", "citeType": "label"})

# Encode labels as numbers
label_encoder = LabelEncoder()
df["label"] = label_encoder.fit_transform(df["label"])

# Split into train/test sets
train_texts, test_texts, train_labels, test_labels = train_test_split(
    df["input_text"], df["label"], test_size=0.2, random_state=42
)

# Convert into Pandas DataFrame format
train_df = pd.DataFrame({"text": train_texts, "label": train_labels})
test_df = pd.DataFrame({"text": test_texts, "label": test_labels})

# Show some training examples
train_df.head()

Unnamed: 0,text,label
19023,\n \n \n ...,17
3758,enforcement activities authorized or approved...,17
13016,For payments to departments of agric...,17
34372,fference between—\n ...,0
25509,0101 of the Disaster Relief and Recovery Suppl...,0


In [None]:
# Choose model: either LegalBERT or LegalRoBERTa
MODEL_NAME = "nlpaueb/legal-bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Tokenization function
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

# Convert Pandas DF to Hugging Face Dataset
train_dataset = Dataset.from_pandas(train_df)
test_dataset = Dataset.from_pandas(test_df)

# Tokenize dataset
tokenized_train = train_dataset.map(tokenize_function, batched=True)
tokenized_test = test_dataset.map(tokenize_function, batched=True)

# Load pre-trained model for classification
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=len(label_encoder.classes_))

Map:   0%|          | 0/28271 [00:00<?, ? examples/s]

Map:   0%|          | 0/7068 [00:00<?, ? examples/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at nlpaueb/legal-bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


In [None]:
# Define training arguments
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir='./logs',
)

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_test,
)

# Train the model
trainer.train()




Epoch,Training Loss,Validation Loss


In [None]:
pip install --upgrade transformers accelerate torch

Note: you may need to restart the kernel to use updated packages.
