# Notebook 5: Token/NER Classification for Contract Entities

**Objective**: Ensure accurate entity extraction from contracts for automated validation:
- Detect token-level label issues (wrong entity type)
- Identify entity boundary errors
- Validate extracted amounts, dates, and parties

---

## Flow Diagram

```mermaid
flowchart TD
    subgraph Input["üìÑ Contract Document"]
        A[Contract Text]
        B[Tokenized Sequence]
    end

    subgraph NER["üè∑Ô∏è Entity Extraction"]
        C[NER Model]
        D[Token Predictions]
        E[Entity Spans]
        F[Prediction Probabilities]
    end

    subgraph Cleanlab["üßπ Cleanlab Token Classification"]
        G[find_label_issues - token]
        H[get_label_quality_scores - token]
        I[Per-Token Quality]
        J[Entity-Level Quality]
    end

    subgraph Quality["üîç Entity Quality Analysis"]
        K{Token Issue Type?}
        L[Wrong Entity Type ‚Üí N High]
        M[Boundary Error ‚Üí S Moderate]
        N[Low Confidence ‚Üí N/S]
        O[Clean Extraction ‚Üí R High]
    end

    subgraph Critical["‚ö†Ô∏è Critical Entity Check"]
        P{Entity Type?}
        Q[AMOUNT ‚Üí Verify value]
        R[DATE ‚Üí Verify format]
        S[PARTY ‚Üí Verify name]
        T[POISONING if critical wrong]
    end

    subgraph Routing["üö¶ Temperature Routing"]
        U[Entity-Level œÑ = 1/R]
        V{Trust Extraction?}
        W[üü¢ GREEN: Use extracted values]
        X[üü° YELLOW: Human verify]
        Y[üî¥ RED: Manual extraction]
    end

    A --> B
    B --> C
    C --> D
    C --> E
    C --> F
    D --> G
    F --> G
    F --> H
    G --> I
    H --> I
    I --> J
    J --> K
    K --> L
    K --> M
    K --> N
    K --> O
    E --> P
    P --> Q
    P --> R
    P --> S
    L --> T
    Q --> T
    R --> T
    S --> T
    O --> U
    M --> U
    N --> U
    T --> U
    U --> V
    V -->|High R| W
    V -->|Medium R| X
    V -->|Low R or Critical| Y

    style NER fill:#e1f5fe
    style Cleanlab fill:#fff3e0
    style Critical fill:#ffebee
    style Routing fill:#e8f5e9
```

---

**Collapse Type Focus**: POISONING (wrong amount/date extracted)

**Difficulty**: ‚≠ê‚≠ê‚≠ê Hard

## 1. Setup

In [None]:
# Install dependencies
!pip install cleanlab scikit-learn pandas numpy --quiet

In [None]:
import numpy as np
import pandas as pd
from typing import List, Dict, Tuple

# Cleanlab token classification
from cleanlab.token_classification import (
    filter as token_filter,
    rank as token_rank
)
from cleanlab.token_classification.summary import display_issues

# Import YRSN adapter
import sys
sys.path.append('../src')
from yrsn_iars.adapters.cleanlab_adapter import CleanlabAdapter, YRSNResult, CollapseType
from yrsn_iars.adapters.temperature import compute_temperature

print("Dependencies loaded successfully")

## 2. Define Entity Schema

Contract entities we need to extract and validate.

In [None]:
# Entity types for contract analysis
ENTITY_TYPES = {
    'O': 0,        # Outside any entity
    'B-AMOUNT': 1, # Beginning of amount
    'I-AMOUNT': 2, # Inside amount
    'B-DATE': 3,   # Beginning of date
    'I-DATE': 4,   # Inside date
    'B-PARTY': 5,  # Beginning of party name
    'I-PARTY': 6,  # Inside party name
    'B-TERM': 7,   # Beginning of term/duration
    'I-TERM': 8,   # Inside term
}

ID_TO_ENTITY = {v: k for k, v in ENTITY_TYPES.items()}
NUM_CLASSES = len(ENTITY_TYPES)

# Critical entities (errors here are POISONING)
CRITICAL_ENTITIES = ['B-AMOUNT', 'I-AMOUNT', 'B-DATE', 'I-DATE']

