<a href="https://colab.research.google.com/github/Tinynja/Sarsa-phi-EB/blob/main/ALE_Framework_Tests.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import sys

if 'google.colab' in sys.modules:
    !rm -rf *
    !git clone https://github.com/Tinynja/Sarsa-phi-EB
    !mv Sarsa-phi-EB/* .
    !rm -rf Sarsa-phi-EB
    # DON'T install packages defined in Pipfile_colab_remove
    !sed -ri "/$(tr '\n' '|' < Pipfile_colab_remove)/d" Pipfile
else:
    print('Skipping GitHub cloning since not running in Colab.')

Cloning into 'Sarsa-phi-EB'...
remote: Enumerating objects: 187, done.[K
remote: Counting objects: 100% (187/187), done.[K
remote: Compressing objects: 100% (161/161), done.[K
remote: Total 187 (delta 36), reused 148 (delta 17), pack-reused 0[K
Receiving objects: 100% (187/187), 567.38 KiB | 5.40 MiB/s, done.
Resolving deltas: 100% (36/36), done.


In [2]:
# Install required dependencies
if 'google.colab' in sys.modules:
    # Colab doesn't support pipenv, hence we convert Pipfile into requirements.txt
    !pip install pipenv 1> /dev/null
    !pipenv lock -r > requirements.txt
    !pip install -r requirements.txt 1> /dev/null
else:
    !pipenv lock 1> /dev/null
    !pipenv install --deploy 1> /dev/null

Creating a virtualenv for this project...
Pipfile: /content/Pipfile
Using /usr/local/bin/python (3.7.12) to create virtualenv...
⠴[0m Creating virtual environment...[Kcreated virtual environment CPython3.7.12.final.0-64 in 992ms
  creator CPython3Posix(dest=/root/.local/share/virtualenvs/content-cQIIIOO2, clear=False, no_vcs_ignore=False, global=False)
  seeder FromAppData(download=False, pip=bundle, setuptools=bundle, wheel=bundle, via=copy, app_data_dir=/root/.local/share/virtualenv)
    added seed packages: pip==21.3.1, setuptools==58.3.0, wheel==0.37.0
  activators BashActivator,CShellActivator,FishActivator,NushellActivator,PowerShellActivator,PythonActivator

[K[?25h[32m[22m✔ Successfully created virtual environment![39m[22m[0m 
Virtualenv location: /root/.local/share/virtualenvs/content-cQIIIOO2
Pipfile.lock not found, creating...
Locking [dev-packages] dependencies...
Locking [packages] dependencies...
[KBuilding requirements...
[KResolving dependencies...
[K[?25h

In [3]:
# Import all supported ROMs into ALE
!ale-import-roms ROMS

[92m[SUPPORTED]    [0m                qbert        ROMS/Q. Bert (1983).bin
[92m[SUPPORTED]    [0m               casino ROMS/Casino - Poker Plus (Paddle) (1979).bin
[92m[SUPPORTED]    [0m         demon_attack ROMS/Demon Attack (Death from Above) (1982).bin
[92m[SUPPORTED]    [0m              bowling        ROMS/Bowling (1979).bin
[92m[SUPPORTED]    [0m             seaquest       ROMS/Seaquest (1983).bin
[92m[SUPPORTED]    [0m      chopper_command ROMS/Chopper Command (1982).bin
[92m[SUPPORTED]    [0m           time_pilot     ROMS/Time Pilot (1983).bin
[92m[SUPPORTED]    [0m             trondead ROMS/TRON - Deadly Discs (TRON Joystick) (1983).bin
[92m[SUPPORTED]    [0m            jamesbond ROMS/James Bond 007 (James Bond Agent 007) (1984).bin
[92m[SUPPORTED]    [0m               pacman        ROMS/Pac-Man (1982).bin
[92m[SUPPORTED]    [0m               kaboom ROMS/Kaboom! (Paddle) (1981).bin
[92m[SUPPORTED]    [0m              asterix ROMS/Asterix (AKA Taz) (1983)

In [71]:
# Built-in libraries
import re
import sys
import base64
import random
from pathlib import Path

# Pipy libraries
from IPython import display as ipythondisplay
from ale_py import ALEInterface, SDL_SUPPORT
from ale_py.roms import Breakout
from PIL import Image

In [87]:
class EnvALE:
    def __init__(self, rom, out_dir='ale-results', display=False, seed=0):
        self.out_dir = Path(out_dir).resolve()
        self.out_dir.mkdir(exist_ok=True)
    
        self.ale = ALEInterface()
        self.ale.setInt("random_seed", seed)
        if SDL_SUPPORT and 'google.colab' not in sys.modules:
            ale.setBool("sound", True)
            ale.setBool("display_screen", True)
        self.ale.loadROM(rom)

        self.action_space = self.ale.getMinimalActionSet()

        # Default values
        self._timestep = 0
        self._do_record = False
        self._record_padding = None

    def reset(self, do_record=False):
        self.ale.reset_game()
        observation = self._observe()
        self._timestep = 0

        self._do_record = do_record
        self._handle_recording()
        
        return observation
        
    def step(self, action):
        reward = self.ale.act(action)
        observation = self._observe()
        done = self.ale.game_over()
        self._timestep += 1
        
        self._handle_recording()
        
        return observation, reward, done, None

    def show_video(self, scale=1):
        """Show a .mp4 video in html format of the recorded episode"""
        filepath = self.out_dir.joinpath('record.mp4')
        video_b64 = base64.b64encode(filepath.read_bytes())
        html = f'''<video alt="{filepath}" autoplay loop controls style="height:300px">
                        <source src="data:video/mp4;base64,{video_b64.decode('ascii')}" type="video/mp4" />
                   </video>'''
        ipythondisplay.display(ipythondisplay.HTML(data=html))
        
    def _observe(self):
        return self.ale.getScreenRGB()
    
    def _handle_recording(self):
        # Do nothing if not asked to record
        if not self._do_record: return
        # This is a new episode, delete previously recorded steps
        if not self._timestep:
            self.out_dir.joinpath('record').mkdir(exist_ok=True)
            for step_png in self.out_dir.glob('record/step_*.png'):
                step_png.unlink()
            self._record_padding = None
        # Record current timestep png
        out_path = self.out_dir.joinpath(f'record/step_{self._timestep}.png')
        self.ale.saveScreenPNG(str(out_path))
        # Once the episode is over, format all png filenames to have the same integer 0 padding
        if self.ale.game_over():
            self._record_padding = len(str(self._timestep))
            self._standardize_record_padding()
            self._png_to_mp4()
    
    def _standardize_record_padding(self):
        number_pattern = re.compile('\d+')
        for png in self.out_dir.glob('record/step_*.png'):
            timestep = int(number_pattern.search(png.stem).group(0))
            new_name = png.parent.joinpath(f'step_{timestep:0{self._record_padding}d}.png')
            png.rename(new_name)

    def _png_to_mp4(self):
        """Convert the recorded set of png files into a mp4 video"""
        in_dir = self.out_dir.joinpath('record')
        in_pattern = self.out_dir.joinpath(f'record/step_%0{self._record_padding}d.png')
        out_file = self.out_dir.joinpath('record.mp4')
        !cd $in_dir; ffmpeg -hide_banner -loglevel error -r 60 -i $in_pattern -vcodec libx264 -crf 25 -pix_fmt yuv420p -y $out_file

In [95]:
env = EnvALE(Breakout)

# Play an episode
done, observation = False, env.reset(do_record=True)
while not done:
    action = random.choice(env.action_space)
    observation, reward, done, info = env.step(action)
env.show_video()