In [2]:
import gymnasium as gym
import flappy_bird_gymnasium

import torch
from torch import nn, optim

import numpy as np
import cv2
import random
import matplotlib.pyplot as plot
import datetime
import time
import os
import glob
from collections import deque

In [3]:
STACK_SIZE = 4
BUFFER_SIZE = 50000
BATCH_SIZE = 64
GAMMA = 0.99
LR = 1e-4
TARGET_UPDATE = 1000
EPISODES= 5000
EPS_START = 1.0
EPS_END = 0.01
EPS_DECAY_RATE = 0.998

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def preprocess(frame):
    if frame is None or len(frame.shape) != 3 or frame.shape[2] != 3:
        frame = np.zeros((288, 512, 3), dtype=np.uint8)
    gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
    gray = cv2.resize(gray, (84, 84), interpolation=cv2.INTER_AREA)
    return gray.astype(np.float32) / 255.0

In [5]:
def setup_experiment(base_dir="checkpoints"):
    os.makedirs(base_dir, exist_ok=True)
    existing = glob.glob(os.path.join(base_dir, "exp_*"))
    ids = [int(e.split("_")[-1]) for e in existing if "_" in e]
    next_id = max(ids) + 1 if ids else 1
    path = os.path.join(base_dir, f"exp_{next_id:03d}")
    os.makedirs(path, exist_ok=True)
    print(f"Initialized: {path}")
    return path

In [6]:
exp_path=setup_experiment()

Initialized: checkpoints\exp_015


In [None]:
class FlappyCNN(nn.Module):
    """
    Deep Q-Network for Flappy Bird with 3 convolutional layers
    and 2 fully connected layers.
    Input: stacked grayscale frames (C=4, H=84, W=84)
    Output: Q-values for each action
    """

    def __init__(self, num_actions: int, input_channels: int = 4):
        super().__init__()
        # Convolutional feature extractor
        self.conv = nn.Sequential(
            nn.Conv2d(
                input_channels, 32, kernel_size=8, stride=4
            ),  # 84x84x4 -> 20x20x32
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),  # 20x20x32 -> 9x9x64
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),  # 9x9x64 -> 7x7x64
            nn.ReLU(),
            nn.Flatten(),  # 7x7x64 = 3136
        )

        # Dynamically compute the flattened size after convs
        with torch.no_grad():
            dummy = torch.zeros(1, input_channels, 84, 84)
            conv_out_size = self.conv(dummy).size(1)

        # Fully connected decision layers
        self.state_value_stream = nn.Sequential(
            nn.Linear(conv_out_size, 512), nn.ReLU(), nn.Linear(512, 1)
        )
        
        self.action_advantage_stream = nn.Sequential(
            nn.Linear(conv_out_size, 512), nn.ReLU(), nn.Linear(512, num_actions)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv(x)
        state_value = self.state_value_stream(x)
        action_advantage = self.action_advantage_stream(x)
        return state_value + (action_advantage - action_advantage.mean(dim=1, keepdim=True))

In [None]:
class ReplayBuffer:
    """
    Fixed-size buffer to store experience tuples for DQN.
    Stores: (state, action, reward, next_state, done)
    """

    def __init__(self, capacity: int):
        self.buffer = deque(maxlen=capacity)

    def push(self, state: torch.Tensor, action: int, reward: float,
             next_state: torch.Tensor, done: bool):
        """
        Add a transition to the buffer.
        All states are expected as torch.FloatTensor.
        """
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size: int):
        """
        Sample a random batch from memory.
        Returns tensors suitable for training.
        """
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)

        states = torch.stack(states) 
        next_states = torch.stack(next_states)
        actions = torch.tensor(actions, dtype=torch.long)
        rewards = torch.tensor(rewards, dtype=torch.float32)
        dones = torch.tensor(dones, dtype=torch.float32)

        return states, actions, rewards, next_states, dones

    def __len__(self):
        return len(self.buffer)

In [None]:
class FrameStack:
    def __init__(self, k):
        self.k = k
        self.frames = deque(maxlen=k)

    def reset(self, frame):
        for _ in range(self.k):
            self.frames.append(frame)
        return torch.from_numpy(np.stack(self.frames, axis=0))

    def step(self, frame):
        self.frames.append(frame)
        return torch.from_numpy(np.stack(self.frames, axis=0))

In [None]:
env = gym.make("FlappyBird-v0", render_mode="rgb_array")
num_actions = env.action_space.n
stacker = FrameStack(STACK_SIZE)

policy_net = FlappyCNN(num_actions).to(DEVICE)
target_net = FlappyCNN(num_actions).to(DEVICE)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
optimizer = optim.Adam(policy_net.parameters(), lr=LR)
replay_buffer = ReplayBuffer(BUFFER_SIZE)


