<a href="https://colab.research.google.com/github/anaumghori/FlappyBird/blob/main/Flappy_Bird_RL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Import necessary libraries**

In [None]:
!pip install pygame pillow

In [None]:
import os
import sys
if "google.colab" in sys.modules:
    import os
    if not os.path.exists('PyGame-Learning-Environment'):
        !git clone https://github.com/ntasfi/PyGame-Learning-Environment.git
    os.chdir('PyGame-Learning-Environment')
    !pip install -e .
    os.chdir('..')
    import sys
    sys.path.append('./PyGame-Learning-Environment')

In [None]:
import random
from collections import deque
from itertools import count
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import cv2
from ple import PLE
from ple.games.flappybird import FlappyBird
import pygame

# Set up headless operation in Colab
os.putenv('SDL_VIDEODRIVER', 'dummy')
os.environ["SDL_VIDEODRIVER"] = "dummy"

# **Define Dueling DQN**

In [None]:
class DQN(nn.Module):
    def __init__(self, input_dim, output_dim, network_type='DQN'):
        super().__init__()
        self.network_type = network_type
        self.layer1 = nn.Linear(input_dim, 64)
        self.layer2 = nn.Linear(64, 128)
        self.layer3 = nn.Linear(128, 256)
        self.layer4 = nn.Linear(256, 512)

        if network_type == 'DuelingDQN':
            self.state_values = nn.Linear(512, 1)
            self.advantages = nn.Linear(512, output_dim)
        else:
            self.output = nn.Linear(512, output_dim)

    def forward(self, x):
        x = F.relu6(self.layer1(x))
        x = F.relu6(self.layer2(x))
        x = F.relu6(self.layer3(x))
        x = F.relu6(self.layer4(x))
        if self.network_type == 'DuelingDQN':
            state_values = self.state_values(x)
            advantages = self.advantages(x)
            return state_values + (advantages - advantages.mean(dim=1, keepdim=True))
        else:
            return self.output(x)

# Memory replay buffer
class MemoryRecall:
    def __init__(self, memory_size):
        self.memory = deque(maxlen=memory_size)

    def cache(self, data):
        self.memory.append(data)

    def recall(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

# **Define Agent**

In [None]:
# Define Agent
class Agent:
    def __init__(self, config):
        self.BATCH_SIZE = config['batch_size']
        self.GAMMA = config['gamma']
        self.TAU = config['tau']
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.epsilon = config['eps_start']
        self.EPS_DECAY = config['eps_decay']
        self.EPS_MIN = config['eps_min']
        self.steps_done = 0

        self.policy_net = DQN(config['input_dim'], config['output_dim'], config['network_type']).to(self.device)
        self.target_net = DQN(config['input_dim'], config['output_dim'], config['network_type']).to(self.device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()

        self.optimizer = optim.AdamW(self.policy_net.parameters(), lr=config['lr'])
        self.memory = MemoryRecall(config['memory_size'])

        self.action_dict = config['action_dict']
        self.episode_durations = []

    def take_action(self, state):
        self.epsilon = max(self.epsilon * self.EPS_DECAY, self.EPS_MIN)
        if random.random() > self.epsilon:
            with torch.no_grad():
                state = state.unsqueeze(0).to(self.device)
                action = torch.argmax(self.policy_net(state)).item()
        else:
            action = random.choice(list(self.action_dict.keys()))
        return action

    def optimize_model(self):
        if len(self.memory) < self.BATCH_SIZE:
            return

        batch = self.memory.recall(self.BATCH_SIZE)
        state_batch, next_state_batch, action_batch, reward_batch, done_batch = zip(*batch)

        state_batch = torch.stack(state_batch).to(self.device)
        action_batch = torch.tensor(action_batch).unsqueeze(1).to(self.device)
        reward_batch = torch.cat(reward_batch).to(self.device)

        non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, next_state_batch)), dtype=torch.bool, device=self.device)
        non_final_next_states = torch.stack([s for s in next_state_batch if s is not None]).to(self.device)

        state_action_values = self.policy_net(state_batch).gather(1, action_batch)

        next_state_values = torch.zeros(self.BATCH_SIZE, device=self.device)
        with torch.no_grad():
            next_state_values[non_final_mask] = self.target_net(non_final_next_states).max(1)[0]

        expected_state_action_values = (next_state_values * self.GAMMA) + reward_batch
        loss = nn.SmoothL1Loss()(state_action_values, expected_state_action_values.unsqueeze(1))

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def update_target_net(self):
        for target_param, policy_param in zip(self.target_net.parameters(), self.policy_net.parameters()):
            target_param.data.copy_(self.TAU * policy_param.data + (1 - self.TAU) * target_param.data)

    def train(self, env, episodes):
        for episode in range(episodes):
            env.reset_game()
            state = torch.tensor(list(env.getGameState().values()), dtype=torch.float32)

            for t in count():
                action = self.take_action(state)
                reward = env.act(self.action_dict[action])
                reward = torch.tensor([reward], dtype=torch.float32)

                next_state = env.getGameState()
                next_state = torch.tensor(list(next_state.values()), dtype=torch.float32) if not env.game_over() else None

                self.memory.cache((state, next_state, action, reward, env.game_over()))

                state = next_state

                self.optimize_model()
                self.update_target_net()

                if env.game_over():
                    self.episode_durations.append(t + 1)
                    break

        self.plot_durations()

    def plot_durations(self):
        plt.figure(1)
        plt.clf()
        durations = torch.tensor(self.episode_durations, dtype=torch.float)
        plt.title('Training')
        plt.xlabel('Episode')
        plt.ylabel('Duration')
        plt.plot(durations.numpy())
        if len(durations) >= 100:
            means = durations.unfold(0, 100, 1).mean(1).view(-1)
            plt.plot(torch.cat((torch.zeros(99), means)).numpy())
        plt.savefig('/content/training_progress.png')

    def save_replay_video(self, env, filename="/content/flappybird_replay.avi", fps=30):
        env.display_screen = True
        env.force_fps = False

        frame_width = env.getScreenDims()[0]
        frame_height = env.getScreenDims()[1]
        fourcc = cv2.VideoWriter_fourcc(*'XVID')
        out = cv2.VideoWriter(filename, fourcc, fps, (frame_width, frame_height))

        env.reset_game()
        state = torch.tensor(list(env.getGameState().values()), dtype=torch.float32)

        while not env.game_over():
            frame = env.getScreenRGB()
            out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))

            action = self.take_action(state)
            env.act(self.action_dict[action])
            state = torch.tensor(list(env.getGameState().values()), dtype=torch.float32)

        out.release()
        env.display_screen = False

# **Main Script**

In [None]:
if __name__ == "__main__":
    game = FlappyBird(width=256, height=256)
    env = PLE(game, display_screen=False)
    env.init()

    actions = env.getActionSet()
    action_dict = {0: actions[1], 1: actions[0]}

    config = {
        'batch_size': 32,
        'memory_size': 100000,
        'gamma': 0.99,
        'tau': 0.005,
        'eps_start': 1.0,
        'eps_decay': 0.999995,
        'eps_min': 0.05,
        'lr': 1e-4,
        'input_dim': len(env.getGameState()),
        'output_dim': len(action_dict),
        'network_type': 'DuelingDQN',
        'action_dict': action_dict
    }

    agent = Agent(config)
    agent.train(env, episodes=20000)
    agent.save_replay_video(env)