In [1]:
import os
import json
import random
import string
import sys
import regex as re
from estnltk import Text, Layer
from estnltk.taggers.standard.text_segmentation.compound_token_tagger import CompoundTokenTagger, ALL_1ST_LEVEL_PATTERNS
from estnltk.taggers.standard.text_segmentation.patterns import MACROS
from collections import Counter, defaultdict

In [2]:
"""
NER Preprocessing Pipeline for Tartu City Council Protocols
Handles filtering, tokenization correction, and diagnostics
"""
# ============================================================================
# CONFIGURATION
# ============================================================================

DATA_DIR = "../data/raw"  # Input: original annotated JSONs
FILTERED_DIR = "json_ne_gold_a"           # Output: filtered JSONs (ne_gold_a only)

# ============================================================================
# TOKENIZATION CORRECTION (from words_tokenization.py)
# ============================================================================

def make_adapted_cp_tagger(**kwargs):
    """Creates an adapted CompoundTokenTagger that:
       1) excludes roman numerals from names with initials;
       2) does not join date-like token sequences as numbers;
    """
    # Pattern 1: Names with 2 initials (exclude titles and roman numerals I, V, X)
    redefined_pat_1 = {
        'comment': '*) Names starting with 2 initials (exclude titles and roman numerals I, V, X);',
        'pattern_type': 'name_with_initial',
        'example': 'A. H. Tammsaare',
        '_regex_pattern_': re.compile(r'''
            (?!(Dr\.|Lb\.|Lh\.|Lm\.|Ln\.|Lv\.|Lw\.|Pr\.))     # exclude titles
            ([ABCDEFGHJKLMNOPQRSTUWYZÅ Å½Ã•Ã„Ã–Ãœ][{LOWERCASE}]?)   # first initial
            \s?\.\s?-?                                        # period (and hyphen potentially)
            ([ABCDEFGHJKLMNOPQRSTUWYZÅ Å½Ã•Ã„Ã–Ãœ][{LOWERCASE}]?)   # second initial
            \s?\.\s?                                          # period
            ((\.[{UPPERCASE}]\.)?[{UPPERCASE}][{LOWERCASE}]+) # last name
        '''.format(**MACROS), re.X),
        '_group_': 0,
        '_priority_': (4, 1),
        'normalized': lambda m: re.sub('\1.\2. \3', '', m.group(0)),
    }

    # Pattern 2: Names with 1 initial (exclude roman numerals I, V, X)
    redefined_pat_2 = {
        'comment': '*) Names starting with one initial (exclude roman numerals I, V, X);',
        'pattern_type': 'name_with_initial',
        'example': 'A. Hein',
        '_regex_pattern_': re.compile(r'''
            ([ABCDEFGHJKLMNOPQRSTUWYZÅ Å½Ã•Ã„Ã–Ãœ])   # first initial
            \s?\.\s?                            # period
            ([{UPPERCASE}][{LOWERCASE}]+)       # last name
        '''.format(**MACROS), re.X),
        '_group_': 0,
        '_priority_': (4, 2),
        'normalized': lambda m: re.sub('\1. \2', '', m.group(0)),
    }

    # Pattern 3: Long numbers (1 group, corrected for timex tagger)
    redefined_number_pat_1 = {
        'comment': '*) A generic pattern for detecting long numbers (1 group) (corrected for timex tagger).',
        'example': '12,456',
        'pattern_type': 'numeric',
        '_group_': 0,
        '_priority_': (2, 1, 5),
        '_regex_pattern_': re.compile(r'''                             
            \d+           # 1 group of numbers
            (,\d+|\ *\.)  # + comma-separated numbers or period-ending
        ''', re.X),
        'normalized': r"lambda m: re.sub(r'[\s]' ,'' , m.group(0))"
    }

    # Pattern 4: Long numbers (2 groups, point-separated, followed by comma-separated)
    redefined_number_pat_2 = {
        'comment': '*) A generic pattern for detecting long numbers (2 groups, point-separated, followed by comma-separated numbers) (corrected for timex tagger).',
        'example': '67.123,456',
        'pattern_type': 'numeric',
        '_group_': 0,
        '_priority_': (2, 1, 3, 1),
        '_regex_pattern_': re.compile(r'''
            \d+\.+\d+   # 2 groups of numbers
            (,\d+)      # + comma-separated numbers
        ''', re.X),
        'normalized': r"lambda m: re.sub(r'[\s\.]' ,'' , m.group(0))"
    }

    # Build new pattern list
    new_1st_level_patterns = []
    for pat in ALL_1ST_LEVEL_PATTERNS:
        # Skip these patterns
        if pat['comment'] in [
            '*) Abbreviations of type <uppercase letter> + <numbers>;',
            '*) Date patterns that contain month as a Roman numeral: "dd. roman_mm yyyy";',
            '*) Date patterns in the commonly used form "dd/mm/yy";'
        ]:
            continue
        
        # Replace these patterns
        if pat['comment'] == '*) Names starting with 2 initials;':
            new_1st_level_patterns.append(redefined_pat_1)
        elif pat['comment'] == '*) Names starting with one initial;':
            new_1st_level_patterns.append(redefined_pat_2)
        elif pat['comment'] == '*) A generic pattern for detecting long numbers (1 group).':
            new_1st_level_patterns.append(redefined_number_pat_1)
        elif pat['comment'] == '*) A generic pattern for detecting long numbers (2 groups, point-separated, followed by comma-separated numbers).':
            new_1st_level_patterns.append(redefined_number_pat_2)
        else:
            new_1st_level_patterns.append(pat)
    
    assert len(new_1st_level_patterns) + 3 == len(ALL_1ST_LEVEL_PATTERNS)
    
    if kwargs and 'patterns_1' in kwargs:
        raise ValueError("Cannot overwrite 'patterns_1' in adapted CompoundTokenTagger.")
    
    return CompoundTokenTagger(
        patterns_1=new_1st_level_patterns,
        do_not_join_on_strings=('\n\n', '\n'),
        **kwargs
    )


