In [1]:
!pip install pygame



In [2]:
!pip install gymnasium  # 安装 gymnasium 和 Atari 环境
!pip install flappy-bird-gymnasium  # 安装 Flappy Bird 环境



In [1]:
import gymnasium as gym
import flappy_bird_gymnasium  
import numpy as np
import random
import torch
from torch import nn
import torch.nn.functional as F
import yaml
import os
import pygame
from experience_replay import ReplayMemory
from dqn import DQN
from datetime import datetime, timedelta
import matplotlib
import matplotlib.pyplot as plt
from itertools import count

# For printing date and time
DATE_FORMAT = "%m-%d %H:%M:%S"

# Directory for saving run info
RUNS_DIR = "runs"
os.makedirs(RUNS_DIR, exist_ok=True)

# 'Agg': used to generate plots as images and save them to a file instead of rendering to screen
matplotlib.use('Agg')

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cpu'  # Force CPU (optional depending on your system)

# Deep Q-Learning Agent
class Agent():

    def __init__(self, hyperparameter_set):
        with open('hyperparameters.yml', 'r') as file:
            all_hyperparameter_sets = yaml.safe_load(file)
            hyperparameters = all_hyperparameter_sets[hyperparameter_set]

        self.hyperparameter_set = hyperparameter_set

        # Hyperparameters (adjustable)
        self.env_id = hyperparameters['env_id']
        self.learning_rate_a = hyperparameters['learning_rate_a']        # learning rate (alpha)
        self.discount_factor_g = hyperparameters['discount_factor_g']      # discount rate (gamma)
        self.network_sync_rate = hyperparameters['network_sync_rate']      # number of steps the agent takes before syncing the policy and target network
        self.replay_memory_size = hyperparameters['replay_memory_size']     # size of replay memory
        self.mini_batch_size = hyperparameters['mini_batch_size']        # size of the training data set sampled from the replay memory
        self.epsilon_init = hyperparameters['epsilon_init']           # 1 = 100% random actions
        self.epsilon_decay = hyperparameters['epsilon_decay']          # epsilon decay rate
        self.epsilon_min = hyperparameters['epsilon_min']            # minimum epsilon value
        self.stop_on_reward = hyperparameters['stop_on_reward']         # stop training after reaching this number of rewards
        self.fc1_nodes = hyperparameters['fc1_nodes']
        self.env_make_params = hyperparameters.get('env_make_params', {}) # Get optional environment-specific parameters, default to empty dict
        self.enable_double_dqn = hyperparameters['enable_double_dqn']      # double dqn on/off flag
        self.enable_dueling_dqn = hyperparameters['enable_dueling_dqn']     # dueling dqn on/off flag

        # Neural Network
        self.loss_fn = nn.MSELoss()          # NN Loss function. MSE=Mean Squared Error can be swapped to something else.
        self.optimizer = None                # NN Optimizer. Initialize later.

        # Path to Run info
        self.LOG_FILE = os.path.join(RUNS_DIR, f'{self.hyperparameter_set}.log')
        self.MODEL_FILE = os.path.join(RUNS_DIR, f'{self.hyperparameter_set}.pt')
        self.GRAPH_FILE = os.path.join(RUNS_DIR, f'{self.hyperparameter_set}.png')

        # Initialize pygame mixer for sound effects
        pygame.mixer.init()

        # Load sound effects
        self.wing_sound = pygame.mixer.Sound('/Users/liuyanru/Desktop/CDS 524/Assignment/flappybird/audio/wing.wav')
        self.swoosh_sound = pygame.mixer.Sound('/Users/liuyanru/Desktop/CDS 524/Assignment/flappybird/audio/swoosh.wav')
        self.point_sound = pygame.mixer.Sound('/Users/liuyanru/Desktop/CDS 524/Assignment/flappybird/audio/point.wav')
        self.hit_sound = pygame.mixer.Sound('/Users/liuyanru/Desktop/CDS 524/Assignment/flappybird/audio/hit.wav')
        self.die_sound = pygame.mixer.Sound('/Users/liuyanru/Desktop/CDS 524/Assignment/flappybird/audio/die.wav')

    def play_wing_sound(self):
        self.wing_sound.play()

    def play_swoosh_sound(self):
        self.swoosh_sound.play()

    def play_point_sound(self):
        self.point_sound.play()

    def play_hit_sound(self):
        self.hit_sound.play()

    def play_die_sound(self):
        self.die_sound.play()

    def run(self, is_training=True, render=False):
        if is_training:
            # Initialize variables for training logging
            training_rewards = []
            start_time = datetime.now()
            last_graph_update_time = start_time

            log_message = f"{start_time.strftime(DATE_FORMAT)}: Training starting..."
            print(log_message)
            with open(self.LOG_FILE, 'w') as file:
                file.write(log_message + '\n')

        # Create instance of the environment.
        env = gym.make(self.env_id, render_mode='human' if render else None, **self.env_make_params)

        # Number of possible actions
        num_actions = env.action_space.n

        # Get observation space size
        num_states = env.observation_space.shape[0]  # Expecting type: Box(low, high, (shape0,), float64)

        # List to keep track of rewards collected per episode.
        rewards_per_episode = []

        # Create policy and target network.
        policy_dqn = DQN(num_states, num_actions, self.fc1_nodes, self.enable_dueling_dqn).to(device)

        # Load the pre-trained model if available
        if os.path.exists(self.MODEL_FILE):
            print("Loading pre-trained model...")
            policy_dqn.load_state_dict(torch.load(self.MODEL_FILE))
            policy_dqn.eval()  # Set to evaluation mode

        if is_training:
            # Initialize epsilon
            epsilon = self.epsilon_init

            # Initialize replay memory
            memory = ReplayMemory(self.replay_memory_size)

            # Create the target network and make it identical to the policy network
            target_dqn = DQN(num_states, num_actions, self.fc1_nodes, self.enable_dueling_dqn).to(device)
            target_dqn.load_state_dict(policy_dqn.state_dict())

            # Policy network optimizer.
            self.optimizer = torch.optim.Adam(policy_dqn.parameters(), lr=self.learning_rate_a)

            # List to keep track of epsilon decay
            epsilon_history = []

            # Track number of steps taken.
            step_count = 0

            # Track best reward
            best_reward = -9999999
        else:
            # Load learned policy
            policy_dqn.load_state_dict(torch.load(self.MODEL_FILE))

            # switch model to evaluation mode
            policy_dqn.eval()

        # Train indefinitely, manually stop the run when you are satisfied
        for episode in count():

            state, _ = env.reset()  # Initialize environment. Reset returns (state,info).
            state = torch.tensor(state, dtype=torch.float, device=device)  # Convert state to tensor directly on device

            terminated = False      # True when agent reaches goal or fails
            episode_reward = 0.0    # Used to accumulate rewards per episode

            while not terminated and episode_reward < self.stop_on_reward:

                # Select action based on epsilon-greedy
                if is_training and random.random() < epsilon:
                    action = env.action_space.sample()
                    action = torch.tensor(action, dtype=torch.int64, device=device)
                else:
                    with torch.no_grad():
                        action = policy_dqn(state.unsqueeze(dim=0)).squeeze().argmax()

                # Execute action
                new_state, reward, terminated, truncated, info = env.step(action.item())

                # Accumulate rewards
                episode_reward += reward

                # Convert new state and reward to tensors
                new_state = torch.tensor(new_state, dtype=torch.float, device=device)
                reward = torch.tensor(reward, dtype=torch.float, device=device)

                # Play sound effects based on game events
                if action.item() == 1:  # 假设动作1是跳跃
                    self.play_wing_sound()

                if reward > 0:  # 得分时播放音效
                    self.play_point_sound()
                    self.play_swoosh_sound()  # Add swoosh sound for scoring

                if terminated:  # 游戏结束时播放死亡音效
                    self.play_die_sound()
                elif reward == -1:  # 碰撞时播放撞击音效
                    self.play_hit_sound()

                if render:
                    self.show_frame(new_state)  # Show frame using pygame

                if is_training:
                    # Save experience into memory
                    memory.append((state, action, new_state, reward, terminated))

                    step_count += 1

                # Move to next state
                state = new_state

            rewards_per_episode.append(episode_reward)

            if is_training:
                if episode_reward > best_reward:
                    log_message = f"{datetime.now().strftime(DATE_FORMAT)}: New best reward {episode_reward:0.1f} ({(episode_reward - best_reward) / best_reward * 100:+.1f}%) at episode {episode}, saving model..."
                    print(log_message)
                    with open(self.LOG_FILE, 'a') as file:
                        file.write(log_message + '\n')

                    torch.save(policy_dqn.state_dict(), self.MODEL_FILE)
                    best_reward = episode_reward

                if len(memory) > self.mini_batch_size:
                    mini_batch = memory.sample(self.mini_batch_size)
                    self.optimize(mini_batch, policy_dqn, target_dqn)

                    epsilon = max(epsilon * self.epsilon_decay, self.epsilon_min)
                    epsilon_history.append(epsilon)

                    if step_count > self.network_sync_rate:
                        target_dqn.load_state_dict(policy_dqn.state_dict())
                        step_count = 0

        # After training, render the final game state
        if not is_training:
            print("Training finished, displaying final game...")
            state, _ = env.reset()
            state = torch.tensor(state, dtype=torch.float, device=device)
            done = False
            while not done:
                with torch.no_grad():
                    action = policy_dqn(state.unsqueeze(0)).squeeze().argmax()
                new_state, reward, done, truncated, info = env.step(action.item())
                self.show_frame(new_state)
                state = torch.tensor(new_state, dtype=torch.float, device=device)
            print("Game over!")

    def show_frame(self, frame):
        """
        Display the current game frame using pygame.
        If the frame is not an image, print its values instead.
        """
        if isinstance(frame, torch.Tensor):
            frame = frame.cpu().numpy()  # Convert to numpy if it's a tensor

        if len(frame.shape) == 3 and frame.shape[2] == 3:
            # Normal image (height, width, 3)
            height, width, _ = frame.shape
        elif len(frame.shape) == 2:
            # Grayscale image (height, width)
            height, width = frame.shape
            frame = np.stack([frame] * 3, axis=-1)  # Convert grayscale to 3-channel by repeating the grayscale values
        else:
            # If it's not an image, just print the frame contents
            print("State vector:", frame)
            return

        # Ensure pygame is initialized
        if not pygame.get_init():
            pygame.init()

        # Create display window
        screen = pygame.display.set_mode((width, height))  # Set mode using width and height
        pygame.display.set_caption("Flappy Bird")

        frame = np.swapaxes(frame, 0, 1)  # Convert from (height, width, channels) to (width, height, channels)
        frame = pygame.surfarray.make_surface(frame)  # Create surface for pygame

        screen.blit(frame, (0, 0))  # Blit the frame to the screen
        pygame.display.update()  # Update the display
        pygame.time.delay(10)  # Delay to make the game viewable and slow down

    def optimize(self, mini_batch, policy_dqn, target_dqn):
        states, actions, new_states, rewards, terminations = zip(*mini_batch)

        states = torch.stack(states)
        actions = torch.stack(actions)
        new_states = torch.stack(new_states)
        rewards = torch.stack(rewards)
        terminations = torch.tensor(terminations).float().to(device)

        # Ensure the shapes match for loss calculation
        current_q_values = policy_dqn(states).gather(1, actions.unsqueeze(1))  # Shape: (batch_size, 1)
        
        with torch.no_grad():
            if self.enable_double_dqn:
                best_actions_from_policy = policy_dqn(new_states).argmax(dim=1)
                target_q_values = target_dqn(new_states).gather(1, best_actions_from_policy.unsqueeze(1))  # Shape: (batch_size, 1)
            else:
                target_q_values = target_dqn(new_states).max(dim=1)[0].unsqueeze(1)  # Shape: (batch_size, 1)

        # Compute expected Q-values
        expected_q_values = rewards + (self.discount_factor_g * target_q_values * (1 - terminations))

        # Compute loss
        loss = self.loss_fn(current_q_values, expected_q_values)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss


