In [1]:
"""
Complete Setup and Installation Script for Fitness Q&A Model Training
This script handles all installations and imports needed for the project
"""

import subprocess
import sys
import os
from pathlib import Path

def install_package(package_name, upgrade=False):
    """Install a package using pip"""
    try:
        cmd = [sys.executable, "-m", "pip", "install"]
        if upgrade:
            cmd.append("--upgrade")
        cmd.append(package_name)
        
        print(f"Installing {package_name}...")
        result = subprocess.run(cmd, check=True, capture_output=True, text=True)
        print(f"✓ Successfully installed {package_name}")
        return True
    except subprocess.CalledProcessError as e:
        print(f"✗ Failed to install {package_name}: {e}")
        print(f"Error output: {e.stderr}")
        return False

def check_and_install_packages():
    """Check and install all required packages"""
    
    # Core packages with specific versions for compatibility
    # Note: PyArrow version is critical for datasets compatibility
    packages = [
        "torch>=2.0.0",
        "pyarrow>=12.0.0,<18.0.0",  # Fix PyArrow compatibility issue
        "transformers>=4.30.0",
        "datasets>=2.12.0",
        "pandas>=1.5.0",
        "numpy>=1.21.0",
        "scikit-learn>=1.2.0",
        "evaluate>=0.4.0",
        "nltk>=3.8.0",
        "accelerate>=0.20.0",  # For better training performance
        "sentencepiece>=0.1.99",  # Required for T5 tokenizer
        "protobuf>=3.20.0",  # Required for datasets
        "tqdm>=4.64.0",  # Progress bars
        "requests>=2.28.0",  # For downloading datasets
        "huggingface_hub>=0.15.0",  # For model/dataset downloads
    ]
    
    print("=" * 60)
    print("INSTALLING REQUIRED PACKAGES")
    print("=" * 60)
    
    failed_packages = []
    
    # First, upgrade pip itself
    print("Upgrading pip...")
    try:
        subprocess.run([sys.executable, "-m", "pip", "install", "--upgrade", "pip"], 
                      check=True, capture_output=True)
        print("✓ pip upgraded successfully")
    except:
        print("⚠ Could not upgrade pip, continuing anyway...")
    
    # Special handling for PyArrow compatibility issue
    print("\n🔧 Fixing PyArrow compatibility issue...")
    try:
        # Uninstall potentially conflicting PyArrow versions
        subprocess.run([sys.executable, "-m", "pip", "uninstall", "-y", "pyarrow"], 
                      capture_output=True)
        # Install compatible version
        subprocess.run([sys.executable, "-m", "pip", "install", "pyarrow>=12.0.0,<18.0.0"], 
                      check=True, capture_output=True)
        print("✓ PyArrow compatibility fixed")
    except Exception as e:
        print(f"⚠ Could not fix PyArrow: {e}")
    
    # Install packages
    for package in packages:
        if not install_package(package):
            failed_packages.append(package)
    
    # Install optional packages for better performance
    optional_packages = [
        "tensorboard",  # For monitoring training
        "matplotlib",   # For plotting if needed
        "seaborn",      # For better plots
    ]
    
    print("\n" + "=" * 60)
    print("INSTALLING OPTIONAL PACKAGES")
    print("=" * 60)
    
    for package in optional_packages:
        install_package(package)  # Don't track failures for optional packages
    
    if failed_packages:
        print(f"\n⚠ Warning: Failed to install: {', '.join(failed_packages)}")
        print("You may need to install these manually or check your environment.")
    else:
        print("\n✓ All required packages installed successfully!")
    
    return len(failed_packages) == 0

