# Install Libs

In [29]:
!pip install gym[atari,accept-rom-license]
ATARI_LIVES_KEY = 'lives'
%load_ext tensorboard

Defaulting to user installation because normal site-packages is not writeable
The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


#コード

## ActorCritic Model



In [30]:
import numpy as np
import tensorflow as tf

from tensorflow.keras import layers

import tensorflow as tf

class RunningMean(tf.keras.metrics.Metric):
  def __init__(self, gamma, name):
    super().__init__(name=name)
    self.gamma = gamma
    self.v = self.add_weight('v', initializer='zeros', dtype=tf.float32)
    self.count = self.add_weight('count', initializer='zeros', dtype=tf.float32)

  def update_state(self, value):
      self.v.assign( value + self.gamma*(self.v-value) )
      return self.count.assign_add(1)

  def result(self):
      return self.v/(1-self.gamma**self.count)


class ActorCritic(tf.keras.Model):

    def __init__(self, network, action_space,  dims=0, value_coef=0.5, entropy_coef=0.01, norm_advs=1, **kwargs):
        super().__init__(**kwargs)

        self.action_space = action_space


        self.norm_advs=norm_advs
        self.value_coef = value_coef

        self.entropy_coef = entropy_coef

        self.network = network

        self.to_values = layers.Dense(1, kernel_initializer=tf.initializers.Orthogonal(1.0) )
        self.to_actions = layers.Dense(action_space, kernel_initializer=tf.initializers.Orthogonal(0.01) )

        self.tracker_loss = RunningMean(0.99, name="loss")
        self.tracker_entropy = RunningMean(0.99, name="entropy")
        self.tracker_policy_loss = RunningMean(0.99, name="policy_loss")
        self.tracker_value_loss = RunningMean(0.99, name="value_loss")


    def sample(self, states):
        # values, policy = self(states, training=False)
        # return values.numpy(), policy.numpy()
        pred = self.predict_on_batch(states)
        return pred[0], pred[1]

    @tf.function
    def call(self, x, training=None):
        x = self.network(x, training=training )
        actions = self.to_actions(x)
        values = self.to_values(x)
        return values, actions

    @tf.function
    def train_step(self, data):
        n = self.optimizer.iterations
        try:
            advantages = data[0][2]
            # Normalize advantages
            if self.norm_advs==1:
                advantages = advantages / (tf.math.reduce_std(advantages)+1e-8)
            elif self.norm_advs==2:
                advantages = (advantages - tf.math.reduce_mean(advantages))/ (tf.math.reduce_std(advantages)+1e-8)

            results = self._train_step((data[0][0],data[0][1],advantages,data[0][3],data[0][4]),
                                       data[1], self.value_coef, self.entropy_coef)

        except Exception as e:
            raise e

        loss = results['loss']
        policy_loss = results['policy_loss']
        value_loss = results['value_loss']
        entropy = results['entropy']

        self.tracker_loss.update_state(loss)
        self.tracker_entropy.update_state(entropy)
        self.tracker_policy_loss.update_state(policy_loss)
        self.tracker_value_loss.update_state(value_loss)

        metrics = {
            "loss": self.tracker_loss.result(),
            "value_loss": self.tracker_value_loss.result(),
            "policy_loss": self.tracker_policy_loss.result(),
            "entropy": self.tracker_entropy.result(),
        }

        return metrics

    def _train_step(self, X, Y, value_coef, entropy_coef):
        print('!!')
        return





### A2C

In [31]:
class A2C(ActorCritic):
    def __init__(self, *args, clip_global_norm=0.5, **kwargs):
        super().__init__(*args, **kwargs)
        self.clip_global_norm = clip_global_norm

    @tf.function
    def _train_step(self, X, Y, value_coef,  entropy_coef):

        states = X[0]
        actions = X[1]

        rewards = Y

        EPS = 1e-8
        def compute_logprob(logits, actions):
            probs = tf.nn.softmax(logits)
            actions_onehot = tf.one_hot(actions, self.action_space, dtype=tf.float32)
            return tf.math.log(tf.reduce_sum(actions_onehot * probs, axis=1) + EPS)

        with tf.GradientTape() as tape:
            values, logits = self(states, training=True)
            values = tf.squeeze(values)

            advantages = rewards - values

            action_log_probs = compute_logprob(logits, actions)

            action_probs = tf.nn.softmax(logits)

            policy_loss = action_log_probs * tf.stop_gradient(advantages)

            entropy = -1 * tf.reduce_sum( action_probs * tf.math.log(action_probs + EPS), axis=1, keepdims=True)

            value_loss = advantages ** 2

            policy_loss = tf.reduce_mean(policy_loss)
            value_loss = tf.reduce_mean(value_loss)
            entropy = tf.reduce_mean(entropy)
            loss  = value_coef * value_loss
            loss -= policy_loss + (entropy_coef * entropy)

        grads = tape.gradient(loss, self.trainable_variables)
        grads, _ = tf.clip_by_global_norm(grads, self.clip_global_norm)
        self.optimizer.apply_gradients( zip(grads, self.trainable_variables))

        return {
            "loss": loss,
            "value_loss": value_loss,
            "policy_loss": policy_loss,
            "entropy": entropy,
        }


### PPO

In [32]:
class PPO(ActorCritic):
    def __init__(self, *args, pi_clip_range=0.2, v_clip_range=None, clip_global_norm=0.5,**kwargs):
        super().__init__(*args, **kwargs)
        self.pi_clip_range = pi_clip_range
        self.v_clip_range = v_clip_range
        self.clip_global_norm = clip_global_norm

    @tf.function
    def _train_step(self, X, Y,value_coef,  entropy_coef):
        states = X[0]
        actions = X[1]
        advantages = tf.squeeze(X[2])
        old_logits = X[3]
        old_values = X[4]
        rewards = tf.squeeze(Y)

        EPS = 1e-8

        def compute_logprob(logits, actions):
            probs = tf.nn.softmax(logits)
            actions_onehot = tf.one_hot(actions, self.action_space, dtype=tf.float32)
            return tf.math.log(tf.reduce_sum(actions_onehot * probs, axis=1) + EPS)


        old_logprob = compute_logprob(old_logits, actions)

        with tf.GradientTape() as tape:

            values, logits = self(states, training=True)

            action_probs = tf.nn.softmax(logits)

            new_logprob = compute_logprob(logits, actions)

            values = tf.squeeze(values)

            value_loss = (values-rewards) ** 2
            if self.v_clip_range is not None:
                values_clipped = tf.clip_by_value(values, old_values-self.v_clip_range, old_values+self.v_clip_range)
                value_loss_clipped = (values_clipped - rewards)**2
                value_loss = tf.maximum( value_loss_clipped, value_loss )


            entropy = -1 * tf.reduce_sum( action_probs * tf.math.log(action_probs + EPS), axis=1, keepdims=True)

            ratio = tf.exp(new_logprob - old_logprob)
            ratio_clipped = tf.clip_by_value(ratio, 1 - self.pi_clip_range, 1 + self.pi_clip_range)
            loss_unclipped = ratio * advantages
            loss_clipped = ratio_clipped * advantages
            policy_loss = tf.minimum(loss_unclipped, loss_clipped)

            policy_loss = tf.reduce_mean(policy_loss)
            value_loss = tf.reduce_mean(value_loss)
            entropy = tf.reduce_mean(entropy)
            loss  = value_coef * value_loss                # minimize
            loss -= policy_loss + (entropy_coef * entropy)      # maximize


        grads = tape.gradient(loss, self.trainable_variables)
        grads, _ = tf.clip_by_global_norm(grads, self.clip_global_norm)
        self.optimizer.apply_gradients(zip(grads, self.trainable_variables))


        return {
            "loss": loss,
            "value_loss": value_loss,
            "policy_loss": policy_loss,
            "entropy": entropy,
        }



