In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import random
import pdb
import torch
from torch.optim import Adam
import gym
import time
import wandb

# Pick an Environment
CartPole, the MNIST of RL, discrete action, parameterized state

Breakout, discrete action, 0 fire, 1 stay, 2 right, 3 left
    Notice that breakout-ram is hard because the state is 128 bytes from the ram
    and the bytes do not have an intuitive meaning
    also notice that 50M samples is typical for DQNs with visual input (refer to rainbow)

MountainCar, discrete or continous action, parameterized state
Gets reward for climbing up a hill that costs energy,
painful exploration is essential

Walker, peanlty -100 for falling. 
The initial greedy strategy may make the agent stand unmoved and prevent falling
As a result, intial test reward =0 while initial train reward=100

For environments with a large penalty, we should use a large batch when updating Q, in order to compensate the variance

In [None]:
from gym.wrappers import FrameStack

class BreakoutWrapper(gym.ObservationWrapper):
    """ 
    takes (210, 160, 3) to (40, 40)
    stops training when one life is lost
    converts to grey scale float
    cuts the margins
    
    fires the ball by pressing action 1
    
    wrapped by framestack (not wrapping FrameStack) to utilize lazy frame for memory saving
    """
    def __init__(self, env):
        gym.Wrapper.__init__(self, env)
        self.observation_space = gym.spaces.Box(0, 1, (170, 160))
        self.pooling = torch.nn.AvgPool2d(kernel_size=(4,4), stride=(4, 4))

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        lives = self.env.unwrapped.ale.lives()
        if lives < self.lives:
            done = True
        self.lives = lives
        return self.observation(obs), reward, done, info

    def reset(self, **kwargs):
        self.env.reset(**kwargs)
        self.lives = self.env.unwrapped.ale.lives()
        obs, _, _, _ = self.step(1)
        return obs

    def observation(self, observation):
        observation = np.array(observation).astype(np.float32) / 255.0
        observation = observation[30:-17] 
        observation = np.mean(observation, axis=2) # greyscale
        tmp = torch.as_tensor(observation).unsqueeze(0)
        observation = np.array(self.pooling(tmp).squeeze(0))
        return observation

## Breakout

In [None]:
from matplotlib import pyplot as plt
env_name = 'Breakout-v0'
env_fn = lambda: FrameStack(BreakoutWrapper(gym.make(env_name)), 4)

env = env_fn()
result  = np.array(env.reset())
result = np.array(result).transpose(1, 2, 0) # 0 is white
#plt.imshow(result[:, :, -1], cmap='Greys') 
plt.imshow(1-result[:, :, 1:4]) 

## CartPole

In [None]:
class CartpoleWrapper(gym.ObservationWrapper):
    def __init__(self, env):
        gym.ObservationWrapper.__init__(self, env)

    def observation(self, x):
        x = np.array(x, dtype=np.float32)
        return x
    
env_name = 'CartPole-v1'
env_fn = lambda: CartpoleWrapper(gym.make(env_name))

env = env_fn()
result  = np.array(env.reset())
print(result, result.dtype)

# Pick an Agent
    DQN, tested on CartPole
    PPO, tested on Breakout
    SAC-continous, run on Walker, alpha = 0.2 is too large, 0.05 still large
    SAC-discrete, tested on CartPole, run on Breakout
        1. must use eps to prevent nan because probablity for some action becomes 0, this happens when Q is large (e.g. a few hundred)
        2. If n_update or lr is much too high, entropy may collapse for a reasonable alpha, with a stable high policy regret
        3. Sometimes Q stays much lower than ground truth, and learns very slow. I believe there is a trick I need to incorporate
        4. when alpha is too large, maximum entropy, test significantly superior than train. when too small, zero entropy. 
        A heuristic: alpha leads to an additional reward about alpha*entropy, which should be smaller than the reward per step.
        5. on Breakout, 3 happens for Q, pi learns never firing the ball!? randomly picking an action from the other three...

In general, when an algo does not work, try large batch low lr with few updates

# DQN

## Config

In [None]:
from utils import Config
from models import CNN
from agents import QLearning
"""
    update_interval, save_interval, etc are counted per sample
    "epoch" is only used for logging 
    
    the configs are the same as rainbow,
    batchsize *8, lr * 4, update frequency/ 8
    no noisy q and therefore eps of 3e-2
"""
algo_args = Config()

algo_args.max_ep_len=2000
algo_args.q_update_interval=32
algo_args.batch_size=256
algo_args.n_warmup=int(2e5)
algo_args.replay_size=int(1e6)
algo_args.test_interval = int(1e4)
algo_args.seed=0
algo_args.save_interval=int(1e6)
algo_args.log_interval=int(1e5)
algo_args.n_step=int(1e8)

agent_args=Config()
agent_args.agent=QLearning
agent_args.gamma=0.99
agent_args.eps=3e-2
agent_args.target_sync_rate=algo_args.q_update_interval/32000

