In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import random
import pdb
import torch
from torch.optim import Adam
import gym
import time
import wandb
from spinup import models
core = models

In [None]:
from spinup.utils.mpi_pytorch import setup_pytorch_for_mpi, sync_params, mpi_avg_grads
from spinup.utils.mpi_tools import mpi_fork, mpi_avg, proc_id, mpi_statistics_scalar, num_procs
from spinup.algos.ppo.ppo import ppo
from spinup.algos.sac.sac import sac

In [None]:
class Config(object):
    def __init__(self):
        return None
    def toDict(self):
        pr = {}
        for name in dir(self):
            value = getattr(self, name)
            if not name.startswith('__') and not callable(value) and not name.endswith('_'):
                pr[name] = value
        return pr
    
class TabularLogger(object):
    """
    A text interface logger, outputs mean and std several times per epoch
    """
    def __init__(self):
        self.buffer = {}
        
    def log(dic, commit=False):
        if commit:
            print
        
class Logger(object):
    """
    A logger wrapper for visualized loggers, such as tb or wandb
    Automatically counts steps, epoch, etc. and sets logging interval
    to prevent the log becoming too big
    uses kwargs instead of dict for convenience
    all None valued keys are counters
    """
    def __init__(self, logger):
        self.logger = logger
        self.counters = {'epoch':0}
        self.frequency = 10 # logs per epoch
        
    def log(self, data=None, **kwargs):
        if data is None:
            data = {}
        data.update(kwargs)
        # counting
        for key in data:
            if not key in self.counters:
                self.counters[key] = 0
            self.counters[key] += 1
            
        to_store = {}
        epoch = self.counters['epoch']
        for key in data:
            count = self.counters[key]
            period = count//(epoch+1) + 1
            flag = random.random()< self.frequency/period
            if flag:
                if data[key] is None:
                    to_store[key] = self.counters[key]
                else:
                    valid = True
                    if isinstance(data[key], torch.Tensor):
                        data[key] = data[key].detach().cpu()
                        if  torch.isnan(data[key]).any():
                            valid = False
                    elif np.isnan(data[key]).any():
                        valid = False
                    if not valid:
                        print(f'{key} is nan!')
                        continue
                    to_store[key] = data[key]
                
        if len(to_store) > 0:
            self.logger.log(to_store, commit=True)
        
    def flush(self):
        self.logger.log(data={'epoch':self.counters['epoch']}, commit=True)
        


# Pick an Environment
CartPole, the MNIST of RL, discrete action, parameterized state

Breakout, discrete action, 0 fire, 1 stay, 2 right, 3 left
the agent need to press the "fire" button or it gets stuck forever
supporting both parameterized state or visual input

MountainCar, discrete or continous action, parameterized state
Gets reward for climbing up a hill that costs energy,
painful exploration is essential

Walker, peanlty -100 for falling. 
The initial greedy strategy may make the agent stand unmoved and prevent falling
As a result, intial test reward =0 while initial train reward=100

For environments with a large penalty, we should use a large batch when updating Q, in order to compensate the variance

In [None]:
env_name="BipedalWalker-v3" # continous action
#env_name="Breakout-ram-v0" # discrete action
#env_name="CartPole-v1" #discrete action

# Run
    DQN, tested on CartPole
    PPO, tested on Breakout
    SAC-continous, tested on Walker, alpha = 0.2 is too large, 0.05 still large
    SAC-discrete, tested on CartPole
        must use eps to prevent nan because probablity for some action becomes 0
        this happens when Q is large (e.g. a few hundred)

## DQN

In [None]:
args = Config()
args.env=env_name #discrete action
args.algorithm="dqn"
args.name=f"{args.env}_{args.algorithm}"
args.gpu=0
args.seed=0
args.cpu=4
args.steps_per_epoch=5000
args.epochs=500

model_args=Config()
model_args.hidden_sizes=[256]*4
model_args.activation=torch.nn.ReLU
model_args = Config()
model_args.gamma=0.99
model_args.polyak=0.995
model_args.lr=3e-5
model_args.alpha=0
model_args.eps=0.01
model_args.dqn=True
args.model_args = model_args.toDict()

