# Data Preprocess

<hr>

In [1]:
import os
import re
import torch
import json
import numpy as np
import pandas as pd
import urllib.request
import tarfile
from pathlib import Path
from typing import List, Dict, Tuple, Optional
from collections import Counter
from torch.utils.data import Dataset, DataLoader
import html
from sklearn.model_selection import train_test_split

## DataDownloader - Data Download and Validation Module

**Function Methods:**
- download_imdb() - Download IMDb dataset archive and extract
- validate_data_directory() - Validate data directory structure integrity
- setup_data() - Main entry point to ensure data availability (download or use existing data)

<hr>

In [2]:
class DataDownloader:
    
    @staticmethod
    def download_imdb() -> Path:
        data_dir = Path('./aclImdb')
        
        if data_dir.exists():
            print("IMDb dataset already exists, skipping download")
            return data_dir
        
        print("Downloading IMDb dataset...")
        url = "https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz"
        tar_path = Path('./aclImdb_v1.tar.gz')
        
        try:
            urllib.request.urlretrieve(url, tar_path)
            print("Download completed, extracting...")
            
            with tarfile.open(tar_path, 'r:gz') as tar:
                tar.extractall(path='./')
            
            tar_path.unlink()
            print("IMDb dataset preparation completed")
            return data_dir
        except Exception as e:
            raise Exception(f"Dataset download failed: {e}")
    
    @staticmethod
    def validate_data_directory(data_dir: Path) -> bool:
        required_dirs = [
            data_dir / 'train' / 'pos',
            data_dir / 'train' / 'neg', 
            data_dir / 'test' / 'pos',
            data_dir / 'test' / 'neg'
        ]
        
        for dir_path in required_dirs:
            if not dir_path.exists():
                return False
            if len(list(dir_path.glob('*.txt'))) == 0:
                return False
                
        return True
    
    @staticmethod
    def setup_data() -> Path:
        data_dir = Path('./aclImdb')
        
        if DataDownloader.validate_data_directory(data_dir):
            print("‚úÖ Data directory validation passed")
            return data_dir
        
        print("Data directory incomplete, re-downloading...")
        data_dir = DataDownloader.download_imdb()
        
        if not DataDownloader.validate_data_directory(data_dir):
            raise Exception("Data directory validation failed, please check dataset manually")
            
        return data_dir

## TextProcessor - Text Processing Module

**Function Methods:**

- robust_text_cleaning(text) - Core text cleaning, includes:

    1. HTML entity decoding and tag removal
    2. URL processing
    3. Contraction expansion (e.g., "can't" ‚Üí "can not")
    4. Punctuation standardization
    5. Case normalization

- process_in_batches() - Batch process text to avoid memory overflow

<hr>

In [3]:
class TextProcessor:
    
    @staticmethod
    def robust_text_cleaning(text: str) -> str:
        
        # Decode HTML entities
        text = html.unescape(text)
        
        # Remove HTML tags
        text = re.sub(r'<[^>]+>', ' ', text)
        
        # Process URLs
        text = re.sub(r'http\S+', ' <URL> ', text)
        
        # Process common contractions and negations
        contractions = {
            r"won't": "will not", r"can't": "can not", r"n't": " not",
            r"'re": " are", r"'s": " is", r"'d": " would", 
            r"'ll": " will", r"'t": " not", r"'ve": " have",
            r"'m": " am"
        }
        
        for pattern, replacement in contractions.items():
            text = re.sub(pattern, replacement, text)
        
        # Preserve basic punctuation for sentiment analysis
        text = re.sub(r'[^a-zA-Z\s\.!?,;:\']', ' ', text)
        
        # Handle repeated punctuation
        text = re.sub(r'([!?.]){2,}', r'\1', text)
        text = re.sub(r'[!?.]+', r' \0 ', text)
        
        # Standardize whitespace handling
        text = re.sub(r'\s+', ' ', text)
        
        return text.strip().lower()
    
    @staticmethod
    def process_in_batches(texts: List[str], process_func, batch_size: int = 1000, process_name: str = "Processing") -> List[str]:
        processed = []
        total = len(texts)
        
        for i in range(0, total, batch_size):
            batch = texts[i:i + batch_size]
            processed_batch = [process_func(text) for text in batch]
            processed.extend(processed_batch)
            
            if i % 5000 == 0 and i > 0:
                print(f"  {process_name} {i}/{total} samples")
        
        print(f"‚úÖ {process_name} completed")
        return processed

