In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pdb
import torch
from torch.optim import Adam
import gym
import time
import wandb
from spinup import models
core = models

In [None]:
from spinup.utils.mpi_pytorch import setup_pytorch_for_mpi, sync_params, mpi_avg_grads
from spinup.utils.mpi_tools import mpi_fork, mpi_avg, proc_id, mpi_statistics_scalar, num_procs
from spinup.algos.ppo.ppo import ppo
from spinup.algos.sac.sac import sac

In [None]:
class Config(object):
    def __init__(self):
        return None
    def toDict(self):
        pr = {}
        for name in dir(self):
            value = getattr(self, name)
            if not name.startswith('__') and not callable(value) and not name.endswith('_'):
                pr[name] = value
        return pr
    
class TabularLogger(object):
    """
    A text interface logger, outputs mean and std several times per epoch
    """
    def __init__(self):
        self.buffer = {}
        
    def log(dic, commit=False):
        if commit:
            print
        
class Logger(object):
    """
    A logger wrapper for visualized loggers, such as tb or wandb
    Automatically counts steps, epoch, etc. and sets logging interval
    to prevent the log becoming too big
    uses kwargs instead of dict for convenience
    all None valued keys are counters
    """
    def __init__(self, logger=logger):
        self.counters = {'epoch':0}
        self.frequency = 10 # logs per epoch
        
    def log(self, data=None, **kwargs):
        if data is None:
            data = {}
        data.update(kwargs)
        # counting
        for key in data:
            if not key in self.counters:
                self.counters[key] = 0
            self.counters[key] += 1
            
        to_store = {}
        epoch = self.counters['epoch']
        for key in data:
            count = self.counters[key]
            period = count//(epoch+1)//self.frequency + 1
            flag = count%period == 0
            if flag:
                if data[key] is None:
                    to_store[key] = self.counters[key]
                else:
                    to_store[key] = data[key]
                
        if len(to_store) > 0:
            self.logger.log(to_store)
        
    def flush(self):
        self.logger.log(data={'epoch':self.counters['epoch']}, commit=True)
        
args = Config()
#args.env="CartPole-v1"
#args.env="Hopper-v2"
args.env="Breakout-ram-v0"
args.algorithm="sac"
args.name=f"{args.env}_{args.algorithm}"
args.gpu=0
args.seed=0
args.hid=256
args.l=6
args.gamma=0.99
args.cpu=4
args.steps_per_epoch=5000
args.epochs=500
args.v_lr=1e-4
args.pi_lr=1e-5
args.activation=torch.nn.ReLU

# Run

## SAC

In [None]:
run=wandb.init(
    project="RL",
    config=args,
    name=args.name,
    group=args.env,
)
logger = Logger(run)
result =sac(lambda : gym.make(args.env), actor_critic=core.MLPDQActorCritic,
    ac_kwargs=dict(hidden_sizes=[args.hid]*args.l), logger=logger, 
           steps_per_epoch=args.steps_per_epoch, epochs=args.epochs)
run.finish()

## PPO

In [None]:
#mpi_fork(args.cpu)  # run parallel code with mpi
run=wandb.init(
    project="RL",
    config=args,
    name=args.name,
    group=args.env,
)
logger = Logger(run)
result = ppo(lambda : gym.make(args.env), actor_critic=core.MLPVActorCritic,
    ac_kwargs=dict(hidden_sizes=(args.hid,)*args.l, logger=logger,  gamma=args.gamma, 
        seed=args.seed, steps_per_epoch=args.steps_per_epoch, epochs=args.epochs)
run.finish()

# Visualization

In [None]:
import gym
from IPython import display
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

env = gym.make('Breakout-ram-v0')
state = env.reset()
img = plt.imshow(env.render(mode='rgb_array')) # only call this once
total = 0
for _ in range(2000):
    img.set_data(env.render(mode='rgb_array')) # just update the data
    display.display(plt.gcf())
    display.clear_output(wait=True)
    tmp = torch.tensor(state).float()
    action = result.pi(tmp)[0].sample()
   # action = env.action_space.sample()
    state, reward, done, info  = env.step(action)
    total += reward
print(total)

In [None]:
x = env.action_space

In [None]:
env = gym.make('Breakout-ram-v0')

In [None]:
env.observation_space.shape

In [None]:
env.action_space.shape

In [None]:
env.reset()