# 

In [1]:
"""
Complete Setup and Installation Script for Fitness Q&A Model Training
This script handles all installations and imports needed for the project
"""

import subprocess
import sys
import os
from pathlib import Path

def install_package(package_name, upgrade=False):
    """Install a package using pip"""
    try:
        cmd = [sys.executable, "-m", "pip", "install"]
        if upgrade:
            cmd.append("--upgrade")
        cmd.append(package_name)
        
        print(f"Installing {package_name}...")
        result = subprocess.run(cmd, check=True, capture_output=True, text=True)
        print(f"✓ Successfully installed {package_name}")
        return True
    except subprocess.CalledProcessError as e:
        print(f"✗ Failed to install {package_name}: {e}")
        print(f"Error output: {e.stderr}")
        return False

def check_and_install_packages():
    """Check and install all required packages"""
    
    # Core packages with specific versions for compatibility
    # Note: PyArrow version is critical for datasets compatibility
    packages = [
        "torch>=2.0.0",
        "pyarrow>=12.0.0,<18.0.0",  # Fix PyArrow compatibility issue
        "transformers>=4.30.0",
        "datasets>=2.12.0",
        "pandas>=1.5.0",
        "numpy>=1.21.0",
        "scikit-learn>=1.2.0",
        "evaluate>=0.4.0",
        "nltk>=3.8.0",
        "accelerate>=0.20.0",  # For better training performance
        "sentencepiece>=0.1.99",  # Required for T5 tokenizer
        "protobuf>=3.20.0",  # Required for datasets
        "tqdm>=4.64.0",  # Progress bars
        "requests>=2.28.0",  # For downloading datasets
        "huggingface_hub>=0.15.0",  # For model/dataset downloads
    ]
    
    print("=" * 60)
    print("INSTALLING REQUIRED PACKAGES")
    print("=" * 60)
    
    failed_packages = []
    
    # First, upgrade pip itself
    print("Upgrading pip...")
    try:
        subprocess.run([sys.executable, "-m", "pip", "install", "--upgrade", "pip"], 
                      check=True, capture_output=True)
        print("✓ pip upgraded successfully")
    except:
        print("⚠ Could not upgrade pip, continuing anyway...")
    
    # Special handling for PyArrow compatibility issue
    print("\n🔧 Fixing PyArrow compatibility issue...")
    try:
        # Uninstall potentially conflicting PyArrow versions
        subprocess.run([sys.executable, "-m", "pip", "uninstall", "-y", "pyarrow"], 
                      capture_output=True)
        # Install compatible version
        subprocess.run([sys.executable, "-m", "pip", "install", "pyarrow>=12.0.0,<18.0.0"], 
                      check=True, capture_output=True)
        print("✓ PyArrow compatibility fixed")
    except Exception as e:
        print(f"⚠ Could not fix PyArrow: {e}")
    
    # Install packages
    for package in packages:
        if not install_package(package):
            failed_packages.append(package)
    
    # Install optional packages for better performance
    optional_packages = [
        "tensorboard",  # For monitoring training
        "matplotlib",   # For plotting if needed
        "seaborn",      # For better plots
    ]
    
    print("\n" + "=" * 60)
    print("INSTALLING OPTIONAL PACKAGES")
    print("=" * 60)
    
    for package in optional_packages:
        install_package(package)  # Don't track failures for optional packages
    
    if failed_packages:
        print(f"\n⚠ Warning: Failed to install: {', '.join(failed_packages)}")
        print("You may need to install these manually or check your environment.")
    else:
        print("\n✓ All required packages installed successfully!")
    
    return len(failed_packages) == 0

def fix_pyarrow_issue():
    """Fix the specific PyArrow compatibility issue"""
    print("\n" + "=" * 60)
    print("FIXING PYARROW COMPATIBILITY ISSUE")
    print("=" * 60)
    
    try:
        # Step 1: Uninstall problematic packages
        packages_to_remove = ["pyarrow", "datasets"]
        for package in packages_to_remove:
            try:
                subprocess.run([sys.executable, "-m", "pip", "uninstall", "-y", package], 
                              capture_output=True, check=False)
                print(f"✓ Uninstalled {package}")
            except:
                print(f"⚠ Could not uninstall {package} (may not be installed)")
        
        # Step 2: Install compatible PyArrow first
        compatible_pyarrow = "pyarrow>=12.0.0,<15.0.0"
        if install_package(compatible_pyarrow):
            print("✓ Installed compatible PyArrow version")
        else:
            print("✗ Failed to install compatible PyArrow")
            return False
        
        # Step 3: Install datasets with the compatible PyArrow
        if install_package("datasets>=2.12.0"):
            print("✓ Installed datasets with compatible PyArrow")
        else:
            print("✗ Failed to install datasets")
            return False
        
        # Step 4: Test the fix
        try:
            import pyarrow
            import datasets
            print(f"✓ PyArrow version: {pyarrow.__version__}")
            print(f"✓ Datasets version: {datasets.__version__}")
            
            # Test basic functionality
            from datasets import Dataset
            test_data = {"text": ["hello", "world"], "label": [1, 0]}
            test_dataset = Dataset.from_dict(test_data)
            print("✓ Datasets functionality test passed")
            
            return True
            
        except Exception as e:
            print(f"✗ Test failed: {e}")
            return False
            
    except Exception as e:
        print(f"✗ PyArrow fix failed: {e}")
        return False
