# DOJ Press Release NER Model - Interactive Notebook

This notebook provides an interactive, step-by-step walkthrough of building a Named Entity Recognition (NER) model for analyzing Department of Justice press releases.

## Project Goals
- Process Prodigy-annotated training data
- Build a custom spaCy NER model
- Evaluate performance with detailed metrics and visualizations
- Extract entities: DEFENDANT, PROSECUTOR, JUDGE, SENTENCE, FRAUD MECHANISM, FRAUD AMOUNT, GOV PROGRAM, BUSINESS

## Notebook Structure
1. **Setup & Data Loading**: Import libraries and load annotated data
2. **Data Exploration**: Analyze entity distribution and dataset statistics
3. **Data Preparation**: Convert Prodigy format to spaCy format
4. **Model Training**: Train custom NER model with spaCy
5. **Evaluation**: Calculate detailed metrics and visualize performance
6. **Interactive Demo**: Test the model on new text

## 1. Setup & Import Required Libraries

Install and import all necessary libraries for data processing, model training, and visualization.

In [1]:
# Import core libraries
import json
import os
from pathlib import Path
from collections import Counter, defaultdict
from typing import List, Dict, Tuple

# Data processing
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

# spaCy for NER
import spacy
from spacy.tokens import DocBin, Doc
from spacy.training import Example

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Progress bars
from tqdm.auto import tqdm

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 6)

print("‚úÖ All libraries imported successfully!")
print(f"spaCy version: {spacy.__version__}")

‚úÖ All libraries imported successfully!
spaCy version: 3.8.11


  from .autonotebook import tqdm as notebook_tqdm


## 2. Load and Explore Prodigy Annotated Data

Load the training data annotated in Prodigy format from `data/raw/2025_11_27.jsonl`.

In [2]:
# Load Prodigy annotated data
data_file = Path("data/raw/2025_11_27.jsonl")

def load_prodigy_data(file_path: Path) -> List[Dict]:
    """Load data from Prodigy JSONL format."""
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line))
    return data

# Load data
prodigy_data = load_prodigy_data(data_file)
print(f"‚úÖ Loaded {len(prodigy_data)} annotated examples")

# Display first example
print("\nüìÑ Sample annotation:")
sample = prodigy_data[0]
print(json.dumps(sample, indent=2)[:500] + "...")

‚úÖ Loaded 1013 annotated examples