def fix_pyarrow_issue():
    """Fix the specific PyArrow compatibility issue"""
    print("\n" + "=" * 60)
    print("FIXING PYARROW COMPATIBILITY ISSUE")
    print("=" * 60)
    
    try:
        # Step 1: Uninstall problematic packages
        packages_to_remove = ["pyarrow", "datasets"]
        for package in packages_to_remove:
            try:
                subprocess.run([sys.executable, "-m", "pip", "uninstall", "-y", package], 
                              capture_output=True, check=False)
                print(f"✓ Uninstalled {package}")
            except:
                print(f"⚠ Could not uninstall {package} (may not be installed)")
        
        # Step 2: Install compatible PyArrow first
        compatible_pyarrow = "pyarrow>=12.0.0,<15.0.0"
        if install_package(compatible_pyarrow):
            print("✓ Installed compatible PyArrow version")
        else:
            print("✗ Failed to install compatible PyArrow")
            return False
        
        # Step 3: Install datasets with the compatible PyArrow
        if install_package("datasets>=2.12.0"):
            print("✓ Installed datasets with compatible PyArrow")
        else:
            print("✗ Failed to install datasets")
            return False
        
        # Step 4: Test the fix
        try:
            import pyarrow
            import datasets
            print(f"✓ PyArrow version: {pyarrow.__version__}")
            print(f"✓ Datasets version: {datasets.__version__}")
            
            # Test basic functionality
            from datasets import Dataset
            test_data = {"text": ["hello", "world"], "label": [1, 0]}
            test_dataset = Dataset.from_dict(test_data)
            print("✓ Datasets functionality test passed")
            
            return True
            
        except Exception as e:
            print(f"✗ Test failed: {e}")
            return False
            
    except Exception as e:
        print(f"✗ PyArrow fix failed: {e}")
        return False
def verify_installations(skip_datasets=False):
    """Verify that all packages can be imported correctly"""
    
    print("\n" + "=" * 60)
    print("VERIFYING INSTALLATIONS")
    print("=" * 60)
    
    imports_to_test = [
        ("torch", "PyTorch"),
        ("transformers", "Transformers"),
        ("pandas", "Pandas"),
        ("numpy", "NumPy"),
        ("sklearn", "Scikit-learn"),
        ("evaluate", "Evaluate"),
        ("nltk", "NLTK"),
        ("accelerate", "Accelerate"),
    ]
    
    # Add datasets only if PyArrow was fixed
    if not skip_datasets:
        imports_to_test.append(("datasets", "Datasets"))
    
    failed_imports = []
    
    for module_name, display_name in imports_to_test:
        try:
            __import__(module_name)
            print(f"✓ {display_name} imported successfully")
        except ImportError as e:
            print(f"✗ Failed to import {display_name}: {e}")
            failed_imports.append(display_name)
        except AttributeError as e:
            if "pyarrow" in str(e).lower():
                print(f"✗ {display_name} failed due to PyArrow compatibility: {e}")
                failed_imports.append(display_name)
            else:
                print(f"✗ Failed to import {display_name}: {e}")
                failed_imports.append(display_name)
    
    # Special checks
    try:
        import torch
        print(f"✓ PyTorch version: {torch.__version__}")
        print(f"✓ CUDA available: {torch.cuda.is_available()}")
        if torch.cuda.is_available():
            print(f"✓ CUDA device count: {torch.cuda.device_count()}")
    except:
        print("✗ Could not get PyTorch details")
    
    try:
        import transformers
        print(f"✓ Transformers version: {transformers.__version__}")
    except:
        print("✗ Could not get Transformers version")
    
    if not skip_datasets:
        try:
            import pyarrow
            import datasets
            print(f"✓ PyArrow version: {pyarrow.__version__}")
            print(f"✓ Datasets version: {datasets.__version__}")
        except:
            print("✗ Could not get PyArrow/Datasets versions")
    
    if failed_imports:
        print(f"\n⚠ Warning: Failed to import: {', '.join(failed_imports)}")
        return False
    else:
        print("\n✓ All packages verified successfully!")
        return True

def download_nltk_data():
    """Download required NLTK data"""
    print("\n" + "=" * 60)
    print("DOWNLOADING NLTK DATA")
    print("=" * 60)
    
    try:
        import nltk
        
        # Download required NLTK data
        nltk_downloads = [
            'punkt',
            'stopwords',
            'wordnet',
            'omw-1.4'
        ]
        
        for item in nltk_downloads:
            try:
                nltk.download(item, quiet=True)
                print(f"✓ Downloaded NLTK {item}")
            except Exception as e:
                print(f"⚠ Could not download NLTK {item}: {e}")
                
    except ImportError:
        print("✗ NLTK not available for downloading data")

