# Legal Case Outcome Prediction - Model Training

This notebook loads, cleans, and trains models on legal case data.


In [1]:
# Fix for importlib metadata issue in conda environments
# This patches both importlib.metadata and importlib_metadata to handle None metadata
import sys
import importlib.metadata
import importlib_metadata

# Patch importlib.metadata.version (used by datasets/config.py)
_original_version = importlib.metadata.version
def _safe_version(name):
    try:
        result = _original_version(name)
        if result is None:
            # Return a valid version string instead of 'unknown'
            return '0.0.0'
        return result
    except Exception:
        # Return a valid version string for packaging.version.parse()
        return '0.0.0'

importlib.metadata.version = _safe_version
sys.modules['importlib.metadata'].version = _safe_version

# Patch importlib_metadata.distribution
_original_distribution = importlib_metadata.distribution
def _safe_distribution(name):
    try:
        dist = _original_distribution(name)
        if hasattr(dist, 'metadata') and dist.metadata is None:
            class MockDist:
                metadata = {'Version': 'unknown'}
                @property
                def version(self):
                    return 'unknown'
            return MockDist()
        return dist
    except Exception:
        class MockDist:
            metadata = {'Version': 'unknown'}
            @property
            def version(self):
                return 'unknown'
        return MockDist()

importlib_metadata.distribution = _safe_distribution
sys.modules['importlib_metadata'].distribution = _safe_distribution

import pandas as pd
import numpy as np
import re
import os
import pickle
from pathlib import Path
from sentence_transformers import SentenceTransformer
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.svm import SVC
from sklearn.neural_network import MLPClassifier
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import warnings
warnings.filterwarnings('ignore')

# Set paths
BASE_DIR = Path('../')
DATA_DIR = BASE_DIR / 'data'
MODELS_DIR = BASE_DIR / 'models'
MODELS_DIR.mkdir(exist_ok=True)

print("Setup complete!")


Setup complete!


## 1. Load CSV Files


In [2]:
# Load both CSV files
dockets = pd.read_csv(DATA_DIR / 'courtlistener_dockets_partial.csv')
opinions = pd.read_csv(DATA_DIR / 'opinions_checkpoint.csv')

print(f"Dockets shape: {dockets.shape}")
print(f"Opinions shape: {opinions.shape}")
print(f"\nDockets columns: {list(dockets.columns)}")
print(f"\nOpinions columns: {list(opinions.columns)}")


Dockets shape: (73000, 58)
Opinions shape: (18500, 8)