# Initialize the adapted tagger
adapted_cp_tokens_tagger = make_adapted_cp_tagger(
    input_tokens_layer='tokens',
    output_layer='compound_tokens'
)


def preprocess_words(input_text):
    """Pre-processes Text object: adds word segmentation."""
    input_text.tag_layer('tokens')
    adapted_cp_tokens_tagger.tag(input_text)
    input_text.tag_layer('words') 
    return input_text


# ============================================================================
# STEP 1: FILTER JSON FILES (keep only ne_gold_a or ne_gold_b)
# ============================================================================

def filter_json_to_gold_a(input_dir, output_dir):
    """
    Filter JSON files to keep only ne_gold_a layer (or ne_gold_b if ne_gold_a missing).
    Prioritizes first annotator (ne_gold_a) over second (ne_gold_b).
    """
    os.makedirs(output_dir, exist_ok=True)
    
    for fname in sorted(os.listdir(input_dir)):
        if not fname.endswith(".json"):
            continue
        
        input_path = os.path.join(input_dir, fname)
        output_path = os.path.join(output_dir, fname)
        
        with open(input_path, encoding="utf-8") as f:
            data = json.load(f)
        
        layers = data.get("layers", [])
        
        # Prioritize ne_gold_a over ne_gold_b
        has_a = any(layer['name'] == 'ne_gold_a' for layer in layers)
        if has_a:
            filtered_layers = [layer for layer in layers if layer["name"] == "ne_gold_a"]
        else:
            filtered_layers = [layer for layer in layers if layer["name"] == "ne_gold_b"]
        
        data["layers"] = filtered_layers
        
        if not filtered_layers:
            print(f"âš  No NER layer in: {fname}")
        
        with open(output_path, "w", encoding="utf-8") as f:
            json.dump(data, f, ensure_ascii=False, indent=2)
    
    print(f"âœ“ Filtered {len([f for f in os.listdir(input_dir) if f.endswith('.json')])} files to {output_dir}")


# ============================================================================
# STEP 2: DIAGNOSTIC - CHECK TOKENIZATION ALIGNMENT
# ============================================================================

