In [None]:
%load_ext autoreload
%autoreload 2

import argparse
import numpy as np
import torch
import os
import time
import math
import matplotlib.pyplot as plt

from run import PongRunner, default_args
from environment import PongEnv
from visualizers.visualize import render_layout
from modular_baselines.vca.algorithm import GradNormalizer

parser = argparse.ArgumentParser(description="Interactive Pong VCA")
default_args(parser)                        

In [None]:
args = {
 "buffer_size": 100000,
 "batchsize": 32,
 "rollout_len": 10,
 "total_timesteps": 30000,
 "entropy_coef": 0.0003,
 "use_gumbel": False,
 "grad_norm": False,
 "policy_hidden_size": 64,
 "transition_hidden_size": 64,
 "policy_tau": 1,
 "policy_lr": 0.003,
 "trans_lr": 0.001,
 "device": "cpu",
 "log_interval": 245,
 "eval_interval": 49,
 "trans_weight_decay": 0.1,
 "seed": None,
 "log_dir": "logs/"}

In [None]:
if args["seed"] is None:
    args["seed"] = np.random.randint(0, 2**20)

runner = PongRunner(args)
args["log_dir"] = os.path.join(args["log_dir"], runner.log_dir_prefix)

algorithm = runner.algo_generator(args)
algorithm.learn(args["total_timesteps"])

In [None]:
env = PongEnv()
FPS = 30
env.make_figure()

In [None]:
def to_torch(state):
    return torch.from_numpy(state).unsqueeze(0)

def episode_grad(random_act=False, render=True):
    r_state_list = []
    r_action_list = []
    reward_list = []
    transition_mae = []
    action_logits = []
    
    state = env.reset()
    state = to_torch(state)
    state.requires_grad = True

    done = False
    while not done:
        start_time = time.time()

        if random_act:
            act = env.action_space.sample()
        else:
#             act = algorithm.policy_module(state).item()
            act = algorithm.policy_module.net(state).argmax(1).item()
        action_logits.append(algorithm.policy_module.net(state).detach())

        action = algorithm._action_onehot(torch.tensor(act).reshape(1, 1))
        action.requires_grad = True
        action.retain_grad()
        state.retain_grad()

        next_state, reward, done, _ = env.step(act)
        next_state = to_torch(next_state)
        
        state = state.detach()
        state.requires_grad = True
        state.retain_grad()

        r_state_list.append(state)
        r_action_list.append(action)
        
#         state = GradNormalizer.apply(state)

        dist = algorithm.transition_module.dist(state, action)
        transition_mae.append((dist.loc - next_state).abs().detach())
        next_state = algorithm.transition_module.reparam(
            next_state, dist)
        expected_reward = algorithm._expected_reward(
                env.reward_info(), next_state)
        reward_list.append(expected_reward)
        
        state = next_state
        if render:
            env.render()
            time.sleep(abs(1/FPS - (time.time() - start_time)))
        
    sum(reward_list).backward()
#     reward_list[-1].backward()
    
    return r_state_list, r_action_list, reward_list, transition_mae, action_logits

states, actions, rewards, trans_mae, act_logits = episode_grad()
eps_act_logit = torch.cat(act_logits, dim=0)
eps_trans_mae = torch.cat(trans_mae, dim=0) 
eps_acts_grad = torch.cat(list(map(lambda act: act.grad, actions)), dim=0)
eps_state_grad = torch.cat(list(map(lambda state: state.grad, states)), dim=0)



In [None]:
plt.figure(figsize=(16, 7))

plt.subplot(141)
plt.imshow(eps_acts_grad)
plt.title("Action Gradients")
plt.colorbar()

plt.subplot(142)
plt.imshow(eps_state_grad)
plt.title("State Gradients")
plt.colorbar()

plt.subplot(143)
plt.imshow(eps_trans_mae)
plt.title("Transition MAE")
plt.colorbar()

plt.subplot(144)
plt.imshow(eps_act_logit)
plt.title("Action Logits")
plt.colorbar()

# [s[0, 3].item() - s[0, 0].item() for s in states]
# (eps_state_grad[:, 0] > 0).float().mean()

In [None]:
render_layout(
    log_dir=args["log_dir"],
    layout=[["S", "S"], ["H", "H"]]
)

# JACOBIAN TIME

In [None]:
state = states[-40].detach()
act = 0

n_state = env.observation_space.shape[0]

act = algorithm._action_onehot(torch.tensor(act).reshape(1, 1))

state.requires_grad = True
state.retain_grad()

act.requires_grad = True
act.retain_grad()

jac_state = torch.zeros((n_state, n_state))
jac_action = torch.zeros((n_state, 3))


mean, std = algorithm.transition_module(state, act)
for ix in range(n_state):
    mean[0, ix].backward(retain_graph=True)
    
    
    jac_state[ix, :] = state.grad[0]
    jac_action[ix, :] = act.grad[0]

    state.grad.zero_()
    act.grad.zero_()
    
plt.figure(figsize=(15, 5))
plt.subplot(131)
plt.imshow(jac_state.detach())
plt.colorbar()
plt.yticks(np.arange(6), ["player_y","enemy_y","ball_x","ball_y","ball_x_prime","ball_y_prime"])
plt.xticks(np.arange(6), ["player_y","enemy_y","ball_x","ball_y","ball_x_prime","ball_y_prime"], rotation=90)
plt.title("State Jac") 

plt.subplot(132)
plt.imshow(jac_action.detach())
plt.yticks(np.arange(6), ["player_y","enemy_y","ball_x","ball_y","ball_x_prime","ball_y_prime"])
plt.colorbar()
plt.title("Action Jac") 

plt.subplot(133)
plt.bar(np.arange(6), state[0].detach().numpy())
plt.xticks(np.arange(6), ["player_y","enemy_y","ball_x","ball_y","ball_x_prime","ball_y_prime"], rotation=90)
plt.title("state")


In [None]:
__next_state = torch.tensor([[0.1, 0.1, 0.1999, 0.1, 0.1, 0.1]])
__next_state.requires_grad = True
__next_state.retain_grad()
algorithm._expected_reward(algorithm.reward_vals, __next_state).sum().backward()
__next_state.grad

In [None]:
pathname = "/Users/tolga/remote/research/tolga/Modular-Baselines/modular_baselines/vca/pong/logs/01-19-2021-01-31-55/01-19-2021-01-31-55-248175/1"

In [None]:
render_layout(
    log_dir=pathname,
    layout=[["S", "S"], ["H", "H"]]
)

## Notes
- More state features & A2C training
- Pong Hyper optimization
- Maze comparison & experiments

In [None]:
from environment import PongEnv
env = PongEnv()
env.pong_env.unwrapped.get_action_meanings()
