In [None]:
import os
!pip uninstall -y numpy transformers datasets
!pip install numpy --force-reinstall --no-cache-dir
!pip install transformers datasets --force-reinstall --no-cache-dir


os.kill(os.getpid(), 9)  # Restart the Colab runtime (REQUIRED)

In [None]:
!git clone https://github.com/babylm/baseline-pretraining.git
%cd baseline-pretraining
!wget -O babylm_data.zip "https://files.osf.io/v1/resources/ad7qg/providers/osfstorage/661517db943bee3731dfec25/?zip="
!unzip babylm_data.zip -d babylm_data
!unzip babylm_data/train_10M.zip -d babylm_data/train_10M
!unzip babylm_data/dev.zip -d babylm_data/dev
!unzip babylm_data/test.zip -d babylm_data/test
!cat babylm_data/train_10M/train_10M/*.train > babylm_data/babylm_train.txt
!cat babylm_data/dev/dev/*.dev > babylm_data/babylm_dev.txt
!cat babylm_data/test/test/*.test > babylm_data/babylm_test.txt

In [None]:
# =============================================================================
# ENHANCED DOMAIN-AWARE T5 TRAINING SYSTEM WITH ROBUST SAVE/LOAD
# =============================================================================

import os
import re
import pickle
import torch
import numpy as np
import random
import hashlib
import json
import time
from typing import Dict, List, Tuple, Optional, Union, Any
from collections import Counter, defaultdict
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config
from transformers import get_linear_schedule_with_warmup
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader, Sampler
from concurrent.futures import ThreadPoolExecutor
from functools import lru_cache
from tqdm import tqdm
import shutil
from datetime import datetime
from pathlib import Path

# Set random seeds for reproducibility
def set_random_seeds(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

set_random_seeds()

# =============================================================================
# ROBUST DATA MANAGEMENT SYSTEM
# =============================================================================
class RobustDataManager:
    """Enhanced data management with pattern matching for datasets"""

    def __init__(self, base_dir: str, project_name: str = "t5_training"):
        self.base_dir = Path(base_dir)
        self.project_name = project_name

        # Create structured directory hierarchy
        self.structure = {
            'root': self.base_dir / project_name,
            'datasets': self.base_dir / project_name / 'datasets',
            'preprocessed': self.base_dir / project_name / 'datasets' / 'preprocessed',
            'complexity_cache': self.base_dir / project_name / 'datasets' / 'complexity_cache',
            'models': self.base_dir / project_name / 'models',
            'checkpoints': self.base_dir / project_name / 'models' / 'checkpoints',
            'final_models': self.base_dir / project_name / 'models' / 'final',
            'logs': self.base_dir / project_name / 'logs',
            'metadata': self.base_dir / project_name / 'metadata'
        }

        # Create all directories
        for name, path in self.structure.items():
            path.mkdir(parents=True, exist_ok=True)

        # Version for cache invalidation
        self.version = "v3.0_stable"

        print(f"Data Manager initialized at: {self.structure['root']}")
        self._log_directory_structure()

    def _log_directory_structure(self):
        """Log the created directory structure"""
        print("Directory structure:")
        for name, path in self.structure.items():
            print(f"  {name}: {path}")

    def get_path(self, path_type: str) -> Path:
        """Get path for specific data type"""
        if path_type not in self.structure:
            raise ValueError(f"Unknown path type: {path_type}")
        return self.structure[path_type]

    def generate_cache_key(self, data_source: str, config_params: Dict) -> str:
        """Generate unique cache key based on data source and configuration"""
        try:
            if os.path.exists(data_source):
                file_stat = os.stat(data_source)
                source_hash = f"size_{file_stat.st_size}"
            else:
                source_hash = "no_file"
        except:
            source_hash = "error"

        essential_config = {
            'max_source_length': config_params.get('max_source_length', 512),
            'max_target_length': config_params.get('max_target_length', 256),
            'split': config_params.get('split', 'train'),
            'max_examples': config_params.get('max_examples', None)
        }

        config_str = json.dumps(essential_config, sort_keys=True)
        combined = f"{source_hash}_{config_str}_{self.version}"

        return hashlib.md5(combined.encode()).hexdigest()[:16]

    def save_preprocessed_data(self, data: Any, cache_key: str, data_type: str) -> bool:
        """Save preprocessed data with integrity checks"""
        try:
            file_path = self.get_path('preprocessed') / f"{data_type}_{cache_key}.pkl"
            file_path.parent.mkdir(parents=True, exist_ok=True)

            # Save data
            with open(file_path, 'wb') as f:
                pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)

            # Save metadata
            metadata = {
                'cache_key': cache_key,
                'data_type': data_type,
                'timestamp': datetime.now().isoformat(),
                'version': self.version,
                'file_size': os.path.getsize(file_path),
                'data_info': self._get_data_info(data),
                'absolute_path': str(file_path.absolute())
            }

            metadata_path = self.get_path('metadata') / f"{data_type}_{cache_key}_meta.json"
            metadata_path.parent.mkdir(parents=True, exist_ok=True)

            with open(metadata_path, 'w') as f:
                json.dump(metadata, f, indent=2)

            print(f"✓ Saved {data_type} data: {file_path.name} ({os.path.getsize(file_path)/1024/1024:.1f}MB)")
            return True

        except Exception as e:
            print(f"✗ Failed to save {data_type} data: {e}")
            return False

    def load_preprocessed_data(self, cache_key: str, data_type: str) -> Optional[Any]:
        """FIXED: Load preprocessed data with flexible pattern matching"""
        try:
            # First try exact match
            file_path = self.get_path('preprocessed') / f"{data_type}_{cache_key}.pkl"
            print(f"🔍 Looking for exact match: {file_path.name}")

            if file_path.exists():
                return self._load_cache_file(file_path, data_type)

            # FIXED: More flexible pattern matching for datasets
            print(f"🔍 Searching for pattern: {data_type}_*.pkl")

            preprocessed_dir = self.get_path('preprocessed')
            if not preprocessed_dir.exists():
                print(f"   Preprocessed directory doesn't exist")
                return None

            # Look for files that start with the data_type prefix
            pattern = f"{data_type}_*.pkl"
            cache_files = list(preprocessed_dir.glob(pattern))

            print(f"   Found {len(cache_files)} potential cache files")

            if not cache_files:
                return None

            # FIXED: Better file selection - prioritize by modification time and compatibility
            compatible_files = []

            for cache_file in cache_files:
                print(f"   Examining: {cache_file.name}")
                # Try to extract info from filename
                file_parts = cache_file.stem.split('_')
                if len(file_parts) >= 2:
                    file_data_type = file_parts[0]
                    if file_data_type == data_type.split('_')[0]:  # Match base type (e.g., 'dataset' from 'dataset_train')
                        compatible_files.append(cache_file)
                        print(f"     ✓ Compatible file found")

            if not compatible_files:
                # If no compatible files by name, try all files with the pattern
                compatible_files = cache_files
                print(f"   No name-compatible files, trying all {len(compatible_files)} files")

            # Sort by modification time (newest first)
            compatible_files.sort(key=lambda f: f.stat().st_mtime, reverse=True)

            # Try loading each file until one works
            for cache_file in compatible_files:
                print(f"   Attempting to load: {cache_file.name}")
                data = self._load_cache_file(cache_file, data_type)
                if data is not None:
                    print(f"✅ Successfully loaded: {cache_file.name}")
                    return data

            print(f"   No compatible cache files could be loaded")
            return None

        except Exception as e:
            print(f"   Cache loading error: {e}")
            return None
    def _find_and_load_best_match(self, data_type: str, cache_key: str) -> Optional[Any]:
        """FIXED: Find and load best matching cache file"""
        try:
            preprocessed_dir = self.get_path('preprocessed')
            pattern = f"{data_type}_*.pkl"

            cache_files = list(preprocessed_dir.glob(pattern))
            print(f"   Found {len(cache_files)} potential cache files")

            if not cache_files:
                return None

            # Sort by modification time (newest first)
            cache_files.sort(key=lambda f: f.stat().st_mtime, reverse=True)

            # Try loading each file until one works
            for cache_file in cache_files:
                print(f"   Trying: {cache_file.name}")
                data = self._load_cache_file(cache_file, data_type)
                if data is not None:
                    print(f"✓ Successfully loaded: {cache_file.name}")
                    return data

            print(f"   No compatible cache files found")
            return None

        except Exception as e:
            print(f"   Pattern matching failed: {e}")
            return None

    def _load_cache_file(self, file_path: Path, data_type: str) -> Optional[Any]:
        """Load a specific cache file"""
        try:
            with open(file_path, 'rb') as f:
                data = pickle.load(f)
            return data
        except Exception as e:
            print(f"   Failed to load {file_path.name}: {e}")
            return None

    def _get_data_info(self, data: Any) -> Dict:
        """Get information about data for metadata"""
        if hasattr(data, '__len__'):
            return {'length': len(data), 'type': type(data).__name__}
        elif isinstance(data, dict):
            return {'keys': list(data.keys()), 'type': 'dict'}
        else:
            return {'type': type(data).__name__}

    def save_model_checkpoint(self, model, tokenizer, optimizer, scheduler,
                            epoch: int, metrics: Dict, is_best: bool = False) -> bool:
        """Save model checkpoint with comprehensive metadata"""
        try:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

            if is_best:
                checkpoint_dir = self.get_path('checkpoints') / 'best_model'
            else:
                checkpoint_dir = self.get_path('checkpoints') / f'epoch_{epoch+1}_{timestamp}'

            checkpoint_dir.mkdir(parents=True, exist_ok=True)

            # Save model and tokenizer
            model.save_pretrained(checkpoint_dir)
            tokenizer.save_pretrained(checkpoint_dir)

            # Save training state
            training_state = {
                'epoch': epoch,
                'optimizer_state': optimizer.state_dict(),
                'scheduler_state': scheduler.state_dict() if scheduler else None,
                'metrics': metrics,
                'timestamp': timestamp,
                'global_step': getattr(self, 'global_step', 0)
            }

            torch.save(training_state, checkpoint_dir / 'training_state.pt')

            # Save checkpoint metadata
            metadata = {
                'checkpoint_name': checkpoint_dir.name,
                'epoch': epoch + 1,
                'timestamp': timestamp,
                'is_best': is_best,
                'metrics': metrics,
                'model_config': model.config.to_dict() if hasattr(model.config, 'to_dict') else {},
                'directory_size_mb': self._get_directory_size_mb(checkpoint_dir),
                'absolute_path': str(checkpoint_dir.absolute())
            }

            with open(checkpoint_dir / 'checkpoint_metadata.json', 'w') as f:
                json.dump(metadata, f, indent=2)

            print(f"✓ Checkpoint saved: {checkpoint_dir.name}")
            return True

        except Exception as e:
            print(f"✗ Failed to save checkpoint: {e}")
            return False

    def _get_directory_size_mb(self, directory: Path) -> float:
        """Get directory size in MB"""
        try:
            total_size = sum(f.stat().st_size for f in directory.rglob('*') if f.is_file())
            return total_size / (1024 * 1024)
        except:
            return 0.0


    def list_checkpoints(self) -> List[Dict]:
        """FIXED: Better checkpoint listing focusing on epoch folders"""
        checkpoints = []
        checkpoint_dir = self.get_path('checkpoints')

        print(f"🔍 Scanning for checkpoints in: {checkpoint_dir}")

        if not checkpoint_dir.exists():
            print("   Checkpoint directory doesn't exist")
            return checkpoints

        # FIXED: Separate different types of checkpoint folders
        epoch_dirs = []
        best_dirs = []
        other_dirs = []

        for subdir in checkpoint_dir.iterdir():
            if subdir.is_dir():
                dir_name = subdir.name.lower()
                if dir_name.startswith('epoch_'):
                    epoch_dirs.append(subdir)
                elif 'best' in dir_name:
                    best_dirs.append(subdir)
                else:
                    other_dirs.append(subdir)

        print(f"   Found: {len(epoch_dirs)} epoch dirs, {len(best_dirs)} best dirs, {len(other_dirs)} other dirs")

        # Process in priority order: epoch folders first, then best, then others
        all_dirs = epoch_dirs + best_dirs + other_dirs

        for subdir in all_dirs:
            print(f"   Examining directory: {subdir.name}")
            checkpoint_info = self._analyze_checkpoint_directory(subdir)
            if checkpoint_info:
                checkpoints.append(checkpoint_info)
                print(f"     ✓ Valid checkpoint: Epoch {checkpoint_info.get('epoch', 'N/A')}")

        print(f"   Total valid checkpoints: {len(checkpoints)}")
        return sorted(checkpoints, key=lambda x: x.get('epoch', 0))
    def _infer_checkpoint_metadata(self, checkpoint_dir: Path) -> Optional[Dict]:
        """FIXED: Infer checkpoint metadata and check for required files"""
        try:
            # FIXED: Check for either pytorch_model.bin OR model.safetensors
            has_model = (
                (checkpoint_dir / 'pytorch_model.bin').exists() or
                (checkpoint_dir / 'model.safetensors').exists()
            )

            has_config = (checkpoint_dir / 'config.json').exists()
            has_tokenizer = (checkpoint_dir / 'tokenizer.json').exists() or (checkpoint_dir / 'spiece.model').exists()

            if not (has_model and has_config):
                print(f"     Missing required files in {checkpoint_dir.name}")
                return None

            # Try to extract epoch from directory name
            dir_name = checkpoint_dir.name
            epoch = 0

            # Look for epoch pattern
            epoch_match = re.search(r'epoch_(\d+)', dir_name)
            if epoch_match:
                epoch = int(epoch_match.group(1))

            # Check for training state
            training_state_exists = (checkpoint_dir / 'training_state.pt').exists()

            return {
                'checkpoint_name': dir_name,
                'epoch': epoch,
                'timestamp': datetime.now().strftime("%Y%m%d_%H%M%S"),
                'is_best': 'best' in dir_name.lower(),
                'metrics': {'loss': 0.0},  # Placeholder
                'path': str(checkpoint_dir),
                'inferred': True,
                'has_training_state': training_state_exists,
                'has_safetensors': (checkpoint_dir / 'model.safetensors').exists()
            }

        except Exception as e:
            print(f"     Failed to infer metadata: {e}")
            return None

    def cleanup_old_checkpoints(self, keep_last_n: int = 5):
        """Clean up old checkpoints, keeping only the last N and best model"""
        checkpoints = self.list_checkpoints()

        # Separate best model from regular checkpoints
        regular_checkpoints = [cp for cp in checkpoints if not cp.get('is_best', False)]

        if len(regular_checkpoints) > keep_last_n:
            to_remove = regular_checkpoints[:-keep_last_n]

            for checkpoint in to_remove:
                try:
                    shutil.rmtree(checkpoint['path'])
                    print(f"✓ Removed old checkpoint: {Path(checkpoint['path']).name}")
                except Exception as e:
                    print(f"✗ Failed to remove checkpoint {checkpoint['path']}: {e}")
    def _analyze_checkpoint_directory(self, checkpoint_dir: Path) -> Optional[Dict]:
        """FIXED: Analyze checkpoint directory and validate required files"""
        try:
            # Check for required files with flexible model format support
            has_safetensors = (checkpoint_dir / 'model.safetensors').exists()
            has_pytorch_model = (checkpoint_dir / 'pytorch_model.bin').exists()
            has_config = (checkpoint_dir / 'config.json').exists()
            has_tokenizer = (
                (checkpoint_dir / 'tokenizer.json').exists() or
                (checkpoint_dir / 'spiece.model').exists() or
                (checkpoint_dir / 'tokenizer_config.json').exists()
            )
            has_training_state = (checkpoint_dir / 'training_state.pt').exists()

            print(f"     File check for {checkpoint_dir.name}:")
            print(f"       model.safetensors: {'✓' if has_safetensors else '✗'}")
            print(f"       pytorch_model.bin: {'✓' if has_pytorch_model else '✗'}")
            print(f"       config.json: {'✓' if has_config else '✗'}")
            print(f"       tokenizer files: {'✓' if has_tokenizer else '✗'}")
            print(f"       training_state.pt: {'✓' if has_training_state else '✗'}")

            # FIXED: Accept either safetensors OR pytorch_model.bin
            has_model = has_safetensors or has_pytorch_model

            if not (has_model and has_config):
                print(f"     ❌ Missing essential files (model or config)")
                return None

            # Extract epoch information from directory name or metadata
            dir_name = checkpoint_dir.name
            epoch = 0

            # Try to get epoch from metadata first
            metadata_file = checkpoint_dir / 'checkpoint_metadata.json'
            if metadata_file.exists():
                try:
                    with open(metadata_file, 'r') as f:
                        metadata = json.load(f)
                    epoch = metadata.get('epoch', 0)
                    print(f"     📋 Metadata epoch: {epoch}")
                except:
                    pass

            # Fallback to extracting from directory name
            if epoch == 0:
                epoch_match = re.search(r'epoch[_\-]?(\d+)', dir_name, re.I)
                if epoch_match:
                    epoch = int(epoch_match.group(1))
                    print(f"     📁 Directory name epoch: {epoch}")

            # Try to get loss from training_state
            loss = 0.0
            if has_training_state:
                try:
                    training_state = torch.load(checkpoint_dir / 'training_state.pt', map_location='cpu')
                    loss = training_state.get('metrics', {}).get('loss', 0.0)
                    print(f"     📊 Training state loss: {loss}")
                except:
                    pass

            return {
                'checkpoint_name': dir_name,
                'epoch': epoch,
                'timestamp': datetime.now().strftime("%Y%m%d_%H%M%S"),
                'is_best': 'best' in dir_name.lower(),
                'metrics': {'loss': loss},
                'path': str(checkpoint_dir),
                'has_safetensors': has_safetensors,
                'has_pytorch_model': has_pytorch_model,
                'has_training_state': has_training_state,
                'has_tokenizer': has_tokenizer,
                'model_format': 'safetensors' if has_safetensors else 'pytorch_bin'
            }

        except Exception as e:
            print(f"     ❌ Error analyzing checkpoint: {e}")
            return None


# =============================================================================
# ENHANCED COMPLEXITY ANALYZER
# =============================================================================

class EnhancedComplexityAnalyzer:
    """Improved complexity analyzer with efficient caching"""

    def __init__(self, tokenizer: T5Tokenizer, data_manager: RobustDataManager):
        self.tokenizer = tokenizer
        self.data_manager = data_manager

        # Compile regex patterns for efficiency
        self.patterns = {
            'words': re.compile(r'\b\w+\b'),
            'sentences': re.compile(r'[.!?]+'),
            'subordinate_clauses': re.compile(r'\b(?:although|because|since|while|if|when|after|before|until|unless|whereas)\b', re.I),
            'relative_clauses': re.compile(r'\b(?:which|that|who|whom|whose|where)\s+\w+', re.I),
            'passive_voice': re.compile(r'\b(?:was|were|is|are|been)\s+\w+ed\b', re.I),
            'complex_tenses': re.compile(r'\b(?:have|has|had|will|would|could|should|might|must)\s+(?:been|have)\b', re.I),
            'discourse_markers': re.compile(r'\b(?:however|therefore|furthermore|moreover|nevertheless|consequently)\b', re.I)
        }

        # Load frequency data
        self.word_frequencies = self._load_or_create_frequency_data()

        # Morphological complexity indicators
        self.complex_suffixes = {
            'derivational': ['tion', 'sion', 'ment', 'ness', 'ity', 'ism', 'ance', 'ence'],
            'inflectional': ['ing', 'ed', 'er', 'est', 's', 'es'],
            'academic': ['ological', 'istically', 'ification']
        }

        self.complex_prefixes = {'un', 're', 'pre', 'dis', 'in', 'im', 'non', 'over', 'under', 'mis', 'anti', 'inter', 'multi'}

        # Cache for complexity scores
        self._cache = {}
        self._cache_stats = {'hits': 0, 'misses': 0}

        print("Enhanced Complexity Analyzer initialized")

    def _load_or_create_frequency_data(self) -> Dict[str, int]:
        """Load or create word frequency data"""
        cache_key = "word_frequencies"

        frequencies = self.data_manager.load_preprocessed_data(cache_key, 'complexity')

        if frequencies is None:
            # Create basic frequency data
            high_freq_words = [
                'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by',
                'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has', 'had', 'do', 'does', 'did',
                'will', 'would', 'can', 'could', 'should', 'may', 'might', 'I', 'you', 'he', 'she', 'it', 'we', 'they',
                'this', 'that', 'these', 'those', 'time', 'person', 'year', 'way', 'day', 'thing', 'man', 'world',
                'life', 'hand', 'part', 'child', 'eye', 'woman', 'place', 'work', 'week', 'case', 'point'
            ]

            frequencies = {word.lower(): len(high_freq_words) - i for i, word in enumerate(high_freq_words)}

            # Save for future use
            self.data_manager.save_preprocessed_data(frequencies, cache_key, 'complexity')

        return frequencies

    def analyze_text_complexity_batch(self, texts: List[str]) -> List[Dict[str, float]]:
        """Analyze complexity for multiple texts with caching"""
        results = []

        for text in texts:
            # Generate cache key from text hash
            text_hash = hashlib.md5(text[:500].encode()).hexdigest()[:12]

            if text_hash in self._cache:
                self._cache_stats['hits'] += 1
                results.append(self._cache[text_hash])
            else:
                self._cache_stats['misses'] += 1
                result = self._analyze_single_text(text)
                self._cache[text_hash] = result
                results.append(result)

        return results

    def _analyze_single_text(self, text: str) -> Dict[str, float]:
        """Analyze complexity of single text with practical metrics"""
        if not text or len(text.strip()) < 10:
            return self._get_default_scores()

        # Extract basic features
        words = self.patterns['words'].findall(text.lower())
        sentences = [s.strip() for s in self.patterns['sentences'].split(text) if s.strip()]

        if not words or not sentences:
            return self._get_default_scores()

        # Calculate complexity dimensions
        morphological = self._calculate_morphological_complexity(words)
        syntactic = self._calculate_syntactic_complexity(text, words, sentences)
        semantic = self._calculate_semantic_complexity(words)
        lexical = self._calculate_lexical_complexity(words)

        # Overall complexity as weighted combination
        overall = (
            morphological * 0.25 +
            syntactic * 0.35 +
            semantic * 0.25 +
            lexical * 0.15
        )

        return {
            'overall_complexity': self._normalize_score(overall),
            'morphological_complexity': self._normalize_score(morphological),
            'syntactic_complexity': self._normalize_score(syntactic),
            'semantic_complexity': self._normalize_score(semantic),
            'lexical_complexity': self._normalize_score(lexical),
            'sentence_length_complexity': self._normalize_score(len(words) / len(sentences) / 15),
            'vocabulary_diversity': self._normalize_score(len(set(words)) / len(words))
        }

    def _calculate_morphological_complexity(self, words: List[str]) -> float:
        """Calculate morphological complexity based on word structure"""
        if not words:
            return 0.1

        complexity_score = 0.0

        for word in words:
            word_score = 0.0

            # Length-based complexity
            if len(word) > 8:
                word_score += 0.3
            elif len(word) > 6:
                word_score += 0.15

            # Suffix complexity
            for suffix_type, suffixes in self.complex_suffixes.items():
                for suffix in suffixes:
                    if word.endswith(suffix) and len(word) > len(suffix) + 2:
                        if suffix_type == 'derivational':
                            word_score += 0.4
                        elif suffix_type == 'academic':
                            word_score += 0.6
                        else:
                            word_score += 0.2
                        break

            # Prefix complexity
            for prefix in self.complex_prefixes:
                if word.startswith(prefix) and len(word) > len(prefix) + 2:
                    word_score += 0.2
                    break

            complexity_score += min(word_score, 1.0)

        return complexity_score / len(words)

    def _calculate_syntactic_complexity(self, text: str, words: List[str], sentences: List[str]) -> float:
        """Calculate syntactic complexity based on sentence structure"""
        if not words:
            return 0.1

        complexity_score = 0.0

        # Clause complexity
        for pattern_name, pattern in self.patterns.items():
            if pattern_name not in ['words', 'sentences']:
                matches = len(pattern.findall(text))
                complexity_score += matches / len(words) * 2

        # Sentence length variation
        if len(sentences) > 1:
            sent_lengths = [len(s.split()) for s in sentences]
            avg_length = np.mean(sent_lengths)
            length_variation = np.std(sent_lengths) / (avg_length + 1)
            complexity_score += length_variation * 0.5

        # Average sentence length complexity
        avg_sent_length = len(words) / len(sentences)
        length_complexity = min((avg_sent_length - 10) / 20, 1.0)
        complexity_score += max(length_complexity, 0) * 0.3

        return complexity_score

    def _calculate_semantic_complexity(self, words: List[str]) -> float:
        """Calculate semantic complexity based on vocabulary characteristics"""
        if not words:
            return 0.1

        # Lexical diversity
        unique_words = set(words)
        diversity = len(unique_words) / len(words)

        # Rare word density
        rare_words = sum(1 for word in words if word not in self.word_frequencies)
        rare_density = rare_words / len(words)

        # Abstract concepts (simplified heuristic)
        abstract_indicators = {'concept', 'idea', 'theory', 'principle', 'notion', 'aspect', 'factor', 'element'}
        abstract_count = sum(1 for word in words if word in abstract_indicators)
        abstract_density = abstract_count / len(words)

        semantic_complexity = (
            diversity * 0.4 +
            min(rare_density * 3, 1.0) * 0.4 +
            min(abstract_density * 5, 1.0) * 0.2
        )

        return semantic_complexity

    def _calculate_lexical_complexity(self, words: List[str]) -> float:
        """Calculate lexical complexity based on word frequency and sophistication"""
        if not words:
            return 0.1

        sophistication_score = 0.0

        for word in words:
            word_score = 0.0

            # Frequency-based scoring
            if word in self.word_frequencies:
                freq_rank = self.word_frequencies[word]
                if freq_rank > 1000:  # Low frequency = high complexity
                    word_score += 0.5
                elif freq_rank > 500:
                    word_score += 0.3
            else:
                word_score += 0.7  # Unknown words are complex

            # Length-based sophistication
            if len(word) > 10:
                word_score += 0.3
            elif len(word) > 7:
                word_score += 0.15

            sophistication_score += min(word_score, 1.0)

        return sophistication_score / len(words)

    def _normalize_score(self, score: float) -> float:
        """Normalize score to [0.01, 1.0] range"""
        return max(0.01, min(score, 1.0))

    def _get_default_scores(self) -> Dict[str, float]:
        """Return default complexity scores for invalid texts"""
        return {
            'overall_complexity': 0.1,
            'morphological_complexity': 0.1,
            'syntactic_complexity': 0.1,
            'semantic_complexity': 0.1,
            'lexical_complexity': 0.1,
            'sentence_length_complexity': 0.1,
            'vocabulary_diversity': 0.1
        }

    def get_cache_stats(self) -> Dict:
        """Get cache performance statistics"""
        total = self._cache_stats['hits'] + self._cache_stats['misses']
        hit_rate = (self._cache_stats['hits'] / total * 100) if total > 0 else 0
        return {
            'cache_hits': self._cache_stats['hits'],
            'cache_misses': self._cache_stats['misses'],
            'hit_rate_percent': hit_rate,
            'cache_size': len(self._cache)
        }


# =============================================================================
# IMPROVED TASK CREATOR
# =============================================================================

class ImprovedTaskCreator:
    """Enhanced task creator with diverse and practical tasks"""

    def __init__(self, tokenizer: T5Tokenizer, max_source_length: int = 512, max_target_length: int = 256):
        self.tokenizer = tokenizer
        self.max_source_length = max_source_length
        self.max_target_length = max_target_length

        # Precompiled patterns
        self.sentence_pattern = re.compile(r'(?<=[.!?])\s+')
        self.word_pattern = re.compile(r'\b\w+\b')

        # Task templates for better variety
        self.task_templates = {
            'morphological': [
                ('inflect', 'inflect {word} with {suffix}:'),
                ('derive', 'derive word from {word} meaning {meaning}:'),
                ('analyze', 'analyze morphemes in {word}:')
            ],
            'syntactic': [
                ('complete', 'complete sentence: {prefix}'),
                ('transform', 'transform to {structure}: {sentence}'),
                ('parse', 'identify main clause in: {sentence}')
            ],
            'semantic': [
                ('continue', 'continue logically: {context}'),
                ('summarize', 'summarize in one sentence: {text}'),
                ('infer', 'what can be inferred from: {statement}')
            ]
        }

        print("Improved Task Creator initialized")

    def create_domain_specific_tasks(self, text: str, complexity_analysis: Dict[str, float]) -> List[Dict]:
        """Create diverse tasks for all linguistic domains"""
        if len(text) < 100:
            return []

        tasks = []
        sentences = self._extract_sentences(text)

        if len(sentences) < 1:
            return []

        # Determine primary domain based on complexity
        primary_domain = self._determine_primary_domain(complexity_analysis)

        # Always create span corruption (T5's core task)
        span_task = self._create_span_corruption_task(text)
        if span_task:
            span_task.update({
                'domain': 'general',
                'primary_domain': primary_domain,
                'complexity_score': complexity_analysis['overall_complexity']
            })
            tasks.append(span_task)

        # Create domain-specific tasks
        for domain in ['morphological', 'syntactic', 'semantic']:
            task = self._create_domain_task(domain, text, sentences, complexity_analysis)
            if task:
                task.update({
                    'domain': domain,
                    'primary_domain': primary_domain,
                    'complexity_score': complexity_analysis[f'{domain}_complexity']
                })
                tasks.append(task)

        return tasks

    def _extract_sentences(self, text: str) -> List[str]:
        """Extract well-formed sentences from text"""
        sentences = [s.strip() for s in self.sentence_pattern.split(text) if s.strip()]
        # Filter sentences with reasonable length
        return [s for s in sentences if 5 <= len(s.split()) <= 40]

    def _determine_primary_domain(self, complexity_analysis: Dict[str, float]) -> str:
        """Determine primary linguistic domain from complexity analysis"""
        domain_scores = {
            'morphological': complexity_analysis.get('morphological_complexity', 0.1),
            'syntactic': complexity_analysis.get('syntactic_complexity', 0.1),
            'semantic': complexity_analysis.get('semantic_complexity', 0.1)
        }
        return max(domain_scores, key=domain_scores.get)

    def _create_span_corruption_task(self, text: str, corruption_rate: float = 0.15) -> Optional[Dict]:
        """Create T5-style span corruption task"""
        try:
            words = text.split()
            if len(words) < 10:
                return None

            # Calculate number of spans to corrupt
            num_spans = max(1, min(int(len(words) * corruption_rate), len(words) // 4))

            # Select random spans
            span_starts = sorted(random.sample(range(len(words) - 1), num_spans))

            corrupted_words = words.copy()
            target_spans = []

            for i, start in enumerate(span_starts):
                sentinel = f'<extra_id_{i}>'
                span_length = random.randint(1, 3)
                end = min(start + span_length, len(words))

                original_span = ' '.join(words[start:end])
                target_spans.append(f'{sentinel} {original_span}')

                # Replace span with sentinel
                for j in range(start, end):
                    if j < len(corrupted_words):
                        if j == start:
                            corrupted_words[j] = sentinel
                        else:
                            corrupted_words[j] = ''

            source_text = ' '.join(w for w in corrupted_words if w)
            target_text = ' '.join(target_spans) + f' <extra_id_{num_spans}>'

            return self._tokenize_example(source_text, target_text, 'span_corruption')

        except Exception:
            return None

    def _create_domain_task(self, domain: str, text: str, sentences: List[str], complexity_analysis: Dict) -> Optional[Dict]:
        """Create task specific to linguistic domain"""
        if domain == 'morphological':
            return self._create_morphological_task(text, sentences)
        elif domain == 'syntactic':
            return self._create_syntactic_task(text, sentences)
        elif domain == 'semantic':
            return self._create_semantic_task(text, sentences)
        return None

    def _create_morphological_task(self, text: str, sentences: List[str]) -> Optional[Dict]:
        """Create morphological analysis task"""
        words = self.word_pattern.findall(text.lower())

        # Find base words suitable for inflection
        base_words = [w for w in words if 3 <= len(w) <= 8 and w.isalpha()
                     and not w.endswith(('ing', 'ed', 'er', 'ly', 'est'))]

        if not base_words:
            return None

        base_word = random.choice(base_words)
        inflections = ['ing', 'ed', 'er', 's']
        suffix = random.choice(inflections)

        template = random.choice(self.task_templates['morphological'])
        source_text = template[1].format(word=base_word, suffix=suffix, meaning='action')
        target_text = base_word + suffix

        return self._tokenize_example(source_text, target_text, f'morphological_{template[0]}')

    def _create_syntactic_task(self, text: str, sentences: List[str]) -> Optional[Dict]:
        """Create syntactic analysis task"""
        if not sentences:
            return None

        sentence = random.choice(sentences)
        words = sentence.split()

        if len(words) < 6:
            return None

        # Create sentence completion task
        split_point = random.randint(3, len(words) - 2)
        prefix = ' '.join(words[:split_point])
        completion = ' '.join(words[split_point:])

        template = random.choice(self.task_templates['syntactic'])
        source_text = template[1].format(prefix=prefix, sentence=sentence, structure='active')

        return self._tokenize_example(source_text, completion, f'syntactic_{template[0]}')

    def _create_semantic_task(self, text: str, sentences: List[str]) -> Optional[Dict]:
        """Create semantic understanding task"""
        if len(sentences) < 2:
            return None

        # Create logical continuation task
        context = sentences[0]
        continuation = sentences[1] if len(sentences) > 1 else context[:50] + "..."

        template = random.choice(self.task_templates['semantic'])
        source_text = template[1].format(context=context, text=text[:200], statement=context)

        return self._tokenize_example(source_text, continuation, f'semantic_{template[0]}')

    def _tokenize_example(self, source_text: str, target_text: str, task_type: str) -> Optional[Dict]:
        """Tokenize example with error handling"""
        try:
            source_text = source_text.strip()
            target_text = target_text.strip()

            if not source_text or not target_text:
                return None

            # Length checks before tokenization
            if len(source_text) > 2000 or len(target_text) > 1000:
                return None

            source_encoding = self.tokenizer(
                source_text,
                max_length=self.max_source_length,
                truncation=True,
                padding='max_length',
                return_tensors='pt'
            )

            target_encoding = self.tokenizer(
                target_text,
                max_length=self.max_target_length,
                truncation=True,
                padding='max_length',
                return_tensors='pt'
            )

            return {
                'input_ids': source_encoding['input_ids'].squeeze(0),
                'attention_mask': source_encoding['attention_mask'].squeeze(0),
                'labels': target_encoding['input_ids'].squeeze(0),
                'source_text': source_text,
                'target_text': target_text,
                'task_type': task_type
            }

        except Exception as e:
            print(f"Tokenization error for {task_type}: {e}")
            return None


# =============================================================================
# IMPROVED DATASET WITH ROBUST CACHING
# =============================================================================

class ImprovedDomainAwareDataset(Dataset):
    """Improved dataset with robust caching and better task distribution"""

    def __init__(self, data_path: str, tokenizer: T5Tokenizer,
                 complexity_analyzer: EnhancedComplexityAnalyzer,
                 task_creator: ImprovedTaskCreator,
                 data_manager: RobustDataManager,
                 max_source_length: int = 512, max_target_length: int = 256,
                 split: str = "train", max_examples: int = None):

        self.data_path = data_path
        self.tokenizer = tokenizer
        self.complexity_analyzer = complexity_analyzer
        self.task_creator = task_creator
        self.data_manager = data_manager
        self.max_source_length = max_source_length
        self.max_target_length = max_target_length
        self.split = split
        self.max_examples = max_examples

        # Generate cache key based on configuration
        config_params = {
            'data_path': data_path,
            'max_source_length': max_source_length,
            'max_target_length': max_target_length,
            'split': split,
            'max_examples': max_examples
        }

        self.cache_key = self.data_manager.generate_cache_key(data_path, config_params)

        # Initialize data structures
        self.examples = []
        self.complexity_scores = []
        self.domain_levels = {'morphological': {}, 'syntactic': {}, 'semantic': {}}
        self.domain_indices = {'morphological': [], 'syntactic': [], 'semantic': [], 'general': []}

        # Load or process data
        self._load_or_process_data()

        print(f"Dataset initialized: {len(self.examples)} examples")
        self._print_dataset_stats()


    def _load_or_process_data(self):
        """FIXED: Load from cache with better pattern matching"""
        # Try loading from cache with flexible matching
        base_cache_key = f'dataset_{self.split}'
        cached_data = self.data_manager.load_preprocessed_data(self.cache_key, base_cache_key)

        if cached_data is not None:
            try:
                self.examples = cached_data['examples']
                self.complexity_scores = cached_data['complexity_scores']
                self.domain_levels = cached_data.get('domain_levels', self.domain_levels)
                self.domain_indices = cached_data.get('domain_indices', self.domain_indices)
                print(f"✅ Loaded {len(self.examples)} examples from cache")
                return
            except Exception as e:
                print(f"⚠️ Cache data corrupted: {e}, reprocessing...")

        # If no cache found, try alternative cache keys
        print(f"🔍 No exact cache match, trying alternative patterns...")

        # Try with just the split name
        alt_cached_data = self.data_manager.load_preprocessed_data('', f'dataset_{self.split}')
        if alt_cached_data is not None:
            try:
                self.examples = alt_cached_data['examples']
                self.complexity_scores = alt_cached_data['complexity_scores']
                self.domain_levels = alt_cached_data.get('domain_levels', self.domain_levels)
                self.domain_indices = alt_cached_data.get('domain_indices', self.domain_indices)
                print(f"✅ Loaded {len(self.examples)} examples from alternative cache")
                return
            except Exception as e:
                print(f"⚠️ Alternative cache also corrupted: {e}")

        # Process from scratch if no cache works
        print(f"📊 No usable cache found, processing {self.split} data from scratch...")
        start_time = time.time()

        success = self._process_raw_data()

        if success and len(self.examples) > 0:
            self._create_domain_curricula()

            # Save to cache with the current cache key
            cached_data = {
                'examples': self.examples,
                'complexity_scores': self.complexity_scores,
                'domain_levels': self.domain_levels,
                'domain_indices': self.domain_indices
            }

            self.data_manager.save_preprocessed_data(cached_data, self.cache_key, base_cache_key)

            processing_time = time.time() - start_time
            print(f"✅ Processing completed in {processing_time:.2f} seconds")
        else:
            print("❌ Failed to process data or no examples created")

    def _process_raw_data(self) -> bool:
        """Process raw data with improved error handling"""
        try:
            if not os.path.exists(self.data_path):
                print(f"Data file not found: {self.data_path}")
                return False

            # Read and split documents
            with open(self.data_path, 'r', encoding='utf-8') as f:
                text = f.read()

            documents = [doc.strip() for doc in re.split(r'\n\s*\n', text)
                        if len(doc.strip()) > 150]

            if self.max_examples:
                max_docs = min(len(documents), self.max_examples // 3)
                documents = documents[:max_docs]

            print(f"Processing {len(documents)} documents...")

            # Process documents with progress tracking
            for doc_idx, document in enumerate(tqdm(documents, desc='Processing documents')):
                try:
                    # Analyze complexity
                    complexity_analysis = self.complexity_analyzer.analyze_text_complexity_batch([document])[0]

                    # Create tasks
                    tasks = self.task_creator.create_domain_specific_tasks(document, complexity_analysis)

                    # Add valid tasks
                    for task in tasks:
                        if task and self._validate_task(task):
                            task['complexity_analysis'] = complexity_analysis
                            task['document_id'] = f"{self.split}_{doc_idx}"

                            self.examples.append(task)
                            self.complexity_scores.append(task['complexity_score'])

                            # Update domain indices
                            domain = task.get('domain', 'general')
                            if domain in self.domain_indices:
                                self.domain_indices[domain].append(len(self.examples) - 1)
                            else:
                                self.domain_indices['general'].append(len(self.examples) - 1)

                except Exception as e:
                    print(f"Error processing document {doc_idx}: {e}")
                    continue

            return len(self.examples) > 0

        except Exception as e:
            print(f"Error in raw data processing: {e}")
            return False

    def _validate_task(self, task: Dict) -> bool:
        """Validate that task has required fields"""
        required_fields = ['input_ids', 'attention_mask', 'labels', 'task_type', 'domain']
        return all(field in task for field in required_fields)

    def _create_domain_curricula(self):
        """Create curriculum levels for each domain"""
        print("Creating domain curricula...")

        for domain in ['morphological', 'syntactic', 'semantic']:
            domain_examples = [(i, ex) for i, ex in enumerate(self.examples)
                             if ex.get('domain') == domain]

            if not domain_examples:
                continue

            # Sort by complexity
            domain_examples.sort(key=lambda x: x[1].get('complexity_score', 0))

            # Create curriculum levels
            max_level = {'morphological': 2, 'syntactic': 3, 'semantic': 4}[domain]
            level_size = max(1, len(domain_examples) // (max_level + 1))

            for level in range(max_level + 1):
                start_idx = level * level_size
                end_idx = min((level + 1) * level_size, len(domain_examples))

                if start_idx < len(domain_examples):
                    level_indices = [idx for idx, _ in domain_examples[start_idx:end_idx]]
                    self.domain_levels[domain][level] = level_indices

        print("Domain curricula created")

    def _print_dataset_stats(self):
        """Print comprehensive dataset statistics"""
        print(f"Dataset Statistics:")
        for domain in ['morphological', 'syntactic', 'semantic', 'general']:
            count = len(self.domain_indices.get(domain, []))
            percentage = (count / len(self.examples) * 100) if self.examples else 0
            print(f"  {domain}: {count} ({percentage:.1f}%)")

        if self.complexity_scores:
            print(f"Complexity Statistics:")
            print(f"  Mean: {np.mean(self.complexity_scores):.3f}")
            print(f"  Std: {np.std(self.complexity_scores):.3f}")

    def get_domain_curriculum_data(self, domain: str, level: int) -> List[int]:
        """Get indices for domain curriculum level"""
        return self.domain_levels.get(domain, {}).get(level, [])

    def get_domain_indices(self, domain: str) -> List[int]:
        """Get all indices for a specific domain"""
        return self.domain_indices.get(domain, [])

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return self.examples[idx]


# =============================================================================
# IMPROVED CURRICULUM SAMPLER
# =============================================================================

class ImprovedCurriculumSampler(Sampler):
    """Improved curriculum sampler with better phase management"""

    def __init__(self, dataset: ImprovedDomainAwareDataset, current_epoch: int, total_epochs: int):
        self.dataset = dataset
        self.current_epoch = current_epoch
        self.total_epochs = total_epochs

        # Phase configuration
        self.foundation_epochs = 3
        self.curriculum_epochs = 4
        self.integration_epochs = 3

        # Determine current phase
        self.phase = self._get_current_phase()
        self.active_indices = self._get_phase_indices()

        print(f"Curriculum Sampler - Epoch {current_epoch + 1}, Phase: {self.phase}")
        print(f"Active examples: {len(self.active_indices)}")

    def _get_current_phase(self) -> str:
        """Determine current training phase"""
        if self.current_epoch < self.foundation_epochs:
            return "foundation"
        elif self.current_epoch < self.foundation_epochs + self.curriculum_epochs:
            return "curriculum"
        else:
            return "integration"

    def _get_phase_indices(self) -> List[int]:
        """Get indices based on current phase"""
        if self.phase == "foundation":
            return self._get_foundation_indices()
        elif self.phase == "curriculum":
            return self._get_curriculum_indices()
        else:
            return self._get_integration_indices()

    def _get_foundation_indices(self) -> List[int]:
        """Foundation phase: Balanced with morphology emphasis"""
        all_indices = []

        # Get domain indices
        morph_indices = self.dataset.get_domain_indices('morphological')
        other_indices = (self.dataset.get_domain_indices('syntactic') +
                        self.dataset.get_domain_indices('semantic') +
                        self.dataset.get_domain_indices('general'))

        # Morphology emphasis (60% morphological, 40% others)
        target_morph = min(len(morph_indices), len(self.dataset) * 6 // 10)
        all_indices.extend(morph_indices[:target_morph])

        remaining = len(self.dataset) - len(all_indices)
        if other_indices and remaining > 0:
            sample_size = min(len(other_indices), remaining)
            all_indices.extend(random.sample(other_indices, sample_size))

        # Ensure we have examples
        if not all_indices:
            all_indices = list(range(len(self.dataset)))

        random.shuffle(all_indices)
        return all_indices

    def _get_curriculum_indices(self) -> List[int]:
        """Curriculum phase: Progressive difficulty"""
        curriculum_epoch = self.current_epoch - self.foundation_epochs
        progress = curriculum_epoch / max(self.curriculum_epochs - 1, 1)

        all_indices = []

        # Progressive curriculum for each domain
        for domain in ['morphological', 'syntactic', 'semantic']:
            max_levels = {'morphological': 2, 'syntactic': 3, 'semantic': 4}
            max_level = max_levels[domain]
            current_level = min(int(progress * (max_level + 1)), max_level)

            for level in range(current_level + 1):
                level_indices = self.dataset.get_domain_curriculum_data(domain, level)
                all_indices.extend(level_indices)

        # Add general examples
        general_indices = self.dataset.get_domain_indices('general')
        if general_indices:
            sample_size = min(len(general_indices), len(self.dataset) // 4)
            all_indices.extend(random.sample(general_indices, sample_size))

        # Fallback to all examples if curriculum is empty
        if not all_indices:
            all_indices = list(range(len(self.dataset)))

        random.shuffle(all_indices)
        return list(set(all_indices))

    def _get_integration_indices(self) -> List[int]:
        """Integration phase: Use all examples"""
        all_indices = list(range(len(self.dataset)))
        random.shuffle(all_indices)
        return all_indices

    def __iter__(self):
        return iter(self.active_indices)

    def __len__(self):
        return len(self.active_indices)

class TrainingConfig:
    """Training configuration with enhanced regularization parameters"""

    def __init__(self):
        # Existing model configuration
        self.model_name = 't5-small'
        self.max_source_length = 512
        self.max_target_length = 256

        # Existing training parameters
        self.num_epochs = 10
        self.batch_size = 16
        self.learning_rate = 5e-5
        self.weight_decay = 0.01
        self.warmup_ratio = 0.1
        self.max_grad_norm = 1.0
        self.patience = 4

        # Existing curriculum phases
        self.foundation_epochs = 3
        self.curriculum_epochs = 4
        self.integration_epochs = 3

        # Existing data limits
        self.max_train_examples = None
        self.max_val_examples = 20000

        # NEW: Enhanced regularization parameters
        self.lambda_accel = 1e-4  # Acceleration penalty coefficient
        self.regularization_strategy = 'first_middle_last'  # Which layers to penalize
        self.temporal_alpha_init = 0.1  # Initial value for temporal embedding weight
        self.enable_temporal_embedding = True  # Enable/disable temporal embedding difference
        self.temporal_alpha_range = (0.01, 0.5)  # Clamp range for alpha

        # Hardware
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

        print("Enhanced training configuration initialized with regularization")


# =============================================================================
# ENHANCED REGULARIZATION LOSSES
# =============================================================================

class TargetedSmoothnessLoss(torch.nn.Module):
    """
    Applies smoothness penalty to select transformer layers in T5 encoder/decoder
    """
    def __init__(self, lambda_coeff: float, strategy: str = 'first_middle_last'):
        super().__init__()
        self.lambda_coeff = lambda_coeff
        self.strategy = strategy
        
        print(f"Initialized TargetedSmoothnessLoss with lambda={lambda_coeff}, strategy='{strategy}'")

    def _get_layer_indices(self, total_layers: int) -> list:
        """Select which layers to penalize based on strategy"""
        if total_layers < 3:
            return [total_layers // 2] if total_layers > 0 else []
        
        if self.strategy == 'first_middle_last':
            # Penalize layers 1, middle, and last (avoiding embedding layer at 0)
            return [1, (total_layers // 2) + 1, total_layers]
        elif self.strategy == 'middle_only':
            return [total_layers // 2]
        elif self.strategy == 'all_layers':
            return list(range(1, total_layers + 1))  # All except embedding
        else:
            return [1, total_layers]  # first and last transformer layers

    def _calculate_penalty(self, hidden_states: tuple) -> torch.Tensor:
        """Calculate acceleration penalty using second-order finite differences"""
        if not hidden_states or len(hidden_states) < 2:
            return torch.tensor(0.0, device=hidden_states[0].device if hidden_states else torch.device('cpu'))
            
        num_transformer_layers = len(hidden_states) - 1  # Exclude embedding layer
        indices_to_penalize = self._get_layer_indices(num_transformer_layers)
        
        total_penalty = torch.tensor(0.0, device=hidden_states[0].device)
        penalty_count = 0
        
        for layer_idx in set(indices_to_penalize):
            # Ensure valid index for hidden_states tuple
            if 0 < layer_idx < len(hidden_states):
                layer_hidden_states = hidden_states[layer_idx]
                
                # Need at least 3 positions for second-order difference
                if layer_hidden_states.shape[1] > 2:
                    # Extract consecutive positions
                    h_t = layer_hidden_states[:, :-2, :]          # positions 0 to n-2
                    h_t_plus_1 = layer_hidden_states[:, 1:-1, :]  # positions 1 to n-1  
                    h_t_plus_2 = layer_hidden_states[:, 2:, :]    # positions 2 to n
                    
                    # Second-order finite difference (discrete acceleration)
                    acceleration = h_t_plus_2 - 2 * h_t_plus_1 + h_t
                    
                    # L2 penalty on acceleration magnitude
                    layer_penalty = torch.mean(torch.norm(acceleration, p=2, dim=2)**2)
                    total_penalty += layer_penalty
                    penalty_count += 1
                    
        # Average penalty across penalized layers
        if penalty_count > 0:
            total_penalty = total_penalty / penalty_count
            
        return total_penalty

    def forward(self, outputs, original_loss: torch.Tensor) -> dict:
        """Compute regularized loss and return detailed breakdown"""
        if self.lambda_coeff <= 0:
            return {
                'total_loss': original_loss,
                'original_loss': original_loss,
                'smoothness_penalty': torch.tensor(0.0, device=original_loss.device),
                'encoder_penalty': torch.tensor(0.0, device=original_loss.device),
                'decoder_penalty': torch.tensor(0.0, device=original_loss.device)
            }

        encoder_penalty = torch.tensor(0.0, device=original_loss.device)
        decoder_penalty = torch.tensor(0.0, device=original_loss.device)
        
        # Apply to encoder if available
        if hasattr(outputs, 'encoder_hidden_states') and outputs.encoder_hidden_states:
            encoder_penalty = self._calculate_penalty(outputs.encoder_hidden_states)
            
        # Apply to decoder if available  
        if hasattr(outputs, 'decoder_hidden_states') and outputs.decoder_hidden_states:
            decoder_penalty = self._calculate_penalty(outputs.decoder_hidden_states)
            
        total_smoothness_penalty = encoder_penalty + decoder_penalty
        total_loss = original_loss + (self.lambda_coeff * total_smoothness_penalty)
        
        return {
            'total_loss': total_loss,
            'original_loss': original_loss,
            'smoothness_penalty': total_smoothness_penalty,
            'encoder_penalty': encoder_penalty,
            'decoder_penalty': decoder_penalty
        }


class TemporalEmbeddingEnhancement(torch.nn.Module):
    """
    WARNING: This implementation has significant architectural concerns.
    Adds learnable temporal differences to first layer embeddings.
    """
    def __init__(self, alpha_init: float = 0.1, alpha_range: tuple = (0.01, 0.5)):
        super().__init__()
        self.alpha = torch.nn.Parameter(torch.tensor(alpha_init))
        self.alpha_range = alpha_range
        
        print(f"WARNING: TemporalEmbeddingEnhancement initialized with alpha={alpha_init}")
        print(f"This approach has known architectural inconsistencies with T5.")

    def _clamp_alpha(self):
        """Clamp alpha to valid range during training"""
        with torch.no_grad():
            self.alpha.clamp_(self.alpha_range[0], self.alpha_range[1])

    def forward(self, encoder_hidden_states: tuple, decoder_hidden_states: tuple = None) -> dict:
        """
        Apply temporal embedding enhancement to first layer embeddings
        Returns modified hidden states and regularization info
        """
        self._clamp_alpha()
        
        enhanced_encoder = None
        enhanced_decoder = None
        temporal_penalty = torch.tensor(0.0, device=self.alpha.device)
        
        # Process encoder embeddings (first hidden state)
        if encoder_hidden_states and len(encoder_hidden_states) > 0:
            enhanced_encoder = self._enhance_embeddings(encoder_hidden_states[0])
            
        # Process decoder embeddings if available
        if decoder_hidden_states and len(decoder_hidden_states) > 0:
            enhanced_decoder = self._enhance_embeddings(decoder_hidden_states[0])
            
        # Calculate temporal penalty (L2 on alpha to prevent explosion)
        temporal_penalty = 0.01 * (self.alpha ** 2)
        
        return {
            'enhanced_encoder': enhanced_encoder,
            'enhanced_decoder': enhanced_decoder,
            'temporal_penalty': temporal_penalty,
            'alpha_value': self.alpha.item()
        }
    
    def _enhance_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor:
        """
        Apply temporal difference enhancement to embeddings
        embeddings: [batch_size, seq_length, hidden_dim]
        """
        if embeddings.shape[1] < 2:
            return embeddings
            
        # Calculate differences between consecutive positions
        current_embeds = embeddings[:, 1:, :]     # positions 1 to n
        previous_embeds = embeddings[:, :-1, :]   # positions 0 to n-1
        temporal_diff = current_embeds - previous_embeds
        
        # Apply learnable weighting and add back to current embeddings
        enhanced_current = current_embeds + (self.alpha * temporal_diff)
        
        # Concatenate with first embedding (unchanged)
        enhanced_embeddings = torch.cat([
            embeddings[:, :1, :],  # First embedding unchanged
            enhanced_current       # Enhanced subsequent embeddings
        ], dim=1)
        
        return enhanced_embeddings


class CombinedRegularizationLoss(torch.nn.Module):
    """
    Combines both smoothness penalty and temporal embedding enhancement
    """
    def __init__(self, config):
        super().__init__()
        
        self.smoothness_loss = TargetedSmoothnessLoss(
            lambda_coeff=getattr(config, 'lambda_accel', 0.0),
            strategy=getattr(config, 'regularization_strategy', 'first_middle_last')
        )
        
        self.temporal_enhancement = None
        if getattr(config, 'enable_temporal_embedding', False):
            self.temporal_enhancement = TemporalEmbeddingEnhancement(
                alpha_init=getattr(config, 'temporal_alpha_init', 0.1),
                alpha_range=getattr(config, 'temporal_alpha_range', (0.01, 0.5))
            )
            
        print(f"Combined regularization initialized:")
        print(f"  Smoothness penalty: {'Enabled' if self.smoothness_loss.lambda_coeff > 0 else 'Disabled'}")
        print(f"  Temporal enhancement: {'Enabled' if self.temporal_enhancement else 'Disabled'}")

    def forward(self, outputs, original_loss: torch.Tensor) -> dict:
        """Apply combined regularization and return comprehensive loss breakdown"""
        
        # Apply smoothness penalty
        smoothness_results = self.smoothness_loss(outputs, original_loss)
        current_loss = smoothness_results['total_loss']
        
        # Apply temporal enhancement if enabled
        temporal_results = {'temporal_penalty': torch.tensor(0.0, device=original_loss.device), 'alpha_value': 0.0}
        
        if self.temporal_enhancement:
            temporal_results = self.temporal_enhancement(
                encoder_hidden_states=getattr(outputs, 'encoder_hidden_states', None),
                decoder_hidden_states=getattr(outputs, 'decoder_hidden_states', None)
            )
            # Add temporal penalty to loss
            current_loss = current_loss + temporal_results['temporal_penalty']
        
        # Combine all results
        return {
            'total_loss': current_loss,
            'original_loss': original_loss,
            'smoothness_penalty': smoothness_results['smoothness_penalty'],
            'encoder_penalty': smoothness_results['encoder_penalty'],
            'decoder_penalty': smoothness_results['decoder_penalty'],
            'temporal_penalty': temporal_results['temporal_penalty'],
            'alpha_value': temporal_results['alpha_value']
        }

# =============================================================================
# IMPROVED TRAINER WITH BETTER ERROR HANDLING
# =============================================================================

class ImprovedT5Trainer:
    """Improved trainer with robust error handling and better save/load"""

    def __init__(self, model: T5ForConditionalGeneration, tokenizer: T5Tokenizer,
                data_manager: RobustDataManager, config):
        self.model = model
        self.tokenizer = tokenizer
        self.data_manager = data_manager
        self.config = config

        # Setup optimizer (existing code)
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters()
                        if not any(nd in n for nd in no_decay)],
                "weight_decay": config.weight_decay,
            },
            {
                "params": [p for n, p in model.named_parameters()
                        if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]

        self.optimizer = AdamW(optimizer_grouped_parameters, lr=config.learning_rate)
        self.scheduler = None
        self.global_step = 0

        # NEW: Initialize combined regularization
        self.regularization_loss = CombinedRegularizationLoss(config).to(config.device)
        
        # Include regularization parameters in optimizer if temporal enhancement is enabled
        if self.regularization_loss.temporal_enhancement:
            temporal_params = list(self.regularization_loss.temporal_enhancement.parameters())
            if temporal_params:
                self.optimizer.add_param_group({
                    'params': temporal_params,
                    'lr': config.learning_rate * 0.1,  # Lower learning rate for alpha
                    'weight_decay': 0.0
                })
                print(f"Added {len(temporal_params)} temporal parameters to optimizer")

        # Existing training metrics initialization
        self.training_stats = {
            'epoch_losses': [],
            'val_losses': [],
            'learning_rates': [],
            'phase_transitions': [],
            'regularization_stats': []  # NEW: Track regularization metrics
        }

        print("Enhanced T5 Trainer initialized with combined regularization")

    def train(self, train_dataset: ImprovedDomainAwareDataset,
              val_dataset: Optional[ImprovedDomainAwareDataset] = None,
              resume_epoch: int = 0):
        """Main training loop with improved error handling"""

        print(f"Starting training with {len(train_dataset)} examples")

        # Setup scheduler
        steps_per_epoch = max(1, len(train_dataset) // self.config.batch_size)
        total_steps = steps_per_epoch * self.config.num_epochs
        warmup_steps = int(total_steps * self.config.warmup_ratio)

        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps
        )

        # Move model to device
        self.model.to(self.config.device)

        best_val_loss = float('inf')
        patience_counter = 0

        for epoch in range(resume_epoch, self.config.num_epochs):
            print(f"\nEpoch {epoch + 1}/{self.config.num_epochs}")

            # Create curriculum sampler
            sampler = ImprovedCurriculumSampler(train_dataset, epoch, self.config.num_epochs)

            # Ensure we have examples
            if len(sampler) == 0:
                print("Warning: No examples in sampler, using all dataset")
                sampler.active_indices = list(range(len(train_dataset)))

            # Create data loader
            train_loader = DataLoader(
                train_dataset,
                batch_size=self.config.batch_size,
                sampler=sampler,
                collate_fn=self._collate_fn,
                num_workers=0,
                pin_memory=True if self.config.device == 'cuda' else False
            )

            # Train epoch
            epoch_loss = self._train_epoch(train_loader)

            # Validation
            val_loss = None
            if val_dataset:
                val_loss = self._evaluate(val_dataset)

                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    patience_counter = 0
                    self._save_checkpoint(epoch, val_loss, is_best=True)
                else:
                    patience_counter += 1

                if patience_counter >= self.config.patience:
                    print(f"Early stopping at epoch {epoch + 1}")
                    break

            # Save training stats
            self.training_stats['epoch_losses'].append(epoch_loss)
            self.training_stats['val_losses'].append(val_loss or 0)
            self.training_stats['learning_rates'].append(self.scheduler.get_last_lr()[0])

            # Print progress
            print(f"Train Loss: {epoch_loss:.4f}")
            if val_loss:
                print(f"Val Loss: {val_loss:.4f}, Best: {best_val_loss:.4f}")

            # Save regular checkpoint
            if (epoch + 1) % 2 == 0:
                self._save_checkpoint(epoch, val_loss or epoch_loss)

            # Memory cleanup
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

        # Save final model
        self._save_final_model()
        return self.training_stats

    def _train_epoch(self, train_loader) -> float:
        """Train single epoch with enhanced regularization and gradient accumulation"""
        self.model.train()
        epoch_loss = 0.0
        num_batches = 0
        
        # Regularization tracking
        epoch_reg_stats = {
            'smoothness_penalty': 0.0,
            'temporal_penalty': 0.0,
            'alpha_values': []
        }

        accumulation_steps = max(1, 32 // self.config.batch_size)

        for batch_idx, batch in enumerate(tqdm(train_loader, desc='Enhanced Training')):
            try:
                # Move to device
                batch = {k: v.to(self.config.device) if isinstance(v, torch.Tensor) else v
                        for k, v in batch.items()}

                # Forward pass with hidden states enabled
                outputs = self.model(
                    input_ids=batch['input_ids'],
                    attention_mask=batch['attention_mask'],
                    labels=batch['labels'],
                    output_hidden_states=True  # CRITICAL: Enable hidden states
                )

                # Apply combined regularization
                loss_results = self.regularization_loss(outputs, outputs.loss)
                total_loss = loss_results['total_loss']
                
                loss = total_loss / accumulation_steps
                loss.backward()

                # Track regularization statistics
                epoch_reg_stats['smoothness_penalty'] += loss_results['smoothness_penalty'].item()
                epoch_reg_stats['temporal_penalty'] += loss_results['temporal_penalty'].item()
                if loss_results['alpha_value'] > 0:
                    epoch_reg_stats['alpha_values'].append(loss_results['alpha_value'])

                # Gradient accumulation and optimization
                if (batch_idx + 1) % accumulation_steps == 0:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
                    
                    # Also clip regularization parameters
                    if self.regularization_loss.temporal_enhancement:
                        torch.nn.utils.clip_grad_norm_(
                            self.regularization_loss.temporal_enhancement.parameters(), 
                            self.config.max_grad_norm
                        )
                    
                    self.optimizer.step()
                    self.scheduler.step()
                    self.optimizer.zero_grad()
                    self.global_step += 1

                epoch_loss += loss.item() * accumulation_steps
                num_batches += 1

                # Log regularization stats every 100 batches
                if batch_idx % 100 == 0 and batch_idx > 0:
                    avg_smooth = epoch_reg_stats['smoothness_penalty'] / num_batches
                    avg_temporal = epoch_reg_stats['temporal_penalty'] / num_batches
                    current_alpha = epoch_reg_stats['alpha_values'][-1] if epoch_reg_stats['alpha_values'] else 0.0
                    
                    print(f"Batch {batch_idx}: Smoothness={avg_smooth:.6f}, Temporal={avg_temporal:.6f}, Alpha={current_alpha:.4f}")

            except Exception as e:
                print(f"Error in enhanced training batch {batch_idx}: {e}")
                continue

        # Save epoch regularization statistics
        final_reg_stats = {
            'avg_smoothness_penalty': epoch_reg_stats['smoothness_penalty'] / max(num_batches, 1),
            'avg_temporal_penalty': epoch_reg_stats['temporal_penalty'] / max(num_batches, 1),
            'final_alpha': epoch_reg_stats['alpha_values'][-1] if epoch_reg_stats['alpha_values'] else 0.0,
            'alpha_std': np.std(epoch_reg_stats['alpha_values']) if len(epoch_reg_stats['alpha_values']) > 1 else 0.0
        }
        self.training_stats['regularization_stats'].append(final_reg_stats)

        return epoch_loss / max(num_batches, 1)

    def _evaluate(self, val_dataset) -> float:
        """Evaluate model on validation set"""
        self.model.eval()

        # Use subset for faster validation
        val_size = min(1000, len(val_dataset))
        val_indices = random.sample(range(len(val_dataset)), val_size)

        val_loader = DataLoader(
            val_dataset,
            batch_size=self.config.batch_size,
            sampler=val_indices,
            collate_fn=self._collate_fn,
            num_workers=0
        )

        total_loss = 0.0
        num_batches = 0

        with torch.no_grad():
            for batch in tqdm(val_loader, desc='Validation'):
                try:
                    batch = {k: v.to(self.config.device) if isinstance(v, torch.Tensor) else v
                            for k, v in batch.items()}

                    outputs = self.model(
                        input_ids=batch['input_ids'],
                        attention_mask=batch['attention_mask'],
                        labels=batch['labels'],
                        output_hidden_states=True  # Add this line
                    )

                    total_loss += outputs.loss.item()
                    num_batches += 1

                except Exception as e:
                    print(f"Error in validation batch: {e}")
                    continue

        return total_loss / max(num_batches, 1)

    def _collate_fn(self, batch):
        """Improved collate function with better error handling"""
        try:
            # Filter out None examples
            batch = [item for item in batch if item is not None]

            if not batch:
                # Return dummy batch
                return self._create_dummy_batch(1)

            input_ids = torch.stack([item['input_ids'] for item in batch])
            attention_mask = torch.stack([item['attention_mask'] for item in batch])
            labels = torch.stack([item['labels'] for item in batch])

            # Mask padding tokens in labels
            labels[labels == self.tokenizer.pad_token_id] = -100

            return {
                'input_ids': input_ids,
                'attention_mask': attention_mask,
                'labels': labels
            }

        except Exception as e:
            print(f"Collate function error: {e}")
            return self._create_dummy_batch(len(batch))

    def _create_dummy_batch(self, batch_size: int):
        """Create dummy batch for error recovery"""
        return {
            'input_ids': torch.zeros(batch_size, self.config.max_source_length, dtype=torch.long),
            'attention_mask': torch.ones(batch_size, self.config.max_source_length, dtype=torch.long),
            'labels': torch.full((batch_size, self.config.max_target_length), -100, dtype=torch.long)
        }

    def _save_checkpoint(self, epoch: int, loss: float, is_best: bool = False):
        """Save checkpoint with regularization state"""
        metrics = {
            'epoch': epoch + 1,
            'loss': loss,
            'global_step': self.global_step,
            'regularization_stats': self.training_stats['regularization_stats'][-1] if self.training_stats['regularization_stats'] else {}
        }

        # Save regularization state
        regularization_state = {
            'smoothness_lambda': self.regularization_loss.smoothness_loss.lambda_coeff,
            'alpha_value': getattr(self.regularization_loss.temporal_enhancement, 'alpha', torch.tensor(0.0)).item() if self.regularization_loss.temporal_enhancement else 0.0
        }
        
        success = self.data_manager.save_model_checkpoint(
            self.model, self.tokenizer, self.optimizer, self.scheduler,
            epoch, metrics, is_best
        )
        
        # Also save regularization state separately
        if success:
            checkpoint_dir = self.data_manager.get_path('checkpoints') / ('best_model' if is_best else f'epoch_{epoch+1}')
            reg_state_path = checkpoint_dir / 'regularization_state.pt'
            torch.save(regularization_state, reg_state_path)

        if not success:
            print(f"Failed to save enhanced checkpoint for epoch {epoch + 1}")

    def _save_final_model(self):
        """Save final model to designated directory"""
        final_dir = self.data_manager.get_path('final_models') / 'final_model'
        final_dir.mkdir(exist_ok=True)

        try:
            self.model.save_pretrained(final_dir)
            self.tokenizer.save_pretrained(final_dir)

            # Save training stats
            with open(final_dir / 'training_stats.json', 'w') as f:
                json.dump(self.training_stats, f, indent=2)

            print(f"Final model saved to {final_dir}")

        except Exception as e:
            print(f"Failed to save final model: {e}")



# =============================================================================
# CHECKPOINT LOADING FUNCTIONALITY
# =============================================================================

def load_existing_checkpoint(data_manager: RobustDataManager, checkpoint_name: str = None):
    """FIXED: Load checkpoint with proper safetensors support and better selection"""
    checkpoints = data_manager.list_checkpoints()

    if not checkpoints:
        print("❌ No checkpoints found")
        return None

    if checkpoint_name:
        # Find specific checkpoint
        target_checkpoint = None
        for cp in checkpoints:
            if checkpoint_name in cp['checkpoint_name']:
                target_checkpoint = cp
                break

        if not target_checkpoint:
            print(f"❌ Checkpoint '{checkpoint_name}' not found")
            print("Available checkpoints:")
            for cp in checkpoints:
                print(f"  - {cp['checkpoint_name']} (Epoch {cp['epoch']})")
            return None
    else:
        # FIXED: Better automatic checkpoint selection
        # Priority: 1) Latest epoch checkpoint, 2) Best checkpoint, 3) Any checkpoint

        epoch_checkpoints = [cp for cp in checkpoints if cp['checkpoint_name'].startswith('epoch_')]
        best_checkpoints = [cp for cp in checkpoints if cp.get('is_best', False)]

        if epoch_checkpoints:
            # Sort by epoch number and take the latest
            target_checkpoint = max(epoch_checkpoints, key=lambda x: x.get('epoch', 0))
            print(f"🏃 Using latest epoch checkpoint: {target_checkpoint['checkpoint_name']} (Epoch {target_checkpoint['epoch']})")
        elif best_checkpoints:
            target_checkpoint = best_checkpoints[0]
            print(f"🏆 Using best checkpoint: {target_checkpoint['checkpoint_name']} (Epoch {target_checkpoint['epoch']})")
        else:
            target_checkpoint = max(checkpoints, key=lambda x: x.get('epoch', 0))
            print(f"📦 Using available checkpoint: {target_checkpoint['checkpoint_name']} (Epoch {target_checkpoint['epoch']})")

    # Try loading the checkpoint
    return _load_checkpoint_files(target_checkpoint)
def _load_checkpoint_files(checkpoint_info: Dict):
    """FIXED: Load checkpoint files with safetensors support"""
    try:
        checkpoint_path = Path(checkpoint_info['path'])
        print(f"📂 Loading checkpoint from: {checkpoint_path}")

        # Check what model format is available
        has_safetensors = checkpoint_info.get('has_safetensors', False)
        has_pytorch_model = checkpoint_info.get('has_pytorch_model', False)

        print(f"   Model format available: {'safetensors' if has_safetensors else 'pytorch_bin' if has_pytorch_model else 'none'}")

        # Load model - FIXED to handle safetensors properly
        print("🔄 Loading model...")
        try:
            # Try direct loading first (works for both formats)
            model = T5ForConditionalGeneration.from_pretrained(checkpoint_path)
            print("✅ Model loaded successfully with from_pretrained")
        except Exception as e:
            print(f"⚠️ from_pretrained failed: {e}")
            # Try manual loading
            try:
                config = T5Config.from_pretrained(checkpoint_path)
                model = T5ForConditionalGeneration(config)

                if has_safetensors:
                    # For safetensors, transformers should handle it automatically
                    model = T5ForConditionalGeneration.from_pretrained(checkpoint_path, use_safetensors=True)
                    print("✅ Model loaded with safetensors")
                elif has_pytorch_model:
                    state_dict = torch.load(checkpoint_path / 'pytorch_model.bin', map_location='cpu')
                    model.load_state_dict(state_dict)
                    print("✅ Model loaded with pytorch state_dict")
                else:
                    print("❌ No compatible model file found")
                    return None

            except Exception as e2:
                print(f"❌ Manual loading also failed: {e2}")
                return None

        # Load tokenizer - FIXED to be more flexible
        print("🔄 Loading tokenizer...")
        try:
            tokenizer = T5Tokenizer.from_pretrained(checkpoint_path)
            print("✅ Tokenizer loaded successfully")
        except Exception as e:
            print(f"⚠️ Local tokenizer loading failed: {e}")
            # Fallback to original model tokenizer
            try:
                tokenizer = T5Tokenizer.from_pretrained('t5-small')
                print("✅ Using fallback tokenizer (t5-small)")
            except Exception as e2:
                print(f"❌ Even fallback tokenizer failed: {e2}")
                return None

        # Load training state if available
        training_state = None
        if checkpoint_info.get('has_training_state', False):
            try:
                print("🔄 Loading training state...")
                training_state_path = checkpoint_path / 'training_state.pt'
                training_state = torch.load(training_state_path, map_location='cpu')
                print("✅ Training state loaded successfully")
            except Exception as e:
                print(f"⚠️ Failed to load training state: {e}")
                print("   Will continue without training state")

        print(f"✅ Successfully loaded checkpoint: {checkpoint_info['checkpoint_name']}")
        return {
            'model': model,
            'tokenizer': tokenizer,
            'training_state': training_state,
            'checkpoint_info': checkpoint_info
        }

    except Exception as e:
        print(f"❌ Failed to load checkpoint: {e}")
        import traceback
        traceback.print_exc()
        return None


# =============================================================================
# UTILITY FUNCTIONS
# =============================================================================

def validate_training_data(data_dir: str) -> bool:
    """Validate that training data exists and is accessible"""
    required_files = ['babylm_train.txt', 'babylm_dev.txt']

    for filename in required_files:
        filepath = os.path.join(data_dir, filename)
        if not os.path.exists(filepath):
            print(f"Required file missing: {filepath}")
            return False

        try:
            with open(filepath, 'r', encoding='utf-8') as f:
                content = f.read(1000)  # Read first 1000 chars
                if len(content.strip()) < 100:
                    print(f"File appears empty or too small: {filepath}")
                    return False
        except Exception as e:
            print(f"Cannot read file {filepath}: {e}")
            return False

    print("Training data validation passed")
    return True


def estimate_training_time(num_examples: int, batch_size: int, num_epochs: int,
                         device: str = 'cpu') -> Dict[str, float]:
    """Estimate training time based on system configuration"""

    # Rough estimates based on typical hardware
    if device == 'cuda':
        examples_per_minute = 1000  # GPU estimate
    else:
        examples_per_minute = 100   # CPU estimate

    total_examples = num_examples * num_epochs
    estimated_minutes = total_examples / examples_per_minute

    return {
        'total_examples': total_examples,
        'estimated_minutes': estimated_minutes,
        'estimated_hours': estimated_minutes / 60,
        'examples_per_epoch': num_examples,
        'batches_per_epoch': max(1, num_examples // batch_size)
    }


def monitor_system_resources():
    """Monitor system resources during training"""
    try:
        import psutil

        cpu_percent = psutil.cpu_percent()
        memory = psutil.virtual_memory()
        disk = psutil.disk_usage('/')

        print(f"System Resources:")
        print(f"  CPU: {cpu_percent:.1f}%")
        print(f"  Memory: {memory.percent:.1f}% ({memory.used // 1024**3:.1f}GB / {memory.total // 1024**3:.1f}GB)")
        print(f"  Disk: {disk.percent:.1f}% free")

        if torch.cuda.is_available():
            gpu_memory = torch.cuda.get_device_properties(0).total_memory
            gpu_allocated = torch.cuda.memory_allocated(0)
            gpu_percent = (gpu_allocated / gpu_memory) * 100
            print(f"  GPU Memory: {gpu_percent:.1f}% ({gpu_allocated // 1024**3:.1f}GB / {gpu_memory // 1024**3:.1f}GB)")

    except ImportError:
        print("psutil not available for resource monitoring")
    except Exception as e:
        print(f"Resource monitoring error: {e}")


def create_training_report(training_stats: Dict, config: TrainingConfig,
                         final_metrics: Dict = None) -> str:
    """Create comprehensive training report"""

    report_lines = [
        "=== T5 DOMAIN-AWARE TRAINING REPORT ===",
        "",
        "CONFIGURATION:",
        f"  Model: {config.model_name}",
        f"  Epochs: {config.num_epochs}",
        f"  Batch Size: {config.batch_size}",
        f"  Learning Rate: {config.learning_rate}",
        f"  Device: {config.device}",
        "",
        "TRAINING STATISTICS:",
    ]

    if training_stats.get('epoch_losses'):
        best_loss = min(training_stats['epoch_losses'])
        final_loss = training_stats['epoch_losses'][-1]
        report_lines.extend([
            f"  Best Training Loss: {best_loss:.4f}",
            f"  Final Training Loss: {final_loss:.4f}",
            f"  Loss Improvement: {training_stats['epoch_losses'][0] - final_loss:.4f}",
        ])

    if training_stats.get('val_losses'):
        val_losses = [l for l in training_stats['val_losses'] if l > 0]
        if val_losses:
            best_val_loss = min(val_losses)
            report_lines.append(f"  Best Validation Loss: {best_val_loss:.4f}")

    if final_metrics:
        report_lines.extend([
            "",
            "FINAL METRICS:",
        ])
        for metric, value in final_metrics.items():
            report_lines.append(f"  {metric}: {value}")

    report_lines.extend([
        "",
        f"Report generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
        "=" * 50
    ])

    return "\n".join(report_lines)




def debug_cache_files(data_manager: RobustDataManager):
    """Debug function to show all cache files"""
    print("🔍 DEBUG: Cache Files Analysis")
    print("=" * 50)

    preprocessed_dir = data_manager.get_path('preprocessed')
    if preprocessed_dir.exists():
        cache_files = list(preprocessed_dir.glob('*.pkl'))
        print(f"Found {len(cache_files)} cache files:")

        for cache_file in sorted(cache_files):
            try:
                size_mb = cache_file.stat().st_size / (1024 * 1024)
                mtime = datetime.fromtimestamp(cache_file.stat().st_mtime)
                print(f"  📄 {cache_file.name}")
                print(f"      Size: {size_mb:.1f}MB")
                print(f"      Modified: {mtime}")

                # Try to load and get basic info
                try:
                    with open(cache_file, 'rb') as f:
                        data = pickle.load(f)
                    if hasattr(data, '__len__'):
                        print(f"      Length: {len(data)}")
                    print(f"      Type: {type(data).__name__}")
                except Exception as e:
                    print(f"      Error loading: {e}")

                print()

            except Exception as e:
                print(f"  ❌ Error analyzing {cache_file.name}: {e}")
    else:
        print("❌ Preprocessed directory doesn't exist")


def debug_checkpoint_files(data_manager: RobustDataManager):
    """Debug function to show all checkpoint files"""
    print("🔍 DEBUG: Checkpoint Files Analysis")
    print("=" * 50)

    checkpoint_dir = data_manager.get_path('checkpoints')
    if checkpoint_dir.exists():
        for subdir in sorted(checkpoint_dir.iterdir()):
            if subdir.is_dir():
                print(f"📁 {subdir.name}/")

                files = list(subdir.iterdir())
                for file_path in sorted(files):
                    if file_path.is_file():
                        size_mb = file_path.stat().st_size / (1024 * 1024)
                        print(f"  📄 {file_path.name} ({size_mb:.1f}MB)")

                print()
    else:
        print("❌ Checkpoint directory doesn't exist")


# =============================================================================
# MAIN TRAINING FUNCTION WITH RESUME CAPABILITY
# =============================================================================

def train_improved_t5_with_resume(data_dir: str, save_dir: str, project_name: str = "t5_improved",
                                resume_from: str = None, validate_data: bool = True):
    """
    Enhanced training function with resume capability and better error handling - FIXED
    """

    print(f"🚀 Starting improved T5 training with resume capability")
    print(f"📁 Data directory: {data_dir}")
    print(f"💾 Save directory: {save_dir}")

    # FIXED: Validate paths exist
    if not os.path.exists(data_dir):
        raise ValueError(f"Data directory does not exist: {data_dir}")

    # Create save directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)

    # Validate data if requested
    if validate_data and not validate_training_data(data_dir):
        raise ValueError("Training data validation failed")

    # Initialize configuration
    config = TrainingConfig()

    # Initialize data manager
    data_manager = RobustDataManager(save_dir, project_name)

    # FIXED: Always try to resume if checkpoints exist
    checkpoint_data = None
    resume_epoch = 0

    print("🔍 Checking for existing checkpoints...")
    checkpoints = data_manager.list_checkpoints()

    if checkpoints and resume_from != "fresh":
        if resume_from:
            checkpoint_data = load_existing_checkpoint(data_manager, resume_from)
        else:
            # Auto-resume from best/latest checkpoint
            checkpoint_data = load_existing_checkpoint(data_manager, None)

        if checkpoint_data:
            print(f"📦 Resuming from checkpoint: {checkpoint_data['checkpoint_info']['checkpoint_name']}")
            resume_epoch = checkpoint_data['checkpoint_info']['epoch']
        else:
            print("⚠️ Checkpoint loading failed, starting fresh")
    else:
        print("🆕 Starting fresh training (no checkpoints found or fresh requested)")

    # Initialize or load components
    if checkpoint_data:
        model = checkpoint_data['model']
        tokenizer = checkpoint_data['tokenizer']
        print("✅ Model and tokenizer loaded from checkpoint")
    else:
        print("🔄 Initializing new model and tokenizer...")
        tokenizer = T5Tokenizer.from_pretrained(config.model_name)
        model_config = T5Config.from_pretrained(config.model_name)
        model = T5ForConditionalGeneration(model_config)
        print("✅ New model and tokenizer initialized")

    complexity_analyzer = EnhancedComplexityAnalyzer(tokenizer, data_manager)
    task_creator = ImprovedTaskCreator(tokenizer, config.max_source_length, config.max_target_length)

    # Create datasets
    datasets = {}
    file_mapping = {'train': 'babylm_train.txt', 'dev': 'babylm_dev.txt'}

    for split, filename in file_mapping.items():
        filepath = os.path.join(data_dir, filename)

        if os.path.exists(filepath):
            print(f"📊 Creating {split} dataset...")
            max_examples = config.max_train_examples if split == 'train' else config.max_val_examples

            dataset = ImprovedDomainAwareDataset(
                data_path=filepath,
                tokenizer=tokenizer,
                complexity_analyzer=complexity_analyzer,
                task_creator=task_creator,
                data_manager=data_manager,
                max_source_length=config.max_source_length,
                max_target_length=config.max_target_length,
                split=split,
                max_examples=max_examples
            )

            if len(dataset) > 0:
                datasets[split] = dataset
                print(f"✅ Created {split} dataset: {len(dataset)} examples")

    if 'train' not in datasets:
        raise ValueError("No training dataset available")

    # Initialize trainer
    print("🏋️ Initializing trainer...")
    trainer = ImprovedT5Trainer(model, tokenizer, data_manager, config)

    # FIXED: Restore training state if resuming
    if checkpoint_data and checkpoint_data['training_state']:
        training_state = checkpoint_data['training_state']
        try:
            print("🔄 Restoring optimizer state...")
            trainer.optimizer.load_state_dict(training_state['optimizer_state'])

            # FIXED: Only restore scheduler if it was saved and we have one
            if training_state.get('scheduler_state') and trainer.scheduler:
                try:
                    trainer.scheduler.load_state_dict(training_state['scheduler_state'])
                    print("✅ Scheduler state restored")
                except Exception as e:
                    print(f"⚠️ Failed to restore scheduler state: {e}")

            trainer.global_step = training_state.get('global_step', 0)
            print(f"✅ Restored training state from epoch {training_state['epoch']}")
            print(f"   Global step: {trainer.global_step}")

        except Exception as e:
            print(f"⚠️ Failed to restore training state: {e}")
            print("   Continuing with fresh optimizer state")

    # Print system info and estimates
    monitor_system_resources()

    if 'train' in datasets:
        time_estimate = estimate_training_time(
            len(datasets['train']), config.batch_size, config.num_epochs, config.device
        )
        print(f"⏰ Training time estimate: {time_estimate['estimated_hours']:.1f} hours")

    # FIXED: Adjust epochs if resuming
    if resume_epoch > 0:
        remaining_epochs = config.num_epochs - resume_epoch
        print(f"📈 Resuming from epoch {resume_epoch + 1}, {remaining_epochs} epochs remaining")

    # Start training
    print("🎯 Starting training...")
    training_stats = trainer.train(
        train_dataset=datasets['train'],
        val_dataset=datasets.get('dev'),
        resume_epoch=resume_epoch
    )

    # Generate training report
    report = create_training_report(training_stats, config)

    # Save report
    report_path = data_manager.get_path('logs') / f'training_report_{datetime.now().strftime("%Y%m%d_%H%M%S")}.txt'
    with open(report_path, 'w') as f:
        f.write(report)

    print("🎉 Training completed successfully!")
    print(f"📋 Training report saved to: {report_path}")
    print(report)

    return {
        'model': model,
        'tokenizer': tokenizer,
        'data_manager': data_manager,
        'training_stats': training_stats,
        'config': config,
        'report': report
    }


# =============================================================================
# SIMPLE TRAINING FUNCTION (FOR FRESH START)
# =============================================================================

def train_improved_t5(data_dir: str, save_dir: str, project_name: str = "t5_improved"):
    """
    Simple training function for fresh start without resume
    """
    return train_improved_t5_with_resume(
        data_dir=data_dir,
        save_dir=save_dir,
        project_name=project_name,
        resume_from=None,
        validate_data=True
    )


# =============================================================================
# SETUP AND INITIALIZATION FUNCTIONS
# =============================================================================

def setup_training_environment(base_dir: str, project_name: str = "t5_improved"):
    """Setup training environment with proper directory structure - FIXED"""

    data_manager = RobustDataManager(base_dir, project_name)
    config = TrainingConfig()

    print(f"✅ Training environment setup complete")
    print(f"📁 Project directory: {data_manager.get_path('root')}")
    print(f"🚀 GPU available: {torch.cuda.is_available()}")

    # FIXED: Show existing files
    print("📋 Existing files:")
    for path_type, path in data_manager.structure.items():
        if path.exists():
            file_count = len(list(path.glob('*'))) if path.is_dir() else 0
            print(f"  {path_type}: {file_count} files")

    return data_manager, config

def list_available_checkpoints(data_manager: RobustDataManager):
    """List all available checkpoints - FIXED"""
    checkpoints = data_manager.list_checkpoints()

    if not checkpoints:
        print("❌ No checkpoints found")
        return

    print("📋 Available checkpoints:")
    for cp in checkpoints:
        best_marker = " 🏆 (BEST)" if cp.get('is_best', False) else ""
        inferred_marker = " ⚠️ (INFERRED)" if cp.get('inferred', False) else ""
        loss = cp.get('metrics', {}).get('loss', 'N/A')
        loss_str = f"{loss:.4f}" if isinstance(loss, (int, float)) else str(loss)

        print(f"  📦 {cp['checkpoint_name']}: Epoch {cp['epoch']}, Loss: {loss_str}{best_marker}{inferred_marker}")

        if cp.get('inferred'):
            has_training_state = "✅" if cp.get('has_training_state') else "❌"
            print(f"      Training state: {has_training_state}")


def cleanup_old_checkpoints(data_manager: RobustDataManager, keep_last_n: int = 5):
    """Clean up old checkpoints to save space"""
    print(f"🧹 Cleaning up old checkpoints (keeping last {keep_last_n})")
    data_manager.cleanup_old_checkpoints(keep_last_n)


def force_fresh_start(data_manager: RobustDataManager, confirm: bool = False):
    """FIXED: Force a completely fresh start by clearing caches"""
    if not confirm:
        print("⚠️ This will delete all cached data and checkpoints!")
        print("⚠️ Call with confirm=True if you're sure")
        return

    try:
        # Clear preprocessed data
        preprocessed_dir = data_manager.get_path('preprocessed')
        if preprocessed_dir.exists():
            shutil.rmtree(preprocessed_dir)
            preprocessed_dir.mkdir(parents=True, exist_ok=True)
            print("🗑️ Cleared preprocessed data cache")

        # Clear metadata
        metadata_dir = data_manager.get_path('metadata')
        if metadata_dir.exists():
            shutil.rmtree(metadata_dir)
            metadata_dir.mkdir(parents=True, exist_ok=True)
            print("🗑️ Cleared metadata cache")

        # Clear checkpoints (optional)
        response = input("❓ Also clear all checkpoints? (y/N): ")
        if response.lower().startswith('y'):
            checkpoints_dir = data_manager.get_path('checkpoints')
            if checkpoints_dir.exists():
                shutil.rmtree(checkpoints_dir)
                checkpoints_dir.mkdir(parents=True, exist_ok=True)
                print("🗑️ Cleared all checkpoints")

        print("✅ Fresh start prepared")

    except Exception as e:
        print(f"❌ Error during cleanup: {e}")


def train_improved_t5(data_dir: str, save_dir: str, project_name: str = "t5_improved"):
    """
    Simple training function for fresh start without resume - FIXED
    """
    return train_improved_t5_with_resume(
        data_dir=data_dir,
        save_dir=save_dir,
        project_name=project_name,
        resume_from="fresh",  # FIXED: Use "fresh" to force fresh start
        validate_data=True
    )


def resume_training(data_dir: str, save_dir: str, project_name: str = "t5_improved",
                   checkpoint_name: str = None):
    """
    FIXED: Dedicated function for resuming training
    """
    print("🔄 Resume Training Mode")

    # Setup data manager to check for existing work
    data_manager = RobustDataManager(save_dir, project_name)

    # List available checkpoints
    print("🔍 Scanning for checkpoints...")
    list_available_checkpoints(data_manager)

    if checkpoint_name:
        print(f"🎯 Targeting specific checkpoint: {checkpoint_name}")
    else:
        print("🎯 Will auto-select best/latest checkpoint")

    return train_improved_t5_with_resume(
        data_dir=data_dir,
        save_dir=save_dir,
        project_name=project_name,
        resume_from=checkpoint_name,
        validate_data=True
    )


# =============================================================================
# EXAMPLE USAGE AND MAIN EXECUTION
# =============================================================================


if __name__ == "__main__":
    print("=" * 60)
    print("🚀 T5 DOMAIN-AWARE TRAINING SYSTEM - FIXED VERSION")
    print("=" * 60)

    # FIXED: Configuration
    DATA_DIR = '/kaggle/working/baseline-pretraining/babylm_data'
    SAVE_DIR = '/kaggle/working'
    PROJECT_NAME = 't5_domain_aware_v3'

    # Example 1: Setup and check environment
    print("\n=== 🔧 ENVIRONMENT SETUP ===")
    try:
        data_manager, config = setup_training_environment(SAVE_DIR, PROJECT_NAME)
        list_available_checkpoints(data_manager)
        print("✅ Environment setup completed!")
    except Exception as e:
        print(f"❌ Environment setup failed: {e}")

    # Example 2: Resume training (recommended)
    print("\n=== 🔄 RESUME TRAINING ===")
    try:
        results = resume_training(
            data_dir=DATA_DIR,
            save_dir=SAVE_DIR,
            project_name=PROJECT_NAME,
            checkpoint_name=None  # Auto-select best checkpoint
        )
        print("✅ Training completed successfully!")
    except Exception as e:
        print(f"❌ Resume training failed: {e}")
        import traceback
        traceback.print_exc()

    # Example 3: Fresh training (if needed)
    print("\n=== 🆕 FRESH TRAINING (if resume failed) ===")
    try:
        # Uncomment the following lines if you want to force fresh start
        # force_fresh_start(data_manager, confirm=True)
        # results = train_improved_t5(DATA_DIR, SAVE_DIR, PROJECT_NAME)
        # print("✅ Fresh training completed successfully!")
        print("⚠️ Fresh training commented out - uncomment if needed")
    except Exception as e:
        print(f"❌ Fresh training failed: {e}")

    # Example 4: Cleanup (optional)
    print("\n=== 🧹 CLEANUP (optional) ===")
    try:
        # cleanup_old_checkpoints(data_manager, keep_last_n=3)
        print("⚠️ Cleanup commented out - uncomment if needed")
    except Exception as e:
        print(f"❌ Cleanup failed: {e}")

    print("\n" + "=" * 60)
    print("🎉 SCRIPT EXECUTION COMPLETED")
    print("=" * 60)