# Mahjong DQN Training - Interactive Notebook

This notebook provides interactive training for the Mahjong DQN agent with validation between Stage 1 and Stage 2.

## Training Pipeline:
1. **Stage 1**: Basic win/loss learning (fundamental game knowledge)
2. **Validation**: Test Stage 1 model performance
3. **Stage 2**: Scoring system integration (strategic depth)
4. **Final Evaluation**: Compare Stage 1 vs Stage 2 performance

In [None]:
# Import required libraries
import os
import sys
import json
import time
import numpy as np
import torch
import matplotlib.pyplot as plt
from datetime import datetime

# Add project root to path
if '.' not in sys.path:
    sys.path.append('.')

# Import training modules
from ai.train_stage1 import Stage1Trainer
from ai.train_stage2 import Stage2Trainer
from ai.evaluate import MahjongEvaluator
from ai.dqn_agent import DQNAgent
from ai.utils.config import dqn_config

# Device selection: CUDA > MPS > CPU
def get_best_device():
    if torch.cuda.is_available():
        return "cuda"
    elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        return "mps"
    else:
        return "cpu"

device = get_best_device()
print(f"Training environment ready!")
print(f"Selected device: {device.upper()}")
print(f"PyTorch version: {torch.__version__}")

# Note: Device will be passed to trainers explicitly instead of using set_default_device

## Configuration Setup

Configure training parameters for both stages. Adjust these based on your computational resources and time constraints.

In [None]:
# Training Configuration
RULE = "standard"  # or "taiwan"
DEVICE = device    # Use the automatically selected device (CUDA > MPS > CPU)

# Stage 1 Configuration (Basic Learning with CTDE)
stage1_config = {
    'num_episodes': 1000,        # Reduced for notebook testing (normally 10000)
    'rule_name': RULE,
    'model_dir': 'ai/models/stage1_notebook',
    'log_dir': 'ai/logs/stage1_notebook',
    'centralized_training': True, # Use CTDE for better efficiency
    'save_frequency': 50,        # Save more frequently for validation
    'eval_frequency': 100,       # Evaluate more frequently
    'log_frequency': 10,
    'target_win_rate': 0.25,
    'early_stopping_patience': 500,
    'max_steps_per_episode': 200,
    'device': DEVICE            # Pass device to training config
}

# Stage 2 Configuration (Scoring Integration with CTDE)
stage2_config = {
    'num_episodes': 1500,        # Reduced for notebook testing (normally 15000)
    'rule_name': RULE,
    'model_dir': 'ai/models/stage2_notebook',
    'log_dir': 'ai/logs/stage2_notebook',
    'centralized_training': True, # Use CTDE for better efficiency
    'save_frequency': 50,
    'eval_frequency': 150,
    'log_frequency': 10,
    'target_win_rate': 0.25,
    'target_avg_score': 50,
    'early_stopping_patience': 750,
    'learning_rate_decay': 0.95,
    'decay_frequency': 500,
    'curriculum_learning': True,
    'score_thresholds': [30, 50, 80, 120],
    'max_steps_per_episode': 200,
    'device': DEVICE            # Pass device to training config
}

# Create directories
os.makedirs(stage1_config['model_dir'], exist_ok=True)
os.makedirs(stage1_config['log_dir'], exist_ok=True)
os.makedirs(stage2_config['model_dir'], exist_ok=True)
os.makedirs(stage2_config['log_dir'], exist_ok=True)

print(f"Configuration loaded for {RULE} rule with CTDE")
print(f"Training device: {DEVICE}")
print(f"Stage 1: {stage1_config['num_episodes']} episodes")
print(f"Stage 2: {stage2_config['num_episodes']} episodes")
print(f"Models will be saved to: {stage1_config['model_dir']} and {stage2_config['model_dir']}")
print("🤖 Using Centralized Training with Decentralized Execution (CTDE)")

## Stage 1 Training: Basic Win/Loss Learning

First stage focuses on learning fundamental Mahjong rules and basic win/loss patterns.

In [3]:
# Stage 1 Training
print("=" * 60)
print("STAGE 1 TRAINING: Basic Win/Loss Learning")
print("=" * 60)

stage1_start_time = time.time()

# Create and run Stage 1 trainer
stage1_trainer = Stage1Trainer(stage1_config)