def verify_installations(skip_datasets=False):
    """Verify that all packages can be imported correctly"""
    
    print("\n" + "=" * 60)
    print("VERIFYING INSTALLATIONS")
    print("=" * 60)
    
    imports_to_test = [
        ("torch", "PyTorch"),
        ("transformers", "Transformers"),
        ("pandas", "Pandas"),
        ("numpy", "NumPy"),
        ("sklearn", "Scikit-learn"),
        ("evaluate", "Evaluate"),
        ("nltk", "NLTK"),
        ("accelerate", "Accelerate"),
    ]
    
    # Add datasets only if PyArrow was fixed
    if not skip_datasets:
        imports_to_test.append(("datasets", "Datasets"))
    
    failed_imports = []
    
    for module_name, display_name in imports_to_test:
        try:
            __import__(module_name)
            print(f"✓ {display_name} imported successfully")
        except ImportError as e:
            print(f"✗ Failed to import {display_name}: {e}")
            failed_imports.append(display_name)
        except AttributeError as e:
            if "pyarrow" in str(e).lower():
                print(f"✗ {display_name} failed due to PyArrow compatibility: {e}")
                failed_imports.append(display_name)
            else:
                print(f"✗ Failed to import {display_name}: {e}")
                failed_imports.append(display_name)
    
    # Special checks
    try:
        import torch
        print(f"✓ PyTorch version: {torch.__version__}")
        print(f"✓ CUDA available: {torch.cuda.is_available()}")
        if torch.cuda.is_available():
            print(f"✓ CUDA device count: {torch.cuda.device_count()}")
    except:
        print("✗ Could not get PyTorch details")
    
    try:
        import transformers
        print(f"✓ Transformers version: {transformers.__version__}")
    except:
        print("✗ Could not get Transformers version")
    
    if not skip_datasets:
        try:
            import pyarrow
            import datasets
            print(f"✓ PyArrow version: {pyarrow.__version__}")
            print(f"✓ Datasets version: {datasets.__version__}")
        except:
            print("✗ Could not get PyArrow/Datasets versions")
    
    if failed_imports:
        print(f"\n⚠ Warning: Failed to import: {', '.join(failed_imports)}")
        return False
    else:
        print("\n✓ All packages verified successfully!")
        return True

def download_nltk_data():
    """Download required NLTK data"""
    print("\n" + "=" * 60)
    print("DOWNLOADING NLTK DATA")
    print("=" * 60)
    
    try:
        import nltk
        
        # Download required NLTK data
        nltk_downloads = [
            'punkt',
            'stopwords',
            'wordnet',
            'omw-1.4'
        ]
        
        for item in nltk_downloads:
            try:
                nltk.download(item, quiet=True)
                print(f"✓ Downloaded NLTK {item}")
            except Exception as e:
                print(f"⚠ Could not download NLTK {item}: {e}")
                
    except ImportError:
        print("✗ NLTK not available for downloading data")

def setup_environment():
    """Set up environment variables and create directories"""
    print("\n" + "=" * 60)
    print("SETTING UP ENVIRONMENT")
    print("=" * 60)
    
    # Set environment variables
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    os.environ["TRANSFORMERS_CACHE"] = "./cache/transformers"
    os.environ["HF_DATASETS_CACHE"] = "./cache/datasets"
    print("✓ Environment variables set")
    
    # Create necessary directories
    directories = [
        "cache",
        "cache/transformers", 
        "cache/datasets",
        "processed_data",
        "models",
        "logs"
    ]
    
    for directory in directories:
        Path(directory).mkdir(parents=True, exist_ok=True)
        print(f"✓ Created directory: {directory}")

def main():
    """Main setup function"""
    print("🚀 Starting complete setup for Fitness Q&A Model Training")
    print("This may take several minutes...")
    
    # Step 0: Fix PyArrow issue first (critical fix)
    print("\n🔧 Applying critical compatibility fixes...")
    pyarrow_fixed = fix_pyarrow_issue()
    
    # Step 1: Install packages
    packages_ok = check_and_install_packages()
    
    # Step 2: Verify installations (skip datasets if PyArrow fix failed)
    imports_ok = verify_installations(skip_datasets=not pyarrow_fixed)
    
    # Step 3: Download NLTK data
    download_nltk_data()
    
    # Step 4: Setup environment
    setup_environment()
    
    # Final status
    print("\n" + "=" * 60)
    print("SETUP COMPLETE")
    print("=" * 60)
    
    if packages_ok and imports_ok:
        print("✅ Setup completed successfully!")
        print("You can now run the fitness Q&A training script.")
        
        # Test basic functionality
        print("\n🧪 Running quick functionality test...")
        try:
            import torch
            from transformers import AutoTokenizer
            
            # Test tokenizer loading
            tokenizer = AutoTokenizer.from_pretrained('t5-small')
            print("✓ T5 tokenizer loads correctly")
            
            # Test basic tokenization
            test_text = "question: How can I improve my fitness?"
            tokens = tokenizer(test_text, return_tensors="pt")
            print("✓ Tokenization works correctly")
            
            # Test datasets if PyArrow was fixed
            if pyarrow_fixed:
                from datasets import Dataset
                test_data = {"question": ["test?"], "answer": ["test answer"]}
                test_ds = Dataset.from_dict(test_data)
                print("✓ Datasets functionality works correctly")
            
            print("✅ All functionality tests passed!")
            
        except Exception as e:
            print(f"⚠ Functionality test failed: {e}")
            print("You may need to restart your Python session.")
    
    else:
        print("❌ Setup completed with some issues.")
        print("Please check the error messages above and resolve any missing packages.")
        if not pyarrow_fixed:
            print("\n🔧 PyArrow compatibility issue detected!")
            print("Try running this fix manually:")
            print("pip uninstall -y pyarrow datasets")
            print("pip install 'pyarrow>=12.0.0,<15.0.0'")
            print("pip install 'datasets>=2.12.0'")
    
    print("\n📋 Next steps:")
    print("1. If setup was successful, you can now run the main training script")
    print("2. If there were issues, please install missing packages manually")
    print("3. Consider restarting your Python kernel/session after installation")
    print("4. If PyArrow issues persist, try the manual fix commands above")

if __name__ == "__main__":
    main()


# =============================================================================
# IMPORTS SECTION - All imports needed for the main script
# =============================================================================

print("\n" + "=" * 60)
print("TESTING ALL IMPORTS FOR MAIN SCRIPT")
print("=" * 60)

