In [1]:
#!pip install torch numpy matplotlib
#!pip install nest-asyncio
#!pip install "ray[rllib]"

In [2]:
# 필요한 라이브러리 임포트
import asyncio
import nest_asyncio
import os
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
import json
import random
import torch
import logging

# 환경 관련 import
from env.battle_env import YakemonEnv

# 유틸리티 관련 import
from utils.battle_logics.create_battle_pokemon import create_battle_pokemon
from utils.visualization import plot_training_results

# RL 관련 import
from RL.reward_calculator import calculate_reward
from RL.get_state_vector import get_state

# 데이터 관련 import
from p_data.move_data import move_data
from p_data.ability_data import ability_data
from p_data.mock_pokemon import create_mock_pokemon_list

# 컨텍스트 관련 import
from context.battle_store import store
from context.duration_store import duration_store

# agent 관련 import
from agent.rainbow_agent import DQNAgent

In [3]:
# 전역 변수 초기화
battle_store = store
duration_store = duration_store

HYPERPARAMS = {
    "memory_size": 100000,
    "batch_size": 64,
    "target_update": 10,
    "gamma": 0.99,
    "alpha": 0.4,
    "beta": 0.4,
    "prior_eps": 1e-6,
    "v_min": -10.0,
    "v_max": 10.0,
    "atom_size": 51,
    "n_step": 3,
    "num_episodes": 1000,
    "save_interval": 100,
    "test_episodes": 100,
    "state_dim": 126,  # get_state_vector의 출력 차원
    "action_dim": 6,   # 4개의 기술 + 2개의 교체
    "learning_rate": 0.0003  # 학습률 추가"
}

In [4]:
from rainbow import test_agent
from rainbow import train_agent
from datetime import datetime

In [None]:
# 메인 실행 코드
if __name__ == "__main__":
    # Jupyter에서 중첩된 이벤트 루프 허용
    nest_asyncio.apply()
    
    # 결과 저장 디렉토리 설정
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    results_dir = os.path.join('results', f'training_{timestamp}')
    models_dir = os.path.join('models', f'training_{timestamp}')
    
    # 환경 초기화
    env = YakemonEnv()
    
    # Rainbow DQN 에이전트 생성
    rainbow_agent = DQNAgent(
        env=env,
        memory_size=HYPERPARAMS["memory_size"],
        batch_size=HYPERPARAMS["batch_size"],
        target_update=HYPERPARAMS["target_update"],
        seed=42,
        gamma=HYPERPARAMS["gamma"],
        alpha=HYPERPARAMS["alpha"],
        beta=HYPERPARAMS["beta"],
        prior_eps=HYPERPARAMS["prior_eps"],
        v_min=HYPERPARAMS["v_min"],
        v_max=HYPERPARAMS["v_max"],
        atom_size=HYPERPARAMS["atom_size"],
        n_step=HYPERPARAMS["n_step"]
        learning_rate=HYPERPARAMS["learning_rate"]
    )
    
    print("Starting Rainbow DQN training...")
    print(f"Results will be saved in: {results_dir}")
    print(f"Models will be saved in: {models_dir}")
    print("\nHyperparameters:")
    for key, value in HYPERPARAMS.items():
        print(f"  {key}: {value}")
    print("\n" + "="*50 + "\n")
    
    # Rainbow DQN 에이전트 학습
    rainbow_rewards, rainbow_losses = asyncio.run(train_agent(
        env=env,
        agent=rainbow_agent,
        num_episodes=HYPERPARAMS["num_episodes"],
        save_path=models_dir,
        agent_name='rainbow'
    ))
    
    # 학습 결과 시각화
    plot_training_results(
        rewards_history=rainbow_rewards,
        losses_history=rainbow_losses,
        agent_name='Rainbow DQN',
        save_path=results_dir
    )
    
    print("\nTraining completed!")
    print(f"Results saved in: {results_dir}")
    print(f"Models saved in: {models_dir}")
    
    # 학습된 에이전트 테스트
    print("\nStarting test phase...")
    test_results = asyncio.run(test_agent(
        env=env,
        agent=rainbow_agent,
        num_episodes=HYPERPARAMS["test_episodes"]
    ))
    
    # 테스트 결과 저장
    test_stats = {
        'avg_reward': test_results[0],
        'std_reward': test_results[1],
        'avg_steps': test_results[2],
        'victories': test_results[3],
        'win_rate': test_results[4]
    }
    
    with open(os.path.join(results_dir, 'test_results.json'), 'w') as f:
        json.dump(test_stats, f, indent=4)
    
    with open(os.path.join(results_dir, 'test_results.txt'), 'w') as f:
        f.write("Test Results\n")
        f.write("=" * 50 + "\n\n")
        f.write(f"Average Reward: {test_stats['avg_reward']:.4f} ± {test_stats['std_reward']:.4f}\n")
        f.write(f"Average Steps: {test_stats['avg_steps']:.2f}\n")
        f.write(f"Victories: {test_stats['victories']}/{HYPERPARAMS['test_episodes']} (Win Rate: {test_stats['win_rate']:.1f}%)\n")
    
    print("\nTest completed!")
    print(f"Test results saved in: {results_dir}")