run=wandb.init(
    project="RL",
    config=args,
    name=args.name,
    group=args.env,
)
logger = Logger(run)
env = gym.make(args.env)
model = core.MLPDQActorCritic(env.observation_space, env.action_space, logger=logger, **(model_args.toDict()))
result =sac(lambda : gym.make(args.env), model=model, logger=logger, 
           steps_per_epoch=args.steps_per_epoch, epochs=args.epochs)
run.finish()

## SAC

In [None]:
args = Config()
args.env=env_name #discrete action
args.algorithm="sac"
args.name=f"{args.env}_{args.algorithm}"
args.gpu=0
args.seed=0
args.cpu=4

algo_args = Config()
algo_args.n_step=4096
algo_args.n_update=50
algo_args.batch_size=2048
algo_args.epochs=9999
algo_args.start_steps=20000
algo_args.update_after=20000

model_args=Config()
model_args.hidden_sizes=[256]*4
model_args.activation=torch.nn.ReLU
model_args = Config()
model_args.gamma=0.99
model_args.polyak=0.995
model_args.lr=3e-5
model_args.alpha=0.02
model_args.eps=0
model_args.dqn=False

args.algo_args = algo_args.toDict()
args.model_args = model_args.toDict()

run=wandb.init(
    project="RL",
    config=args,
    name=args.name,
    group=args.env,
)
logger = Logger(run)
env = gym.make(args.env)
model = core.MLPDQActorCritic(env.observation_space, env.action_space, logger=logger, **(model_args.toDict()))
device=0
model.to(device)
result =sac(lambda : gym.make(args.env), model=model, logger=logger,  device=device, **(algo_args.toDict()))
run.finish()

## PPO

In [None]:
args = Config()
args.env=env_name #discrete action
args.algorithm="ppo"
args.name=f"{args.env}_{args.algorithm}"
args.gpu=0
args.seed=0
args.cpu=4
args.steps_per_epoch=5000
args.epochs=500

model_args=Config()
model_args.hidden_sizes=[256]*4
model_args.activation=torch.nn.ReLU
model_args = Config()
model_args.gamma=0.99
model_args.polyak=0.995
model_args.lr=3e-5
model_args.alpha=0
model_args.eps=0.01
model_args.dqn=True
args.model_args = model_args.toDict()


#mpi_fork(args.cpu)  # run parallel code with mpi
run=wandb.init(
    project="RL",
    config=args,
    name=args.name,
    group=args.env,
)
logger = Logger(run)
result = ppo(lambda : gym.make(args.env), actor_critic=core.MLPVActorCritic,
    ac_kwargs=dict(hidden_sizes=(args.hid,)*args.l, logger=logger,  gamma=args.gamma, 
        seed=args.seed, steps_per_epoch=args.steps_per_epoch, epochs=args.epochs)
run.finish()

# Test

In [None]:
import gym

env = gym.make(env_name)
state = env.reset()

total = 0
for _ in range(2000):
    tmp = torch.tensor(state).float()
    action = model.act(tmp)
   # action = env.action_space.sample()
    state, reward, done, info  = env.step(action)
    total += reward
    if done:
        print(f"episode len {_}, reward {total}")
        break

# Visualization

In [None]:
import gym
from IPython import display
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

env = gym.make(env_name)
state = env.reset()
img = plt.imshow(env.render(mode='rgb_array')) # only call this once
total = 0
for _ in range(2000):
    img.set_data(env.render(mode='rgb_array')) # just update the data
    display.display(plt.gcf())
    display.clear_output(wait=True)
    tmp = torch.tensor(state).float()
    action = model.act(tmp, deterministic=True)
   # action = env.action_space.sample()
    state, reward, done, info  = env.step(action)
    total += reward
    if done:
        print(f"episode len {_}, reward {total}")
        break

In [None]:
x = torch.randn(1)
x.requires_grad = True
y = x+torch.zeros(1)/torch.zeros(1)
y.backward()

In [None]:
torch.nn.LogSoftmax()

In [None]:
optim = torch.optim

In [None]:
optim.Adam()

In [None]:
from tqdm import tqdm

In [None]:
pbar = iter(tqdm(range(100)))

In [None]:
pbar()

In [None]:
next(pbar)