In [4]:
#!/usr/bin/env python3
"""
Script to split JSON dataset into train/validation/test sets
ensuring all labels appear in test set with good representation.
"""

import json
import os
import random
from collections import defaultdict, Counter
from typing import List, Dict, Tuple, Set

def load_json_data(filepath: str) -> List[Dict]:
    """Load JSON data from file."""
    with open(filepath, 'r', encoding='utf-8') as f:
        return json.load(f)

def save_json_data(data: List[Dict], filepath: str) -> None:
    """Save data to JSON file."""
    os.makedirs(os.path.dirname(filepath), exist_ok=True)
    with open(filepath, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=2)

def get_labels_from_sample(sample: Dict) -> Set[str]:
    """Extract unique labels from a sample."""
    return {label['label'] for label in sample.get('labels', [])}

def analyze_label_distribution(data: List[Dict]) -> Dict[str, int]:
    """Analyze label frequency across the dataset."""
    label_counts = Counter()
    for sample in data:
        labels = get_labels_from_sample(sample)
        label_counts.update(labels)
    return dict(label_counts)

def get_samples_with_labels(data: List[Dict], target_labels: Set[str]) -> List[int]:
    """Get indices of samples that contain any of the target labels."""
    indices = []
    for i, sample in enumerate(data):
        sample_labels = get_labels_from_sample(sample)
        if sample_labels.intersection(target_labels):
            indices.append(i)
    return indices

def stratified_split(data: List[Dict], test_size: int = 40, val_size: int = 25) -> Tuple[List[Dict], List[Dict], List[Dict]]:
    """
    Split data ensuring all labels appear in all sets with good representation.

    Strategy:
    1. Identify all unique labels and their frequencies
    2. Ensure each label appears in all sets (except BIC stays in training)
    3. Use stratified approach to maintain proportional distribution
    4. Prioritize rare labels to appear in multiple sets
    """
    print(f"Total samples: {len(data)}")

    # Analyze label distribution
    label_counts = analyze_label_distribution(data)
    all_labels = set(label_counts.keys())
    print(f"Total unique labels: {len(all_labels)}")
    print(f"Label distribution: {label_counts}")

    # BIC should only be in training set
    bic_only_labels = {'BIC'}
    labels_for_all_sets = all_labels - bic_only_labels

    # Find rare labels (appearing in <=5 samples)
    rare_labels = {label for label, count in label_counts.items() if count <= 5}
    print(f"Rare labels (<=5 samples): {rare_labels}")

    # Create sample-to-labels mapping
    sample_labels = [get_labels_from_sample(sample) for sample in data]

    # Initialize sets
    test_indices = set()
    val_indices = set()
    train_indices = set()

    # Step 1: Distribute samples with rare labels across all sets
    for label in rare_labels:
        if label in bic_only_labels:
            continue  # BIC will be handled separately

        samples_with_label = [i for i, labels in enumerate(sample_labels)
                             if label in labels]

        if len(samples_with_label) == 0:
            continue

        # Ensure this rare label appears in all sets
        random.shuffle(samples_with_label)

        # Distribute: at least 1 in each set, rest proportionally
        if len(samples_with_label) >= 3:
            # At least 1 in each set
            test_indices.add(samples_with_label[0])
            val_indices.add(samples_with_label[1])
            train_indices.add(samples_with_label[2])

            # Distribute remaining proportionally
            remaining = samples_with_label[3:]
            if remaining:
                # Rough proportion: 40% test, 25% val, 35% train
                n_test = max(1, int(len(remaining) * 0.4))
                n_val = max(1, int(len(remaining) * 0.25))

                test_indices.update(remaining[:n_test])
                val_indices.update(remaining[n_test:n_test+n_val])
                train_indices.update(remaining[n_test+n_val:])
        else:
            # Very rare labels: put in training set primarily
            train_indices.update(samples_with_label)

        print(f"Distributed {len(samples_with_label)} samples for rare label '{label}'")

    # Step 2: Handle BIC samples (only in training)
    bic_samples = [i for i, labels in enumerate(sample_labels)
                   if 'BIC' in labels and i not in train_indices]
    train_indices.update(bic_samples)
    print(f"Added {len(bic_samples)} BIC samples to training set")

    # Step 3: Fill remaining slots with common labels
    used_indices = test_indices | val_indices | train_indices
    remaining_samples = [i for i in range(len(data)) if i not in used_indices]

    # Shuffle for randomness
    random.shuffle(remaining_samples)

    # Calculate how many more samples we need for each set
    test_needed = max(0, test_size - len(test_indices))
    val_needed = max(0, val_size - len(val_indices))

    # Fill test set
    if test_needed > 0:
        test_additional = remaining_samples[:test_needed]
        test_indices.update(test_additional)
        remaining_samples = remaining_samples[test_needed:]

    # Fill validation set
    if val_needed > 0:
        val_additional = remaining_samples[:val_needed]
        val_indices.update(val_additional)
        remaining_samples = remaining_samples[val_needed:]

    # Rest goes to training
    train_indices.update(remaining_samples)

    # Create final datasets
    test_data = [data[i] for i in test_indices]
    val_data = [data[i] for i in val_indices]
    train_data = [data[i] for i in train_indices]

    # Verify label distribution
    test_labels = set()
    val_labels = set()
    train_labels = set()

    for sample in test_data:
        test_labels.update(get_labels_from_sample(sample))
    for sample in val_data:
        val_labels.update(get_labels_from_sample(sample))
    for sample in train_data:
        train_labels.update(get_labels_from_sample(sample))

    print(f"\nLabel coverage:")
    print(f"Test set: {len(test_labels)} labels")
    print(f"Validation set: {len(val_labels)} labels")
    print(f"Training set: {len(train_labels)} labels")

    # Check for labels missing from train/val (except BIC)
    missing_from_test = labels_for_all_sets - test_labels
    missing_from_val = labels_for_all_sets - val_labels
    missing_from_train = (all_labels - train_labels)

    if missing_from_test:
        print(f"Warning: Labels missing from test set: {missing_from_test}")
    if missing_from_val:
        print(f"Warning: Labels missing from validation set: {missing_from_val}")
    if missing_from_train:
        print(f"Warning: Labels missing from training set: {missing_from_train}")

    return train_data, val_data, test_data