def diagnose_tokenization(file_path, max_spans=None, verbose=True, all_spans_log=None, error_log=None):
    """
    Check alignment between NER spans and word tokens.
    Args:
        file_path: Path to JSON file
        max_spans: Maximum spans to check (None = all spans)
        verbose: Print detailed output
        all_spans_log: File handle to write ALL spans to (optional)
        error_log: File handle to write errors to (optional)
    Returns: (total_spans, missing_alignments)
    """
    with open(file_path, encoding="utf-8") as f:
        data = json.load(f)
    
    text = data.get("text", "")
    txt = Text(text)
    txt = preprocess_words(txt)
    
    # Build word token list
    words = []
    for w in txt['words']:
        norm = w.annotations[0].get('normalized_form') if w.annotations else None
        norm = norm or w.text
        words.append({
            "text": w.text,
            "start": w.start,
            "end": w.end,
            "norm": norm
        })
    
    # Find NER layer (ne_gold_a or ne_gold_b)
    ne_layer = None
    for candidate in ("ne_gold_a", "ne_gold_b"):
        ne_layer = next((l for l in data.get("layers", []) if l["name"] == candidate), None)
        if ne_layer:
            break
    
    if not ne_layer:
        if verbose:
            print(f"âš  No NER layer in: {file_path}")
        return 0, 0
    
    spans = ne_layer.get("spans", [])
    
    def tokens_overlapping(start, end):
        return [w for w in words if w["start"] < end and w["end"] > start]
    
    missing = 0
    
    if verbose:
        print(f"\n{file_path}")
        print("=" * 80)
    
    spans_to_check = spans if max_spans is None else spans[:max_spans]
    
    for span in spans_to_check:
        start, end = span["base_span"]
        etype = span["annotations"][0].get("tag", "?")
        etext = text[start:end]
        covered = tokens_overlapping(start, end)
        
        # Get token strings
        token_strings = [w["norm"] for w in covered] if covered else []
        tokens_str = "|".join(token_strings)
        has_error = "YES" if not covered else "NO"
        
        # Write to all_spans_log if provided
        if all_spans_log:
            all_spans_log.write(f"{os.path.basename(file_path)}\t{repr(etext)}\t{etype}\t({start},{end})\t{tokens_str}\t{has_error}\n")
        
        # Track and log errors separately
        if not covered:
            missing += 1
            
            if error_log:
                error_log.write(f"{os.path.basename(file_path)}\t{repr(etext)}\t{etype}\t({start},{end})\n")
            
            if verbose:
                print(f"âš  MISSING: {repr(etext)} â†’ {etype} | span=({start},{end})")
        elif verbose:
            print(f"âœ“ {repr(etext)} â†’ {etype} | tokens={token_strings}")
    
    if verbose:
        checked_count = len(spans_to_check)
        print(f"\nðŸ“Š Checked {checked_count}/{len(spans)} spans, missing alignments: {missing}")
        print("-" * 80)
    
    return len(spans), missing


def diagnose_all_files(input_dir, max_spans=None, verbose=True, all_spans_file="all_spans.tsv", error_file="tokenization_errors.tsv"):
    """
    Run diagnostics on all JSON files in directory.
    Args:
        input_dir: Directory with JSON files
        max_spans: Maximum spans per file (None = all spans)
        verbose: Print detailed output
        all_spans_file: File to write ALL spans to (None = don't write)
        error_file: File to write errors only (None = don't write)
    """
    total_spans = 0
    total_missing = 0
    
    all_spans_log = None
    error_log = None
    
    if all_spans_file:
        all_spans_log = open(all_spans_file, "w", encoding="utf-8")
        all_spans_log.write("file\tentity_text\tentity_type\tspan_positions\ttokens\thas_error\n")
    
    if error_file:
        error_log = open(error_file, "w", encoding="utf-8")
        error_log.write("file\tentity_text\tentity_type\tspan_positions\n")
    
    try:
        for fname in sorted(f for f in os.listdir(input_dir) if f.endswith(".json")):
            file_path = os.path.join(input_dir, fname)
            spans, missing = diagnose_tokenization(file_path, max_spans, verbose, all_spans_log, error_log)
            total_spans += spans
            total_missing += missing
        
        print(f"\n{'=' * 80}")
        print(f"TOTAL: {total_spans} spans, {total_missing} missing alignments ({100*total_missing/total_spans:.1f}%)")
        print(f"{'=' * 80}")
        
        if all_spans_file:
            print(f"âœ“ All spans written to: {all_spans_file}")
        if error_file and total_missing > 0:
            print(f"âš  Errors written to: {error_file}")
    
    finally:
        if all_spans_log:
            all_spans_log.close()
        if error_log:
            error_log.close()


# ============================================================================
# MAIN EXECUTION
# ============================================================================