print(f"Entity types: {list(ENTITY_TYPES.keys())}")
print(f"Critical entities: {CRITICAL_ENTITIES}")

## 3. Generate Synthetic Contract Data

In [None]:
np.random.seed(42)

def generate_contract_sentence():
    """Generate a synthetic contract sentence with entities."""
    templates = [
        ("Agreement between {party1} and {party2} dated {date} for {amount} .",
         [(2, 'PARTY'), (4, 'PARTY'), (6, 'DATE'), (8, 'AMOUNT')]),
        ("{party1} agrees to pay {amount} to {party2} by {date} .",
         [(0, 'PARTY'), (4, 'AMOUNT'), (6, 'PARTY'), (8, 'DATE')]),
        ("Contract term : {term} starting {date} with value {amount} .",
         [(3, 'TERM'), (5, 'DATE'), (8, 'AMOUNT')]),
        ("Payment of {amount} due on {date} from {party1} .",
         [(2, 'AMOUNT'), (5, 'DATE'), (7, 'PARTY')]),
    ]
    
    parties = ["Acme Corp", "Tech Inc", "Global LLC", "Alpha Partners", "Beta Services"]
    amounts = ["$50,000", "$100,000", "$250,000", "$1,000,000", "$75,500"]
    dates = ["January 15, 2024", "March 1, 2024", "December 31, 2023", "July 20, 2024"]
    terms = ["12 months", "24 months", "36 months", "6 months"]
    
    template, entity_positions = templates[np.random.randint(len(templates))]
    
    # Fill template
    text = template.format(
        party1=np.random.choice(parties),
        party2=np.random.choice(parties),
        amount=np.random.choice(amounts),
        date=np.random.choice(dates),
        term=np.random.choice(terms)
    )
    
    return text, entity_positions

# Generate synthetic dataset
n_sentences = 100
sentences_data = []

for i in range(n_sentences):
    text, entity_pos = generate_contract_sentence()
    tokens = text.split()
    
    # Create labels
    labels = [0] * len(tokens)  # Default: O
    for pos, entity_type in entity_pos:
        if pos < len(tokens):
            # Handle multi-token entities
            labels[pos] = ENTITY_TYPES[f'B-{entity_type}']
            # Check if next token is also part of entity (simplified)
            if pos + 1 < len(tokens) and tokens[pos + 1] not in ['.', ',', 'and', 'to', 'by', 'from', 'for', 'with']:
                if np.random.random() > 0.5:  # 50% chance of multi-token
                    labels[pos + 1] = ENTITY_TYPES[f'I-{entity_type}']
    
    sentences_data.append({
        'sentence_id': i,
        'text': text,
        'tokens': tokens,
        'labels': labels,
        'n_tokens': len(tokens)
    })

sentences_df = pd.DataFrame(sentences_data)
print(f"Generated {len(sentences_df)} contract sentences")
print(f"\nSample:")
print(f"Text: {sentences_df.iloc[0]['text']}")
print(f"Tokens: {sentences_df.iloc[0]['tokens']}")
print(f"Labels: {[ID_TO_ENTITY[l] for l in sentences_df.iloc[0]['labels']]}")

## 4. Simulate NER Predictions with Errors