def print_split_statistics(train_data: List[Dict], val_data: List[Dict], test_data: List[Dict]) -> None:
    """Print statistics about the data split."""
    print("\n" + "="*50)
    print("DATASET SPLIT STATISTICS")
    print("="*50)

    datasets = [
        ("Training", train_data),
        ("Validation", val_data),
        ("Test", test_data)
    ]

    for name, data in datasets:
        print(f"\n{name} Set:")
        print(f"  Samples: {len(data)}")

        # Count labels
        label_counts = analyze_label_distribution(data)
        print(f"  Unique labels: {len(label_counts)}")
        print(f"  Total label instances: {sum(label_counts.values())}")

        # Show label distribution
        sorted_labels = sorted(label_counts.items(), key=lambda x: x[1], reverse=True)
        print(f"  Label distribution: {dict(sorted_labels)}")

def main():
    """Main function to execute the dataset split."""
    # Set random seed for reproducibility
    random.seed(42)

    # Define paths
    input_path = "../../../data/original/golden_dataset_with_spans_norm.json"
    output_dir = "../../../data/original/granular_dataset_split_norm"

    # Load data
    print("Loading data...")
    try:
        data = load_json_data(input_path)
    except FileNotFoundError:
        print(f"Error: Could not find input file at {input_path}")
        return
    except json.JSONDecodeError:
        print(f"Error: Invalid JSON in file {input_path}")
        return

    print(f"Loaded {len(data)} samples")

    # Split data
    print("\nSplitting data...")
    train_data, val_data, test_data = stratified_split(data, test_size=40, val_size=25)

    # Save split datasets
    print(f"\nSaving datasets to {output_dir}...")
    save_json_data(train_data, os.path.join(output_dir, "train_norm.json"))
    save_json_data(val_data, os.path.join(output_dir, "validation_norm.json"))
    save_json_data(test_data, os.path.join(output_dir, "test_norm.json"))

    # Print statistics
    print_split_statistics(train_data, val_data, test_data)

    print(f"\n✓ Dataset split completed successfully!")
    print(f"Files saved in: {output_dir}")

if __name__ == "__main__":
    main()

Loading data...
Loaded 160 samples

Splitting data...
Total samples: 160
Total unique labels: 21
Label distribution: {'ZAHLUNG': 12, 'LINK': 7, 'GESENDET_MIT': 25, 'WOHNORT': 58, 'NACHNAME': 153, 'VORNAME': 145, 'HAUSNUMMER': 56, 'STRASSE': 59, 'VERTRAGSNUMMER': 61, 'POSTLEITZAHL': 55, 'DATUM': 53, 'IBAN': 6, 'TITEL': 13, 'TELEFONNUMMER': 32, 'ZÄHLERNUMMER': 35, 'ZÄHLERSTAND': 10, 'FAX': 4, 'FIRMA': 18, 'BIC': 1, 'BANK': 3, 'EMAIL': 13}
Rare labels (<=5 samples): {'FAX', 'BIC', 'BANK'}
Distributed 4 samples for rare label 'FAX'
Distributed 3 samples for rare label 'BANK'
Added 0 BIC samples to training set

Label coverage:
Test set: 20 labels
Validation set: 20 labels
Training set: 21 labels

Saving datasets to ../../../data/original/granular_dataset_split_norm...

DATASET SPLIT STATISTICS

Training Set:
  Samples: 95
  Unique labels: 21
  Total label instances: 510
  Label distribution: {'NACHNAME': 93, 'VORNAME': 88, 'STRASSE': 42, 'WOHNORT': 40, 'HAUSNUMMER': 39, 'POSTLEITZAHL': 38,