## LabelProcessor - Label Processing Module

**Function Methods:**
- parse_rating_from_filename() - Parse star rating from filename
- convert_to_binary() - Convert star rating to binary classification
- get_label_schema_config() - Get label schema configurationÁΩÆ

<hr>

In [4]:
class LabelProcessor:
    
    @staticmethod
    def parse_rating_from_filename(file_path: Path) -> Optional[int]:
        stem = file_path.stem
        parts = stem.split('_')
        if len(parts) < 2:
            return None
        
        rating_str = parts[-1]
        try:
            return int(rating_str)  # 1..10
        except ValueError:
            return None
    
    @staticmethod
    def convert_to_binary(rating: int) -> int:
        if rating <= 4:
            return 0  # Negative
        else:  # rating >= 7
            return 1  # Positive
    
    @staticmethod
    def get_label_schema_config(schema_type: str) -> Dict:
        schemas = {
            "binary": {
                "name": "Binary Classification",
                "num_classes": 2,
                "output_dir": "processed_data_binary",
                "label_range": "0-1 (Negative/Positive)"
            },
            "multiclass": {
                "name": "Multi-class (1-10 stars)",
                "num_classes": 8,  # Actually only 8 classes (1-4, 7-10)
                "output_dir": "processed_data_multiclass", 
                "label_range": "1-10 stars (Missing 5-6 stars)"
            }
        }
        return schemas.get(schema_type, schemas["binary"])

## DataLoaderManager - Data Loading Manager

**Function Methods:**
- load_raw_data() - Load raw IMDb data according to label schema

<hr>

In [5]:
class DataLoaderManager:
    
    @staticmethod
    def load_raw_data(data_dir: Path, label_schema: str) -> Tuple[List[str], List[int], List[str], List[int]]:
        print(f"Loading IMDb dataset - {LabelProcessor.get_label_schema_config(label_schema)['name']}...")
        
        def load_from_directory(directory: Path, schema: str) -> Tuple[List[str], List[int]]:
            texts, labels = [], []
            
            for label_type in ['pos', 'neg']:
                dir_name = directory / label_type
                if not dir_name.exists():
                    print(f"Warning: Missing directory: {dir_name}")
                    continue
                    
                for file_path in dir_name.glob('*.txt'):
                    # Parse star rating
                    rating = LabelProcessor.parse_rating_from_filename(file_path)
                    if rating is None:
                        continue
                    
                    # Select label based on schema
                    if schema == "binary":
                        label = LabelProcessor.convert_to_binary(rating)
                    else:  # multiclass
                        label = rating
                    
                    # Read text
                    try:
                        with open(file_path, 'r', encoding='utf-8') as f:
                            text = f.read().strip()
                        texts.append(text)
                        labels.append(label)
                    except Exception as e:
                        print(f"Skipping unreadable file: {file_path.name} ({e})")
                        continue
            
            # Distribution check
            dist = Counter(labels)
            dist_sorted = dict(sorted(dist.items()))
            print(f"{directory.name} label distribution: {dist_sorted}")
            return texts, labels
        
        train_dir = data_dir / 'train'
        test_dir = data_dir / 'test'
        
        train_texts, train_labels = load_from_directory(train_dir, label_schema)
        test_texts, test_labels = load_from_directory(test_dir, label_schema)
        
        print(f"Loaded {len(train_texts)} training samples, {len(test_texts)} test samples")
        return train_texts, train_labels, test_texts, test_labels

## DataSplitter - Data Splitter

