# Machina Quickstart notebook
This notebook exmaple is for bigginer.

In this example you can try classic control tasks of [gym](https://gym.openai.com/envs/#classic_control).

In [1]:
import numpy as np
import gym
import torch
import os

In [2]:
# define your environment
env_name = 'Acrobot-v1'
env = gym.make(env_name)
obs = env.reset()
env.render() # show your env

[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m


True

In [3]:
# check dimension of observation space and action space
ob_space = env.observation_space
ac_space = env.action_space
print('obs:', ob_space)
print('act:', ac_space)

obs: Box(6,)
act: Discrete(3)


In [4]:
# define your policy
from simple_net import PolNet,VNet
from machina.pols import CategoricalPol
from machina.vfuncs import DeterministicSVfunc
# policy
pol_net = PolNet(ob_space, ac_space)
pol = CategoricalPol(ob_space, ac_space, pol_net)
# value function
vf_net = VNet(ob_space)
vf = DeterministicSVfunc(ob_space, vf_net)

# set optimizer to both models
pol_lr = 1e-4
optim_pol = torch.optim.Adam(pol_net.parameters(), pol_lr)

vf_lr = 3e-4
optim_vf = torch.optim.Adam(vf_net.parameters(), vf_lr)

In [5]:
# registrate your environment and policy to sampler
from machina.samplers import EpiSampler
sampler = EpiSampler(env, pol, num_parallel=2, seed=42)

In [6]:
# show your initial policy's behavior
import time
done = False
o = env.reset() 
for _ in range(150): # show 150 frames (=10 sec)
    if done:
        time.sleep(1) # when the boundary　of eposode
        o = env.reset()
    ac_real, ac, a_i = pol.deterministic_ac_real(torch.tensor(o, dtype=torch.float))
    ac_real = ac_real.reshape(pol.ac_space.shape)
    next_o, r, done, e_i = env.step(np.array(ac_real))
    o = next_o
    time.sleep(1/15) # 15fps
    env.render()

In [7]:
# train your policy
from machina.traj import epi_functional as ef
from machina import logger
from machina.utils import measure
from machina.traj import Traj
from machina.algos import ppo_clip

# machina automatically write log (model ,scores, etc..)
log_dir_name = 'garbage'
if not os.path.exists(log_dir_name):
    os.mkdir(log_dir_name)
    os.mkdir(log_dir_name+'/models')
score_file = os.path.join(log_dir_name, 'progress.csv')
logger.add_tabular_output(score_file)

#  arguments of PPO
kl_beta = 1
gamma = 0.995
lam = 1 
clip_param = 0.2
epoch_per_iter = 50
batch_size = 64
max_grad_norm = 10

# counter and record for loop
total_epi = 0
total_step = 0
max_rew = -500

# how long will you train
max_episodes = 100 # for100 eposode

# max timesteps per eposode
max_steps_per_iter = 150 # 150 frames (= 10 sec)

# train loop
while max_episodes > total_epi:
    # sample trajectories
    with measure('sample'):
        epis = sampler.sample(pol, max_steps=max_steps_per_iter)

    # train from trajectories
    with measure('train'):
        traj = Traj()
        traj.add_epis(epis)
        
        # calulate advantage
        traj = ef.compute_vs(traj, vf)
        traj = ef.compute_rets(traj, gamma)
        traj = ef.compute_advs(traj, gamma, lam)
        traj = ef.centerize_advs(traj)
        traj = ef.compute_h_masks(traj)
        traj.register_epis()
        
        result_dict = ppo_clip.train(traj=traj, pol=pol, vf=vf, clip_param=clip_param,
                                     optim_pol=optim_pol, optim_vf=optim_vf, 
                                     epoch=epoch_per_iter, batch_size=batch_size,
                                     max_grad_norm=max_grad_norm)
    # update counter and record
    total_epi += traj.num_epi
    step = traj.num_step
    total_step += step
    rewards = [np.sum(epi['rews']) for epi in epis]
    mean_rew = np.mean(rewards)
    logger.record_results(log_dir_name, result_dict, score_file,
                          total_epi, step, total_step,
                          rewards,
                          plot_title=env_name)
    if mean_rew > max_rew:
        torch.save(pol.state_dict(), os.path.join(
            log_dir_name, 'models', 'pol_max.pkl'))
        torch.save(vf.state_dict(), os.path.join(
            log_dir_name, 'models', 'vf_max.pkl'))
        torch.save(optim_pol.state_dict(), os.path.join(
            log_dir_name, 'models', 'optim_pol_max.pkl'))
        torch.save(optim_vf.state_dict(), os.path.join(
            log_dir_name, 'models', 'optim_vf_max.pkl'))
        max_rew = mean_rew

    torch.save(pol.state_dict(), os.path.join(
        log_dir_name, 'models', 'pol_last.pkl'))
    torch.save(vf.state_dict(), os.path.join(
        log_dir_name, 'models', 'vf_last.pkl'))
    torch.save(optim_pol.state_dict(), os.path.join(
        log_dir_name, 'models', 'optim_pol_last.pkl'))
    torch.save(optim_vf.state_dict(), os.path.join(
        log_dir_name, 'models', 'optim_vf_last.pkl'))
    del traj
del sampler

2019-01-09 16:58:18.475314 JST | sample: 0.4030sec
2019-01-09 16:58:18.498011 JST | Optimizing...
2019-01-09 16:58:22.142084 JST | Optimization finished!
2019-01-09 16:58:22.143209 JST | train: 3.6667sec
2019-01-09 16:58:22.145390 JST | outdir /Users/kyoshiro/Downloads/machina/example/garbage
2019-01-09 16:58:22.151043 JST | --------------  -------------
2019-01-09 16:58:22.152865 JST | PolLossAverage      0.0724339
2019-01-09 16:58:22.154599 JST | PolLossStd          0.862497
2019-01-09 16:58:22.155943 JST | PolLossMedian       0.390175
2019-01-09 16:58:22.158134 JST | PolLossMin         -1.57809
2019-01-09 16:58:22.159076 JST | PolLossMax          1.06903
2019-01-09 16:58:22.160351 JST | VfLossAverage    3262.66
2019-01-09 16:58:22.162751 JST | VfLossStd        3373.38
2019-01-09 16:58:22.165026 JST | VfLossMedian     1980.92
2019-01-09 16:58:22.166042 JST | VfLossMin         118.77
2019-01-09 16:58:22.168700 JST | VfLossMax       16197.8
2019-01-09 16:58:22.173138 JST | RewardAverag

2019-01-09 16:58:43.241343 JST | VfLossMax       2481.44
2019-01-09 16:58:43.245593 JST | RewardAverage   -466.5
2019-01-09 16:58:43.247183 JST | RewardStd         33.5
2019-01-09 16:58:43.250090 JST | RewardMedian    -466.5
2019-01-09 16:58:43.251410 JST | RewardMin       -500
2019-01-09 16:58:43.253389 JST | RewardMax       -433
2019-01-09 16:58:43.257536 JST | EpisodePerIter     2
2019-01-09 16:58:43.258631 JST | TotalEpisode      10
2019-01-09 16:58:43.259601 JST | StepPerIter      934
2019-01-09 16:58:43.261697 JST | TotalStep       4934
2019-01-09 16:58:43.262684 JST | --------------  ------------
2019-01-09 16:58:43.524636 JST | Saved a figure as /Users/kyoshiro/Downloads/machina/example/garbage/PolLoss.png
2019-01-09 16:58:43.736706 JST | Saved a figure as /Users/kyoshiro/Downloads/machina/example/garbage/VfLoss.png
2019-01-09 16:58:43.962763 JST | Saved a figure as /Users/kyoshiro/Downloads/machina/example/garbage/Reward.png
2019-01-09 16:58:44.407205 JST | sample: 0.4326sec
2

2019-01-09 16:59:07.882840 JST | sample: 0.5201sec
2019-01-09 16:59:07.915839 JST | Optimizing...
2019-01-09 16:59:12.000063 JST | Optimization finished!
2019-01-09 16:59:12.003859 JST | train: 4.1170sec
2019-01-09 16:59:12.006354 JST | outdir /Users/kyoshiro/Downloads/machina/example/garbage
2019-01-09 16:59:12.011913 JST | --------------  ------------
2019-01-09 16:59:12.013263 JST | PolLossAverage     0.0619547
2019-01-09 16:59:12.015620 JST | PolLossStd         0.841991
2019-01-09 16:59:12.016777 JST | PolLossMedian      0.385266
2019-01-09 16:59:12.019047 JST | PolLossMin        -1.78505
2019-01-09 16:59:12.020232 JST | PolLossMax         1.37218
2019-01-09 16:59:12.023254 JST | VfLossAverage    844.062
2019-01-09 16:59:12.024981 JST | VfLossStd       1386.69
2019-01-09 16:59:12.027387 JST | VfLossMedian     119.722
2019-01-09 16:59:12.030844 JST | VfLossMin          1.19408
2019-01-09 16:59:12.034031 JST | VfLossMax       5037.78
2019-01-09 16:59:12.035265 JST | RewardAverage   -

2019-01-09 16:59:26.757822 JST | VfLossMedian      381.649
2019-01-09 16:59:26.760657 JST | VfLossMin           4.25489
2019-01-09 16:59:26.762319 JST | VfLossMax        5847.95
2019-01-09 16:59:26.763415 JST | RewardAverage    -418
2019-01-09 16:59:26.764385 JST | RewardStd          82
2019-01-09 16:59:26.765417 JST | RewardMedian     -418
2019-01-09 16:59:26.768859 JST | RewardMin        -500
2019-01-09 16:59:26.771365 JST | RewardMax        -336
2019-01-09 16:59:26.772619 JST | EpisodePerIter      2
2019-01-09 16:59:26.775250 JST | TotalEpisode       28
2019-01-09 16:59:26.776341 JST | StepPerIter       837
2019-01-09 16:59:26.778501 JST | TotalStep       12803
2019-01-09 16:59:26.779910 JST | --------------  -------------
2019-01-09 16:59:27.005569 JST | Saved a figure as /Users/kyoshiro/Downloads/machina/example/garbage/PolLoss.png
2019-01-09 16:59:27.207608 JST | Saved a figure as /Users/kyoshiro/Downloads/machina/example/garbage/VfLoss.png
2019-01-09 16:59:27.398117 JST | Saved 

2019-01-09 16:59:43.220170 JST | Saved a figure as /Users/kyoshiro/Downloads/machina/example/garbage/VfLoss.png
2019-01-09 16:59:43.448978 JST | Saved a figure as /Users/kyoshiro/Downloads/machina/example/garbage/Reward.png
2019-01-09 16:59:43.903961 JST | sample: 0.4372sec
2019-01-09 16:59:43.934582 JST | Optimizing...
2019-01-09 16:59:46.807909 JST | Optimization finished!
2019-01-09 16:59:46.811002 JST | train: 2.9014sec
2019-01-09 16:59:46.813449 JST | outdir /Users/kyoshiro/Downloads/machina/example/garbage
2019-01-09 16:59:46.817405 JST | --------------  ---------------
2019-01-09 16:59:46.819956 JST | PolLossAverage     -0.00248635
2019-01-09 16:59:46.821012 JST | PolLossStd          0.826766
2019-01-09 16:59:46.821931 JST | PolLossMedian      -0.000411008
2019-01-09 16:59:46.822861 JST | PolLossMin         -1.31904
2019-01-09 16:59:46.825554 JST | PolLossMax          1.56819
2019-01-09 16:59:46.827410 JST | VfLossAverage     215.212
2019-01-09 16:59:46.828982 JST | VfLossStd   

2019-01-09 16:59:57.420295 JST | PolLossMin         -1.32387
2019-01-09 16:59:57.421771 JST | PolLossMax          0.623551
2019-01-09 16:59:57.423494 JST | VfLossAverage      35.6161
2019-01-09 16:59:57.425286 JST | VfLossStd          55.3551
2019-01-09 16:59:57.426657 JST | VfLossMedian        8.11314
2019-01-09 16:59:57.428061 JST | VfLossMin           0.403717
2019-01-09 16:59:57.444277 JST | VfLossMax         269.939
2019-01-09 16:59:57.448300 JST | RewardAverage    -176
2019-01-09 16:59:57.452437 JST | RewardStd           7
2019-01-09 16:59:57.459454 JST | RewardMedian     -176
2019-01-09 16:59:57.463318 JST | RewardMin        -183
2019-01-09 16:59:57.467466 JST | RewardMax        -169
2019-01-09 16:59:57.471839 JST | EpisodePerIter      2
2019-01-09 16:59:57.476141 JST | TotalEpisode       46
2019-01-09 16:59:57.481234 JST | StepPerIter       354
2019-01-09 16:59:57.485177 JST | TotalStep       18190
2019-01-09 16:59:57.487080 JST | --------------  ------------
2019-01-09 16:59:5

2019-01-09 17:00:10.897569 JST | TotalEpisode       55
2019-01-09 17:00:10.899810 JST | StepPerIter       575
2019-01-09 17:00:10.901901 JST | TotalStep       20506
2019-01-09 17:00:10.904128 JST | --------------  ------------
2019-01-09 17:00:11.147793 JST | Saved a figure as /Users/kyoshiro/Downloads/machina/example/garbage/PolLoss.png
2019-01-09 17:00:11.349398 JST | Saved a figure as /Users/kyoshiro/Downloads/machina/example/garbage/VfLoss.png
2019-01-09 17:00:11.561731 JST | Saved a figure as /Users/kyoshiro/Downloads/machina/example/garbage/Reward.png
2019-01-09 17:00:11.794539 JST | sample: 0.2267sec
2019-01-09 17:00:11.806408 JST | Optimizing...
2019-01-09 17:00:14.005888 JST | Optimization finished!
2019-01-09 17:00:14.009394 JST | train: 2.2139sec
2019-01-09 17:00:14.011470 JST | outdir /Users/kyoshiro/Downloads/machina/example/garbage
2019-01-09 17:00:14.015615 JST | --------------  -------------
2019-01-09 17:00:14.016834 JST | PolLossAverage      0.0408619
2019-01-09 17:00

2019-01-09 17:00:26.826584 JST | outdir /Users/kyoshiro/Downloads/machina/example/garbage
2019-01-09 17:00:26.831490 JST | --------------  -------------
2019-01-09 17:00:26.832680 JST | PolLossAverage     -0.0757299
2019-01-09 17:00:26.834877 JST | PolLossStd          0.882978
2019-01-09 17:00:26.835994 JST | PolLossMedian      -0.300946
2019-01-09 17:00:26.838034 JST | PolLossMin         -1.19098
2019-01-09 17:00:26.840074 JST | PolLossMax          1.36441
2019-01-09 17:00:26.841393 JST | VfLossAverage     238.34
2019-01-09 17:00:26.843798 JST | VfLossStd         236.141
2019-01-09 17:00:26.845551 JST | VfLossMedian      168.711
2019-01-09 17:00:26.847757 JST | VfLossMin           5.4172
2019-01-09 17:00:26.849501 JST | VfLossMax         809.803
2019-01-09 17:00:26.852965 JST | RewardAverage    -188
2019-01-09 17:00:26.854707 JST | RewardStd          22
2019-01-09 17:00:26.857058 JST | RewardMedian     -188
2019-01-09 17:00:26.858283 JST | RewardMin        -210
2019-01-09 17:00:26.859

2019-01-09 17:00:38.212697 JST | RewardAverage    -118.333
2019-01-09 17:00:38.214186 JST | RewardStd          22.4252
2019-01-09 17:00:38.216222 JST | RewardMedian     -110
2019-01-09 17:00:38.217036 JST | RewardMin        -149
2019-01-09 17:00:38.220392 JST | RewardMax         -96
2019-01-09 17:00:38.223155 JST | EpisodePerIter      3
2019-01-09 17:00:38.225593 JST | TotalEpisode       78
2019-01-09 17:00:38.227359 JST | StepPerIter       358
2019-01-09 17:00:38.229057 JST | TotalStep       24670
2019-01-09 17:00:38.230076 JST | --------------  -------------
2019-01-09 17:00:38.500443 JST | Saved a figure as /Users/kyoshiro/Downloads/machina/example/garbage/PolLoss.png
2019-01-09 17:00:38.722904 JST | Saved a figure as /Users/kyoshiro/Downloads/machina/example/garbage/VfLoss.png
2019-01-09 17:00:38.940885 JST | Saved a figure as /Users/kyoshiro/Downloads/machina/example/garbage/Reward.png
2019-01-09 17:00:39.245727 JST | sample: 0.2804sec
2019-01-09 17:00:39.256695 JST | Optimizing..

2019-01-09 17:00:50.775407 JST | Saved a figure as /Users/kyoshiro/Downloads/machina/example/garbage/Reward.png
2019-01-09 17:00:50.997409 JST | sample: 0.2094sec
2019-01-09 17:00:51.008981 JST | Optimizing...
2019-01-09 17:00:52.969730 JST | Optimization finished!
2019-01-09 17:00:52.971709 JST | train: 1.9734sec
2019-01-09 17:00:52.974114 JST | outdir /Users/kyoshiro/Downloads/machina/example/garbage
2019-01-09 17:00:52.979029 JST | --------------  --------------
2019-01-09 17:00:52.980830 JST | PolLossAverage      0.00155199
2019-01-09 17:00:52.983390 JST | PolLossStd          0.765483
2019-01-09 17:00:52.985110 JST | PolLossMedian      -0.225083
2019-01-09 17:00:52.986984 JST | PolLossMin         -1.19572
2019-01-09 17:00:52.988858 JST | PolLossMax          1.24828
2019-01-09 17:00:52.990860 JST | VfLossAverage      94.1178
2019-01-09 17:00:52.992469 JST | VfLossStd         100.317
2019-01-09 17:00:53.000825 JST | VfLossMedian       61.5569
2019-01-09 17:00:53.001962 JST | VfLossMi

- you can check progress by **garbage/Reward.png**

In [8]:
# load best policy
best_path = 'garbage/models/pol_max.pkl'
best_pol = CategoricalPol(ob_space, ac_space, pol_net)
best_pol.load_state_dict(torch.load(best_path))

In [9]:
# show your trained policy's behavior
done = False
o = env.reset()
for _ in range(300): # show 300 frames (=20 sec)
    if done:
        time.sleep(1) # when the boundary　of eposode
        o = env.reset()
    ac_real, ac, a_i = pol.deterministic_ac_real(torch.tensor(o, dtype=torch.float))
    ac_real = ac_real.reshape(pol.ac_space.shape)
    next_o, r, done, e_i = env.step(np.array(ac_real))
    o = next_o
    time.sleep(1/15) # 15fps
    env.render()

In [10]:
# close your environment
env.close()