In [3]:
#!/usr/bin/env python3
"""
Script to split JSON dataset into train/validation/test sets
ensuring all labels appear in test set with good representation.
"""

import json
import os
import random
from collections import defaultdict, Counter
from typing import List, Dict, Tuple, Set

def load_json_data(filepath: str) -> List[Dict]:
    """Load JSON data from file."""
    with open(filepath, 'r', encoding='utf-8') as f:
        return json.load(f)

def save_json_data(data: List[Dict], filepath: str) -> None:
    """Save data to JSON file."""
    os.makedirs(os.path.dirname(filepath), exist_ok=True)
    with open(filepath, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=2)

def get_labels_from_sample(sample: Dict) -> Set[str]:
    """Extract unique labels from a sample."""
    return {label['label'] for label in sample.get('labels', [])}

def analyze_label_distribution(data: List[Dict]) -> Dict[str, int]:
    """Analyze label frequency across the dataset."""
    label_counts = Counter()
    for sample in data:
        labels = get_labels_from_sample(sample)
        label_counts.update(labels)
    return dict(label_counts)

def get_samples_with_labels(data: List[Dict], target_labels: Set[str]) -> List[int]:
    """Get indices of samples that contain any of the target labels."""
    indices = []
    for i, sample in enumerate(data):
        sample_labels = get_labels_from_sample(sample)
        if sample_labels.intersection(target_labels):
            indices.append(i)
    return indices

def stratified_split(data: List[Dict], test_size: int = 50, val_size: int = 30) -> Tuple[List[Dict], List[Dict], List[Dict]]:
    """
    Split data ensuring all labels appear in test set with good representation.

    Strategy:
    1. Identify all unique labels
    2. For rare labels, ensure they appear multiple times in test set
    3. Use stratified sampling to maintain label distribution
    4. Remaining samples go to train/val
    """
    print(f"Total samples: {len(data)}")

    # Analyze label distribution
    label_counts = analyze_label_distribution(data)
    all_labels = set(label_counts.keys())
    print(f"Total unique labels: {len(all_labels)}")
    print(f"Label distribution: {label_counts}")

    # Find rare labels (appearing in <5% of samples)
    rare_threshold = len(data) * 0.05
    rare_labels = {label for label, count in label_counts.items() if count < rare_threshold}
    print(f"Rare labels (< {rare_threshold:.1f} samples): {rare_labels}")

    # Create sample-to-labels mapping
    sample_labels = [get_labels_from_sample(sample) for sample in data]

    # Initialize sets
    test_indices = set()
    used_indices = set()
    labels_in_test = set()

    # Step 1: Ensure each label appears at least twice in test set
    for label in all_labels:
        samples_with_label = [i for i, labels in enumerate(sample_labels)
                             if label in labels and i not in used_indices]

        if len(samples_with_label) == 0:
            print(f"Warning: Label '{label}' not found in available samples")
            continue

        # For rare labels, try to get more samples in test set
        target_count = min(3 if label in rare_labels else 2, len(samples_with_label))

        # Randomly select samples for this label
        selected = random.sample(samples_with_label, target_count)
        test_indices.update(selected)
        used_indices.update(selected)
        labels_in_test.add(label)

        print(f"Added {len(selected)} samples for label '{label}' to test set")

    # Step 2: Fill remaining test set slots with diverse samples
    remaining_samples = [i for i in range(len(data)) if i not in used_indices]
    remaining_needed = test_size - len(test_indices)

    if remaining_needed > 0 and len(remaining_samples) >= remaining_needed:
        # Prefer samples with multiple labels for diversity
        remaining_samples.sort(key=lambda i: len(sample_labels[i]), reverse=True)
        additional_test = remaining_samples[:remaining_needed]
        test_indices.update(additional_test)
        used_indices.update(additional_test)

        print(f"Added {len(additional_test)} additional samples to test set")

    # Step 3: Split remaining samples into train/val
    remaining_samples = [i for i in range(len(data)) if i not in used_indices]
    random.shuffle(remaining_samples)

    val_indices = set(remaining_samples[:val_size])
    train_indices = set(remaining_samples[val_size:])

    # Create final datasets
    test_data = [data[i] for i in test_indices]
    val_data = [data[i] for i in val_indices]
    train_data = [data[i] for i in train_indices]

    # Verify all labels are in test set
    test_labels = set()
    for sample in test_data:
        test_labels.update(get_labels_from_sample(sample))

    missing_labels = all_labels - test_labels
    if missing_labels:
        print(f"Warning: Labels missing from test set: {missing_labels}")
    else:
        print("✓ All labels are present in test set")

    return train_data, val_data, test_data

