# GPU-Aware Data Preprocessing for Nanotron Training

This notebook processes parquet files and splits them into training and evaluation datasets (80/20 split) with GPU-aware optimizations.

## Features:
- 🔍 Automatic GPU detection and configuration
- 💾 Memory-optimized processing for large datasets 
- ⚙️ Device-specific optimizations (CPU vs GPU)
- 📈 Progress tracking and memory monitoring
- 📦 Efficient data loading with chunked processing
- 📊 Data validation and integrity checks

## Hardware Requirements:
- **CPU**: Minimum 8GB RAM recommended for large datasets
- **GPU**: Optional but recommended for faster processing
- **Storage**: SSD recommended for better I/O performance

## Configuration:
The notebook automatically detects your hardware and configures optimal settings for your environment.

# GPU-Aware Data Preprocessing for Training

This notebook processes parquet files and splits them into training and evaluation datasets (80/20 split).
Includes GPU detection and device configuration for efficient processing on training devices.

In [None]:
import pandas as pd
import numpy as np
from pathlib import Path
import os
from sklearn.model_selection import train_test_split
import json
from typing import List, Dict, Any
import warnings
from tqdm.auto import tqdm
import time

# GPU and device management
import torch
import psutil

warnings.filterwarnings('ignore')

# Enable progress bars for pandas operations
tqdm.pandas()

print("Libraries imported successfully!")

# GPU Detection and Device Configuration
def detect_gpu_setup():
    """Detect available GPUs and system configuration"""
    print("\n🔍 GPU and System Detection:")
    
    # Check PyTorch installation
    print(f"   PyTorch version: {torch.__version__}")
    
    # Check CUDA availability
    if torch.cuda.is_available():
        gpu_count = torch.cuda.device_count()
        print(f"   ✅ CUDA available with {gpu_count} GPU(s)")
        
        for i in range(gpu_count):
            gpu_props = torch.cuda.get_device_properties(i)
            memory_gb = gpu_props.total_memory / 1024**3
            print(f"      GPU {i}: {gpu_props.name} ({memory_gb:.1f} GB)")
            
        # Get current GPU
        current_device = torch.cuda.current_device()
        print(f"   Current device: cuda:{current_device}")
        
        return True, gpu_count
    else:
        print("   ⚠️  CUDA not available - using CPU for data processing")
        return False, 0

# System resources
def check_system_resources():
    """Check system memory and CPU cores"""
    print("\n💻 System Resources:")
    
    # Memory
    memory = psutil.virtual_memory()
    memory_gb = memory.total / 1024**3
    available_gb = memory.available / 1024**3
    print(f"   RAM: {memory_gb:.1f} GB total, {available_gb:.1f} GB available")
    
    # CPU
    cpu_count = psutil.cpu_count()
    print(f"   CPU cores: {cpu_count}")
    
    return memory_gb, cpu_count

# Run detection
has_gpu, gpu_count = detect_gpu_setup()
memory_gb, cpu_count = check_system_resources()

print("\n📊 Recommendations for data processing:")
if has_gpu:
    print(f"   • Use GPU acceleration for large datasets")
    print(f"   • Enable GPU-accelerated pandas operations")
    print(f"   • Consider GPU memory when processing large files")
else:
    print(f"   • Optimize for CPU processing")
    print(f"   • Use chunked processing for large datasets")
    print(f"   • Increase num_workers for parallel processing")

Libraries imported successfully!


## Configuration

Set the path to your parquet files and output directories.

In [None]:
# Configuration with GPU and Device Settings
INPUT_DATA_PATH = "/Users/zhang/Desktop/huawei/untitled folder 5/nanotron-infini/data"  # Change this to your actual path
OUTPUT_DIR = "/Users/zhang/Desktop/huawei/untitled folder 5/nanotron-infini/data"
TRAIN_SPLIT = 0.8
EVAL_SPLIT = 0.2
RANDOM_SEED = 42

# GPU Configuration (set these for your training device)
GPU_DEVICE = "cuda:0"  # Change to your specific GPU (cuda:0, cuda:1, etc.)
USE_GPU_PROCESSING = has_gpu  # Enable GPU-accelerated processing if available
CHUNK_SIZE = 10000 if not has_gpu else 50000  # Larger chunks if GPU available
NUM_WORKERS = min(4, cpu_count)  # Parallel processing workers

# Memory management
MAX_MEMORY_GB = min(memory_gb * 0.8, 32)  # Use up to 80% of available RAM, max 32GB
GPU_MEMORY_FRACTION = 0.9  # Use 90% of GPU memory if available

print(f"🔧 Configuration:")
print(f"   Input path: {INPUT_DATA_PATH}")
print(f"   Output directory: {OUTPUT_DIR}")
print(f"   Train split: {TRAIN_SPLIT}, Eval split: {EVAL_SPLIT}")
print(f"   Random seed: {RANDOM_SEED}")
print(f"\n🎯 Device Configuration:")
print(f"   Target GPU device: {GPU_DEVICE}")
print(f"   GPU processing: {'Enabled' if USE_GPU_PROCESSING else 'Disabled'}")
print(f"   Chunk size: {CHUNK_SIZE:,} rows")
print(f"   Parallel workers: {NUM_WORKERS}")
print(f"   Max memory usage: {MAX_MEMORY_GB:.1f} GB")

if USE_GPU_PROCESSING and has_gpu:
    # Set GPU memory fraction
    torch.cuda.set_per_process_memory_fraction(GPU_MEMORY_FRACTION, device=torch.cuda.current_device())
    print(f"   GPU memory fraction: {GPU_MEMORY_FRACTION}")

# Create output directories
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(f"{OUTPUT_DIR}/train", exist_ok=True)
os.makedirs(f"{OUTPUT_DIR}/eval", exist_ok=True)

print(f"\n✅ Configuration complete!")

Input path: /Users/zhang/Desktop/huawei/untitled folder 5/nanotron-infini/data
Output directory: /Users/zhang/Desktop/huawei/untitled folder 5/nanotron-infini/data
Train split: 0.8, Eval split: 0.2


## Data Loading

Load all parquet files from the specified directory.

In [None]:
from typing import List
import pandas as pd
from pathlib import Path
from tqdm import tqdm

def find_parquet_files(data_path: str) -> List[str]:
    """Find all parquet files in the given directory."""
    parquet_files = []
    data_path = Path(data_path)
    
    if data_path.is_file() and data_path.suffix == '.parquet':
        return [str(data_path)]
    
    for file_path in data_path.rglob("*.parquet"):
        parquet_files.append(str(file_path))
    
    return sorted(parquet_files)

