# DSA - Deep Learning [5] - Reinforcement learning

In [None]:
# Install necessary libraries
!pip install flappy-bird-gymnasium pygame
!apt-get install -y xvfb python3-opengl ffmpeg
!pip install pyvirtualdisplay
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu

Collecting flappy-bird-gymnasium
  Downloading flappy_bird_gymnasium-0.4.0-py3-none-any.whl.metadata (4.5 kB)
Collecting gymnasium (from flappy-bird-gymnasium)
  Downloading gymnasium-1.0.0-py3-none-any.whl.metadata (9.5 kB)
Collecting farama-notifications>=0.0.1 (from gymnasium->flappy-bird-gymnasium)
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl.metadata (558 bytes)
Downloading flappy_bird_gymnasium-0.4.0-py3-none-any.whl (37.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m37.3/37.3 MB[0m [31m18.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading gymnasium-1.0.0-py3-none-any.whl (958 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m958.1/958.1 kB[0m [31m25.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)
Installing collected packages: farama-notifications, gymnasium, flappy-bird-gymnasium
Successfully installed farama-notifications-0.0.4 flappy-bird-gymnasium-0.4.0 gymnasium-1.0.0

In [None]:
# Import necessary libraries
import os
import torch
import random
import numpy as np
import pygame
import imageio
from IPython.display import display, Image
from PIL import Image as PILImage  # Importing PIL for image manipulation
from flappy_bird_gymnasium.envs.flappy_bird_env import FlappyBirdEnv

# Set environment variables for rendering and audio in Colab
os.environ["SDL_VIDEODRIVER"] = "dummy"
os.environ["SDL_AUDIODRIVER"] = "dummy"


pygame 2.6.1 (SDL 2.28.4, Python 3.10.12)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [None]:
class CustomFlappyBirdEnv(FlappyBirdEnv):
    def __init__(self):
        super().__init__()

        # Initialize pygame and enforce dummy display
        pygame.init()
        if not pygame.display.get_init():
            pygame.display.init()
        pygame.display.set_mode((1, 1))  # Enforce dummy video mode

        # Initialize pygame mixer for audio
        if not pygame.mixer.get_init():
            pygame.mixer.init()

        # Initialize game surface
        self._surface = pygame.Surface((288, 512))  # Game surface dimensions

        # Initialize display surface (required for FlappyBirdEnv rendering)
        self._display = pygame.display.set_mode((288, 512))  # Create display window of appropriate size

        # Initialize the FPS clock for controlling the frame rate
        self._fps_clock = pygame.time.Clock()  # Initialize the FPS clock

        # Initialize image assets
        self._images = {}

        # Load images required for the game
        self._images["background"] = self._load_image("background-day.png")
        self._images["pipe"] = [
            self._load_image("pipe-green.png"),  # Top pipe
            pygame.transform.flip(self._load_image("pipe-green.png"), False, True)  # Bottom pipe (flipped)
        ]
        self._images["base"] = self._load_image("base.png")
        self._images["player"] = [
            self._load_image("yellowbird-upflap.png"),
            self._load_image("yellowbird-midflap.png"),
            self._load_image("yellowbird-downflap.png"),
        ]
        self._images["numbers"] = {
            i: self._load_image(f"{i}.png") for i in range(10)  # Load images for digits 0-9
        }

        # Load audio assets if needed
        self._audio = {
            "wing": self._load_audio("wing.wav"),
            "point": self._load_audio("point.wav"),
            "hit": self._load_audio("hit.wav"),
            "die": self._load_audio("die.wav"),
        }

        # Additional attributes required by the parent class
        self._score = 0
        self._player_index = 0
        self._base_shift = self._images["base"].get_width() - self._surface.get_width()
        self._pipes = []
        self._player_y = 256
        self._player_velocity_y = 0
        self._gravity = 1
        self._pipe_gap = 100

    def _load_image(self, filename):
        """
        Load an image from the assets directory.
        Args:
            filename: Name of the image file.
        Returns:
            Loaded pygame image.
        """
        assets_path = "/usr/local/lib/python3.10/dist-packages/flappy_bird_gymnasium/assets/sprites"
        filepath = os.path.join(assets_path, filename)
        return pygame.image.load(filepath).convert_alpha()

    def _load_audio(self, filename):
        """
        Load an audio file from the assets directory.
        Args:
            filename: Name of the audio file.
        Returns:
            Loaded pygame audio sound.
        """
        assets_path = "/usr/local/lib/python3.10/dist-packages/flappy_bird_gymnasium/assets/audio"
        filepath = os.path.join(assets_path, filename)
        return pygame.mixer.Sound(filepath)

    def render(self):
        """
        Render the game screen to the display and capture the frame for Colab visualization.
        """
        super().render()  # Call the parent class's render method

        # Capture the screen as an array
        frame = pygame.surfarray.array3d(pygame.display.get_surface())
        self.frames.append(frame)  # Save the frame for GIF creation

        # Control frame rate
        self._fps_clock.tick(self.metadata["render_fps"])

    def create_gif(self, gif_name="flappy_bird_game.gif"):
        """
        Create and display a GIF from the captured frames.
        """
        flipped_frames = []
        for frame in self.frames:
            pil_frame = PILImage.fromarray(frame)
            flipped_frame = pil_frame.rotate(270, expand=True)  # Rotate 270 degrees
            flipped_frames.append(flipped_frame)

        # Save and display the GIF
        flipped_gif_name = gif_name.replace(".gif", "_flipped.gif")
        imageio.mimsave(flipped_gif_name, flipped_frames, duration=1 / self.metadata["render_fps"])
        display(Image(flipped_gif_name))

    def reset(self):
        """
        Reset the environment and clear the stored frames.
        """
        self.frames = []  # Clear captured frames
        return super().reset()