try:
    stage1_trainer.train()
    stage1_training_time = time.time() - stage1_start_time
    print(f"\n✅ Stage 1 training completed in {stage1_training_time:.1f} seconds")
    
    # Check for best model
    stage1_best_path = os.path.join(stage1_config['model_dir'], 'best_model.pth')
    if os.path.exists(stage1_best_path):
        print(f"✅ Stage 1 best model saved to: {stage1_best_path}")
        stage1_success = True
    else:
        print("⚠️ Stage 1 completed but no best model found")
        stage1_success = False
        
except KeyboardInterrupt:
    print("\n⏸️ Stage 1 training interrupted by user")
    stage1_success = False
except Exception as e:
    print(f"\n❌ Stage 1 training failed: {e}")
    stage1_success = False
    raise

print(f"\nStage 1 Status: {'✅ Success' if stage1_success else '❌ Failed'}")

STAGE 1 TRAINING: Basic Win/Loss Learning
Training on device: mps
Starting Stage 1 Training: Basic Win/Loss Learning
Target episodes: 1000
Rule: standard
------------------------------------------------------------


Training:   0%|          | 0/1000 [00:00<?, ?it/s]

Setting up Mahjong game...
Created Player 1 (AI) with wind: 東
Created Player 2 (AI) with wind: 南
Created Player 3 (AI) with wind: 西
Created Player 4 (AI) with wind: 北
Wall shuffled.
Initial hands dealt to all players.
Game setup complete. Rule: standard
Starting player: Player 1 (AI)

❌ Stage 1 training failed: Tensor for argument input is on cpu but expected on mps





RuntimeError: Tensor for argument input is on cpu but expected on mps

## Stage 1 Model Validation

Evaluate the Stage 1 model performance before proceeding to Stage 2.

In [None]:
# Stage 1 Model Validation
print("=" * 60)
print("STAGE 1 MODEL VALIDATION")
print("=" * 60)

if stage1_success and os.path.exists(stage1_best_path):
    try:
        # Create evaluator
        evaluator = MahjongEvaluator(stage1_best_path, rule=RULE)
        
        # Run evaluation against random players
        print("Evaluating Stage 1 model against random players...")
        eval_results = evaluator.evaluate_against_random(num_games=200)
        
        # Display results
        print(f"\n📊 Stage 1 Evaluation Results:")
        print(f"   Games Played: {eval_results['games_played']}")
        print(f"   Win Rate: {eval_results['win_rate']:.3f} ({eval_results['win_rate']*100:.1f}%)")
        print(f"   Average Score: {eval_results['average_score']:.1f}")
        
        if eval_results['game_lengths']:
            avg_length = np.mean(eval_results['game_lengths'])
            print(f"   Average Game Length: {avg_length:.1f} turns")
        
        # Performance assessment
        if eval_results['win_rate'] >= 0.20:  # Should be better than random (0.25 is ideal)
            print("\n✅ Stage 1 model performance is acceptable")
            print("   Ready to proceed to Stage 2 training")
            stage1_validation_passed = True
        else:
            print("\n⚠️ Stage 1 model performance is below threshold")
            print("   Consider retraining with more episodes")
            stage1_validation_passed = False
            
        # Generate evaluation plots
        fig, axes = plt.subplots(1, 2, figsize=(12, 5))
        
        # Win rate visualization
        outcomes = ['Wins', 'Losses', 'Draws']
        values = [eval_results['wins'], eval_results['losses'], eval_results['draws']]
        colors = ['green', 'red', 'gray']
        
        axes[0].pie(values, labels=outcomes, colors=colors, autopct='%1.1f%%', startangle=90)
        axes[0].set_title('Stage 1 Model: Game Outcomes')
        
        # Score distribution
        if eval_results['score_distribution']:
            axes[1].hist(eval_results['score_distribution'], bins=20, alpha=0.7, color='blue')
            axes[1].set_title('Stage 1 Model: Score Distribution')
            axes[1].set_xlabel('Score')
            axes[1].set_ylabel('Frequency')
        
        plt.tight_layout()
        plt.show()
        
    except Exception as e:
        print(f"❌ Stage 1 validation failed: {e}")
        stage1_validation_passed = False
else:
    print("⚠️ Stage 1 model not available for validation")
    stage1_validation_passed = False

print(f"\nStage 1 Validation: {'✅ Passed' if stage1_validation_passed else '❌ Failed'}")

## Decision Point: Proceed to Stage 2?

Based on the validation results above, decide whether to proceed with Stage 2 training.

In [None]:
# Decision Point
print("=" * 60)
print("DECISION POINT: Proceed to Stage 2?")
print("=" * 60)