def load_parquet_robust(file_path: str) -> pd.DataFrame:
    """Load a parquet file with robust error handling for Arrow extension issues."""
    try:
        # First attempt: Default pyarrow engine
        return pd.read_parquet(file_path, engine='pyarrow')
    except Exception as e:
        print(f"    ⚠️  PyArrow failed ({str(e)[:50]}...), trying alternatives...")
        
        try:
            # Second attempt: Use fastparquet engine
            return pd.read_parquet(file_path, engine='fastparquet')
        except Exception as e2:
            print(f"    ⚠️  FastParquet failed ({str(e2)[:50]}...), trying PyArrow with ignore_metadata...")
            
            try:
                # Third attempt: PyArrow without metadata
                import pyarrow.parquet as pq
                table = pq.read_table(file_path, use_pandas_metadata=False)
                return table.to_pandas()
            except Exception as e3:
                print(f"    ⚠️  PyArrow no-metadata failed ({str(e3)[:50]}...), trying manual conversion...")
                
                try:
                    # Fourth attempt: Manual Arrow to Pandas conversion
                    import pyarrow.parquet as pq
                    import pyarrow as pa
                    
                    # Read table without pandas metadata
                    table = pq.read_table(file_path)
                    
                    # Convert to pandas manually, handling extension types
                    df_dict = {}
                    for i, column in enumerate(table.column_names):
                        col_data = table.column(i)
                        
                        # Handle different Arrow types
                        if pa.types.is_string(col_data.type) or pa.types.is_large_string(col_data.type):
                            df_dict[column] = col_data.to_pandas()
                        elif pa.types.is_integer(col_data.type) or pa.types.is_floating(col_data.type):
                            df_dict[column] = col_data.to_pandas()
                        else:
                            # For extension types, convert to string as fallback
                            try:
                                df_dict[column] = col_data.to_pandas()
                            except:
                                # Last resort: convert to string
                                df_dict[column] = [str(x) for x in col_data.to_pylist()]
                    
                    return pd.DataFrame(df_dict)
                    
                except Exception as e4:
                    print(f"    ❌ All methods failed. Final error: {e4}")
                    raise e4

# Find all parquet files
parquet_files = find_parquet_files(INPUT_DATA_PATH)
print(f"Found {len(parquet_files)} parquet files:")
for i, file in enumerate(parquet_files[:10]):  # Show first 10 files
    print(f"  {i+1}. {file}")
if len(parquet_files) > 10:
    print(f"  ... and {len(parquet_files) - 10} more files")

Found 1 parquet files:
  1. /Users/zhang/Desktop/huawei/untitled folder 5/nanotron-infini/data/000_00000.parquet


In [None]:
def load_parquet_advanced_fallback(file_path: str) -> pd.DataFrame:
    """Advanced parquet loading with specific handling for pandas.period extension type errors."""
    import pyarrow.parquet as pq
    import pyarrow as pa
    
    strategies = [
        # Strategy 1: Direct pandas read
        lambda: pd.read_parquet(file_path),
        
        # Strategy 2: PyArrow with ignore_metadata
        lambda: pd.read_parquet(file_path, use_pandas_metadata=False),
        
        # Strategy 3: FastParquet engine
        lambda: pd.read_parquet(file_path, engine='fastparquet'),
        
        # Strategy 4: PyArrow table with metadata reset
        lambda: load_with_metadata_reset(file_path),
        
        # Strategy 5: Manual column-by-column conversion
        lambda: load_with_manual_conversion(file_path)
    ]
    
    for i, strategy in enumerate(strategies, 1):
        try:
            print(f"    Trying strategy {i}...")
            df = strategy()
            print(f"    ✅ Strategy {i} succeeded!")
            return df
        except Exception as e:
            print(f"    ❌ Strategy {i} failed: {str(e)[:80]}...")
            if i == len(strategies):
                raise e
            continue

def load_with_metadata_reset(file_path: str) -> pd.DataFrame:
    """Load parquet by completely resetting Arrow metadata."""
    import pyarrow.parquet as pq
    import pyarrow as pa
    
    # Read table
    table = pq.read_table(file_path)
    
    # Create completely new schema without any extension metadata
    new_fields = []
    for i, (name, field) in enumerate(zip(table.schema.names, table.schema)):
        # Get the base type, stripping any extension metadata
        base_type = field.type
        
        # Handle specific problematic types
        if hasattr(base_type, 'value_type'):
            # This is likely an extension type, use the underlying type
            base_type = base_type.value_type
        
        # Create new field with clean type
        new_field = pa.field(name, base_type)
        new_fields.append(new_field)
    
    # Create new schema
    new_schema = pa.schema(new_fields)
    
    # Create new table with clean schema
    new_table = pa.table(table.columns, schema=new_schema)
    
    # Convert to pandas
    return new_table.to_pandas()

def load_with_manual_conversion(file_path: str) -> pd.DataFrame:
    """Load parquet with manual type conversion to avoid extension type issues."""
    import pyarrow.parquet as pq
    import pyarrow as pa
    
    # Read table
    table = pq.read_table(file_path)
    
    # Convert each column manually
    data_dict = {}
    
    for i, column_name in enumerate(table.schema.names):
        column = table.column(i)
        column_type = column.type
        
        try:
            # Try direct conversion first
            data_dict[column_name] = column.to_pandas()
        except Exception:
            # If that fails, convert to basic types
            if pa.types.is_string(column_type) or pa.types.is_large_string(column_type):
                data_dict[column_name] = [str(x) if x is not None else None for x in column.to_pylist()]
            elif pa.types.is_integer(column_type):
                data_dict[column_name] = [int(x) if x is not None else None for x in column.to_pylist()]
            elif pa.types.is_floating(column_type):
                data_dict[column_name] = [float(x) if x is not None else None for x in column.to_pylist()]
            elif pa.types.is_boolean(column_type):
                data_dict[column_name] = [bool(x) if x is not None else None for x in column.to_pylist()]
            else:
                # Fallback: convert everything to string
                data_dict[column_name] = [str(x) if x is not None else None for x in column.to_pylist()]
    
    return pd.DataFrame(data_dict)

def load_parquet_files_robust(file_paths: List[str]) -> pd.DataFrame:
    """Load multiple parquet files with advanced error handling."""
    dataframes = []
    failed_files = []
    
    print(f"🔧 Loading {len(file_paths)} parquet files with advanced fallback strategies...")
    
    for file_path in tqdm(file_paths, desc="Loading files", unit="file"):
        file_name = Path(file_path).name
        try:
            print(f"\n📁 Loading {file_name}...")
            df = load_parquet_advanced_fallback(file_path)
            dataframes.append(df)
            print(f"  ✅ Successfully loaded {file_name}: {len(df):,} rows, {len(df.columns)} columns")
        except Exception as e:
            failed_files.append((file_path, str(e)))
            print(f"  ❌ Failed to load {file_name}: {e}")
    
    if not dataframes:
        print("\n💥 CRITICAL ERROR: No files could be loaded!")
        print("\n🔍 Detailed error analysis:")
        for file_path, error in failed_files:
            print(f"   {Path(file_path).name}: {error}")
        
        print("\n🛠️  Advanced troubleshooting for pandas.period error:")
        print("   1. This error occurs when Arrow extension types conflict")
        print("   2. Try recreating the parquet files:")
        print("      import pandas as pd")
        print("      df = pd.read_csv('your_data.csv')  # or other format")
        print("      df.to_parquet('fixed_file.parquet', engine='pyarrow', compression='snappy')")
        print("   3. Check PyArrow version compatibility:")
        print("      pip install pyarrow==12.0.0 pandas==2.0.3")
        print("   4. Alternative: Convert to different format temporarily")
        
        raise ValueError("All loading strategies failed. See troubleshooting guide above.")
    
    print(f"\n🔗 Concatenating {len(dataframes)} successfully loaded dataframes...")
    
    # Ensure all dataframes have the same columns
    if len(dataframes) > 1:
        first_cols = set(dataframes[0].columns)
        for i, df in enumerate(dataframes[1:], 1):
            if set(df.columns) != first_cols:
                print(f"   ⚠️  Warning: Column mismatch in file {i+1}. Aligning columns...")
                # Align columns
                all_cols = sorted(set().union(*[df.columns for df in dataframes]))
                for j, df in enumerate(dataframes):
                    for col in all_cols:
                        if col not in df.columns:
                            df[col] = None
                    dataframes[j] = df[all_cols]
    
    combined_df = pd.concat(dataframes, ignore_index=True)
    
    print(f"✅ Loading complete!")
    print(f"   📊 Total rows: {len(combined_df):,}")
    print(f"   📋 Total columns: {len(combined_df.columns)}")
    print(f"   📈 Successfully loaded: {len(dataframes)}/{len(file_paths)} files")
    print(f"   💾 Memory usage: {combined_df.memory_usage(deep=True).sum() / 1024**2:.1f} MB")
    
    if failed_files:
        print(f"   ⚠️  {len(failed_files)} files failed to load")
    
    return combined_df

