In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import torch
from torch.optim import Adam
import gym
import time
import wandb
from spinup import models
core = models

In [None]:
from spinup.utils.mpi_pytorch import setup_pytorch_for_mpi, sync_params, mpi_avg_grads
from spinup.utils.mpi_tools import mpi_fork, mpi_avg, proc_id, mpi_statistics_scalar, num_procs
from spinup.algos.ppo.ppo import ppo

In [None]:
class Config(object):
    def __init__(self):
        return None
    
class Logger(object):
    """
    uses kwargs instead of dict for convenience
    pass an int: logging per epoch
    pass nothing: logginer per run
    pass a name: logging per step (model update step, env interaction step...)
    """
    def __init__(self, logger):
        self.logger = logger
        self.epoch = 0
        self.counters = {}
        
    def log(self, data={}, step=None, **kwargs):
        if isinstance(step, int):
            self.epoch = step
            data['epoch'] = step
        elif isinstance(step, str):
            if not step in self.counters:
                self.counters[step] = 0
            data[step] = self.counters[step]
            self.counters[step] += 1
        for key in kwargs:
            data[key] = kwargs[key]
        self.logger.log(data=data)
        
args = Config()
#args.env="CartPole-v1"
#args.env="Hopper-v2"
args.env="Breakout-ram-v0"
args.algorithm="ppo"
args.name=f"{args.env}_{args.algorithm}"
args.gpu=0
args.seed=0
args.hid=256
args.l=6
args.gamma=0.99
args.cpu=4
args.steps=5000
args.epochs=500
args.activation=torch.nn.ReLU

# Run

In [None]:
#mpi_fork(args.cpu)  # run parallel code with mpi
run=wandb.init(
    project="RL",
    config=args,
    name=args.name,
    group=args.env,
)
logger = Logger(run)
result = ppo(lambda : gym.make(args.env), actor_critic=core.MLPActorCritic,
    ac_kwargs=dict(hidden_sizes=(args.hid,)*args.l, activation=args.activation), gamma=args.gamma, 
    seed=args.seed, steps_per_epoch=args.steps, epochs=args.epochs, logger=logger)
run.finish()

# Visualization

In [None]:
import gym
from IPython import display
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

env = gym.make('Breakout-ram-v0')
env.reset()
img = plt.imshow(env.render(mode='rgb_array')) # only call this once
total = 0
for _ in range(500):
    img.set_data(env.render(mode='rgb_array')) # just update the data
    display.display(plt.gcf())
    display.clear_output(wait=True)
    tmp = torch.tensor(state).float()
    action = result.pi(tmp)[0].sample()
   # action = env.action_space.sample()
    state, reward, done, info  = env.step(action)
    total += reward
print(total)