In [6]:
import os
import random
import logging

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def simple_ordered_split(input_dir, output_dir, samples_per_class=4524, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, random_seed=42):
    """
    Simple, ordered class-balanced train/val/test split:
    1. Gather all lines of each class (TEXT, FORM, TABLE)
    2. Sample exactly samples_per_class items from each class without shuffling
    3. Take the first train_ratio for train, next val_ratio for val, remainder for test
    """
    random.seed(random_seed)
    
    # Create output directories
    os.makedirs(output_dir, exist_ok=True)
    train_dir = os.path.join(output_dir, "train")
    val_dir = os.path.join(output_dir, "val")
    test_dir = os.path.join(output_dir, "test")
    
    for directory in [train_dir, val_dir, test_dir]:
        os.makedirs(directory, exist_ok=True)
    
    # Step 1: Load all lines and labels from all files
    all_text_lines = []
    all_form_lines = []
    all_table_lines = []
    
    # Find all line and label files
    for filename in os.listdir(input_dir):
        if not filename.endswith('_lines.txt'):
            continue
            
        lines_file = os.path.join(input_dir, filename)
        labels_file = lines_file.replace('_lines.txt', '_labels.txt')
        
        if not os.path.exists(labels_file):
            logger.warning(f"Labels file not found for {lines_file}, skipping")
            continue
        
        # Read both files
        with open(lines_file, 'r', encoding='utf-8') as f:
            lines = [line.strip() for line in f.readlines()]
        
        with open(labels_file, 'r', encoding='utf-8') as f:
            labels = [label.strip() for label in f.readlines()]
        
        # Group by class
        for line, label in zip(lines, labels):
            if label == 'TEXT':
                all_text_lines.append(line)
            elif label == 'FORM':
                all_form_lines.append(line)
            elif label == 'TABLE':
                all_table_lines.append(line)
    
    logger.info(f"Loaded {len(all_text_lines)} TEXT lines, {len(all_form_lines)} FORM lines, {len(all_table_lines)} TABLE lines")
    
    # Step 2: Randomly sample from each class but keep order
    # Use random.sample to select indices, then sort them to maintain order
    if len(all_text_lines) > samples_per_class:
        text_indices = sorted(random.sample(range(len(all_text_lines)), samples_per_class))
        sampled_text = [all_text_lines[i] for i in text_indices]
    else:
        sampled_text = all_text_lines
        
    if len(all_form_lines) > samples_per_class:
        form_indices = sorted(random.sample(range(len(all_form_lines)), samples_per_class))
        sampled_form = [all_form_lines[i] for i in form_indices]
    else:
        sampled_form = all_form_lines
        
    if len(all_table_lines) > samples_per_class:
        table_indices = sorted(random.sample(range(len(all_table_lines)), samples_per_class))
        sampled_table = [all_table_lines[i] for i in table_indices]
    else:
        sampled_table = all_table_lines
    
    logger.info(f"Sampled {len(sampled_text)} TEXT, {len(sampled_form)} FORM, {len(sampled_table)} TABLE lines")
    
    # Step 3: Split each class in order (not randomly)
    def split_data(data, train_ratio, val_ratio):
        train_idx = int(len(data) * train_ratio)
        val_idx = train_idx + int(len(data) * val_ratio)
        
        train_data = data[:train_idx]
        val_data = data[train_idx:val_idx]
        test_data = data[val_idx:]
        
        return train_data, val_data, test_data
    
    # Split each class
    text_train, text_val, text_test = split_data(sampled_text, train_ratio, val_ratio)
    form_train, form_val, form_test = split_data(sampled_form, train_ratio, val_ratio)
    table_train, table_val, table_test = split_data(sampled_table, train_ratio, val_ratio)
    
    # Step 4: Create final splits with labels
    train_lines = text_train + form_train + table_train
    train_labels = ['TEXT'] * len(text_train) + ['FORM'] * len(form_train) + ['TABLE'] * len(table_train)
    
    val_lines = text_val + form_val + table_val
    val_labels = ['TEXT'] * len(text_val) + ['FORM'] * len(form_val) + ['TABLE'] * len(table_val)
    
    test_lines = text_test + form_test + table_test
    test_labels = ['TEXT'] * len(text_test) + ['FORM'] * len(form_test) + ['TABLE'] * len(table_test)
    
    # Step 5: Save splits
    with open(os.path.join(train_dir, "lines.txt"), 'w', encoding='utf-8') as f:
        for line in train_lines:
            f.write(f"{line}\n")
    
    with open(os.path.join(train_dir, "labels.txt"), 'w', encoding='utf-8') as f:
        for label in train_labels:
            f.write(f"{label}\n")
    
    with open(os.path.join(val_dir, "lines.txt"), 'w', encoding='utf-8') as f:
        for line in val_lines:
            f.write(f"{line}\n")
    
    with open(os.path.join(val_dir, "labels.txt"), 'w', encoding='utf-8') as f:
        for label in val_labels:
            f.write(f"{label}\n")
    
    with open(os.path.join(test_dir, "lines.txt"), 'w', encoding='utf-8') as f:
        for line in test_lines:
            f.write(f"{line}\n")
    
    with open(os.path.join(test_dir, "labels.txt"), 'w', encoding='utf-8') as f:
        for label in test_labels:
            f.write(f"{label}\n")
    
    # Step 6: Log statistics
    train_counts = {'TEXT': len(text_train), 'FORM': len(form_train), 'TABLE': len(table_train)}
    val_counts = {'TEXT': len(text_val), 'FORM': len(form_val), 'TABLE': len(table_val)}
    test_counts = {'TEXT': len(text_test), 'FORM': len(form_test), 'TABLE': len(table_test)}
    
    logger.info("\nTrain/Val/Test split complete:")
    
    logger.info(f"Train set: {len(train_lines)} examples")
    for cls, count in sorted(train_counts.items()):
        logger.info(f"  {cls}: {count} examples ({count/len(train_lines)*100:.1f}%)")
    
    logger.info(f"\nValidation set: {len(val_lines)} examples")
    for cls, count in sorted(val_counts.items()):
        logger.info(f"  {cls}: {count} examples ({count/len(val_lines)*100:.1f}%)")
    
    logger.info(f"\nTest set: {len(test_lines)} examples")
    for cls, count in sorted(test_counts.items()):
        logger.info(f"  {cls}: {count} examples ({count/len(test_lines)*100:.1f}%)")
    
    # Return statistics
    return {
        'train': {'examples': len(train_lines), 'class_counts': train_counts},
        'validation': {'examples': len(val_lines), 'class_counts': val_counts},
        'test': {'examples': len(test_lines), 'class_counts': test_counts}
    }