if __name__ == "__main__":
    # Step 1: Filter JSON files (keep only ne_gold_a)
    print("Step 1: Filtering JSON files...")
    filter_json_to_gold_a(DATA_DIR, FILTERED_DIR)
    
    # Step 2: Run diagnostics on ALL spans and save to files
    print("\nStep 2: Running tokenization diagnostics on ALL spans...")
    diagnose_all_files(
        FILTERED_DIR, 
        max_spans=None,  # Check ALL spans
        verbose=False,   # Don't print every span (too much output)
        all_spans_file="all_spans.tsv",  # Write ALL spans here
        error_file="tokenization_errors.tsv"  # Write only errors here
    )
    
    print("\nâœ“ Done!")
    print("  - all_spans.tsv: Contains ALL spans with their tokenization")
    print("  - tokenization_errors.tsv: Contains only problematic spans")

Step 1: Filtering JSON files...
âœ“ Filtered 44 files to json_ne_gold_a

Step 2: Running tokenization diagnostics on ALL spans...

TOTAL: 22660 spans, 0 missing alignments (0.0%)
âœ“ All spans written to: all_spans.tsv

âœ“ Done!
  - all_spans.tsv: Contains ALL spans with their tokenization
  - tokenization_errors.tsv: Contains only problematic spans


In [3]:
"""
Step 1.3: Convert corpus to BIO format
Unifies entity categories and converts to BIO tagging format
"""
# ============================================================================
# CONFIGURATION
# ============================================================================

FILTERED_DIR = "json_ne_gold_a"           # Input: filtered JSONs
BIO_OUTPUT_FILE = "corpus_bio.tsv"        # Output: BIO format corpus

# Category mapping for unification
CATEGORY_MAPPING = {
    'LOC_ADDRESS': 'LOC',
    'ORG_GPE': 'ORG',
    'ORG_POL': 'ORG',
    # Keep others as-is: PER, LOC, ORG, POSITION, etc.
}

# ============================================================================
# BIO CONVERSION FUNCTIONS
# ============================================================================

def unify_category(category):
    """Unify entity categories according to mapping."""
    return CATEGORY_MAPPING.get(category, category)


def convert_file_to_bio(file_path):
    """
    Convert a single JSON file to BIO format.
    Returns list of sentences, where each sentence is a list of (token, bio_tag) tuples.
    """
    with open(file_path, encoding="utf-8") as f:
        data = json.load(f)
    
    text = data.get("text", "")
    txt = Text(text)
    txt = preprocess_words(txt)
    
    # Tag sentences
    txt.tag_layer('sentences')
    
    # Build word token list with sentence info
    words = []
    for sent_id, sent_span in enumerate(txt['sentences']):
        sent_words = []
        for w in txt['words']:
            # Check if word is within this sentence
            if w.start >= sent_span.start and w.end <= sent_span.end:
                norm = w.annotations[0].get('normalized_form') if w.annotations else None
                norm = norm or w.text
                sent_words.append({
                    "text": w.text,
                    "norm": norm,
                    "start": w.start,
                    "end": w.end,
                    "sent_id": sent_id
                })
        if sent_words:  # Only add non-empty sentences
            words.extend(sent_words)
    
    # Find NER layer
    ne_layer = None
    for candidate in ("ne_gold_a", "ne_gold_b"):
        ne_layer = next((l for l in data.get("layers", []) if l["name"] == candidate), None)
        if ne_layer:
            break
    
    if not ne_layer:
        return []
    
    spans = ne_layer.get("spans", [])
    
    # Create BIO tags for each word
    for word in words:
        word['bio_tag'] = 'O'  # Default: outside any entity
    
    # Process each NER span
    for span in spans:
        start, end = span["base_span"]
        etype = span["annotations"][0].get("tag", "?")
        etype = unify_category(etype)  # Unify categories
        
        # Find overlapping words
        overlapping = [w for w in words if w["start"] < end and w["end"] > start]
        
        if overlapping:
            # First token gets B- (Beginning)
            overlapping[0]['bio_tag'] = f'B-{etype}'
            # Rest get I- (Inside)
            for w in overlapping[1:]:
                w['bio_tag'] = f'I-{etype}'
    
    # Group words by sentence
    sentences = []
    current_sent_id = None
    current_sent = []
    
    for word in words:
        if word['sent_id'] != current_sent_id:
            if current_sent:
                sentences.append(current_sent)
            current_sent = []
            current_sent_id = word['sent_id']
        current_sent.append((word['norm'], word['bio_tag']))
    
    if current_sent:
        sentences.append(current_sent)
    
    return sentences


