In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import random
import pdb
import torch
from torch.optim import Adam
import gym
import time
import wandb

# Pick an Environment
CartPole, the MNIST of RL, discrete action, parameterized state

Breakout, discrete action, 0 fire, 1 stay, 2 right, 3 left
    Notice that breakout-ram is hard because the state is 128 bytes from the ram
    and the bytes do not have an intuitive meaning
    also notice that 50M samples is typical for DQNs with visual input (refer to rainbow)

MountainCar, discrete or continous action, parameterized state
Gets reward for climbing up a hill that costs energy,
painful exploration is essential

Walker, peanlty -100 for falling. 
The initial greedy strategy may make the agent stand unmoved and prevent falling
As a result, intial test reward =0 while initial train reward=100

For environments with a large penalty, we should use a large batch when updating Q, in order to compensate the variance

In [None]:
from gym.wrappers import FrameStack

class BreakoutWrapper(gym.ObservationWrapper):
    """ 
    takes (210, 160, 3) to (40, 40)
    stops training when one life is lost
    converts to grey scale float
    cuts the margins
    
    fires the ball by pressing action 1
    
    wrapped by framestack (not wrapping FrameStack) to utilize lazy frame for memory saving
    """
    def __init__(self, env):
        gym.Wrapper.__init__(self, env)
        self.observation_space = gym.spaces.Box(0, 1, (170, 160))
        self.pooling = torch.nn.AvgPool2d(kernel_size=(4,4), stride=(4, 4))

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        lives = self.env.unwrapped.ale.lives()
        if lives < self.lives:
            done = True
        self.lives = lives
        return self.observation(obs), reward, done, info

    def reset(self, **kwargs):
        self.env.reset(**kwargs)
        self.lives = self.env.unwrapped.ale.lives()
        obs, _, _, _ = self.step(1)
        return obs

    def observation(self, observation):
        observation = np.array(observation).astype(np.float32) / 255.0
        observation = observation[30:-17] 
        observation = np.mean(observation, axis=2) # greyscale
        tmp = torch.as_tensor(observation).unsqueeze(0)
        observation = np.array(self.pooling(tmp).squeeze(0))
        return observation

## Breakout

In [None]:
from matplotlib import pyplot as plt
env_name = 'Breakout-v0'
env_fn = lambda: FrameStack(BreakoutWrapper(gym.make(env_name)), 4)

env = env_fn()
result  = np.array(env.reset())
result = np.array(result).transpose(1, 2, 0) # 0 is white
#plt.imshow(result[:, :, -1], cmap='Greys') 
plt.imshow(1-result[:, :, 1:4]) 

## CartPole

In [None]:
class CartpoleWrapper(gym.ObservationWrapper):
    def __init__(self, env):
        gym.ObservationWrapper.__init__(self, env)

    def observation(self, x):
        x = np.array(x, dtype=np.float32)
        return x
    
env_name = 'CartPole-v1'
env_fn = lambda: CartpoleWrapper(gym.make(env_name))

env = env_fn()
result  = np.array(env.reset())
print(result, result.dtype)

# Pick an Agent
    QLearning
    SAC-discrete, tested on CartPole, run on Breakout
    MBPO

In general, when an algo does not work, try large batch low lr with few updates

It is okay if the loss of Q increases significantly

# MBPO

## Config

In [None]:
from utils import Config
from models import MLP
from agents import MBPO
"""
    the hyperparameters are the same as MBPO
"""
algo_args = Config()

algo_args.n_warmup=int(5e3)
"""
 rainbow said 2e5 samples or 5e4 updates is typical for Qlearning
 bs256lr3e-4, it takes 2e4updates
 for the model on CartPole to learn done...

 Only 3e5 samples are needed for parameterized input continous motion control
"""
algo_args.replay_size=int(1e6)
algo_args.max_ep_len=500
algo_args.test_interval = int(1e4)
algo_args.seed=0
algo_args.batch_size=256 # the same as MBPO
algo_args.save_interval=600 # in seconds
algo_args.log_interval=int(2e3/200)
algo_args.n_step=int(1e8)

p_args=Config()
p_args.network = MLP
p_args.activation=torch.nn.ReLU
p_args.lr=3e-4
p_args.sizes = [4, 16, 32, 3] 
p_args.update_interval=1/10
"""
 bs=32 interval=4 from rainbow Q
 MBPO retrains fram scratch periodically
 in principle this can be arbitrarily frequent
"""
p_args.n_p=7 # ensemble
p_args.refresh_interval=int(1e3) # refreshes the model buffer
# ideally rollouts should be used only once
p_args.branch=400
p_args.roll_length=1 # length > 1 not implemented yet