# Install required package if needed
try:
    import fastparquet
except ImportError:
    print("📦 Installing fastparquet for additional loading support...")
    import subprocess
    import sys
    subprocess.check_call([sys.executable, "-m", "pip", "install", "fastparquet"])
    import fastparquet
    print("✅ fastparquet installed successfully")

# Load data with advanced error handling
if parquet_files:
    start_time = time.time()
    print(f"\n🚀 Starting advanced parquet loading...")
    print(f"   Target files: {len(parquet_files)}")
    print(f"   Device: {GPU_DEVICE}")
    
    try:
        df = load_parquet_files_robust(parquet_files)
        load_time = time.time() - start_time
        
        print(f"\n🎉 SUCCESS! Data loaded in {load_time:.2f} seconds")
        print(f"📊 Final dataset info:")
        print(f"   Shape: {df.shape}")
        print(f"   Columns: {list(df.columns)}")
        print(f"   Data types: {df.dtypes.value_counts().to_dict()}")
        
        # Show sample data
        print(f"\n👀 Sample data (first 3 rows):")
        print(df.head(3).to_string())
        
    except Exception as e:
        print(f"\n💥 FINAL ERROR: {e}")
        print(f"\n📞 Contact support with this error information:")
        print(f"   File: {parquet_files[0] if parquet_files else 'None'}")
        print(f"   Error: {type(e).__name__}: {e}")
        print(f"   PyArrow version: {pa.__version__ if 'pa' in locals() else 'Unknown'}")
        print(f"   Pandas version: {pd.__version__}")
else:
    print("❌ No parquet files found! Check your INPUT_DATA_PATH.")

In [1]:
def load_parquet_with_fallbacks(file_path: str) -> pd.DataFrame:
    """Load a parquet file with multiple fallback strategies for Arrow extension type errors."""
    strategies = [
        # Strategy 1: Default pandas read_parquet
        lambda: pd.read_parquet(file_path),
        
        # Strategy 2: Use pyarrow directly without extension metadata
        lambda: pd.read_parquet(file_path, engine='pyarrow'),
        
        # Strategy 3: Load with pyarrow and ignore metadata
        lambda: load_parquet_ignore_metadata(file_path),
        
        # Strategy 4: Load with fastparquet engine
        lambda: pd.read_parquet(file_path, engine='fastparquet'),
        
        # Strategy 5: Direct pyarrow table conversion
        lambda: load_parquet_direct_arrow(file_path)
    ]
    
    last_error = None
    for i, strategy in enumerate(strategies, 1):
        try:
            return strategy()
        except Exception as e:
            last_error = e
            if i < len(strategies):  # Don't print for last attempt
                continue
    
    # If all strategies failed, raise the last error
    raise last_error

def load_parquet_ignore_metadata(file_path: str) -> pd.DataFrame:
    """Load parquet file ignoring extension metadata that causes Arrow type errors."""
    import pyarrow.parquet as pq
    import pyarrow as pa
    
    # Read parquet file without extension metadata
    table = pq.read_table(file_path)
    
    # Remove extension metadata that might cause issues
    schema_without_metadata = pa.schema([
        pa.field(name, field.type.value_type if hasattr(field.type, 'value_type') else field.type)
        for name, field in zip(table.schema.names, table.schema)
    ])
    
    # Create new table with clean schema
    clean_table = pa.table(table.columns, schema=schema_without_metadata)
    
    return clean_table.to_pandas()

def load_parquet_direct_arrow(file_path: str) -> pd.DataFrame:
    """Load parquet using direct PyArrow table conversion."""
    import pyarrow.parquet as pq
    
    # Read as PyArrow table
    table = pq.read_table(file_path)
    
    # Convert to pandas with safe conversion
    return table.to_pandas(safe=False, ignore_metadata=True)

def load_parquet_files(file_paths: List[str]) -> pd.DataFrame:
    """Load and concatenate multiple parquet files with robust error handling."""
    dataframes = []
    total_rows = 0
    failed_files = []
    
    print("Loading parquet files with robust error handling...")
    
    # Use tqdm for progress bar
    for file_path in tqdm(file_paths, desc="Loading files", unit="file"):
        try:
            df = load_parquet_with_fallbacks(file_path)
            dataframes.append(df)
            total_rows += len(df)
            tqdm.write(f"  ✓ Loaded {Path(file_path).name}: {len(df):,} rows")
        except Exception as e:
            failed_files.append((file_path, str(e)))
            tqdm.write(f"  ✗ Error loading {Path(file_path).name}: {e}")
    
    if not dataframes:
        print("\n❌ No parquet files could be loaded successfully!")
        print("\n🔍 Failed files details:")
        for file_path, error in failed_files:
            print(f"   {Path(file_path).name}: {error}")
        print("\n💡 Troubleshooting suggestions:")
        print("   1. Check if files are corrupted: Try opening one file manually")
        print("   2. Check Arrow/PyArrow versions: pip install --upgrade pyarrow pandas")
        print("   3. Try converting files: Use a different tool to re-save the parquet files")
        print("   4. Check file permissions: Ensure files are readable")
        raise ValueError("No parquet files could be loaded successfully! See troubleshooting above.")
    
    if failed_files:
        print(f"\n⚠️  Warning: {len(failed_files)} files failed to load:")
        for file_path, error in failed_files[:3]:  # Show first 3 failures
            print(f"   {Path(file_path).name}: {error}")
        if len(failed_files) > 3:
            print(f"   ... and {len(failed_files) - 3} more")
    
    print(f"\n📊 Concatenating {len(dataframes)} dataframes...")
    # Show progress for concatenation
    with tqdm(total=1, desc="Concatenating", unit="operation") as pbar:
        combined_df = pd.concat(dataframes, ignore_index=True)
        pbar.update(1)
    
    print(f"✅ Total rows after concatenation: {len(combined_df):,}")
    print(f"📈 Successfully loaded {len(dataframes)}/{len(file_paths)} files")
    
    return combined_df