Dockets columns: ['resource_uri', 'id', 'court', 'court_id', 'original_court_info', 'idb_data', 'clusters', 'audio_files', 'assigned_to', 'referred_to', 'absolute_url', 'date_created', 'date_modified', 'source', 'appeal_from_str', 'assigned_to_str', 'referred_to_str', 'panel_str', 'date_last_index', 'date_cert_granted', 'date_cert_denied', 'date_argued', 'date_reargued', 'date_reargument_denied', 'date_filed', 'date_terminated', 'date_last_filing', 'case_name_short', 'case_name', 'case_name_full', 'slug', 'docket_number', 'docket_number_core', 'docket_number_raw', 'federal_dn_office_code', 'federal_dn_case_type', 'federal_dn_judge_initials_assigned', 'federal_dn_judge_initials_referred', 'federal_defendant_number', 'pacer_case_id', 'cause', 'nature_of_suit', 'jury_demand', 'jurisdiction_type', 'appellate_fee_status', 'appellate_case_type_information', 'mdl_status', 'filepath_ia', 'filepath_ia_json', 'ia_upload_failure_count', 'ia_n

## 2. Merge Docket and Opinion Data


In [3]:
# Extract docket_id from dockets (it's in the id column)
dockets['docket_id'] = dockets['id'].astype(str)
opinions['docket_id'] = opinions['docket_id'].astype(str)

# Merge on docket_id
df = opinions.merge(
    dockets[['docket_id', 'case_name', 'court']],
    on='docket_id',
    how='inner'
)

print(f"Merged dataset shape: {df.shape}")
print(f"\nMissing values:")
print(df.isnull().sum())
print(f"\nFirst few rows:")
df.head()


Merged dataset shape: (18500, 10)

Missing values:
docket_id           0
case_name_x         0
court_x             0
date_filed      18500
opinion_id          0
opinion_type        0
opinion_text        1
outcome             0
case_name_y         0
court_y             0
dtype: int64

First few rows:


Unnamed: 0,docket_id,case_name_x,court_x,date_filed,opinion_id,opinion_type,opinion_text,outcome,case_name_y,court_y
0,71884389,Trump v. Orr,https://www.courtlistener.com/api/rest/v4/cour...,,11198703,010combined,Cite as: 607 U. S. ____ (202...,granted,Trump v. Orr,https://www.courtlistener.com/api/rest/v4/cour...
1,71735833,Boyd v. Hamm,https://www.courtlistener.com/api/rest/v4/cour...,,11177002,010combined,...,granted,Boyd v. Hamm,https://www.courtlistener.com/api/rest/v4/cour...
2,71735833,Boyd v. Hamm,https://www.courtlistener.com/api/rest/v4/cour...,,11177003,010combined,Cite as: 607 U. S. ____ (202...,granted,Boyd v. Hamm,https://www.courtlistener.com/api/rest/v4/cour...
3,71735833,Boyd v. Hamm,https://www.courtlistener.com/api/rest/v4/cour...,,11176501,010combined,Cite as: 607 U. S. ____ (202...,granted,Boyd v. Hamm,https://www.courtlistener.com/api/rest/v4/cour...
4,71659774,Crawford v. Mississippi,https://www.courtlistener.com/api/rest/v4/cour...,,11171208,010combined,Cite as: 607 U. S. ____ (202...,denied,Crawford v. Mississippi,https://www.courtlistener.com/api/rest/v4/cour...


## 3. Clean Text Data


In [4]:
def clean_legal_text(text):
    """
    Clean legal text by removing outcome-revealing words and procedural boilerplate.
    Applies tail-scrubbing to last 2000 characters.
    """
    if pd.isna(text) or text == '':
        return ''
    
    text = str(text)
    
    # Outcome-revealing words to remove
    outcome_words = [
        'AFFIRMED', 'REVERSED', 'VACATED', 'REMANDED',
        'GRANTED', 'DISMISSED', 'DENIED',
        'affirmed', 'reversed', 'vacated', 'remanded',
        'granted', 'dismissed', 'denied'
    ]
    
    # Remove outcome words
    for word in outcome_words:
        text = text.replace(word, '')
    
    # Tail-scrubbing: clean last 2000 characters
    if len(text) > 2000:
        tail = text[-2000:]
        main_text = text[:-2000]
        
        # Remove procedural boilerplate patterns from tail
        tail = re.sub(r'Judgment\s+vacated[^.]*\.', '', tail, flags=re.IGNORECASE)
        tail = re.sub(r'and\s+remanded[^.]*\.', '', tail, flags=re.IGNORECASE)
        tail = re.sub(r'Certiorari\s+granted[^.]*\.', '', tail, flags=re.IGNORECASE)
        tail = re.sub(r'The\s+petition\s+for\s+rehearing\s+is\s+denied[^.]*\.', '', tail, flags=re.IGNORECASE)
        tail = re.sub(r'\bremanded\b[^.]*\.', '', tail, flags=re.IGNORECASE)
        tail = re.sub(r'\bvacated\b[^.]*\.', '', tail, flags=re.IGNORECASE)
        tail = re.sub(r'\breversed\b[^.]*\.', '', tail, flags=re.IGNORECASE)
        tail = re.sub(r'\baffirmed\b[^.]*\.', '', tail, flags=re.IGNORECASE)
        
        text = main_text + tail
    else:
        # Apply same cleaning to entire text if shorter than 2000 chars
        text = re.sub(r'Judgment\s+vacated[^.]*\.', '', text, flags=re.IGNORECASE)
        text = re.sub(r'and\s+remanded[^.]*\.', '', text, flags=re.IGNORECASE)
        text = re.sub(r'Certiorari\s+granted[^.]*\.', '', text, flags=re.IGNORECASE)
        text = re.sub(r'The\s+petition\s+for\s+rehearing\s+is\s+denied[^.]*\.', '', text, flags=re.IGNORECASE)
    
    # Clean up extra whitespace
    text = re.sub(r'\s+', ' ', text)
    text = text.strip()
    
    return text

# Apply cleaning
df['clean_text'] = df['opinion_text'].apply(clean_legal_text)

# Remove rows with empty clean_text
df = df[df['clean_text'].str.len() > 0]

print(f"After cleaning, dataset shape: {df.shape}")
print(f"\nSample cleaned text (first 500 chars):")
print(df['clean_text'].iloc[0][:500] if len(df) > 0 else 'No data')


After cleaning, dataset shape: (18499, 11)

Sample cleaned text (first 500 chars):
Cite as: 607 U. S. ____ (2025) 1 SUPREME COURT OF THE UNITED STATES _________________ No. 25A319 _________________ DONALD J. TRUMP, PRESIDENT OF THE UNITED STATES, ET AL. v. ASHTON ORR, ET AL. ON APPLICATION FOR STAY [November 6, 2025] This case concerns an Executive Branch policy requiring all new passports to display an individual’s biological sex at birth. The United States District Court for the District of Massachusetts preliminarily enjoined the Government from enforcing the policy, and th


## 4. Create Binary Labels


In [5]:
def create_binary_label(outcome):
    """
    Create binary label:
    win → {reversed, granted}
    lose → {affirmed, denied, dismissed, remanded}
    unknown → drop
    """
    if pd.isna(outcome):
        return 'unknown'
    
    outcome_str = str(outcome).lower().strip()
    
    # Win cases
    if 'reversed' in outcome_str or 'granted' in outcome_str:
        return 'win'
    
    # Lose cases
    if 'affirmed' in outcome_str or 'denied' in outcome_str or \
       'dismissed' in outcome_str or 'remanded' in outcome_str:
        return 'lose'
    
    return 'unknown'

# Create labels
df['winlose'] = df['outcome'].apply(create_binary_label)

# Drop unknown labels
df = df[df['winlose'] != 'unknown']

print(f"After labeling, dataset shape: {df.shape}")
print(f"\nLabel distribution:")
print(df['winlose'].value_counts())


After labeling, dataset shape: (18100, 12)

Label distribution:
winlose
lose    17467
win       633
Name: count, dtype: int64


## 5. Generate LegalBERT Embeddings


In [6]:
# Load LegalBERT model
print("Loading LegalBERT model...")
model = SentenceTransformer('nlpaueb/legal-bert-base-uncased')

# Generate embeddings
print("Generating embeddings...")
texts = df['clean_text'].tolist()
embeddings = model.encode(texts, batch_size=8, show_progress_bar=True)

print(f"\nEmbeddings shape: {embeddings.shape}")
print(f"Embedding dimension: {embeddings.shape[1]}")


Loading LegalBERT model...


No sentence-transformers model found with name nlpaueb/legal-bert-base-uncased. Creating a new one with mean pooling.


Generating embeddings...


Batches:   0%|          | 0/2263 [00:00<?, ?it/s]


Embeddings shape: (18100, 768)
Embedding dimension: 768


## 6. Split Dataset


In [7]:
# Prepare labels
label_encoder = LabelEncoder()
y = label_encoder.fit_transform(df['winlose'])

# Split 80/20
X_train, X_test, y_train, y_test = train_test_split(
    embeddings, y, test_size=0.2, random_state=42, stratify=y
)

print(f"Training set: {X_train.shape[0]} samples")
print(f"Test set: {X_test.shape[0]} samples")
print(f"\nTraining label distribution:")
print(pd.Series(y_train).value_counts())
print(f"\nTest label distribution:")
print(pd.Series(y_test).value_counts())


Training set: 14480 samples
Test set: 3620 samples

Training label distribution:
0    13974
1      506
Name: count, dtype: int64

Test label distribution:
0    3493
1     127
Name: count, dtype: int64


## 7. Train and Compare Models


In [8]:
# Define models to test
models = {
    'LogisticRegression': LogisticRegression(random_state=42, max_iter=1000),
    'RandomForest': RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1),
    'GradientBoosting': GradientBoostingClassifier(random_state=42),
    'SVC': SVC(kernel='rbf', random_state=42, probability=True),
    'MLPClassifier': MLPClassifier(hidden_layer_sizes=(256,), random_state=42, max_iter=500)
}