In [None]:
def simulate_ner_predictions(tokens: List[str], true_labels: List[int], 
                              error_rate: float = 0.15) -> Tuple[List[int], np.ndarray]:
    """
    Simulate NER model predictions with controlled error rate.
    Returns predicted labels and prediction probabilities.
    """
    n_tokens = len(tokens)
    pred_labels = true_labels.copy()
    pred_probs = np.zeros((n_tokens, NUM_CLASSES))
    
    for i in range(n_tokens):
        true_label = true_labels[i]
        
        if np.random.random() < error_rate:
            # Introduce error
            error_type = np.random.choice(['wrong_type', 'boundary', 'miss'])
            
            if error_type == 'wrong_type' and true_label > 0:
                # Predict wrong entity type
                wrong_labels = [l for l in range(NUM_CLASSES) if l != true_label and l != 0]
                pred_labels[i] = np.random.choice(wrong_labels)
                # Lower confidence
                pred_probs[i, pred_labels[i]] = 0.4 + np.random.random() * 0.3
                pred_probs[i, true_label] = 0.2 + np.random.random() * 0.2
                
            elif error_type == 'boundary':
                # Boundary error: B/I confusion
                if ID_TO_ENTITY[true_label].startswith('B-'):
                    pred_labels[i] = ENTITY_TYPES[ID_TO_ENTITY[true_label].replace('B-', 'I-')]
                elif ID_TO_ENTITY[true_label].startswith('I-'):
                    pred_labels[i] = ENTITY_TYPES[ID_TO_ENTITY[true_label].replace('I-', 'B-')]
                pred_probs[i, pred_labels[i]] = 0.5 + np.random.random() * 0.3
                pred_probs[i, true_label] = 0.3 + np.random.random() * 0.2
                
            else:  # miss
                # Miss entity entirely (predict O)
                pred_labels[i] = 0
                pred_probs[i, 0] = 0.5 + np.random.random() * 0.3
                pred_probs[i, true_label] = 0.2 + np.random.random() * 0.2
        else:
            # Correct prediction
            pred_labels[i] = true_label
            pred_probs[i, true_label] = 0.85 + np.random.random() * 0.14
        
        # Distribute remaining probability
        remaining = 1.0 - pred_probs[i].sum()
        noise = np.random.random(NUM_CLASSES)
        noise[pred_probs[i] > 0] = 0
        if noise.sum() > 0:
            pred_probs[i] += remaining * (noise / noise.sum())
    
    return pred_labels, pred_probs

# Generate predictions for all sentences
all_tokens = []
all_true_labels = []
all_pred_labels = []
all_pred_probs = []
sentence_boundaries = [0]

for _, row in sentences_df.iterrows():
    pred_labels, pred_probs = simulate_ner_predictions(
        row['tokens'], row['labels'], error_rate=0.12
    )
    
    all_tokens.extend(row['tokens'])
    all_true_labels.extend(row['labels'])
    all_pred_labels.extend(pred_labels)
    all_pred_probs.append(pred_probs)
    sentence_boundaries.append(sentence_boundaries[-1] + len(row['tokens']))

all_pred_probs = np.vstack(all_pred_probs)
all_true_labels = np.array(all_true_labels)
all_pred_labels = np.array(all_pred_labels)

print(f"Total tokens: {len(all_tokens)}")
print(f"Prediction accuracy: {(all_pred_labels == all_true_labels).mean():.3f}")

## 5. Cleanlab Token Classification Analysis

In [None]:
# Convert to format expected by Cleanlab token classification
# Cleanlab expects list of sentences, each with list of token labels/probs

labels_per_sentence = []
pred_probs_per_sentence = []

for i in range(len(sentences_df)):
    start = sentence_boundaries[i]
    end = sentence_boundaries[i + 1]
    labels_per_sentence.append(all_true_labels[start:end].tolist())
    pred_probs_per_sentence.append(all_pred_probs[start:end])

print(f"Prepared {len(labels_per_sentence)} sentences for Cleanlab")

In [None]:
# Find token-level label issues
token_issues = token_filter.find_label_issues(
    labels=labels_per_sentence,
    pred_probs=pred_probs_per_sentence
)

# Get quality scores
token_quality = token_rank.get_label_quality_scores(
    labels=labels_per_sentence,
    pred_probs=pred_probs_per_sentence
)

print(f"Found token issues in {sum(1 for issues in token_issues if any(issues))} sentences")

In [None]:
# Add results to sentences dataframe
sentences_df['token_issues'] = token_issues
sentences_df['token_quality'] = token_quality
sentences_df['n_issues'] = [sum(issues) for issues in token_issues]
sentences_df['min_quality'] = [min(q) if len(q) > 0 else 1.0 for q in token_quality]
sentences_df['mean_quality'] = [np.mean(q) if len(q) > 0 else 1.0 for q in token_quality]

print("Token Quality Statistics:")
print(sentences_df[['n_issues', 'min_quality', 'mean_quality']].describe())

## 6. Aggregate to Entity-Level YRSN

