<a href="https://colab.research.google.com/github/RaresFelix/ppo_lstm/blob/main/PPO_LSTM_Games.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install minigrid



# Game & Rendering logic (just fold & run cells)

In [2]:
import gymnasium as gym
import minigrid
import numpy as np
from gymnasium import spaces
from minigrid.minigrid_env import MiniGridEnv
from minigrid.core.mission import MissionSpace
from minigrid.core.grid import Grid
from minigrid.core.world_object import Ball, Box, Key, Wall, Goal
from abc import abstractmethod
import re
from PIL import Image
from minigrid.core.constants import (
    COLOR_NAMES,
    DIR_TO_VEC,
    OBJECT_TO_IDX,
    COLOR_TO_IDX,
    STATE_TO_IDX
)
from IPython.display import HTML, display
import base64
import io

class DisjointSetUnion:
    def __init__(self, n):
        self.e = [-1] * n

    def parent(self, u):
        while self.e[u] >= 0:
            u = self.e[u]
        return u

    def join(self, u, v):
        u, v = self.parent(u), self.parent(v)
        if u == v:
            return
        if self.e[u] > self.e[v]:
            u, v = v, u
        self.e[u] += self.e[v]
        self.e[v] = u

    def same(self, u, v):
        return self.parent(u) == self.parent(v)


class MiniGridCustomMaze(MiniGridEnv):
    def __init__(
        self,
        size=8,
        agent_start_pos=(1, 1),
        agent_start_dir=0,
        max_steps: int | None = None,
        **kwargs,
    ):
        self.agent_start_pos = agent_start_pos
        self.agent_start_dir = agent_start_dir

        mission_space = MissionSpace(mission_func=self._gen_mission)

        if size > 13:
            print("WARNING: max_steps(128) is not enough for large grid size, consider increasing max_steps")
        super().__init__(
            mission_space=mission_space,
            grid_size=size,
            max_steps=256,
            see_through_walls=True,
            **kwargs,
        )

    @staticmethod
    def _gen_mission():
        return ''

    @abstractmethod
    def _gen_grid(self, width, height):
        self.grid = Grid(width, height)
        self.grid.wall_rect(0, 0, width, height)

        # Generate a tree of free places with kruskal
        def _gen_maze_kruskal(w, h):
            is_free = np.zeros((w, h), dtype=bool)
            for x in range(0, w, 2):
                for y in range(0, h, 2):
                    is_free[x, y] = True
            walls = [(x, y) for x in range(w) for y in range(h) if (x % 2) + (y % 2) == 1]
            DSU = DisjointSetUnion(w * h)
            while walls:
                x, y = walls.pop(self._rand_int(0, len(walls)))
                if is_free[x, y]:
                    continue
                #check if all free neighbors are in different sets
                neighbors = [(x + dx, y + dy) for dx, dy in DIR_TO_VEC if 0 <= x + dx < w and 0 <= y + dy < h]
                neighbors = [(nx, ny) for nx, ny in neighbors if is_free[nx, ny]]
                neighbors_sets = set(DSU.parent(nx + ny * w) for nx, ny in neighbors)
                ok = len(neighbors_sets) == len(neighbors)

                if ok:
                    is_free[x, y] = True
                    for nx, ny in neighbors:
                        DSU.join(x + y * w, nx + ny * w)
            return is_free

        inside_maze_free = _gen_maze_kruskal(width - 2, height - 2)
        free_cells = [(x, y) for x in range(1, width - 1) for y in range(1, height - 1) if inside_maze_free[x - 1, y - 1]]

        self.agent_pos = free_cells.pop(self._rand_int(0, len(free_cells)))
        self.agent_dir = 0
        self.goal_pos = free_cells.pop(self._rand_int(0, len(free_cells)))
        self.grid.set(*self.goal_pos, Goal())

        for x in range(1, width - 1):
            for y in range(1, height - 1):
                if not inside_maze_free[x - 1, y - 1]:
                    self.grid.set(x, y, Wall())
    def step(self, action):
        obs, reward, terminated, truncated, info = super().step(action)
        return obs, reward, terminated, truncated, info
    def reset(self, seed=None, options=None):
        obs, info = super().reset(seed=seed, options=options)
        return obs, info
    def render(self):
        return super().render()