try:
    # Standard library imports
    import os
    import re
    import logging
    from pathlib import Path
    from typing import Dict, List, Optional, Tuple
    print("✓ Standard library imports successful")
    
    # Data handling
    import pandas as pd
    import numpy as np
    print("✓ Data handling imports successful")
    
    # Machine learning
    from sklearn.model_selection import train_test_split
    print("✓ Scikit-learn imports successful")
    
    # PyTorch
    import torch
    print("✓ PyTorch import successful")
    
    # NLTK
    import nltk
    print("✓ NLTK import successful")
    
    # Hugging Face
    from datasets import Dataset, load_dataset
    from transformers import (
        AutoTokenizer, 
        T5ForConditionalGeneration, 
        Trainer, 
        TrainingArguments,
        EarlyStoppingCallback
    )
    print("✓ Transformers imports successful")
    
    # Evaluation
    import evaluate
    print("✓ Evaluate import successful")
    
    print("\n✅ ALL IMPORTS SUCCESSFUL - Ready to run main script!")
    
except ImportError as e:
    print(f"\n❌ Import failed: {e}")
    print("Please run the setup section above to install missing packages.")

# =============================================================================
# SYSTEM INFO
# =============================================================================

def print_system_info():
    """Print system information for debugging"""
    print("\n" + "=" * 60)
    print("SYSTEM INFORMATION")
    print("=" * 60)
    
    print(f"Python version: {sys.version}")
    print(f"Platform: {sys.platform}")
    
    try:
        import torch
        print(f"PyTorch version: {torch.__version__}")
        print(f"CUDA available: {torch.cuda.is_available()}")
        if torch.cuda.is_available():
            print(f"CUDA version: {torch.version.cuda}")
            print(f"GPU count: {torch.cuda.device_count()}")
            for i in range(torch.cuda.device_count()):
                print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
    except:
        pass
    
    try:
        import transformers
        print(f"Transformers version: {transformers.__version__}")
    except:
        pass
    
    try:
        import datasets
        print(f"Datasets version: {datasets.__version__}")
    except:
        pass

print_system_info()

🚀 Starting complete setup for Fitness Q&A Model Training
This may take several minutes...

🔧 Applying critical compatibility fixes...

FIXING PYARROW COMPATIBILITY ISSUE
✓ Uninstalled pyarrow
✓ Uninstalled datasets
Installing pyarrow>=12.0.0,<15.0.0...
✓ Successfully installed pyarrow>=12.0.0,<15.0.0
✓ Installed compatible PyArrow version
Installing datasets>=2.12.0...
✓ Successfully installed datasets>=2.12.0
✓ Installed datasets with compatible PyArrow
✓ PyArrow version: 20.0.0
✓ Datasets version: 3.6.0
✓ Datasets functionality test passed
INSTALLING REQUIRED PACKAGES
Upgrading pip...
✓ pip upgraded successfully

🔧 Fixing PyArrow compatibility issue...
✓ PyArrow compatibility fixed
Installing torch>=2.0.0...
✓ Successfully installed torch>=2.0.0
Installing pyarrow>=12.0.0,<18.0.0...
✓ Successfully installed pyarrow>=12.0.0,<18.0.0
Installing transformers>=4.30.0...
✓ Successfully installed transformers>=4.30.0
Installing datasets>=2.12.0...
✓ Successfully installed datasets>=2.12.0
I

2025-06-19 11:28:07.023343: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1750332487.231120      35 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750332487.287037      35 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


✓ Evaluate imported successfully
✓ NLTK imported successfully
✓ Accelerate imported successfully
✓ Datasets imported successfully
✓ PyTorch version: 2.6.0+cu124
✓ CUDA available: True
✓ CUDA device count: 1
✓ Transformers version: 4.51.3
✓ PyArrow version: 20.0.0
✓ Datasets version: 3.6.0

✓ All packages verified successfully!

DOWNLOADING NLTK DATA
✓ Downloaded NLTK punkt
✓ Downloaded NLTK stopwords
✓ Downloaded NLTK wordnet
✓ Downloaded NLTK omw-1.4

SETTING UP ENVIRONMENT
✓ Environment variables set
✓ Created directory: cache
✓ Created directory: cache/transformers
✓ Created directory: cache/datasets
✓ Created directory: processed_data
✓ Created directory: models
✓ Created directory: logs

SETUP COMPLETE
✅ Setup completed successfully!
You can now run the fitness Q&A training script.

🧪 Running quick functionality test...


tokenizer_config.json:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

✓ T5 tokenizer loads correctly
✓ Tokenization works correctly
✓ Datasets functionality works correctly
✅ All functionality tests passed!

📋 Next steps:
1. If setup was successful, you can now run the main training script
2. If there were issues, please install missing packages manually
3. Consider restarting your Python kernel/session after installation
4. If PyArrow issues persist, try the manual fix commands above

TESTING ALL IMPORTS FOR MAIN SCRIPT
✓ Standard library imports successful
✓ Data handling imports successful
✓ Scikit-learn imports successful
✓ PyTorch import successful
✓ NLTK import successful
✓ Transformers imports successful
✓ Evaluate import successful

✅ ALL IMPORTS SUCCESSFUL - Ready to run main script!

SYSTEM INFORMATION
Python version: 3.11.11 (main, Dec  4 2024, 08:55:07) [GCC 11.4.0]
Platform: linux
PyTorch version: 2.6.0+cu124
CUDA available: True
CUDA version: 12.4
GPU count: 1
GPU 0: Tesla P100-PCIE-16GB
Transformers version: 4.51.3
Datasets version: 3.6.0


In [2]:
# Step 1: Install necessary libraries (run in terminal or notebook if needed)
# !pip install datasets transformers sentence-transformers pandas scikit-learn numpy

# Step 2: Import required libraries
import pandas as pd
import re
from datasets import load_dataset
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer
from sklearn.model_selection import train_test_split
import numpy as np

# Step 3: Define utility functions

def clean_text(text):
    """
    Clean text by converting to lowercase, removing extra spaces, and special characters.
    Args:
        text (str): Input text to clean.
    Returns:
        str: Cleaned text.
    """
    if not isinstance(text, str):
        return ""
    text = text.lower().strip()  # Convert to lowercase and remove leading/trailing spaces
    text = re.sub(r'\s+', ' ', text)  # Replace multiple spaces with single space
    text = re.sub(r'[^\w\s.,!?]', '', text)  # Remove special characters except punctuation
    return text