##Agent

In [33]:
from collections import deque

class Agent:
    def __init__(self, agent_id, env,trajectory_length, hist_len=100, e_info=None):
        self.agent_id = agent_id
        self.env = env
        self._state = None
        self._trajectory_length = trajectory_length
        self._reward = 0
        self._reset_trajectory()
        self._e_info = e_info


    def _reset_trajectory(self):
        length = self._trajectory_length
        self.trajectory = {"s": np.empty((length,)+self.env.observation_space.shape, dtype=self.env.observation_space.dtype),
                           "a": np.empty((length), dtype=np.int32),
                           "r": np.empty((length), dtype=np.float32),
                           "p": np.empty((length, self.env.action_space.n), dtype=np.float32),
                           "v": np.empty((length), dtype=np.float32),
                           "dones": np.empty((length), dtype=np.int32)
                           }
        self.count = 0

    def reset_env(self):
        self._state = self.env.reset()
        return self._state

    def step_by_value_and_prob(self, value, logit):
        self.trajectory["v"][self.count]=value
        self.trajectory["p"][self.count]=logit

        def softmax(x):
            x = np.exp(x - np.max(x))
            return x / np.sum(x)

        prob = softmax(logit)
        action = np.random.choice( self.env.action_space.n, p=prob)

        return self.step(action)

    def step(self, action):
        state = self._state
        next_state, reward, _, _, info, done = self.env.step(action)

        self.trajectory["s"][self.count]=state
        self.trajectory["a"][self.count]=action
        self.trajectory["r"][self.count]=reward
        self.trajectory["dones"][self.count]=done

        self._reward += reward
        if done:
            next_state = self.env.reset()
            if ATARI_LIVES_KEY in info: # Atari
                if "episode" in info:
                    e_info = info["episode"]
                    e_info.update( {"reward":self._reward} )
                    self._e_info.append(e_info)
                    self._reward = 0
            else:
                self._e_info.append({"score":self._reward,"reward":self._reward})
                self._reward = 0

        self._state = next_state
        self.count+=1
        return self._state

    def collect_trajectory(self):
        trajectory = self.trajectory
        self._reset_trajectory()
        return trajectory


import math
import threading

class EpisodesInfo():
    def __init__(self):
        self._scores = []
        self._rewards = []
        self._truncated_count = 0
        self._episode_count = 0
        self._lock = threading.Lock()

    def append(self, info):
        with self._lock:
            self._episode_count += 1
            # print(epinfo)
            if 'truncated' not in info:
                self._scores.append(info["score"])
                self._rewards.append(info["reward"])
            else:
                self._truncated_count += 1

    def get_info(self, num_hist=100):
        with self._lock:
            score_max = 0
            score_mean = 0
            score_min = 0


            scores = self._scores[-num_hist:]
            if len(scores)==0:
                return { 'episodes': 0, 'reward_mean':0, 'score_mean': 0 }

            rewards = self._rewards[-num_hist:]

            return {
                'episodes': self._episode_count,
                'truncated': self._truncated_count,
                'score_max': max(scores),
                'score_min': min(scores),
                'score_mean': sum(scores)/len(scores),
                'reward_mean': sum(rewards)/len(rewards),
            }

class AgentManager():
    def __init__(self, num_agents, envs, trajectory_length, hist_len=100):
        self._e_info =  EpisodesInfo()
        hist_len = math.ceil(hist_len/num_agents)
        self._agents = [Agent(i, envs[i],trajectory_length, hist_len=hist_len, e_info=self._e_info) for i in range(num_agents)]
        self._states = None

    def reset(self):
        self._states =  [agent.reset_env() for agent in self._agents]

    def step(self, actions):
        self._states = [agent.step(action) for action, agent in zip(actions, self._agents)]

    def step_by_value_and_prob(self, values, action_probs):
        self._states = [agent.step_by_value_and_prob(value, prob) for value, prob, agent in zip(values, action_probs, self._agents)]

    def collect_trajectory(self):
        return [agent.collect_trajectory() for agent in self._agents]

    def get_states(self):
        return np.array(self._states, dtype=np.float32)

    def get_info(self):
        return self._e_info.get_info()


##Training Loop

In [34]:
from pathlib import Path
import shutil
import os
import gym
from gym import wrappers
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
from tqdm import tqdm

import datetime
def make_network(env_name, obs_shape):
    if env_name.startswith('CartPole'):
        network = tf.keras.models.Sequential([
            layers.Dense(64,activation='relu', input_shape=obs_shape),
            layers.Dense(64,activation='relu')
        ])
    else:
        print(obs_shape)
        envs = []
        network = build_atari_model(obs_shape)
    return network

import random
def set_global_seeds(seed):
    tf.random.set_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    os.environ['TF_DETERMINISTIC_OPS'] = '1'
    os.environ['TF_CUDNN_DETERMINISTIC'] = '1'

import time