class MiniGridCustomMazeRandom(MiniGridCustomMaze):
    def __init__(
        self,
        max_size=8,
        min_size=6,
        agent_start_pos=(1, 1),
        agent_start_dir=0,
        max_steps: int | None = None,
        **kwargs,
    ):
        self.max_size = max_size
        self.min_size = min_size
        size = self._rand_int(min_size, max_size + 1)
        self.width = size
        self.height = size
        super().__init__(
            size=size,
            agent_start_pos=agent_start_pos,
            agent_start_dir=agent_start_dir,
            max_steps=max_steps,
            **kwargs,
        )

    def reset(self, seed=None, options=None):
        # Choose new random size before reset
        size = self._rand_int(self.min_size, self.max_size + 1)
        self.width = size
        self.height = size
        self.grid_size = size
        self.grid = None
        return super().reset(seed=seed, options=options)



In [3]:
from IPython.display import display, clear_output
import ipywidgets as widgets
from PIL import Image
import io
import time
import numpy as np

def create_game_interface(env, use_obs=False):
    # Use an ipywidgets.Image widget for a faster update.
    img_widget = widgets.Image(format='png', width=500, height=500)

    # Movement buttons
    up_button = widgets.Button(description='↑')
    down_button = widgets.Button(description='↓')
    left_button = widgets.Button(description='←')
    right_button = widgets.Button(description='→')
    reset_button = widgets.Button(description='Reset')

    # Game state
    game_state = {
        'running': True,
        'env': env,
        'last_update': 0,
        'current_obs': None  # Store current observation
    }

    def update_display():
        # Debounce: update only if 50ms have passed since the last update
        current_time = time.time() * 1000
        if current_time - game_state['last_update'] < 50:
            return
        game_state['last_update'] = current_time

        # Get the image either from observation or render
        if use_obs and game_state['current_obs'] is not None:
            image_array = game_state['current_obs']['image']
        else:
            image_array = game_state['env'].render()

        # Ensure the image is in the correct format (uint8)
        if image_array.dtype != np.uint8:
            image_array = (image_array * 255).astype(np.uint8)

        # Convert array to an image and write to a buffer
        img = Image.fromarray(image_array)
        buffer = io.BytesIO()
        # You could experiment with lowering quality or removing optimize if needed
        img.save(buffer, format='PNG', optimize=True, quality=85)

        # Update the image widget directly with the new binary data
        img_widget.value = buffer.getvalue()

    def move(action):
        obs, reward, terminated, truncated, info = game_state['env'].step(action)
        game_state['current_obs'] = obs  # Store the observation
        update_display()
        if terminated or truncated:
            obs, info = game_state['env'].reset()
            game_state['current_obs'] = obs  # Store the reset observation
            update_display()

    def reset_game(b):
        obs, info = game_state['env'].reset()
        game_state['current_obs'] = obs  # Store the reset observation
        update_display()

    # Button click handlers
    up_button.on_click(lambda b: move(2))    # Forward
    left_button.on_click(lambda b: move(0))   # Left turn
    right_button.on_click(lambda b: move(1))  # Right turn
    reset_button.on_click(reset_game)

    # Keyboard handler: remains mostly the same, though it now calls our Python functions.
    keyboard_handler = widgets.HTML('''
        <script>
        var actions = {
            'w': 2,  // Forward
            'a': 0,  // Left turn
            'd': 1,  // Right turn
            'r': 'reset'
        };

        var lastKeyTime = 0;
        var pendingExecution = false;
        const DEBOUNCE_TIME = 50;
        var lastExecutionTime = 0;

        function executeAction(action) {
            var now = Date.now();
            if (pendingExecution || (now - lastExecutionTime < 20)) return;
            pendingExecution = true;
            lastExecutionTime = now;
            if (action === 'reset') {
                IPython.notebook.kernel.execute('reset_game(None)', {}, {silent: true}).then(() => { pendingExecution = false; });
            } else {
                IPython.notebook.kernel.execute('move(' + action + ')', {}, {silent: true}).then(() => { pendingExecution = false; });
            }
        }

        function handleKey(event) {
            var currentTime = Date.now();
            if (currentTime - lastKeyTime < DEBOUNCE_TIME) return;
            var key = event.key.toLowerCase();
            if (key in actions) {
                lastKeyTime = currentTime;
                executeAction(actions[key]);
                event.preventDefault();
            }
        }

        document.addEventListener('keydown', handleKey);
        </script>
    ''')

    # Layout: We display a header, instructions, the image, buttons, and the keyboard script.
    buttons = widgets.HBox([left_button, up_button, right_button, reset_button])
    display(widgets.HTML('<h3>MiniGrid Maze Game</h3>'))
    display(widgets.HTML('<p>Controls: W (↑), A (←), D (→), R (reset) or use buttons below</p>'))
    display(img_widget)
    display(buttons)
    display(keyboard_handler)

    # Initial display update
    obs, info = game_state['env'].reset()  # Get initial observation
    game_state['current_obs'] = obs  # Store the initial observation
    update_display()

    return game_state