def load_parquet_files_gpu_aware(file_paths: List[str]) -> pd.DataFrame:
    """Load and concatenate multiple parquet files with GPU-aware processing and robust error handling."""
    dataframes = []
    total_rows = 0
    failed_files = []
    
    print(f"📚 Loading {len(file_paths)} parquet files with GPU-aware processing...")
    print(f"   Device target: {GPU_DEVICE}")
    print(f"   Memory limit: {MAX_MEMORY_GB:.1f} GB")
    
    # Memory monitoring
    def get_memory_usage():
        if USE_GPU_PROCESSING and has_gpu:
            gpu_memory = torch.cuda.memory_allocated() / 1024**3
            return f"GPU: {gpu_memory:.1f}GB"
        else:
            ram_usage = psutil.virtual_memory().used / 1024**3
            return f"RAM: {ram_usage:.1f}GB"
    
    # Process files with memory monitoring and robust loading
    for file_path in tqdm(file_paths, desc="Loading files", unit="file"):
        try:
            # Check memory before loading
            memory_info = get_memory_usage()
            
            # Load with robust fallback strategies
            df = load_parquet_with_fallbacks(file_path)
            
            # Memory optimization
            if USE_GPU_PROCESSING and has_gpu:
                # Convert to GPU-friendly format if needed
                # Note: pandas doesn't directly support GPU, but we prepare for downstream GPU processing
                pass
            
            dataframes.append(df)
            total_rows += len(df)
            
            tqdm.write(f"  ✓ Loaded {Path(file_path).name}: {len(df):,} rows | {memory_info}")
            
            # Memory management - garbage collection if needed
            if len(dataframes) % 5 == 0:  # Every 5 files
                import gc
                gc.collect()
                if USE_GPU_PROCESSING and has_gpu:
                    torch.cuda.empty_cache()
                    
        except Exception as e:
            failed_files.append((file_path, str(e)))
            tqdm.write(f"  ✗ Error loading {Path(file_path).name}: {e}")
    
    if not dataframes:
        print("\n❌ No parquet files could be loaded successfully!")
        print("\n🔍 Failed files details:")
        for file_path, error in failed_files:
            print(f"   {Path(file_path).name}: {error}")
        print("\n💡 Advanced troubleshooting:")
        print("   1. Arrow extension types: This error suggests incompatible Arrow metadata")
        print("   2. Install compatible versions: pip install pyarrow==12.0.0 pandas==2.0.3")
        print("   3. Check source: Verify how the parquet files were created")
        print("   4. Manual inspection: Try loading with: pyarrow.parquet.read_table(file)")
        raise ValueError("No parquet files could be loaded successfully! See troubleshooting above.")
    
    if failed_files:
        print(f"\n⚠️  Warning: {len(failed_files)} files failed to load:")
        for file_path, error in failed_files[:3]:  # Show first 3 failures
            print(f"   {Path(file_path).name}: {error}")
        if len(failed_files) > 3:
            print(f"   ... and {len(failed_files) - 3} more")
    
    print(f"\n🔗 Concatenating {len(dataframes)} dataframes...")
    
    # Efficient concatenation with progress tracking
    with tqdm(total=1, desc="Concatenating", unit="operation") as pbar:
        # Use efficient concatenation
        combined_df = pd.concat(dataframes, ignore_index=True, copy=False)
        pbar.update(1)
    
    # Clear intermediate dataframes to free memory
    del dataframes
    import gc
    gc.collect()
    if USE_GPU_PROCESSING and has_gpu:
        torch.cuda.empty_cache()
    
    print(f"✅ Total rows after concatenation: {len(combined_df):,}")
    print(f"📈 Successfully loaded {len(file_paths) - len(failed_files)}/{len(file_paths)} files")
    print(f"   Final memory usage: {get_memory_usage()}")
    
    return combined_df

# Load all data with enhanced error handling
if parquet_files:
    print("\n🔧 Installing required packages for robust parquet loading...")
    try:
        import fastparquet
    except ImportError:
        print("Installing fastparquet for additional fallback support...")
        import subprocess
        import sys
        subprocess.check_call([sys.executable, "-m", "pip", "install", "fastparquet"])
        import fastparquet
    
    start_time = time.time()
    print(f"\n🚀 Starting robust data loading with fallback strategies...")
    
    try:
        df = load_parquet_files_gpu_aware(parquet_files)
        load_time = time.time() - start_time
        print(f"⏱️  Loading completed in {load_time:.2f} seconds")
        print(f"📈 Dataset shape: {df.shape}")
        print(f"📋 Columns: {list(df.columns)}")
        print(f"💾 Memory usage: {df.memory_usage(deep=True).sum() / 1024**2:.1f} MB")
    except ValueError as e:
        print(f"\n❌ All loading strategies failed. Error: {e}")
        print("\n🛠️  Additional debugging steps:")
        print("   1. Check one file manually:")
        if parquet_files:
            print(f"      import pandas as pd")
            print(f"      df = pd.read_parquet('{parquet_files[0]}')")
        print("   2. Check pyarrow version: import pyarrow; print(pyarrow.__version__)")
        print("   3. Try recreating parquet files with: df.to_parquet(path, engine='pyarrow')")
else:
    print("❌ No parquet files found! Please check your INPUT_DATA_PATH.")

NameError: name 'pd' is not defined

## Data Exploration

Explore the structure and content of the loaded data.

In [7]:
# Data exploration
if 'df' in locals():
    print("Dataset Info:")
    print(f"Shape: {df.shape}")
    print(f"Memory usage: {df.memory_usage(deep=True).sum() / 1024**2:.2f} MB")
    print("\nColumn types:")
    print(df.dtypes)
    
    print("\nFirst few rows:")
    display(df.head())
    
    print("\nDataset statistics:")
    print(df.describe(include='all'))
    
    # Check for missing values
    missing_values = df.isnull().sum()
    if missing_values.sum() > 0:
        print("\nMissing values:")
        print(missing_values[missing_values > 0])
    else:
        print("\nNo missing values found.")

Dataset Info:
Shape: (1048581, 9)
Memory usage: 6094.63 MB

Column types:
text               object
id                 object
dump               object
url                object
date               object
file_path          object
language           object
language_score    float64
token_count         int64
dtype: object

First few rows:


Unnamed: 0,text,id,dump,url,date,file_path,language,language_score,token_count
0,|Viewing Single Post From: Spoilers for the We...,<urn:uuid:39147604-bfbe-4ed5-b19c-54105f8ae8a7>,CC-MAIN-2013-20,http://daytimeroyaltyonline.com/single/?p=8906...,2013-05-18T05:48:59Z,s3://commoncrawl/crawl-data/CC-MAIN-2013-20/se...,en,0.82321,142
1,"*sigh* Fundamentalist community, let me pass o...",<urn:uuid:ba819eb7-e6e6-415a-87f4-0347b6a4f017>,CC-MAIN-2013-20,http://endogenousretrovirus.blogspot.com/2007/...,2013-05-18T06:43:03Z,s3://commoncrawl/crawl-data/CC-MAIN-2013-20/se...,en,0.973771,703
2,A novel two-step immunotherapy approach has sh...,<urn:uuid:07b8e00d-b445-4736-a593-cd1c147dce21>,CC-MAIN-2013-20,http://news.cancerconnect.com/,2013-05-18T05:23:15Z,s3://commoncrawl/crawl-data/CC-MAIN-2013-20/se...,en,0.872709,576
3,Free the Cans! Working Together to Reduce Wast...,<urn:uuid:c970d9a2-a5ce-4050-9ea3-58d7bbd609a8>,CC-MAIN-2013-20,http://sharingsolution.com/2009/05/23/free-the...,2013-05-18T05:49:03Z,s3://commoncrawl/crawl-data/CC-MAIN-2013-20/se...,en,0.93236,575
4,"ORLANDO, Fla. — While the Rapid Recall Exchang...",<urn:uuid:5c2cac9e-2fda-4194-959b-6ede0668ad2a>,CC-MAIN-2013-20,http://supermarketnews.com/food-safety/more-su...,2013-05-18T05:25:43Z,s3://commoncrawl/crawl-data/CC-MAIN-2013-20/se...,en,0.955206,708



Dataset statistics:
                                                     text  \
count                                             1048581   
unique                                            1048417   
top     |Track & Field Profile - Embed| Suggest a Corr...   
freq                                                    5   
mean                                                  NaN   
std                                                   NaN   
min                                                   NaN   
25%                                                   NaN   
50%                                                   NaN   
75%                                                   NaN   
max                                                   NaN   

                                                     id             dump  \
count                                           1048581          1048581   
unique                                          1048581                8   
top     <urn:uuid:

## Data Preprocessing

Clean and preprocess the data for training.

