# N-Step Actor-Critic for LunarLander-v2 (PyTorch & Gymnasium) - Granular Cells
이 노트북은 Gymnasium의 LunarLander-v2 환경에서 N-Step Actor-Critic 알고리즘을 사용하여 강화학습 에이전트를 훈련합니다. 각 단계를 세분화된 셀로 나누어 Colab과 같은 환경에서 단계별 실행 및 이해를 돕도록 구성했습니다.

## 1. Setup: Installs and Imports
- 필요한 라이브러리를 설치하고 임포트합니다.
- Colab 환경에서는 `gymnasium[box2d]`와 `opencv-python-headless` 설치가 필요할 수 있습니다.

In [None]:
# Colab 또는 유사 환경에서 실행 시 필요한 라이브러리 설치 (주석 해제 후 실행)
# !pip install gymnasium[box2d] opencv-python-headless matplotlib seaborn -q

In [None]:
# --- 라이브러리 임포트 ---
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import cv2  # OpenCV for video writing
from datetime import datetime
import os
import logging
from collections import deque
import time # 시간 측정용

## 2. Logging Configuration
- 훈련 과정을 기록하기 위한 로깅 시스템을 설정합니다.
- 로그 파일과 비디오, 플롯을 저장할 고유한 디렉토리를 생성합니다.

In [None]:
# --- 로깅 설정 함수 정의 ---
def setup_logging():
    """훈련 로그를 파일과 콘솔에 기록하고, 로그 디렉토리 경로를 반환합니다."""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_dir = f"lunar_lander_a2c_nstep_{timestamp}"
    os.makedirs(log_dir, exist_ok=True)
    print(f"Log directory created: {log_dir}")

    # 로거 설정 (파일 핸들러 + 스트림 핸들러)
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO) # 로그 레벨 설정
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')

    # 파일 핸들러
    file_handler = logging.FileHandler(os.path.join(log_dir, "training.log"))
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

    # 스트림 핸들러 (콘솔 출력)
    stream_handler = logging.StreamHandler()
    stream_handler.setFormatter(formatter)
    logger.addHandler(stream_handler)

    # 다른 라이브러리의 로깅 레벨 조절 (선택 사항)
    logging.getLogger('matplotlib').setLevel(logging.WARNING)

    return logger, log_dir

In [None]:
# --- 로거 및 디렉토리 초기화 ---
logger, log_base_dir = setup_logging()
video_dir = os.path.join(log_base_dir, "videos")
os.makedirs(video_dir, exist_ok=True)
logger.info(f"Video save directory: {video_dir}")
plot_save_path = os.path.join(log_base_dir, "rewards_history.png")
logger.info(f"Plot save path: {plot_save_path}")

## 3. Network Definitions
- Actor-Critic 알고리즘에 사용될 신경망 모델을 정의합니다.

### 3.1 Actor Network
- 상태(State)를 입력받아 각 행동(Action)의 확률 분포를 출력합니다. (Policy Network)

In [None]:
class Actor(nn.Module):
    """Actor Network: Maps state to action probabilities."""
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super(Actor, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
            nn.Softmax(dim=-1) # Probabilities for each action
        )
        logger.debug("Actor network instance created.")

    def forward(self, state):
        return self.network(state)

### 3.2 Critic Network
- 상태(State)를 입력받아 해당 상태의 가치(Value)를 추정합니다. (Value Network)

In [None]:
class Critic(nn.Module):
    """Critic Network: Maps state to its estimated value."""
    def __init__(self, state_dim, hidden_dim=256):
        super(Critic, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1) # Single value output representing state value
        )
        logger.debug("Critic network instance created.")

    def forward(self, state):
        return self.network(state)

## 4. N-Step Actor-Critic Agent
- 위에서 정의한 Actor, Critic 네트워크를 사용하여 N-Step Actor-Critic 에이전트를 구현합니다.
- 이 클래스는 에이전트의 초기화, 행동 선택, 경험 저장, N-Step 리턴 계산, 네트워크 업데이트 로직을 포함합니다.