def convert_corpus_to_bio(input_dir, output_file):
    """
    Convert all JSON files to BIO format and save as TSV.
    Format: token<TAB>bio_tag, with double newline between sentences.
    Returns: list of all sentences for further processing.
    """
    all_sentences = []
    file_sentences = {}  # Track which file each sentence came from
    
    print("Converting corpus to BIO format...")
    
    for fname in sorted(f for f in os.listdir(input_dir) if f.endswith(".json")):
        file_path = os.path.join(input_dir, fname)
        sentences = convert_file_to_bio(file_path)
        file_sentences[fname] = sentences
        all_sentences.extend(sentences)
        print(f"  âœ“ {fname}: {len(sentences)} sentences")
    
    # Write to TSV
    with open(output_file, 'w', encoding='utf-8') as f:
        for sentence in all_sentences:
            for token, bio_tag in sentence:
                f.write(f"{token}\t{bio_tag}\n")
            f.write("\n")  # Double newline between sentences
    
    print(f"\nâœ“ Saved BIO corpus to: {output_file}")
    print(f"  Total sentences: {len(all_sentences)}")
    print(f"  Total tokens: {sum(len(s) for s in all_sentences)}")
    
    # Count entities by type
    entity_counts = {}
    for sentence in all_sentences:
        for token, bio_tag in sentence:
            if bio_tag.startswith('B-'):
                etype = bio_tag[2:]
                entity_counts[etype] = entity_counts.get(etype, 0) + 1
    
    print(f"\n  Entity counts:")
    for etype, count in sorted(entity_counts.items()):
        print(f"    {etype}: {count}")
    
    return all_sentences, file_sentences


# ============================================================================
# MAIN EXECUTION
# ============================================================================

if __name__ == "__main__":
    print("=" * 80)
    print("STEP 1.3: Converting corpus to BIO format")
    print("=" * 80)
    
    all_sentences, file_sentences = convert_corpus_to_bio(FILTERED_DIR, BIO_OUTPUT_FILE)
    
    print("\n" + "=" * 80)
    print("âœ“ BIO conversion complete!")
    print("=" * 80)
    print(f"Output file: {BIO_OUTPUT_FILE}")
    print("\nNext: Run Step 1.4 (train/dev/test split)")

STEP 1.3: Converting corpus to BIO format
Converting corpus to BIO format...
  âœ“ 1918-01-22_manual_annotated_anonymized.json: 146 sentences
  âœ“ 1918-12-09_manual_annotated_anonymized.json: 76 sentences
  âœ“ 1919-06-16_manual_annotated_anonymized.json: 307 sentences
  âœ“ 1919-07-28_manual_annotated_anonymized.json: 186 sentences
  âœ“ 1919-08-25_manual_annotated_anonymized.json: 233 sentences
  âœ“ 1921-05-19_manual_annotated_anonymized.json: 229 sentences
  âœ“ 1921-08-29_manual_annotated_anonymized.json: 207 sentences
  âœ“ 1921-09-26_manual_annotated_anonymized.json: 313 sentences
  âœ“ 1922-04-24_manual_annotated_anonymized.json: 179 sentences
  âœ“ 1922-05-29_manual_annotated_anonymized.json: 475 sentences
  âœ“ 1923-09-24_manual_annotated_anonymized.json: 489 sentences
  âœ“ 1924-11-26_manual_annotated_anonymized.json: 261 sentences
  âœ“ 1925-11-11_manual_annotated_anonymized.json: 231 sentences
  âœ“ 1926-02-17_manual_annotated_anonymized.json: 558 sentences
  âœ“ 1926-03-

In [4]:
"""
Step 1.4: Split corpus into train/dev/test sets
- Test: 10% of tokens
- Dev: 8-10% of tokens
- Train: ~80% of tokens
Splits by complete protocols (files), not by individual sentences

This version EXCLUDES entity types: EVENT and UNK (everywhere):
- They are not converted to BIO tags
- They are not counted in per-file entity totals
- They do not appear in "ENTITY COUNTS BY TYPE"
- File distribution "Total entities" reflects the excluded setting
"""
# ====== CONFIG ======
TRAIN_FILE = "train.tsv"
DEV_FILE = "dev.tsv"
TEST_FILE = "test.tsv"

RANDOM_SEED = 42
TEST_SIZE = 0.10   # 10% for test
DEV_SIZE = 0.09    # 9% for dev
# Train will be remaining (~81%)