**Function Methods:**
- split_data_with_seeds() - Use random seeds for reproducible data splitting

    1. Automatic stratified sampling protection
    2. Handle small sample classes
    3. Train/validation/test set split (60%/20%/20%)
    4. Distribution statistics output

<hr>

In [6]:
class DataSplitter:

    @staticmethod
    def split_data_with_seeds(texts: List[str], labels: List[int], seed: int = 42,
                             test_size: float = 0.2, val_size: float = 0.2) -> Tuple:
        
        def _print_dist(name: str, y: List[int]):
            c = Counter(y)
            print(f"{name} class distribution: {dict(sorted(c.items()))}")
        
        print("Performing reproducible data splitting (with class checking)...")
        counts = Counter(labels)
        min_count = min(counts.values()) if counts else 0
        use_stratify = labels if min_count >= 2 else None
        
        if use_stratify is None:
            print(f"Warning: Minimum class sample count {min_count} < 2, initial split not using stratified sampling")
        
        # First split test set
        train_texts, test_texts, train_labels, test_labels = train_test_split(
            texts, labels, test_size=test_size, random_state=seed,
            stratify=use_stratify
        )
        
        # Then split validation set from training set
        train_counts = Counter(train_labels)
        min_train_count = min(train_counts.values()) if train_counts else 0
        use_stratify_val = train_labels if min_train_count >= 2 else None
        
        if use_stratify_val is None:
            print(f"Warning: Minimum class sample count in training set {min_train_count} < 2, validation split not using stratified sampling")
        
        train_texts, val_texts, train_labels, val_labels = train_test_split(
            train_texts, train_labels, test_size=val_size, random_state=seed,
            stratify=use_stratify_val
        )
        
        print(f"Data split: Training set={len(train_texts)}, Validation set={len(val_texts)}, Test set={len(test_texts)}")
        _print_dist("Training set", train_labels)
        _print_dist("Validation set", val_labels)
        _print_dist("Test set", test_labels)
        
        return train_texts, train_labels, val_texts, val_labels, test_texts, test_labels

## VocabularyBuilder - Vocabulary Management Module

**Function Methods:**
- build_vocabulary(texts) - Build vocabulary based on training texts:

    1. Word frequency statistics
    2. Filter low-frequency words
    3. Add special tokens (PAD/UNK/BOS/EOS)
    4. Calculate vocabulary coverage

- text_to_sequence(text) - Convert text to numerical sequence:

    1. Add start/end tokens
    2. Handle unknown words (OOV)
    3. Sequence padding/truncation
    4. Generate attention masks

<hr>