def learn( env_name, envs, num_agents=4, prev_policy=None, gamma=0.99, lam=0.95,
          entropy_coef=0.01, value_coef=0.5, pi_clip_range=0.1,
          trajectory_length=8, batch_size=None, num_epochs = 1, num_batches=1,
          num_frames=10000, lr=1e-4 , test_play_interval=256,
          mode='a2c', rand_seed=1234, use_tensorboard=True):

    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    logdir = "logs/" + current_time
    if use_tensorboard:
        summary_writer = tf.summary.create_file_writer(str(logdir))
    else:
        summary_writer = None

    monitordir = logdir

    set_global_seeds(rand_seed)

    network = make_network(env_name, envs[0].observation_space.shape)
    action_space = envs[0].action_space.n
    for i in range(len(envs)):
        envs[i].seed(rand_seed)


    test_score = 0
    use_gae = False
    metrics = {}

    try:
        num_samples = num_agents*trajectory_length

        if mode=='ppo':
            actor_critic = PPO(network, action_space=action_space,
                                entropy_coef=entropy_coef, value_coef=value_coef,
                                pi_clip_range=pi_clip_range)
            use_gae = True
        elif mode=='a2c':
            actor_critic = A2C(network, action_space=action_space,
                                entropy_coef=entropy_coef, value_coef=value_coef)
            use_gae = True
            lam = 1.0

        if prev_policy:
            actor_critic.build((None,84,84,4))
            actor_critic.set_weights(prev_policy.get_weights())

        frames_per_iteration = trajectory_length*num_agents

        if batch_size is None:
            batch_size = frames_per_iteration//num_batches


        if type(lr) is tuple:
            lr_decay = lr[1]**(1/num_frames)
            lr = lr[0]
        else:
            lr_decay = (0.1)**(1/num_frames)

        print(mode, num_agents, trajectory_length, num_samples, batch_size, num_batches)
        print(env_name,envs[0].get_action_meanings())
        print(lr, gamma)

        optimizer = tf.keras.optimizers.Adam(lr, epsilon=1e-7)
        actor_critic.compile(optimizer)

        agents = AgentManager(num_agents, envs, trajectory_length)
        agents.reset()

        next_test_point = test_play_interval

        info = None


        current_frames = 0
        test_score, truncated =play_out(envs[-1], actor_critic, video_fps=12, filename='_recording.avi')
        convert_avi_file('_recording.avi', f'{logdir}/{env_name}_{current_frames//1000:05d}_[{test_score}].gif')


        print('Initial Score : ', test_score)
        best_score= -10000000

        pbar = tqdm(total=num_frames)
        iter_count = 0
        for n in range(0, num_frames, frames_per_iteration):
            iter_count+=1

            # playout
            for _ in range(trajectory_length):
                values, action_probs = actor_critic.sample(agents.get_states())
                agents.step_by_value_and_prob(values, action_probs)

            # collect
            trajectories = agents.collect_trajectory()

            # calc rewards
            next_values, _ = actor_critic.sample(agents.get_states())

            (states, actions, discounted_rewards, old_probs, advantages) = [], [], [], [], []
            (dones,values) = [], []
            for idx, trajectory in enumerate(trajectories):
                discounted_Rs = [0.0] * trajectory_length
                ADVs = [0.0] * trajectory_length

                R = next_values[idx][0]

                tmp_rewards = trajectory["r"]
                tmp_dones = trajectory["dones"]
                if not use_gae:
                    for i in reversed(range(trajectory_length)):
                        R = tmp_rewards[i] + gamma * (1 - tmp_dones[i]) * R
                        discounted_Rs[i] = R
                else:
                    last_gae = 0.0
                    tmp_values = trajectory["v"]
                    for i in reversed(range(trajectory_length)):
                        mask = 1 - tmp_dones[i]
                        value = tmp_values[i]

                        delta = tmp_rewards[i] + (mask * gamma * R) - value
                        last_gae = delta + (mask * gamma * lam * last_gae)
                        ADVs[i] = last_gae
                        discounted_Rs[i] = last_gae+value
                        R = value

                discounted_rewards.append( discounted_Rs )
                advantages.append( ADVs )
                states.append( trajectory["s"] )
                actions.append( trajectory["a"] )
                old_probs.append( trajectory["p"] )

                values.append( trajectory["v"] )


            states = np.concatenate(states, dtype=states[0].dtype)
            old_probs = np.concatenate(old_probs, dtype=np.float32)
            actions = np.concatenate(actions, dtype=np.int32)
            discounted_rewards = np.concatenate(discounted_rewards, dtype=np.float32)
            advantages = np.concatenate(advantages, dtype=np.float32)
            values = np.concatenate(values, dtype=np.float32)

            # Update
            optimizer.lr = lr * (lr_decay**n)

            metrics = {}
            reward_mean = 0.0
            for _ in range(num_epochs):
                idxes = [idx for idx in range(len(states))]
                random.shuffle(idxes)

                for i in range(num_batches):
                    idx = idxes[i*batch_size:(i+1)*batch_size]

                    metrics = actor_critic.train_step(
                        ([states[idx], actions[idx], advantages[idx], old_probs[idx], values[idx]],
                         discounted_rewards[idx]))

            #Output Results

            info = agents.get_info()
            reward_mean = info['reward_mean']
            score_mean = info['score_mean']
            values_mean = sum(values)/len(values)

            if summary_writer is not None:
                with summary_writer.as_default():
                    tf.summary.scalar("reward_mean", reward_mean, step=n)
                    tf.summary.scalar("score_mean", score_mean, step=n)
                    tf.summary.scalar("loss", metrics["loss"], step=n)
                    tf.summary.scalar("policy_loss", metrics["policy_loss"], step=n)
                    tf.summary.scalar("value_loss", metrics["value_loss"], step=n)
                    tf.summary.scalar("entropy", metrics["entropy"], step=n)
                    tf.summary.scalar("lr", optimizer.lr, step=n)
                    tf.summary.scalar("values_mean", values_mean, step=n)

            log_str = ', '.join([f'{k}={v:.4f}' for k,v in metrics.items()])
            pbar.set_postfix_str(log_str+ f', score_mean={score_mean:.4f}')
            pbar.update(frames_per_iteration)

            current_frames = n+frames_per_iteration
            if current_frames >= next_test_point*1000 or current_frames>=num_frames:
                if best_score < score_mean:
                    actor_critic.save_weights(logdir+f'/{env_name}_best_weights.h5')
                    best_score=score_mean

                next_test_point += test_play_interval

                test_score, truncated =play_out(envs[-1], actor_critic, video_fps=12, filename='_recording.avi')
                if os.path.exists('_recording.avi'):
                    convert_avi_file('_recording.avi', f'{logdir}/{env_name}_{current_frames//1000:05d}_[{test_score}].gif')

                info_str = ', '.join([f'{k}={v:.2f}' if isinstance(v, float) else f'{k}={v}' for k,v in info.items()])
                print('')
                is_truncated = '(truncated)' if truncated else ''
                print(f'    {info_str}, test_score={test_score}{is_truncated}')


    except KeyboardInterrupt:
        pass
    finally:
        pbar.close()
        if actor_critic:
            actor_critic.save_weights(logdir+f'/{env_name}_last_weights.h5')
    return actor_critic



#



##Atari

### Wrapper

In [35]:
from types import FrameType
import gym
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2


from gym.core import Wrapper
from collections import deque

import math