# To run the agent and start training:
hyperparameters = 'flappybird1'  # Replace with your actual hyperparameter set name
agent = Agent(hyperparameter_set=hyperparameters)
# agent.run(is_training=True, render=False)  # Train without rendering

# After training is done, show the final game
agent.run(is_training=False, render=True)


Loading pre-trained model...


  policy_dqn.load_state_dict(torch.load(self.MODEL_FILE))
  policy_dqn.load_state_dict(torch.load(self.MODEL_FILE))
2025-02-25 18:07:27.111 python[49600:6445665] +[IMKClient subclass]: chose IMKClient_Modern
2025-02-25 18:07:27.111 python[49600:6445665] +[IMKInputSession subclass]: chose IMKInputSession_Modern


State vector: [ 0.9861111   0.29296875  0.48828125  1.          0.          1.
  1.          0.          1.          0.4609375  -0.8         0.46666667]
State vector: [ 0.9722222   0.29296875  0.48828125  1.          0.          1.
  1.          0.          1.          0.44726562 -0.7         0.43333334]
State vector: [ 0.9583333   0.29296875  0.48828125  1.          0.          1.
  1.          0.          1.          0.43554688 -0.6         0.4       ]
State vector: [ 0.9444444   0.29296875  0.48828125  1.          0.          1.
  1.          0.          1.          0.42578125 -0.5         0.36666667]
State vector: [ 0.9305556   0.29296875  0.48828125  1.          0.          1.
  1.          0.          1.          0.41796875 -0.4         0.33333334]
State vector: [ 0.9166667   0.29296875  0.48828125  1.          0.          1.
  1.          0.          1.          0.41210938 -0.3         0.3       ]
State vector: [ 0.9027778   0.29296875  0.48828125  1.          0.          1.
  1

KeyboardInterrupt: 