In [1]:
import os
import random
import logging
from collections import defaultdict
import json

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def create_balanced_dataset(input_dir, output_dir, target_count=4524):
    """
    Create a balanced dataset by:
    1. Keeping all TABLE lines
    2. Randomly sampling TEXT and FORM lines to match the number of TABLE lines
    3. Preserving document ordering to enable proper train/val/test splits later
    
    Args:
        input_dir: Directory with processed lines and labels files
        output_dir: Directory to save the balanced dataset
        target_count: Number of examples to sample for each class (default: 4524 for TABLE)
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # Step 1: Read all files and organize by document
    documents = {}
    
    for filename in os.listdir(input_dir):
        if not filename.endswith('_lines.txt'):
            continue
            
        doc_id = filename.replace('_lines.txt', '')
        lines_file = os.path.join(input_dir, filename)
        labels_file = os.path.join(input_dir, filename.replace('_lines.txt', '_labels.txt'))
        
        if not os.path.exists(labels_file):
            logger.warning(f"Labels file not found for {lines_file}, skipping")
            continue
        
        # Read files
        with open(lines_file, 'r', encoding='utf-8') as f:
            lines = [line.strip() for line in f.readlines()]
        
        with open(labels_file, 'r', encoding='utf-8') as f:
            labels = [label.strip() for label in f.readlines()]
        
        documents[doc_id] = {
            'lines': lines,
            'labels': labels
        }
    
    logger.info(f"Read {len(documents)} documents")
    
    # Step 2: Create class-specific indices per document
    class_indices = defaultdict(list)
    
    for doc_id, doc in documents.items():
        for i, label in enumerate(doc['labels']):
            # Store as tuple: (document_id, line_index)
            class_indices[label].append((doc_id, i))
    
    # Count instances per class
    class_counts = {label: len(indices) for label, indices in class_indices.items()}
    logger.info(f"Class distribution: {class_counts}")
    
    # Step 3: Sample from TEXT and FORM while keeping all TABLE
    sampled_indices = {}
    
    # Keep all TABLE instances
    sampled_indices['TABLE'] = class_indices['TABLE']
    
    # Sample from TEXT and FORM to match target count
    for label in ['TEXT', 'FORM']:
        if len(class_indices[label]) <= target_count:
            # Keep all if less than target
            sampled_indices[label] = class_indices[label]
        else:
            # Random sample without replacement
            sampled_indices[label] = random.sample(class_indices[label], target_count)
    
    # Step 4: Collect all sampled data while preserving document order
    all_indices = []
    for label in ['TEXT', 'FORM', 'TABLE']:
        all_indices.extend(sampled_indices[label])
    
    # Sort by document ID and then by line index to preserve document ordering
    all_indices.sort()
    
    # Create final lists
    balanced_lines = []
    balanced_labels = []
    doc_boundaries = {}  # Track where each document starts/ends
    
    current_doc = None
    for doc_id, line_idx in all_indices:
        if current_doc != doc_id:
            if current_doc is not None:
                doc_boundaries[current_doc]['end'] = len(balanced_lines) - 1
            
            current_doc = doc_id
            doc_boundaries[current_doc] = {'start': len(balanced_lines)}
        
        balanced_lines.append(documents[doc_id]['lines'][line_idx])
        balanced_labels.append(documents[doc_id]['labels'][line_idx])
    
    # Add the final document end
    if current_doc is not None:
        doc_boundaries[current_doc]['end'] = len(balanced_lines) - 1
    
    # Step 5: Save the balanced dataset
    with open(os.path.join(output_dir, "all_lines.txt"), 'w', encoding='utf-8') as f:
        for line in balanced_lines:
            f.write(f"{line}\n")
    
    with open(os.path.join(output_dir, "all_labels.txt"), 'w', encoding='utf-8') as f:
        for label in balanced_labels:
            f.write(f"{label}\n")
    
    # Save document boundaries for potential document-aware train/val/test split
    with open(os.path.join(output_dir, "document_boundaries.json"), 'w', encoding='utf-8') as f:
        json.dump(doc_boundaries, f, indent=2)
    
    # Count final class distribution
    final_counts = defaultdict(int)
    for label in balanced_labels:
        final_counts[label] += 1
    
    logger.info(f"Created balanced dataset with {len(balanced_lines)} total examples")
    logger.info(f"Final class distribution: {dict(final_counts)}")

if __name__ == "__main__":
    # Example usage
    input_dir = "balanced_output"
    output_dir = "balanced_sampled_by_document"
    target_count = 4524  # Match the number of TABLE instances
    
    create_balanced_dataset(input_dir, output_dir, target_count)

2025-03-22 16:02:31,313 - INFO - Read 656 documents
2025-03-22 16:02:31,315 - INFO - Class distribution: {'FORM': 7437, 'TEXT': 7673, 'TABLE': 4524}
2025-03-22 16:02:31,326 - INFO - Created balanced dataset with 13572 total examples
2025-03-22 16:02:31,326 - INFO - Final class distribution: {'FORM': 4524, 'TEXT': 4524, 'TABLE': 4524}