def filter_fitness_relevance(df, question_col, keywords):
    """
    Filter dataset to keep only fitness-related questions based on keywords.
    Args:
        df (pd.DataFrame): Input dataframe with question column.
        question_col (str): Name of the question column.
        keywords (list): List of fitness-related keywords.
    Returns:
        pd.DataFrame: Filtered dataframe.
    """
    pattern = '|'.join(keywords)
    return df[df[question_col].str.contains(pattern, case=False, na=False)]

def paraphrase_question(question, paraphraser):
    """
    Generate a paraphrased version of the input question using a sentence transformer.
    Args:
        question (str): Original question.
        paraphraser: SentenceTransformer model for paraphrasing.
    Returns:
        str: Paraphrased question (placeholder, can be enhanced).
    """
    # Placeholder: In practice, use paraphraser.encode() to generate embeddings and find similar phrasing
    return question  # Replace with actual paraphrasing logic if needed

def augment_data(df, question_col, answer_col, paraphraser, num_augmentations=1):
    """
    Augment dataset by generating paraphrased questions.
    Args:
        df (pd.DataFrame): Input dataframe with question and answer columns.
        question_col (str): Name of the question column.
        answer_col (str): Name of the answer column.
        paraphraser: SentenceTransformer model for paraphrasing.
        num_augmentations (int): Number of paraphrases per question.
    Returns:
        pd.DataFrame: Augmented dataframe.
    """
    augmented_rows = []
    for _, row in df.iterrows():
        original_question = row[question_col]
        answer = row[answer_col]
        augmented_rows.append({question_col: original_question, answer_col: answer})
        for _ in range(num_augmentations):
            paraphrased_question = paraphrase_question(original_question, paraphraser)
            augmented_rows.append({question_col: paraphrased_question, answer_col: answer})
    return pd.DataFrame(augmented_rows)

def tokenize_data(row, question_col, answer_col, tokenizer, max_length=512):
    """
    Tokenize question-answer pair for T5 model.
    Args:
        row (pd.Series): Dataframe row with question and answer.
        question_col (str): Name of the question column.
        answer_col (str): Name of the answer column.
        tokenizer: Transformers tokenizer.
        max_length (int): Maximum token length.
    Returns:
        dict: Tokenized input and attention mask.
    """
    input_text = f"question: {row[question_col]} answer: {row[answer_col]}"
    return tokenizer(input_text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt")

def main():
    # Step 4: Load the dataset from Hugging Face
    try:
        dataset = load_dataset("its-myrto/fitness-question-answers")
        df = dataset['train'].to_pandas()
        print(f"Initial dataset size: {len(df)}")
    except Exception as e:
        print(f"Error loading dataset: {e}")
        return

    # Step 5: Inspect column names
    print("Dataset columns:", df.columns.tolist())

    # Step 6: Define column names (update these based on actual column names)
    # Based on error, 'question' and 'answer' may not exist. Adjust as needed.
    # For now, let's assume they might be 'Question' and 'Answer' (case-sensitive) or check output above.
    question_col = 'Question'  # Update this after checking printed columns
    answer_col = 'Answer'      # Update this after checking printed columns

    # Verify column names exist
    if question_col not in df.columns or answer_col not in df.columns:
        print(f"Error: Columns '{question_col}' and/or '{answer_col}' not found in dataset.")
        print("Please update 'question_col' and 'answer_col' in the script with correct column names.")
        return

    # Step 7: Clean the data
    # Remove duplicates
    df = df.drop_duplicates(subset=[question_col, answer_col], keep='first')
    print(f"Rows after removing duplicates: {len(df)}")

    # Remove missing values
    df = df.dropna(subset=[question_col, answer_col])
    print(f"Rows after removing missing values: {len(df)}")

    # Clean questions and answers
    df[question_col] = df[question_col].apply(clean_text)
    df[answer_col] = df[answer_col].apply(clean_text)

    # Step 8: Filter for fitness relevance
    fitness_keywords = ['exercise', 'workout', 'fitness', 'nutrition', 'muscle', 'cardio', 'strength', 'yoga', 'running']
    df = filter_fitness_relevance(df, question_col, fitness_keywords)
    print(f"Rows after filtering for fitness relevance: {len(df)}")

    # Step 9: Data augmentation (optional, enabled if dataset is small)
    if len(df) < 1000:
        print("Augmenting dataset due to small size...")
        try:
            paraphraser = SentenceTransformer('paraphrase-MiniLM-L6-v2')
            df = augment_data(df, question_col, answer_col, paraphraser, num_augmentations=1)
            print(f"Rows after augmentation: {len(df)}")
        except Exception as e:
            print(f"Augmentation failed: {e}. Proceeding without augmentation.")
            # Continue without augmentation if it fails

    # Step 10: Split dataset into train and validation
    train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)
    print(f"Train size: {len(train_df)}, Validation size: {len(val_df)}")

    # Step 11: Tokenize data for T5 model
    try:
        tokenizer = AutoTokenizer.from_pretrained('t5-small')
        train_tokens = train_df.apply(lambda row: tokenize_data(row, question_col, answer_col, tokenizer), axis=1)
        val_tokens = val_df.apply(lambda row: tokenize_data(row, question_col, answer_col, tokenizer), axis=1)
    except Exception as e:
        print(f"Tokenization failed: {e}")
        return

      # Step 12: Save cleaned and split datasets
    output_dir = "/kaggle/working/processed_data"
    os.makedirs(output_dir, exist_ok=True)

    train_df.to_csv(os.path.join(output_dir, 'train_cleaned.csv'), index=False)
    val_df.to_csv(os.path.join(output_dir, 'val_cleaned.csv'), index=False)
    print("Cleaned datasets saved in '/kaggle/working/processed_data/'")


if __name__ == "__main__":
    main()

README.md:   0%|          | 0.00/203 [00:00<?, ?B/s]

conversational_dataset.csv:   0%|          | 0.00/289k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/965 [00:00<?, ? examples/s]

Initial dataset size: 965
Dataset columns: ['Unnamed: 0', 'Question', 'Answer']
Rows after removing duplicates: 965
Rows after removing missing values: 965
Rows after filtering for fitness relevance: 515
Augmenting dataset due to small size...


modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/3.51k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/629 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/314 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Rows after augmentation: 1030
Train size: 824, Validation size: 206
Cleaned datasets saved in '/kaggle/working/processed_data/'


In [3]:
import pandas as pd
import re

def clean_invalid_questions():
    files = ['train_cleaned.csv', 'val_cleaned.csv']
    for file in files:
        try:
            df = pd.read_csv(f"/kaggle/working/processed_data/{file}")
            print(f"\nProcessing {file}: {len(df)} rows")
            # Remove rows where 'Question' is not a valid question
            df = df[df['Question'].str.contains(r'[a-zA-Z\s]+[?]', na=False)]
            # Remove rows with non-fitness-related or invalid terms
            invalid_terms = ['entailment', 'true', 'false', 'contradiction']
            df = df[~df['Question'].str.lower().str.contains('|'.join(invalid_terms), na=False)]
            print(f"After removing invalid questions: {len(df)} rows")
            # Save cleaned file
            df.to_csv(f"/kaggle/working/processed_data/{file}", index=False)
            print(f"Saved cleaned {file}")
            # Preview cleaned data
            print(f"Preview of cleaned {file}:")
            print(df.head(5))
        except Exception as e:
            print(f"Error processing {file}: {e}")

if __name__ == "__main__":
    clean_invalid_questions()


Processing train_cleaned.csv: 824 rows
After removing invalid questions: 824 rows
Saved cleaned train_cleaned.csv
Preview of cleaned train_cleaned.csv:
                                            Question  \
0    how does strength training improve flexibility?   
1  how do i know if ive worked a muscle hard enough?   
2  what are some simple exercises i can do at hom...   
3                         how to strengthen my knee?   
4  is it okay to skip a workout if im feeling tired?   

                                              Answer  
0  strength training can improve flexibility by i...  
1  if you feel a burn in the muscle during the la...  
2  there are plenty of exercises you can do at ho...  
3  exercise is a noninvasive and healthful way to...  
4  its okay to skip a workout if youre feeling ti...  

Processing val_cleaned.csv: 206 rows
After removing invalid questions: 206 rows
Saved cleaned val_cleaned.csv
Preview of cleaned val_cleaned.csv:
                                 

In [4]:
import pandas as pd
import re

def clean_invalid_questions():
    files = ['train_cleaned.csv', 'val_cleaned.csv']
    for file in files:
        try:
            df = pd.read_csv(f"/kaggle/working/processed_data/{file}")
            print(f"\nProcessing {file}: {len(df)} rows")
            # Remove rows where 'Question' is not a valid question
            df = df[df['Question'].str.contains(r'[a-zA-Z\s]+[?]', na=False)]
            # Remove rows with non-fitness-related or invalid terms
            invalid_terms = ['entailment', 'true', 'false', 'contradiction']
            df = df[~df['Question'].str.lower().str.contains('|'.join(invalid_terms), na=False)]
            print(f"After removing invalid questions: {len(df)} rows")
            # Save cleaned file
            df.to_csv(f"/kaggle/working/processed_data/{file}", index=False)
            print(f"Saved cleaned {file}")
            # Preview cleaned data
            print(f"Preview of cleaned {file}:")
            print(df.head(5))
        except Exception as e:
            print(f"Error processing {file}: {e}")

if __name__ == "__main__":
    clean_invalid_questions()


Processing train_cleaned.csv: 824 rows
After removing invalid questions: 824 rows
Saved cleaned train_cleaned.csv
Preview of cleaned train_cleaned.csv:
                                            Question  \
0    how does strength training improve flexibility?   
1  how do i know if ive worked a muscle hard enough?   
2  what are some simple exercises i can do at hom...   
3                         how to strengthen my knee?   
4  is it okay to skip a workout if im feeling tired?   

                                              Answer  
0  strength training can improve flexibility by i...  
1  if you feel a burn in the muscle during the la...  
2  there are plenty of exercises you can do at ho...  
3  exercise is a noninvasive and healthful way to...  
4  its okay to skip a workout if youre feeling ti...  

Processing val_cleaned.csv: 206 rows
After removing invalid questions: 206 rows
Saved cleaned val_cleaned.csv
Preview of cleaned val_cleaned.csv:
                                 

In [5]:

import os
import re
import logging
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import pandas as pd
import torch
import nltk
from datasets import load_dataset, Dataset
from transformers import (
    AutoTokenizer,
    T5ForConditionalGeneration,
    Trainer,
    TrainingArguments,
    EarlyStoppingCallback
)
from sklearn.model_selection import train_test_split
import evaluate

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Memory management
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch.cuda.empty_cache()

def format_prompt(question: str, prompt_type: str = "instruct") -> str:
    if prompt_type == "instruct":
        return f"Answer the following fitness question: {question}"
    elif prompt_type == "qa":
        return f"Q: {question}\nA:"
    else:
        return f"question: {question}"

