In [2]:
# Import necessary libraries
import os
import re
import logging
import json
from typing import List, Dict, Tuple

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def split_text_by_length(text: str, target_length: int) -> List[str]:
    """
    Split text into segments of approximately target_length characters.
    Try to split on word boundaries.
    """
    words = text.split()
    segments = []
    current_segment = ""
    
    for word in words:
        # If adding this word would exceed target length and we already have content,
        # finish the current segment
        if current_segment and len(current_segment) + len(word) + 1 > target_length:
            segments.append(current_segment.strip())
            current_segment = ""
        
        # Add word to current segment
        current_segment += " " + word if current_segment else word
    
    # Add any remaining text
    if current_segment:
        segments.append(current_segment.strip())
    
    return segments

def split_text_on_sentences(text: str, target_length: int) -> List[str]:
    """
    Split text on sentence boundaries, trying to keep segments close to target_length.
    """
    # Simple sentence boundary detection (period, question mark, exclamation mark followed by space)
    sentence_end_pattern = r'[.!?](?=\s|$)'
    
    # Split into sentences
    sentences = re.split(sentence_end_pattern, text)
    sentences = [s + '.' if not s.endswith(('.', '!', '?')) else s for s in sentences if s.strip()]
    
    segments = []
    current_segment = ""
    
    for sentence in sentences:
        # If adding this sentence would exceed target length and we already have content,
        # finish the current segment
        if current_segment and len(current_segment) + len(sentence) > target_length:
            segments.append(current_segment.strip())
            current_segment = ""
        
        # If the sentence itself is longer than target_length, split it further
        if len(sentence) > target_length:
            if current_segment:
                segments.append(current_segment.strip())
                current_segment = ""
            
            # Split long sentence by length
            sentence_segments = split_text_by_length(sentence, target_length)
            segments.extend(sentence_segments)
        else:
            current_segment += " " + sentence if current_segment else sentence
    
    # Add any remaining text
    if current_segment:
        segments.append(current_segment.strip())
    
    return segments

def split_long_text_lines(
    input_dir: str,
    output_dir: str,
    target_length: int = 100,
    split_on_sentence: bool = True
) -> Dict:
    """
    Split long TEXT lines into shorter segments to balance the dataset.
    
    Args:
        input_dir: Directory containing *_lines.txt and *_labels.txt files
        output_dir: Directory to save processed files
        target_length: Target character length for TEXT lines
        split_on_sentence: Whether to try to split on sentence boundaries
        
    Returns:
        Dict containing statistics about the processing
    """
    os.makedirs(output_dir, exist_ok=True)
    
    stats = {
        'files_processed': 0,
        'original_lines': {'TEXT': 0, 'FORM': 0, 'TABLE': 0},
        'processed_lines': {'TEXT': 0, 'FORM': 0, 'TABLE': 0},
        'split_lines': 0,
        'original_char_counts': {'TEXT': 0, 'FORM': 0, 'TABLE': 0},
        'processed_char_counts': {'TEXT': 0, 'FORM': 0, 'TABLE': 0}
    }

    # Process each file pair in the directory
    for filename in os.listdir(input_dir):
        if not filename.endswith('_lines.txt'):
            continue
            
        lines_file = os.path.join(input_dir, filename)
        labels_file = os.path.join(input_dir, filename.replace('_lines.txt', '_labels.txt'))
        
        if not os.path.exists(labels_file):
            logger.warning(f"Labels file not found for {lines_file}, skipping")
            continue
        
        # Read files
        with open(lines_file, 'r', encoding='utf-8') as f:
            lines = [line.strip() for line in f.readlines()]
        
        with open(labels_file, 'r', encoding='utf-8') as f:
            labels = [label.strip() for label in f.readlines()]
        
        # Process lines and labels
        new_lines = []
        new_labels = []
        
        for i, (line, label) in enumerate(zip(lines, labels)):
            # Count original stats
            stats['original_lines'][label] = stats['original_lines'].get(label, 0) + 1
            stats['original_char_counts'][label] = stats['original_char_counts'].get(label, 0) + len(line)
            
            # Only split TEXT lines that are longer than target_length
            if label == 'TEXT' and len(line) > target_length:
                # Try to split on sentence boundaries if requested
                if split_on_sentence:
                    segments = split_text_on_sentences(line, target_length)
                else:
                    segments = split_text_by_length(line, target_length)
                
                new_lines.extend(segments)
                new_labels.extend(['TEXT'] * len(segments))
                
                stats['split_lines'] += 1
                stats['processed_lines']['TEXT'] += len(segments)
                stats['processed_char_counts']['TEXT'] += sum(len(segment) for segment in segments)
            else:
                # Keep non-TEXT lines or short TEXT lines as is
                new_lines.append(line)
                new_labels.append(label)
                stats['processed_lines'][label] = stats['processed_lines'].get(label, 0) + 1
                stats['processed_char_counts'][label] = stats['processed_char_counts'].get(label, 0) + len(line)
        
        # Write processed files
        output_lines_file = os.path.join(output_dir, filename)
        output_labels_file = os.path.join(output_dir, filename.replace('_lines.txt', '_labels.txt'))
        
        with open(output_lines_file, 'w', encoding='utf-8') as f:
            for line in new_lines:
                f.write(f"{line}\n")
        
        with open(output_labels_file, 'w', encoding='utf-8') as f:
            for label in new_labels:
                f.write(f"{label}\n")
        
        stats['files_processed'] += 1
        logger.info(f"Processed {filename}: {len(lines)} original lines -> {len(new_lines)} processed lines")
    
    # Save statistics
    stats_file = os.path.join(output_dir, "line_splitting_stats.json")
    with open(stats_file, 'w', encoding='utf-8') as f:
        json.dump(stats, f, indent=2)
    
    # Log summary statistics
    logger.info("\nText line splitting complete:")
    logger.info(f"Files processed: {stats['files_processed']}")
    logger.info(f"TEXT lines: {stats['original_lines']['TEXT']} original -> {stats['processed_lines']['TEXT']} processed")
    logger.info(f"FORM lines: {stats['original_lines']['FORM']} original -> {stats['processed_lines']['FORM']} processed")
    logger.info(f"TABLE lines: {stats['original_lines']['TABLE']} original -> {stats['processed_lines']['TABLE']} processed")
    
    # Calculate and log average lengths
    for label in ['TEXT', 'FORM', 'TABLE']:
        orig_avg_len = stats['original_char_counts'][label] / max(1, stats['original_lines'][label])
        proc_avg_len = stats['processed_char_counts'][label] / max(1, stats['processed_lines'][label])
        logger.info(f"{label} average length: {orig_avg_len:.1f} chars -> {proc_avg_len:.1f} chars")
    
    return stats