# Exclude these entity types everywhere
EXCLUDED_ENTITY_TYPES = {"EVENT", "UNK"}

# NOTE: These come from your environment. Keep as-is.
# - FILTERED_DIR
# - Text
# - preprocess_words


def unify_category(category: str) -> str:
    """Unify entity categories according to mapping."""
    CATEGORY_MAPPING = {
        'LOC_ADDRESS': 'LOC',
        'ORG_GPE': 'ORG',
        'ORG_POL': 'ORG',
    }
    return CATEGORY_MAPPING.get(category, category)


def get_ne_layer(data: dict):
    """Get the first available NER gold layer."""
    for candidate in ("ne_gold_a", "ne_gold_b"):
        layer = next((l for l in data.get("layers", []) if l.get("name") == candidate), None)
        if layer:
            return layer
    return None


def get_file_statistics(file_path: str) -> dict:
    """
    Get statistics for a single file: tokens, sentences, entities (EXCLUDING EVENT/UNK).
    Returns dict with counts.
    """
    with open(file_path, encoding="utf-8") as f:
        data = json.load(f)

    text = data.get("text", "")
    txt = Text(text)
    txt = preprocess_words(txt)
    txt.tag_layer('sentences')

    num_tokens = len(txt['words'])
    num_sentences = len(txt['sentences'])

    # Count entities excluding EVENT/UNK
    ne_layer = get_ne_layer(data)
    num_entities = 0
    if ne_layer:
        for span in ne_layer.get("spans", []):
            tag = span.get("annotations", [{}])[0].get("tag", "?")
            etype = unify_category(tag)
            if etype in EXCLUDED_ENTITY_TYPES:
                continue
            num_entities += 1

    return {
        'tokens': num_tokens,
        'sentences': num_sentences,
        'entities': num_entities
    }


def convert_file_to_bio_sentences(file_path: str):
    """Convert a file to BIO format sentences (EXCLUDING EVENT/UNK spans)."""
    with open(file_path, encoding="utf-8") as f:
        data = json.load(f)

    text = data.get("text", "")
    txt = Text(text)
    txt = preprocess_words(txt)
    txt.tag_layer('sentences')

    # Build word list
    words = []
    for sent_id, sent_span in enumerate(txt['sentences']):
        for w in txt['words']:
            if w.start >= sent_span.start and w.end <= sent_span.end:
                norm = w.annotations[0].get('normalized_form') if w.annotations else None
                norm = norm or w.text
                words.append({
                    "norm": norm,
                    "start": w.start,
                    "end": w.end,
                    "sent_id": sent_id,
                    "bio_tag": 'O'
                })

    # Apply NER spans
    ne_layer = get_ne_layer(data)
    if ne_layer:
        for span in ne_layer.get("spans", []):
            start, end = span["base_span"]
            tag = span.get("annotations", [{}])[0].get("tag", "?")
            etype = unify_category(tag)

            # Exclude these types completely
            if etype in EXCLUDED_ENTITY_TYPES:
                continue

            overlapping = [w for w in words if w["start"] < end and w["end"] > start]
            if overlapping:
                overlapping[0]['bio_tag'] = f'B-{etype}'
                for w in overlapping[1:]:
                    w['bio_tag'] = f'I-{etype}'

    # Group by sentence
    sentences = []
    current_sent_id = None
    current_sent = []

    for word in words:
        if word['sent_id'] != current_sent_id:
            if current_sent:
                sentences.append(current_sent)
            current_sent = []
            current_sent_id = word['sent_id']
        current_sent.append((word['norm'], word['bio_tag']))

    if current_sent:
        sentences.append(current_sent)

    return sentences