# https://github.com/MadryLab/implementation-matters
class RewardFilter(gym.core.Wrapper):
    def __init__(self, env, reward_coef=None, clip=1.0, gamma = None,
                 penalty_on_lost=False, penalty_on_no_reward=None):
        super().__init__(env)#, new_step_api=False)
        self.gamma = gamma
        self.clip = clip
        self.reward_coef = reward_coef
        self.penalty_on_no_reward = penalty_on_no_reward
        self.penalty_on_lost = penalty_on_lost
        self._n = 0
        self._M = 0.0
        self._S = 0.0
        self._ret = 0.0
        self._lives = 0

    def step(self, action):
        next_state, reward, _, _, info, done = self.env.step(action)
        reward, done, info = self._filter(reward, done, info)
        return obs, reward, done, info

    def disable_penalty(self):
        self.penalty_on_lost=False
        self.penalty_on_no_reward =None

    def reset(self):
        obs = self.env.reset()

        self._ret = 0.0
        self.no_reward_count = 0
        self._lives = 0
        return obs

    def _filter(self, reward, done, info):
        if self.reward_coef:
            if type(self.reward_coef) is tuple:
                reward = reward * ( self.reward_coef[0] if reward>0 else self.reward_coef[1])
            else:
                reward = self.reward_coef * reward

        if self.gamma:
            self._ret = self._ret * self.gamma + reward
            self._push(self._ret)
            reward = reward / (self._std() + 1e-8)

        if self.clip:
            if type(self.clip)==tuple:
                reward = np.clip(reward, self.clip[0], self.clip[1])
            else:
                reward = np.clip(reward, -self.clip, self.clip)

        if self.penalty_on_lost:
            lives = info[ATARI_LIVES_KEY]
            if self._lives!=0 and lives<self._lives:
                reward -= 1.0
            self._lives = lives

        if self.penalty_on_no_reward is not None:
            if done:
                self.no_reward_count = 0
            else:
                if reward==0:
                    self.no_reward_count += 1
                    # print(self.no_reward_count,self.penalty_on_no_reward[0])
                    if self.no_reward_count>=self.penalty_on_no_reward[0]:
                        reward = self.penalty_on_no_reward[2]
                    if done == False and self.penalty_on_no_reward[1] is not None:
                        done = self.no_reward_count>=self.penalty_on_no_reward[1]
                        if done:
                            info['episode']= {'truncated':True}

                else:
                    self.no_reward_count = 0

        return reward, done, info


    def _push(self, x):
        self._n += 1
        if self._n == 1:
            self._M = x
        else:
            oldM = self._M
            self._M = oldM + (x - oldM) / self._n
            self._S = self._S + (x - oldM) * (x - self._M)

    def _mean(self):
        return self._M
    def _var(self):
        return self._S / (self._n - 1) if self._n > 1 else self._M**2
    def _std(self):
        return math.sqrt(self._var())

class ActionFilter(gym.core.Wrapper):
    def __init__(self, env, action_meanings=None, action_rewards=None):
        super().__init__(env)#, new_step_api=False)
        self.mode_name='train'

        org_meanings = env.get_action_meanings()
        if action_meanings==None:
            action_meanings = org_meanings
        self.action_rewards = {}
        self.act = [0]*len(action_meanings)
        for idx, action in enumerate(action_meanings):
            self.act[idx] = org_meanings.index(action)
            if action_rewards is not None:
                if action in action_rewards:
                    self.action_rewards[self.act[idx]] = action_rewards[action]

        self.action_space = gym.spaces.Discrete(len(self.act))
        self.action_meanings = action_meanings

    def step(self, action):
        action = self.act[action]
        next_state, reward, terminated, truncated, info, done = self.env.step(action)
        if action in self.action_rewards:
            reward += self.action_rewards[action]

        return obs, reward, terminated, truncated, info, done

    def get_action_meanings(self):
        return self.action_meanings



class AtariWrapper(gym.core.Wrapper):
    def __init__(self, env, img_size=(84,84), n_frames=4, n_skips=4,
                 fire_on_reset=0, noop_on_reset=(0,15), done_on_lost=True,
                 done_on_reward=False, obs_merge=1, score_limit=None):
        super().__init__(env)#, #new_step_api=False)
        height, width, n_channels = env.observation_space.shape
        height, width = img_size
        obs_shape = [height, width, n_frames]
        self.observation_space = gym.spaces.Box(0.0, 1.0, obs_shape, dtype=np.uint8)
        self.img_size = img_size
        self.n_frames = n_frames
        self.n_skips = n_skips
        self.lives = 0
        self.true_score = 0
        self.episodes = 0
        self.score_limit = score_limit
        self.fire_on_reset = fire_on_reset
        self.act=None
        self.mode_name='train'
        self.true_reset = True
        self.done_on_reward = done_on_reward
        self.recording = False
        self.obs_merge = obs_merge
        self.done_on_lost = done_on_lost
        self.noop_on_reset = noop_on_reset

    def set_true_reset(self, true_reset=True):
        self.true_reset = true_reset

    def set_recording(self, flag=True):
        self.recording = flag

    def get_raw_frames(self):
        raw_frames = self.raw_frames
        self.raw_frames = []
        return raw_frames

    def reset(self):
        self.no_reward_count = 0
        self.frames = deque(maxlen=self.n_frames)

        # print('reset', self.true_reset)
        if self.true_reset:
            self.true_reset = False
            self.true_score = 0
            self.obs_buffer = deque(maxlen=2)
            self.raw_frames = []
            self._reset()

        for i in range(np.random.randint(low=self.noop_on_reset[0], high=self.noop_on_reset[1], size=1)[0]):
            self._step(0)

        for i in range(self.fire_on_reset):
            self._step(1)


        merged_frame = self._merge_frames()
        for i in range(self.n_frames):
            self.frames.append(merged_frame)

        frame = np.concatenate(self.frames, axis = -1)
        return  frame


    def _reset(self):
        obs = self.env.reset()
        if self.recording:
            self.raw_frames.append(obs)
        self.obs_buffer.append(obs)
        return obs

    def _step(self, action):
        obs, reward, terminated, truncated, info, dont = super().step(action)
        if self.recording:
            self.raw_frames.append(obs)
        self.obs_buffer.append(obs)
        return obs, reward, terminated, truncated, info, dont

    def _preprocess_obs(self, obs):
        obs = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
        obs = cv2.resize(obs, self.img_size)
        obs = obs[:, :, None]
        return obs

    def _merge_frames(self):
        if self.obs_merge == 0 or len(self.obs_buffer)==1:
            frame = self._preprocess_obs(self.obs_buffer[-1])
        elif self.obs_merge == 1:
            buf = [self._preprocess_obs(self.obs_buffer[0]), self._preprocess_obs(self.obs_buffer[1])]
            frame = np.array(buf).max(axis=0)
        elif self.obs_merge == 2:
            buf = [self._preprocess_obs(self.obs_buffer[0]), self._preprocess_obs(self.obs_buffer[1])]
            frame = ((buf[0].astype(np.float32) + buf[1].astype(np.float32))/2).astype(np.uint8)
        return frame

    def _step_pool(self,action):

        n_skips = self.n_skips

        reward = 0.0
        for i in range(n_skips):
            obs, tmp_reward, done, info = self._step(action)
            reward += tmp_reward
            if done:
                break

        merged_frame = self._merge_frames()

        return merged_frame, reward, done, info

    def step(self, action):
        obs, reward, done, info = self._step_pool(action)

        self.true_score += reward
        if self.score_limit is not None:
            if self.true_score >= self.score_limit:
                done = True


        if not done:
            if self.done_on_reward:
                if reward != 0:
                    done = True
            lives = info[ATARI_LIVES_KEY]
            if self.done_on_lost:
                if self.lives!=0 and lives!=self.lives:
                    done = True
            self.lives = lives
        else:
            info['episode'] = {
                'count':self.episodes+1,
                'score':int(self.true_score),
            }

            self.lives = 0
            self.episodes += 1
            self.true_reset = True
        self.frames.append(obs)

        return np.concatenate(self.frames, axis = -1), reward, done, info

# test("MsPacman")


### Customize Atari Env

In [36]:


def build_atari_env_and_param(env_name, num_envs, img_size=(84,84)):
    gamma = 0.99
    envs=[]

    for _ in range(num_envs):
        env = gym.make(env_name+'NoFrameskip-v4')
        if env_name == 'Breakout':
            env = AtariWrapper(env, img_size=img_size, n_skips=4, fire_on_reset=1, score_limit=864)
            env = RewardFilter(env, clip=1.0, penalty_on_lost=True, penalty_on_no_reward=(200,1000,-0.01))
            env = ActionFilter(env, ['NOOP', 'LEFT', 'RIGHT'])
            gamma = 0.95
        elif env_name == 'SpaceInvaders':
            env = AtariWrapper(env, img_size=img_size, n_skips=4)
            env = RewardFilter(env, reward_coef=0.01, clip=1.0, penalty_on_lost=True)
            env = ActionFilter(env, ['NOOP', 'FIRE', 'LEFT', 'RIGHT'])
        elif env_name == 'CrazyClimber':
            env = AtariWrapper(env, img_size=img_size, n_skips=4)
            env = RewardFilter(env, reward_coef=0.001, clip=1.0, penalty_on_lost=True)
            env = ActionFilter(env, ['NOOP', 'UP', 'DOWN', 'LEFT', 'RIGHT'])
        elif env_name == 'Alien':
            env = AtariWrapper(env, img_size=img_size, n_skips=4)
            env = RewardFilter(env, reward_coef=0.01, clip=1.0, penalty_on_lost=True)
            env = ActionFilter(env, ['UP', 'DOWN', 'LEFT', 'RIGHT'])
        elif env_name == 'Boxing':
            env = AtariWrapper(env, img_size=img_size, n_skips=4)
            env = RewardFilter(env, clip=1.0, penalty_on_lost=False)
            env = ActionFilter(env, [ 'FIRE', 'UPRIGHT', 'UPLEFT', 'DOWNRIGHT', 'DOWNLEFT',])
            # env = ActionFilter(env, ['NOOP', 'FIRE', 'UP', 'RIGHT', 'LEFT', 'DOWN', 'UPRIGHT', 'UPLEFT', 'DOWNRIGHT', 'DOWNLEFT',])
            gamma = 0.95
        elif env_name == 'Zaxxon':
            env = AtariWrapper(env, img_size=img_size, n_skips=4)
            env = RewardFilter(env, clip=1.0, penalty_on_lost=True)
            env = ActionFilter(env, ['FIRE', 'UPRIGHT', 'UPLEFT', 'DOWNRIGHT', 'DOWNLEFT', ] )
        elif env_name == 'Tennis':
            env = AtariWrapper(env, img_size=img_size, n_skips=4,
                               fire_on_reset=0, done_on_reward=True)
            env = RewardFilter(env, clip=1.0, penalty_on_no_reward=(500,500,-1.0))
            env = ActionFilter(env, [ 'FIRE', 'UPFIRE', 'RIGHTFIRE', 'LEFTFIRE', 'DOWNFIRE', 'UPRIGHTFIRE', 'UPLEFTFIRE', 'DOWNRIGHTFIRE', 'DOWNLEFTFIRE'])
        elif env_name == 'Seaquest':
            env = AtariWrapper(env, img_size=img_size, n_skips=4)
            env = RewardFilter(env, reward_coef=0.01, clip=1.0, penalty_on_lost=True)
            env = ActionFilter(env, ['NOOP', 'FIRE', 'UP', 'RIGHT', 'LEFT', 'DOWN'] )
        elif env_name == 'Pong':
            env = AtariWrapper(env, img_size=img_size, n_skips=4,
                               fire_on_reset=1, done_on_reward=True)
            env = RewardFilter(env, clip=1.0)
            env = ActionFilter(env, [ 'NOOP', 'RIGHT', 'LEFT',] )
        elif env_name == 'TimePilot':
            env = AtariWrapper(env, img_size=img_size, n_skips=4)
            env = RewardFilter(env, reward_coef=0.001, clip=3.0, penalty_on_lost=True)
            env = ActionFilter(env, ['FIRE', 'UP', 'RIGHT', 'LEFT', 'DOWN', ] )
        elif env_name == 'MsPacman':
            env = AtariWrapper(env, img_size=img_size, n_skips=4)
            env = RewardFilter(env, reward_coef=0.01, clip=1.0, penalty_on_lost=True)
            env = ActionFilter(env, ['UP', 'DOWN', 'LEFT', 'RIGHT'])
        elif env_name == 'ChopperCommand':
            env = AtariWrapper(env, img_size=img_size, n_skips=4)
            env = RewardFilter(env, clip=1.0, penalty_on_lost=True)
            env = ActionFilter(env, ['FIRE', 'UP', 'DOWN', 'LEFT', 'RIGHT'])
        elif env_name == 'Gopher':
            env = AtariWrapper(env, img_size=img_size, n_skips=4)
            env = RewardFilter(env, reward_coef=0.01, clip=1.0, penalty_on_lost=True, penalty_on_no_reward=(500,1000,-0.1))
            env = ActionFilter(env, ['NOOP', 'FIRE', 'UP', 'LEFT', 'RIGHT'])
        elif env_name == 'Bowling':
            env = AtariWrapper(env, img_size=img_size, n_skips=4)
            env = RewardFilter(env, clip=30.0)
            env = ActionFilter(env, ['FIRE', 'UP'])
            gamma = 0.999
        elif env_name == 'Kangaroo':
            env = AtariWrapper(env, img_size=img_size, n_skips=4)
            env = RewardFilter(env, clip=1.0, penalty_on_lost=True)
            env = ActionFilter(env, [ 'FIRE', 'UP', 'RIGHT', 'LEFT', 'DOWN'], {'RIGHT':0.01})
        elif env_name == 'Enduro':
            env = AtariWrapper(env, img_size=img_size, n_skips=4)
            env = RewardFilter(env, clip=1.0)
            env = ActionFilter(env,
                ['NOOP', 'FIRE', 'RIGHT', 'LEFT', 'DOWN', 'DOWNRIGHT', 'DOWNLEFT', 'RIGHTFIRE', 'LEFTFIRE'],
                {'FIRE':0.01})
            gamma = 0.95
        elif env_name == 'Freeway':
            env = AtariWrapper(env, img_size=img_size, n_skips=4)
            env = RewardFilter(env, clip=1.0)
            env = ActionFilter(env, ['NOOP', 'UP'])
        elif env_name == 'Gravitar':
            env = AtariWrapper(env, img_size=img_size, n_skips=4)
            env = RewardFilter(env, clip=1.0)
            env = ActionFilter(env,  ['NOOP', 'FIRE', 'UP', 'RIGHT', 'LEFT'])
        elif env_name == 'BeamRider':
            env = AtariWrapper(env, img_size=img_size, n_skips=4)
            env = RewardFilter(env, clip=1.0, penalty_on_lost=True)
            env = ActionFilter(env,  ['NOOP', 'FIRE', 'UP', 'RIGHT', 'LEFT'])
        elif env_name == 'DoubleDunk':
            env = AtariWrapper(env, img_size=img_size, n_skips=4, done_on_reward=True)
            env = RewardFilter(env, clip=1.0, penalty_on_no_reward=(500,5000,-0.1))
            env = ActionFilter(env,
                [ 'NOOP','FIRE', 'UP', 'RIGHT', 'LEFT', 'DOWN',  'UPFIRE', 'RIGHTFIRE', 'LEFTFIRE', 'DOWNFIRE',])
        elif env_name == 'Robotank':
            env = AtariWrapper(env, img_size=img_size, n_skips=4)
            env = RewardFilter(env, clip=1.0, penalty_on_lost=True)
            env = ActionFilter(env, ['FIRE', 'UP', 'RIGHT', 'LEFT', 'UPRIGHT', 'UPLEFT'] )
            gamma=0.995
        elif env_name == 'KungFuMaster':
            env = AtariWrapper(env, img_size=img_size, n_skips=4)
            env = RewardFilter(env, reward_coef=0.001, clip=1.0, penalty_on_lost=True)
            env = ActionFilter(env,
                ['NOOP', 'UP', 'RIGHT', 'LEFT', 'DOWN', 'RIGHTFIRE', 'LEFTFIRE', 'DOWNFIRE', 'DOWNRIGHTFIRE', 'DOWNLEFTFIRE'],
                {'LEFT':0.01})
        elif env_name == 'MontezumaRevenge':
            env = AtariWrapper(env, img_size=img_size, n_skips=4)
            env = RewardFilter(env, clip=1.0, penalty_on_lost=True)
            env = ActionFilter(env, ['UP', 'RIGHT', 'LEFT', 'DOWN', 'RIGHTFIRE', 'LEFTFIRE'])
        elif env_name == 'RoadRunner':
            env = AtariWrapper(env, img_size=img_size, n_skips=4 )
            env = RewardFilter(env, clip=1.0, penalty_on_lost=True)
            env = ActionFilter(env,
                ['UP', 'RIGHT', 'LEFT', 'DOWN', 'UPRIGHT', 'UPLEFT', 'DOWNRIGHT', 'DOWNLEFT', 'LEFTFIRE', 'DOWNLEFTFIRE', 'UPLEFTFIRE'],
                {'LEFT':0.001})
            gamm=0.95
        elif env_name == 'Amidar':
            env = AtariWrapper(env, img_size=img_size, n_skips=4)
            env = RewardFilter(env, clip=1.0, penalty_on_lost=True)
            env = ActionFilter(env, ['UP', 'RIGHT', 'LEFT', 'DOWN'])
            gamma = 0.995
        elif env_name == 'Asteroids':
            env = AtariWrapper(env, img_size=img_size, n_skips=4)
            env = RewardFilter(env, reward_coef=0.01, clip=1.0, penalty_on_lost=True )
            env = ActionFilter(env, ['NOOP', 'FIRE', 'UP', 'RIGHT', 'LEFT'])
        elif env_name == 'IceHockey':
            env = AtariWrapper(env, img_size=img_size, n_skips=4, done_on_reward=True)
            env = RewardFilter(env, clip=1.0)
            env = ActionFilter(env, ['FIRE', 'UPRIGHT', 'UPLEFT', 'DOWNRIGHT', 'DOWNLEFT', ] )
        elif env_name == 'Frostbite':
            env = AtariWrapper(env, img_size=img_size, n_skips=4)
            env = RewardFilter(env, reward_coef=0.01, clip=1.0, penalty_on_lost=True)
            env = ActionFilter(env, ['NOOP', 'UP', 'RIGHT', 'LEFT', 'DOWN'] )
        elif env_name == 'FishingDerby':
            env = AtariWrapper(env, img_size=img_size, n_skips=4)
            env = RewardFilter(env, clip=(0.0,10.0))
            env = ActionFilter(env, ['NOOP', 'FIRE', 'UP', 'RIGHT', 'LEFT', 'DOWN',])
        elif env_name == 'BattleZone':
            env = AtariWrapper(env, img_size=img_size, n_skips=4)
            env = RewardFilter(env, clip=1.0, penalty_on_lost=True)
            env = ActionFilter(env, ['FIRE', 'UP', 'RIGHT', 'LEFT', 'DOWN','UPRIGHT', 'UPLEFT', 'DOWNRIGHT', 'DOWNLEFT', ] )
        elif env_name == 'WizardOfWor':
            env = AtariWrapper(env, img_size=img_size, n_skips=4)
            env = RewardFilter(env, clip=1.0, penalty_on_lost=True, penalty_on_no_reward=(500,1000,-0.01))
            env = ActionFilter(env, ['FIRE', 'UP', 'RIGHT', 'LEFT', 'DOWN', ])
            gamma = 0.995
        elif env_name == 'BankHeist':
            env = AtariWrapper(env, img_size=img_size, n_skips=4)
            env = RewardFilter(env, clip=1.0, penalty_on_lost=True)
            env = ActionFilter(env, ['NOOP', 'FIRE', 'UP', 'RIGHT', 'LEFT', 'DOWN'] )
        elif env_name == 'PrivateEye':
            env = AtariWrapper(env, img_size=img_size, n_skips=4)
            env = RewardFilter(env, reward_coef=0.1, clip=10.0, penalty_on_no_reward=(200,5000,-0.01))
            env = ActionFilter(env, ['NOOP', 'FIRE', 'UP', 'RIGHT', 'LEFT', 'DOWN', 'RIGHTFIRE', 'LEFTFIRE'] )
        elif env_name == 'Asterix':
            env = AtariWrapper(env, img_size=img_size, n_skips=4)
            env = RewardFilter(env, reward_coef=0.01, clip=1.0, penalty_on_lost=True)
            # env = ActionFilter(env, ['NOOP', 'UP', 'RIGHT', 'LEFT', 'DOWN',])
        elif env_name == 'NameThisGame':
            env = AtariWrapper(env, img_size=img_size, n_skips=4)
            env = RewardFilter(env, reward_coef=0.01, clip=1.0, penalty_on_lost=True)
            env = ActionFilter(env, ['NOOP', 'FIRE', 'RIGHT', 'LEFT'])
        elif env_name == 'Riverraid':
            env = AtariWrapper(env, img_size=img_size, n_skips=4)
            env = RewardFilter(env, clip=1.0, penalty_on_lost=True)
            env = ActionFilter(env, ['NOOP', 'FIRE', 'UPRIGHT', 'UPLEFT', 'DOWNRIGHT', 'DOWNLEFT'])
        elif env_name == 'Jamesbond':
            env = AtariWrapper(env, img_size=img_size, n_skips=4)
            env = RewardFilter(env, clip=1.0, penalty_on_lost=True)
            env = ActionFilter(env, ['NOOP', 'FIRE', 'UP', 'RIGHT', 'LEFT', 'DOWN', 'UPRIGHT', 'UPLEFT', 'DOWNRIGHT', 'DOWNLEFT'])
        elif env_name == 'Atlantis':
            env = AtariWrapper(env, img_size=img_size, n_skips=4, score_limit=100000)
            env = RewardFilter(env, clip=1.0, penalty_on_lost=True)
        elif env_name == 'Centipede':
            env = AtariWrapper(env, img_size=img_size, n_skips=4)
            env = RewardFilter(env, reward_coef=0.01, clip=1.0, penalty_on_lost=True)
            env = ActionFilter(env, ['NOOP', 'FIRE', 'UP', 'RIGHT', 'LEFT', 'DOWN'])
        elif env_name == 'Assault':
            env = AtariWrapper(env, img_size=img_size, n_skips=4)
            env = RewardFilter(env, reward_coef=0.01, clip=1.0, penalty_on_lost=True)
            env = ActionFilter(env, ['UP', 'RIGHT', 'LEFT', 'RIGHTFIRE', 'LEFTFIRE'])
        elif env_name == 'StarGunner':
            env = AtariWrapper(env, img_size=img_size, n_skips=4)
            env = RewardFilter(env, reward_coef=0.01, clip=1.0, penalty_on_lost=True)
            env = ActionFilter(env, ['NOOP', 'FIRE', 'UP', 'RIGHT', 'LEFT', 'DOWN', 'UPRIGHT', 'UPLEFT', 'DOWNRIGHT', 'DOWNLEFT',] )
        elif env_name == 'DemonAttack':
            env = AtariWrapper(env, img_size=img_size, n_skips=4)
            env = RewardFilter(env, clip=1.0, penalty_on_lost=True)
            env = ActionFilter(env, ['NOOP', 'FIRE', 'RIGHT', 'LEFT'])
        elif env_name == 'UpNDown':
            env = AtariWrapper(env, img_size=img_size, n_skips=4)
            env = RewardFilter(env, reward_coef=0.01, clip=1.0, penalty_on_lost=True)
            env = ActionFilter(env, ['NOOP','FIRE', 'UP', 'DOWN'],{'UP':0.1})
        elif env_name == 'VideoPinball':
            env = AtariWrapper(env, img_size=img_size, n_skips=4)
            env = RewardFilter(env, reward_coef=0.001, clip=1.0, penalty_on_lost=True)
            env = ActionFilter(env, ['NOOP', 'FIRE', 'UP', 'RIGHT', 'LEFT', 'DOWN'])
        else:
            print("Default AtariWrapper")
            env = AtariWrapper(env, img_size=img_size, n_skips=4)
            env = RewardFilter(env, clip=1.0, penalty_on_lost=True)
        envs.append(env)
    return envs, gamma