In [None]:
def preprocess_data(df: pd.DataFrame) -> pd.DataFrame:
    """
    Preprocess the dataset with progress tracking.
    Modify this function based on your specific data requirements.
    """
    print("🚀 Starting data preprocessing...")
    original_shape = df.shape
    
    # Create a progress bar for preprocessing steps
    preprocessing_steps = [
        "Removing duplicates",
        "Handling missing values", 
        "Filtering text columns",
        "Custom preprocessing"
    ]
    
    with tqdm(total=len(preprocessing_steps), desc="Preprocessing", unit="step") as pbar:
        # 1. Remove duplicates
        pbar.set_description("Removing duplicates")
        df = df.drop_duplicates()
        duplicates_removed = original_shape[0] - df.shape[0]
        tqdm.write(f"  ✓ After removing duplicates: {df.shape} (removed {duplicates_removed:,} rows)")
        pbar.update(1)
        
        # 2. Handle missing values
        pbar.set_description("Handling missing values")
        before_na = len(df)
        df = df.dropna()
        na_removed = before_na - len(df)
        tqdm.write(f"  ✓ After removing missing values: {df.shape} (removed {na_removed:,} rows)")
        pbar.update(1)
        
        # 3. Filter out empty text fields
        pbar.set_description("Filtering text columns")
        text_columns = df.select_dtypes(include=['object']).columns
        
        if len(text_columns) > 0:
            for col in tqdm(text_columns, desc="Processing text cols", leave=False):
                if col in df.columns:
                    initial_len = len(df)
                    df = df[df[col].str.strip().str.len() > 0]
                    removed = initial_len - len(df)
                    if removed > 0:
                        tqdm.write(f"    ✓ Filtered empty {col}: removed {removed:,} rows")
        else:
            tqdm.write("  ℹ️  No text columns found")
        pbar.update(1)
        
        # 4. Custom preprocessing steps
        pbar.set_description("Custom preprocessing")
        # Add any custom preprocessing steps here
        # Example: text length filtering, tokenization, etc.
        tqdm.write("  ✓ Custom preprocessing completed")
        pbar.update(1)
    
    total_removed = original_shape[0] - len(df)
    print(f"✅ Preprocessing complete!")
    print(f"   📊 Original shape: {original_shape}")
    print(f"   📊 Final shape: {df.shape}")
    print(f"   📊 Total rows removed: {total_removed:,} ({total_removed/original_shape[0]*100:.1f}%)")
    
    return df

def preprocess_data_gpu_aware(df: pd.DataFrame) -> pd.DataFrame:
    """
    GPU-aware preprocessing with memory optimization and device management.
    Modify this function based on your specific data requirements.
    """
    print("🚀 Starting GPU-aware data preprocessing...")
    print(f"   Device target: {GPU_DEVICE}")
    print(f"   GPU processing: {'Enabled' if USE_GPU_PROCESSING else 'Disabled'}")
    
    original_shape = df.shape
    initial_memory = df.memory_usage(deep=True).sum() / 1024**2
    print(f"   Initial memory usage: {initial_memory:.1f} MB")
    
    # Memory monitoring function
    def monitor_memory(step_name):
        current_memory = df.memory_usage(deep=True).sum() / 1024**2
        if USE_GPU_PROCESSING and has_gpu:
            gpu_memory = torch.cuda.memory_allocated() / 1024**2
            return f"RAM: {current_memory:.1f}MB, GPU: {gpu_memory:.1f}MB"
        return f"RAM: {current_memory:.1f}MB"
    
    # Create a progress bar for preprocessing steps
    preprocessing_steps = [
        "Memory optimization",
        "Removing duplicates",
        "Handling missing values", 
        "Filtering text columns",
        "GPU preparation",
        "Custom preprocessing"
    ]
    
    with tqdm(total=len(preprocessing_steps), desc="GPU-aware preprocessing", unit="step") as pbar:
        # 0. Memory optimization
        pbar.set_description("Memory optimization")
        # Optimize data types to reduce memory usage
        for col in df.select_dtypes(include=['int64']).columns:
            df[col] = pd.to_numeric(df[col], downcast='integer')
        for col in df.select_dtypes(include=['float64']).columns:
            df[col] = pd.to_numeric(df[col], downcast='float')
        # Optimize object columns
        for col in df.select_dtypes(include=['object']).columns:
            if df[col].nunique() / len(df) < 0.5:  # If less than 50% unique values
                df[col] = df[col].astype('category')
        
        memory_after_opt = df.memory_usage(deep=True).sum() / 1024**2
        memory_saved = initial_memory - memory_after_opt
        tqdm.write(f"  ✓ Memory optimized: {memory_saved:.1f} MB saved | {monitor_memory('optimization')}")
        pbar.update(1)
        
        # 1. Remove duplicates
        pbar.set_description("Removing duplicates")
        df = df.drop_duplicates()
        duplicates_removed = original_shape[0] - df.shape[0]
        tqdm.write(f"  ✓ After removing duplicates: {df.shape} (removed {duplicates_removed:,} rows) | {monitor_memory('duplicates')}")
        pbar.update(1)
        
        # 2. Handle missing values
        pbar.set_description("Handling missing values")
        before_na = len(df)
        df = df.dropna()
        na_removed = before_na - len(df)
        tqdm.write(f"  ✓ After removing missing values: {df.shape} (removed {na_removed:,} rows) | {monitor_memory('missing')}")
        pbar.update(1)
        
        # 3. Filter out empty text fields
        pbar.set_description("Filtering text columns")
        text_columns = df.select_dtypes(include=['object', 'category']).columns
        
        if len(text_columns) > 0:
            for col in tqdm(text_columns, desc="Processing text cols", leave=False):
                if col in df.columns and df[col].dtype in ['object', 'category']:
                    initial_len = len(df)
                    # Convert categorical back to string for filtering
                    if df[col].dtype.name == 'category':
                        df[col] = df[col].astype('str')
                    df = df[df[col].str.strip().str.len() > 0]
                    removed = initial_len - len(df)
                    if removed > 0:
                        tqdm.write(f"    ✓ Filtered empty {col}: removed {removed:,} rows")
        else:
            tqdm.write("  ℹ️  No text columns found")
        
        tqdm.write(f"  ✓ Text filtering complete | {monitor_memory('text_filter')}")
        pbar.update(1)
        
        # 4. GPU preparation
        pbar.set_description("GPU preparation")
        if USE_GPU_PROCESSING and has_gpu:
            tqdm.write(f"  ✓ Data prepared for GPU device: {GPU_DEVICE}")
            tqdm.write(f"  ✓ GPU memory management enabled")
            # Clear any existing GPU cache
            torch.cuda.empty_cache()
        else:
            tqdm.write(f"  ✓ Data optimized for CPU processing")
        pbar.update(1)
        
        # 5. Custom preprocessing steps
        pbar.set_description("Custom preprocessing")
        
        # Text length filtering (if text column exists)
        if 'text' in df.columns:
            initial_len = len(df)
            min_text_length = 50  # Minimum characters
            df = df[df['text'].str.len() >= min_text_length]
            filtered_short = initial_len - len(df)
            if filtered_short > 0:
                tqdm.write(f"    ✓ Filtered short texts (<{min_text_length} chars): removed {filtered_short:,} rows")
        
        # Final memory cleanup
        import gc
        gc.collect()
        if USE_GPU_PROCESSING and has_gpu:
            torch.cuda.empty_cache()
        
        tqdm.write(f"  ✓ Custom preprocessing completed | {monitor_memory('custom')}")
        pbar.update(1)
    
    total_removed = original_shape[0] - len(df)
    final_memory = df.memory_usage(deep=True).sum() / 1024**2
    memory_reduction = initial_memory - final_memory
    
    print(f"✅ GPU-aware preprocessing complete!")
    print(f"   📊 Original shape: {original_shape}")
    print(f"   📊 Final shape: {df.shape}")
    print(f"   📊 Total rows removed: {total_removed:,} ({total_removed/original_shape[0]*100:.1f}%)")
    print(f"   💾 Memory reduction: {memory_reduction:.1f} MB ({memory_reduction/initial_memory*100:.1f}%)")
    print(f"   🎯 Ready for GPU training on device: {GPU_DEVICE}")
    
    return df