def setup_environment():
    """Set up environment variables and create directories"""
    print("\n" + "=" * 60)
    print("SETTING UP ENVIRONMENT")
    print("=" * 60)
    
    # Set environment variables
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    os.environ["TRANSFORMERS_CACHE"] = "./cache/transformers"
    os.environ["HF_DATASETS_CACHE"] = "./cache/datasets"
    print("✓ Environment variables set")
    
    # Create necessary directories
    directories = [
        "cache",
        "cache/transformers", 
        "cache/datasets",
        "processed_data",
        "models",
        "logs"
    ]
    
    for directory in directories:
        Path(directory).mkdir(parents=True, exist_ok=True)
        print(f"✓ Created directory: {directory}")

def main():
    """Main setup function"""
    print("🚀 Starting complete setup for Fitness Q&A Model Training")
    print("This may take several minutes...")
    
    # Step 0: Fix PyArrow issue first (critical fix)
    print("\n🔧 Applying critical compatibility fixes...")
    pyarrow_fixed = fix_pyarrow_issue()
    
    # Step 1: Install packages
    packages_ok = check_and_install_packages()
    
    # Step 2: Verify installations (skip datasets if PyArrow fix failed)
    imports_ok = verify_installations(skip_datasets=not pyarrow_fixed)
    
    # Step 3: Download NLTK data
    download_nltk_data()
    
    # Step 4: Setup environment
    setup_environment()
    
    # Final status
    print("\n" + "=" * 60)
    print("SETUP COMPLETE")
    print("=" * 60)
    
    if packages_ok and imports_ok:
        print("✅ Setup completed successfully!")
        print("You can now run the fitness Q&A training script.")
        
        # Test basic functionality
        print("\n🧪 Running quick functionality test...")
        try:
            import torch
            from transformers import AutoTokenizer
            
            # Test tokenizer loading
            tokenizer = AutoTokenizer.from_pretrained('t5-small')
            print("✓ T5 tokenizer loads correctly")
            
            # Test basic tokenization
            test_text = "question: How can I improve my fitness?"
            tokens = tokenizer(test_text, return_tensors="pt")
            print("✓ Tokenization works correctly")
            
            # Test datasets if PyArrow was fixed
            if pyarrow_fixed:
                from datasets import Dataset
                test_data = {"question": ["test?"], "answer": ["test answer"]}
                test_ds = Dataset.from_dict(test_data)
                print("✓ Datasets functionality works correctly")
            
            print("✅ All functionality tests passed!")
            
        except Exception as e:
            print(f"⚠ Functionality test failed: {e}")
            print("You may need to restart your Python session.")
    
    else:
        print("❌ Setup completed with some issues.")
        print("Please check the error messages above and resolve any missing packages.")
        if not pyarrow_fixed:
            print("\n🔧 PyArrow compatibility issue detected!")
            print("Try running this fix manually:")
            print("pip uninstall -y pyarrow datasets")
            print("pip install 'pyarrow>=12.0.0,<15.0.0'")
            print("pip install 'datasets>=2.12.0'")
    
    print("\n📋 Next steps:")
    print("1. If setup was successful, you can now run the main training script")
    print("2. If there were issues, please install missing packages manually")
    print("3. Consider restarting your Python kernel/session after installation")
    print("4. If PyArrow issues persist, try the manual fix commands above")

if __name__ == "__main__":
    main()


# =============================================================================
# IMPORTS SECTION - All imports needed for the main script
# =============================================================================

print("\n" + "=" * 60)
print("TESTING ALL IMPORTS FOR MAIN SCRIPT")
print("=" * 60)

try:
    # Standard library imports
    import os
    import re
    import logging
    from pathlib import Path
    from typing import Dict, List, Optional, Tuple
    print("✓ Standard library imports successful")
    
    # Data handling
    import pandas as pd
    import numpy as np
    print("✓ Data handling imports successful")
    
    # Machine learning
    from sklearn.model_selection import train_test_split
    print("✓ Scikit-learn imports successful")
    
    # PyTorch
    import torch
    print("✓ PyTorch import successful")
    
    # NLTK
    import nltk
    print("✓ NLTK import successful")
    
    # Hugging Face
    from datasets import Dataset, load_dataset
    from transformers import (
        AutoTokenizer, 
        T5ForConditionalGeneration, 
        Trainer, 
        TrainingArguments,
        EarlyStoppingCallback
    )
    print("✓ Transformers imports successful")
    
    # Evaluation
    import evaluate
    print("✓ Evaluate import successful")
    
    print("\n✅ ALL IMPORTS SUCCESSFUL - Ready to run main script!")
    
