# 🎓 Healthcare Chatbot - Model Training

In this notebook, you'll:
1. Set up training configuration
2. Choose and load a pre-trained model
3. Train your healthcare chatbot
4. Monitor training progress
5. Save and test your trained model

Let's train your AI doctor! 🤖🏥

## ⚙️ Step 1: Training Configuration

First, let's set up the training parameters and check our data.

In [None]:
import sys
import os
import json
import time
import torch
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

# Add src to path
sys.path.append('/workspace/src')

# Set up plotting
%matplotlib inline

print("🔧 Training environment setup complete!")
print(f"🔥 PyTorch version: {torch.__version__}")
print(f"🎮 CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"🎮 GPU: {torch.cuda.get_device_name(0)}")
    print(f"🎮 GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("💻 Using CPU (training will be slower)")

In [None]:
# Training Configuration - MODIFY THESE SETTINGS
TRAINING_CONFIG = {
    # Dataset settings
    'dataset_path': '/workspace/data/kaggle_medical_dataset.json',  # Path to your processed dataset
    'max_samples': None,  # Limit dataset size (None = use all data)
    'train_split': 0.8,   # Fraction for training vs validation
    
    # Model settings
    'model_key': 'distilgpt2',  # Choose: distilgpt2, gpt2, dialogpt-medium
    'output_dir': '/workspace/models/my_healthcare_chatbot',
    
    # Training hyperparameters
    'epochs': 2,          # Number of training epochs (start with 1-2 for testing)
    'batch_size': 2,      # Training batch size (reduce if out of memory)
    'learning_rate': 5e-5, # Learning rate
    'warmup_steps': 50,   # Warmup steps
    
    # Training options
    'skip_hyperparameter_search': True,  # Set to False for automatic optimization
    'use_early_stopping': True,
}

print("⚙️ TRAINING CONFIGURATION")
print("="*40)
for key, value in TRAINING_CONFIG.items():
    print(f"{key}: {value}")

print("\n💡 TIPS:")
print("- Start with distilgpt2 for quick testing")
print("- Use dialogpt-medium for better quality (slower)")
print("- Reduce batch_size if you get memory errors")
print("- Start with 1-2 epochs for testing, then increase")

In [None]:
# Check dataset availability
dataset_path = TRAINING_CONFIG['dataset_path']

print("📊 CHECKING DATASET")
print("="*30)

if os.path.exists(dataset_path):
    # Load and check dataset
    with open(dataset_path, 'r') as f:
        dataset = json.load(f)
    
    print(f"✅ Dataset found: {dataset_path}")
    print(f"📊 Total samples: {len(dataset)}")
    
    # Apply sample limit if specified
    if TRAINING_CONFIG['max_samples'] and len(dataset) > TRAINING_CONFIG['max_samples']:
        dataset = dataset[:TRAINING_CONFIG['max_samples']]
        print(f"📊 Limited to: {len(dataset)} samples")
    
    # Show sample
    if dataset:
        sample = dataset[0]
        print(f"\n📝 Sample Q&A:")
        print(f"   Q: {sample['question']}")
        print(f"   A: {sample['answer'][:100]}...")
    
    # Calculate split sizes
    train_size = int(len(dataset) * TRAINING_CONFIG['train_split'])
    val_size = len(dataset) - train_size
    
    print(f"\n📈 Training split:")
    print(f"   Training: {train_size} samples ({TRAINING_CONFIG['train_split']*100:.0f}%)")
    print(f"   Validation: {val_size} samples ({(1-TRAINING_CONFIG['train_split'])*100:.0f}%)")
    
    dataset_ready = True
    
else:
    print(f"❌ Dataset not found: {dataset_path}")
    print("\n🔧 SOLUTIONS:")
    print("1. Run notebook 02_Data_Exploration.ipynb first")
    print("2. Or run: python setup_kaggle_dataset.py")
    print("3. Check that your dataset path is correct")
    
    dataset_ready = False
    dataset = None

## 🤖 Step 2: Model Selection and Setup

Let's load and configure the pre-trained model for fine-tuning.

In [None]:
# Model information
MODEL_INFO = {
    'distilgpt2': {
        'description': 'Lightweight GPT-2 variant (fastest)',
        'size': '~350MB',
        'speed': 'Fast',
        'quality': 'Good',
        'recommended_for': 'Testing, small datasets, quick experiments'
    },
    'gpt2': {
        'description': 'Original GPT-2 base model',
        'size': '~500MB',
        'speed': 'Medium',
        'quality': 'Good',
        'recommended_for': 'Balanced performance and speed'
    },
    'dialogpt-medium': {
        'description': 'Microsoft DialoGPT optimized for conversation',
        'size': '~1.5GB',
        'speed': 'Slower',
        'quality': 'Best',
        'recommended_for': 'Production use, best conversational quality'
    }
}

selected_model = TRAINING_CONFIG['model_key']

print("🤖 MODEL SELECTION")
print("="*30)
print(f"Selected model: {selected_model}")

if selected_model in MODEL_INFO:
    info = MODEL_INFO[selected_model]
    print(f"\n📋 Model Information:")
    for key, value in info.items():
        print(f"   {key}: {value}")
else:
    print(f"⚠️ Unknown model: {selected_model}")

print("\n🔄 Available Models:")
for model, info in MODEL_INFO.items():
    icon = "👉" if model == selected_model else "  "
    print(f"{icon} {model}: {info['description']} ({info['size']})")

In [None]:
# Initialize model manager and load model
if dataset_ready:
    print("🔄 LOADING MODEL")
    print("="*25)
    
    try:
        from model_manager import HealthcareModelManager
        
        # Initialize model manager
        model_manager = HealthcareModelManager(
            model_key=TRAINING_CONFIG['model_key'],
            device='auto'  # Automatically choose GPU or CPU
        )
        
        print(f"🔄 Loading {TRAINING_CONFIG['model_key']}...")
        print("   (This may take a few minutes for first download)")
        
        # Load model and tokenizer
        model, tokenizer = model_manager.load_model_and_tokenizer()
        
        print("✅ Model loaded successfully!")
        
        # Get model information
        model_info = model_manager.get_model_info()
        
        print(f"\n📊 Model Details:")
        for key, value in model_info.items():
            if key != 'model_key':
                print(f"   {key}: {value}")
        
        model_loaded = True
        
    except Exception as e:
        print(f"❌ Error loading model: {e}")
        print("\n🔧 SOLUTIONS:")
        print("1. Check internet connection (for model download)")
        print("2. Try a smaller model (distilgpt2)")
        print("3. Restart the notebook if memory issues")
        
        model_loaded = False
        model_manager = None
        
else:
    print("⚠️ Cannot load model - dataset not ready")
    model_loaded = False
    model_manager = None

## 🎓 Step 3: Training Setup and Execution

Now let's set up the training pipeline and start training!

In [None]:
# Initialize fine-tuner
if dataset_ready and model_loaded:
    print("🎓 SETTING UP TRAINING")
    print("="*30)
    
    try:
        from fine_tuning import HealthcareFinetuner
        
        # Initialize fine-tuner
        finetuner = HealthcareFinetuner(
            model_key=TRAINING_CONFIG['model_key'],
            output_dir=TRAINING_CONFIG['output_dir']
        )
        
        print("✅ Fine-tuner initialized")
        
        # Prepare dataset
        print("🔄 Preparing dataset...")
        training_dataset = finetuner.prepare_data(
            dataset_path=TRAINING_CONFIG['dataset_path'],
            train_split=TRAINING_CONFIG['train_split']
        )
        
        print(f"✅ Dataset prepared:")
        print(f"   Training samples: {len(training_dataset['train'])}")
        print(f"   Validation samples: {len(training_dataset['validation'])}")
        
        training_ready = True
        
    except Exception as e:
        print(f"❌ Error setting up training: {e}")
        training_ready = False
        finetuner = None
        training_dataset = None
        
else:
    print("⚠️ Cannot set up training - prerequisites not met")
    training_ready = False
    finetuner = None
    training_dataset = None

In [None]:
# Estimate training time
if training_ready:
    print("⏱️ TRAINING TIME ESTIMATION")
    print("="*35)
    
    # Rough estimates based on model and data size
    train_samples = len(training_dataset['train'])
    epochs = TRAINING_CONFIG['epochs']
    batch_size = TRAINING_CONFIG['batch_size']
    
    # Calculate steps
    steps_per_epoch = train_samples // batch_size
    total_steps = steps_per_epoch * epochs
    
    # Time estimates (very rough)
    if TRAINING_CONFIG['model_key'] == 'distilgpt2':
        seconds_per_step = 2 if torch.cuda.is_available() else 8
    elif TRAINING_CONFIG['model_key'] == 'gpt2':
        seconds_per_step = 3 if torch.cuda.is_available() else 12
    else:  # dialogpt-medium
        seconds_per_step = 5 if torch.cuda.is_available() else 20
    
    estimated_seconds = total_steps * seconds_per_step
    estimated_minutes = estimated_seconds / 60
    
    print(f"📊 Training Overview:")
    print(f"   Samples: {train_samples}")
    print(f"   Epochs: {epochs}")
    print(f"   Batch size: {batch_size}")
    print(f"   Steps per epoch: {steps_per_epoch}")
    print(f"   Total steps: {total_steps}")
    
    print(f"\n⏱️ Estimated training time: {estimated_minutes:.1f} minutes")
    
    if estimated_minutes > 30:
        print("\n⚠️ Training will take a while. Consider:")
        print("   - Reducing epochs or max_samples")
        print("   - Using a smaller model (distilgpt2)")
        print("   - Increasing batch_size if you have GPU memory")
    
    print(f"\n🚀 Ready to start training!")

In [None]:
# START TRAINING!
if training_ready:
    print("🚀 STARTING TRAINING")
    print("="*30)
    print(f"⏰ Start time: {datetime.now().strftime('%H:%M:%S')}")
    print("\n🔄 Training in progress...")
    print("   (This cell will show progress updates)")
    
    try:
        # Prepare hyperparameters
        if TRAINING_CONFIG['skip_hyperparameter_search']:
            hyperparameters = {
                'learning_rate': TRAINING_CONFIG['learning_rate'],
                'per_device_train_batch_size': TRAINING_CONFIG['batch_size'],
                'num_train_epochs': TRAINING_CONFIG['epochs'],
                'warmup_steps': TRAINING_CONFIG['warmup_steps']
            }
            print(f"📋 Using manual hyperparameters: {hyperparameters}")
        else:
            hyperparameters = None
            print("🔍 Will perform hyperparameter search...")
        
        # Start training
        start_time = time.time()
        
        training_results = finetuner.fine_tune(
            dataset=training_dataset,
            hyperparameters=hyperparameters,
            use_early_stopping=TRAINING_CONFIG['use_early_stopping']
        )
        
        end_time = time.time()
        training_time = end_time - start_time
        
        print("\n🎉 TRAINING COMPLETED!")
        print("="*30)
        print(f"⏰ Training time: {training_time/60:.1f} minutes")
        print(f"📊 Final training loss: {training_results.get('train_loss', 'N/A'):.4f}")
        print(f"📊 Final validation loss: {training_results.get('eval_loss', 'N/A'):.4f}")
        print(f"💾 Model saved to: {training_results['model_path']}")
        
        training_completed = True
        model_path = training_results['model_path']
        
    except Exception as e:
        print(f"\n❌ Training failed: {e}")
        print("\n🔧 TROUBLESHOOTING:")
        print("1. Check GPU memory (reduce batch_size)")
        print("2. Try a smaller model (distilgpt2)")
        print("3. Reduce max_samples or epochs")
        print("4. Restart notebook to free memory")
        
        training_completed = False
        model_path = None
        training_results = None
        
        # Print full error for debugging
        import traceback
        print(f"\n🐛 Full error:\n{traceback.format_exc()}")
        
else:
    print("⚠️ Cannot start training - setup incomplete")
    training_completed = False
    model_path = None
    training_results = None

## 📊 Step 4: Training Results Analysis

Let's analyze the training results and visualize the progress.

In [None]:
# Analyze training results
if training_completed and training_results:
    print("📊 TRAINING RESULTS ANALYSIS")
    print("="*40)
    
    # Display key metrics
    results_summary = {
        'Final Training Loss': f"{training_results.get('train_loss', 'N/A'):.4f}",
        'Final Validation Loss': f"{training_results.get('eval_loss', 'N/A'):.4f}",
        'Training Runtime': f"{training_results.get('train_runtime', 0)/60:.1f} minutes",
        'Model Path': training_results.get('model_path', 'N/A'),
        'Hyperparameters': training_results.get('hyperparameters', {})
    }
    
    for key, value in results_summary.items():
        if key == 'Hyperparameters':
            print(f"{key}:")
            for hp_key, hp_value in value.items():
                print(f"   {hp_key}: {hp_value}")
        else:
            print(f"{key}: {value}")
    
    # Plot training curves if available
    training_history = training_results.get('training_history', [])
    
    if training_history:
        print("\n📈 Plotting training curves...")
        
        # Extract losses
        train_logs = [log for log in training_history if 'loss' in log]
        eval_logs = [log for log in training_history if 'eval_loss' in log]
        
        if train_logs and eval_logs:
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
            
            # Training and validation loss
            train_steps = [log['step'] for log in train_logs]
            train_losses = [log['loss'] for log in train_logs]
            eval_steps = [log['step'] for log in eval_logs]
            eval_losses = [log['eval_loss'] for log in eval_logs]
            
            ax1.plot(train_steps, train_losses, label='Training Loss', alpha=0.7)
            ax1.plot(eval_steps, eval_losses, label='Validation Loss', alpha=0.7)
            ax1.set_xlabel('Steps')
            ax1.set_ylabel('Loss')
            ax1.set_title('Training Progress')
            ax1.legend()
            ax1.grid(True, alpha=0.3)
            
            # Learning rate schedule
            lr_logs = [log for log in training_history if 'learning_rate' in log]
            if lr_logs:
                lr_steps = [log['step'] for log in lr_logs]
                lr_values = [log['learning_rate'] for log in lr_logs]
                ax2.plot(lr_steps, lr_values, label='Learning Rate', color='orange')
                ax2.set_xlabel('Steps')
                ax2.set_ylabel('Learning Rate')
                ax2.set_title('Learning Rate Schedule')
                ax2.legend()
                ax2.grid(True, alpha=0.3)
            else:
                ax2.text(0.5, 0.5, 'No learning rate data\navailable', 
                         ha='center', va='center', transform=ax2.transAxes)
                ax2.set_title('Learning Rate Schedule')
            
            plt.tight_layout()
            plt.show()
        
        else:
            print("⚠️ No training curves available to plot")
    
    else:
        print("⚠️ No training history available")
        
else:
    print("⚠️ No training results to analyze")

## 🧪 Step 5: Quick Model Testing

Let's test our trained model with some sample questions!

In [None]:
# Test the trained model
if training_completed and model_path:
    print("🧪 TESTING TRAINED MODEL")
    print("="*35)
    
    try:
        from chatbot import HealthcareChatbot
        
        # Initialize chatbot with trained model
        print("🔄 Loading trained model for testing...")
        chatbot = HealthcareChatbot(model_path)
        
        print("✅ Chatbot initialized!")
        
        # Test questions
        test_questions = [
            "What are the symptoms of diabetes?",
            "How can I prevent heart disease?",
            "What should I do if I have a fever?",
            "How much water should I drink daily?",
            "What are the side effects of aspirin?"
        ]
        
        print("\n🤖 CHATBOT RESPONSES:")
        print("-"*50)
        
        for i, question in enumerate(test_questions, 1):
            print(f"\n{i}. 👤 User: {question}")
            
            try:
                response = chatbot.chat(question)
                print(f"   🏥 Bot: {response['response']}")
                print(f"   📊 Type: {response['type']}")
            except Exception as e:
                print(f"   ❌ Error: {e}")
            
            print("-" * 50)
        
        print("\n✅ Model testing completed!")
        testing_completed = True
        
    except Exception as e:
        print(f"❌ Error testing model: {e}")
        print("\n🔧 TROUBLESHOOTING:")
        print("1. Make sure training completed successfully")
        print("2. Check that model files were saved properly")
        print("3. Try restarting the notebook")
        
        testing_completed = False
        
else:
    print("⚠️ Cannot test model - training not completed")
    testing_completed = False

## 🎯 Step 6: Next Steps

Congratulations! You've trained your healthcare chatbot. Here's what to do next:

In [None]:
print("🎉 TRAINING SESSION COMPLETE!")
print("="*50)

if training_completed:
    print(f"✅ Successfully trained healthcare chatbot!")
    print(f"💾 Model saved to: {model_path}")
    
    if training_results:
        print(f"📊 Final validation loss: {training_results.get('eval_loss', 'N/A'):.4f}")
        print(f"⏱️ Training time: {training_results.get('train_runtime', 0)/60:.1f} minutes")
    
    print("\n🚀 NEXT STEPS:")
    print("-"*20)
    
    print("1. 📈 Evaluate Your Model:")
    print("   → Open: notebooks/04_Model_Evaluation.ipynb")
    print("   → Get detailed performance metrics")
    
    print("\n2. 🌐 Deploy Your Chatbot:")
    print("   → Open: notebooks/05_Deployment.ipynb")
    print("   → Launch web interface")
    
    print("\n3. 🖥️ Quick Testing Options:")
    print(f"   → CLI: python -m src.chatbot {model_path}")
    print(f"   → Web: python -m src.web_interface --model_path {model_path}")
    
    print("\n4. 🔧 Model Improvement Ideas:")
    if training_results and training_results.get('eval_loss', 0) > 2.0:
        print("   → Try more epochs (increase from current)")
        print("   → Use a larger model (dialogpt-medium)")
        print("   → Add more training data")
    else:
        print("   → Your model looks good!")
        print("   → Consider fine-tuning hyperparameters")
        print("   → Test with more diverse questions")
    
else:
    print("❌ Training was not completed successfully.")
    print("\n🔧 TROUBLESHOOTING STEPS:")
    print("-"*25)
    print("1. Check error messages above")
    print("2. Try reducing batch_size or epochs")
    print("3. Use a smaller model (distilgpt2)")
    print("4. Restart notebook and try again")
    print("5. Check dataset format and size")

print("\n📚 RESOURCES:")
print("-"*15)
print("• Training Guide: KAGGLE_DATASET_GUIDE.md")
print("• Quick Commands: YOUR_KAGGLE_DATASET_INSTRUCTIONS.md")
print("• Next Notebook: 04_Model_Evaluation.ipynb")
print("• Deployment: 05_Deployment.ipynb")

print("\n💡 TIP: Save this notebook to preserve your training configuration!")

In [None]:
# Save training summary for reference
if training_completed and training_results:
    summary = {
        'timestamp': datetime.now().isoformat(),
        'config': TRAINING_CONFIG,
        'results': {
            'model_path': model_path,
            'train_loss': training_results.get('train_loss'),
            'eval_loss': training_results.get('eval_loss'),
            'train_runtime': training_results.get('train_runtime'),
        },
        'model_info': model_info if 'model_info' in locals() else {},
        'testing_completed': testing_completed if 'testing_completed' in locals() else False
    }
    
    summary_path = '/workspace/notebooks/training_session_summary.json'
    
    try:
        with open(summary_path, 'w') as f:
            json.dump(summary, f, indent=2, default=str)
        
        print(f"💾 Training summary saved to: {summary_path}")
        
    except Exception as e:
        print(f"⚠️ Could not save summary: {e}")

print("\n🎊 Ready for the next step in your AI journey!")