# test( env, ".")
from PIL import Image, ImageDraw
def test(env_name):
    envs, _ = build_atari_env_and_param(env_name, 1, (84,84))
    env = envs[0]
    obs = env.reset()
    print(env.get_action_meanings())
    print(env.action_space)
    print(env.observation_space)
    imgs = []
    for i in range(90):
        obs, reward, terminated, truncated, info, done = env.step(0)
    for n in range(0,16):
        # print(obs.shape)
        plt.figure(figsize=(16,4))
        for i in range(4):
            img = obs[:,:,i]
            pil_img = Image.fromarray(img)
            plt.subplot(1, 4, i+1)
            plt.imshow(np.asarray(pil_img), cmap='gray',interpolation='none')
            plt.axis('off')
        plt.show()
        action = np.random.randint(0,env.action_space.n)
        action = 0
        obs, reward, done, info = env.step(action)
        # print(reward, action)
        if done:
            obs = env.reset()


import warnings
warnings.simplefilter('ignore', DeprecationWarning)
test("SpaceInvaders")




error: OpenCV(4.5.4) :-1: error: (-5:Bad argument) in function 'cvtColor'
> Overload resolution failed:
>  - src is not a numerical tuple
>  - Expected Ptr<cv::UMat> for argument 'src'


### Network