üìÑ Sample annotation:
{
  "text": "\u2022 Five defendants indicted for operating shell companies to \u201crent\u201d workers\u2019 compensation to work crews that unlawfully employed illegal aliens and for cashing approximately $292 million in payroll checks and failing to pay more than $52 million in payroll taxes",
  "meta": {
    "source_row": 2550,
    "paragraph_index": 23,
    "source_csv": "all_2025.csv",
    "press_release_date": "2/18/2025",
    "unique_id": "c66e8800-ad77-4931-b67f-66292d436232"
  },
  "_in...


### 2.1 Entity Distribution Analysis

Analyze what entity types are present and their frequency distribution.

In [3]:
# Extract entity statistics
entity_counts = Counter()
entity_examples = defaultdict(list)

for item in prodigy_data:
    if 'spans' in item:
        for span in item['spans']:
            label = span['label']
            entity_counts[label] += 1
            # Store first 3 examples per entity type
            if len(entity_examples[label]) < 3:
                text = item['text'][span['start']:span['end']]
                entity_examples[label].append(text)

# Create DataFrame for visualization
entity_df = pd.DataFrame([
    {'Entity Type': label, 'Count': count} 
    for label, count in entity_counts.most_common()
])

print(f"üìä Found {len(entity_counts)} entity types:")
print(entity_df.to_string(index=False))
print(f"\nüìà Total entities annotated: {sum(entity_counts.values())}")

# Show examples
print("\nüìù Example entities:")
for label, examples in sorted(entity_examples.items()):
    print(f"\n{label}:")
    for ex in examples[:3]:
        print(f"  - {ex}")

üìä Found 8 entity types:
    Entity Type  Count
      DEFENDANT    809
     PROSECUTOR    187
FRAUD MECHANISM    108
       BUSINESS     94
          JUDGE     70
       SENTENCE     70
    GOV PROGRAM     57
   FRAUD AMOUNT     43

üìà Total entities annotated: 1438

üìù Example entities:

BUSINESS:
  - Partex Oman Corp.
  - Renewable Energy Campus Arkansas, Inc.
  - Stonetek Global Corp.

DEFENDANT:
  - Alvis Alexander Briceno-Yajures
  - Antoinette Kennedy
  - Andre Lane

FRAUD AMOUNT:
  - $161,900,000
  - 5,000.00
  - $600,000

FRAUD MECHANISM:
  - Medicare fraud scheme
  - COVID test kit fraud
  - COVID testing fraud

GOV PROGRAM:
  - the Paycheck Protection Program (‚ÄúPPP‚Äù)
  - Affordable Care Act
  - Paycheck Protection Program (PPP)

JUDGE:
  - Marcia Crone
  - Shanlyn A.S. Park
  - Robert F. Rossiter, Jr.

PROSECUTOR:
  - Matthew J. Del Mastro
  - Matt Quinn
  - Rebecca A. Perlmutter

SENTENCE:
  - 151 months
  - 21 years and 6 months
  - 4 years and 3 months


### 2.2 Visualize Entity Distribution

In [4]:
# Create interactive bar chart with Plotly
fig = px.bar(
    entity_df, 
    x='Entity Type', 
    y='Count',
    title='Entity Type Distribution in Training Data',
    color='Count',
    color_continuous_scale='viridis',
    text='Count'
)
fig.update_traces(textposition='outside')
fig.update_layout(
    xaxis_title="Entity Type",
    yaxis_title="Number of Annotations",
    showlegend=False,
    height=500
)
fig.show()

# Also create a pie chart
fig2 = px.pie(
    entity_df, 
    values='Count', 
    names='Entity Type',
    title='Entity Type Proportion',
    hole=0.3
)
fig2.show()

## 3. Data Preparation: Convert to spaCy Format

Convert Prodigy annotations to spaCy's binary DocBin format and split into train/dev/test sets.

In [5]:
# Initialize blank spaCy model for creating training data
nlp = spacy.blank("en")

def convert_to_spacy_format(prodigy_data: List[Dict]) -> List[Tuple[str, Dict]]:
    """Convert Prodigy format to spaCy training format."""
    training_data = []
    skipped = 0
    
    for item in tqdm(prodigy_data, desc="Converting to spaCy format"):
        text = item['text']
        entities = []
        
        if 'spans' in item:
            for span in item['spans']:
                start = span['start']
                end = span['end']
                label = span['label']
                
                # Validate entity
                if start < end and start >= 0 and end <= len(text):
                    entities.append((start, end, label))
                else:
                    skipped += 1
        
        if entities:  # Only include examples with entities
            training_data.append((text, {"entities": entities}))
    
    print(f"‚úÖ Converted {len(training_data)} examples")
    if skipped > 0:
        print(f"‚ö†Ô∏è Skipped {skipped} invalid entities")
    
    return training_data

# Convert data
spacy_data = convert_to_spacy_format(prodigy_data)
print(f"\nüì¶ Total training examples: {len(spacy_data)}")

Converting to spaCy format: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1013/1013 [00:00<00:00, 1010663.64it/s]

‚úÖ Converted 544 examples

üì¶ Total training examples: 544





### 3.1 Split Data into Train/Dev/Test Sets

Split the data: 70% training, 15% development, 15% testing.

In [6]:
# Split data: 70% train, 15% dev, 15% test
train_data, temp_data = train_test_split(spacy_data, test_size=0.3, random_state=42)
dev_data, test_data = train_test_split(temp_data, test_size=0.5, random_state=42)

print(f"üìä Data Split:")
print(f"  Training set:   {len(train_data)} examples ({len(train_data)/len(spacy_data)*100:.1f}%)")
print(f"  Development set: {len(dev_data)} examples ({len(dev_data)/len(spacy_data)*100:.1f}%)")
print(f"  Test set:       {len(test_data)} examples ({len(test_data)/len(spacy_data)*100:.1f}%)")

# Visualize split
split_df = pd.DataFrame({
    'Split': ['Train', 'Dev', 'Test'],
    'Count': [len(train_data), len(dev_data), len(test_data)],
    'Percentage': [
        len(train_data)/len(spacy_data)*100,
        len(dev_data)/len(spacy_data)*100,
        len(test_data)/len(spacy_data)*100
    ]
})

fig = px.bar(
    split_df, 
    x='Split', 
    y='Count',
    title='Train/Dev/Test Data Split',
    text='Count',
    color='Percentage',
    color_continuous_scale='blues'
)
fig.update_traces(textposition='outside')
fig.show()

üìä Data Split:
  Training set:   380 examples (69.9%)
  Development set: 82 examples (15.1%)
  Test set:       82 examples (15.1%)


### 3.2 Create spaCy DocBin Files

Convert training data to spaCy's efficient binary format for training.

In [7]:
def create_docbin(data: List[Tuple[str, Dict]], nlp) -> DocBin:
    """Convert training data to spaCy DocBin format."""
    db = DocBin()
    
    for text, annotations in tqdm(data, desc="Creating DocBin"):
        doc = nlp.make_doc(text)
        ents = []
        
        for start, end, label in annotations["entities"]:
            span = doc.char_span(start, end, label=label, alignment_mode="contract")
            if span is not None:
                ents.append(span)
        
        doc.ents = ents
        db.add(doc)
    
    return db

# Create output directory
output_dir = Path("data/processed")
output_dir.mkdir(parents=True, exist_ok=True)

# Create DocBin files for each split
print("üì¶ Creating DocBin files...")
train_db = create_docbin(train_data, nlp)
dev_db = create_docbin(dev_data, nlp)
test_db = create_docbin(test_data, nlp)

# Save to disk
train_db.to_disk(output_dir / "train.spacy")
dev_db.to_disk(output_dir / "dev.spacy")
test_db.to_disk(output_dir / "test.spacy")

print(f"\n‚úÖ DocBin files saved to {output_dir}")
print(f"   - train.spacy: {len(train_data)} examples")
print(f"   - dev.spacy: {len(dev_data)} examples")
print(f"   - test.spacy: {len(test_data)} examples")

üì¶ Creating DocBin files...


Creating DocBin: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 380/380 [00:00<00:00, 2259.49it/s]
Creating DocBin: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 380/380 [00:00<00:00, 2259.49it/s]
Creating DocBin: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 82/82 [00:00<00:00, 3380.31it/s]
Creating DocBin:   0%|          | 0/82 [00:00<?, ?it/s]
Creating DocBin: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 82/82 [00:00<00:00, 3064.81it/s]


‚úÖ DocBin files saved to data\processed
   - train.spacy: 380 examples
   - dev.spacy: 82 examples
   - test.spacy: 82 examples





## 4. Train Custom NER Model

Train the spaCy NER model using the prepared data. This will take several minutes.

**Note:** Training uses the `config.cfg` file in the `config/` directory. You can adjust hyperparameters there if needed.

In [8]:
# Train model using spaCy CLI
import subprocess
import sys

config_path = Path("config/config.cfg")
output_path = Path("models/ner_model")

print("üöÄ Starting model training...")
print(f"   Config: {config_path}")
print(f"   Output: {output_path}")
print(f"   Training data: {output_dir / 'train.spacy'}")
print(f"   Dev data: {output_dir / 'dev.spacy'}")
print("\n‚è≥ This will take 10-20 minutes...\n")

# Run training command
cmd = [
    sys.executable, "-m", "spacy", "train",
    str(config_path),
    "--output", str(output_path),
    "--paths.train", str(output_dir / "train.spacy"),
    "--paths.dev", str(output_dir / "dev.spacy")
]

try:
    result = subprocess.run(cmd, capture_output=True, text=True, check=True)
    print(result.stdout)
    print("\n‚úÖ Training completed successfully!")
except subprocess.CalledProcessError as e:
    print(f"‚ùå Training failed:\n{e.stderr}")
    raise

üöÄ Starting model training...
   Config: config\config.cfg
   Output: models\ner_model
   Training data: data\processed\train.spacy
   Dev data: data\processed\dev.spacy

‚è≥ This will take 10-20 minutes...

[38;5;4m‚Ñπ Saving to output directory: models\ner_model[0m
[38;5;4m‚Ñπ Using CPU[0m
[1m
[38;5;2m‚úî Initialized pipeline[0m
[1m
[38;5;4m‚Ñπ Pipeline: ['tok2vec', 'ner'][0m
[38;5;4m‚Ñπ Initial learn rate: 0.001[0m
E    #       LOSS TOK2VEC  LOSS NER  ENTS_F  ENTS_P  ENTS_R  SCORE 
---  ------  ------------  --------  ------  ------  ------  ------
  0       0          0.00     95.35    0.00    0.00    0.00    0.00
  0     200        241.88   2115.45   34.00   44.74   27.42    0.34
  1     400        675.36   1226.75   64.83   75.18   56.99    0.65
  2     600        361.50    924.36   66.09   71.70   61.29    0.66
  3     800        669.79    863.67   68.92   80.58   60.22    0.69
  5    1000        464.40    782.90   67.46   75.84   60.75    0.67
  7    1200        9

## 5. Load Trained Model and Evaluate Performance

Load the best trained model and evaluate on the test set.

In [9]:
# Load the best trained model
model_path = output_path / "model-best"
print(f"üìÇ Loading trained model from: {model_path}")

trained_nlp = spacy.load(model_path)
print(f"‚úÖ Model loaded successfully!")
print(f"   Pipeline components: {trained_nlp.pipe_names}")
print(f"   Entity labels: {trained_nlp.get_pipe('ner').labels}")

üìÇ Loading trained model from: models\ner_model\model-best
‚úÖ Model loaded successfully!
   Pipeline components: ['tok2vec', 'ner']
   Entity labels: ('BUSINESS', 'DEFENDANT', 'FRAUD AMOUNT', 'FRAUD MECHANISM', 'GOV PROGRAM', 'JUDGE', 'PROSECUTOR', 'SENTENCE')


### 5.1 Calculate Detailed Metrics on Test Set

In [10]:
def evaluate_model(nlp, test_data: List[Tuple[str, Dict]]) -> Dict:
    """Evaluate model and return detailed metrics."""
    
    # Track predictions per entity type
    tp = defaultdict(int)  # True positives
    fp = defaultdict(int)  # False positives
    fn = defaultdict(int)  # False negatives
    
    all_predictions = []
    
    for text, annotations in tqdm(test_data, desc="Evaluating on test set"):
        # Get ground truth entities
        gold_ents = set()
        for start, end, label in annotations["entities"]:
            gold_ents.add((start, end, label))
        
        # Get predicted entities
        doc = nlp(text)
        pred_ents = set()
        for ent in doc.ents:
            pred_ents.add((ent.start_char, ent.end_char, ent.label_))
        
        # Store for confusion matrix
        all_predictions.append({
            'text': text,
            'gold': gold_ents,
            'pred': pred_ents
        })
        
        # Calculate metrics per entity
        for ent in pred_ents:
            label = ent[2]
            if ent in gold_ents:
                tp[label] += 1
            else:
                fp[label] += 1
        
        for ent in gold_ents:
            label = ent[2]
            if ent not in pred_ents:
                fn[label] += 1
    
    # Calculate precision, recall, F1 per entity
    metrics = {}
    all_labels = set(list(tp.keys()) + list(fp.keys()) + list(fn.keys()))
    
    for label in all_labels:
        precision = tp[label] / (tp[label] + fp[label]) if (tp[label] + fp[label]) > 0 else 0
        recall = tp[label] / (tp[label] + fn[label]) if (tp[label] + fn[label]) > 0 else 0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
        
        metrics[label] = {
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'tp': tp[label],
            'fp': fp[label],
            'fn': fn[label],
            'support': tp[label] + fn[label]
        }
    
    # Calculate overall metrics
    total_tp = sum(tp.values())
    total_fp = sum(fp.values())
    total_fn = sum(fn.values())
    
    overall_precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0
    overall_recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0
    overall_f1 = 2 * overall_precision * overall_recall / (overall_precision + overall_recall) if (overall_precision + overall_recall) > 0 else 0
    
    metrics['OVERALL'] = {
        'precision': overall_precision,
        'recall': overall_recall,
        'f1': overall_f1,
        'tp': total_tp,
        'fp': total_fp,
        'fn': total_fn,
        'support': total_tp + total_fn
    }
    
    return metrics, all_predictions

# Run evaluation
print("üîç Evaluating model on test set...\n")
metrics, predictions = evaluate_model(trained_nlp, test_data)

# Display results
print("=" * 80)
print("MODEL EVALUATION RESULTS")
print("=" * 80)

# Create DataFrame for better display
metrics_data = []
for label, values in sorted(metrics.items()):
    metrics_data.append({
        'Entity Type': label,
        'Precision': f"{values['precision']:.3f}",
        'Recall': f"{values['recall']:.3f}",
        'F1-Score': f"{values['f1']:.3f}",
        'Support': values['support']
    })

metrics_df = pd.DataFrame(metrics_data)
print(metrics_df.to_string(index=False))
print("=" * 80)

üîç Evaluating model on test set...



Evaluating on test set: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 82/82 [00:00<00:00, 195.71it/s]

MODEL EVALUATION RESULTS
    Entity Type Precision Recall F1-Score  Support
       BUSINESS     0.444  0.235    0.308       17
      DEFENDANT     0.810  0.810    0.810      116
   FRAUD AMOUNT     1.000  0.111    0.200        9
FRAUD MECHANISM     0.154  0.095    0.118       21
    GOV PROGRAM     0.333  0.455    0.385       11
          JUDGE     0.769  1.000    0.870       10
        OVERALL     0.718  0.658    0.687      225
     PROSECUTOR     0.793  0.793    0.793       29
       SENTENCE     0.900  0.750    0.818       12





### 5.2 Visualize Performance Metrics

In [11]:
# Prepare data for visualization (exclude OVERALL)
viz_metrics = {k: v for k, v in metrics.items() if k != 'OVERALL'}

# Create comparison chart for Precision, Recall, F1
entity_types = list(viz_metrics.keys())
precision_scores = [viz_metrics[et]['precision'] for et in entity_types]
recall_scores = [viz_metrics[et]['recall'] for et in entity_types]
f1_scores = [viz_metrics[et]['f1'] for et in entity_types]

fig = go.Figure()
fig.add_trace(go.Bar(name='Precision', x=entity_types, y=precision_scores, marker_color='#1f77b4'))
fig.add_trace(go.Bar(name='Recall', x=entity_types, y=recall_scores, marker_color='#ff7f0e'))
fig.add_trace(go.Bar(name='F1-Score', x=entity_types, y=f1_scores, marker_color='#2ca02c'))

fig.update_layout(
    title='Model Performance by Entity Type',
    xaxis_title='Entity Type',
    yaxis_title='Score',
    barmode='group',
    height=500,
    yaxis_range=[0, 1.1]
)
fig.show()

# Create support chart (number of entities per type)
support_data = [(et, viz_metrics[et]['support']) for et in entity_types]
support_df = pd.DataFrame(support_data, columns=['Entity Type', 'Support'])

fig2 = px.bar(
    support_df,
    x='Entity Type',
    y='Support',
    title='Test Set Entity Distribution (Support)',
    text='Support',
    color='Support',
    color_continuous_scale='viridis'
)
fig2.update_traces(textposition='outside')
fig2.show()

# Create radar chart for overall model performance
categories = ['Precision', 'Recall', 'F1-Score']
overall_values = [
    metrics['OVERALL']['precision'],
    metrics['OVERALL']['recall'],
    metrics['OVERALL']['f1']
]

fig3 = go.Figure()
fig3.add_trace(go.Scatterpolar(
    r=overall_values,
    theta=categories,
    fill='toself',
    name='Overall Performance'
))
fig3.update_layout(
    polar=dict(radialaxis=dict(visible=True, range=[0, 1])),
    title='Overall Model Performance Radar Chart',
    height=500
)
fig3.show()

### 5.3 Detailed Error Analysis

Examine specific predictions to understand model strengths and weaknesses.

In [12]:
# Find examples with errors
errors = []
perfect = []

for pred in predictions:
    gold = pred['gold']
    predicted = pred['pred']
    
    if gold == predicted:
        perfect.append(pred)
    else:
        # Calculate error metrics
        missed = gold - predicted  # False negatives
        incorrect = predicted - gold  # False positives
        
        if missed or incorrect:
            errors.append({
                'text': pred['text'][:200] + "..." if len(pred['text']) > 200 else pred['text'],
                'missed': missed,
                'incorrect': incorrect,
                'gold_count': len(gold),
                'pred_count': len(predicted)
            })

print(f"üìä Error Analysis Summary:")
print(f"   Perfect predictions: {len(perfect)}/{len(predictions)} ({len(perfect)/len(predictions)*100:.1f}%)")
print(f"   Predictions with errors: {len(errors)}/{len(predictions)} ({len(errors)/len(predictions)*100:.1f}%)")

# Show sample errors
print(f"\n‚ùå Sample Errors (showing first 3):")
for i, error in enumerate(errors[:3], 1):
    print(f"\n{i}. Text: {error['text']}")
    print(f"   Expected {error['gold_count']} entities, predicted {error['pred_count']}")
    
    if error['missed']:
        print(f"   Missed (False Negatives):")
        for start, end, label in list(error['missed'])[:3]:
            print(f"      - {label}")
    
    if error['incorrect']:
        print(f"   Incorrect (False Positives):")
        for start, end, label in list(error['incorrect'])[:3]:
            print(f"      - {label}")

# Show sample perfect predictions
print(f"\n‚úÖ Sample Perfect Predictions (showing first 2):")
for i, pred in enumerate(perfect[:2], 1):
    print(f"\n{i}. Text: {pred['text'][:150]}...")
    print(f"   Entities found: {len(pred['gold'])}")
    for start, end, label in list(pred['gold'])[:3]:
        print(f"      - {label}")

üìä Error Analysis Summary:
   Perfect predictions: 30/82 (36.6%)
   Predictions with errors: 52/82 (63.4%)

‚ùå Sample Errors (showing first 3):

1. Text: According to court records, filed plea documents, and court proceedings, from April 2020 to November 2021, Blackmon executed a scheme to defraud the U.S. Small Business Administration (SBA) and SBA-ba...
   Expected 8 entities, predicted 7
   Missed (False Negatives):
      - BUSINESS
      - FRAUD AMOUNT
      - BUSINESS
   Incorrect (False Positives):
      - DEFENDANT
      - GOV PROGRAM
      - DEFENDANT

2. Text: Henson faces a maximum possible sentence of 30 years in federal prison for each count of Bank Fraud, and a maximum possible sentence of 5 years in prison for each count of False Statements. U.S. Distr...
   Expected 3 entities, predicted 3
   Missed (False Negatives):
      - FRAUD MECHANISM
      - FRAUD MECHANISM
   Incorrect (False Positives):
      - JUDGE
      - BUSINESS

3. Text: Exum and Wandland conspired to 

## 6. Interactive Demo: Test on Custom Text

Try the model on your own DOJ press release text! Enter text in the cell below to extract entities.

In [13]:
# Sample DOJ press release text (you can replace this with your own)
sample_text = """
John Smith was sentenced to 5 years in prison by Judge Mary Johnson in the 
United States District Court for the Southern District of New York. Assistant 
U.S. Attorney Robert Davis prosecuted the case. Smith defrauded Medicare of 
approximately $2.5 million through a fraudulent billing scheme operated by 
his company, ABC Medical Services Inc.
"""

def extract_and_display_entities(text: str):
    """Extract entities and display with highlighting."""
    doc = trained_nlp(text)
    
    print("üîç ENTITY EXTRACTION RESULTS")
    print("=" * 80)
    print(f"Text: {text.strip()}\n")
    print("=" * 80)
    
    if not doc.ents:
        print("No entities found.")
        return
    
    # Group by entity type
    entities_by_type = defaultdict(list)
    for ent in doc.ents:
        entities_by_type[ent.label_].append(ent.text)
    
    print("\nüìã Extracted Entities:")
    for label in sorted(entities_by_type.keys()):
        print(f"\n{label}:")
        for entity in entities_by_type[label]:
            print(f"  ‚Ä¢ {entity}")
    
    print("\n" + "=" * 80)
    
    # Create visualization with spaCy's displaCy
    from spacy import displacy
    
    colors = {
        "DEFENDANT": "#ff9999",
        "PROSECUTOR": "#99ccff",
        "JUDGE": "#ffcc99",
        "SENTENCE": "#cc99ff",
        "FRAUD MECHANISM": "#ffff99",
        "FRAUD AMOUNT": "#99ff99",
        "GOV PROGRAM": "#ff99cc",
        "BUSINESS": "#99ffff"
    }
    
    options = {"colors": colors}
    
    # Display in notebook
    html = displacy.render(doc, style="ent", options=options, jupyter=False)
    from IPython.display import HTML, display
    display(HTML(html))
    
    return entities_by_type

# Run extraction
entities = extract_and_display_entities(sample_text)

üîç ENTITY EXTRACTION RESULTS
Text: John Smith was sentenced to 5 years in prison by Judge Mary Johnson in the 
United States District Court for the Southern District of New York. Assistant 
U.S. Attorney Robert Davis prosecuted the case. Smith defrauded Medicare of 
approximately $2.5 million through a fraudulent billing scheme operated by 
his company, ABC Medical Services Inc.


üìã Extracted Entities:

DEFENDANT:
  ‚Ä¢ Smith
  ‚Ä¢ Smith

FRAUD AMOUNT:
  ‚Ä¢ $2.5 million

PROSECUTOR:
  ‚Ä¢ Robert Davis

SENTENCE:
  ‚Ä¢ 5 years



### 6.1 Try Your Own Text

Edit the text below and run the cell to extract entities from your own press release!

In [14]:
# YOUR CUSTOM TEXT HERE - Replace with any DOJ press release text
custom_text = """
Paste your own DOJ press release text here to test the model!
"""

# Extract entities from your custom text
if custom_text.strip() and "Paste your own" not in custom_text:
    extract_and_display_entities(custom_text)
else:
    print("‚ö†Ô∏è Please replace the placeholder text with your own DOJ press release text.")

‚ö†Ô∏è Please replace the placeholder text with your own DOJ press release text.


## 7. Batch Processing: Process Multiple Press Releases

Load and process multiple press releases from a JSONL file.

In [15]:
# Process a batch of press releases
def batch_process(input_file: Path, nlp) -> pd.DataFrame:
    """Process multiple press releases and return results as DataFrame."""
    
    results = []
    
    with open(input_file, 'r', encoding='utf-8') as f:
        for line in tqdm(f, desc="Processing press releases"):
            item = json.loads(line)
            text = item.get('text', '')
            
            if not text:
                continue
            
            # Extract entities
            doc = nlp(text)
            
            # Count entities by type
            entity_counts = Counter()
            all_entities = defaultdict(list)
            
            for ent in doc.ents:
                entity_counts[ent.label_] += 1
                all_entities[ent.label_].append(ent.text)
            
            result = {
                'text_preview': text[:100] + "...",
                'text_length': len(text),
                'total_entities': len(doc.ents),
                **{f'{label}_count': entity_counts.get(label, 0) 
                   for label in ['DEFENDANT', 'PROSECUTOR', 'JUDGE', 'SENTENCE', 
                                 'FRAUD MECHANISM', 'FRAUD AMOUNT', 'GOV PROGRAM', 'BUSINESS']},
                'entities': dict(all_entities)
            }
            results.append(result)
    
    return pd.DataFrame(results)

# Example: Process test data
print("üìä Processing test set...")
test_results_df = batch_process(Path("data/raw/2025_11_27.jsonl"), trained_nlp)

print(f"\n‚úÖ Processed {len(test_results_df)} press releases")
print(f"\nSample results:")
print(test_results_df[['text_preview', 'total_entities', 'DEFENDANT_count', 'PROSECUTOR_count']].head())

üìä Processing test set...


Processing press releases: 1013it [00:04, 214.84it/s]




‚úÖ Processed 1013 press releases

Sample results:
                                        text_preview  total_entities  \
0  ‚Ä¢ Five defendants indicted for operating shell...               0   
1  ‚Ä¢ Concrete company owner pleads guilty to harb...               0   
2  The U.S. Attorney‚Äôs Office has charged more th...               0   
3  ‚Ä¢ Administrator of webhosting domain indicted ...               0   
4  ‚Ä¢ Defendant indicted for a $70 million Medicar...               2   

   DEFENDANT_count  PROSECUTOR_count  
0                0                 0  
1                0                 0  
2                0                 0  
3                0                 0  
4                0                 0  


### 7.1 Visualize Batch Processing Results

In [16]:
# Visualize entity extraction across all processed documents

# Calculate total entities found per type
entity_cols = [col for col in test_results_df.columns if col.endswith('_count')]
entity_totals = test_results_df[entity_cols].sum()

entity_summary = pd.DataFrame({
    'Entity Type': [col.replace('_count', '') for col in entity_cols],
    'Total Found': entity_totals.values
})

# Create bar chart
fig = px.bar(
    entity_summary,
    x='Entity Type',
    y='Total Found',
    title='Total Entities Extracted Across All Documents',
    color='Total Found',
    color_continuous_scale='plasma',
    text='Total Found'
)
fig.update_traces(textposition='outside')
fig.update_layout(height=500)
fig.show()

# Distribution of entities per document
fig2 = px.histogram(
    test_results_df,
    x='total_entities',
    title='Distribution of Total Entities per Document',
    labels={'total_entities': 'Number of Entities', 'count': 'Number of Documents'},
    nbins=30
)
fig2.show()

# Correlation heatmap between entity types
entity_corr = test_results_df[entity_cols].corr()
entity_corr.index = [col.replace('_count', '') for col in entity_cols]
entity_corr.columns = [col.replace('_count', '') for col in entity_cols]

fig3 = px.imshow(
    entity_corr,
    title='Entity Co-occurrence Correlation Matrix',
    color_continuous_scale='RdBu',
    aspect='auto',
    text_auto='.2f'
)
fig3.update_layout(height=600)
fig3.show()

## 8. Summary and Next Steps

### üéâ Congratulations!

You've successfully:
1. ‚úÖ Loaded and explored Prodigy-annotated DOJ press release data
2. ‚úÖ Converted data to spaCy format and created train/dev/test splits
3. ‚úÖ Trained a custom NER model with 8 entity types
4. ‚úÖ Evaluated model performance with detailed metrics and visualizations
5. ‚úÖ Tested the model interactively on custom text
6. ‚úÖ Performed batch processing on multiple documents

### üìä Model Performance Summary

Your model can now identify the following entities in DOJ press releases:
- **DEFENDANT**: Individuals or entities charged with crimes
- **PROSECUTOR**: U.S. Attorneys and prosecutors handling cases
- **JUDGE**: Federal judges presiding over cases
- **SENTENCE**: Prison terms, fines, and other sentencing details
- **FRAUD MECHANISM**: Methods used to commit fraud
- **FRAUD AMOUNT**: Dollar amounts involved in fraud cases
- **GOV PROGRAM**: Government programs targeted (Medicare, Medicaid, etc.)
- **BUSINESS**: Companies and business entities involved

### üöÄ Next Steps

1. **Improve Performance**: 
   - Add more training data
   - Adjust hyperparameters in `config.cfg`
   - Try different spaCy architectures (transformers)

2. **Deploy Model**:
   - Save model for production use
   - Create API endpoint for real-time extraction
   - Build web interface for non-technical users

3. **Extend Functionality**:
   - Add relationship extraction between entities
   - Implement document classification
   - Build time-series analysis of DOJ cases

### üíæ Model Location

Your trained model is saved at: `models/ner_model/model-best/`

You can load it anytime with:
```python
nlp = spacy.load("models/ner_model/model-best")
```