Approach 3 of 3 for hierarchical supervised classification: use finetuned legalSmall_case_classifier as the case classification model, and use a cross encoder model to match quotes downstream

In [31]:
!pip uninstall -y sympy
!pip install sympy

Found existing installation: sympy 1.13.1
Uninstalling sympy-1.13.1:
  Successfully uninstalled sympy-1.13.1
Collecting sympy
  Downloading sympy-1.13.3-py3-none-any.whl.metadata (12 kB)
Downloading sympy-1.13.3-py3-none-any.whl (6.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.2/6.2 MB[0m [31m46.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: sympy
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torch 2.5.1+cu121 requires sympy==1.13.1; python_version >= "3.9", but you have sympy 1.13.3 which is incompatible.[0m[31m
[0mSuccessfully installed sympy-1.13.3


In [1]:
!pip install -U sentence-transformers
from sentence_transformers import CrossEncoder

Collecting sympy==1.13.1 (from torch>=1.11.0->sentence-transformers)
  Downloading sympy-1.13.1-py3-none-any.whl.metadata (12 kB)
Downloading sympy-1.13.1-py3-none-any.whl (6.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.2/6.2 MB[0m [31m41.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: sympy
  Attempting uninstall: sympy
    Found existing installation: sympy 1.13.3
    Uninstalling sympy-1.13.3:
      Successfully uninstalled sympy-1.13.3
Successfully installed sympy-1.13.1


In [2]:
from google.colab import drive
drive.mount("/content/drive", force_remount = True)

Mounted at /content/drive


In [3]:
!pip install datasets
from datasets import Dataset, load_dataset
import pandas as pd



In [4]:
# load the training data
df = pd.read_csv('/content/drive/MyDrive/6.8610/NLP Final Project/top_10000_training_data.csv.gz', compression='gzip')
print(df.head())
df.shape #currently should be the 2 million rows

    dest_id  source_id   dest_date  \
0  11582807   11781059  1999-08-16   
1  11582807   11781059  1999-08-16   
2  11582807   11781059  1999-08-16   
3  11665092   11781059  1999-07-13   
4  11123976   11781059  2001-03-16   

                                          dest_court  \
0  United States Court of Appeals for the Sixth C...   
1  United States Court of Appeals for the Sixth C...   
2  United States Court of Appeals for the Sixth C...   
3  United States Court of Appeals for the Fifth C...   
4  United States Court of Appeals for the Eighth ...   

                  dest_name                                      dest_cite  \
0  United States v. Houston  United States v. Houston, 187 F.3d 593 (1999)   
1  United States v. Houston  United States v. Houston, 187 F.3d 593 (1999)   
2  United States v. Houston  United States v. Houston, 187 F.3d 593 (1999)   
3     United States v. Ruiz     United States v. Ruiz, 180 F.3d 675 (1999)   
4   United States v. Nation   United States 

(2076241, 13)

In [5]:
baseline_df = df[['destination_context', 'quote', 'passage_id', 'source_name']]
baseline_df["label"], _ = pd.factorize(baseline_df["passage_id"])

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  baseline_df["label"], _ = pd.factorize(baseline_df["passage_id"])


In [6]:
baseline_df = baseline_df[:10000] #cut to top 10k rows
baseline_df.shape
baseline_df

Unnamed: 0,destination_context,quote,passage_id,source_name,label
0,If Houston’s escape is to be deemed a violent...,otherwise involve! ] conduct that presents a s...,11781059_7,United States v. Harris,0
1,If Houston’s escape is to be deemed a violent...,otherwise involves conduct that presents a ser...,11781059_7,United States v. Harris,0
2,Id. at 323 (finding that although the Ohio ki...,otherwise involves conduct that presents a ser...,11781059_7,United States v. Harris,0
3,” We rejected an identical argument in United ...,presenting] a serious potential risk of physic...,11781059_7,United States v. Harris,0
4,The court sentenced Nation pursuant to U.S.S....,involves conduct that presents a serious poten...,11781059_7,United States v. Harris,0
...,...,...,...,...,...
9995,"Lee, 236 F.R.D. at 204; see also Robidoux, 98...",the representative parties will fairly and ade...,11239507_17,"Baffa v. Donaldson, Lufkin & Jenrette Securiti...",77
9996,"To establish adequacy, plaintiffs must show t...",antagonistic to the interest of other members ...,11239507_3,"Baffa v. Donaldson, Lufkin & Jenrette Securiti...",76
9997,"To establish adequacy, plaintiffs must show t...","qualified, experienced and able to conduct the...",11239507_3,"Baffa v. Donaldson, Lufkin & Jenrette Securiti...",76
9998,"\nIn support of commonality, Maziarz asserts t...","qualified, experienced and able to conduct the...",11239507_3,"Baffa v. Donaldson, Lufkin & Jenrette Securiti...",76


In [7]:
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModel
from torch.nn import functional as F
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

In [8]:
#need to instantiate a copy of the model for case classification to substitute finetuned parameters

class LegalBERTSmallClassifier(torch.nn.Module):
    def __init__(self, n_classes):
        super(LegalBERTSmallClassifier, self).__init__()
        self.legalbertSmall = AutoModelForSequenceClassification.from_pretrained(
            "nlpaueb/legal-bert-small-uncased",
            num_labels=n_classes
        )

        # Freeze all parameters except classifier
        for param in self.legalbertSmall.parameters():
            param.requires_grad = False
        for param in self.legalbertSmall.classifier.parameters():
            param.requires_grad = True

    def forward(self, input_ids, attention_mask):
        output = self.legalbertSmall(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        return output.logits  # Note: need to return logits for classification

In [19]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch.nn.functional as F

class LegalPrecedentMatcher:
    def __init__(self, case_classifier_path, num_labels, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device

        # Initialize case classifier as before
        self.case_classifier = LegalBERTSmallClassifier(n_classes=num_labels)
        state_dict = torch.load(case_classifier_path, map_location=device)
        self.case_classifier.load_state_dict(state_dict)
        self.case_classifier.to(device)
        self.case_classifier.eval()

        # Initialize cross-encoder using sentence-transformers
        self.cross_encoder = CrossEncoder('cross-encoder/stsb-roberta-base', device=device)

        # Keep LegalBERT tokenizer for case classification
        self.case_tokenizer = AutoTokenizer.from_pretrained('nlpaueb/legal-bert-small-uncased')

    def classify_case(self, quote):
        with torch.no_grad():
            # Get tokenizer outputs
            tokenizer_output = self.case_tokenizer(quote,
                                                 return_tensors='pt',
                                                 padding=True,
                                                 truncation=True,
                                                 max_length=512)

            # Only pass the expected arguments
            inputs = {
                'input_ids': tokenizer_output['input_ids'].to(self.device),
                'attention_mask': tokenizer_output['attention_mask'].to(self.device)
            }

            outputs = self.case_classifier(**inputs)
            predicted_label = torch.argmax(outputs, dim=1)

        return predicted_label.item()

    def find_most_similar_quote(self, query_quote, candidate_quotes):
        # Score all pairs
        pairs = [[query_quote, candidate] for candidate in candidate_quotes]
        scores = self.cross_encoder.predict(pairs)

        # Find best match
        most_similar_idx = np.argmax(scores)

        return {
            'most_similar_quote': candidate_quotes[most_similar_idx],
            'similarity_score': scores[most_similar_idx],
            'all_similarities': scores
        }

In [20]:
baseline_df_labels = baseline_df["label"].to_list()[:10000]
num_labels_baseline = len(baseline_df["label"].unique())
num_labels_baseline

78

In [21]:
#inference (single example) pipeline
#using case_classifier_df (the top-10k rows version of baseline_df)

case_classifier_df = df[['destination_context', 'quote', 'passage_id', 'source_name', 'source_id']]
case_classifier_df["label"], _ = pd.factorize(case_classifier_df["source_id"])
training_labels = case_classifier_df["label"].to_list()[:10000]
num_labels = len(set(training_labels))

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  case_classifier_df["label"], _ = pd.factorize(case_classifier_df["source_id"])


In [22]:
def main(df, query_quote, case_classifier_path):
    matcher = LegalPrecedentMatcher(case_classifier_path, num_labels)

    # Step 1: Classify the case
    predicted_label = matcher.classify_case(query_quote)

    # Step 2: Get all quotes from the predicted case
    case_quotes = df[df['label'] == predicted_label]['destination_context'].tolist()

    if not case_quotes:
        return {
            'predicted_case_label': predicted_label,
            'predicted_case_name': 'No matching case found',
            'most_similar_quote': None,
            'similarity_score': None
        }

    # Step 3: Find the most similar quote using cross encoder
    result = matcher.find_most_similar_quote(query_quote, case_quotes)

    return {
        'predicted_case_label': predicted_label,
        'predicted_case_name': df[df['label'] == predicted_label]['source_name'].iloc[0],
        'most_similar_quote': result['most_similar_quote'],
        'similarity_score': result['similarity_score']
    }

In [23]:
# Test with exact same file path as in training
baseline_df = baseline_df[:10000]
test_quote = baseline_df['quote'].iloc[2459]
results = main(baseline_df, test_quote, '/content/drive/MyDrive/6.8610/NLP Final Project/LegalSmall_Case_Classifier_top10k.pth')
print(f"Predicted Case: {results['predicted_case_name']}")
print(f"Most Similar Quote (Similarity: {results['similarity_score']:.3f}):")
print(results['most_similar_quote'])

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at nlpaueb/legal-bert-small-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  state_dict = torch.load(case_classifier_path, map_location=device)


Predicted Case: Walczak v. Florida Union Free School District
Most Similar Quote (Similarity: 0.568):
 The IHO held that both IEPs “offered an appropriate public education in the least restrictive environment,” id. at 55, and that the procedures followed in their development adequately complied with the IDEA, id. An SRO affirmed the decisions of the IHO. Id.
DISCUSSION
I. Judicial Review Under the IDEA
Federal courts reviewing administrative determinations under the IDEA must base their decisions on “a preponderance of the evidence,” taking into account not only the record from the administrative proceedings, but also any further evidence presented before the District Court by the parties. See 20 U.S.C. § 1415(i)(2)(B). The Supreme Court and our Court have interpreted the IDEA as strictly limiting judicial review of state administrative decisions. Federal courts reviewing administrative decisions must give “due weight” to the administrative proceedings, “


Evaluations

In [24]:
import numpy as np
import random
from tqdm import tqdm

In [25]:
def evaluate_full_system(df, matcher, n_samples, n_distractors):
    """
    Evaluates both case classification and passage matching

    Args:
        df: DataFrame with columns 'quote', 'destination_context', 'source_name', 'label'
        matcher: Initialized LegalPrecedentMatcher instance
        n_samples: Number of test cases to evaluate
        n_distractors: Number of incorrect passages to include

    Returns:
        Dictionary with evaluation metrics
    """
    correct_case = 0
    correct_passage = 0
    correct_both = 0
    mrr = 0  # Mean Reciprocal Rank

    # Sample test cases
    all_indices = list(range(len(df)))
    test_indices = np.random.choice(all_indices, n_samples, replace=False)

    # For storing detailed results
    detailed_results = []

    # Progress bar
    for idx in tqdm(test_indices, desc="Evaluating"):
        # Get query quote and true information
        query = df.iloc[idx]['quote']
        true_case = df.iloc[idx]['source_name']
        true_passage = df.iloc[idx]['destination_context']

        # Get distractor passages (randomly sampled)
        # Exclude the current index from potential distractors
        possible_distractors = [i for i in all_indices if i != idx]
        distractor_indices = np.random.choice(
            possible_distractors,
            n_distractors,
            replace=False
        )

        # Create list of candidate passages (true passage + distractors)
        candidate_passages = [true_passage] + [
            df.iloc[i]['destination_context'] for i in distractor_indices
        ]

        # Shuffle candidates and keep track of true passage position
        random.shuffle(candidate_passages)
        true_passage_idx = candidate_passages.index(true_passage)

        # Get model predictions
        # First, classify the case
        predicted_label = matcher.classify_case(query)
        predicted_case = df[df['label'] == predicted_label]['source_name'].iloc[0]

        # Then, find most similar passage
        result = matcher.find_most_similar_quote(query, candidate_passages)
        similarities = result['all_similarities']

        # Find rank of true passage in similarity scores
        ranked_indices = np.argsort(-similarities)  # Sort in descending order
        passage_rank = np.where(ranked_indices == true_passage_idx)[0][0] + 1

        # Update metrics
        case_correct = (predicted_case == true_case)
        passage_correct = (passage_rank == 1)

        if case_correct:
            correct_case += 1
        if passage_correct:
            correct_passage += 1
        if case_correct and passage_correct:
            correct_both += 1
        mrr += 1.0 / passage_rank

        # Store detailed results for this sample
        detailed_results.append({
            'query': query,
            'true_case': true_case,
            'predicted_case': predicted_case,
            'case_correct': case_correct,
            'passage_rank': passage_rank,
            'true_passage': true_passage,
            'predicted_passage': candidate_passages[ranked_indices[0]],
            'similarity_score': similarities[ranked_indices[0]]
        })

    # Calculate final metrics
    results = {
        'case_accuracy': correct_case / n_samples,
        'passage_accuracy': correct_passage / n_samples,
        'full_system_accuracy': correct_both / n_samples,
        'mrr': mrr / n_samples,
        'detailed_results': detailed_results  # Store all individual results
    }

    return results

In [26]:
def print_evaluation_results(results):
    """Pretty prints the evaluation results"""
    print("\nEvaluation Results:")
    print("-" * 50)
    print(f"Case Classification Accuracy: {results['case_accuracy']:.3f}")
    print(f"Passage Matching Accuracy: {results['passage_accuracy']:.3f}")
    print(f"Full System Accuracy: {results['correct_both']:.3f}")
    print(f"Mean Reciprocal Rank: {results['mrr']:.3f}")

    print("\nDetailed Error Analysis:")
    print("-" * 50)
    # Print a few examples where the system failed
    print("Examples of errors:")
    errors = [r for r in results['detailed_results']
             if not (r['case_correct'] and r['passage_rank'] == 1)]

    for i, error in enumerate(errors[:3]):  # Show first 3 errors
        print(f"\nError Example {i+1}:")
        print(f"Query: {error['query'][:200]}...")
        print(f"True Case: {error['true_case']}")
        print(f"Predicted Case: {error['predicted_case']}")
        print(f"Passage Rank: {error['passage_rank']}")
        print(f"Similarity Score: {error['similarity_score']:.3f}")

In [27]:
matcher = LegalPrecedentMatcher('/content/drive/MyDrive/6.8610/NLP Final Project/LegalSmall_Case_Classifier_top10k.pth', num_labels)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at nlpaueb/legal-bert-small-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  state_dict = torch.load(case_classifier_path, map_location=device)


In [36]:
results = evaluate_full_system(baseline_df, matcher, n_samples=20, n_distractors=9)
print_evaluation_results(results)

Evaluating: 100%|██████████| 20/20 [05:57<00:00, 17.85s/it]


Evaluation Results:
--------------------------------------------------
Case Classification Accuracy: 0.050
Passage Matching Accuracy: 0.200





KeyError: 'correct_both'