q_args=Config()
q_args.network = MLP
q_args.activation=torch.nn.ReLU
q_args.lr=3e-4
q_args.sizes = [4, 16, 32, 3] # 2 actions, dueling q learning
q_args.update_interval=1/20
# MBPO used 1/40 for continous control tasks
# 1/20 for invert pendulum

pi_args=Config()
pi_args.network = MLP
pi_args.activation=torch.nn.ReLU
pi_args.lr=3e-4
pi_args.sizes = [4, 16, 32, 2] 
pi_args.update_interval=1/20

agent_args=Config()
agent_args.agent=MBPO
agent_args.gamma=0.99
agent_args.alpha=0.2 
agent_args.target_sync_rate=5e-3
# called tau in MBPO
# sync rate per update = update interval/target sync interval

args = Config()
args.env_name=env_name
args.name=f"{args.env_name}_{agent_args.agent}"
device = 0

q_args.env_fn = env_fn
agent_args.env_fn = env_fn
algo_args.env_fn = env_fn

agent_args.p_args = p_args
agent_args.q_args = q_args
agent_args.pi_args = pi_args
algo_args.agent_args = agent_args
args.algo_args = algo_args # do not call toDict() before config is set

print(f"rollout reuse:{(p_args.refresh_interval/q_args.update_interval*algo_args.batch_size)/algo_args.replay_size}")
# each generated data will be used so many times

## Run

In [None]:
from algorithm import RL
from utils import Logger

RL(logger = Logger(args), device=device, **algo_args._toDict())


  0%|          | 44/1000000 [00:55<814:18:32,  2.93s/it][A
  0%|          | 45/1000000 [00:58<821:40:51,  2.96s/it][A
  0%|          | 46/1000000 [01:01<827:32:41,  2.98s/it][A
  0%|          | 47/1000000 [01:04<834:00:36,  3.00s/it][A
  0%|          | 48/1000000 [01:07<837:44:33,  3.02s/it][A
  0%|          | 49/1000000 [01:10<840:28:17,  3.03s/it][A
  0%|          | 50/1000000 [01:13<842:28:30,  3.03s/it][A
  0%|          | 51/1000000 [01:16<844:02:53,  3.04s/it][A
  0%|          | 52/1000000 [01:20<844:58:56,  3.04s/it][A
  0%|          | 53/1000000 [01:23<850:17:16,  3.06s/it][A
  0%|          | 54/1000000 [01:26<850:26:32,  3.06s/it][A
  0%|          | 55/1000000 [01:29<851:05:40,  3.06s/it][A
  0%|          | 56/1000000 [01:32<849:46:09,  3.06s/it][A
  0%|          | 57/1000000 [01:35<846:54:39,  3.05s/it][A
  0%|          | 58/1000000 [01:38<843:41:08,  3.04s/it][A
  0%|          | 59/1000000 [01:41<840:52:29,  3.03s/it][A
  0%|          | 60/1000000 [01:44<840:

In [None]:
from time import time

In [None]:
time()

# Test

In [None]:
import gym

env = gym.make(env_name)
state = env.reset()

total = 0
for _ in range(2000):
    tmp = torch.tensor(state).float()
    action = model.act(tmp)
   # action = env.action_space.sample()
    state, reward, done, info  = env.step(action)
    total += reward
    if done:
        print(f"episode len {_}, reward {total}")
        break

# Visualization

In [None]:
import gym
from IPython import display
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

img = plt.imshow(env.render(mode='rgb_array')) # only call this once
total = 0
for _ in range(2000):
    img.set_data(env.render(mode='rgb_array')) # just update the data
    display.display(plt.gcf())
    display.clear_output(wait=True)
    tmp = torch.as_tensor(state,  dtype=torch.float).to(device)
    action = model.act(tmp, deterministic=False)
   # action = env.action_space.sample()
    state, reward, done, info  = env.step(action)
    total += reward
    if done:
        print(f"episode len {_}, reward {total}")
        break

## Human Control

In [None]:
import gym
from IPython import display
import matplotlib


import matplotlib.pyplot as plt
%matplotlib inline

img = plt.imshow(env.reset()[-1]) # only call this once
total = 0
for _ in range(2000):
    tmp = input()
    if len(tmp) == 0:
        tmp = "0"
    action = int(tmp)
    state, reward, done, info  = env.step(action)
    total += reward
    
    display.clear_output(wait=True)
    img.set_data(state[-1]) # just update the data
    display.display(plt.gcf())
    

    print(f"this: {reward}, total: {total}")
    if done:
        print(f"episode len {_}, reward {total}")
        break

In [None]:
x = env.observation_space

In [None]:
vars(env.action_space)

In [None]:
torch.nn.Conv2d()

In [None]:
def func(x, **kwargs):
    return 0

In [None]:
func(**{'x':1})

In [None]:
env.observation_space

In [None]:
run.finish()

In [None]:
torch.save()