class FitnessQAProcessor:
    FITNESS_KEYWORDS = [
        'exercise', 'workout', 'fitness', 'nutrition', 'muscle', 'cardio',
        'strength', 'yoga', 'running', 'sleep', 'stress', 'recovery',
        'flexibility', 'balance', 'posture', 'hydration', 'motivation',
        'diet', 'weight', 'training', 'gym', 'health', 'stretch'
    ]

    def __init__(self):
        self.question_col = 'Question'
        self.answer_col = 'Answer'

    def clean_text(self, text: str) -> str:
        if not isinstance(text, str) or pd.isna(text):
            return ""
        text = text.lower().strip()
        text = re.sub(r'\s+', ' ', text)
        text = re.sub(r'[^\w\s.,!?-]', '', text)
        if len(text.split()) < 3:
            return ""
        return text

    def is_valid_question(self, question: str) -> bool:
        if not question or len(question) < 10:
            return False
        if not re.search(r'[a-zA-Z]', question):
            return False
        question_indicators = ['how', 'what', 'why', 'when', 'where', 'which', 'who', 'can', 'should', 'do', 'does', 'is', 'are']
        return question.endswith('?') or any(question.startswith(word) for word in question_indicators)

    def is_fitness_related(self, text: str) -> bool:
        pattern = '|'.join(self.FITNESS_KEYWORDS)
        return bool(re.search(pattern, text, re.IGNORECASE))

    def load_and_clean_data(self, dataset_name: str = "its-myrto/fitness-question-answers") -> pd.DataFrame:
        logger.info(f"Loading dataset: {dataset_name}")
        dataset = load_dataset(dataset_name)
        df = dataset['train'].to_pandas()
        if 'Unnamed: 0' in df.columns:
            df.drop(columns=['Unnamed: 0'], inplace=True)
        if self.question_col not in df.columns or self.answer_col not in df.columns:
            raise ValueError(f"Required columns '{self.question_col}' and/or '{self.answer_col}' not found")
        return self._clean_dataframe(df)

    def _clean_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
        df = df.drop_duplicates(subset=[self.question_col, self.answer_col])
        df = df.dropna(subset=[self.question_col, self.answer_col])
        df[self.question_col] = df[self.question_col].apply(self.clean_text)
        df[self.answer_col] = df[self.answer_col].apply(self.clean_text)
        df = df[(df[self.question_col] != "") & (df[self.answer_col] != "")]
        df = df[df[self.question_col].apply(self.is_valid_question)]
        fitness_mask = (df[self.question_col].apply(self.is_fitness_related) | df[self.answer_col].apply(self.is_fitness_related))
        return df[fitness_mask].reset_index(drop=True)

class T5FitnessTrainer:
    def __init__(self, model_name: str = 't5-small', max_length: int = 512):
        self.model_name = model_name
        self.max_length = max_length
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        nltk.download('punkt', quiet=True)
        self.tokenizer = None
        self.model = None

    def initialize_model(self):
        logger.info(f"Initializing {self.model_name} on {self.device}")
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.model = T5ForConditionalGeneration.from_pretrained(self.model_name)
        self.model.gradient_checkpointing_enable()
        self.model.to(self.device)

    def prepare_dataset(self, df: pd.DataFrame) -> Dataset:
        inputs = [format_prompt(row['Question']) for _, row in df.iterrows()]
        targets = [row['Answer'] for _, row in df.iterrows()]
        input_encodings = self.tokenizer(inputs, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt")
        target_encodings = self.tokenizer(targets, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt")
        dataset_dict = {
            'input_ids': input_encodings['input_ids'].squeeze(),
            'attention_mask': input_encodings['attention_mask'].squeeze(),
            'labels': target_encodings['input_ids'].squeeze()
        }
        return Dataset.from_dict(dataset_dict)

    def train_model(self, train_df: pd.DataFrame, val_df: pd.DataFrame, output_dir: str) -> Trainer:
        if self.model is None or self.tokenizer is None:
            self.initialize_model()
        train_dataset = self.prepare_dataset(train_df)
        val_dataset = self.prepare_dataset(val_df)
        training_args = TrainingArguments(
            output_dir=output_dir,
            num_train_epochs=15,
            per_device_train_batch_size=1,
            per_device_eval_batch_size=2,
            gradient_accumulation_steps=4,
            gradient_checkpointing=True,
            warmup_steps=200,
            weight_decay=0.01,
            logging_dir=f"{output_dir}/logs",
            logging_steps=50,
            eval_strategy="steps",
            eval_steps=200,
            save_strategy="steps",
            save_steps=200,
            load_best_model_at_end=True,
            metric_for_best_model="eval_loss",
            greater_is_better=False,
            save_total_limit=2,
            fp16=torch.cuda.is_available(),
            dataloader_pin_memory=False,
            report_to=[],
            batch_eval_metrics=False
        )
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            compute_metrics=None,
            callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
        )
        trainer.train()
        trainer.save_model(f"{output_dir}_final")
        self.tokenizer.save_pretrained(f"{output_dir}_final")
        return trainer

    def test_model(self, test_questions: Optional[List[str]] = None, prompt_type: str = "instruct") -> None:
        if self.model is None or self.tokenizer is None:
            model_path = "/kaggle/working/fitness_qa_model_final"
            self.tokenizer = AutoTokenizer.from_pretrained(model_path)
            self.model = T5ForConditionalGeneration.from_pretrained(model_path).to(self.device)
        if test_questions is None:
            test_questions = [
                "How can I improve my running endurance?",
                "What are effective core exercises?",
                "How do I stay motivated for workouts?",
                "What should I eat before exercising?",
                "How often should I rest between workouts?"
            ]
        self.model.eval()
        with torch.no_grad():
            for question in test_questions:
                prompt = format_prompt(question, prompt_type)
                inputs = self.tokenizer(prompt, return_tensors="pt", max_length=self.max_length, truncation=True).to(self.device)
                outputs = self.model.generate(**inputs, max_length=self.max_length, num_beams=5, no_repeat_ngram_size=2, early_stopping=True)
                print(f"\nQ: {question}\nA: {self.tokenizer.decode(outputs[0], skip_special_tokens=True)}")

    def evaluate_bleu(self, val_df: pd.DataFrame, sample_size: int = 100, prompt_type: str = "instruct") -> float:
        bleu = evaluate.load("bleu")
        preds, refs = [], []
        self.model.eval()
        with torch.no_grad():
            for i, row in enumerate(val_df.itertuples()):
                if i >= sample_size: break
                prompt = format_prompt(row.Question, prompt_type)
                inputs = self.tokenizer(prompt, return_tensors='pt', max_length=self.max_length, truncation=True).to(self.device)
                outputs = self.model.generate(**inputs, max_length=64, num_beams=5, no_repeat_ngram_size=2, early_stopping=True)
                preds.append(self.tokenizer.decode(outputs[0], skip_special_tokens=True))
                refs.append(row.Answer)
        score = bleu.compute(predictions=preds, references=[[r] for r in refs])["bleu"]
        logger.info(f"BLEU score: {score:.4f}")
        return score

def main():
    OUTPUT_DIR = "/kaggle/working/processed_data/fitness_qa_model"
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    os.makedirs(f"{OUTPUT_DIR}/logs", exist_ok=True)
    os.makedirs(f"{OUTPUT_DIR}_final", exist_ok=True)
    processor = FitnessQAProcessor()
    df = processor.load_and_clean_data()
    train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)
    train_df.to_csv("/kaggle/working/processed_data/train_cleaned.csv", index=False)
    val_df.to_csv("/kaggle/working/processed_data/val_cleaned.csv", index=False)
    trainer = T5FitnessTrainer()
    trainer.train_model(train_df, val_df, output_dir=OUTPUT_DIR)
    trainer.test_model(prompt_type="instruct")
    trainer.evaluate_bleu(val_df, sample_size=100, prompt_type="instruct")

