In [1]:
#!pip install torch numpy matplotlib
#!pip install nest-asyncio
#!pip install "ray[rllib]"

In [1]:
# 필요한 라이브러리 임포트
import asyncio
import nest_asyncio
import os
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
import json
import random
import ray
from ray.rllib.algorithms.dqn import DQNConfig
from ray.tune.logger import pretty_print
import torch
import logging

# 환경 관련 import
from env.battle_env import YakemonEnv

# 유틸리티 관련 import
from utils.battle_logics.create_battle_pokemon import create_battle_pokemon

# RL 관련 import
from RL.reward_calculator import calculate_reward
from RL.get_state_vector import get_state

# 데이터 관련 import
from p_data.move_data import move_data
from p_data.ability_data import ability_data
from p_data.mock_pokemon import create_mock_pokemon_list

# 컨텍스트 관련 import
from context.battle_store import store
from context.duration_store import duration_store




In [2]:
# GPU 가용성 확인
num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
print(f"Available GPUs: {num_gpus}")

Available GPUs: 0


In [3]:
# 전역 변수 초기화
battle_store = store
duration_store = duration_store

In [None]:
# Ray 초기화
ray.init()

2025-05-14 10:48:18,794	ERROR services.py:1362 -- Failed to start the dashboard , return code 3221226505
2025-05-14 10:48:18,796	ERROR services.py:1387 -- Error should be written to 'dashboard.log' or 'dashboard.err'. We are printing the last 20 lines for you. See 'https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#logging-directory-structure' to find where the log file is.
2025-05-14 10:48:18,813	ERROR services.py:1431 -- 
The last 20 lines of C:\Users\wogn2\AppData\Local\Temp\ray\session_2025-05-14_10-48-16_912724_2644\logs\dashboard.log (it contains the error message from the dashboard): 
Traceback (most recent call last):
  File "C:\Users\wogn2\AppData\Roaming\Python\Python312\site-packages\ray\dashboard\dashboard.py", line 247, in <module>
    logging_utils.redirect_stdout_stderr_if_needed(
  File "C:\Users\wogn2\AppData\Roaming\Python\Python312\site-packages\ray\_private\logging_utils.py", line 48, in redirect_stdout_stderr_if_needed
    sys.stderr

In [None]:
# Rainbow DQN 설정
config = (
    DQNConfig()
    .environment(
        YakemonEnv,
        env_config={}
    )
    .training(
        # Rainbow DQN 핵심 기능들
        double_q=True,  # Double DQN
        dueling=True,   # Dueling DQN
        n_step=3,       # N-step learning
        num_atoms=51,   # Distributional DQN
        v_min=-10.0,    # Distributional DQN value range
        v_max=10.0,     # Distributional DQN value range
        noisy=True,     # Noisy Networks
        sigma0=0.5,     # Noisy Networks 초기 파라미터
        
        # Replay Buffer 설정
        replay_buffer_config={
            "type": "PrioritizedEpisodeReplayBuffer",
            "capacity": 100000,
            "alpha": 0.6,  # Prioritized Experience Replay
            "beta": 0.4,   # Importance Sampling
        },
        
        # 학습 관련 설정
        lr=0.00025,
        train_batch_size=32,
        gamma=0.99,
        target_network_update_freq=1000,
        num_steps_sampled_before_learning_starts=1000,
        td_error_loss_fn="huber",  # Huber loss for stability
    )
    .framework("torch")
    .rollouts(num_rollout_workers=0)
    .debugging(log_level="ERROR")
    .resources(num_gpus=num_gpus)
)

2025-05-14 09:06:35,422	ERROR services.py:1362 -- Failed to start the dashboard , return code 3221226505
2025-05-14 09:06:35,424	ERROR services.py:1387 -- Error should be written to 'dashboard.log' or 'dashboard.err'. We are printing the last 20 lines for you. See 'https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#logging-directory-structure' to find where the log file is.
2025-05-14 09:06:35,438	ERROR services.py:1397 -- Couldn't read dashboard.log file. Error: 'utf-8' codec can't decode byte 0xc7 in position 22: invalid continuation byte. It means the dashboard is broken even before it initializes the logger (mostly dependency issues). Reading the dashboard.err file which contains stdout/stderr.
2025-05-14 09:06:35,441	ERROR services.py:1431 -- 
The last 20 lines of C:\Users\wogn2\AppData\Local\Temp\ray\session_2025-05-14_09-06-33_457382_17996\logs\dashboard.err (it contains the error message from the dashboard): 
2025-05-14 09:06:35,619	INFO worker.p

In [5]:
from rainbow import test_agent
from rainbow import train_agent
from datetime import datetime
from rainbow import plot_training_results

In [None]:
# 메인 실행 코드
if __name__ == "__main__":
    # Jupyter에서 중첩된 이벤트 루프 허용
    nest_asyncio.apply()
    
    # 결과 저장 디렉토리 설정
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    results_dir = os.path.join('results', f'training_{timestamp}')
    models_dir = os.path.join('models', f'training_{timestamp}')
    
    # 환경 초기화
    env = YakemonEnv()
    
    # Rainbow DQN 알고리즘 생성
    algo = config.build()
    
    print("Starting Rainbow DQN training...")
    print(f"Results will be saved in: {results_dir}")
    print(f"Models will be saved in: {models_dir}")
    print("\nConfiguration:")
    print(pretty_print(config.to_dict()))
    print("\n" + "="*50 + "\n")
    
    # Rainbow DQN 에이전트 학습
    rainbow_rewards = asyncio.run(train_agent(
        env=env,
        algo=algo,
        num_episodes=1000,
        save_path=models_dir,
        agent_name='rainbow'
    ))
    
    # 학습 결과 시각화
    plt.figure(figsize=(12, 6))
    plt.plot(rainbow_rewards, label='Average Reward', color='blue', alpha=0.6)
    plt.title('Rainbow DQN Training Rewards')
    plt.xlabel('Episode')
    plt.ylabel('Average Reward')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig(os.path.join(results_dir, 'rainbow_rewards.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    print("\nTraining completed!")
    print(f"Results saved in: {results_dir}")
    print(f"Models saved in: {models_dir}")
    
    # 학습된 에이전트 테스트
    print("\nStarting test phase...")
    test_results = asyncio.run(test_agent(
        env=env,
        algo=algo,
        num_episodes=100
    ))
    
    # 테스트 결과 저장
    test_stats = {
        'avg_reward': test_results[0],
        'std_reward': test_results[1],
        'avg_steps': test_results[2],
        'victories': test_results[3],
        'win_rate': test_results[4]
    }
    
    with open(os.path.join(results_dir, 'test_results.json'), 'w') as f:
        json.dump(test_stats, f, indent=4)
    
    with open(os.path.join(results_dir, 'test_results.txt'), 'w') as f:
        f.write("Test Results\n")
        f.write("=" * 50 + "\n\n")
        f.write(f"Average Reward: {test_stats['avg_reward']:.4f} ± {test_stats['std_reward']:.4f}\n")
        f.write(f"Average Steps: {test_stats['avg_steps']:.2f}\n")
        f.write(f"Victories: {test_stats['victories']}/100 (Win Rate: {test_stats['win_rate']:.1f}%)\n")
    
    print("\nTest completed!")
    print(f"Test results saved in: {results_dir}")
    
    # Ray 종료
    ray.shutdown()