# Apply preprocessing
if 'df' in locals():
    start_time = time.time()
    df_processed = preprocess_data(df.copy())
    process_time = time.time() - start_time
    print(f"⏱️  Preprocessing completed in {process_time:.2f} seconds")
else:
    print("❌ No data to preprocess. Please load data first.")

# Apply GPU-aware preprocessing
if 'df' in locals():
    start_time = time.time()
    print(f"\n🔧 Starting preprocessing with GPU configuration:")
    print(f"   Target device: {GPU_DEVICE}")
    print(f"   GPU processing: {'Enabled' if USE_GPU_PROCESSING else 'Disabled'}")
    
    df_processed = preprocess_data_gpu_aware(df.copy())
    process_time = time.time() - start_time
    print(f"⏱️  GPU-aware preprocessing completed in {process_time:.2f} seconds")
else:
    print("❌ No data to preprocess. Please load data first.")

Starting data preprocessing...
After removing duplicates: (1048581, 9) (removed 0 rows)
After removing missing values: (1048581, 9)
After filtering empty text: (1048581, 9) (removed 0 rows)
After filtering empty id: (1048581, 9) (removed 0 rows)
After filtering empty dump: (1048581, 9) (removed 0 rows)
After filtering empty url: (1048581, 9) (removed 0 rows)
After filtering empty date: (1048581, 9) (removed 0 rows)
After filtering empty file_path: (1048581, 9) (removed 0 rows)
After filtering empty language: (1048581, 9) (removed 0 rows)
Preprocessing complete. Final shape: (1048581, 9)


## Train/Eval Split

Split the data into training and evaluation sets.

In [9]:
def split_dataset(df: pd.DataFrame, train_size: float = 0.8, random_state: int = 42) -> tuple:
    """Split dataset into train and eval sets with progress tracking."""
    print(f"🔀 Splitting dataset with train_size={train_size}, random_state={random_state}")
    
    with tqdm(total=3, desc="Dataset splitting", unit="step") as pbar:
        # Shuffle the dataset
        pbar.set_description("Shuffling dataset")
        df_shuffled = df.sample(frac=1, random_state=random_state).reset_index(drop=True)
        pbar.update(1)
        
        # Split the data
        pbar.set_description("Splitting data")
        train_df, eval_df = train_test_split(
            df_shuffled, 
            train_size=train_size, 
            random_state=random_state,
            shuffle=False  # Already shuffled above
        )
        pbar.update(1)
        
        # Validate split
        pbar.set_description("Validating split")
        assert len(train_df) + len(eval_df) == len(df), "Split validation failed!"
        pbar.update(1)
    
    print(f"✅ Split completed!")
    print(f"   📊 Train set: {train_df.shape} ({len(train_df) / len(df)*100:.1f}%)")
    print(f"   📊 Eval set: {eval_df.shape} ({len(eval_df) / len(df)*100:.1f}%)")
    
    return train_df, eval_df

# Split the data
if 'df_processed' in locals():
    start_time = time.time()
    train_df, eval_df = split_dataset(df_processed, TRAIN_SPLIT, RANDOM_SEED)
    split_time = time.time() - start_time
    print(f"⏱️  Splitting completed in {split_time:.2f} seconds")
else:
    print("❌ No processed data to split. Please run preprocessing first.")

Splitting dataset with train_size=0.8, random_state=42
Train set: (838864, 9)
Eval set: (209717, 9)
Train ratio: 0.800
Eval ratio: 0.200


## Save Processed Data

Save the train and eval datasets to parquet files.

In [None]:
# Clean dataframes to prevent Arrow extension type errors during save
if 'train_df' in locals() and 'eval_df' in locals():
    print("🧹 Cleaning dataframes to prevent Arrow extension type errors...")
    
    def clean_dataframe_types(df: pd.DataFrame) -> pd.DataFrame:
        """Clean DataFrame by converting problematic types to basic pandas types."""
        df_clean = df.copy()
        
        print(f"   Cleaning DataFrame with {len(df_clean.columns)} columns...")
        
        for col in df_clean.columns:
            original_dtype = df_clean[col].dtype
            
            # Handle different problematic types
            if 'period' in str(original_dtype).lower():
                print(f"     Converting period column '{col}' to string")
                df_clean[col] = df_clean[col].astype(str)
            elif 'category' in str(original_dtype).lower():
                print(f"     Converting category column '{col}' to string")
                df_clean[col] = df_clean[col].astype(str)
            elif hasattr(original_dtype, 'name') and 'extension' in str(original_dtype).lower():
                print(f"     Converting extension type column '{col}' to string")
                df_clean[col] = df_clean[col].astype(str)
            # Convert any remaining object columns to string to be safe
            elif original_dtype == 'object':
                try:
                    # Try to keep as object if it's already basic
                    test_save = df_clean[col].iloc[:5]
                    pd.DataFrame({col: test_save}).to_parquet('/tmp/test_col.parquet', index=False)
                    os.remove('/tmp/test_col.parquet')
                except:
                    print(f"     Converting problematic object column '{col}' to string")
                    df_clean[col] = df_clean[col].astype(str)
        
        return df_clean
    
    # Clean both dataframes
    print("\n🔧 Cleaning training dataframe...")
    train_df_clean = clean_dataframe_types(train_df)
    
    print("\n🔧 Cleaning evaluation dataframe...")
    eval_df_clean = clean_dataframe_types(eval_df)
    
    print("\n✅ Dataframes cleaned successfully!")
    print(f"   Train DataFrame: {train_df_clean.shape}")
    print(f"   Eval DataFrame: {eval_df_clean.shape}")
    print(f"   Data types summary:")
    print(f"     Train: {train_df_clean.dtypes.value_counts().to_dict()}")
    print(f"     Eval: {eval_df_clean.dtypes.value_counts().to_dict()}")
    
    # Replace original dataframes with cleaned versions
    train_df = train_df_clean
    eval_df = eval_df_clean
    
else:
    print("❌ No data to clean. Please run preprocessing and splitting first.")