# Example usage
input_dir = "balanced_output"  # Directory with your *_lines.txt and *_labels.txt files
output_dir = "train_val_test"  # Where to save the splits

splits_info = simple_ordered_split(
    input_dir=input_dir,
    output_dir=output_dir,
    samples_per_class=4524,     # Take exactly 4524 samples from each class
    train_ratio=0.7,
    val_ratio=0.15,
    test_ratio=0.15,
    random_seed=42
)

print("\nSplit complete!")
print(f"Train set: {splits_info['train']['examples']} examples")
print(f"Validation set: {splits_info['validation']['examples']} examples")
print(f"Test set: {splits_info['test']['examples']} examples")

2025-03-22 16:16:51,655 - INFO - Loaded 7673 TEXT lines, 7437 FORM lines, 4524 TABLE lines
2025-03-22 16:16:51,658 - INFO - Sampled 4524 TEXT, 4524 FORM, 4524 TABLE lines
2025-03-22 16:16:51,662 - INFO - 
Train/Val/Test split complete:
2025-03-22 16:16:51,662 - INFO - Train set: 9498 examples
2025-03-22 16:16:51,662 - INFO -   FORM: 3166 examples (33.3%)
2025-03-22 16:16:51,662 - INFO -   TABLE: 3166 examples (33.3%)
2025-03-22 16:16:51,663 - INFO -   TEXT: 3166 examples (33.3%)
2025-03-22 16:16:51,663 - INFO - 
Validation set: 2034 examples
2025-03-22 16:16:51,663 - INFO -   FORM: 678 examples (33.3%)
2025-03-22 16:16:51,663 - INFO -   TABLE: 678 examples (33.3%)
2025-03-22 16:16:51,663 - INFO -   TEXT: 678 examples (33.3%)
2025-03-22 16:16:51,663 - INFO - 
Test set: 2040 examples
2025-03-22 16:16:51,663 - INFO -   FORM: 680 examples (33.3%)
2025-03-22 16:16:51,663 - INFO -   TABLE: 680 examples (33.3%)
2025-03-22 16:16:51,664 - INFO -   TEXT: 680 examples (33.3%)



Split complete!
Train set: 9498 examples
Validation set: 2034 examples
Test set: 2040 examples