In [7]:
class VocabularyBuilder:
    
    def __init__(self, max_vocab_size: int = 30000, min_freq: int = 2):
        self.max_vocab_size = max_vocab_size
        self.min_freq = min_freq
        self.special_tokens = ['<PAD>', '<UNK>', '<BOS>', '<EOS>']
        self.token_to_id = {}
        self.id_to_token = {}
        self.vocab_size = 0
    
    def build_vocabulary(self, texts: List[str]) -> None:
        print("Building vocabulary...")
        
        # Count word frequency
        word_freq = Counter()
        for text in texts:
            tokens = text.split()
            word_freq.update(tokens)
        
        # Filter low-frequency words
        filtered_words = [(word, freq) for word, freq in word_freq.items() 
                         if freq >= self.min_freq]
        
        # Sort by frequency and select top N words
        sorted_words = sorted(filtered_words, key=lambda x: x[1], reverse=True)
        selected_words = [word for word, freq in sorted_words[:self.max_vocab_size - len(self.special_tokens)]]
        
        # Build vocabulary
        self.token_to_id = {}
        self.id_to_token = {}
        
        # Add special tokens
        for idx, token in enumerate(self.special_tokens):
            self.token_to_id[token] = idx
            self.id_to_token[idx] = token
        
        # Add regular vocabulary
        for idx, word in enumerate(selected_words, start=len(self.special_tokens)):
            self.token_to_id[word] = idx
            self.id_to_token[idx] = word
        
        self.vocab_size = len(self.token_to_id)
        
        # OOV analysis
        total_tokens = sum(word_freq.values())
        covered_tokens = sum(freq for word, freq in word_freq.items() 
                           if word in self.token_to_id)
        coverage = covered_tokens / total_tokens * 100
        
        print(f"Vocabulary building completed: {self.vocab_size} tokens")
        print(f"Vocabulary coverage: {coverage:.2f}%")
    
    def text_to_sequence(self, text: str, max_length: int = 512, add_special_tokens: bool = True) -> Tuple[List[int], List[int], int]:

        tokens = text.split()
        sequence = []
        
        # Add beginning token
        if add_special_tokens:
            sequence.append(self.token_to_id['<BOS>'])
        
        # Convert tokens
        for token in tokens:
            sequence.append(self.token_to_id.get(token, self.token_to_id['<UNK>']))
        
        # Add end token
        if add_special_tokens:
            sequence.append(self.token_to_id['<EOS>'])
        
        original_length = len(sequence)
        
        # Padding or truncation
        if len(sequence) < max_length:
            sequence.extend([self.token_to_id['<PAD>']] * (max_length - len(sequence)))
            attention_mask = [1] * original_length + [0] * (max_length - original_length)
        else:
            sequence = sequence[:max_length]
            attention_mask = [1] * max_length
        
        return sequence, attention_mask, original_length

## IMDBDataset - Dataset Interface Module

**Function Methods:**
- __init__() - Initialize dataset, validate data consistency
- __len__() - Return dataset size
- __getitem__() - Get single sample, return dictionary format


<hr>

In [8]:
class IMDBDataset(Dataset):
    
    def __init__(self, sequences: List[List[int]], attention_masks: List[List[int]], 
                 labels: List[int], lengths: List[int]):
        self.sequences = torch.tensor(sequences, dtype=torch.long)
        self.attention_masks = torch.tensor(attention_masks, dtype=torch.long)
        self.labels = torch.tensor(labels, dtype=torch.long)
        self.lengths = torch.tensor(lengths, dtype=torch.long)
        
        # Validate data shapes
        assert len(self.sequences) == len(self.labels), "Sequence and label count mismatch"
        assert len(self.sequences) == len(self.attention_masks), "Sequence and attention mask count mismatch"
        assert len(self.sequences) == len(self.lengths), "Sequence and length count mismatch"
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        return {
            'input_ids': self.sequences[idx],
            'attention_mask': self.attention_masks[idx],
            'labels': self.labels[idx],
            'lengths': self.lengths[idx]
        }

## DataSaver - Data Saving Module

**Function Methods:**
- save_processed_data() - Main save function
- _save_vocabulary() - Save vocabulary as JSON
- _save_config() - Save preprocessing configuration
- _save_tensor_data() - Save tensor data as .pt files
- _save_metadata() - Save statistical metadata
- _verify_files() - Verify generated files

<hr>