In [None]:
def save_datasets(train_df: pd.DataFrame, eval_df: pd.DataFrame, output_dir: str):
    """Save train and eval datasets to parquet files with progress tracking and Arrow error handling."""
    print("💾 Saving datasets...")
    
    def save_parquet_robust(df: pd.DataFrame, path: str) -> None:
        """Save DataFrame to parquet with robust error handling for Arrow extension types."""
        try:
            # Strategy 1: Direct save with pyarrow
            df.to_parquet(path, index=False, engine='pyarrow')
        except Exception as e:
            if "pandas.period already defined" in str(e) or "extension" in str(e).lower():
                print(f"      ⚠️  Arrow extension error, trying fallback methods...")
                
                try:
                    # Strategy 2: Save with fastparquet
                    df.to_parquet(path, index=False, engine='fastparquet')
                except Exception as e2:
                    print(f"      ⚠️  FastParquet failed, cleaning DataFrame...")
                    
                    try:
                        # Strategy 3: Clean DataFrame and save
                        df_clean = clean_dataframe_for_save(df)
                        df_clean.to_parquet(path, index=False, engine='pyarrow')
                    except Exception as e3:
                        print(f"      ⚠️  Clean DataFrame failed, trying final fallback...")
                        
                        # Strategy 4: Convert to basic types and save
                        df_basic = convert_to_basic_types(df)
                        df_basic.to_parquet(path, index=False, engine='pyarrow')
            else:
                raise e
    
    def clean_dataframe_for_save(df: pd.DataFrame) -> pd.DataFrame:
        """Clean DataFrame by removing problematic extension types."""
        df_clean = df.copy()
        
        # Convert any remaining extension types to basic types
        for col in df_clean.columns:
            if hasattr(df_clean[col].dtype, 'name'):
                if 'period' in str(df_clean[col].dtype).lower():
                    # Convert period to string
                    df_clean[col] = df_clean[col].astype(str)
                elif 'category' in str(df_clean[col].dtype).lower():
                    # Convert category to string if it's causing issues
                    df_clean[col] = df_clean[col].astype(str)
                elif 'extension' in str(df_clean[col].dtype).lower():
                    # Convert any other extension types to string
                    df_clean[col] = df_clean[col].astype(str)
        
        return df_clean
    
    def convert_to_basic_types(df: pd.DataFrame) -> pd.DataFrame:
        """Convert DataFrame to only basic pandas types."""
        df_basic = pd.DataFrame()
        
        for col in df.columns:
            try:
                # Try to infer basic type
                series = df[col]
                if series.dtype == 'object':
                    df_basic[col] = series.astype(str)
                elif 'int' in str(series.dtype):
                    df_basic[col] = series.astype('int64')
                elif 'float' in str(series.dtype):
                    df_basic[col] = series.astype('float64')
                elif 'bool' in str(series.dtype):
                    df_basic[col] = series.astype(bool)
                else:
                    # Fallback to string for any complex types
                    df_basic[col] = series.astype(str)
            except Exception:
                # Ultimate fallback
                df_basic[col] = df[col].astype(str)
        
        return df_basic
    
    save_tasks = [
        ("Training data", train_df, f"{output_dir}/train/train_data.parquet"),
        ("Evaluation data", eval_df, f"{output_dir}/eval/eval_data.parquet")
    ]
    
    saved_paths = []
    
    with tqdm(total=len(save_tasks) + 1, desc="Saving datasets", unit="file") as pbar:
        for task_name, data, path in save_tasks:
            pbar.set_description(f"Saving {task_name.lower()}")
            
            # Save with robust error handling
            try:
                save_parquet_robust(data, path)
                file_size = os.path.getsize(path) / 1024**2
                
                tqdm.write(f"  ✓ {task_name} saved to: {path}")
                tqdm.write(f"    📊 Shape: {data.shape}")
                tqdm.write(f"    💾 Size: {file_size:.2f} MB")
                
                saved_paths.append(path)
                pbar.update(1)
                
            except Exception as e:
                tqdm.write(f"  ❌ Failed to save {task_name}: {e}")
                print(f"\n🛠️  Troubleshooting save error:")
                print(f"     Error: {type(e).__name__}: {e}")
                print(f"     Data types in {task_name}:")
                for col, dtype in data.dtypes.items():
                    print(f"       {col}: {dtype}")
                raise e
            
            tqdm.write(f"  ✓ {task_name} saved to: {path}")
            tqdm.write(f"    📊 Shape: {data.shape}")
            tqdm.write(f"    💾 Size: {file_size:.2f} MB")
            
            saved_paths.append(path)
            pbar.update(1)
        
        # Save metadata
        pbar.set_description("Saving metadata")
        metadata = {
            "total_samples": len(train_df) + len(eval_df),
            "train_samples": len(train_df),
            "eval_samples": len(eval_df),
            "train_split": len(train_df) / (len(train_df) + len(eval_df)),
            "eval_split": len(eval_df) / (len(train_df) + len(eval_df)),
            "columns": list(train_df.columns),
            "random_seed": RANDOM_SEED,
            "source_files": len(parquet_files) if 'parquet_files' in locals() else 0,
            "processing_timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
        }
        
        metadata_path = f"{output_dir}/metadata.json"
        with open(metadata_path, 'w') as f:
            json.dump(metadata, f, indent=2)
        
        tqdm.write(f"  ✓ Metadata saved to: {metadata_path}")
        saved_paths.append(metadata_path)
        pbar.update(1)
    
    print("✅ All datasets saved successfully!")
    return saved_paths[0], saved_paths[1], saved_paths[2]

# Save the datasets
if 'train_df' in locals() and 'eval_df' in locals():
    start_time = time.time()
    train_path, eval_path, metadata_path = save_datasets(train_df, eval_df, OUTPUT_DIR)
    save_time = time.time() - start_time
    print(f"⏱️  Saving completed in {save_time:.2f} seconds")
else:
    print("❌ No data to save. Please run the previous cells first.")

Saving datasets...


## Verification

Verify the saved datasets by loading them back and checking their properties.

In [None]:
def verify_saved_data(train_path: str, eval_path: str, metadata_path: str):
    """Verify the saved datasets with progress tracking."""
    print("🔍 Verifying saved datasets...")
    
    verification_tasks = [
        ("Loading metadata", metadata_path),
        ("Loading train data", train_path), 
        ("Loading eval data", eval_path),
        ("Checking data integrity", None)
    ]
    
    with tqdm(total=len(verification_tasks), desc="Verification", unit="task") as pbar:
        # Load metadata
        pbar.set_description("Loading metadata")
        with open(metadata_path, 'r') as f:
            metadata = json.load(f)
        tqdm.write("✓ Metadata loaded:")
        for key, value in metadata.items():
            tqdm.write(f"    {key}: {value}")
        pbar.update(1)
        
        # Load and verify train data
        pbar.set_description("Loading train data")
        train_loaded = pd.read_parquet(train_path)
        tqdm.write(f"✓ Train data loaded from: {Path(train_path).name}")
        tqdm.write(f"    📊 Shape: {train_loaded.shape}")
        tqdm.write(f"    📋 Columns: {list(train_loaded.columns)}")
        pbar.update(1)
        
        # Load and verify eval data
        pbar.set_description("Loading eval data")
        eval_loaded = pd.read_parquet(eval_path)
        tqdm.write(f"✓ Eval data loaded from: {Path(eval_path).name}")
        tqdm.write(f"    📊 Shape: {eval_loaded.shape}")
        tqdm.write(f"    📋 Columns: {list(eval_loaded.columns)}")
        pbar.update(1)
        
        # Data integrity checks
        pbar.set_description("Checking integrity")
        checks_passed = 0
        total_checks = 3
        
        # Check 1: Column consistency
        if list(train_loaded.columns) == list(eval_loaded.columns):
            tqdm.write("  ✓ Column names match between train and eval")
            checks_passed += 1
        else:
            tqdm.write("  ✗ Column names mismatch between train and eval")
        
        # Check 2: No empty datasets
        if len(train_loaded) > 0 and len(eval_loaded) > 0:
            tqdm.write("  ✓ Both datasets contain data")
            checks_passed += 1
        else:
            tqdm.write("  ✗ One or both datasets are empty")
        
        # Check 3: Metadata consistency
        expected_total = metadata['train_samples'] + metadata['eval_samples']
        actual_total = len(train_loaded) + len(eval_loaded)
        if expected_total == actual_total:
            tqdm.write("  ✓ Sample counts match metadata")
            checks_passed += 1
        else:
            tqdm.write(f"  ✗ Sample count mismatch: expected {expected_total}, got {actual_total}")
        
        pbar.update(1)
    print(f"✅ Verification complete! ({checks_passed}/{total_checks} checks passed)")
    return train_loaded, eval_loaded

