In [90]:
import numpy as np
import gym
import torch
import os

In [91]:
env = gym.make('CartPole-v0')
obs = env.reset()

[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m


In [92]:
ob_space = env.observation_space

In [93]:
ac_space = env.action_space

In [94]:
from simple_net import PolNet,VNet
pol_net = PolNet(ob_space, ac_space)

In [95]:
from machina.pols import CategoricalPol
use_rnn = False
pol = CategoricalPol(ob_space, ac_space, pol_net, use_rnn)

In [96]:
from machina.vfuncs import DeterministicSVfunc
vf_net = VNet(ob_space)
vf = DeterministicSVfunc(ob_space, vf_net, use_rnn)

In [97]:
from machina.samplers import EpiSampler

In [98]:
sampler = EpiSampler(env, pol, num_parallel=2, seed=42)

In [99]:
pol_lr = 1e-4
optim_pol = torch.optim.Adam(pol_net.parameters(), pol_lr)

In [100]:
vf_lr = 3e-4
optim_vf = torch.optim.Adam(vf_net.parameters(), vf_lr)

In [101]:
total_epi = 0
total_step = 0
max_rew = 50
kl_beta = 1
max_episodes = 100

In [104]:
# before learning
import time
done = True
for _ in range(1000):
    if done:
        time.sleep(.2)
        o = env.reset()
    ac_real, ac, a_i = pol.deterministic_ac_real(torch.tensor(o, dtype=torch.float))
    ac_real = ac_real.reshape(pol.ac_space.shape)
    next_o, r, done, e_i = env.step(np.array(ac_real))
    o = next_o
    
    # 死んだらカートを赤にする
    if done:
        env.env.track.set_color(.9,.1,.1)
        env.env.axle.set_color(.9,.1,.1)
        env.render()
        env.env.track.set_color(0,0,0)
        env.env.axle.set_color(.5,.5,.8)
        
    else:
        env.render()
# env.close()

In [105]:
from machina.traj import epi_functional as ef
from machina.misc import logger
from machina.utils import measure
from machina.traj import Traj
from machina.algos import ppo_clip

max_steps_per_iter = 200


env_name = 'CartPole-v0'
log_dir_name = 'garbage'
if not os.path.exists(log_dir_name):
    os.mkdir(log_dir_name)
    
score_file = os.path.join(log_dir_name, 'progress.csv')
logger.add_tabular_output(score_file)

## PPOの引数
gamma = 0.995
lam = 1 # これなに？
clip_param = 0.2
epoch_per_iter = 50
batch_size = 64 # defo 256
max_grad_norm = 10

while max_episodes > total_epi:
    with measure('sample'): # 便利すぎて泣いてる
        epis = sampler.sample(pol, max_steps=max_steps_per_iter)
    with measure('train'):
        traj = Traj()
        traj.add_epis(epis)
        
        traj = ef.compute_vs(traj, vf)
        traj = ef.compute_rets(traj, gamma)
        traj = ef.compute_advs(traj, gamma, lam) # アドバンテージを計算
        traj = ef.centerize_advs(traj) # アドバンテージを計算２
        traj = ef.compute_h_masks(traj) # これなんだっけ？
        traj.register_epis()
        
        result_dict = ppo_clip.train(traj=traj, pol=pol, vf=vf, clip_param=clip_param,
                                     optim_pol=optim_pol, optim_vf=optim_vf, 
                                     epoch=epoch_per_iter, batch_size=batch_size,
                                     max_grad_norm=max_grad_norm)
    total_epi += traj.num_epi
    step = traj.num_step
    total_step += step
    rewards = [np.sum(epi['rews']) for epi in epis]
    mean_rew = np.mean(rewards)
    logger.record_results(log_dir_name, result_dict, score_file,
                          total_epi, step, total_step,
                          rewards,
                          plot_title=env_name)

    if mean_rew > max_rew:
        torch.save(pol.state_dict(), os.path.join(
            log_dir_name, 'models', 'pol_max.pkl'))
        torch.save(vf.state_dict(), os.path.join(
            log_dir_name, 'models', 'vf_max.pkl'))
        torch.save(optim_pol.state_dict(), os.path.join(
            log_dir_name, 'models', 'optim_pol_max.pkl'))
        torch.save(optim_vf.state_dict(), os.path.join(
            log_dir_name, 'models', 'optim_vf_max.pkl'))
        max_rew = mean_rew

    torch.save(pol.state_dict(), os.path.join(
        log_dir_name, 'models', 'pol_last.pkl'))
    torch.save(vf.state_dict(), os.path.join(
        log_dir_name, 'models', 'vf_last.pkl'))
    torch.save(optim_pol.state_dict(), os.path.join(
        log_dir_name, 'models', 'optim_pol_last.pkl'))
    torch.save(optim_vf.state_dict(), os.path.join(
        log_dir_name, 'models', 'optim_vf_last.pkl'))
    del traj
del sampler

2018-12-26 16:11:38.706438 JST | sample: 0.1124sec
2018-12-26 16:11:38.719973 JST | Optimizing...
2018-12-26 16:11:39.591269 JST | Optimization finished!
2018-12-26 16:11:39.596583 JST | train: 0.8883sec
2018-12-26 16:11:39.598735 JST | outdir /Users/kyoshiro/Downloads/machina/example/garbage
2018-12-26 16:11:39.604081 JST | --------------  ----------
2018-12-26 16:11:39.605666 JST | PolLossAverage   -0.012481
2018-12-26 16:11:39.606721 JST | PolLossStd        0.179282
2018-12-26 16:11:39.607615 JST | PolLossMedian    -0.069745
2018-12-26 16:11:39.608593 JST | PolLossMin       -0.205202
2018-12-26 16:11:39.609580 JST | PolLossMax        0.233493
2018-12-26 16:11:39.610730 JST | VfLossAverage    56.3204
2018-12-26 16:11:39.611733 JST | VfLossStd        14.6059
2018-12-26 16:11:39.614075 JST | VfLossMedian     57.3968
2018-12-26 16:11:39.615596 JST | VfLossMin        32.7934
2018-12-26 16:11:39.616800 JST | VfLossMax        92.0185
2018-12-26 16:11:39.617982 JST | RewardAverage    19.416

2018-12-26 16:11:51.172998 JST | VfLossMax        712.096
2018-12-26 16:11:51.174161 JST | RewardAverage     63.4
2018-12-26 16:11:51.177812 JST | RewardStd         22.9312
2018-12-26 16:11:51.179166 JST | RewardMedian      59
2018-12-26 16:11:51.180469 JST | RewardMin         29
2018-12-26 16:11:51.181916 JST | RewardMax        100
2018-12-26 16:11:51.185568 JST | EpisodePerIter     5
2018-12-26 16:11:51.189132 JST | TotalEpisode      41
2018-12-26 16:11:51.190713 JST | StepPerIter      317
2018-12-26 16:11:51.194888 JST | TotalStep       1347
2018-12-26 16:11:51.197058 JST | --------------  ------------
2018-12-26 16:11:52.056652 JST | Saved a figure as /Users/kyoshiro/Downloads/machina/example/garbage/PolLoss.png
2018-12-26 16:11:52.298562 JST | Saved a figure as /Users/kyoshiro/Downloads/machina/example/garbage/VfLoss.png
2018-12-26 16:11:52.541833 JST | Saved a figure as /Users/kyoshiro/Downloads/machina/example/garbage/Reward.png
2018-12-26 16:11:52.636264 JST | sample: 0.0842sec

2018-12-26 16:12:05.829514 JST | Saved a figure as /Users/kyoshiro/Downloads/machina/example/garbage/Reward.png
2018-12-26 16:12:06.002114 JST | sample: 0.1622sec
2018-12-26 16:12:06.013945 JST | Optimizing...
2018-12-26 16:12:08.824164 JST | Optimization finished!
2018-12-26 16:12:08.827973 JST | train: 2.8250sec
2018-12-26 16:12:08.830334 JST | outdir /Users/kyoshiro/Downloads/machina/example/garbage
2018-12-26 16:12:08.837673 JST | --------------  ------------
2018-12-26 16:12:08.839230 JST | PolLossAverage    -0.0591839
2018-12-26 16:12:08.840248 JST | PolLossStd         0.806141
2018-12-26 16:12:08.841306 JST | PolLossMedian     -0.268224
2018-12-26 16:12:08.844894 JST | PolLossMin        -1.20116
2018-12-26 16:12:08.846615 JST | PolLossMax         1.14071
2018-12-26 16:12:08.848858 JST | VfLossAverage    527.139
2018-12-26 16:12:08.853199 JST | VfLossStd        431.484
2018-12-26 16:12:08.855167 JST | VfLossMedian     531.772
2018-12-26 16:12:08.857832 JST | VfLossMin         27.

2018-12-26 16:12:20.743966 JST | VfLossStd        172.822
2018-12-26 16:12:20.745510 JST | VfLossMedian      84.129
2018-12-26 16:12:20.746393 JST | VfLossMin          0.981126
2018-12-26 16:12:20.748455 JST | VfLossMax        502.882
2018-12-26 16:12:20.749927 JST | RewardAverage    135
2018-12-26 16:12:20.751847 JST | RewardStd         31.8434
2018-12-26 16:12:20.754204 JST | RewardMedian     135
2018-12-26 16:12:20.758293 JST | RewardMin         96
2018-12-26 16:12:20.760278 JST | RewardMax        174
2018-12-26 16:12:20.761911 JST | EpisodePerIter     3
2018-12-26 16:12:20.763035 JST | TotalEpisode      74
2018-12-26 16:12:20.764957 JST | StepPerIter      405
2018-12-26 16:12:20.766457 JST | TotalStep       5013
2018-12-26 16:12:20.767975 JST | --------------  ------------
2018-12-26 16:12:21.064188 JST | Saved a figure as /Users/kyoshiro/Downloads/machina/example/garbage/PolLoss.png
2018-12-26 16:12:21.301212 JST | Saved a figure as /Users/kyoshiro/Downloads/machina/example/garbag

2018-12-26 16:12:36.905884 JST | Saved a figure as /Users/kyoshiro/Downloads/machina/example/garbage/VfLoss.png
2018-12-26 16:12:37.180256 JST | Saved a figure as /Users/kyoshiro/Downloads/machina/example/garbage/Reward.png
2018-12-26 16:12:37.300513 JST | sample: 0.1146sec
2018-12-26 16:12:37.311173 JST | Optimizing...
2018-12-26 16:12:39.525715 JST | Optimization finished!
2018-12-26 16:12:39.528404 JST | train: 2.2266sec
2018-12-26 16:12:39.529706 JST | outdir /Users/kyoshiro/Downloads/machina/example/garbage
2018-12-26 16:12:39.533893 JST | --------------  -----------
2018-12-26 16:12:39.535012 JST | PolLossAverage    -0.185862
2018-12-26 16:12:39.535880 JST | PolLossStd         0.667199
2018-12-26 16:12:39.536818 JST | PolLossMedian     -0.303516
2018-12-26 16:12:39.537845 JST | PolLossMin        -0.823975
2018-12-26 16:12:39.538660 JST | PolLossMax         1.19173
2018-12-26 16:12:39.539503 JST | VfLossAverage    143.234
2018-12-26 16:12:39.540415 JST | VfLossStd        211.07
20

2018-12-26 16:12:54.101225 JST | PolLossMax         1.45295
2018-12-26 16:12:54.103342 JST | VfLossAverage    465.136
2018-12-26 16:12:54.104397 JST | VfLossStd        480.804
2018-12-26 16:12:54.106021 JST | VfLossMedian     279.634
2018-12-26 16:12:54.107146 JST | VfLossMin          0.658091
2018-12-26 16:12:54.109628 JST | VfLossMax       1590.14
2018-12-26 16:12:54.110634 JST | RewardAverage    200
2018-12-26 16:12:54.112860 JST | RewardStd          0
2018-12-26 16:12:54.115469 JST | RewardMedian     200
2018-12-26 16:12:54.118414 JST | RewardMin        200
2018-12-26 16:12:54.122302 JST | RewardMax        200
2018-12-26 16:12:54.124223 JST | EpisodePerIter     2
2018-12-26 16:12:54.125661 JST | TotalEpisode     100
2018-12-26 16:12:54.127540 JST | StepPerIter      400
2018-12-26 16:12:54.128878 JST | TotalStep       9392
2018-12-26 16:12:54.129825 JST | --------------  ------------
2018-12-26 16:12:54.522765 JST | Saved a figure as /Users/kyoshiro/Downloads/machina/example/garbage

In [106]:
# after learning
import time
done = True
for _ in range(1000):
    if done:
        time.sleep(.2)
        o = env.reset()
    ac_real, ac, a_i = pol.deterministic_ac_real(torch.tensor(o, dtype=torch.float))
    ac_real = ac_real.reshape(pol.ac_space.shape)
    next_o, r, done, e_i = env.step(np.array(ac_real))
    o = next_o
    
    # 死んだらカートを赤にする
    if done:
        env.env.track.set_color(.9,.1,.1)
        env.env.axle.set_color(.9,.1,.1)
        env.render()
        env.env.track.set_color(0,0,0)
        env.env.axle.set_color(.5,.5,.8)
        
    else:
        env.render()
# env.close()