# Games

## MiniGrid Maze

In [4]:
def start_minigrid_maze(size=13):
    env = MiniGridCustomMaze(size=size, agent_view_size = 3, render_mode='rgb_array')
    env.reset()
    game_state = create_game_interface(env)
    return game_state

def start_minigrid_maze_agent_perspective(size=13):
    env = MiniGridCustomMaze(size=size, agent_view_size = 3, render_mode='rgb_array')
    env = minigrid.wrappers.RGBImgPartialObsWrapper(env, tile_size = 64)
    obs, _ = env.reset()
    game_state = create_game_interface(env, use_obs = True)
    return game_state

In [5]:
# Full observation
game = start_minigrid_maze()

HTML(value='<h3>MiniGrid Maze Game</h3>')

HTML(value='<p>Controls: W (↑), A (←), D (→), R (reset) or use buttons below</p>')

Image(value=b'', height='500', width='500')

HBox(children=(Button(description='←', style=ButtonStyle()), Button(description='↑', style=ButtonStyle()), But…

HTML(value="\n        <script>\n        var actions = {\n            'w': 2,  // Forward\n            'a': 0, …

In [72]:
#From agent perspective
game = start_minigrid_maze_agent_perspective()

HTML(value='<h3>MiniGrid Maze Game</h3>')

HTML(value='<p>Controls: W (↑), A (←), D (→), R (reset) or use buttons below</p>')

Image(value=b'', height='500', width='500')

HBox(children=(Button(description='←', style=ButtonStyle()), Button(description='↑', style=ButtonStyle()), But…

HTML(value="\n        <script>\n        var actions = {\n            'w': 2,  // Forward\n            'a': 0, …

## MiniGrid Memory

In [7]:
def start_minigrid_memory(size=13):
    env = gym.make('MiniGrid-MemoryS13Random-v0', agent_view_size = 3, render_mode='rgb_array')
    env.reset()
    game_state = create_game_interface(env)
    return game_state
game = start_minigrid_memory()

HTML(value='<h3>MiniGrid Maze Game</h3>')

HTML(value='<p>Controls: W (↑), A (←), D (→), R (reset) or use buttons below</p>')

Image(value=b'', height='500', width='500')

HBox(children=(Button(description='←', style=ButtonStyle()), Button(description='↑', style=ButtonStyle()), But…

HTML(value="\n        <script>\n        var actions = {\n            'w': 2,  // Forward\n            'a': 0, …