q_args=Config()
q_args.network = CNN
q_args.activation=torch.nn.ReLU
q_args.lr=2e-4
q_args.strides = [2]*6
q_args.kernels = [3]*6
q_args.paddings = [1]*6
q_args.sizes = [4, 16, 32, 64, 128, 128, 5] # 4 actions, dueling q learning

args = Config()
args.env_name="Breakout-v0"
args.name=f"{args.env_name}_{agent_args.agent}"
device = 0

agent_args.q_args = q_args
algo_args.agent_args = agent_args
args.algo_args = algo_args # do not call toDict() before config is set

## Run

In [None]:
from algorithm import RL

RL(logger = Logger(args), device=device, **algo_args._toDict())

# MBPO

## Config

In [None]:
from utils import Config
from models import MLP
from agents import MBPO
"""
    the hyperparameters are the same as MBPO
"""
algo_args = Config()

algo_args.n_warmup=int(5e3)
"""
 rainbow said 2e5 samples or 5e4 updates is typical for Qlearning
 bs256lr3e-4, it takes 2e4updates
 for the model on CartPole to learn done...

 Only 3e5 samples are needed for parameterized input continous motion control
"""
algo_args.replay_size=int(1e6)
algo_args.max_ep_len=500
algo_args.test_interval = int(1e4)
algo_args.seed=0
algo_args.batch_size=256 # the same as MBPO
algo_args.save_interval=600 # in seconds
algo_args.log_interval=int(2e3/200)
algo_args.n_step=int(1e8)

p_args=Config()
p_args.network = MLP
p_args.activation=torch.nn.ReLU
p_args.lr=3e-4
p_args.sizes = [4, 16, 32, 3] 
p_args.update_interval=1/10
"""
 bs=32 interval=4 from rainbow Q
 MBPO retrains fram scratch periodically
 in principle this can be arbitrarily frequent
"""
p_args.n_p=7 # ensemble
p_args.refresh_interval=int(1e3) # refreshes the model buffer
# ideally rollouts should be used only once
p_args.branch=400
p_args.roll_length=1 # length > 1 not implemented yet

q_args=Config()
q_args.network = MLP
q_args.activation=torch.nn.ReLU
q_args.lr=3e-4
q_args.sizes = [4, 16, 32, 3] # 2 actions, dueling q learning
q_args.update_interval=1/20
# MBPO used 1/40 for continous control tasks
# 1/20 for invert pendulum

pi_args=Config()
pi_args.network = MLP
pi_args.activation=torch.nn.ReLU
pi_args.lr=3e-4
pi_args.sizes = [4, 16, 32, 2] 
pi_args.update_interval=1/20

agent_args=Config()
agent_args.agent=MBPO
agent_args.gamma=0.99
agent_args.alpha=0.2 
agent_args.target_sync_rate=5e-3
# called tau in MBPO
# sync rate per update = update interval/target sync interval

args = Config()
args.env_name=env_name
args.name=f"{args.env_name}_{agent_args.agent}"
device = 0

q_args.env_fn = env_fn
agent_args.env_fn = env_fn
algo_args.env_fn = env_fn

agent_args.p_args = p_args
agent_args.q_args = q_args
agent_args.pi_args = pi_args
algo_args.agent_args = agent_args
args.algo_args = algo_args # do not call toDict() before config is set

print(f"rollout reuse:{(p_args.refresh_interval/q_args.update_interval*algo_args.batch_size)/algo_args.replay_size}")
# each generated data will be used so many times

## Run

In [None]:
from algorithm import RL
from utils import Logger

