<a href="https://colab.research.google.com/github/Tinynja/Sarsa-phi-EB/blob/main/ALE_Framework_Tests.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import sys

if 'google.colab' in sys.modules:
    !rm -rf *
    !git clone https://github.com/Tinynja/Sarsa-phi-EB
    !mv Sarsa-phi-EB/* .
    !rm -rf Sarsa-phi-EB
    # DON'T install packages defined in Pipfile_colab_remove
    !sed -ri "/$(tr '\n' '|' < Pipfile_colab_remove)/d" Pipfile
else:
    print('Skipping GitHub cloning since not running in Colab.')

Cloning into 'Sarsa-phi-EB'...
remote: Enumerating objects: 210, done.[K
remote: Counting objects: 100% (210/210), done.[K
remote: Compressing objects: 100% (182/182), done.[K
remote: Total 210 (delta 45), reused 155 (delta 19), pack-reused 0[K
Receiving objects: 100% (210/210), 655.71 KiB | 6.43 MiB/s, done.
Resolving deltas: 100% (45/45), done.


In [2]:
# Install required dependencies
if 'google.colab' in sys.modules:
    # Colab doesn't support pipenv, hence we convert Pipfile into requirements.txt
    !pip install pipenv 1> /dev/null
    !pipenv lock -r > requirements.txt
    !pip install -r requirements.txt 1> /dev/null
else:
    !pipenv lock 1> /dev/null
    !pipenv install --deploy 1> /dev/null

Creating a virtualenv for this project...
Pipfile: /content/Pipfile
Using /usr/local/bin/python (3.7.12) to create virtualenv...
⠴[0m Creating virtual environment...[Kcreated virtual environment CPython3.7.12.final.0-64 in 1042ms
  creator CPython3Posix(dest=/root/.local/share/virtualenvs/content-cQIIIOO2, clear=False, no_vcs_ignore=False, global=False)
  seeder FromAppData(download=False, pip=bundle, setuptools=bundle, wheel=bundle, via=copy, app_data_dir=/root/.local/share/virtualenv)
    added seed packages: pip==21.3.1, setuptools==58.3.0, wheel==0.37.0
  activators BashActivator,CShellActivator,FishActivator,NushellActivator,PowerShellActivator,PythonActivator

[K[?25h[32m[22m✔ Successfully created virtual environment![39m[22m[0m 
Virtualenv location: /root/.local/share/virtualenvs/content-cQIIIOO2
Pipfile.lock not found, creating...
Locking [dev-packages] dependencies...
Locking [packages] dependencies...
[KBuilding requirements...
[KResolving dependencies...
[K[?25h

In [3]:
# Import all supported ROMs into ALE
!ale-import-roms ROMS

[92m[SUPPORTED]    [0m            ms_pacman    ROMS/Ms. Pac-Man (1983).bin
[92m[SUPPORTED]    [0m               amidar         ROMS/Amidar (1982).bin
[92m[SUPPORTED]    [0m        haunted_house ROMS/Haunted House (Mystery Mansion, Graves' Manor, Nightmare Manor) (1982).bin
[92m[SUPPORTED]    [0m               gopher ROMS/Gopher (Gopher Attack) (1982).bin
[92m[SUPPORTED]    [0m            centipede      ROMS/Centipede (1983).bin
[92m[SUPPORTED]    [0m           beam_rider      ROMS/Beamrider (1984).bin
[92m[SUPPORTED]    [0m         demon_attack ROMS/Demon Attack (Death from Above) (1982).bin
[92m[SUPPORTED]    [0m          road_runner    ROMS/Road Runner (1989).bin
[92m[SUPPORTED]    [0m              berzerk        ROMS/Berzerk (1982).bin
[92m[SUPPORTED]    [0m           mario_bros    ROMS/Mario Bros. (1983).bin
[92m[SUPPORTED]    [0m                   et ROMS/E.T. - The Extra-Terrestrial (1982).bin
[92m[SUPPORTED]    [0m           time_pilot     ROMS/Time Pilot

In [4]:
#### ALE-related imports ####

# Built-in libraries
import re
import sys
import base64
import pickle
import subprocess
from pathlib import Path

# Pypi libraries
import numpy as np
from IPython import display as ipythondisplay
from ale_py import ALEInterface, SDL_SUPPORT

In [5]:
class features:
    @staticmethod
    def basic(img, palette, crop_size=np.array([15,10])):
        # For each color in palette, tell if each pixel is that color
        # e.g. 4x4x3 image, with 2x3 palette, returns 4x4x2
        colors_in_pixels = np.logical_and.reduce(np.expand_dims(img,-2) == palette, axis=-1)
        # Split the image into `n_subimages`, each with dimension `crop_size`
        n_subimages = (img.shape[:2]/crop_size).prod()
        if n_subimages.is_integer():
            n_subimages = int(n_subimages)
        else:
            raise TypeError(f'n_subimages must be an integer, got `{n_subimages}` instead')
        cropped_colors_in_pixels = colors_in_pixels.reshape(n_subimages, *crop_size, colors_in_pixels.shape[-1])
        # Apply logical or insize each cropped image
        cropped_features = np.logical_or.reduce(cropped_colors_in_pixels, axis=(1,2))
        # Flatten the features
        features = cropped_features.flatten()
        return features

    @staticmethod
    def b_pros(img, palette, crop_size=np.array([15,10])):
        basic_features = features.basic(img, palette, crop_size=crop_size)
        pass

In [22]:
class EnvALE:
    def __init__(self, rom, out_dir='ale-results', display=False, seed=0, feature_type='raw'):
        self.rom = rom
        self.out_dir = Path(out_dir).resolve()
        self.out_dir.mkdir(exist_ok=True)
        self.feature_type = feature_type

        self.ale = ALEInterface()
        self.ale.setInt("random_seed", seed)
        if display and SDL_SUPPORT and 'google.colab' not in sys.modules:
            ale.setBool("sound", True)
            ale.setBool("display_screen", True)
        self.ale.loadROM(rom)

        self.action_space = self.ale.getMinimalActionSet()
        self.color_palette = self._get_color_palette()
        self._set_observe_method(feature_type)

        # Default values
        self._timestep = 0
        self._do_record = False
        self._record_padding = None

    def reset(self, do_record=False):
        self.ale.reset_game()
        observation = self._observe()
        self._timestep = 0

        self._do_record = do_record
        self._handle_recording()
        
        return observation
        
    def step(self, action, number_of_frames=4):
        if isinstance(action, int):
            action = self.action_space[action]
        for i in range(number_of_frames):
          reward = self.ale.act(action)
        observation = self._observe()
        done = self.ale.game_over()
        self._timestep += 1
        
        self._handle_recording()
        
        return observation, reward, done, None

    def show_video(self, scale=1):
        """Show a .mp4 video in html format of the recorded episode"""
        filepath = self.out_dir.joinpath('record.mp4')
        video_b64 = base64.b64encode(filepath.read_bytes())
        html = f'''<video alt="{filepath}" autoplay loop controls style="height:300px">
                        <source src="data:video/mp4;base64,{video_b64.decode('ascii')}" type="video/mp4" />
                   </video>'''
        ipythondisplay.display(ipythondisplay.HTML(data=html))

    def _set_observe_method(self, feature_type):
        if feature_type == 'raw':
            self._observe = self.ale.getScreenRGB
        elif feature_type == 'Basic':
            self._observe = lambda: features.basic(img=self.ale.getScreenRGB(),
                                                   palette=self.color_palette)
        else:
            raise NotImplementedError(f'Feature type `{feature_type}` is not supported')
        
    def _observe(self):
        raise NotImplementedError()
    
    def _handle_recording(self):
        # Do nothing if not asked to record
        if not self._do_record: return
        # This is a new episode, delete previously recorded steps
        if not self._timestep:
            self.out_dir.joinpath('record').mkdir(exist_ok=True)
            for step_png in self.out_dir.glob('record/step_*.png'):
                step_png.unlink()
            self._record_padding = None
        # Record current timestep png
        out_path = self.out_dir.joinpath(f'record/step_{self._timestep}.png')
        self.ale.saveScreenPNG(str(out_path))
        # Once the episode is over, format all png filenames to have the same integer 0 padding
        if self.ale.game_over():
            self._record_padding = len(str(self._timestep))
            self._standardize_record_padding()
            self._png_to_mp4()
    
    def _standardize_record_padding(self):
        number_pattern = re.compile('\d+')
        for png in self.out_dir.glob('record/step_*.png'):
            timestep = int(number_pattern.search(png.stem).group(0))
            new_name = png.parent.joinpath(f'step_{timestep:0{self._record_padding}d}.png')
            png.rename(new_name)

    def _png_to_mp4(self):
        """Convert the recorded set of png files into a mp4 video"""
        in_dir = self.out_dir.joinpath('record')
        in_pattern = self.out_dir.joinpath(f'record/step_%0{self._record_padding}d.png')
        out_file = self.out_dir.joinpath('record.mp4')
        !cd $in_dir; ffmpeg -hide_banner -loglevel error -r 60 -i $in_pattern -vcodec libx264 -crf 25 -pix_fmt yuv420p -y $out_file
    
    def _get_color_palette(self):
        result = subprocess.run(['python', '-c', f'__import__("ale_py").ALEInterface().loadROM("{str(self.rom)}")'], capture_output=True)
        palette_name = result.stderr.decode().splitlines()[6].strip().split()[-1]
        with open(f'{palette_name}_Palette.pickle', 'rb') as file:
            palette = pickle.load(file)
        return palette

In [30]:
from ale_py.roms import Breakout

env = EnvALE(Breakout, feature_type='Basic')
done, observation = False, env.reset(do_record=True)
t = 0
while not done and t <= 150:
    #action = random.choice(env.action_space)
    action = random.choice(range(4))
    observation, reward, done, info = env.step(action)
    t += 1
env.show_video()

print(env._get_color_palette().shape)
print(observation.shape)
#59 sec without speedup
#15 sec with 4 frames speedup (game ended)
#13 sec with 5 frames speedup (game ended)
#7 sec with 10 frames speedup (game ended)

(128, 3)
(28672,)


In [15]:
#@title
# Uncomment the following line to run this test
%%script echo Skipped this cell because it is a manual test.

import random

# Init environment
env = EnvALE(Breakout, observe_method=lambda *args: env.ale.getScreenRGB())

# Play an episode
done, observation = False, env.reset(do_record=True)
while not done:
    action = random.choice(env.action_space)
    observation, reward, done, info = env.step(action)

# Show episode
env.show_video()

Skipped this cell because it is a manual test.


In [9]:
class BreakoutEnv(EnvALE):
    def __init__(self):
        from ale_py.roms import Breakout
        super().__init__(Breakout)

    def _observe(self):
        # Blob-PROS
        return self.ale.getScreenRGB()

In [10]:
#@title
# Uncomment the following line to run this test
%%script echo Skipped this cell because it is a manual test.

import random

env = BreakoutEnv()

# Play an episode
done, observation = False, env.reset(do_record=True)
while not done:
    #action = random.choice(env.action_space)
    action = random.choice(range(4))
    observation, reward, done, info = env.step(action)
env.show_video()

Skipped this cell because it is a manual test.


Sarsa implementation


In [11]:
# Charles' commented out pieces of code

# install dependencies
# !pip install torch torchvision pyvirtualdisplay matplotlib seaborn pandas numpy pathlib gym
# env
# import gym
# import pandas as pd
# torch stuff
# import torch
# import torch.nn as nn
# import torch.nn.functional as F 
# from torch import optim

In [12]:
#### Learning-related imports ####

# Built-in libraries
# from typing import Sequence, Tuple, Dict, Any, Optional

# Pypi libraries
# import numpy as np
# import matplotlib.pyplot as plt

In [13]:
# DON'T USE!!!

# class SarsaPhiEB:
#     def __init__(self):
#         pass
    
#     def generate_episode

# def run_episodes(env, Q, num_episodes=100, to_print=False):
#     '''
#     Run some episodes to test the policy
#     '''
#     tot_rew = []
#     state = env.reset()

#     for _ in range(num_episodes):
#         done = False
#         game_rew = 0

#         while not done:
#             # select a greedy action
#             next_state, rew, done, _ = env.step(greedy(Q, state))

#             state = next_state
#             game_rew += rew 
#             if done:
#                 state = env.reset()
#                 tot_rew.append(game_rew)
#     return tot_rew

# #eps-greedy
# def eps_greedy(Q, s, eps=0.1):
#     '''
#     Epsilon greedy policy
#     '''
#     if np.random.uniform(0,1) < eps:
#         # Choose a random action
#         return np.random.randint(Q.shape[1])
#     else:
#         # Choose the action of a greedy policy
#         return greedy(Q, s)

# #formatting action for ale env
# def formatAction(action):
#   return action

# #sarsa
# def SARSA(env, lr=0.01, num_episodes=10000, eps=0.3, gamma=0.95, eps_decay=0.00005):

#     nA = len(env.action_space)
#     nS = 1
#     for shape in env.ale.getScreenRGB().shape:
#       nS = nS*shape

#     # Initialize the Q matrix
#     # Q: matrix nS*nA where each row represent a state and each colums represent a different action
#     Q = np.zeros((nS, nA))
#     games_reward = []
#     test_rewards = []

#     for ep in range(num_episodes):
#         state = env.reset()
#         done = False
#         tot_rew = 0

#         # decay the epsilon value until it reaches the threshold of 0.01
#         if eps > 0.01:
#             eps -= eps_decay


#         action = eps_greedy(Q, state, eps)
#         move = formatAction(action)
        

#         # loop the main body until the environment stops
#         while not done:
#             next_state, rew, done, _ = env.step(move) # Take one step in the environment

#             # choose the next action (needed for the SARSA update)
#             next_action = eps_greedy(Q, next_state, eps) 
#             # SARSA update
#             Q[state][action] = Q[state][action] + lr*(rew + gamma*Q[next_state][next_action] - Q[state][action])

#             state = next_state
#             action = next_action
#             move = formatAction(action)
#             tot_rew += rew
#             if done:
#                 games_reward.append(tot_rew)

#         # Test the policy every 300 episodes and print the results
#         if (ep % 300) == 0:
#             test_rew = run_episodes(env, Q, 1000)
#             print("Episode:{:5d}  Eps:{:2.4f}  Rew:{:2.4f}".format(ep, eps, test_rew))
#             test_rewards.append(test_rew)

#     return Q