In [None]:
def extract_entities_with_quality(tokens, labels, quality_scores, issues):
    """Extract entities and compute per-entity quality."""
    entities = []
    current_entity = None
    
    for i, (token, label, quality, is_issue) in enumerate(zip(tokens, labels, quality_scores, issues)):
        entity_name = ID_TO_ENTITY[label]
        
        if entity_name.startswith('B-'):
            # Save previous entity
            if current_entity:
                entities.append(current_entity)
            
            # Start new entity
            entity_type = entity_name[2:]
            current_entity = {
                'type': entity_type,
                'tokens': [token],
                'start': i,
                'end': i,
                'qualities': [quality],
                'has_issue': [is_issue],
                'is_critical': entity_name in CRITICAL_ENTITIES
            }
            
        elif entity_name.startswith('I-') and current_entity:
            # Continue entity
            current_entity['tokens'].append(token)
            current_entity['end'] = i
            current_entity['qualities'].append(quality)
            current_entity['has_issue'].append(is_issue)
            
        else:
            # O tag - save current entity
            if current_entity:
                entities.append(current_entity)
                current_entity = None
    
    # Don't forget last entity
    if current_entity:
        entities.append(current_entity)
    
    # Compute entity-level quality
    for entity in entities:
        entity['text'] = ' '.join(entity['tokens'])
        entity['min_quality'] = min(entity['qualities'])
        entity['mean_quality'] = np.mean(entity['qualities'])
        entity['any_issue'] = any(entity['has_issue'])
    
    return entities

# Extract entities for all sentences
all_entities = []
for _, row in sentences_df.iterrows():
    entities = extract_entities_with_quality(
        row['tokens'], row['labels'], 
        row['token_quality'], row['token_issues']
    )
    for entity in entities:
        entity['sentence_id'] = row['sentence_id']
        all_entities.append(entity)

entities_df = pd.DataFrame(all_entities)
print(f"Extracted {len(entities_df)} entities")
print(f"\nEntity type distribution:")
print(entities_df['type'].value_counts())

## 7. YRSN Decomposition for Entities

In [None]:
adapter = CleanlabAdapter()

def compute_entity_yrsn(row):
    """Compute YRSN for a single entity extraction."""
    
    # N (Noise): Issues in extraction
    n_from_quality = 1 - row['min_quality']
    n_from_issues = 0.3 if row['any_issue'] else 0.0
    
    # Critical entities get higher N if issues exist
    if row['is_critical'] and row['any_issue']:
        n_from_issues += 0.3  # POISONING risk
    
    N = min(1.0, n_from_quality * 0.6 + n_from_issues)
    
    # S (Superfluous): Boundary uncertainty
    # Multi-token entities with quality variance ‚Üí boundary issues
    quality_variance = np.var(row['qualities']) if len(row['qualities']) > 1 else 0
    S = min(0.5, quality_variance * 2)
    S = min(1.0 - N, S)
    
    # R (Relevant): Clean extraction
    R = max(0, 1.0 - N - S)
    
    # Normalize
    total = R + S + N
    return YRSNResult(R=R/total, S=S/total, N=N/total)

# Apply YRSN computation
yrsn_results = entities_df.apply(compute_entity_yrsn, axis=1)
entities_df['R'] = [y.R for y in yrsn_results]
entities_df['S'] = [y.S for y in yrsn_results]
entities_df['N'] = [y.N for y in yrsn_results]
entities_df['collapse_type'] = [y.collapse_type.value for y in yrsn_results]

# Mark POISONING for critical entities with issues
entities_df.loc[
    (entities_df['is_critical']) & (entities_df['N'] > 0.3),
    'collapse_type'
] = 'poisoning'

print("Entity YRSN Statistics:")
print(entities_df[['R', 'S', 'N']].describe())
print(f"\nCollapse distribution:")
print(entities_df['collapse_type'].value_counts())

## 8. Temperature-Based Routing for Entity Extraction

In [None]:
# Compute temperature per entity
entities_df['temperature'] = entities_df['R'].apply(lambda r: compute_temperature(r))

# NER-specific routing (tighter thresholds for critical entities)
def route_entity(row):
    tau = row['temperature']
    is_critical = row['is_critical']
    collapse = row['collapse_type']
    
    # POISONING on critical entity: always red
    if collapse == 'poisoning' and is_critical:
        return 'red'
    
    # Critical entities use tighter thresholds
    if is_critical:
        if tau < 1.2 and row['R'] > 0.7:
            return 'green'
        elif tau < 1.5:
            return 'yellow'
        else:
            return 'red'
    else:
        # Non-critical entities
        if tau < 1.5:
            return 'green'
        elif tau < 2.5:
            return 'yellow'
        else:
            return 'red'