In [9]:
class DataSaver:
    
    @staticmethod
    def save_processed_data(datasets: Dict, vocab, metadata: Dict, output_dir: str, label_schema: str):
        output_path = Path(output_dir)
        output_path.mkdir(parents=True, exist_ok=True)
        
        schema_config = LabelProcessor.get_label_schema_config(label_schema)
        print(f"Starting to save {schema_config['name']} data to: {output_path}")
        
        # 1. Save vocabulary
        DataSaver._save_vocabulary(vocab, output_path)
        
        # 2. Save configuration
        DataSaver._save_config(output_path, label_schema, schema_config)
        
        # 3. Save tensor data
        DataSaver._save_tensor_data(datasets, output_path)
        
        # 4. Save metadata
        DataSaver._save_metadata(metadata, output_path)
        
        # 5. Verify files
        DataSaver._verify_files(output_path)
    
    @staticmethod
    def _save_vocabulary(vocab, output_path):
        vocab_path = output_path / 'vocabulary.json'
        with open(vocab_path, 'w', encoding='utf-8') as f:
            json.dump(vocab, f, indent=2, ensure_ascii=False)
        print(f"‚úÖ Vocabulary saved to: {vocab_path}")
    
    @staticmethod
    def _save_config(output_path, label_schema, schema_config):
        config_path = output_path / 'preprocessing_config.json'
        config = {
            'label_schema': label_schema,
            'schema_name': schema_config['name'],
            'num_classes': schema_config['num_classes'],
            'label_range': schema_config['label_range'],
            'max_vocab_size': 30000,
            'max_length': 512,
            'min_freq': 2,
            'seed': 42,
            'saved_time': str(pd.Timestamp.now())
        }
        with open(config_path, 'w', encoding='utf-8') as f:
            json.dump(config, f, indent=2)
        print(f"‚úÖ Configuration saved to: {config_path}")
    
    @staticmethod
    def _save_tensor_data(datasets, output_path):
        data_path = output_path / 'all_data.pt'
        
        save_data = {
            'train_sequences': datasets['train'].sequences,
            'train_masks': datasets['train'].attention_masks,
            'train_labels': datasets['train'].labels,
            'train_lengths': datasets['train'].lengths,
            
            'val_sequences': datasets['val'].sequences,
            'val_masks': datasets['val'].attention_masks,
            'val_labels': datasets['val'].labels,
            'val_lengths': datasets['val'].lengths,
            
            'test_sequences': datasets['test'].sequences,
            'test_masks': datasets['test'].attention_masks,
            'test_labels': datasets['test'].labels,
            'test_lengths': datasets['test'].lengths,
        }
        
        torch.save(save_data, data_path)
        file_size = data_path.stat().st_size / 1024 / 1024
        print(f"‚úÖ Tensor data saved to: {data_path} (Size: {file_size:.2f} MB)")
    
    @staticmethod
    def _save_metadata(metadata, output_path):
        metadata_path = output_path / 'preprocessing_metadata.json'
        with open(metadata_path, 'w', encoding='utf-8') as f:
            json.dump(metadata, f, indent=2, default=str, ensure_ascii=False)
        print(f"‚úÖ Metadata saved to: {metadata_path}")
    
    @staticmethod
    def _verify_files(output_path):
        print("üìÅ Verifying generated files:")
        for file in output_path.iterdir():
            size_kb = file.stat().st_size / 1024
            print(f"  {file.name} ({size_kb:.1f} KB)")

## UnifiedDataProcessingPipeline - Process Coordination Module

**Function Methods:**
- run_complete_pipeline() - Execute complete multi-schema data processing pipeline
- process_single_schema() - Process data for single label schema
- _ensure_data_ready() - Ensure data preparation is complete (execute only once)
- _vectorize_texts() - Batch vectorize texts
- _create_datasets() - Create train/validation/test datasets
- _check_data_size() - Check data scale

<hr>

