In [None]:
import os
# Completely remove and reinstall NumPy, Transformers, and Datasets
!pip uninstall -y numpy transformers datasets
!pip install numpy --force-reinstall --no-cache-dir
!pip install transformers datasets --force-reinstall --no-cache-dir
!pip install rouge-score
!pip install evaluate

os.kill(os.getpid(), 9)  # Restart the Colab runtime (REQUIRED)


In [None]:
!git clone https://github.com/babylm/baseline-pretraining.git
%cd baseline-pretraining
!wget -O babylm_data.zip "https://files.osf.io/v1/resources/ad7qg/providers/osfstorage/661517db943bee3731dfec25/?zip="
!unzip babylm_data.zip -d babylm_data
!unzip babylm_data/train_10M.zip -d babylm_data/train_10M
!unzip babylm_data/dev.zip -d babylm_data/dev
!unzip babylm_data/test.zip -d babylm_data/test
!cat babylm_data/train_10M/train_10M/*.train > babylm_data/babylm_train.txt
!cat babylm_data/dev/dev/*.dev > babylm_data/babylm_dev.txt
!cat babylm_data/test/test/*.test > babylm_data/babylm_test.txt

# finetuning T5 small with Curriculum Learning with Hybrid Complexity Analysis


In [None]:
# new more advanced t5-small base model with data preprocessing and preparation
import os
import re
import pickle
import torch
import numpy as np
import random
from typing import Dict, List, Tuple, Optional, Union
from collections import Counter, defaultdict
from torch.utils.data import Dataset, DataLoader, Sampler
from transformers import T5ForConditionalGeneration, T5Tokenizer, T5Config
from transformers import get_linear_schedule_with_warmup
from torch.optim import AdamW
from tqdm import tqdm
import json
import math
from scipy.stats import entropy
from scipy.spatial.distance import cosine
import networkx as nx
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.decomposition import TruncatedSVD
from sklearn.cluster import KMeans
import torch.nn as nn
import time
# =============================================================================
# HYBRID COMPLEXITY ANALYZER COMBINING BOTH APPROACHES
# =============================================================================

class HybridComplexityAnalyzer:
    """
    Hybrid complexity analyzer that combines rule-based and data-driven approaches
    """

    def __init__(self, tokenizer: T5Tokenizer, save_dir: str, learning_mode: bool = True):
        self.tokenizer = tokenizer
        self.save_dir = save_dir
        self.learning_mode = learning_mode
        os.makedirs(save_dir, exist_ok=True)

        # Rule-based patterns (from EnhancedT5ComplexityAnalyzer)
        self.linguistic_patterns = {
            'subordinate_clauses': r'\b(although|because|since|while|whereas|if|unless|when|after|before|until|once|provided|given|assuming)\b',
            'relative_clauses': r'\b(which|that|who|whom|whose|where|when|why)\s+\w+',
            'passive_constructions': r'\b(was|were|is|are|been|being)\s+\w*ed\b|\b\w+\s+(was|were|is|are)\s+\w*ed\b',
            'complex_verb_forms': r'\b(have|has|had|will|would|could|should|might|must)\s+(been|have|had)\b',
            'discourse_markers': r'\b(furthermore|moreover|additionally|however|nevertheless|nonetheless|consequently|therefore|thus|hence)\b',
            'abstract_concepts': r'\b\w+(tion|sion|ness|ment|ity|ism|ance|ence|ship|hood|dom|age)\b',
            'modal_expressions': r'\b(possibly|probably|certainly|definitely|presumably|apparently|obviously)\b',
        }

        # Academic word list
        self.academic_words = {
            'analysis', 'approach', 'area', 'assessment', 'assume', 'authority',
            'available', 'benefit', 'concept', 'consistent', 'constitutional',
            'context', 'contract', 'create', 'data', 'definition', 'derived',
            'distribution', 'economic', 'environment', 'established', 'estimate',
            'evidence', 'export', 'factors', 'financial', 'formula', 'function',
            'identified', 'income', 'indicate', 'individual', 'interpretation',
            'involved', 'issues', 'labor', 'legal', 'legislation', 'major',
            'method', 'occur', 'percent', 'period', 'policy', 'principle',
            'procedure', 'process', 'required', 'research', 'response', 'role',
            'section', 'significant', 'similar', 'source', 'specific', 'structure',
            'theory', 'variables'
        }

        # Data-driven components
        self.word_complexity_scores = {}
        self.semantic_features = {}
        self.syntactic_patterns = {}
        self.cooccurrence_networks = {}
        self.complexity_clusters = {}

        # Models
        self.tfidf_vectorizer = None
        self.complexity_classifier = None
        self.is_trained = False

        # Load or initialize
        self._initialize_components()

    def _initialize_components(self):
        """Initialize or load pre-trained components"""
        models_file = os.path.join(self.save_dir, 'hybrid_complexity_models.pkl')

        if os.path.exists(models_file) and not self.learning_mode:
            print("📦 Loading pre-trained hybrid complexity models...")
            try:
                with open(models_file, 'rb') as f:
                    saved_models = pickle.load(f)
                    self.word_complexity_scores = saved_models.get('word_complexity_scores', {})
                    self.semantic_features = saved_models.get('semantic_features', {})
                    self.syntactic_patterns = saved_models.get('syntactic_patterns', {})
                    self.tfidf_vectorizer = saved_models.get('tfidf_vectorizer', None)
                    self.complexity_clusters = saved_models.get('complexity_clusters', {})
                    self.is_trained = True
                print("✅ Loaded pre-trained models")
            except Exception as e:
                print(f"⚠️ Failed to load models: {e}")

        # Load basic frequency data
        self.word_frequencies = self._load_frequency_dict()

    def _load_frequency_dict(self) -> Dict[str, int]:
        """Load word frequency dictionary"""
        freq_file = os.path.join(self.save_dir, 'word_frequencies.pkl')

        if os.path.exists(freq_file):
            try:
                with open(freq_file, 'rb') as f:
                    return pickle.load(f)
            except:
                pass

        # Create basic frequency mapping
        frequencies = {}

        # High frequency words (low complexity)
        high_freq = ['the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to',
                    'for', 'of', 'with', 'by', 'is', 'are', 'was', 'were', 'be',
                    'been', 'being', 'have', 'has', 'had', 'do', 'does', 'did',
                    'will', 'would', 'can', 'could', 'should', 'may', 'might',
                    'I', 'you', 'he', 'she', 'it', 'we', 'they', 'this', 'that',
                    'time', 'person', 'year', 'way', 'day', 'thing', 'man', 'world',
                    'life', 'hand', 'part', 'child', 'eye', 'woman', 'place', 'work']

        for i, word in enumerate(high_freq):
            frequencies[word.lower()] = i

        # Medium frequency words
        medium_freq = ['analysis', 'approach', 'structure', 'process', 'development',
                      'environment', 'significant', 'individual', 'particular',
                      'available', 'information', 'community', 'economic', 'political']

        for i, word in enumerate(medium_freq):
            frequencies[word.lower()] = 1000 + i * 50

        # Save frequency data
        try:
            with open(freq_file, 'wb') as f:
                pickle.dump(frequencies, f)
        except:
            pass

        return frequencies

    def learn_from_data(self, texts: List[str], complexity_scores: List[float] = None):
        """
        Learn complexity patterns from training data
        """
        if not texts:
            print("⚠️ No training texts provided")
            return

        print("🧠 Learning complexity patterns from data...")

        # If no complexity scores provided, calculate basic ones
        if complexity_scores is None:
            complexity_scores = [self._calculate_basic_complexity(text) for text in texts]

        # 1. Learn word-level complexity
        self._learn_word_complexity(texts, complexity_scores)

        # 2. Learn semantic patterns using TF-IDF
        self._learn_semantic_patterns(texts, complexity_scores)

        # 3. Learn syntactic pattern correlations
        self._learn_syntactic_correlations(texts, complexity_scores)

        # 4. Create complexity clusters
        self._create_complexity_clusters(texts, complexity_scores)

        # Save learned models
        self._save_models()

        self.is_trained = True
        print("✅ Learning completed!")

    def _learn_word_complexity(self, texts: List[str], complexity_scores: List[float]):
        """Learn word-level complexity associations"""
        print("📊 Learning word-level complexity...")

        word_stats = defaultdict(list)

        for text, complexity in zip(texts, complexity_scores):
            words = re.findall(r'\b\w+\b', text.lower())
            for word in words:
                word_stats[word].append(complexity)

        # Calculate complexity scores for words
        for word, scores in word_stats.items():
            if len(scores) >= 3:  # Minimum occurrences
                avg_complexity = np.mean(scores)
                frequency = self.word_frequencies.get(word, 10000)  # Default high frequency

                # Combine average complexity with rarity
                complexity_score = (
                    avg_complexity * 0.7 +
                    min(frequency / 10000.0, 1.0) * 0.3  # Rarity factor
                )

                self.word_complexity_scores[word] = {
                    'score': complexity_score,
                    'frequency': len(scores),
                    'avg_complexity': avg_complexity
                }

    def _learn_semantic_patterns(self, texts: List[str], complexity_scores: List[float]):
        """Learn semantic complexity patterns using TF-IDF"""
        print("🎯 Learning semantic patterns...")

        # Create TF-IDF vectorizer
        self.tfidf_vectorizer = TfidfVectorizer(
            max_features=5000,
            ngram_range=(1, 2),
            stop_words='english',
            min_df=2,
            max_df=0.95
        )

        try:
            tfidf_matrix = self.tfidf_vectorizer.fit_transform(texts)
            feature_names = self.tfidf_vectorizer.get_feature_names_out()

            # Calculate correlation between features and complexity
            feature_complexities = {}

            for i, feature in enumerate(feature_names):
                feature_scores = tfidf_matrix[:, i].toarray().flatten()
                if np.std(feature_scores) > 0:
                    correlation = np.corrcoef(feature_scores, complexity_scores)[0, 1]
                    if not np.isnan(correlation):
                        feature_complexities[feature] = correlation

            # Store top complexity-correlated features
            sorted_features = sorted(feature_complexities.items(),
                                   key=lambda x: abs(x[1]), reverse=True)

            self.semantic_features = {
                'high_complexity': [f for f, c in sorted_features[:200] if c > 0],
                'low_complexity': [f for f, c in sorted_features[-200:] if c < 0],
                'correlations': feature_complexities
            }

        except Exception as e:
            print(f"⚠️ TF-IDF learning failed: {e}")
            self.semantic_features = {'high_complexity': [], 'low_complexity': [], 'correlations': {}}

    def _learn_syntactic_correlations(self, texts: List[str], complexity_scores: List[float]):
        """Learn correlations between syntactic patterns and complexity"""
        print("🔍 Learning syntactic correlations...")

        pattern_correlations = {}

        for pattern_name, pattern_regex in self.linguistic_patterns.items():
            pattern_densities = []

            for text in texts:
                words = re.findall(r'\b\w+\b', text)
                matches = len(re.findall(pattern_regex, text, re.IGNORECASE))
                density = matches / max(len(words), 1)
                pattern_densities.append(density)

            # Calculate correlation with complexity
            if np.std(pattern_densities) > 0:
                correlation = np.corrcoef(pattern_densities, complexity_scores)[0, 1]
                if not np.isnan(correlation):
                    pattern_correlations[pattern_name] = {
                        'correlation': correlation,
                        'avg_density': np.mean(pattern_densities)
                    }

        self.syntactic_patterns = pattern_correlations

    def _create_complexity_clusters(self, texts: List[str], complexity_scores: List[float]):
        """Create complexity-based text clusters"""
        print("🎨 Creating complexity clusters...")

        if not self.tfidf_vectorizer:
            return

        try:
            # Transform texts to TF-IDF vectors
            tfidf_matrix = self.tfidf_vectorizer.transform(texts)

            # Dimensionality reduction
            svd = TruncatedSVD(n_components=50)
            reduced_features = svd.fit_transform(tfidf_matrix)

            # K-means clustering
            kmeans = KMeans(n_clusters=5, random_state=42, n_init=10)
            cluster_labels = kmeans.fit_predict(reduced_features)

            # Map clusters to average complexity
            cluster_complexities = {}
            for cluster_id in range(5):
                mask = cluster_labels == cluster_id
                if np.any(mask):
                    cluster_complexities[cluster_id] = np.mean(np.array(complexity_scores)[mask])

            self.complexity_clusters = {
                'svd': svd,
                'kmeans': kmeans,
                'cluster_complexities': cluster_complexities
            }

        except Exception as e:
            print(f"⚠️ Clustering failed: {e}")
            self.complexity_clusters = {}

    def analyze_text_complexity(self, text: str) -> Dict[str, float]:
        """
        Main method to analyze text complexity using hybrid approach
        """
        if not text or len(text.strip()) < 10:
            return {'overall_complexity': 0.0}

        words = re.findall(r'\b\w+\b', text.lower())
        sentences = re.split(r'[.!?]+', text)
        sentences = [s.strip() for s in sentences if s.strip()]

        if not words or not sentences:
            return {'overall_complexity': 0.0}

        features = {}

        # 1. RULE-BASED FEATURES (from original EnhancedT5ComplexityAnalyzer)
        features.update(self._calculate_rule_based_features(text, words, sentences))

        # 2. DATA-DRIVEN FEATURES
        features.update(self._calculate_data_driven_features(text, words))

        # 3. SEMANTIC FEATURES
        features.update(self._calculate_semantic_features(text))

        # 4. OVERALL COMPLEXITY CALCULATION
        features['overall_complexity'] = self._calculate_overall_complexity(features)

        return features

    def _calculate_rule_based_features(self, text: str, words: List[str], sentences: List[str]) -> Dict[str, float]:
        """Calculate rule-based linguistic features"""
        features = {}

        # Lexical diversity
        unique_words = len(set(words))
        ttr = unique_words / len(words)
        corrected_ttr = unique_words / math.sqrt(2 * len(words))
        features['lexical_diversity'] = min(corrected_ttr, 1.0)

        # Morphological complexity
        avg_word_length = sum(len(word) for word in words) / len(words)
        features['morphological_complexity'] = min((avg_word_length - 3) / 6, 1.0)

        # Sentence complexity
        sent_lengths = [len(re.findall(r'\b\w+\b', s)) for s in sentences]
        avg_sent_length = sum(sent_lengths) / len(sent_lengths)
        features['sentence_length'] = min((avg_sent_length - 8) / 25, 1.0)

        # Syntactic pattern density
        syntactic_complexity = 0.0
        for pattern_name, pattern_regex in self.linguistic_patterns.items():
            matches = len(re.findall(pattern_regex, text, re.IGNORECASE))
            density = matches / len(words)

            # Weight by learned correlation if available
            if self.is_trained and pattern_name in self.syntactic_patterns:
                weight = abs(self.syntactic_patterns[pattern_name]['correlation'])
            else:
                weight = 1.0

            syntactic_complexity += density * weight

        features['syntactic_complexity'] = min(syntactic_complexity, 1.0)

        # Academic word density
        academic_count = sum(1 for word in words if word in self.academic_words)
        features['academic_density'] = min(academic_count / len(words) * 10, 1.0)

        return features

    def _calculate_data_driven_features(self, text: str, words: List[str]) -> Dict[str, float]:
        """Calculate data-driven features"""
        features = {}

        if not self.is_trained:
            features['learned_word_complexity'] = 0.5
            features['rare_word_density'] = 0.5
            return features

        # Learned word complexity
        word_complexities = []
        rare_word_count = 0

        for word in words:
            if word in self.word_complexity_scores:
                word_complexities.append(self.word_complexity_scores[word]['score'])
            elif word not in self.word_frequencies or self.word_frequencies[word] > 5000:
                rare_word_count += 1

        if word_complexities:
            features['learned_word_complexity'] = min(np.mean(word_complexities), 1.0)
        else:
            features['learned_word_complexity'] = 0.5

        features['rare_word_density'] = min(rare_word_count / len(words) * 5, 1.0)

        return features

    def _calculate_semantic_features(self, text: str) -> Dict[str, float]:
        """Calculate semantic complexity features"""
        features = {}

        if not self.is_trained or not self.tfidf_vectorizer:
            features['semantic_complexity'] = 0.5
            features['cluster_complexity'] = 0.5
            return features

        try:
            # TF-IDF based semantic complexity
            tfidf_vector = self.tfidf_vectorizer.transform([text])
            feature_names = self.tfidf_vectorizer.get_feature_names_out()

            semantic_score = 0.0
            total_weight = 0.0

            for i, feature in enumerate(feature_names):
                weight = tfidf_vector[0, i]
                if weight > 0 and feature in self.semantic_features['correlations']:
                    correlation = self.semantic_features['correlations'][feature]
                    semantic_score += weight * max(correlation, 0)  # Only positive correlations
                    total_weight += weight

            if total_weight > 0:
                features['semantic_complexity'] = min(semantic_score / total_weight, 1.0)
            else:
                features['semantic_complexity'] = 0.5

            # Cluster-based complexity
            if self.complexity_clusters:
                try:
                    reduced_vector = self.complexity_clusters['svd'].transform(tfidf_vector)
                    cluster_id = self.complexity_clusters['kmeans'].predict(reduced_vector)[0]
                    cluster_complexity = self.complexity_clusters['cluster_complexities'].get(cluster_id, 0.5)
                    features['cluster_complexity'] = cluster_complexity
                except:
                    features['cluster_complexity'] = 0.5
            else:
                features['cluster_complexity'] = 0.5

        except Exception as e:
            features['semantic_complexity'] = 0.5
            features['cluster_complexity'] = 0.5

        return features

    def _calculate_overall_complexity(self, features: Dict[str, float]) -> float:
        """Calculate overall complexity from all features"""

        # Adaptive weights based on whether system is trained
        if self.is_trained:
            weights = {
                'lexical_diversity': 0.12,
                'morphological_complexity': 0.08,
                'sentence_length': 0.10,
                'syntactic_complexity': 0.15,
                'academic_density': 0.08,
                'learned_word_complexity': 0.20,
                'rare_word_density': 0.10,
                'semantic_complexity': 0.12,
                'cluster_complexity': 0.05
            }
        else:
            # Rule-based weights when not trained
            weights = {
                'lexical_diversity': 0.20,
                'morphological_complexity': 0.15,
                'sentence_length': 0.15,
                'syntactic_complexity': 0.25,
                'academic_density': 0.15,
                'learned_word_complexity': 0.05,
                'rare_word_density': 0.05,
                'semantic_complexity': 0.0,
                'cluster_complexity': 0.0
            }

        overall_score = sum(
            features.get(feature, 0.0) * weight
            for feature, weight in weights.items()
        )

        return min(max(overall_score, 0.01), 1.0)

    def _calculate_basic_complexity(self, text: str) -> float:
        """Calculate basic complexity for initial learning"""
        words = re.findall(r'\b\w+\b', text.lower())
        sentences = re.split(r'[.!?]+', text)
        sentences = [s.strip() for s in sentences if s.strip()]

        if not words or not sentences:
            return 0.1

        # Simple heuristics
        avg_word_len = sum(len(w) for w in words) / len(words)
        avg_sent_len = len(words) / len(sentences)
        academic_ratio = sum(1 for w in words if w in self.academic_words) / len(words)

        complexity = (
            min((avg_word_len - 4) / 6, 1.0) * 0.3 +
            min((avg_sent_len - 8) / 20, 1.0) * 0.4 +
            min(academic_ratio * 10, 1.0) * 0.3
        )

        return min(max(complexity, 0.05), 0.95)

    def _save_models(self):
        """Save all learned models"""
        models_to_save = {
            'word_complexity_scores': self.word_complexity_scores,
            'semantic_features': self.semantic_features,
            'syntactic_patterns': self.syntactic_patterns,
            'tfidf_vectorizer': self.tfidf_vectorizer,
            'complexity_clusters': self.complexity_clusters
        }

        models_file = os.path.join(self.save_dir, 'hybrid_complexity_models.pkl')
        try:
            with open(models_file, 'wb') as f:
                pickle.dump(models_to_save, f)
            print(f"💾 Models saved to {models_file}")
        except Exception as e:
            print(f"⚠️ Failed to save models: {e}")

# =============================================================================
# ENHANCED T5 DATASET WITH HYBRID COMPLEXITY ANALYSIS
# =============================================================================

class T5CurriculumDataset(Dataset):
    """T5 dataset with hybrid complexity analysis and curriculum learning"""

    def __init__(self, data_path: str, tokenizer: T5Tokenizer,
                 complexity_analyzer: HybridComplexityAnalyzer,
                 max_source_length: int = 512, max_target_length: int = 256,
                 cache_dir: str = None, split: str = "train",
                 max_examples: int = None, corruption_probability: float = 0.15):

        self.data_path = data_path
        self.tokenizer = tokenizer
        self.complexity_analyzer = complexity_analyzer
        self.max_source_length = max_source_length
        self.max_target_length = max_target_length
        self.split = split
        self.max_examples = max_examples
        self.corruption_probability = corruption_probability

        # Cache setup
        self.cache_dir = cache_dir
        if cache_dir:
            os.makedirs(cache_dir, exist_ok=True)
            self.cache_file = os.path.join(cache_dir, f"t5_curriculum_{split}_data.pkl")
        else:
            self.cache_file = None

        self.examples = []
        self.complexity_scores = []
        self.curriculum_levels = {}

        self._load_or_process_data()

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return self.examples[idx]

    def _load_or_process_data(self):
        """Load cached data or process from scratch"""

        # Try loading from cache
        if self.cache_file and os.path.exists(self.cache_file):
            try:
                print(f"📦 Loading cached data from {self.cache_file}")
                with open(self.cache_file, 'rb') as f:
                    cached_data = pickle.load(f)
                    self.examples = cached_data['examples']
                    self.complexity_scores = cached_data['complexity_scores']
                    self.curriculum_levels = cached_data.get('curriculum_levels', {})
                print(f"✅ Loaded {len(self.examples)} cached examples")
                return
            except Exception as e:
                print(f"⚠️ Failed to load cache: {e}")

        # Process from scratch
        self._process_raw_data()

        # Create curriculum if training data
        if self.split == 'train' and len(self.examples) > 100:
            self._create_curriculum()

        # Save to cache
        if self.cache_file:
            try:
                cached_data = {
                    'examples': self.examples,
                    'complexity_scores': self.complexity_scores,
                    'curriculum_levels': self.curriculum_levels
                }
                with open(self.cache_file, 'wb') as f:
                    pickle.dump(cached_data, f)
                print(f"💾 Cached {len(self.examples)} examples")
            except Exception as e:
                print(f"⚠️ Failed to save cache: {e}")

    def _process_raw_data(self):
        """Process raw data into T5 tasks"""

        if not os.path.exists(self.data_path):
            print(f"❌ Data file not found: {self.data_path}")
            return

        # Read data
        try:
            with open(self.data_path, 'r', encoding='utf-8', errors='ignore') as f:
                text = f.read()
        except Exception as e:
            print(f"❌ Failed to read data: {e}")
            return

        # Clean and split into documents
        documents = self._clean_and_split_text(text)

        print(f"📝 Processing {len(documents)} documents...")

        processed_count = 0
        all_texts_for_learning = []
        all_complexities_for_learning = []

        # First pass: collect all texts for complexity learning
        if self.complexity_analyzer.learning_mode and self.split == 'train':
            print("🔍 First pass: collecting texts for complexity learning...")
            for doc in tqdm(documents[:min(2000, len(documents))], desc="Collecting texts"):
                if len(doc.split()) < 20 or len(doc.split()) > 500:
                    continue
                all_texts_for_learning.append(doc)
                # Calculate basic complexity for initial learning
                basic_complexity = self.complexity_analyzer._calculate_basic_complexity(doc)
                all_complexities_for_learning.append(basic_complexity)

            # Train the complexity analyzer
            if all_texts_for_learning:
                self.complexity_analyzer.learn_from_data(
                    all_texts_for_learning,
                    all_complexities_for_learning
                )

        # Second pass: process into T5 tasks with learned complexity analysis
        for doc in tqdm(documents, desc="Processing documents"):
            if self.max_examples and processed_count >= self.max_examples:
                break

            try:
                word_count = len(doc.split())
                if word_count < 20 or word_count > 500:
                    continue

                # Analyze complexity with hybrid analyzer
                complexity_analysis = self.complexity_analyzer.analyze_text_complexity(doc)
                complexity_score = complexity_analysis['overall_complexity']

                # Skip very low complexity texts in training
                if self.split == 'train' and complexity_score < 0.02:
                    continue

                # Create T5 tasks
                tasks = self._create_t5_tasks(doc)

                for task in tasks:
                    if task and processed_count < (self.max_examples or float('inf')):
                        self.examples.append(task)
                        self.complexity_scores.append(complexity_score)
                        processed_count += 1

            except Exception as e:
                continue

        print(f"✅ Processed {len(self.examples)} examples")

    def _clean_and_split_text(self, text: str) -> List[str]:
        """Clean text and split into documents"""

        # Basic cleaning
        text = re.sub(r'\n\s*\n', '\n\n', text)
        text = re.sub(r'[ \t]+', ' ', text)
        text = re.sub(r'\s+([,.!?;:])', r'\1', text)

        # Split into paragraphs
        paragraphs = text.split('\n\n')
        documents = []

        for para in paragraphs:
            para = para.strip()
            if len(para.split()) < 20:
                continue

            # If paragraph is too long, split by sentences
            if len(para.split()) > 400:
                sentences = re.split(r'(?<=[.!?])\s+', para)
                current_doc = []
                current_length = 0

                for sent in sentences:
                    sent_length = len(sent.split())
                    if current_length + sent_length > 300 and current_doc:
                        documents.append(' '.join(current_doc))
                        current_doc = [sent]
                        current_length = sent_length
                    else:
                        current_doc.append(sent)
                        current_length += sent_length

                if current_doc and current_length >= 20:
                    documents.append(' '.join(current_doc))
            else:
                documents.append(para)

        return documents

    def _create_t5_tasks(self, text: str) -> List[Dict]:
        """Create various T5 tasks from text"""

        tasks = []
        sentences = re.split(r'(?<=[.!?])\s+', text)
        sentences = [s.strip() for s in sentences if len(s.split()) >= 3]

        if len(sentences) < 2:
            return tasks

        # 1. Span Corruption (T5's main pretraining task)
        corrupted_task = self._create_span_corruption_task(text)
        if corrupted_task:
            tasks.append(corrupted_task)

        # 2. Summarization
        if len(sentences) >= 3:
            if len(sentences) <= 5:
                summary = sentences[0]
            else:
                num_summary = max(1, len(sentences) // 3)
                summary = ' '.join(sentences[:num_summary])

            source_text = f"summarize: {text}"
            task = self._tokenize_example(source_text, summary, 'summarization')
            if task:
                tasks.append(task)

        # 3. Text Completion
        if len(sentences) >= 3:
            split_point = len(sentences) // 2
            prefix = ' '.join(sentences[:split_point])
            completion = ' '.join(sentences[split_point:])

            source_text = f"complete: {prefix}"
            task = self._tokenize_example(source_text, completion, 'completion')
            if task:
                tasks.append(task)

        return tasks

    def _create_span_corruption_task(self, text: str) -> Optional[Dict]:
        """Create T5-style span corruption task"""

        try:
            tokens = self.tokenizer.tokenize(text)

            if len(tokens) < 10:
                return None

            # Calculate spans to corrupt
            num_tokens_to_corrupt = max(1, int(len(tokens) * self.corruption_probability))
            corrupted_tokens = tokens.copy()
            target_spans = []
            sentinel_count = 0

            i = 0
            tokens_corrupted = 0

            while i < len(corrupted_tokens) and tokens_corrupted < num_tokens_to_corrupt:
                if random.random() < 0.3 and tokens_corrupted < num_tokens_to_corrupt:
                    # Determine span length (exponential distribution)
                    span_length = max(1, min(5, int(random.expovariate(1.0/3.0))))
                    span_end = min(i + span_length, len(corrupted_tokens))

                    # Extract span
                    span_tokens = corrupted_tokens[i:span_end]
                    sentinel_token = f'<extra_id_{sentinel_count}>'
                    target_spans.append(f'{sentinel_token} {" ".join(span_tokens)}')

                    # Replace in source
                    corrupted_tokens[i:span_end] = [sentinel_token]

                    tokens_corrupted += len(span_tokens)
                    sentinel_count += 1
                    i += 1
                else:
                    i += 1

            # Create source and target
            source_text = self.tokenizer.convert_tokens_to_string(corrupted_tokens)
            target_text = ' '.join(target_spans) + f' <extra_id_{sentinel_count}>'

            return self._tokenize_example(source_text, target_text, 'span_corruption')

        except Exception as e:
            return None

    def _tokenize_example(self, source_text: str, target_text: str, task_type: str) -> Optional[Dict]:
        """Tokenize example with proper error handling"""

        try:
            source_text = source_text.strip()
            target_text = target_text.strip()

            if not source_text or not target_text:
                return None

            source_encoding = self.tokenizer(
                source_text,
                max_length=self.max_source_length,
                truncation=True,
                padding='max_length',
                return_tensors='pt'
            )

            target_encoding = self.tokenizer(
                target_text,
                max_length=self.max_target_length,
                truncation=True,
                padding='max_length',
                return_tensors='pt'
            )

            return {
                'input_ids': source_encoding['input_ids'].squeeze(0),
                'attention_mask': source_encoding['attention_mask'].squeeze(0),
                'labels': target_encoding['input_ids'].squeeze(0),
                'source_text': source_text,
                'target_text': target_text,
                'task_type': task_type
            }

        except Exception as e:
            return None

    def _create_curriculum(self):
        """Create 5-level curriculum based on complexity"""

        print("📚 Creating curriculum...")

        # Sort by complexity
        complexity_indices = sorted(
            range(len(self.examples)),
            key=lambda i: self.complexity_scores[i]
        )

        # Create 5 levels with overlap for smooth transitions
        num_levels = 5
        level_size = len(complexity_indices) // num_levels
        overlap_size = level_size // 4  # 25% overlap

        for level in range(num_levels):
            start_idx = max(0, level * level_size - overlap_size)
            if level == 0:
                start_idx = 0

            end_idx = min(len(complexity_indices), (level + 1) * level_size + overlap_size)
            if level == num_levels - 1:
                end_idx = len(complexity_indices)

            level_indices = complexity_indices[start_idx:end_idx]
            self.curriculum_levels[level] = level_indices

            # Statistics
            level_complexities = [self.complexity_scores[i] for i in level_indices]
            if level_complexities:
                print(f"Level {level}: {len(level_indices)} examples, "
                      f"complexity: {min(level_complexities):.3f}-{max(level_complexities):.3f}")

    def get_curriculum_level_data(self, level: int) -> List[int]:
        """Get indices for curriculum level"""
        return self.curriculum_levels.get(level, list(range(len(self.examples))))

# =============================================================================
# ADAPTIVE CURRICULUM SAMPLER
# =============================================================================

class AdaptiveCurriculumSampler(Sampler):
    """Adaptive curriculum sampler that adjusts based on training progress"""

    def __init__(self, dataset: T5CurriculumDataset, current_epoch: int, max_epochs: int,
                 strategy: str = 'progressive'):

        self.dataset = dataset
        self.current_epoch = current_epoch
        self.max_epochs = max_epochs
        self.strategy = strategy

        # Calculate progress
        self.progress = min(current_epoch / max(max_epochs - 1, 1), 1.0)

        # Determine active levels
        self.active_levels = self._get_active_levels()

        # Get active indices
        self.active_indices = self._get_active_indices()

        print(f"📚 Curriculum Epoch {current_epoch}: Levels {self.active_levels}, "
              f"{len(self.active_indices)} examples")

    def _get_active_levels(self) -> List[int]:
        """Determine which curriculum levels to include"""

        if self.strategy == 'progressive':
            if self.progress < 0.2:
                return [0]
            elif self.progress < 0.4:
                return [0, 1]
            elif self.progress < 0.6:
                return [0, 1, 2]
            elif self.progress < 0.8:
                return [0, 1, 2, 3]
            else:
                return [0, 1, 2, 3, 4]

        elif self.strategy == 'mixed':
            # Always include easier levels, gradually add harder ones
            levels = [0, 1]
            if self.progress > 0.3:
                levels.append(2)
            if self.progress > 0.6:
                levels.append(3)
            if self.progress > 0.8:
                levels.append(4)
            return levels

        else:  # uniform
            return list(range(5))

    def _get_active_indices(self) -> List[int]:
        """Get indices for active curriculum levels"""

        all_indices = []

        # Sample from each active level
        for level in self.active_levels:
            level_indices = self.dataset.get_curriculum_level_data(level)

            # Weight simpler levels more heavily early in training
            if len(self.active_levels) > 1:
                weight = 1.0 + (len(self.active_levels) - level - 1) * 0.2
                sample_size = int(len(level_indices) * weight / sum(
                    1.0 + (len(self.active_levels) - l - 1) * 0.2
                    for l in self.active_levels
                ))
                sample_size = min(sample_size, len(level_indices))
            else:
                sample_size = len(level_indices)

            if sample_size > 0:
                sampled = random.sample(level_indices, sample_size)
                all_indices.extend(sampled)

        random.shuffle(all_indices)
        return all_indices

    def __iter__(self):
        return iter(self.active_indices)

    def __len__(self):
        return len(self.active_indices)

# =============================================================================
# T5 TRAINER WITH CURRICULUM LEARNING
# =============================================================================

class T5CurriculumTrainer:
    """T5 trainer with curriculum learning capabilities"""

    def __init__(self, model: T5ForConditionalGeneration, tokenizer: T5Tokenizer,
                 save_dir: str, learning_rate: float = 1e-4, weight_decay: float = 0.01):

        self.model = model
        self.tokenizer = tokenizer
        self.save_dir = save_dir

        os.makedirs(save_dir, exist_ok=True)

        # Optimizer setup
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters()
                          if not any(nd in n for nd in no_decay)],
                "weight_decay": weight_decay,
            },
            {
                "params": [p for n, p in model.named_parameters()
                          if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]

        self.optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
        self.scheduler = None
        self.global_step = 0

        # Training metrics
        self.training_stats = {
            'epoch_losses': [],
            'val_losses': [],
            'curriculum_levels': []
        }

        print("✅ T5 Curriculum Trainer initialized")

    def train(self, train_dataset: T5CurriculumDataset, val_dataset: Optional[T5CurriculumDataset] = None,
              num_epochs: int = 3, batch_size: int = 8, curriculum_strategy: str = 'progressive',
              device: str = 'cuda', patience: int = 3):

        print(f"🚀 Starting T5 curriculum training")
        print(f"   Epochs: {num_epochs}, Batch size: {batch_size}")
        print(f"   Strategy: {curriculum_strategy}")

        # Setup scheduler
        steps_per_epoch = max(1, len(train_dataset) // batch_size)
        total_steps = steps_per_epoch * num_epochs
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer, num_warmup_steps=total_steps // 10, num_training_steps=total_steps
        )

        # Move model to device
        self.model.to(device)

        best_val_loss = float('inf')
        patience_counter = 0

        for epoch in range(num_epochs):
            print(f"\n{'='*50}")
            print(f"EPOCH {epoch + 1}/{num_epochs}")
            print(f"{'='*50}")

            # Create curriculum sampler
            curriculum_sampler = AdaptiveCurriculumSampler(
                train_dataset, epoch, num_epochs, curriculum_strategy
            )

            # Data loader
            train_loader = DataLoader(
                train_dataset,
                batch_size=batch_size,
                sampler=curriculum_sampler,
                collate_fn=self._collate_fn,
                num_workers=0
            )

            # Train epoch
            epoch_loss = self._train_epoch(train_loader, device)

            # Validation
            val_loss = None
            if val_dataset:
                val_loss = self._evaluate(val_dataset, device, batch_size)

                # Early stopping
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    patience_counter = 0
                    self._save_checkpoint(epoch, is_best=True)
                else:
                    patience_counter += 1

                if patience_counter >= patience:
                    print(f"🛑 Early stopping at epoch {epoch + 1}")
                    break

            # Save stats
            self.training_stats['epoch_losses'].append(epoch_loss)
            self.training_stats['val_losses'].append(val_loss or 0)
            self.training_stats['curriculum_levels'].append(curriculum_sampler.active_levels)

            print(f"📊 Epoch {epoch + 1} - Train Loss: {epoch_loss:.4f}")
            if val_loss:
                print(f"📊 Epoch {epoch + 1} - Val Loss: {val_loss:.4f}")

        # Save final model
        self._save_final_model()
        return self.training_stats

    def _train_epoch(self, train_loader: DataLoader, device: str) -> float:
        """Train one epoch"""

        self.model.train()
        epoch_loss = 0.0
        num_batches = 0

        progress_bar = tqdm(train_loader, desc='Training')

        for batch in progress_bar:
            # Move to device
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
                    for k, v in batch.items()}

            # Forward pass
            outputs = self.model(
                input_ids=batch['input_ids'],
                attention_mask=batch['attention_mask'],
                labels=batch['labels']
            )

            loss = outputs.loss

            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)

            self.optimizer.step()
            self.scheduler.step()
            self.optimizer.zero_grad()

            # Update metrics
            epoch_loss += loss.item()
            num_batches += 1
            self.global_step += 1

            # Update progress bar
            progress_bar.set_postfix({
                'Loss': f'{loss.item():.4f}',
                'Avg': f'{epoch_loss/num_batches:.4f}'
            })

        return epoch_loss / max(num_batches, 1)

    def _evaluate(self, val_dataset: T5CurriculumDataset, device: str, batch_size: int = 8) -> float:
        """Evaluate model on validation set"""

        self.model.eval()

        # Sample validation examples
        val_indices = list(range(min(500, len(val_dataset))))
        random.shuffle(val_indices)

        val_loader = DataLoader(
            val_dataset,
            batch_size=batch_size,
            sampler=val_indices,
            collate_fn=self._collate_fn
        )

        total_loss = 0.0
        num_batches = 0

        with torch.no_grad():
            for batch in tqdm(val_loader, desc='Validation', leave=False):
                batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
                        for k, v in batch.items()}

                outputs = self.model(
                    input_ids=batch['input_ids'],
                    attention_mask=batch['attention_mask'],
                    labels=batch['labels']
                )

                total_loss += outputs.loss.item()
                num_batches += 1

        return total_loss / max(num_batches, 1)

    def _collate_fn(self, batch):
        """Collate function for T5 batch processing"""

        input_ids = torch.stack([item['input_ids'] for item in batch])
        attention_mask = torch.stack([item['attention_mask'] for item in batch])
        labels = torch.stack([item['labels'] for item in batch])

        # Mask padding tokens in labels
        labels[labels == self.tokenizer.pad_token_id] = -100

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels,
            'task_types': [item['task_type'] for item in batch]
        }

    def _save_checkpoint(self, epoch: int, is_best: bool = False):
        """Save model checkpoint"""

        checkpoint_dir = os.path.join(self.save_dir, f'checkpoint-epoch-{epoch}')
        if is_best:
            checkpoint_dir = os.path.join(self.save_dir, 'best_model')

        os.makedirs(checkpoint_dir, exist_ok=True)

        self.model.save_pretrained(checkpoint_dir)
        self.tokenizer.save_pretrained(checkpoint_dir)

        # Save training state
        training_state = {
            'epoch': epoch,
            'global_step': self.global_step,
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
            'training_stats': self.training_stats
        }

        torch.save(training_state, os.path.join(checkpoint_dir, 'training_state.pt'))

    def _save_final_model(self):
        """Save final model"""

        final_dir = os.path.join(self.save_dir, 'final_model')
        os.makedirs(final_dir, exist_ok=True)

        self.model.save_pretrained(final_dir)
        self.tokenizer.save_pretrained(final_dir)

        # Save training stats
        with open(os.path.join(final_dir, 'training_stats.json'), 'w') as f:
            json.dump(self.training_stats, f, indent=2)

        print(f"🏆 Final model saved to {final_dir}")

# =============================================================================
# CONFIGURATION AND MAIN TRAINING FUNCTION
# =============================================================================

class TrainingConfig:
    """Training configuration"""

    def __init__(self):
        # Model config
        self.model_name = 't5-small'
        self.max_source_length = 512
        self.max_target_length = 256

        # Training config
        self.num_epochs = 5
        self.batch_size = 4
        self.learning_rate = 1e-4
        self.weight_decay = 0.01

        # Curriculum config
        self.curriculum_strategy = 'progressive'  # 'progressive', 'mixed', 'uniform'
        self.corruption_probability = 0.15

        # Data config
        self.max_train_examples = None  # No limit
        self.max_val_examples = None
        self.patience = 3

        # Paths - UPDATE THESE FOR YOUR SETUP
        self.data_dir = '/content/baseline-pretraining/babylm_data'
        self.save_dir = '/content/drive/MyDrive/llm-project/t5-small-new-base_datapreparation/t5_curriculum_training'
        self.cache_dir = '/content/drive/MyDrive/llm-project/t5-small-new-base_datapreparation/full_t5_curriculum_cache'

        # Hardware
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

def create_datasets(config: TrainingConfig):
    """Create training and validation datasets"""

    print("🔧 Creating datasets...")

    # Initialize components
    tokenizer = T5Tokenizer.from_pretrained(config.model_name)

    # Initialize hybrid complexity analyzer
    complexity_analyzer = HybridComplexityAnalyzer(
        tokenizer=tokenizer,
        save_dir=os.path.join(config.cache_dir, 'complexity_models'),
        learning_mode=True  # Enable learning from data
    )

    datasets = {}

    # Dataset files
    file_mapping = {
        'train': 'babylm_train.txt',
        'dev': 'babylm_dev.txt'
    }

    for split, filename in file_mapping.items():
        filepath = os.path.join(config.data_dir, filename)

        if not os.path.exists(filepath):
            print(f"⚠️ File not found: {filepath}")
            continue

        max_examples = (config.max_train_examples if split == 'train'
                       else config.max_val_examples)

        dataset = T5CurriculumDataset(
            data_path=filepath,
            tokenizer=tokenizer,
            complexity_analyzer=complexity_analyzer,
            max_source_length=config.max_source_length,
            max_target_length=config.max_target_length,
            cache_dir=os.path.join(config.cache_dir, 'dataset_cache'),
            split=split,
            max_examples=max_examples,
            corruption_probability=config.corruption_probability
        )

        if len(dataset) > 0:
            datasets[split] = dataset
            print(f"✅ Created {split} dataset: {len(dataset)} examples")

    return datasets, tokenizer, complexity_analyzer

def train_t5_with_curriculum(config: TrainingConfig = None):
    """Main training function"""

    if config is None:
        config = TrainingConfig()

    print(f"🚀 T5 Curriculum Training")
    print(f"   Model: {config.model_name}")
    print(f"   Device: {config.device}")
    print(f"   Strategy: {config.curriculum_strategy}")

    # Create directories
    os.makedirs(config.save_dir, exist_ok=True)
    os.makedirs(config.cache_dir, exist_ok=True)

    # Load model and create datasets BEFORE creating the model
    print("📊 Creating datasets...")
    datasets, tokenizer, complexity_analyzer = create_datasets(config)

    # Now create the model with random weights
    print("🔤 Loading model...")
    model_config = T5Config.from_pretrained("t5-small")  # Get T5-small architecture
    model = T5ForConditionalGeneration(model_config)  # Initialize with random weights

    if 'train' not in datasets:
        raise ValueError("❌ No training dataset available!")

    print(f"📊 Dataset sizes:")
    for split, dataset in datasets.items():
        print(f"   {split}: {len(dataset)} examples")

    # Initialize trainer
    trainer = T5CurriculumTrainer(
        model=model,
        tokenizer=tokenizer,
        save_dir=config.save_dir,
        learning_rate=config.learning_rate,
        weight_decay=config.weight_decay
    )

    # Train
    training_stats = trainer.train(
        train_dataset=datasets['train'],
        val_dataset=datasets.get('dev'),
        num_epochs=config.num_epochs,
        batch_size=config.batch_size,
        curriculum_strategy=config.curriculum_strategy,
        device=config.device,
        patience=config.patience
    )

    print("🏁 Training completed!")

    return {
        'model': model,
        'tokenizer': tokenizer,
        'complexity_analyzer': complexity_analyzer,
        'training_stats': training_stats,
        'config': config
    }

# =============================================================================
# EXAMPLE USAGE
# =============================================================================

if __name__ == "__main__":
    print("🎯 T5 Curriculum Learning with Hybrid Complexity Analysis")
    print("=" * 60)

    # Create custom config
    config = TrainingConfig()

    # Customize paths for your environment
    config.data_dir = '/content/baseline-pretraining/babylm_data'
    config.save_dir = '/content/drive/MyDrive/llm-project/t5-small-new-base_datapreparation/t5_curriculum_training'
    config.cache_dir = '/content/drive/MyDrive/llm-project/t5-small-new-base_datapreparation/full_t5_curriculum_cache'

    # Training parameters
    config.num_epochs = 5  # Reduced for testing
    config.batch_size = 4  # Adjust based on your GPU memory
    config.max_train_examples = None  # Limit for testing
    config.max_val_examples = 5000

    try:
        # Run training
        results = train_t5_with_curriculum(config)

        print("\n✅ Training completed successfully!")
        print(f"📁 Model saved to: {config.save_dir}")

        # Print some statistics
        stats = results['training_stats']
        print(f"\n📊 Training Statistics:")
        print(f"   Final train loss: {stats['epoch_losses'][-1]:.4f}")
        if stats['val_losses'] and stats['val_losses'][-1] > 0:
            print(f"   Final val loss: {stats['val_losses'][-1]:.4f}")
        print(f"   Curriculum progression: {stats['curriculum_levels']}")

    except Exception as e:
        print(f"❌ Training failed: {e}")
        import traceback
        traceback.print_exc()