entities_df['stream'] = entities_df.apply(route_entity, axis=1)

print("Entity Routing Distribution:")
print(entities_df['stream'].value_counts())

print("\nRouting by Entity Type:")
print(pd.crosstab(entities_df['type'], entities_df['stream'], normalize='index').round(2))

## 9. Identify POISONING Cases (Critical Extraction Errors)

In [None]:
# Find POISONING cases (critical entity extraction errors)
poisoning_cases = entities_df[
    (entities_df['collapse_type'] == 'poisoning') | 
    ((entities_df['is_critical']) & (entities_df['N'] > 0.25))
]

print("POISONING Cases (Critical Entity Extraction Errors):")
print("="*60)
for _, row in poisoning_cases.head(10).iterrows():
    print(f"\n[Sentence {row['sentence_id']}] Entity: '{row['text']}'")
    print(f"  Type: {row['type']} (CRITICAL)")
    print(f"  YRSN: R={row['R']:.2f}, S={row['S']:.2f}, N={row['N']:.2f}")
    print(f"  Min Quality: {row['min_quality']:.3f}")
    print(f"  Has Issue: {row['any_issue']}")
    print(f"  Stream: {row['stream'].upper()}, œÑ={row['temperature']:.2f}")

## 10. Aggregate to Sentence-Level Decision

In [None]:
# For each sentence, determine overall routing based on worst entity
def aggregate_sentence_routing(sentence_entities):
    """Aggregate entity-level routing to sentence-level."""
    if len(sentence_entities) == 0:
        return {'stream': 'green', 'worst_entity': None, 'n_entities': 0}
    
    # Sentence stream is worst of all entities
    stream_priority = {'red': 0, 'yellow': 1, 'green': 2}
    worst_idx = sentence_entities['stream'].map(stream_priority).idxmin()
    worst_row = sentence_entities.loc[worst_idx]
    
    return {
        'stream': worst_row['stream'],
        'worst_entity_type': worst_row['type'],
        'worst_entity_text': worst_row['text'],
        'worst_R': worst_row['R'],
        'n_entities': len(sentence_entities),
        'n_critical': sentence_entities['is_critical'].sum(),
        'has_poisoning': (sentence_entities['collapse_type'] == 'poisoning').any()
    }

# Aggregate
sentence_routing = entities_df.groupby('sentence_id').apply(
    aggregate_sentence_routing
).apply(pd.Series)

sentences_df = sentences_df.merge(sentence_routing, left_on='sentence_id', right_index=True)

print("Sentence-Level Routing:")
print(sentences_df['stream'].value_counts())
print(f"\nSentences with POISONING risk: {sentences_df['has_poisoning'].sum()}")

## 11. Export Results

In [None]:
# Save entity-level results
entity_cols = ['sentence_id', 'type', 'text', 'is_critical', 
               'min_quality', 'R', 'S', 'N', 'collapse_type', 'temperature', 'stream']
entities_df[entity_cols].to_csv('entity_yrsn_results.csv', index=False)

# Save sentence-level results
sentence_cols = ['sentence_id', 'text', 'n_issues', 'mean_quality', 
                 'stream', 'worst_entity_type', 'has_poisoning', 'n_critical']
sentences_df[sentence_cols].to_csv('sentence_yrsn_results.csv', index=False)

print(f"Saved {len(entities_df)} entities and {len(sentences_df)} sentences")

## Summary

In this notebook we:
1. Generated synthetic contract text with entity annotations
2. Simulated NER model predictions with controlled error rates
3. Used Cleanlab token classification to find label issues
4. Aggregated token quality to entity-level YRSN
5. Applied stricter routing for critical entities (AMOUNT, DATE)
6. Identified POISONING collapse cases (critical extraction errors)
7. Aggregated to sentence-level routing decisions

**Key Insight**: For NER in contracts, AMOUNT and DATE entities are critical. Any extraction error on these is POISONING and forces RED routing for manual verification. Temperature is kept tight (œÑ < 1.2) for critical entities.

**Next**: Notebook 6 - Production Pipeline (Full IARS Integration)