def split_corpus(input_dir: str, test_size=0.10, dev_size=0.09, random_seed=42):
    """
    Split corpus into train/dev/test by files (protocols).
    Returns: dict with keys train/dev/test including files, data, stats, entities.
    """
    random.seed(random_seed)

    # Get all files and their statistics
    files_stats = {}
    for fname in sorted(f for f in os.listdir(input_dir) if f.endswith(".json")):
        file_path = os.path.join(input_dir, fname)
        stats = get_file_statistics(file_path)
        files_stats[fname] = stats

    total_tokens = sum(s['tokens'] for s in files_stats.values())
    target_test_tokens = int(total_tokens * test_size)
    target_dev_tokens = int(total_tokens * dev_size)

    # Shuffle files
    all_files = list(files_stats.keys())
    random.shuffle(all_files)

    # Greedy allocation to get close to target sizes
    test_files, dev_files, train_files = [], [], []
    test_tokens = 0
    dev_tokens = 0

    for fname in all_files:
        tokens = files_stats[fname]['tokens']
        if test_tokens < target_test_tokens:
            test_files.append(fname)
            test_tokens += tokens
        elif dev_tokens < target_dev_tokens:
            dev_files.append(fname)
            dev_tokens += tokens
        else:
            train_files.append(fname)

    # Convert files to sentences and collect stats
    def load_files_data(file_list):
        all_sentences = []
        stats = defaultdict(int)
        entity_counts = defaultdict(int)

        for fname in sorted(file_list):
            file_path = os.path.join(input_dir, fname)
            sentences = convert_file_to_bio_sentences(file_path)
            all_sentences.extend(sentences)

            file_stats = files_stats[fname]
            stats['protocols'] += 1
            stats['sentences'] += file_stats['sentences']
            stats['tokens'] += file_stats['tokens']
            stats['entities'] += file_stats['entities']  # already excludes EVENT/UNK

            # Count entities by type from BIO (B- tags only)
            for sentence in sentences:
                for _, bio_tag in sentence:
                    if bio_tag.startswith('B-'):
                        etype = bio_tag[2:]
                        entity_counts[etype] += 1

        return all_sentences, dict(stats), dict(entity_counts)

    train_data, train_stats, train_entities = load_files_data(train_files)
    dev_data, dev_stats, dev_entities = load_files_data(dev_files)
    test_data, test_stats, test_entities = load_files_data(test_files)

    return {
        'train': {'files': train_files, 'data': train_data, 'stats': train_stats, 'entities': train_entities},
        'dev': {'files': dev_files, 'data': dev_data, 'stats': dev_stats, 'entities': dev_entities},
        'test': {'files': test_files, 'data': test_data, 'stats': test_stats, 'entities': test_entities},
    }


def write_bio_file(sentences, output_file: str):
    """Write sentences to BIO format TSV file."""
    with open(output_file, 'w', encoding='utf-8') as f:
        for sentence in sentences:
            for token, bio_tag in sentence:
                f.write(f"{token}\t{bio_tag}\n")
            f.write("\n")


def print_statistics_table(splits_data: dict):
    """Print a nice table with corpus statistics."""
    print("\n" + "=" * 80)
    print("CORPUS STATISTICS")
    print("=" * 80)

    print(f"{'Split':<10} {'Protocols':<12} {'Sentences':<12} {'Tokens':<12} {'Entities':<12}")
    print("-" * 80)

    for split_name in ['train', 'dev', 'test']:
        stats = splits_data[split_name]['stats']
        print(f"{split_name.capitalize():<10} "
              f"{stats.get('protocols', 0):<12} "
              f"{stats.get('sentences', 0):<12} "
              f"{stats.get('tokens', 0):<12} "
              f"{stats.get('entities', 0):<12}")

    print("-" * 80)
    total_protocols = sum(s['stats'].get('protocols', 0) for s in splits_data.values())
    total_sentences = sum(s['stats'].get('sentences', 0) for s in splits_data.values())
    total_tokens = sum(s['stats'].get('tokens', 0) for s in splits_data.values())
    total_entities = sum(s['stats'].get('entities', 0) for s in splits_data.values())

    print(f"{'TOTAL':<10} "
          f"{total_protocols:<12} "
          f"{total_sentences:<12} "
          f"{total_tokens:<12} "
          f"{total_entities:<12}")

    print("\n" + "=" * 80)
    print("PERCENTAGES (by tokens)")
    print("=" * 80)
    for split_name in ['train', 'dev', 'test']:
        tokens = splits_data[split_name]['stats'].get('tokens', 0)
        pct = 100 * tokens / total_tokens if total_tokens else 0.0
        print(f"{split_name.capitalize():<10} {pct:>6.2f}%")

    print("\n" + "=" * 80)
    print("ENTITY COUNTS BY TYPE (EVENT/UNK excluded)")
    print("=" * 80)

    all_types = set()
    for split_data in splits_data.values():
        all_types.update(split_data.get('entities', {}).keys())

    print(f"{'Type':<15} {'Train':<10} {'Dev':<10} {'Test':<10} {'Total':<10}")
    print("-" * 80)

    for etype in sorted(all_types):
        train_count = splits_data['train']['entities'].get(etype, 0)
        dev_count = splits_data['dev']['entities'].get(etype, 0)
        test_count = splits_data['test']['entities'].get(etype, 0)
        total_count = train_count + dev_count + test_count
        print(f"{etype:<15} {train_count:<10} {dev_count:<10} {test_count:<10} {total_count:<10}")

    print("=" * 80)