In [None]:
episode_rewards_history = []
episode_scores_history = []
best_score_achieved = 0
epsilon_greedy_value = EPS_START
total_training_steps = 0
try:
    for episode in range(1, EPISODES + 1):
        obs, _ = env.reset()
        frame = preprocess(env.render())
        state = stacker.reset(frame).to(DEVICE)
        episode_reward = 0
        done = False
        while not done:
            if random.random() < epsilon_greedy_value:
                action = env.action_space.sample()  # Explore
            else:
                with torch.no_grad():
                    state_input = state.unsqueeze(0).to(DEVICE)
                    q_values = policy_net(state_input)
                    action = q_values.argmax(1).item()  # Exploit

            _, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated
            episode_reward += reward
            next_frame = preprocess(env.render())
            next_state = stacker.step(next_frame).to(DEVICE)
            # Store transition
            replay_buffer.push(state, action, reward, next_state, done)
            state = next_state
            total_training_steps += 1
            # Training step
            if len(replay_buffer) >= BATCH_SIZE:
                s, a, r, ns, d = replay_buffer.sample(BATCH_SIZE)
                s, a, r, ns, d = (
                    s.to(DEVICE),
                    a.to(DEVICE),
                    r.to(DEVICE),
                    ns.to(DEVICE),
                    d.to(DEVICE),
                )
                q = policy_net(s).gather(1, a.unsqueeze(1)).squeeze(1).to(DEVICE)
                with torch.no_grad():
                    next_q = target_net(ns).max(1)[0]
                    target = r + GAMMA * next_q * (1 - d)
                loss = nn.functional.smooth_l1_loss(q, target)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            # Update target network
            if total_training_steps % TARGET_UPDATE == 0:
                target_net.load_state_dict(policy_net.state_dict())

        episode_score = info.get("score", 0)
        if episode_score >= best_score_achieved:
            best_score_achieved = episode_score
            torch.save(policy_net.state_dict(), os.path.join(exp_path, "best_model_score.pth"))

        episode_scores_history.append(info.get("score", 0))
        episode_rewards_history.append(episode_reward)
        epsilon_greedy_value = max(EPS_END, epsilon_greedy_value * EPS_DECAY_RATE)
        if episode % 10 == 0:
            print(
                f"Episode: {episode} | "
                f"Average Score: {np.mean(episode_scores_history[-10:])} | "
                f"Best: {best_score_achieved} | "
                f"Epsilon: {epsilon_greedy_value:.3f}"
            )
except KeyboardInterrupt:
    print("Training interrupted by user.")

torch.save(policy_net.state_dict(),os.path.join(exp_path, "final_model.pth"))

  logger.warn(f"{pre} is not within the observation space.")
  logger.warn(f"{pre} is not within the observation space.")


Episode: 10 | Average Score: 0.0 | Best: 0 | Epsilon: 0.980
Episode: 20 | Average Score: 0.0 | Best: 0 | Epsilon: 0.961
Episode: 30 | Average Score: 0.0 | Best: 0 | Epsilon: 0.942
Episode: 40 | Average Score: 0.0 | Best: 0 | Epsilon: 0.923
Episode: 50 | Average Score: 0.0 | Best: 0 | Epsilon: 0.905
Episode: 60 | Average Score: 0.0 | Best: 0 | Epsilon: 0.887
Episode: 70 | Average Score: 0.0 | Best: 0 | Epsilon: 0.869
Episode: 80 | Average Score: 0.0 | Best: 0 | Epsilon: 0.852
Episode: 90 | Average Score: 0.0 | Best: 0 | Epsilon: 0.835
Episode: 100 | Average Score: 0.0 | Best: 0 | Epsilon: 0.819
Episode: 110 | Average Score: 0.0 | Best: 0 | Epsilon: 0.802
Episode: 120 | Average Score: 0.0 | Best: 0 | Epsilon: 0.786
Episode: 130 | Average Score: 0.0 | Best: 0 | Epsilon: 0.771
Episode: 140 | Average Score: 0.0 | Best: 0 | Epsilon: 0.756
Episode: 150 | Average Score: 0.0 | Best: 0 | Epsilon: 0.741
Episode: 160 | Average Score: 0.0 | Best: 0 | Epsilon: 0.726
Episode: 170 | Average Score: 0.0

In [None]:
figure, (reward_axis, score_axis) = plot.subplots(2, 1, figsize=(10, 8))

reward_axis.plot(
    episode_rewards_history,
    color="blue",
    alpha=0.6
)
reward_axis.set_title("Reward per episode")
if len(episode_rewards_history) >= 100:
    ma100 = np.convolve(episode_rewards_history, np.ones(100)/100, mode='valid')
    reward_axis.plot(range(99, len(episode_rewards_history)), ma100, color="red")
score_axis.plot(
    episode_scores_history,
    color="green",
    alpha=0.6
)
if len(episode_scores_history) >= 100:
    ma100 = np.convolve(episode_scores_history, np.ones(100)/100, mode='valid')
    score_axis.plot(range(99, len(episode_scores_history)), ma100, color="green", label="MA(100)")
score_axis.set_title("Score per episode")

plot.tight_layout()
plot.savefig(os.path.join(exp_path, "training_graph.png"))
plot.show()

NameError: name 'plot' is not defined

In [None]:
model = FlappyCNN(num_actions).to(DEVICE)
model.load_state_dict(torch.load("checkpoints\\best_flappy_pixels.pth", map_location=DEVICE))
model.eval()
try:
    for episode in range(10):
        obs, _ = env.reset()
        frame = env.render()
        state = stacker.reset(preprocess(frame)).to(DEVICE)
        done = False
        while not done:
            state_tensor = state.unsqueeze(0).to(DEVICE)
            show_screen=cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
            cv2.imshow("Flappy Bird", show_screen)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
            with torch.no_grad():
                action = model(state_tensor).argmax().item()

            _, reward, term, trunc, info = env.step(action)
            frame = env.render()
            next_frame = preprocess(frame)
            state = stacker.step(next_frame).to(DEVICE)
            total_reward=info.get("score",0)
            if term or trunc:
                print(f"Episode {episode+1} finished. Score: {total_reward}")
                break
except KeyboardInterrupt:
    print("Evaluation interrupted by user.")
env.close()
cv2.destroyAllWindows()

Episode 1 finished. Score: 56
Episode 2 finished. Score: 74
Episode 3 finished. Score: 72
Episode 4 finished. Score: 5
Episode 5 finished. Score: 323
Episode 6 finished. Score: 164
Episode 7 finished. Score: 25
Episode 8 finished. Score: 1
Episode 9 finished. Score: 47
Episode 10 finished. Score: 299