if __name__ == "__main__":
    main()

config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Step,Training Loss,Validation Loss
200,0.5209,0.402534
400,0.3713,0.341175
600,0.3586,0.319831
800,0.3382,0.314006
1000,0.3361,0.311116
1200,0.3644,0.308658
1400,0.332,0.306921
1600,0.3209,0.305729
1800,0.3152,0.30476
2000,0.3142,0.304101


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight', 'lm_head.weight'].



Q: How can I improve my running endurance?
A: running endurance is a great way to improve your running performance. if you are looking for strength training, it can help.

Q: What are effective core exercises?
A: effective core exercises include squats, tai chi, and hamstrings.

Q: How do I stay motivated for workouts?
A: staying motivated for workouts is a great way to stay motivated.

Q: What should I eat before exercising?
A: eat a lot of food before exercising, especially if you are exercising.

Q: How often should I rest between workouts?
A: rest between workouts depends on the intensity of the workout.


Downloading builder script:   0%|          | 0.00/5.94k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/3.34k [00:00<?, ?B/s]

In [6]:
from transformers import AutoTokenizer, T5ForConditionalGeneration
import shutil
import torch

model_path = "/kaggle/working/processed_data/fitness_qa_model_final"
best_model_path = "/kaggle/working/processed_data/fitness_qa_model_best_model_1"

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = T5ForConditionalGeneration.from_pretrained(model_path).to("cuda" if torch.cuda.is_available() else "cpu")

# Save under new name
tokenizer.save_pretrained(best_model_path)
model.save_pretrained(best_model_path)

print(f"Model saved as: {best_model_path}")


Model saved as: /kaggle/working/processed_data/fitness_qa_model_best_model_1


In [7]:
# 📊 Evaluation Cell: BLEU, F1, Perplexity

from math import exp
import torch
import torch.nn.functional as F
import numpy as np
from sklearn.metrics import f1_score
import evaluate
from tqdm import tqdm
import pandas as pd
import re, string

# 📌 Load validation set if not already loaded
try:
    val_df
except NameError:
    val_df = pd.read_csv("processed_data/val_cleaned.csv")
    print(f"✅ Loaded val_df with {len(val_df)} samples")

# 🔧 Text normalization
def normalize_answer(s):
    def remove_articles(text): return re.sub(r'\b(a|an|the)\b', ' ', text)
    def white_space_fix(text): return ' '.join(text.split())
    def remove_punc(text): return ''.join(ch for ch in text if ch not in set(string.punctuation))
    def lower(text): return text.lower()
    return white_space_fix(remove_articles(remove_punc(lower(s))))

# 🔧 F1 computation
def compute_f1(pred, true):
    pred_tokens = normalize_answer(pred).split()
    true_tokens = normalize_answer(true).split()
    common = set(pred_tokens) & set(true_tokens)
    if len(common) == 0: return 0.0
    precision = len(common) / len(pred_tokens)
    recall = len(common) / len(true_tokens)
    return 2 * (precision * recall) / (precision + recall)

# 🔧 Prompt formatting
def format_prompt(question, prompt_type="instruct"):
    if prompt_type == "instruct":
        return f"Answer this fitness question: {question}"
    elif prompt_type == "qa":
        return f"question: {question}"
    else:
        return question

# 🧪 Evaluation function
def evaluate_model(model, tokenizer, val_df, device='cuda', max_length=512, sample_size=100, prompt_type="instruct"):
    bleu = evaluate.load("bleu")
    preds, refs, f1s, perplexities = [], [], [], []

    model.eval()
    with torch.no_grad():
        for i, row in tqdm(enumerate(val_df.itertuples()), total=min(sample_size, len(val_df))):
            if i >= sample_size: break

            prompt = format_prompt(row.Question, prompt_type)
            inputs = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=max_length).to(device)
            labels = tokenizer(row.Answer, return_tensors='pt', truncation=True, max_length=max_length).input_ids.to(device)

            # Generate prediction
            outputs = model.generate(**inputs, max_length=64)
            decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
            preds.append(decoded)
            refs.append([row.Answer])
            f1s.append(compute_f1(decoded, row.Answer))

            # Compute perplexity
            logits = model(**inputs, labels=labels).logits
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)),
                                   shift_labels.view(-1), ignore_index=tokenizer.pad_token_id)
            perplexities.append(exp(loss.item()))

    # 📈 Final metrics
    bleu_score = bleu.compute(predictions=preds, references=refs)["bleu"]
    f1_score_avg = np.mean(f1s)
    perplexity_avg = np.mean(perplexities)

    print(f"🔹 BLEU Score     : {bleu_score:.4f}")
    print(f"🔹 F1 Score       : {f1_score_avg*100:.2f}%")
    print(f"🔹 Perplexity     : {perplexity_avg:.2f}")

# ✅ Example call:
# evaluate_model(model, tokenizer, val_df, device='cuda')


✅ Loaded val_df with 179 samples


In [8]:
evaluate_model(model, tokenizer, val_df, device='cuda')


100%|██████████| 100/100 [00:31<00:00,  3.14it/s]

🔹 BLEU Score     : 0.0361
🔹 F1 Score       : 17.03%
🔹 Perplexity     : 10819.09