except ImportError as e:
    print(f"\n❌ Import failed: {e}")
    print("Please run the setup section above to install missing packages.")

# =============================================================================
# SYSTEM INFO
# =============================================================================

def print_system_info():
    """Print system information for debugging"""
    print("\n" + "=" * 60)
    print("SYSTEM INFORMATION")
    print("=" * 60)
    
    print(f"Python version: {sys.version}")
    print(f"Platform: {sys.platform}")
    
    try:
        import torch
        print(f"PyTorch version: {torch.__version__}")
        print(f"CUDA available: {torch.cuda.is_available()}")
        if torch.cuda.is_available():
            print(f"CUDA version: {torch.version.cuda}")
            print(f"GPU count: {torch.cuda.device_count()}")
            for i in range(torch.cuda.device_count()):
                print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
    except:
        pass
    
    try:
        import transformers
        print(f"Transformers version: {transformers.__version__}")
    except:
        pass
    
    try:
        import datasets
        print(f"Datasets version: {datasets.__version__}")
    except:
        pass

print_system_info()

🚀 Starting complete setup for Fitness Q&A Model Training
This may take several minutes...

🔧 Applying critical compatibility fixes...

FIXING PYARROW COMPATIBILITY ISSUE
✓ Uninstalled pyarrow
✓ Uninstalled datasets
Installing pyarrow>=12.0.0,<15.0.0...
✓ Successfully installed pyarrow>=12.0.0,<15.0.0
✓ Installed compatible PyArrow version
Installing datasets>=2.12.0...
✓ Successfully installed datasets>=2.12.0
✓ Installed datasets with compatible PyArrow
✓ PyArrow version: 20.0.0
✓ Datasets version: 3.6.0
✓ Datasets functionality test passed
INSTALLING REQUIRED PACKAGES
Upgrading pip...
✓ pip upgraded successfully

🔧 Fixing PyArrow compatibility issue...
✓ PyArrow compatibility fixed
Installing torch>=2.0.0...
✓ Successfully installed torch>=2.0.0
Installing pyarrow>=12.0.0,<18.0.0...
✓ Successfully installed pyarrow>=12.0.0,<18.0.0
Installing transformers>=4.30.0...
✓ Successfully installed transformers>=4.30.0
Installing datasets>=2.12.0...
✓ Successfully installed datasets>=2.12.0
I

2025-06-18 04:36:01.969924: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1750221362.188675      35 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750221362.260036      35 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


✓ Evaluate imported successfully
✓ NLTK imported successfully
✓ Accelerate imported successfully
✓ Datasets imported successfully
✓ PyTorch version: 2.6.0+cu124
✓ CUDA available: True
✓ CUDA device count: 1
✓ Transformers version: 4.51.3
✓ PyArrow version: 20.0.0
✓ Datasets version: 3.6.0

✓ All packages verified successfully!

DOWNLOADING NLTK DATA
✓ Downloaded NLTK punkt
✓ Downloaded NLTK stopwords
✓ Downloaded NLTK wordnet
✓ Downloaded NLTK omw-1.4

SETTING UP ENVIRONMENT
✓ Environment variables set
✓ Created directory: cache
✓ Created directory: cache/transformers
✓ Created directory: cache/datasets
✓ Created directory: processed_data
✓ Created directory: models
✓ Created directory: logs

SETUP COMPLETE
✅ Setup completed successfully!
You can now run the fitness Q&A training script.

🧪 Running quick functionality test...


tokenizer_config.json:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

✓ T5 tokenizer loads correctly
✓ Tokenization works correctly
✓ Datasets functionality works correctly
✅ All functionality tests passed!

📋 Next steps:
1. If setup was successful, you can now run the main training script
2. If there were issues, please install missing packages manually
3. Consider restarting your Python kernel/session after installation
4. If PyArrow issues persist, try the manual fix commands above

TESTING ALL IMPORTS FOR MAIN SCRIPT
✓ Standard library imports successful
✓ Data handling imports successful
✓ Scikit-learn imports successful
✓ PyTorch import successful
✓ NLTK import successful
✓ Transformers imports successful
✓ Evaluate import successful