In [None]:
class NStepActorCritic:
    """N-Step Actor-Critic Agent implementation."""
    def __init__(self, state_dim, action_dim, lr_actor=3e-4, lr_critic=3e-4, gamma=0.99, n_steps=5):
        """Initializes the agent, networks, optimizers, and buffers."""
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        logger.info(f"Agent initializing on device: {self.device}")

        # Networks
        self.actor = Actor(state_dim, action_dim).to(self.device)
        self.critic = Critic(state_dim).to(self.device)

        # Optimizers
        self.optimizer_actor = optim.Adam(self.actor.parameters(), lr=lr_actor)
        self.optimizer_critic = optim.Adam(self.critic.parameters(), lr=lr_critic)

        # Parameters
        self.gamma = gamma
        self.n_steps = n_steps
        self.step_losses = [] # Store losses for averaging

        # N-Step Buffers using deque for automatic old data removal
        self.state_buffer = deque(maxlen=n_steps)
        self.action_buffer = deque(maxlen=n_steps)
        self.reward_buffer = deque(maxlen=n_steps)
        self.next_state_buffer = deque(maxlen=n_steps)
        self.done_buffer = deque(maxlen=n_steps)
        self.prob_buffer = deque(maxlen=n_steps) # Stores action probabilities

        logger.info(f"NStepActorCritic initialized: n_steps={n_steps}, lr_actor={lr_actor}, lr_critic={lr_critic}, gamma={gamma}")

    def select_action(self, state):
        """Selects an action based on the current policy (actor network)."""
        state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        self.actor.eval() # Set actor to evaluation mode
        with torch.no_grad():
            probs = self.actor(state_tensor).squeeze(0) # Get action probabilities
        self.actor.train() # Set actor back to training mode

        action_dist = torch.distributions.Categorical(probs)
        action = action_dist.sample() # Sample action from distribution
        action_prob = probs[action] # Get probability of the chosen action
        # logger.debug(f"State: {state}, Probs: {probs.cpu().numpy()}, Action: {action.item()}, Prob: {action_prob.item():.4f}")
        return action.item(), action_prob # Return action index and its probability tensor

    def store_transition(self, state, action, reward, next_state, done, action_prob):
        """Stores a transition in the N-step buffers."""
        self.state_buffer.append(state)
        self.action_buffer.append(action)
        self.reward_buffer.append(reward)
        self.next_state_buffer.append(next_state)
        self.done_buffer.append(done)
        # Detach prob tensor from graph and move to CPU before storing
        self.prob_buffer.append(action_prob.detach().cpu())

    def compute_n_step_returns(self):
        """Calculates N-step returns based on rewards and the final state value."""
        if not self.reward_buffer: return [] # Return empty if buffer is empty

        returns = deque(maxlen=self.n_steps)
        last_state = self.next_state_buffer[-1]
        last_done = self.done_buffer[-1]

        # Estimate value of the state after N steps
        if last_done:
            R = 0.0
        else:
            self.critic.eval() # Set critic to evaluation mode
            with torch.no_grad():
                last_state_tensor = torch.FloatTensor(last_state).unsqueeze(0).to(self.device)
                R = self.critic(last_state_tensor).item()
            self.critic.train() # Set critic back to training mode

        # Calculate returns backward from the Nth step
        for i in reversed(range(len(self.reward_buffer))):
            R = self.reward_buffer[i] + self.gamma * R * (1 - self.done_buffer[i])
            returns.appendleft(R) # Add to the left to maintain order

        return list(returns)

    def update(self):
        """Performs a learning update for both Actor and Critic networks."""
        # Check if buffer has enough transitions
        if len(self.state_buffer) < self.n_steps:
            # logger.debug("Buffer not full enough for update.")
            return 0.0, 0.0 # Not ready to update

        # Calculate N-step returns
        n_step_returns = self.compute_n_step_returns()
        if not n_step_returns: # Handle empty return case
             return 0.0, 0.0

        returns_tensor = torch.FloatTensor(n_step_returns).unsqueeze(1).to(self.device)

        # Prepare tensors from buffers
        states_tensor = torch.FloatTensor(np.array(self.state_buffer)).to(self.device)
        # Retrieve stored action probabilities (already detached)
        # Need to stack them correctly and move to device
        action_probs_tensor = torch.stack(list(self.prob_buffer)).unsqueeze(1).to(self.device)

        # --- Critic Update ---
        current_values = self.critic(states_tensor) # V(s_t)
        advantages = returns_tensor - current_values  # TD Error: R_t^(n) - V(s_t)
        critic_loss = advantages.pow(2).mean() # MSE Loss

        self.optimizer_critic.zero_grad()
        critic_loss.backward()
        nn.utils.clip_grad_norm_(self.critic.parameters(), 0.5) # Optional: Gradient Clipping
        self.optimizer_critic.step()

        # --- Actor Update ---
        log_probs = torch.log(action_probs_tensor + 1e-6) # Add epsilon for numerical stability
        # Actor loss using policy gradient theorem (log_prob * advantage)
        # detach advantages so its gradient doesn't affect actor loss calculation
        actor_loss = -(log_probs * advantages.detach()).mean()

        self.optimizer_actor.zero_grad()
        actor_loss.backward()
        nn.utils.clip_grad_norm_(self.actor.parameters(), 0.5) # Optional: Gradient Clipping
        self.optimizer_actor.step()

        # Store step loss for monitoring
        total_step_loss = critic_loss.item() + actor_loss.item()
        self.step_losses.append(total_step_loss)
        # logger.debug(f"Update successful. Critic Loss: {critic_loss.item():.4f}, Actor Loss: {actor_loss.item():.4f}")

        # Clear buffers for the next set of N-steps (deque manages this automatically)
        # No explicit clear needed unless logic changes

        return critic_loss.item(), actor_loss.item()