In [None]:
import torch
import logging
import re
import sys
import os
from transformers import AutoTokenizer, T5ForConditionalGeneration
from typing import List, Optional

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class FitnessChatbot:
    # Define fitness-related keywords
    FITNESS_KEYWORDS = [
        'exercise', 'workout', 'fitness', 'nutrition', 'muscle', 'cardio', 'strength', 
        'yoga', 'running', 'sleep', 'stress', 'recovery', 'flexibility', 'balance', 
        'posture', 'hydration', 'motivation', 'diet', 'weight', 'training', 'gym', 
        'health', 'stretch', 'protein', 'calorie', 'endurance', 'aerobic', 'anaerobic'
    ]

    def __init__(self, model_path: str, max_length: int = 512):
        """
        Initialize the fitness chatbot with a trained T5 model and tokenizer.
        Args:
            model_path (str): Path to the trained model directory.
            max_length (int): Maximum token length for input/output.
        """
        self.max_length = max_length
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        logger.info(f"Using device: {self.device}")
        self.tokenizer = None
        self.model = None
        self.load_model(model_path)

    def load_model(self, model_path: str) -> None:
        """
        Load the T5 model and tokenizer from the specified path.
        Args:
            model_path (str): Path to the trained model directory.
        """
        try:
            if not os.path.exists(model_path):
                raise FileNotFoundError(f"Model path {model_path} does not exist.")
            logger.info(f"Loading model from {model_path}")
            self.tokenizer = AutoTokenizer.from_pretrained(model_path)
            self.model = T5ForConditionalGeneration.from_pretrained(model_path).to(self.device)
            self.model.eval()
            logger.info("Model and tokenizer loaded successfully")
        except Exception as e:
            logger.error(f"Failed to load model: {e}")
            sys.exit(1)

    def clean_text(self, text: str) -> str:
        """
        Clean input text by converting to lowercase, removing extra spaces, and special characters.
        Args:
            text (str): Input text to clean.
        Returns:
            str: Cleaned text.
        """
        if not isinstance(text, str):
            return ""
        text = text.lower().strip()
        text = re.sub(r'\s+', ' ', text)
        text = re.sub(r'[^\w\s.,!?-]', '', text)
        return text

    def is_fitness_related(self, question: str) -> bool:
        """
        Check if the question is fitness-related based on keywords.
        Args:
            question (str): User's question.
        Returns:
            bool: True if fitness-related, False otherwise.
        """
        question = self.clean_text(question)
        if not question or len(question.split()) < 3:
            return False
        pattern = '|'.join(self.FITNESS_KEYWORDS)
        return bool(re.search(pattern, question, re.IGNORECASE))

    def format_prompt(self, question: str, prompt_type: str = "instruct") -> str:
        """
        Format the input question as a prompt for the T5 model.
        Args:
            question (str): User's question.
            prompt_type (str): Type of prompt formatting ('instruct' or 'qa').
        Returns:
            str: Formatted prompt.
        """
        if prompt_type == "instruct":
            return f"Answer this fitness-related question with specific advice: {question}"
        elif prompt_type == "qa":
            return f"question: {question}\nanswer:"
        return question

    def generate_response(self, question: str) -> str:
        """
        Generate a response to the user's question using the T5 model.
        Args:
            question (str): User's question.
        Returns:
            str: Generated response or error/rejection message.
        """
        try:
            # Clean and validate input
            question = self.clean_text(question)
            if not question or len(question.split()) < 3:
                return "Please ask a valid question with at least a few words."

            # Check if question is fitness-related
            if not self.is_fitness_related(question):
                return "Sorry, I can only answer fitness-related questions. Please ask about exercise, nutrition, or health!"

            # Format prompt and tokenize
            prompt = self.format_prompt(question)
            inputs = self.tokenizer(
                prompt,
                return_tensors="pt",
                max_length=self.max_length,
                truncation=True
            ).to(self.device)

            # Generate response
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_length=64,
                    num_beams=5,
                    no_repeat_ngram_size=2,
                    early_stopping=True
                )
            response = self.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()

            # Fallback for vague responses
            if len(response.split()) < 5 or "sorry" in response.lower():
                return "I couldn't provide a detailed answer. Try rephrasing or asking something like 'What are good protein sources for muscle gain?'"

            return response

        except Exception as e:
            logger.error(f"Error generating response: {e}")
            return "An error occurred while generating the response. Please try again."

    def run(self) -> None:
        """
        Run the interactive chatbot loop, allowing users to type questions.
        """
        print("\n=== Fitness Chatbot ===")
        print("I'm here to answer fitness-related questions about exercise, nutrition, and health!")
        print("Type 'exit' or 'quit' to stop.")
        print("Examples: 'How can I improve my running endurance?' or 'What should I eat before a workout?'")
        print("Note: I only respond to fitness-related questions.\n")

        while True:
            try:
                user_input = input("You: ").strip()
                if user_input.lower() in ['exit', 'quit']:
                    print("Goodbye!")
                    break

                response = self.generate_response(user_input)
                print(f"Bot: {response}\n")

            except KeyboardInterrupt:
                print("\nInterrupted. Goodbye!")
                break
            except Exception as e:
                logger.error(f"Error in chatbot loop: {e}")
                print("An error occurred. Please try again.\n")

def main():
    # Define model path (update if needed)
    model_path = "/kaggle/working/processed_data/fitness_qa_model_best_model_1"
    
    # Initialize and run chatbot
    chatbot = FitnessChatbot(model_path)
    chatbot.run()

if __name__ == "__main__":
    main()


=== Fitness Chatbot ===
I'm here to answer fitness-related questions about exercise, nutrition, and health!
Type 'exit' or 'quit' to stop.
Examples: 'How can I improve my running endurance?' or 'What should I eat before a workout?'
Note: I only respond to fitness-related questions.



You:  what is agriculuture


Bot: Sorry, I can only answer fitness-related questions. Please ask about exercise, nutrition, or health!



You:   how can i optimize my nutrition to support my fitness goals?


Bot: maximizing your nutrition to support your fitness goals is a great way to improve your health and well-being.



You:  why should I optimize my nutrition to support my fitness goals?


Bot: you should optimize your nutrition to support your fitness goals.



You:  are a gay?


Bot: Sorry, I can only answer fitness-related questions. Please ask about exercise, nutrition, or health!