✅ ALL IMPORTS SUCCESSFUL - Ready to run main script!

SYSTEM INFORMATION
Python version: 3.11.11 (main, Dec  4 2024, 08:55:07) [GCC 11.4.0]
Platform: linux
PyTorch version: 2.6.0+cu124
CUDA available: True
CUDA version: 12.4
GPU count: 1
GPU 0: Tesla P100-PCIE-16GB
Transformers version: 4.51.3
Datasets version: 3.6.0


In [2]:
"""
Fitness Q&A Model Training Pipeline
A complete pipeline for training a T5 model on fitness-related question-answer pairs
"""

import os
import re
import logging
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import pandas as pd
import torch
import nltk
from datasets import Dataset, load_dataset
from transformers import (
    AutoTokenizer, 
    T5ForConditionalGeneration, 
    Trainer, 
    TrainingArguments,
    EarlyStoppingCallback
)
from sklearn.model_selection import train_test_split
import evaluate

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Suppress tokenizers parallelism warning
os.environ["TOKENIZERS_PARALLELISM"] = "false"

class FitnessQAProcessor:
    """Handles data preprocessing for fitness Q&A dataset"""
    
    FITNESS_KEYWORDS = [
        'exercise', 'workout', 'fitness', 'nutrition', 'muscle', 'cardio', 
        'strength', 'yoga', 'running', 'sleep', 'stress', 'recovery', 
        'flexibility', 'balance', 'posture', 'hydration', 'motivation',
        'diet', 'weight', 'training', 'gym', 'health', 'stretch'
    ]
    
    def __init__(self):
        self.question_col = 'Question'
        self.answer_col = 'Answer'
    
    def clean_text(self, text: str) -> str:
        """Clean and normalize text"""
        if not isinstance(text, str) or pd.isna(text):
            return ""
        
        # Convert to lowercase and strip whitespace
        text = text.lower().strip()
        
        # Remove extra whitespace
        text = re.sub(r'\s+', ' ', text)
        
        # Remove special characters but keep basic punctuation
        text = re.sub(r'[^\w\s.,!?-]', '', text)
        
        # Remove very short texts (likely noise)
        if len(text.split()) < 3:
            return ""
            
        return text
    
    def is_valid_question(self, question: str) -> bool:
        """Check if text is a valid question"""
        if not question or len(question) < 10:
            return False
        
        # Must contain letters and proper question structure
        if not re.search(r'[a-zA-Z]', question):
            return False
        
        # Should end with question mark or be a question word
        question_indicators = ['how', 'what', 'why', 'when', 'where', 'which', 'who', 'can', 'should', 'do', 'does', 'is', 'are']
        ends_with_question = question.endswith('?')
        starts_with_question = any(question.startswith(word) for word in question_indicators)
        
        return ends_with_question or starts_with_question
    
    def is_fitness_related(self, text: str) -> bool:
        """Check if text is fitness related"""
        pattern = '|'.join(self.FITNESS_KEYWORDS)
        return bool(re.search(pattern, text, re.IGNORECASE))
    
    def load_and_clean_data(self, dataset_name: str = "its-myrto/fitness-question-answers") -> pd.DataFrame:
        """Load and clean the fitness Q&A dataset"""
        try:
            logger.info(f"Loading dataset: {dataset_name}")
            dataset = load_dataset(dataset_name)
            df = dataset['train'].to_pandas()
            logger.info(f"Initial dataset size: {len(df)}")
            
            # Drop unnecessary columns
            cols_to_drop = ['Unnamed: 0'] if 'Unnamed: 0' in df.columns else []
            if cols_to_drop:
                df = df.drop(columns=cols_to_drop)
                logger.info(f"Dropped columns: {cols_to_drop}")
            
            # Verify required columns exist
            if self.question_col not in df.columns or self.answer_col not in df.columns:
                raise ValueError(f"Required columns '{self.question_col}' and/or '{self.answer_col}' not found")
            
            # Clean the data
            df = self._clean_dataframe(df)
            
            logger.info(f"Final cleaned dataset size: {len(df)}")
            return df
            
        except Exception as e:
            logger.error(f"Error loading dataset: {e}")
            raise
    
    def _clean_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
        """Apply all cleaning steps to dataframe"""
        initial_size = len(df)
        
        # Remove duplicates
        df = df.drop_duplicates(subset=[self.question_col, self.answer_col], keep='first')
        logger.info(f"Removed {initial_size - len(df)} duplicates")
        
        # Remove missing values
        df = df.dropna(subset=[self.question_col, self.answer_col])
        logger.info(f"Removed rows with missing values, remaining: {len(df)}")
        
        # Clean text
        df[self.question_col] = df[self.question_col].apply(self.clean_text)
        df[self.answer_col] = df[self.answer_col].apply(self.clean_text)
        
        # Remove empty entries after cleaning
        df = df[(df[self.question_col] != "") & (df[self.answer_col] != "")]
        
        # Filter valid questions
        df = df[df[self.question_col].apply(self.is_valid_question)]
        logger.info(f"After filtering valid questions: {len(df)}")
        
        # Filter fitness-related content
        fitness_mask = (df[self.question_col].apply(self.is_fitness_related) | 
                       df[self.answer_col].apply(self.is_fitness_related))
        df = df[fitness_mask]
        logger.info(f"After filtering fitness-related content: {len(df)}")
        
        return df.reset_index(drop=True)