## 5. Helper Function: Video Saving
- 에피소드의 프레임들을 받아 MP4 비디오 파일로 저장합니다. OpenCV 라이브러리를 사용합니다.

In [None]:
def save_episode_video(frames, episode, video_dir):
    """Saves a list of frames as an MP4 video file."""
    if not frames:
        logger.warning(f"Attempted to save video for episode {episode}, but no frames were provided.")
        return

    try:
        # Ensure frame is NumPy array
        if isinstance(frames[0], torch.Tensor):
            frames = [f.cpu().numpy() for f in frames]

        # Get frame dimensions
        height, width, channels = frames[0].shape
        if channels != 3:
             logger.error(f"Invalid frame shape for video saving: {frames[0].shape}. Expected 3 channels.")
             return

        # Define video path and codec
        video_path = os.path.join(video_dir, f"episode_{episode:04d}.mp4")
        fourcc = cv2.VideoWriter_fourcc(*'mp4v') # MP4 codec
        fps = 30.0 # Frames per second

        # Create VideoWriter object
        video = cv2.VideoWriter(video_path, fourcc, fps, (width, height))
        if not video.isOpened():
            logger.error(f"Failed to open video writer for path: {video_path}")
            return

        logger.debug(f"Saving video: {video_path}, Resolution: {width}x{height}, FPS: {fps}")
        # Write frames to video file
        for frame in frames:
            # Gymnasium's rgb_array is typically RGB, OpenCV expects BGR
            video.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))

        video.release() # Release the video writer
        logger.debug(f"Video saved successfully for episode {episode}.")

    except Exception as e:
        logger.error(f"Error during video saving for episode {episode}: {e}", exc_info=True)

## 6. Main Training Loop Function
- 환경 설정, 에이전트 생성, 훈련 루프 실행 로직을 포함합니다.
- 각 에피소드에서 환경과 상호작용하며 데이터를 수집하고 에이전트를 업데이트합니다.
- 주기적으로 로그 출력, 비디오 저장, 보상 그래프 저장을 수행합니다.