if stage1_success and stage1_validation_passed:
    print("✅ Stage 1 training was successful and validation passed")
    print("✅ Ready to proceed with Stage 2 training")
    proceed_to_stage2 = True
else:
    print("❌ Issues with Stage 1 training or validation")
    print("\nOptions:")
    print("1. Re-run Stage 1 with more episodes")
    print("2. Adjust Stage 1 configuration")
    print("3. Proceed to Stage 2 anyway (not recommended)")
    
    # Allow manual override
    user_choice = input("\nProceed to Stage 2 anyway? (y/n): ").lower().strip()
    proceed_to_stage2 = user_choice in ['y', 'yes']

print(f"\nDecision: {'Proceed to Stage 2' if proceed_to_stage2 else 'Do not proceed'}")

## Stage 2 Training: Scoring System Integration

Second stage focuses on learning strategic depth and scoring optimization.

In [None]:
# Stage 2 Training
if proceed_to_stage2:
    print("=" * 60)
    print("STAGE 2 TRAINING: Scoring System Integration")
    print("=" * 60)
    
    stage2_start_time = time.time()
    
    # Determine pretrained model path
    pretrained_path = stage1_best_path if (stage1_success and os.path.exists(stage1_best_path)) else None
    
    if pretrained_path:
        print(f"Using Stage 1 pretrained model: {pretrained_path}")
        pretrained_dir = os.path.dirname(pretrained_path)
    else:
        print("Starting Stage 2 from scratch (no pretrained model)")
        pretrained_dir = None
    
    # Create and run Stage 2 trainer
    stage2_trainer = Stage2Trainer(stage2_config, pretrained_dir)
    
    try:
        stage2_trainer.train()
        stage2_training_time = time.time() - stage2_start_time
        print(f"\n✅ Stage 2 training completed in {stage2_training_time:.1f} seconds")
        
        # Check for best model
        stage2_best_path = os.path.join(stage2_config['model_dir'], 'best_model.pth')
        if os.path.exists(stage2_best_path):
            print(f"✅ Stage 2 best model saved to: {stage2_best_path}")
            stage2_success = True
        else:
            print("⚠️ Stage 2 completed but no best model found")
            stage2_success = False
            
    except KeyboardInterrupt:
        print("\n⏸️ Stage 2 training interrupted by user")
        stage2_success = False
    except Exception as e:
        print(f"\n❌ Stage 2 training failed: {e}")
        stage2_success = False
        raise
    
    print(f"\nStage 2 Status: {'✅ Success' if stage2_success else '❌ Failed'}")
else:
    print("⏭️ Skipping Stage 2 training")
    stage2_success = False
    stage2_best_path = None

## Final Model Comparison

Compare the performance of Stage 1 and Stage 2 models to see the improvement.

In [None]:
# Final Model Comparison
print("=" * 60)
print("FINAL MODEL COMPARISON")
print("=" * 60)

comparison_results = {}

# Evaluate Stage 1 model
if stage1_success and os.path.exists(stage1_best_path):
    print("\n🔍 Evaluating Stage 1 model...")
    stage1_evaluator = MahjongEvaluator(stage1_best_path, rule=RULE)
    stage1_final_results = stage1_evaluator.evaluate_against_random(num_games=500)
    comparison_results['stage1'] = stage1_final_results
    
    print(f"Stage 1 Final Results:")
    print(f"   Win Rate: {stage1_final_results['win_rate']:.3f}")
    print(f"   Average Score: {stage1_final_results['average_score']:.1f}")

# Evaluate Stage 2 model
if stage2_success and os.path.exists(stage2_best_path):
    print("\n🔍 Evaluating Stage 2 model...")
    stage2_evaluator = MahjongEvaluator(stage2_best_path, rule=RULE)
    stage2_final_results = stage2_evaluator.evaluate_against_random(num_games=500)
    comparison_results['stage2'] = stage2_final_results
    
    print(f"Stage 2 Final Results:")
    print(f"   Win Rate: {stage2_final_results['win_rate']:.3f}")
    print(f"   Average Score: {stage2_final_results['average_score']:.1f}")

