In [None]:
%load_ext autoreload
%autoreload 2

import torch
import numpy as np
import os
import json
import datetime
from collections import namedtuple

from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
from stable_baselines3.common.env_util import make_vec_env

from modular_baselines.loggers.basic import(InitLogCallback,
                                            LogRolloutCallback,
                                            LogWeightCallback,
                                            LogGradCallback)

from modular_baselines.vca.algorithm import DiscerteStateVCA
from modular_baselines.vca.buffer import Buffer
from modular_baselines.vca.collector import NStepCollector
from modular_baselines.vca.modules import (CategoricalPolicyModule,
                     CategoricalTransitionModule,
                     CategoricalRewardModule)
from environment import WalkEnv


In [None]:
now = datetime.datetime.now().strftime("%m-%d-%Y-%H-%M-%S")
args = dict(
    state_size = 11,
    buffer_size = 500,
    policy_hidden_size = 16,
    policy_tau = 1,
    transition_hidden_size = 16,
    transition_module_tau = 1,
    reward_set = [-1, 0, 1],
    reward_hidden_size = 16,
    reward_module_tau = 1,
    batchsize = 32,
    entropy_coef = 0.01,
    rollout_len=5,
    total_timesteps=int(1e4),
    device="cpu",
    log_interval=100,
    trans_lr=1e-3,
    policy_lr=1e-2,
    reward_lr=1e-3,
    log_dir="logs/{}".format(now)
)
args = namedtuple("Args", args.keys())(*args.values())

In [None]:
env = WalkEnv(args.state_size)
vecenv = make_vec_env(lambda: WalkEnv(args.state_size))

rollout_callback = LogRolloutCallback()
init_callback = InitLogCallback(args.log_interval,
                                args.log_dir)
weight_callback = LogWeightCallback("weights.json")
grad_callback = LogGradCallback("grads.json")

buffer = Buffer(
    args.buffer_size,
    vecenv.observation_space,
    vecenv.action_space)

policy_m = CategoricalPolicyModule(
    vecenv.observation_space.n,
    vecenv.action_space.n,
    args.policy_hidden_size,
    tau=args.policy_tau)
trans_m = CategoricalTransitionModule(
    vecenv.observation_space.n,
    vecenv.action_space.n,
    state_set=torch.from_numpy(env.state_set),
    hidden_size=args.transition_hidden_size,
    tau=args.transition_module_tau)
reward_m = CategoricalRewardModule(
    vecenv.observation_space.n,
    env.reward_set,
    args.reward_hidden_size,
    tau=args.reward_module_tau)

collector = NStepCollector(
    env=vecenv,
    buffer=buffer,
    policy=policy_m,
    callbacks=[rollout_callback])
algorithm = DiscerteStateVCA(
    policy_module=policy_m,
    transition_module=trans_m,
    reward_module=reward_m,
    buffer=buffer,
    collector=collector,
    env=vecenv,
    rollout_len=args.rollout_len,
    trans_opt=torch.optim.RMSprop(trans_m.parameters(), lr=args.trans_lr),
    policy_opt=torch.optim.RMSprop(policy_m.parameters(), lr=args.policy_lr),
    reward_opt=torch.optim.RMSprop(reward_m.parameters(), lr=args.reward_lr),
    batch_size=args.batchsize,
    entropy_coef=args.entropy_coef,
    device=args.device,
    callbacks=[init_callback, weight_callback, grad_callback]
)

In [None]:
algorithm.learn(args.total_timesteps)

In [None]:
from visualizers.visualize import render_layout

render_layout(
    log_dir="logs/{}".format(now),
    layout=[["S", "S"], ["H", "H"]]
)