<a href="https://colab.research.google.com/github/Tinynja/Sarsa-phi-EB/blob/main/notebooks/ALE_Framework_Tests.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import sys

if 'google.colab' in sys.modules:
    !rm -rf *
    !git clone https://github.com/Tinynja/Sarsa-phi-EB
    !mv Sarsa-phi-EB/* .
    !rm -rf Sarsa-phi-EB
    # DON'T install packages defined in Pipfile_colab_remove
    !sed -ri "/$(tr '\n' '|' < Pipfile_Colab_exclude)/d" Pipfile
else:
    print('Skipping GitHub cloning since not running in Colab.')

Cloning into 'Sarsa-phi-EB'...
remote: Enumerating objects: 277, done.[K
remote: Counting objects: 100% (277/277), done.[K
remote: Compressing objects: 100% (228/228), done.[K
remote: Total 277 (delta 78), reused 200 (delta 37), pack-reused 0[K
Receiving objects: 100% (277/277), 708.29 KiB | 6.00 MiB/s, done.
Resolving deltas: 100% (78/78), done.


In [None]:
# Install required dependencies
import os

if 'google.colab' in sys.modules:
    # Colab doesn't support pipenv, hence we convert Pipfile into requirements.txt
    if 'requirements_Colab.txt' not in os.listdir():
        !pip install pipenv
        !pipenv lock -r > requirements.txt
    !pip install -r requirements_Colab.txt 1> /dev/null
else:
    !pipenv install 1> /dev/null

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
yellowbrick 1.3.post1 requires numpy<1.20,>=1.16.0, but you have numpy 1.21.4 which is incompatible.
datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.
arviz 0.11.4 requires typing-extensions<4,>=3.7.4.3, but you have typing-extensions 4.0.1 which is incompatible.
albumentations 0.1.12 requires imgaug<0.2.7,>=0.2.5, but you have imgaug 0.2.9 which is incompatible.[0m


In [None]:
# Import all supported ROMs into ALE
!ale-import-roms ROMS

[92m[SUPPORTED]    [0m                qbert        ROMS/Q. Bert (1983).bin
[92m[SUPPORTED]    [0m              hangman ROMS/Hangman - Spelling (1978).bin
[92m[SUPPORTED]    [0m       miniature_golf ROMS/Miniature Golf - Arcade Golf (1979).bin
[92m[SUPPORTED]    [0m       space_invaders ROMS/Space Invaders (1980).bin
[92m[SUPPORTED]    [0m             defender       ROMS/Defender (1982).bin
[92m[SUPPORTED]    [0m               pooyan         ROMS/Pooyan (1983).bin
[92m[SUPPORTED]    [0m            ms_pacman    ROMS/Ms. Pac-Man (1983).bin
[92m[SUPPORTED]    [0m        video_pinball ROMS/Video Pinball - Arcade Pinball (1981).bin
[92m[SUPPORTED]    [0m        wizard_of_wor  ROMS/Wizard of Wor (1982).bin
[92m[SUPPORTED]    [0m             surround ROMS/Surround - Chase (Blockade) (1977).bin
[92m[SUPPORTED]    [0m              pitfall ROMS/Pitfall! - Pitfall Harry's Jungle Adventure (Jungle Runner) (1982).bin
[92m[SUPPORTED]    [0m              freeway        ROMS/Fre

In [None]:
#### ALE-related imports ####

# Built-in libraries
import re
import sys
import base64
import pickle
import random
import subprocess
from pathlib import Path

# Pypi libraries
import torch
import numpy as np
import matplotlib.pyplot as plt
from IPython import display as ipythondisplay
from ale_py import ALEInterface, SDL_SUPPORT
import ale_py.roms as ROMS


# Configuration
device = 'cuda' if torch.cuda.device_count() else 'cpu'

  for external in metadata.entry_points().get(self.group, []):


In [None]:
class features:
    @staticmethod
    def basic(frame, palette, background, crop_size=torch.Tensor([15,10])):
        # For each color in palette, tell if each pixel is that color
        # e.g. 4x4x3 image, with 2x3 palette, returns 4x4x2
        colors_in_pixels = ((frame-background).unsqueeze(-2) == palette).all(-1)
        # Split the image into `n_subimages`, each with dimension `crop_size`
        frame_dims = torch.Tensor([*frame.shape[:2]])
        n_subimages = (frame_dims/crop_size).prod().item()
        if n_subimages.is_integer():
            n_subimages = int(n_subimages)
        else:
            raise TypeError(f'n_subimages must be an integer, got `{n_subimages}` instead')
        cropped_colors_in_pixels = colors_in_pixels.reshape((n_subimages, *crop_size.int().tolist(), colors_in_pixels.shape[-1]))
        # Apply logical or insize each cropped image
        cropped_features = cropped_colors_in_pixels.any(2).any(1)
        # Flatten the features
        features = cropped_features.flatten()
        return features
    
    @staticmethod
    def b_pros(frame, palette, crop_size=torch.Tensor([15,10])):
        basic_features = features.basic(frame, palette, crop_size=crop_size)
        pass

In [None]:
class EnvALE:
    def __init__(self, rom, out_dir='ale-results', display=False, seed=0, feature_type='ScreenRGB',
                 regen_bg=False, bg_samples=18000):
        self.rom = rom
        self.rom_name = rom.stem
        self.out_dir = Path(out_dir).resolve()
        self.out_dir.mkdir(exist_ok=True)
        self.feature_type = feature_type

        self.ale = ALEInterface()
        self.ale.setInt("random_seed", seed)
        if display and SDL_SUPPORT and 'google.colab' not in sys.modules:
            ale.setBool("sound", True)
            ale.setBool("display_screen", True)
        self.ale.loadROM(rom)

        self.action_space = self.ale.getMinimalActionSet()
        self.color_palette = self._get_color_palette().to(device)

        self.bg_path = Path(f'./backgrounds/{self.rom_name}.pickle')
        if regen_bg or not self.bg_path.exists() or not self.bg_path.is_file():
            self.background = self._get_background(n_samples=bg_samples)
        else:
            with open(self.bg_path, 'rb') as file:
                self.background = pickle.load(file).to(device)
        
        self._set_observe_method(feature_type)

        # Default values
        self._timestep = 0
        self._do_record = False
        self._record_padding = None

    def reset(self, do_record=False):
        self.ale.reset_game()
        observation = self._observe()
        self._timestep = 0

        self._do_record = do_record
        self._handle_recording()
        
        return observation
        
    def step(self, action, repeat=5):
        if isinstance(action, int):
            action = self.action_space[action]
        for i in range(repeat):
          reward = self.ale.act(action)
        observation = self._observe()
        done = self.ale.game_over()
        self._timestep += 1
        
        self._handle_recording()
        
        return observation, reward, done, None

    def show_video(self, scale=1):
        """Show a .mp4 video in html format of the recorded episode"""
        filepath = self.out_dir.joinpath('record.mp4')
        video_b64 = base64.b64encode(filepath.read_bytes())
        html = f'''<video alt="{filepath}" autoplay loop controls style="height:300px">
                        <source src="data:video/mp4;base64,{video_b64.decode('ascii')}" type="video/mp4" />
                   </video>'''
        ipythondisplay.display(ipythondisplay.HTML(data=html))

    def _set_observe_method(self, feature_type):
        if feature_type == 'ScreenRGB':
            self._observe = lambda: torch.from_numpy(self.ale.getScreenRGB()).to(device)
        elif feature_type == 'ScreenGrayscale':
            self._observe = lambda: torch.from_numpy(self.ale.getScreenGrayscale()).to(device)
        elif feature_type == 'Basic':
            self._observe = lambda: features.basic(frame=torch.from_numpy(self.ale.getScreenRGB()).to(device),
                                                   palette=self.color_palette,
                                                   background=self.background)
        else:
            raise NotImplementedError(f'Feature type `{feature_type}` is not supported')
        
    def _observe(self):
        raise NotImplementedError()
    
    def _get_color_palette(self):
        result = subprocess.run(['python', '-c', f'__import__("ale_py").ALEInterface().loadROM("{str(self.rom)}")'], capture_output=True)
        palette_name = result.stderr.decode().splitlines()[6].strip().split()[-1]
        with open(f'palettes/{palette_name}_Palette.pickle', 'rb') as file:
            palette = pickle.load(file)
        return palette
    
    def _get_background(self, n_samples):
        bg_feature_type = 'ScreenRGB' if self.feature_type not in ['ScreenGrayscale',] else 'ScreenGrayscale'
        self._set_observe_method(bg_feature_type)
        
        sample_i = 0
        pixel_histogram = torch.zeros((*self.ale.getScreenDims(), self.color_palette.shape[0]), dtype=int).to(device)
        while sample_i < n_samples:
            done, observation = False, self.reset()
            while not done and sample_i < n_samples:
                if not sample_i%10:
                    print(f'\rGenerating background... {sample_i}/{n_samples} samples ({sample_i/n_samples:.0%})', end='')
                action = random.choice(self.action_space)
                observation, reward, done, info = self.step(action)
                observation = torch.from_numpy(observation).to(device)
                colors_in_pixels = (observation.unsqueeze(-2) == self.color_palette).all(-1)
                # for i in range(colors_in_pixels.shape[-1]):
                #     print(colors_in_pixels.reshape(-1, 128))
                pixel_histogram += colors_in_pixels
                sample_i += 1
        background_ids = pixel_histogram.argmax(axis=-1)
        background = self.color_palette[background_ids]
        
        self.bg_path.parent.mkdir(exist_ok=True)
        with open(self.bg_path, 'wb') as file:
            pickle.dump(background.cpu(), file)
        
        return background
    
    def _handle_recording(self):
        # Do nothing if not asked to record
        if not self._do_record: return
        # This is a new episode, delete previously recorded steps
        if not self._timestep:
            self.out_dir.joinpath('record').mkdir(exist_ok=True)
            for step_png in self.out_dir.glob('record/step_*.png'):
                step_png.unlink()
            self._record_padding = None
        # Record current timestep png
        out_path = self.out_dir.joinpath(f'record/step_{self._timestep}.png')
        self.ale.saveScreenPNG(str(out_path))
        # Once the episode is over, format all png filenames to have the same integer 0 padding
        if self.ale.game_over():
            self._record_padding = len(str(self._timestep))
            self._standardize_record_padding()
            self._png_to_mp4()
    
    def _standardize_record_padding(self):
        number_pattern = re.compile('\d+')
        for png in self.out_dir.glob('record/step_*.png'):
            timestep = int(number_pattern.search(png.stem).group(0))
            new_name = png.parent.joinpath(f'step_{timestep:0{self._record_padding}d}.png')
            png.rename(new_name)

    def _png_to_mp4(self):
        """Convert the recorded set of png files into a mp4 video"""
        in_dir = self.out_dir.joinpath('record')
        in_pattern = self.out_dir.joinpath(f'record/step_%0{self._record_padding}d.png')
        out_file = self.out_dir.joinpath('record.mp4')
        !cd $in_dir; ffmpeg -hide_banner -loglevel error -r 60 -i $in_pattern -vcodec libx264 -crf 25 -pix_fmt yuv420p -y $out_file


In [None]:
#@title
# Uncomment the following line to regenerate backgrounds
%%script echo Skipped background regeneration.

device = 'cuda'
from ale_py.roms import *
games_to_generate_bg = [Breakout, MontezumaRevenge, Venture, Qbert, Frostbite, Freeway]

for game in games_to_generate_bg:
    print(game.stem)
    env = EnvALE(game, regen_bg=True)
    plt.imshow(env.background.cpu().numpy())
    plt.show()

Skipped background regeneration.


In [None]:
#@title
# Uncomment the following line to display all stored backgrounds
%%script echo Skipped displaying stored backgrounds to reduce ouptuts.

for filepath in Path('backgrounds').iterdir():
    print(f'Background in `{filepath.resolve()}`')
    with open(filepath, 'rb') as file:
        bg = pickle.load(file)
    plt.imshow(bg)
    plt.show()

Skipped displaying stored backgrounds to reduce ouptuts.


In [None]:
#@title
# Uncomment the following line to run this test
%%script echo Skipped manual test.

import random

# Init environment
env = EnvALE(ROMS.Breakout, feature_type='ScreenRGB')

# Play an episode
done, observation = False, env.reset(do_record=True)
while not done:
    action = random.choice(env.action_space)
    observation, reward, done, info = env.step(action)
    print(f'\rTimestep {env._timestep}...', end='')

# Show episode
env.show_video()

Skipped manual test.


Sarsa implementation


In [None]:
# Charles' commented out pieces of code

# install dependencies
# !pip install torch torchvision pyvirtualdisplay matplotlib seaborn pandas numpy pathlib gym
# env
# import gym
# import pandas as pd
# torch stuff
# import torch
# import torch.nn as nn
# import torch.nn.functional as F 
# from torch import optim

In [None]:
#### Learning-related imports ####

# Built-in libraries
# from typing import Sequence, Tuple, Dict, Any, Optional

# Pypi libraries
# import numpy as np
# import matplotlib.pyplot as plt

In [None]:
# DON'T USE!!!

# class SarsaPhiEB:
#     def __init__(self):
#         pass
    
#     def generate_episode

# def run_episodes(env, Q, num_episodes=100, to_print=False):
#     '''
#     Run some episodes to test the policy
#     '''
#     tot_rew = []
#     state = env.reset()

#     for _ in range(num_episodes):
#         done = False
#         game_rew = 0

#         while not done:
#             # select a greedy action
#             next_state, rew, done, _ = env.step(greedy(Q, state))

#             state = next_state
#             game_rew += rew 
#             if done:
#                 state = env.reset()
#                 tot_rew.append(game_rew)
#     return tot_rew

# #eps-greedy
# def eps_greedy(Q, s, eps=0.1):
#     '''
#     Epsilon greedy policy
#     '''
#     if np.random.uniform(0,1) < eps:
#         # Choose a random action
#         return np.random.randint(Q.shape[1])
#     else:
#         # Choose the action of a greedy policy
#         return greedy(Q, s)

# #formatting action for ale env
# def formatAction(action):
#   return action

# #sarsa
# def SARSA(env, lr=0.01, num_episodes=10000, eps=0.3, gamma=0.95, eps_decay=0.00005):

#     nA = len(env.action_space)
#     nS = 1
#     for shape in env.ale.getScreenRGB().shape:
#       nS = nS*shape

#     # Initialize the Q matrix
#     # Q: matrix nS*nA where each row represent a state and each colums represent a different action
#     Q = np.zeros((nS, nA))
#     games_reward = []
#     test_rewards = []

#     for ep in range(num_episodes):
#         state = env.reset()
#         done = False
#         tot_rew = 0

#         # decay the epsilon value until it reaches the threshold of 0.01
#         if eps > 0.01:
#             eps -= eps_decay


#         action = eps_greedy(Q, state, eps)
#         move = formatAction(action)
        

#         # loop the main body until the environment stops
#         while not done:
#             next_state, rew, done, _ = env.step(move) # Take one step in the environment

#             # choose the next action (needed for the SARSA update)
#             next_action = eps_greedy(Q, next_state, eps) 
#             # SARSA update
#             Q[state][action] = Q[state][action] + lr*(rew + gamma*Q[next_state][next_action] - Q[state][action])

#             state = next_state
#             action = next_action
#             move = formatAction(action)
#             tot_rew += rew
#             if done:
#                 games_reward.append(tot_rew)

#         # Test the policy every 300 episodes and print the results
#         if (ep % 300) == 0:
#             test_rew = run_episodes(env, Q, 1000)
#             print("Episode:{:5d}  Eps:{:2.4f}  Rew:{:2.4f}".format(ep, eps, test_rew))
#             test_rewards.append(test_rew)

#     return Q