# Compare models if both exist
if 'stage1' in comparison_results and 'stage2' in comparison_results:
    print("\n📊 Model Comparison:")
    
    stage1_wr = comparison_results['stage1']['win_rate']
    stage2_wr = comparison_results['stage2']['win_rate']
    wr_improvement = ((stage2_wr - stage1_wr) / stage1_wr) * 100
    
    stage1_score = comparison_results['stage1']['average_score']
    stage2_score = comparison_results['stage2']['average_score']
    score_improvement = ((stage2_score - stage1_score) / stage1_score) * 100
    
    print(f"   Win Rate Improvement: {wr_improvement:+.1f}%")
    print(f"   Score Improvement: {score_improvement:+.1f}%")
    
    if wr_improvement > 0 or score_improvement > 0:
        print("   ✅ Stage 2 training was beneficial!")
    else:
        print("   ⚠️ Stage 2 training may need adjustment")
    
    # Visualization
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    # Win rate comparison
    models = ['Stage 1', 'Stage 2']
    win_rates = [stage1_wr, stage2_wr]
    colors = ['lightblue', 'darkblue']
    
    axes[0].bar(models, win_rates, color=colors)
    axes[0].set_title('Win Rate Comparison')
    axes[0].set_ylabel('Win Rate')
    axes[0].set_ylim(0, max(win_rates) * 1.2)
    
    # Add value labels on bars
    for i, v in enumerate(win_rates):
        axes[0].text(i, v + 0.01, f'{v:.3f}', ha='center', va='bottom')
    
    # Score comparison
    scores = [stage1_score, stage2_score]
    
    axes[1].bar(models, scores, color=colors)
    axes[1].set_title('Average Score Comparison')
    axes[1].set_ylabel('Average Score')
    axes[1].set_ylim(0, max(scores) * 1.2)
    
    # Add value labels on bars
    for i, v in enumerate(scores):
        axes[1].text(i, v + 1, f'{v:.1f}', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.show()

print("\n🏁 Training and evaluation completed!")

## Training Summary

Summary of the complete training process and final model paths.

In [None]:
# Training Summary
print("=" * 60)
print("TRAINING SUMMARY")
print("=" * 60)

summary = {
    'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
    'rule': RULE,
    'stage1': {
        'success': stage1_success if 'stage1_success' in locals() else False,
        'model_path': stage1_best_path if 'stage1_best_path' in locals() and os.path.exists(stage1_best_path) else None,
        'episodes': stage1_config['num_episodes'],
        'validation_passed': stage1_validation_passed if 'stage1_validation_passed' in locals() else False
    },
    'stage2': {
        'success': stage2_success if 'stage2_success' in locals() else False,
        'model_path': stage2_best_path if 'stage2_best_path' in locals() and stage2_best_path and os.path.exists(stage2_best_path) else None,
        'episodes': stage2_config['num_episodes'],
        'used_pretrained': 'pretrained_path' in locals() and pretrained_path is not None
    },
    'comparison_results': comparison_results if 'comparison_results' in locals() else {}
}

# Display summary
print(f"Training Date: {summary['timestamp']}")
print(f"Rule: {summary['rule']}")
print()
print(f"Stage 1:")
print(f"   Success: {'✅' if summary['stage1']['success'] else '❌'}")
print(f"   Episodes: {summary['stage1']['episodes']}")
print(f"   Validation: {'✅ Passed' if summary['stage1']['validation_passed'] else '❌ Failed'}")
print(f"   Model: {summary['stage1']['model_path'] or 'Not available'}")
print()
print(f"Stage 2:")
print(f"   Success: {'✅' if summary['stage2']['success'] else '❌'}")
print(f"   Episodes: {summary['stage2']['episodes']}")
print(f"   Used Pretrained: {'✅' if summary['stage2']['used_pretrained'] else '❌'}")
print(f"   Model: {summary['stage2']['model_path'] or 'Not available'}")

# Save summary
summary_path = 'ai/training_summary_notebook.json'
with open(summary_path, 'w') as f:
    json.dump(summary, f, indent=2, default=str)

print(f"\n📄 Training summary saved to: {summary_path}")

# Recommendations
print("\n💡 Next Steps:")
if summary['stage2']['success'] and summary['stage2']['model_path']:
    print(f"1. Test your trained agent: python ai/play_vs_ai.py --models {summary['stage2']['model_path']}")
    print(f"2. Detailed evaluation: python ai/evaluate.py --model {summary['stage2']['model_path']} --games 1000")
if summary['stage1']['success'] and summary['stage1']['model_path']:
    if not summary['stage2']['success']:
        print(f"1. Test Stage 1 agent: python ai/play_vs_ai.py --models {summary['stage1']['model_path']}")
        print(f"2. Consider re-running Stage 2 with adjusted parameters")

print("\n🎉 Notebook training completed!")