# Verify the saved data
if all(var in locals() for var in ['train_path', 'eval_path', 'metadata_path']):
    start_time = time.time()
    train_verified, eval_verified = verify_saved_data(train_path, eval_path, metadata_path)
    verify_time = time.time() - start_time
    print(f"⏱️  Verification completed in {verify_time:.2f} seconds")
else:
    print("❌ No saved data to verify. Please run the saving step first.")

## Summary

Data preprocessing and splitting completed successfully! 

### Next Steps:
1. Review the processed data quality
2. Adjust preprocessing parameters if needed
3. Use the saved parquet files for training with nanotron
4. The data paths are ready to be used in your training configuration

### File Outputs:
- Training data: `{OUTPUT_DIR}/train/train_data.parquet`
- Evaluation data: `{OUTPUT_DIR}/eval/eval_data.parquet` 
- Metadata: `{OUTPUT_DIR}/metadata.json`

In [None]:
# Final summary with enhanced progress display
if all(var in locals() for var in ['train_df', 'eval_df']):
    print("🎉 " + "="*60 + " 🎉")
    print("📊 GPU-AWARE DATA PREPROCESSING SUMMARY")
    print("🎉 " + "="*60 + " 🎉")
    
    # Hardware configuration summary
    print(f"\n🖥️  Hardware Configuration:")
    print(f"   Target GPU device: {GPU_DEVICE}")
    print(f"   GPU processing: {'Enabled' if USE_GPU_PROCESSING else 'Disabled'}")
    print(f"   Processing chunk size: {CHUNK_SIZE:,} rows")
    print(f"   Parallel workers: {NUM_WORKERS}")
    print(f"   Max memory usage: {MAX_MEMORY_GB:.1f} GB")
    
    # Calculate total processing time if variables exist
    total_time = 0
    if 'load_time' in locals():
        total_time += load_time
        print(f"\n⏱️  Performance Metrics:")
        print(f"   Loading time: {load_time:.2f}s")
    if 'process_time' in locals():
        total_time += process_time  
        print(f"   Processing time: {process_time:.2f}s")
    if 'split_time' in locals():
        total_time += split_time
        print(f"   Splitting time: {split_time:.2f}s")
    if 'save_time' in locals():
        total_time += save_time
        print(f"   Saving time: {save_time:.2f}s")
    if 'verify_time' in locals():
        total_time += verify_time
        print(f"   Verification time: {verify_time:.2f}s")
    
    if total_time > 0:
        print(f"   Total processing time: {total_time:.2f}s")
    
    # Data summary
    print(f"\n📈 Data Summary:")
    print(f"   Input files processed: {len(parquet_files) if 'parquet_files' in locals() else 0}")ing and splitting completed successfully!** 
    print(f"   Total samples: {len(train_df) + len(eval_df):,}")
    print(f"   Training samples: {len(train_df):,} ({len(train_df)/(len(train_df)+len(eval_df))*100:.1f}%)")
    print(f"   Evaluation samples: {len(eval_df):,} ({len(eval_df)/(len(train_df)+len(eval_df))*100:.1f}%)")GPU or CPU
    print(f"   Output directory: {OUTPUT_DIR}")*: Dynamic memory management based on available hardware
    ing optimized for your system
    # File paths for trainingorker data loading when supported
    print(f"\n📁 Training Files:")
    print(f"   Train data: {OUTPUT_DIR}/train/train_data.parquet")
    print(f"   Eval data: {OUTPUT_DIR}/eval/eval_data.parquet")in_data.parquet`
    print(f"   Metadata: {OUTPUT_DIR}/metadata.json")Evaluation data**: `{OUTPUT_DIR}/eval/eval_data.parquet` 
    `{OUTPUT_DIR}/metadata.json`
    # Next steps for training
    print(f"\n🚀 Next Steps for Training:")g:
    print(f"   1. 📝 Review the processed data quality")Check the processed data statistics above













































            print(f"  ❌ {step}")        for step in missing_vars:        print("Missing steps:")    if missing_vars:                missing_vars.append("🔀 Data splitting")    if 'train_df' not in locals() or 'eval_df' not in locals():        missing_vars.append("🔧 Data preprocessing")    if 'df_processed' not in locals():        missing_vars.append("📈 Data loading")    if 'df' not in locals():        missing_vars.append("📁 File discovery")    if 'parquet_files' not in locals():    missing_vars = []    # Show which steps are missing        print("❌ Please run all cells above to complete the data preprocessing pipeline.")else:    print("🎉 " + "="*60 + " 🎉")    print("\n🎉 Data ready for GPU-accelerated training with Infini attention!")            print(f"   {key}: {value}")    for key, value in training_config.items():    print(f"\n📄 Training Configuration (copy to training script):")        }        "total_samples": len(train_df) + len(eval_df)        "flash_attention": USE_GPU_PROCESSING,        "mixed_precision": USE_GPU_PROCESSING,        "gradient_accumulation_steps": 4,        "batch_size_per_gpu": 16 if USE_GPU_PROCESSING else 4,        "use_gpu": USE_GPU_PROCESSING,        "device": GPU_DEVICE,        "eval_file": f"{OUTPUT_DIR}/eval/eval_data.parquet",        "train_file": f"{OUTPUT_DIR}/train/train_data.parquet",        "data_path": OUTPUT_DIR,    training_config = {    # Configuration for training script        print(f"   4. 📊 Monitor training progress and GPU utilization")    print(f"   3. 🎯 Run the training notebook (scripts/train.ipynb)")    print(f"      - Memory optimization: {MAX_MEMORY_GB:.1f} GB limit")    print(f"      - GPU processing: {'Enabled' if USE_GPU_PROCESSING else 'Disabled'}")    print(f"      - Device: {GPU_DEVICE}")    print(f"   2. 🔧 Configure your training environment with these settings:")    2. **Training Environment**: Use the detected GPU configuration for optimal performance
    3. **Nanotron Training**: Run `scripts/train.ipynb` with the generated data files
    4. **Monitoring**: Track GPU utilization and memory usage during training

    ### ⚙️ Configuration for Training Script:
    ```python
    # Use these paths in your training configuration
    TRAIN_DATA_PATH = "{OUTPUT_DIR}/train/train_data.parquet"
    EVAL_DATA_PATH = "{OUTPUT_DIR}/eval/eval_data.parquet"
    DEVICE = "cuda:0"  # or your detected GPU device
    USE_FLASH_ATTENTION = True  # if GPU supports it
    MIXED_PRECISION = True      # for faster training
    ```

    ### 📊 Performance Optimization:
    - **GPU Memory**: Optimized for available VRAM
    - **Batch Size**: Automatically configured based on your hardware
    - **Data Loading**: Parallel workers for efficient I/O
    - **Memory Management**: Garbage collection and cache clearing
else:
    print("❌ Please run all cells above to complete the data preprocessing pipeline.")
    
    # Show which steps are missing
    missing_vars = []
    if 'parquet_files' not in locals():
        missing_vars.append("📁 File discovery")
    if 'df' not in locals():
        missing_vars.append("📊 Data loading")
    if 'df_processed' not in locals():
        missing_vars.append("🔧 Data preprocessing")
    if 'train_df' not in locals() or 'eval_df' not in locals():
        missing_vars.append("🔀 Data splitting")
        
    if missing_vars:
        print("Missing steps:")
        for step in missing_vars:
            print(f"  ❌ {step}")