class T5FitnessTrainer:
    """Handles T5 model training for fitness Q&A"""
    
    def __init__(self, model_name: str = 't5-small', max_length: int = 512):
        self.model_name = model_name
        self.max_length = max_length
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.tokenizer = None
        self.model = None
        
        # Download NLTK data for BLEU computation
        try:
            nltk.download('punkt', quiet=True)
        except:
            logger.warning("Could not download NLTK punkt tokenizer")
    
    def initialize_model(self):
        """Initialize tokenizer and model"""
        try:
            logger.info(f"Initializing {self.model_name} on {self.device}")
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
            self.model = T5ForConditionalGeneration.from_pretrained(self.model_name)
            self.model.to(self.device)
            logger.info("Model and tokenizer initialized successfully")
        except Exception as e:
            logger.error(f"Error initializing model: {e}")
            raise
    
    def prepare_dataset(self, df: pd.DataFrame) -> Dataset:
        """Convert DataFrame to tokenized Dataset"""
        inputs = [f"question: {row['Question']}" for _, row in df.iterrows()]
        targets = [row['Answer'] for _, row in df.iterrows()]
        
        # Tokenize inputs
        input_encodings = self.tokenizer(
            inputs,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        
        # Tokenize targets
        target_encodings = self.tokenizer(
            targets,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        
        # Create dataset
        dataset_dict = {
            'input_ids': input_encodings['input_ids'],
            'attention_mask': input_encodings['attention_mask'],
            'labels': target_encodings['input_ids']
        }
        
        return Dataset.from_dict(dataset_dict)
    
    def compute_metrics(self, eval_pred) -> Dict[str, float]:
        """Compute evaluation metrics"""
        predictions, labels = eval_pred
        
        # Handle model output format
        if isinstance(predictions, tuple):
            predictions = predictions[0]
        
        # Convert to text
        if predictions.ndim == 3:
            predictions = predictions.argmax(-1)
        
        pred_texts = self.tokenizer.batch_decode(predictions, skip_special_tokens=True)
        label_texts = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
        
        # Remove padding tokens from labels
        label_texts = [text.replace(self.tokenizer.pad_token, "").strip() for text in label_texts]
        
        # Compute BLEU score
        try:
            bleu = evaluate.load("bleu")
            result = bleu.compute(
                predictions=pred_texts, 
                references=[[label] for label in label_texts]
            )
            return {"bleu": result["bleu"]}
        except Exception as e:
            logger.warning(f"Could not compute BLEU score: {e}")
            return {"bleu": 0.0}
    
    def train_model(self, train_df: pd.DataFrame, val_df: pd.DataFrame, 
                   output_dir: str = "./fitness_qa_model") -> None:
        """Train the T5 model"""
        if self.model is None or self.tokenizer is None:
            self.initialize_model()
        
        # Prepare datasets
        logger.info("Preparing training datasets...")
        train_dataset = self.prepare_dataset(train_df)
        val_dataset = self.prepare_dataset(val_df)
        
        # Training arguments
        training_args = TrainingArguments(
            output_dir=output_dir,
            num_train_epochs=3,
            per_device_train_batch_size=8,
            per_device_eval_batch_size=16,
            warmup_steps=200,
            weight_decay=0.01,
            logging_dir=f"{output_dir}/logs",
            logging_steps=50,
            eval_strategy="steps",
            eval_steps=200,
            save_strategy="steps",
            save_steps=200,
            load_best_model_at_end=True,
            metric_for_best_model="bleu",
            greater_is_better=True,
            save_total_limit=2,
            fp16=torch.cuda.is_available(),
            dataloader_pin_memory=False,
            report_to=[]  # Disable wandb logging
        )
        
        # Initialize trainer
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            compute_metrics=self.compute_metrics,
            callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
        )
        
        # Train the model
        logger.info("Starting training...")
        try:
            trainer.train()
            logger.info("Training completed successfully")
        except Exception as e:
            logger.error(f"Training failed: {e}")
            raise
        
        # Save final model
        final_output_dir = f"{output_dir}_final"
        trainer.save_model(final_output_dir)
        self.tokenizer.save_pretrained(final_output_dir)
        logger.info(f"Model saved to {final_output_dir}")
        
        return trainer
    
    def test_model(self, test_questions: Optional[List[str]] = None) -> None:
        """Test the trained model with sample questions"""
        if test_questions is None:
            test_questions = [
                "How can I improve my running endurance?",
                "What are effective core exercises?",
                "How do I stay motivated for workouts?",
                "What should I eat before exercising?",
                "How often should I rest between workouts?"
            ]
        
        if self.model is None or self.tokenizer is None:
            logger.error("Model not initialized. Please train or load a model first.")
            return
        
        logger.info("Testing model with sample questions...")
        self.model.eval()
        
        with torch.no_grad():
            for question in test_questions:
                input_text = f"question: {question}"
                inputs = self.tokenizer(
                    input_text, 
                    return_tensors="pt", 
                    max_length=self.max_length, 
                    truncation=True
                ).to(self.device)
                
                outputs = self.model.generate(
                    **inputs, 
                    max_length=self.max_length, 
                    num_beams=5,
                    no_repeat_ngram_size=2,
                    early_stopping=True
                )
                
                answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
                print(f"\nQ: {question}")
                print(f"A: {answer}")

def main():
    """Main training pipeline"""
    # Configuration
    OUTPUT_DIR = "./fitness_qa_model"
    TEST_SIZE = 0.2
    RANDOM_STATE = 42
    
    try:
        # Step 1: Process data
        logger.info("Starting fitness Q&A model training pipeline")
        processor = FitnessQAProcessor()
        df = processor.load_and_clean_data()
        
        if len(df) < 100:
            logger.warning(f"Dataset is very small ({len(df)} samples). Consider finding more data.")
        
        # Step 2: Split data
        train_df, val_df = train_test_split(
            df, 
            test_size=TEST_SIZE, 
            random_state=RANDOM_STATE,
            stratify=None
        )
        logger.info(f"Data split - Train: {len(train_df)}, Validation: {len(val_df)}")
        
        # Step 3: Save processed datasets
        os.makedirs("processed_data", exist_ok=True)
        train_df.to_csv("processed_data/train_cleaned.csv", index=False)
        val_df.to_csv("processed_data/val_cleaned.csv", index=False)
        logger.info("Processed datasets saved")
        
        # Step 4: Initialize trainer and train model
        trainer = T5FitnessTrainer()
        trained_model = trainer.train_model(train_df, val_df, OUTPUT_DIR)
        
        # Step 5: Test the model
        trainer.test_model()
        
        logger.info("Pipeline completed successfully!")
        
    except Exception as e:
        logger.error(f"Pipeline failed: {e}")
        raise

if __name__ == "__main__":
    main()

README.md:   0%|          | 0.00/203 [00:00<?, ?B/s]

conversational_dataset.csv:   0%|          | 0.00/289k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/965 [00:00<?, ? examples/s]

config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Step,Training Loss,Validation Loss


OutOfMemoryError: CUDA out of memory. Tried to allocate 6.86 GiB. GPU 0 has a total capacity of 15.89 GiB of which 1007.12 MiB is free. Process 3327 has 14.90 GiB memory in use. Of the allocated memory 7.67 GiB is allocated by PyTorch, and 6.94 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)