In [10]:
class UnifiedDataProcessingPipeline:
    
    def __init__(self, max_vocab_size: int = 30000, max_length: int = 512, 
                 min_freq: int = 2, seed: int = 42):
        self.max_vocab_size = max_vocab_size
        self.max_length = max_length
        self.min_freq = min_freq
        self.seed = seed
        self.data_dir = None  # Add data directory cache
    
    def _ensure_data_ready(self) -> bool:
        if self.data_dir is not None:
            return True
            
        try:
            print("üîç Checking data source...")
            self.data_dir = DataDownloader.setup_data()
            print("‚úÖ Data source preparation completed")
            return True
        except Exception as e:
            print(f"‚ùå Data source preparation failed: {e}")
            return False
    
    def process_single_schema(self, label_schema: str) -> bool:
        try:
            # 0. Ensure data is ready
            if not self._ensure_data_ready():
                return False
                
            schema_config = LabelProcessor.get_label_schema_config(label_schema)
            print(f"{'='*50}")
            print(f"=== Processing {schema_config['name']} Data ===")
            print(f"Label range: {schema_config['label_range']}")
            print(f"Number of classes: {schema_config['num_classes']}")
            print(f"Output directory: {schema_config['output_dir']}")
            print(f"{'='*50}")
            
            # 1. Initialize components
            vocab_builder = VocabularyBuilder(self.max_vocab_size, self.min_freq)
            text_processor = TextProcessor()
            
            # 2. Load raw data
            print("üì• Loading raw data...")
            train_texts, train_labels, test_texts, test_labels = DataLoaderManager.load_raw_data(
                self.data_dir, label_schema
            )
            
            # 3. Merge and split data
            print("üìä Data splitting...")
            all_texts = train_texts + test_texts
            all_labels = train_labels + test_labels
            
            train_texts, train_labels, val_texts, val_labels, test_texts, test_labels = \
                DataSplitter.split_data_with_seeds(all_texts, all_labels, self.seed)
            
            # 4. Text cleaning
            print("üßπ Cleaning text data...")
            train_texts_clean = text_processor.process_in_batches(
                train_texts, text_processor.robust_text_cleaning, process_name="Text cleaning"
            )
            val_texts_clean = text_processor.process_in_batches(
                val_texts, text_processor.robust_text_cleaning, process_name="Text cleaning"
            )
            test_texts_clean = text_processor.process_in_batches(
                test_texts, text_processor.robust_text_cleaning, process_name="Text cleaning"
            )
            
            # 5. Build vocabulary
            print("üìö Building vocabulary...")
            vocab_builder.build_vocabulary(train_texts_clean)
            
            # 6. Text vectorization
            print("üî¢ Vectorizing data...")
            train_sequences, train_masks, train_lengths = self._vectorize_texts(
                train_texts_clean, vocab_builder, self.max_length
            )
            val_sequences, val_masks, val_lengths = self._vectorize_texts(
                val_texts_clean, vocab_builder, self.max_length
            )
            test_sequences, test_masks, test_lengths = self._vectorize_texts(
                test_texts_clean, vocab_builder, self.max_length
            )
            
            # 7. Create datasets
            print("üóÇÔ∏è Creating datasets...")
            datasets = self._create_datasets(
                train_sequences, train_masks, train_labels, train_lengths,
                val_sequences, val_masks, val_labels, val_lengths,
                test_sequences, test_masks, test_labels, test_lengths
            )
            
            # 8. Prepare data for saving
            vocab_info = {
                'token_to_id': vocab_builder.token_to_id,
                'id_to_token': vocab_builder.id_to_token,
                'vocab_size': vocab_builder.vocab_size,
                'special_tokens': vocab_builder.special_tokens
            }
            
            metadata = {
                'label_schema': label_schema,
                'schema_config': schema_config,
                'sequence_length_analysis': {
                    'mean_length': np.mean(train_lengths),
                    'max_length': np.max(train_lengths),
                    'min_length': np.min(train_lengths),
                    'total_samples': len(train_sequences)
                },
                'class_distribution': {
                    'train': dict(Counter(train_labels)),
                    'val': dict(Counter(val_labels)),
                    'test': dict(Counter(test_labels))
                }
            }
            
            # 9. Save data
            DataSaver.save_processed_data(
                datasets, vocab_info, metadata, 
                f"./{schema_config['output_dir']}", 
                label_schema
            )
            
            print(f"üéâ {schema_config['name']} data processing completed!")
            return True
            
        except Exception as e:
            print(f"‚ùå {label_schema} processing failed: {e}")
            import traceback
            traceback.print_exc()
            return False
    
    def run_complete_pipeline(self) -> bool:
        print("üöÄ Starting unified data processing pipeline")
        
        # First uniformly check data source
        print("üì¶ Preparing data source...")
        if not self._ensure_data_ready():
            print("‚ùå Data source preparation failed, terminating processing")
            return False
            
        success_count = 0
        schemas = ["binary", "multiclass"]
        
        for schema in schemas:
            success = self.process_single_schema(schema)
            if success:
                success_count += 1
                self._check_data_size(schema)
        
        print(f"{'='*60}")
        print(f"üéØ Data processing completion summary:")
        print(f"Successfully processed: {success_count}/{len(schemas)} data schemas")
        
        if success_count == len(schemas):
            print("‚úÖ All data schemas processed successfully!")
            print("üìÅ Generated data directories:")
            for schema in schemas:
                config = LabelProcessor.get_label_schema_config(schema)
                print(f"  - {config['output_dir']}: {config['name']}")
            return True
        else:
            print("‚ùå Partial data processing failed")
            return False
    
    def _vectorize_texts(self, texts: List[str], vocab_builder: VocabularyBuilder, max_length: int):
        sequences, masks, lengths = [], [], []
        
        for i, text in enumerate(texts):
            seq, mask, length = vocab_builder.text_to_sequence(text, max_length)
            sequences.append(seq)
            masks.append(mask)
            lengths.append(length)
            
            if i % 5000 == 0 and i > 0:
                print(f"  Vectorized {i}/{len(texts)} samples")
        
        return sequences, masks, lengths
    
    def _create_datasets(self, train_sequences, train_masks, train_labels, train_lengths,
                        val_sequences, val_masks, val_labels, val_lengths,
                        test_sequences, test_masks, test_labels, test_lengths):
        train_dataset = IMDBDataset(train_sequences, train_masks, train_labels, train_lengths)
        val_dataset = IMDBDataset(val_sequences, val_masks, val_labels, val_lengths)
        test_dataset = IMDBDataset(test_sequences, test_masks, test_labels, test_lengths)
        
        return {
            'train': train_dataset,
            'val': val_dataset,
            'test': test_dataset
        }
    
    def _check_data_size(self, label_schema: str):
        config = LabelProcessor.get_label_schema_config(label_schema)
        data_path = f"./{config['output_dir']}/all_data.pt"
        
        try:
            data = torch.load(data_path)
            
            print(f"=== {config['name']} Data Scale Check ===")
            print(f"Training sequences: {data['train_sequences'].shape}")
            print(f"Training labels: {data['train_labels'].shape}")
            print(f"Validation sequences: {data['val_sequences'].shape}")
            print(f"Test sequences: {data['test_sequences'].shape}")
            
            total_samples = (data['train_sequences'].shape[0] + 
                            data['val_sequences'].shape[0] + 
                            data['test_sequences'].shape[0])
            
            print(f"Total samples: {total_samples}")
            
            if total_samples >= 40000:
                print("‚úÖ Processed complete dataset")
            else:
                print("‚ùå Abnormal data volume, please check processing")
                
        except Exception as e:
            print(f"Check failed: {e}")