In [None]:
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import models

def conv(x, filters, kernel_size, strides=1, padding='VALID', norm=None, activation='relu', kernel_initializer='glorot_normal'):
    x = layers.Conv2D(filters, kernel_size, strides=strides, padding=padding, use_bias=not norm, kernel_initializer=kernel_initializer)(x)
    if norm=='batch':
        x = layers.BatchNormalization()(x)
    elif norm=='layer':
        x = layers.LayerNormalization()(x)
    x = layers.Activation(activation)(x)
    return x

def dense(x, filters, norm=None, activation='relu', kernel_initializer='glorot_normal'):
    x = layers.Dense(filters, use_bias=not norm, kernel_initializer=kernel_initializer)(x)
    if norm=='batch':
        x = layers.BatchNormalization()(x)
    elif norm=='layer':
        x = layers.LayerNormalization()(x)
    x = layers.Activation(activation)(x)
    return x

def build_atari_model(in_shape):
    inputs = layers.Input(in_shape)
    x = layers.Lambda( lambda x: x/255.0)(inputs)

    kernel_initializer = tf.initializers.Orthogonal(np.sqrt(2))
    activation = 'relu'
    norm = None
    if x.shape[1]<84:
        pad = (84-x.shape[1])//2
        x = tf.pad(x, [[0,0],[pad,pad],[pad,pad],[0,0]] )
    x = conv(x, 32, (8, 8), strides=4,
            activation=activation, kernel_initializer=kernel_initializer, norm=norm)
    x = conv(x, 64, (4, 4), strides=2,
            activation=activation, kernel_initializer=kernel_initializer, norm=norm)
    x = conv(x, 64, (3, 3), strides=1,
            activation=activation, kernel_initializer=kernel_initializer, norm=norm)
    x = layers.Flatten()(x)
    x = dense(x, 512,
            activation=activation, kernel_initializer=kernel_initializer, norm=norm)

    network = models.Model(inputs, x)
    return network



network = build_atari_model((84,84,4))
network.summary()



##Play Out

In [None]:
from traitlets.traitlets import DottedObjectName
# !pip install gym-notebook-wrapper
# import gnwrapper
import gym
import numpy as np
from PIL import Image, ImageDraw

import os

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import matplotlib.cm as cm

from PIL import Image, ImageDraw, ImageFont

from IPython.display import display,HTML
matplotlib.interactive(False)

#https://github.com/horoiwa/deep_reinforcement_learning_gallery/blob/master/MuZero/visualize.py

def draw_frame_and_info(env, rgb, info, draw_type='raw'):
    step, action, reward, score, value, lives = info

    action_name = env.get_action_meanings()[action]
    img_frame = Image.fromarray(rgb)

    if draw_type=='raw':
        return img_frame

    if draw_type=='wide':
        resized_img = img_frame.resize((img_frame.width*2, img_frame.height), Image.NEAREST)
        return resized_img

    img_desc = Image.new('RGB', (160, img_frame.height), color="black")
    fnt = ImageFont.truetype("LiberationSans-Regular.ttf", 12)
    fnt_sm = ImageFont.truetype("LiberationSans-Regular.ttf", 10)

    pl = 10
    pb = 20

    v = str(round(value, 2))

    draw = ImageDraw.Draw(img_desc)
    draw.fontmode = '1'
    draw.text((pl, 20+pb*0), f"Step: {step:,}", font=fnt, fill="white")
    draw.text((pl, 20+pb*1), f"Lives: {int(lives):,}", font=fnt, fill="white")
    draw.text((pl, 20+pb*2), f"Score: {int(score):,}", font=fnt, fill="white")
    draw.text((pl, 20+pb*3), f"Action: {action_name}", font=fnt, fill="white")
    draw.text((pl, 20+pb*4), f"Reward: {reward:.4f}", font=fnt, fill="white")
    draw.text((pl, 20+pb*5), f"V(s): {v}", font=fnt, fill="white")


    img_bg = Image.new(
        'RGB', (img_frame.width + img_desc.width, img_frame.height))

    img_bg.paste(img_frame, (0, 0))
    img_bg.paste(img_desc, (img_frame.width, 0))

    return img_bg