In [None]:
def train(env_name="LunarLander-v2", num_episodes=2000, n_steps=5, lr_actor=3e-4, lr_critic=3e-4, gamma=0.99, save_video_interval=100, log_interval=10, plot_interval=50, random_seed=42):
    """Main function to train the N-Step Actor-Critic agent."""
    logger.info(f"--- Starting Training ---")
    logger.info(f"Environment: {env_name}, Episodes: {num_episodes}, N-Steps: {n_steps}")
    logger.info(f"LR Actor: {lr_actor}, LR Critic: {lr_critic}, Gamma: {gamma}")
    logger.info(f"Save Video Interval: {save_video_interval}, Log Interval: {log_interval}, Plot Interval: {plot_interval}")
    logger.info(f"Random Seed: {random_seed}")

    try:
        # Create environment with specified render mode
        env = gym.make(env_name, render_mode="rgb_array")
    except Exception as e:
        logger.error(f"Failed to create environment '{env_name}': {e}", exc_info=True)
        return None, None, None # Return None if env creation fails

    # Set seeds for reproducibility
    torch.manual_seed(random_seed)
    np.random.seed(random_seed)
    # env.seed(random_seed) # Deprecated, use reset(seed=...) instead

    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    logger.info(f"State Dim: {state_dim}, Action Dim: {action_dim}")

    # Initialize agent
    agent = NStepActorCritic(state_dim, action_dim, n_steps=n_steps, lr_actor=lr_actor, lr_critic=lr_critic, gamma=gamma)

    # History lists for plotting and analysis
    rewards_history = []
    avg_rewards_history = []
    avg_losses_history = []
    window_size = 100 # For calculating moving average reward

    total_training_start_time = time.time()

    # --- Training Loop ---
    for episode in range(num_episodes):
        state, info = env.reset(seed=random_seed + episode) # Reset env with new seed
        episode_reward = 0
        frames = []
        steps_in_episode = 0
        done = False
        episode_losses = [] # Store losses for this specific episode

        episode_start_time = time.time()

        while not done:
            # Render environment frame if needed for video saving
            if episode % save_video_interval == 0:
                try:
                    frame = env.render()
                    if frame is not None: frames.append(frame)
                except Exception as render_err:
                    logger.error(f"Rendering error in ep {episode}: {render_err}", exc_info=False)
                    frames = [] # Avoid saving corrupted video
                    break # Stop collecting frames for this episode

            # Agent selects action
            action, action_prob = agent.select_action(state)

            # Environment steps forward
            next_state, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated

            # Store experience
            agent.store_transition(state, action, reward, next_state, done, action_prob)

            # Perform agent update if buffer is ready or episode ends
            if len(agent.state_buffer) >= n_steps or done:
                c_loss, a_loss = agent.update()
                if c_loss is not None and a_loss is not None: # Check if update happened
                     episode_losses.append(c_loss + a_loss)

            # Move to next state
            state = next_state
            episode_reward += reward
            steps_in_episode += 1

            if done: break # Exit while loop if episode finished

        # --- End of Episode ---
        rewards_history.append(episode_reward)
        avg_reward = np.mean(rewards_history[-window_size:]) if rewards_history else 0
        avg_rewards_history.append(avg_reward)

        # Calculate average loss for the episode
        avg_episode_loss = np.mean(episode_losses) if episode_losses else 0
        avg_losses_history.append(avg_episode_loss)

        episode_end_time = time.time()
        episode_duration = episode_end_time - episode_start_time

        # Log progress periodically
        if episode % log_interval == 0 or episode == num_episodes - 1:
            logger.info(f"Ep {episode:04d}/{num_episodes} | "
                        f"Steps: {steps_in_episode:3d} | "
                        f"Return: {episode_reward:7.2f} | "
                        f"Avg Return ({window_size}ep): {avg_reward:7.2f} | "
                        f"Avg Ep Loss: {avg_episode_loss:7.4f} | "
                        f"Duration: {episode_duration:5.2f}s")

        # Save video periodically
        if episode % save_video_interval == 0 and frames:
            save_episode_video(frames, episode, video_dir)

        # Save plot periodically
        if episode % plot_interval == 0 or episode == num_episodes - 1:
            try:
                plt.figure(figsize=(12, 6))
                # Plot rewards
                plt.plot(rewards_history, label='Episode Reward', alpha=0.5, color='lightblue')
                plt.plot(avg_rewards_history, label=f'Avg Reward (Last {window_size} ep)', linewidth=2, color='blue')
                # Plot loss on secondary y-axis if desired (optional)
                # ax2 = plt.gca().twinx()
                # ax2.plot(avg_losses_history, label='Avg Episode Loss', color='orange', linestyle='--')
                # ax2.set_ylabel('Average Loss', color='orange')
                # ax2.tick_params(axis='y', labelcolor='orange')

                plt.title(f"Training Progress - {env_name} (N-Step A2C)")
                plt.xlabel("Episode")
                plt.ylabel("Total Reward", color='blue')
                plt.tick_params(axis='y', labelcolor='blue')
                plt.legend(loc='upper left')
                # if 'ax2' in locals(): plt.legend(loc='upper right') # Adjust legend location if secondary axis exists
                plt.grid(True)
                plt.savefig(plot_save_path)
                plt.close() # Close plot to free memory
                logger.debug(f"Reward history plot updated and saved to {plot_save_path}")
            except Exception as plot_err:
                logger.error(f"Error saving plot: {plot_err}", exc_info=True)

    # --- End of Training ---
    env.close() # Close the environment
    total_training_end_time = time.time()
    total_duration_seconds = total_training_end_time - total_training_start_time
    logger.info(f"--- Training completed in {total_duration_seconds:.2f} seconds ---")

    return rewards_history, avg_rewards_history, avg_losses_history # Return collected data