## Utility Functions

**Function Methods:**
Check processed data scale and statistical information

<hr>

In [11]:
def check_data_size(data_path: str):
    try:
        data = torch.load(data_path)
        
        print("=== Data Scale Check ===")
        print(f"Training sequences: {data['train_sequences'].shape}")
        print(f"Training labels: {data['train_labels'].shape}")
        print(f"Validation sequences: {data['val_sequences'].shape}") 
        print(f"Test sequences: {data['test_sequences'].shape}")
        
        total_samples = (data['train_sequences'].shape[0] + 
                        data['val_sequences'].shape[0] + 
                        data['test_sequences'].shape[0])
        
        print(f"Total samples: {total_samples}")
        
        if total_samples >= 40000:
            print("‚úÖ Processed complete dataset")
        elif total_samples >= 20000:
            print("‚ö†Ô∏è  Processed partial dataset")  
        else:
            print("‚ùå Abnormal data volume, please check processing")
            
    except Exception as e:
        print(f"Check failed: {e}")

## Main Program Entry

<hr>

In [12]:
if __name__ == "__main__":
    # Create unified data processing pipeline
    pipeline = UnifiedDataProcessingPipeline(
        max_vocab_size=30000,
        max_length=512, 
        min_freq=2,
        seed=42
    )
    
    # Run complete multi-schema processing
    success = pipeline.run_complete_pipeline()
    
    if success:
        check_data_size('./processed_data_binary/all_data.pt')
        check_data_size('./processed_data_multiclass/all_data.pt')
        print("üéâ All data processing completed! Now you can start training models for different tasks!")
        print("   - Binary sentiment analysis: Use processed_data_binary/")
        print("   - Multi-class star rating prediction: Use processed_data_multiclass/")
    else:
        print("‚ùå Data processing failed, please check error messages")

