In [None]:
%load_ext autoreload
%autoreload 2

import torch
import numpy as np
import os
import json
import datetime
from collections import namedtuple

from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
from stable_baselines3.common.env_util import make_vec_env

from modular_baselines.loggers.basic import(InitLogCallback,
                                            LogRolloutCallback,
                                            LogWeightCallback,
                                            LogGradCallback)

from modular_baselines.vca.algorithm import DiscerteStateVCA
from modular_baselines.vca.buffer import Buffer
from modular_baselines.vca.collector import NStepCollector
from modular_baselines.vca.modules import (CategoricalPolicyModule,
                     CategoricalTransitionModule,
                     CategoricalRewardModule)
from environment import MazeEnv


In [None]:
now = datetime.datetime.now().strftime("%m-%d-%Y-%H-%M-%S")
args = dict(
    state_size = 11,
    buffer_size = 50000,
    policy_hidden_size = 32,
    policy_tau = 1,
    transition_hidden_size = 32,
    transition_module_tau = 1,
    reward_set = [-1, 0, 1],
    reward_hidden_size = 16,
    reward_module_tau = 1,
    batchsize = 32,
    entropy_coef = 0.01,
    rollout_len=10,
    total_timesteps=int(5e4),
    device="cpu",
    log_interval=95,
    trans_lr=3e-3,
    policy_lr=3e-3,
    reward_lr=1e-3,
    use_gumbel=False,
    log_dir="logs/{}".format(now)
)
args = namedtuple("Args", args.keys())(*args.values())

In [None]:
env = MazeEnv()
vecenv = make_vec_env(lambda: MazeEnv())

rollout_callback = LogRolloutCallback()
init_callback = InitLogCallback(args.log_interval,
                                args.log_dir)
weight_callback = LogWeightCallback("weights.json")
grad_callback = LogGradCallback("grads.json")

buffer = Buffer(
    args.buffer_size,
    vecenv.observation_space,
    vecenv.action_space)

policy_m = CategoricalPolicyModule(
    vecenv.observation_space.n,
    vecenv.action_space.n,
    args.policy_hidden_size,
    tau=args.policy_tau,
    use_gumbel=args.use_gumbel)
trans_m = CategoricalTransitionModule(
    vecenv.observation_space.n,
    vecenv.action_space.n,
    state_set=torch.from_numpy(env.state_set),
    hidden_size=args.transition_hidden_size,
    tau=args.transition_module_tau,
    use_gumbel=args.use_gumbel)
reward_m = CategoricalRewardModule(
    vecenv.observation_space.n,
    env.reward_set,
    args.reward_hidden_size,
    tau=args.reward_module_tau)

collector = NStepCollector(
    env=vecenv,
    buffer=buffer,
    policy=policy_m,
    callbacks=[rollout_callback])
algorithm = DiscerteStateVCA(
    policy_module=policy_m,
    transition_module=trans_m,
    reward_module=reward_m,
    buffer=buffer,
    collector=collector,
    env=vecenv,
    reward_vals=env.expected_reward(),
    rollout_len=args.rollout_len,
    trans_opt=torch.optim.RMSprop(trans_m.parameters(), lr=args.trans_lr),
    policy_opt=torch.optim.RMSprop(policy_m.parameters(), lr=args.policy_lr),
    reward_opt=torch.optim.RMSprop(reward_m.parameters(), lr=args.reward_lr),
    batch_size=args.batchsize,
    entropy_coef=args.entropy_coef,
    device=args.device,
    callbacks=[init_callback, weight_callback, grad_callback]
)

In [None]:
algorithm.learn(args.total_timesteps)

In [None]:
from visualizers.visualize import render_layout

render_layout(
    log_dir="logs/{}".format(now),
    layout=[["S", "S"], ["H", "H"]]
)

In [None]:
path = "static/policy_m.b"
import os

os.makedirs("static", exist_ok=True)
with open(path, "wb") as bin_file:
    torch.save(policy_m.state_dict(), bin_file)

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt


def one_step_grad(init_state=8):
    jac = torch.zeros((4, env.observation_space.n, 4))
    
    r_state = (torch.ones((4, 1)) * init_state) == torch.arange(env.observation_space.n).reshape(1, -1)
    r_state = r_state.float()
    r_state.requires_grad = True
    r_state.retain_grad()
    
    action = torch.arange(4, requires_grad=False)
    r_action = algorithm._action_onehot(action.reshape(4, 1))
    r_action.requires_grad = True
    r_action.retain_grad()
    
    logits = algorithm.transition_module(r_state, r_action)
    logits.retain_grad()
    print(logits.argmax(1))
    
    for ix in range(env.observation_space.n):
        probs = torch.nn.functional.softmax(logits, dim=1)
        probs[:, ix].backward(torch.ones(4), retain_graph=True)
        jac[:, ix, :] = r_action.grad
        r_action.grad.zero_()
    return jac

jax = one_step_grad()

# (batch action, state size, grad action size)


plt.figure(dpi=300)
plt.imshow(jax[1])
plt.colorbar()

jax[1, 25, :].argmax()

In [None]:
import matplotlib.pyplot as plt

def get_grad(logit_k=1):
    logit = torch.tensor([0., 0, 0, logit_k, 0], requires_grad=True)
    one_hot = (torch.arange(5) == 3).float()

    probs = torch.nn.functional.softmax(logit, dim=0)
#     out = (probs * one_hot)
    out = logit * one_hot * probs.detach()
    out = one_hot + out - out.detach()

    out.backward(torch.tensor([1., 1, 1, 1, 1]))
    return logit.grad[3], probs[3]


x = np.linspace(-20, 20, 1001)
y_grad, probs = list(zip(*[get_grad(logit_k=ix) for ix in x]))
plt.figure(dpi=300)
plt.plot(x, y_grad, label="grad")
# plt.plot(x, probs, label="prob")
plt.title("Softmax Gradient")
plt.xlabel("k-th logit")
plt.ylabel("k-th logit grad")
plt.legend()