RL(logger = Logger(args), device=device, **algo_args._toDict())







  0%|          | 212/1000000 [08:21<848:20:33,  3.05s/it][A[A[A[A[A[A





  0%|          | 213/1000000 [08:24<847:28:28,  3.05s/it][A[A[A[A[A[A





  0%|          | 214/1000000 [08:27<849:28:07,  3.06s/it][A[A[A[A[A[A





  0%|          | 215/1000000 [08:30<848:51:04,  3.06s/it][A[A[A[A[A[A





  0%|          | 216/1000000 [08:33<849:29:06,  3.06s/it][A[A[A[A[A[A





  0%|          | 217/1000000 [08:36<848:26:08,  3.06s/it][A[A[A[A[A[A





  0%|          | 218/1000000 [08:39<850:32:30,  3.06s/it][A[A[A[A[A[A





  0%|          | 219/1000000 [08:42<850:36:55,  3.06s/it][A[A[A[A[A[A





  0%|          | 220/1000000 [08:46<851:35:48,  3.07s/it][A[A[A[A[A[A





  0%|          | 221/1000000 [08:49<849:26:50,  3.06s/it][A[A[A[A[A[A





  0%|          | 222/1000000 [08:52<849:30:06,  3.06s/it][A[A[A[A[A[A





  0%|          | 223/1000000 [08:55<849:00:44,  3.06s/it][A[A[A[A[A[A





  0%|          | 224/1

checkpoint save as 2443.pt








  0%|          | 245/1000000 [10:02<851:25:06,  3.07s/it][A[A[A[A[A[A





  0%|          | 246/1000000 [10:05<849:37:35,  3.06s/it][A[A[A[A[A[A





  0%|          | 247/1000000 [10:08<851:38:47,  3.07s/it][A[A[A[A[A[A





  0%|          | 248/1000000 [10:11<851:23:19,  3.07s/it][A[A[A[A[A[A





  0%|          | 249/1000000 [10:14<850:16:01,  3.06s/it][A[A[A[A[A[A





  0%|          | 250/1000000 [10:17<849:34:48,  3.06s/it][A[A[A[A[A[A





  0%|          | 251/1000000 [10:20<851:27:52,  3.07s/it][A[A[A[A[A[A





  0%|          | 252/1000000 [10:24<851:10:17,  3.06s/it][A[A[A[A[A[A





  0%|          | 253/1000000 [10:27<851:37:58,  3.07s/it][A[A[A[A[A[A





  0%|          | 254/1000000 [10:30<852:51:59,  3.07s/it][A[A[A[A[A[A





  0%|          | 255/1000000 [10:33<852:31:53,  3.07s/it][A[A[A[A[A[A





  0%|          | 256/1000000 [10:36<850:16:37,  3.06s/it][A[A[A[A[A[A





  0%|          | 257/1

checkpoint save as 4569.pt








  0%|          | 457/1000000 [20:00<845:56:08,  3.05s/it][A[A[A[A[A[A





  0%|          | 458/1000000 [20:03<845:11:29,  3.04s/it][A[A[A[A[A[A





  0%|          | 459/1000000 [20:06<843:33:23,  3.04s/it][A[A[A[A[A[A





  0%|          | 460/1000000 [20:09<840:57:18,  3.03s/it][A[A[A[A[A[A





  0%|          | 461/1000000 [20:12<841:41:19,  3.03s/it][A[A[A[A[A[A





  0%|          | 462/1000000 [20:15<840:09:59,  3.03s/it][A[A[A[A[A[A





  0%|          | 463/1000000 [20:18<841:11:12,  3.03s/it][A[A[A[A[A[A





  0%|          | 464/1000000 [20:22<846:46:17,  3.05s/it][A[A[A[A[A[A





  0%|          | 465/1000000 [20:25<847:56:45,  3.05s/it][A[A[A[A[A[A





  0%|          | 466/1000000 [20:28<851:13:42,  3.07s/it][A[A[A[A[A[A





  0%|          | 467/1000000 [20:31<852:42:34,  3.07s/it][A[A[A[A[A[A





  0%|          | 468/1000000 [20:34<853:51:10,  3.08s/it][A[A[A[A[A[A





  0%|          | 469/1

checkpoint save as 5647.pt








  0%|          | 565/1000000 [30:03<1939:15:39,  6.99s/it][A[A[A[A[A[A

In [None]:
from time import time

In [None]:
time()

# Test

In [None]:
import gym

env = gym.make(env_name)
state = env.reset()

total = 0
for _ in range(2000):
    tmp = torch.tensor(state).float()
    action = model.act(tmp)
   # action = env.action_space.sample()
    state, reward, done, info  = env.step(action)
    total += reward
    if done:
        print(f"episode len {_}, reward {total}")
        break

# Visualization

In [None]:
import gym
from IPython import display
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

img = plt.imshow(env.render(mode='rgb_array')) # only call this once
total = 0
for _ in range(2000):
    img.set_data(env.render(mode='rgb_array')) # just update the data
    display.display(plt.gcf())
    display.clear_output(wait=True)
    tmp = torch.as_tensor(state,  dtype=torch.float).to(device)
    action = model.act(tmp, deterministic=False)
   # action = env.action_space.sample()
    state, reward, done, info  = env.step(action)
    total += reward
    if done:
        print(f"episode len {_}, reward {total}")
        break

## Human Control

In [None]:
import gym
from IPython import display
import matplotlib


import matplotlib.pyplot as plt
%matplotlib inline

img = plt.imshow(env.reset()[-1]) # only call this once
total = 0
for _ in range(2000):
    tmp = input()
    if len(tmp) == 0:
        tmp = "0"
    action = int(tmp)
    state, reward, done, info  = env.step(action)
    total += reward
    
    display.clear_output(wait=True)
    img.set_data(state[-1]) # just update the data
    display.display(plt.gcf())
    

    print(f"this: {reward}, total: {total}")
    if done:
        print(f"episode len {_}, reward {total}")
        break

In [None]:
x = env.observation_space

In [None]:
vars(env.action_space)

In [None]:
torch.nn.Conv2d()

In [None]:
def func(x, **kwargs):
    return 0

In [None]:
func(**{'x':1})

In [None]:
env.observation_space

In [None]:
run.finish()

In [None]:
torch.save()