## 7. Training Execution
- 이 섹션에서는 실제 훈련을 시작합니다.

### 7.1 Set Training Hyperparameters
- 훈련에 사용할 하이퍼파라미터를 설정합니다. 이 값을 변경하여 훈련 성능을 조절할 수 있습니다.

In [None]:
# --- Training Hyperparameters ---
ENV_NAME = "LunarLander-v2"
NUM_EPISODES = 1500       # 훈련할 총 에피소드 수
N_STEPS = 8              # N-스텝 리턴 계산 시 사용할 스텝 수
LR_ACTOR = 5e-4          # Actor 학습률
LR_CRITIC = 5e-4         # Critic 학습률
GAMMA = 0.99             # 할인 계수 (Discount factor)
SAVE_VIDEO_INTERVAL = 200 # 비디오 저장 주기 (에피소드 단위)
LOG_INTERVAL = 10        # 로그 출력 주기 (에피소드 단위)
PLOT_INTERVAL = 50       # 플롯 저장 주기 (에피소드 단위)
RANDOM_SEED = 42         # 재현성을 위한 랜덤 시드

### 7.2 Run Training
- 설정된 하이퍼파라미터로 `train` 함수를 호출하여 훈련을 시작하고 결과를 저장합니다.

In [None]:
# --- Start Training ---
logger.info("="*20 + " Initiating Training Run " + "="*20)
training_results = train(
    env_name=ENV_NAME,
    num_episodes=NUM_EPISODES,
    n_steps=N_STEPS,
    lr_actor=LR_ACTOR,
    lr_critic=LR_CRITIC,
    gamma=GAMMA,
    save_video_interval=SAVE_VIDEO_INTERVAL,
    log_interval=LOG_INTERVAL,
    plot_interval=PLOT_INTERVAL,
    random_seed=RANDOM_SEED
)

# Unpack results if training was successful
if training_results:
    rewards_hist, avg_rewards_hist, avg_losses_hist = training_results
    logger.info("Training function executed successfully.")
else:
    logger.error("Training function failed to execute properly.")
    rewards_hist, avg_rewards_hist, avg_losses_hist = None, None, None

logger.info("="*20 + " Training Run Finished " + "="*20)

## 8. Results Visualization
- 훈련 과정에서 기록된 보상 및 손실 데이터를 사용하여 최종 그래프를 시각화합니다.

In [None]:
# --- Plot Final Results ---
if rewards_hist and avg_rewards_hist and avg_losses_hist:
    logger.info("Displaying final training results plots...")
    plt.figure(figsize=(18, 7))

    # Plot Rewards
    plt.subplot(1, 2, 1)
    plt.plot(rewards_hist, label='Episode Reward', alpha=0.6, color='lightblue')
    plt.plot(avg_rewards_hist, label='Avg Reward (Smoothed)', linewidth=2.5, color='blue')
    plt.title("Episode Rewards over Time", fontsize=14)
    plt.xlabel("Episode", fontsize=12)
    plt.ylabel("Total Reward", fontsize=12)
    plt.legend(fontsize=10)
    plt.grid(True, linestyle='--', alpha=0.6)

    # Plot Losses
    plt.subplot(1, 2, 2)
    plt.plot(avg_losses_hist, label='Average Episode Loss', color='darkorange', linewidth=2)
    plt.title("Average Training Loss per Episode", fontsize=14)
    plt.xlabel("Episode", fontsize=12)
    plt.ylabel("Average Loss", fontsize=12)
    plt.legend(fontsize=10)
    plt.grid(True, linestyle='--', alpha=0.6)

    plt.suptitle(f"N-Step Actor-Critic Training Results ({ENV_NAME})", fontsize=16, y=1.02)
    plt.tight_layout()
    plt.show() # Display the plot in the notebook
    logger.info("Final plots displayed.")

    # Save the final combined plot as well
    try:
        final_plot_path = os.path.join(log_base_dir, "final_training_summary.png")
        plt.savefig(final_plot_path)
        logger.info(f"Final summary plot saved to {final_plot_path}")
    except Exception as e:
        logger.error(f"Failed to save final plot: {e}", exc_info=True)
    plt.close() # Close plot figure

else:
    logger.warning("No training results available to plot.")