def save_file_distribution(splits_data: dict, output_file='file_distribution.txt'):
    """Save information about which files are in which split (entities exclude EVENT/UNK)."""
    with open(output_file, 'w', encoding='utf-8') as f:
        f.write("=" * 80 + "\n")
        f.write("FILE DISTRIBUTION ACROSS SPLITS\n")
        f.write("=" * 80 + "\n")
        f.write(f"\nNOTE: Total entities EXCLUDE types: {', '.join(sorted(EXCLUDED_ENTITY_TYPES))}\n\n")

        for split_name in ['train', 'dev', 'test']:
            files = splits_data[split_name]['files']
            stats = splits_data[split_name]['stats']

            f.write(f"{split_name.upper()} SET ({len(files)} files)\n")
            f.write("-" * 80 + "\n")
            f.write(f"Total tokens: {stats.get('tokens', 0)}\n")
            f.write(f"Total sentences: {stats.get('sentences', 0)}\n")
            f.write(f"Total entities: {stats.get('entities', 0)}\n\n")

            f.write("Files:\n")
            for fname in sorted(files):
                f.write(f"  - {fname}\n")
            f.write("\n\n")

        f.write("=" * 80 + "\n")

    print(f" File distribution saved to {output_file}")


if __name__ == "__main__":
    print("=" * 80)
    print("STEP 1.4: Splitting corpus into train/dev/test (EVENT/UNK excluded)")
    print("=" * 80)
    print(f"Random seed: {RANDOM_SEED}")
    print(f"Excluded entity types: {', '.join(sorted(EXCLUDED_ENTITY_TYPES))}")
    print(f"Target split: Train ~{100*(1-TEST_SIZE-DEV_SIZE):.0f}%, Dev ~{100*DEV_SIZE:.0f}%, Test ~{100*TEST_SIZE:.0f}%")

    splits_data = split_corpus(FILTERED_DIR, TEST_SIZE, DEV_SIZE, RANDOM_SEED)

    print("Writing split files...")
    write_bio_file(splits_data['train']['data'], TRAIN_FILE)
    print(f"   {TRAIN_FILE}")
    write_bio_file(splits_data['dev']['data'], DEV_FILE)
    print(f"   {DEV_FILE}")
    write_bio_file(splits_data['test']['data'], TEST_FILE)
    print(f"   {TEST_FILE}")

    print_statistics_table(splits_data)

    with open('corpus_statistics.txt', 'w', encoding='utf-8') as f:
        old_stdout = sys.stdout
        sys.stdout = f
        print_statistics_table(splits_data)
        sys.stdout = old_stdout
    print(f" Statistics saved to: corpus_statistics.txt")

    save_file_distribution(splits_data, 'file_distribution.txt')

    print("\n Train/dev/test split complete!")
    print("\nFiles created:")
    print(f"  - {TRAIN_FILE}")
    print(f"  - {DEV_FILE}")
    print(f"  - {TEST_FILE}")
    print(f"  - corpus_statistics.txt")
    print(f"  - file_distribution.txt")

STEP 1.4: Splitting corpus into train/dev/test (EVENT/UNK excluded)
Random seed: 42
Excluded entity types: EVENT, UNK
Target split: Train ~81%, Dev ~9%, Test ~10%
Writing split files...
   train.tsv
   dev.tsv
   test.tsv

CORPUS STATISTICS
Split      Protocols    Sentences    Tokens       Entities    
--------------------------------------------------------------------------------
Train      35           17643        225676       17483       
Dev        4            2170         27060        2121        
Test       5            2790         34773        2958        
--------------------------------------------------------------------------------
TOTAL      44           22603        287509       22562       

PERCENTAGES (by tokens)
Train       78.49%
Dev          9.41%
Test        12.09%

ENTITY COUNTS BY TYPE (EVENT/UNK excluded)
Type            Train      Dev        Test       Total     
--------------------------------------------------------------------------------
LAW            