def play_out(env, policy, filename='_recording.png', verbose=0,
              video_fps=12,
              deterministic=False, draw_type='info'):
    ims = []
    is_atari = False
    if hasattr(env, 'set_recording'):
        is_atari = True
        env.set_recording(True)

    # if hasattr(env, 'start_recording'):
    #     env.start_recording(filename, 30.0)
    if hasattr(env, 'set_true_reset'):
        env.set_true_reset()

    if hasattr(env, 'disable_penalty'):
        env.disable_penalty()

    state = env.reset()

    score = 0
    done = False
    step_count = 0
    no_reward_count = 0

    frames = []
    frame_count=0
    video_skip_num = 60//video_fps


    truncated = False
    video = None
    lives = 0

    while not done:
        if policy==None:
            action = np.random.randint(0,env.action_space.n)
            value = 0.0
            prob = [0.0]
        else:
            values, logits = policy.sample(np.expand_dims(state,0))

            def softmax(x):
                x = np.exp(x - np.max(x))
                return x / np.sum(x)

            prob = softmax(logits[0])

            if deterministic:
                action = np.argmax(logits[0])
            else:
                action = np.random.choice( len(prob), p=prob)
            value = values[0][0]

        next_state, reward, terminated, truncated, info, dont = env.step(action)
        if any(info):
            lives  = info[ATARI_LIVES_KEY] if ATARI_LIVES_KEY in info else 0

        if is_atari:
            score = env.true_score
            frame_info = (step_count, action, reward, score, value, lives)
            for frame in env.get_raw_frames():
                frames.append(frame)
                if len(frames)==video_skip_num:
                    frames = []
                    rgb = draw_frame_and_info(env, frame, frame_info , draw_type)
                    rgb = np.array(rgb)[:,:,::-1]
                    if video==None:
                        fourcc = cv2.VideoWriter_fourcc('p','n','g', ' ')
                        video  = cv2.VideoWriter( filename, fourcc, video_fps, (rgb.shape[1],rgb.shape[0]) )
                    video.write(rgb)
                frame_count+=1
        else:
            score = score+reward

        if verbose>=2:
            if reward!=0 or done:
                print(step_count,reward, int(score), action, done, info)

        if done:
            if any(info):
                if 'TimeLimit.truncated' in info: # max step
                    break
                elif 'episode' in info:
                    if 'truncated' in info['episode']:
                        print('truncated')
                    if 'score' in info['episode']:
                        score = info['episode']['score']
                    break
                done = False
                env.reset()
        else:
            state = next_state

        step_count+=1
        if step_count>=10000:
            truncated=True
            break


    if verbose!=0:
        print( f'step_count={step_count}, frames={len(frames)}, score={score}' )

    if is_atari and filename:
        if len(frames)!=0:
            rgb = draw_frame_and_info(env, frames[-1], frame_info , draw_type)
            rgb = np.array(rgb)[:,:,::-1]
        for i in range(video_fps*4):
            video.write(rgb)
        video.release()

        if verbose!=0:
            fsize = os.path.getsize(filename)
            print( f'{filename} size={fsize:,}, frames={frame_count:,}')

    return int(score), truncated

import subprocess
import sys

def convert_avi_file(in_filename, out_filename):
    if out_filename.endswith('.gif'):
        ffmpg_options = '-filter_complex "[0:v]palettegen=reserve_transparent=0[pal];[0:v][pal]paletteuse=dither=none"'
    elif out_filename.endswith('.png'):
        ffmpg_options = '-filter_complex "[0:v]palettegen=reserve_transparent=0[pal];[0:v][pal]paletteuse=dither=none" -f apng'
    else:
        ffmpg_options = '-vb 500k'
    cmdline = f'ffmpeg -y -i {in_filename} {ffmpg_options} {out_filename}'
    try:
        result = subprocess.run(cmdline, shell=True, check=True,
                                stdout=subprocess.PIPE, stderr=subprocess.PIPE,
                                universal_newlines=True)
        # print(result)
    except subprocess.CalledProcessError as e:
        print(e)


env_name = 'VideoPinball'

envs, gamma = build_atari_env_and_param(env_name,1)
model = make_network(env_name, envs[0].observation_space.shape)
print(envs[0].get_action_meanings())
print(envs[0].action_space.n)
action_space = envs[0].action_space.n

policy = PPO(model, action_space=action_space)
optimizer = tf.keras.optimizers.Adam(0.1, epsilon=1e-7)
policy.compile(optimizer)
policy.build((1,84,84,4))


# score, truncated = play_out( envs[0], policy, verbose=1, video_fps=12, deterministic=False, filename='_recording.avi')
# convert_avi_file('_recording.avi', 'test.gif')

#訓練

In [None]:
!nvidia-smi
import tensorflow as tf
tf.__version__

## TensorBoard

In [None]:
%tensorboard --logdir /content/logs

## 実行

In [None]:
#@title  { form-width: "300px" }

env_name = 'Breakout' #@param ['Amidar','Alien','Asteroids','Assault','Asterix','Atlantis','BattleZone','BeamRider','Boxing','Breakout','Bowling','BankHeist','Centipede','CrazyClimber','ChopperCommand','DoubleDunk','DemonAttack','Enduro','FishingDerby','Freeway','Frostbite','Gravitar','Gopher','IceHockey', 'Jamesbond', 'Kangaroo','KungFuMaster','MsPacman','MontezumaRevenge','NameThisGame', 'Pitfall','Pong','PrivateEye','Qbert','Riverraid','Robotank','RoadRunner','SpaceInvaders','Seaquest','TimePilot','Tennis','UpNDown','VideoPinball','WizardOfWor','Zaxxon']

#
# PPO
#
num_agents = 8
lr = 2.5e-4
envs, gamma = build_atari_env_and_param(env_name,num_agents+1)
policy = learn( mode='ppo', prev_policy=None, env_name=env_name, envs=envs,
               lr=lr, gamma=gamma,num_frames=10*1000*1000,
               num_agents=num_agents, trajectory_length=128,
               batch_size=None, num_epochs=4, num_batches=8,
               entropy_coef=0.01, value_coef=0.5, pi_clip_range=0.1,
               test_play_interval=250, use_tensorboard=True, rand_seed=123)
#
# A2C
#
# num_agents = 16
# lr = 1e-3
# envs, gamma = build_atari_env_and_param(env_name,num_agents+1)
# policy = learn( mode='a2c', prev_policy=None, env_name=env_name, envs=envs,
#                 lr=lr, gamma=gamma,num_frames=10*1000*1000,
#                 num_agents=num_agents, trajectory_length=5,
#                 batch_size=None, num_epochs=1, num_batches=1,
#                 entropy_coef=0.01, value_coef=0.5,
#                 test_play_interval=250, use_tensorboard=True, rand_seed=123)