üöÄ Starting unified data processing pipeline
üì¶ Preparing data source...
üîç Checking data source...
Data directory incomplete, re-downloading...
Downloading IMDb dataset...
Download completed, extracting...
IMDb dataset preparation completed
‚úÖ Data source preparation completed
=== Processing Binary Classification Data ===
Label range: 0-1 (Negative/Positive)
Number of classes: 2
Output directory: processed_data_binary
üì• Loading raw data...
Loading IMDb dataset - Binary Classification...
train label distribution: {0: 12500, 1: 12500}
test label distribution: {0: 12500, 1: 12500}
Loaded 25000 training samples, 25000 test samples
üìä Data splitting...
Performing reproducible data splitting (with class checking)...
Data split: Training set=32000, Validation set=8000, Test set=10000
Training set class distribution: {0: 16000, 1: 16000}
Validation set class distribution: {0: 4000, 1: 4000}
Test set class distribution: {0: 5000, 1: 5000}
üßπ Cleaning text data...
  Text cleaning 

  data = torch.load(data_path)


=== Binary Classification Data Scale Check ===
Training sequences: torch.Size([32000, 512])
Training labels: torch.Size([32000])
Validation sequences: torch.Size([8000, 512])
Test sequences: torch.Size([10000, 512])
Total samples: 50000
‚úÖ Processed complete dataset
=== Processing Multi-class (1-10 stars) Data ===
Label range: 1-10 stars (Missing 5-6 stars)
Number of classes: 8
Output directory: processed_data_multiclass
üì• Loading raw data...
Loading IMDb dataset - Multi-class (1-10 stars)...
train label distribution: {1: 5100, 2: 2284, 3: 2420, 4: 2696, 7: 2496, 8: 3009, 9: 2263, 10: 4732}
test label distribution: {1: 5022, 2: 2302, 3: 2541, 4: 2635, 7: 2307, 8: 2850, 9: 2344, 10: 4999}
Loaded 25000 training samples, 25000 test samples
üìä Data splitting...
Performing reproducible data splitting (with class checking)...
Data split: Training set=32000, Validation set=8000, Test set=10000
Training set class distribution: {1: 6478, 2: 2935, 3: 3175, 4: 3412, 7: 3074, 8: 3750, 9: 294

  data = torch.load(data_path)
  data = torch.load(data_path)


=== Multi-class (1-10 stars) Data Scale Check ===
Training sequences: torch.Size([32000, 512])
Training labels: torch.Size([32000])
Validation sequences: torch.Size([8000, 512])
Test sequences: torch.Size([10000, 512])
Total samples: 50000
‚úÖ Processed complete dataset
üéØ Data processing completion summary:
Successfully processed: 2/2 data schemas
‚úÖ All data schemas processed successfully!
üìÅ Generated data directories:
  - processed_data_binary: Binary Classification
  - processed_data_multiclass: Multi-class (1-10 stars)
=== Data Scale Check ===
Training sequences: torch.Size([32000, 512])
Training labels: torch.Size([32000])
Validation sequences: torch.Size([8000, 512])
Test sequences: torch.Size([10000, 512])
Total samples: 50000
‚úÖ Processed complete dataset
=== Data Scale Check ===
Training sequences: torch.Size([32000, 512])
Training labels: torch.Size([32000])
Validation sequences: torch.Size([8000, 512])
Test sequences: torch.Size([10000, 512])
Total samples: 50000
‚úÖ