# Now use the function in your notebook
input_dir = "problem2_output"  # Your directory with lines and labels files
output_dir = "balanced_output"  # Where to save the processed files
target_length = 100  # Target character length for TEXT lines
split_on_sentence = True  # Try to split on sentence boundaries

# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Run the splitting function
stats = split_long_text_lines(
    input_dir=input_dir,
    output_dir=output_dir,
    target_length=target_length,
    split_on_sentence=split_on_sentence
)

# Display results
print("\nOriginal dataset:")
print(f"TEXT: {stats['original_lines']['TEXT']} lines ({stats['original_lines']['TEXT']/sum(stats['original_lines'].values())*100:.1f}%)")
print(f"FORM: {stats['original_lines']['FORM']} lines ({stats['original_lines']['FORM']/sum(stats['original_lines'].values())*100:.1f}%)")
print(f"TABLE: {stats['original_lines']['TABLE']} lines ({stats['original_lines']['TABLE']/sum(stats['original_lines'].values())*100:.1f}%)")

print("\nProcessed dataset:")
print(f"TEXT: {stats['processed_lines']['TEXT']} lines ({stats['processed_lines']['TEXT']/sum(stats['processed_lines'].values())*100:.1f}%)")
print(f"FORM: {stats['processed_lines']['FORM']} lines ({stats['processed_lines']['FORM']/sum(stats['processed_lines'].values())*100:.1f}%)")
print(f"TABLE: {stats['processed_lines']['TABLE']} lines ({stats['processed_lines']['TABLE']/sum(stats['processed_lines'].values())*100:.1f}%)")

2025-03-22 15:53:20,658 - INFO - Processed document_part_classification_345_lines.txt: 9 original lines -> 65 processed lines
2025-03-22 15:53:20,659 - INFO - Processed document_part_classification_459_lines.txt: 12 original lines -> 36 processed lines
2025-03-22 15:53:20,660 - INFO - Processed document_part_classification_173_lines.txt: 9 original lines -> 9 processed lines
2025-03-22 15:53:20,661 - INFO - Processed document_part_classification_490_lines.txt: 52 original lines -> 52 processed lines
2025-03-22 15:53:20,662 - INFO - Processed document_part_classification_203_lines.txt: 13 original lines -> 44 processed lines
2025-03-22 15:53:20,663 - INFO - Processed document_part_classification_279_lines.txt: 27 original lines -> 27 processed lines
2025-03-22 15:53:20,663 - INFO - Processed document_part_classification_565_lines.txt: 14 original lines -> 74 processed lines
2025-03-22 15:53:20,664 - INFO - Processed document_part_classification_615_lines.txt: 5 original lines -> 5 proce


Original dataset:
TEXT: 891 lines (6.9%)
FORM: 7437 lines (57.9%)
TABLE: 4524 lines (35.2%)

Processed dataset:
TEXT: 7673 lines (39.1%)
FORM: 7437 lines (37.9%)
TABLE: 4524 lines (23.0%)