def print_split_statistics(train_data: List[Dict], val_data: List[Dict], test_data: List[Dict]) -> None:
    """Print statistics about the data split."""
    print("\n" + "="*50)
    print("DATASET SPLIT STATISTICS")
    print("="*50)

    datasets = [
        ("Training", train_data),
        ("Validation", val_data),
        ("Test", test_data)
    ]

    for name, data in datasets:
        print(f"\n{name} Set:")
        print(f"  Samples: {len(data)}")

        # Count labels
        label_counts = analyze_label_distribution(data)
        print(f"  Unique labels: {len(label_counts)}")
        print(f"  Total label instances: {sum(label_counts.values())}")

        # Show label distribution
        sorted_labels = sorted(label_counts.items(), key=lambda x: x[1], reverse=True)
        print(f"  Label distribution: {dict(sorted_labels)}")

def main():
    """Main function to execute the dataset split."""
    # Set random seed for reproducibility
    random.seed(42)

    # Define paths
    input_path = "./original_with_spans.json"
    output_dir = "./granular_dataset_split"

    # Load data
    print("Loading data...")
    try:
        data = load_json_data(input_path)
    except FileNotFoundError:
        print(f"Error: Could not find input file at {input_path}")
        return
    except json.JSONDecodeError:
        print(f"Error: Invalid JSON in file {input_path}")
        return

    print(f"Loaded {len(data)} samples")

    # Split data
    print("\nSplitting data...")
    train_data, val_data, test_data = stratified_split(data, test_size=50, val_size=30)

    # Save split datasets
    print(f"\nSaving datasets to {output_dir}...")
    save_json_data(train_data, os.path.join(output_dir, "train.json"))
    save_json_data(val_data, os.path.join(output_dir, "validation.json"))
    save_json_data(test_data, os.path.join(output_dir, "test.json"))

    # Print statistics
    print_split_statistics(train_data, val_data, test_data)

    print(f"\n✓ Dataset split completed successfully!")
    print(f"Files saved in: {output_dir}")

if __name__ == "__main__":
    main()

Loading data...
Loaded 160 samples

Splitting data...
Total samples: 160
Total unique labels: 21
Label distribution: {'ZAHLUNG': 22, 'GESENDET_MIT': 25, 'LINK': 7, 'WOHNORT': 58, 'HAUSNUMMER': 57, 'VORNAME': 145, 'NACHNAME': 153, 'POSTLEITZAHL': 58, 'VERTRAGSNUMMER': 63, 'STRASSE': 59, 'TITEL': 13, 'DATUM': 53, 'IBAN': 6, 'ZÄHLERNUMMER': 36, 'TELEFONNUMMER': 33, 'ZÄHLERSTAND': 10, 'FIRMA': 19, 'FAX': 5, 'BANK': 3, 'BIC': 1, 'EMAIL': 13}
Rare labels (< 8.0 samples): {'BANK', 'BIC', 'LINK', 'IBAN', 'FAX'}
Added 2 samples for label 'ZAHLUNG' to test set
Added 3 samples for label 'LINK' to test set
Added 2 samples for label 'TELEFONNUMMER' to test set
Added 2 samples for label 'VERTRAGSNUMMER' to test set
Added 2 samples for label 'STRASSE' to test set
Added 2 samples for label 'WOHNORT' to test set
Added 3 samples for label 'BANK' to test set
Added 2 samples for label 'DATUM' to test set
Added 2 samples for label 'EMAIL' to test set
Added 2 samples for label 'GESENDET_MIT' to test set
Add