In [3]:
!pip install flappy-bird-gymnasium

Collecting flappy-bird-gymnasium
  Downloading flappy_bird_gymnasium-0.4.0-py3-none-any.whl.metadata (4.5 kB)
Collecting gymnasium (from flappy-bird-gymnasium)
  Downloading gymnasium-1.0.0-py3-none-any.whl.metadata (9.5 kB)
Collecting farama-notifications>=0.0.1 (from gymnasium->flappy-bird-gymnasium)
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl.metadata (558 bytes)
Downloading flappy_bird_gymnasium-0.4.0-py3-none-any.whl (37.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m37.3/37.3 MB[0m [31m43.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading gymnasium-1.0.0-py3-none-any.whl (958 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m958.1/958.1 kB[0m [31m43.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)
Installing collected packages: farama-notifications, gymnasium, flappy-bird-gymnasium
Successfully installed farama-notifications-0.0.4 flappy-bird-gymnasium-0.4.0 gymnasium-1.0.0

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
from collections import deque
import flappy_bird_gymnasium
import gymnasium as gym
import time
import cv2
from torch.nn import functional as F
from torchvision import transforms

In [5]:
class FlappyQNetwork(nn.Module):
  def __init__(self, input_shape, actions):
    super(FlappyQNetwork, self).__init__()

    self.conv = nn.Sequential(
        nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
        nn.ReLU(),
        nn.Conv2d(32, 64, kernel_size=4, stride=2),
        nn.ReLU(),
        nn.Conv2d(64, 64, kernel_size=3, stride=1),
        nn.ReLU()
    )

    self.conv_output_size = self.compute_conv_output_size(input_shape)

    self.fc = nn.Sequential(
        nn.Linear(self.conv_output_size, 256),
        nn.ReLU(),
        nn.Linear(256, actions)
    )

  def compute_conv_output_size(self, input_shape):
      with torch.no_grad():
          dummy_input = torch.zeros(1, *input_shape)
          output = self.conv(dummy_input)
      return int(np.prod(output.size()))

  def forward(self, x):
      conv_out = self.conv(x).view(x.size(0), -1)
      return self.fc(conv_out)

In [6]:
def preprocess_state_outline(state):
    if len(state.shape) == 3 and state.shape[2] == 3:
        gray = state.sum(axis=2) / 3
    else:
        gray = state

    weights = torch.tensor([[-1, -1, -1],
                            [-1,  8, -1],
                            [-1, -1, -1]])
    weights = weights.reshape(1,1,        *weights.shape)
    gray = torch.tensor(gray).reshape(1,1,*gray.shape)
    output = F.conv2d(gray.byte(), weights.byte())
    output = output.reshape(output.shape[2],output.shape[3])

    resized = cv2.resize(np.array(output), (84, 84))
    normalized = resized / 255.0
    return normalized

In [7]:
def train_q_network(q_network, target_network, replay_buffer, optimizer, batch_size, gamma):
    if len(replay_buffer) < batch_size:
        return

    batch = random.sample(replay_buffer, batch_size)
    states, actions, rewards, next_states, final_state_check = zip(*batch)

    states = torch.tensor(np.array(states), dtype=torch.float32)
    actions = torch.tensor(actions, dtype=torch.long).unsqueeze(-1)
    rewards = torch.tensor(rewards, dtype=torch.float32)
    next_states = torch.tensor(np.array(next_states), dtype=torch.float32)
    final_state_check = torch.tensor(final_state_check, dtype=torch.float32)

    q_values = q_network(states).gather(1, actions).squeeze(-1)
    with torch.no_grad():
        max_next_q_values = target_network(next_states).max(dim=1)[0]
        target_q_values = rewards + gamma * max_next_q_values * (1 - final_state_check)

    loss = nn.MSELoss()(q_values, target_q_values)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [None]:
def stack_frames(frame, stacked_frames, is_new_episode):
    if is_new_episode:
          stacked_frames = np.stack([frame] * 4, axis=0)
    else:
          stacked_frames = np.concatenate((stacked_frames[1:, :, :], np.expand_dims(frame, 0)), axis=0)
    return stacked_frames

def train_q_agent():
    env = gym.make("FlappyBird-v0", render_mode="rgb_array")

    state_shape = (4, 84, 84)
    nr_actions = env.action_space.n
    gamma = 0.91
    learning_rate = 0.01
    batch_size = 32
    epsilon = 1.0
    epsilon_decay = 0.995
    min_epsilon = 0.01
    #replay_buffer = deque(maxlen=100000)

    OBSERVE = 1000
    REPLAY_BUFFER = 5000
    replay_buffer = deque(maxlen=REPLAY_BUFFER)
    EXPLORE = 10000
    INITIAL_EPSILON = 0.1
    FINAL_EPSILON = 0.0001


    q_network = FlappyQNetwork(state_shape, nr_actions)
    target_network = FlappyQNetwork(state_shape, nr_actions)
    target_network.load_state_dict(q_network.state_dict())
    optimizer = optim.Adam(q_network.parameters(), lr=learning_rate)

    episodes = EXPLORE
    epsilon = INITIAL_EPSILON
    skip_frames = 4

    for episode in range(episodes):
        env.reset()
        raw_state = env.render()
        state = preprocess_state_outline(raw_state)
        stacked_frames = stack_frames(state, None, is_new_episode=True)

        total_reward = 0
        done = False
        while not done:
            if random.random() < epsilon:
                action = env.action_space.sample()
            else:
                with torch.no_grad():
                    state_tensor = torch.tensor(stacked_frames, dtype=torch.float32).unsqueeze(0)
                    action = q_network(state_tensor).argmax().item()

            _, reward, done, _, _ = env.step(action)

            if action == 1:
              for _ in range(skip_frames):
                _, frame_reward, frame_done, _, _ = env.step(0)
                reward += frame_reward
                done = done or frame_done
                if done:
                    break

            raw_next_state = env.render()
            next_state = preprocess_state_outline(raw_next_state)
            next_stacked_frames = stack_frames(next_state, stacked_frames, is_new_episode=False)

            replay_buffer.append((stacked_frames, action, reward, next_stacked_frames, done))
            if replay_buffer.__len__() > OBSERVE:
                replay_buffer.popleft()

            stacked_frames = next_stacked_frames
            total_reward += reward


            if episode > OBSERVE:
                train_q_network(q_network, target_network, replay_buffer, optimizer, batch_size, gamma)


        if epsilon > FINAL_EPSILON and episode > OBSERVE:
            epsilon -= (INITIAL_EPSILON - FINAL_EPSILON) /EXPLORE

        if episode % 10 == 0:
            target_network.load_state_dict(q_network.state_dict())

        print(f"Episode {episode}, Total Reward: {total_reward}, Epsilon: {epsilon}")

    env.close()

if __name__ == "__main__":
    train_q_agent()

Episode 0, Total Reward: 0.8000000000000012, Epsilon: 0.1
Episode 1, Total Reward: -0.29999999999999893, Epsilon: 0.1
Episode 2, Total Reward: 0.8000000000000012, Epsilon: 0.1
Episode 3, Total Reward: -0.8999999999999986, Epsilon: 0.1
Episode 4, Total Reward: -0.09999999999999964, Epsilon: 0.1
Episode 5, Total Reward: -2.6999999999999984, Epsilon: 0.1
Episode 6, Total Reward: 0.8000000000000012, Epsilon: 0.1
Episode 7, Total Reward: 1.4000000000000017, Epsilon: 0.1
Episode 8, Total Reward: -1.5999999999999994, Epsilon: 0.1
Episode 9, Total Reward: 0.8000000000000012, Epsilon: 0.1
Episode 10, Total Reward: 5.399999999999994, Epsilon: 0.1
Episode 11, Total Reward: 3.1999999999999975, Epsilon: 0.1
Episode 12, Total Reward: 1.4000000000000017, Epsilon: 0.1
Episode 13, Total Reward: 1.2000000000000015, Epsilon: 0.1
Episode 14, Total Reward: 2.0999999999999996, Epsilon: 0.1
Episode 15, Total Reward: 2.0999999999999996, Epsilon: 0.1
Episode 16, Total Reward: 0.8000000000000012, Epsilon: 0.1
E