results = {}

for name, model in models.items():
    print(f"\n{'='*50}")
    print(f"Training {name}...")
    print(f"{'='*50}")
    
    # Train
    model.fit(X_train, y_train)
    
    # Evaluate
    y_pred = model.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)
    
    results[name] = {
        'model': model,
        'accuracy': accuracy,
        'predictions': y_pred
    }
    
    print(f"\n{name} Accuracy: {accuracy:.4f}")
    print(f"\nClassification Report:")
    print(classification_report(y_test, y_pred, target_names=label_encoder.classes_))



Training LogisticRegression...

LogisticRegression Accuracy: 0.9845

Classification Report:
              precision    recall  f1-score   support

        lose       0.99      0.99      0.99      3493
         win       0.83      0.70      0.76       127

    accuracy                           0.98      3620
   macro avg       0.91      0.85      0.88      3620
weighted avg       0.98      0.98      0.98      3620


Training RandomForest...

RandomForest Accuracy: 0.9837

Classification Report:
              precision    recall  f1-score   support

        lose       0.99      1.00      0.99      3493
         win       0.83      0.67      0.74       127

    accuracy                           0.98      3620
   macro avg       0.91      0.83      0.87      3620
weighted avg       0.98      0.98      0.98      3620


Training GradientBoosting...

GradientBoosting Accuracy: 0.9843

Classification Report:
              precision    recall  f1-score   support

        lose       0.99     

## 8. Select Best Model and Save


In [9]:
# Find best model
best_model_name = max(results, key=lambda x: results[x]['accuracy'])
best_model = results[best_model_name]['model']
best_accuracy = results[best_model_name]['accuracy']

print(f"\n{'='*50}")
print(f"Best Model: {best_model_name}")
print(f"Accuracy: {best_accuracy:.4f}")
print(f"{'='*50}")

# Save model
with open(MODELS_DIR / 'model.pkl', 'wb') as f:
    pickle.dump(best_model, f)

# Save label encoder
with open(MODELS_DIR / 'label_encoder.pkl', 'wb') as f:
    pickle.dump(label_encoder, f)

# Save embeddings
np.save(MODELS_DIR / 'embeddings.npy', embeddings)

# Save clean dataset
df.to_csv(MODELS_DIR / 'clean_dataset.csv', index=False)

print(f"\nAll artifacts saved to {MODELS_DIR}")
print(f"\nSaved files:")
print(f"  - model.pkl")
print(f"  - label_encoder.pkl")
print(f"  - embeddings.npy")
print(f"  - clean_dataset.csv")



Best Model: SVC
Accuracy: 0.9851

All artifacts saved to ../models

Saved files:
  - model.pkl
  - label_encoder.pkl
  - embeddings.npy
  - clean_dataset.csv
