In [1]:
#!/usr/bin/env python3
import cv2
import random
import numpy as np
import argparse
from DRL.evaluator import Evaluator
from utils.util import *
from utils.tensorboard import TensorBoard
import time


In [2]:
exp = os.path.abspath('.').split('/')[-1]
writer = TensorBoard('../train_log/{}'.format(exp))
os.system('ln -sf ../train_log/{} ./log'.format(exp))
# os.system('mkdir ./model')


0

In [3]:
def train(agent, env, evaluate):
    train_times = args.train_times
    env_batch = args.env_batch
    validate_interval = args.validate_interval
    max_step = args.max_step
    debug = args.debug
    episode_train_times = args.episode_train_times
    resume = args.resume
    output = args.output
    time_stamp = time.time()
    step = episode = episode_steps = 0
    tot_reward = 0.
    observation = None
    noise_factor = args.noise_factor
    while step <= train_times:
        step += 1
        episode_steps += 1
        # reset if it is the start of episode
        if observation is None:
            observation = env.reset()
            agent.reset(observation, noise_factor)    
        action = agent.select_action(observation, noise_factor=noise_factor)
        observation, reward, done, _ = env.step(action)
        agent.observe(reward, observation, done, step)
        if (episode_steps >= max_step and max_step):
            if step > args.warmup:
                # [optional] evaluate
                if episode > 0 and validate_interval > 0 and episode % validate_interval == 0:
                    reward, dist = evaluate(env, agent.select_action, debug=debug)
                    if debug: prRed('Step_{:07d}: mean_reward:{:.3f} mean_dist:{:.3f} var_dist:{:.3f}'.format(step - 1, np.mean(reward), np.mean(dist), np.var(dist)))
                    writer.add_scalar('validate/mean_reward', np.mean(reward), step)
                    writer.add_scalar('validate/mean_dist', np.mean(dist), step)
                    writer.add_scalar('validate/var_dist', np.var(dist), step)
                    agent.save_model(output)
            train_time_interval = time.time() - time_stamp
            time_stamp = time.time()
            tot_Q = 0.
            tot_value_loss = 0.
            if step > args.warmup:
#                 if step < 10000 * max_step:
#                     lr = (3e-4, 1e-3)
#                 elif step < 20000 * max_step:
#                     lr = (1e-4, 3e-4)
#                 else:
#                     lr = (3e-5, 1e-5)
                if step < 1000 * max_step:
                    lr = (3e-4, 1e-3)
                elif step < 2000 * max_step:
                    lr = (1e-4, 3e-4)
                else:
                    lr = (3e-5, 1e-5)
                for i in range(episode_train_times):
                    Q, value_loss = agent.update_policy(lr)
                    tot_Q += Q.data.cpu().numpy()
                    tot_value_loss += value_loss.data.cpu().numpy()
                writer.add_scalar('train/critic_lr', lr[0], step)
                writer.add_scalar('train/actor_lr', lr[1], step)
                writer.add_scalar('train/Q', tot_Q / episode_train_times, step)
                writer.add_scalar('train/critic_loss', tot_value_loss / episode_train_times, step)
            if debug: prBlack('#{}: steps:{} interval_time:{:.2f} train_time:{:.2f}' \
                .format(episode, step, train_time_interval, time.time()-time_stamp)) 
            time_stamp = time.time()
            # reset
            observation = None
            episode_steps = 0
            episode += 1

In [4]:
import sys
class Arg():
    def __init__(self):
        self.batch_size = 96
        self.max_step = 40
        self.warmup = 400
        self.discount = 0.95**5
        self.rmsize = 800
        self.env_batch = 96
        self.tau = 0.001
        self.noise_factor = 0;
        self.validate_interval = 50
        self.validate_episodes = 5
        self.train_times = 196000
        self.episode_train_times = 10
        self.resume = None
        self.debug = True
        self.output = './model'
        self.seed = 1234

args = Arg()
# parser = argparse.ArgumentParser(description='Learning to Paint')

# # hyper-parameter
# parser.add_argument('--warmup', default=400, type=int, help='timestep without training but only filling the replay memory')
# parser.add_argument('--discount', default=0.95**5, type=float, help='discount factor')
# parser.add_argument('--batch_size', default=96, type=int, help='minibatch size')
# parser.add_argument('--rmsize', default=800, type=int, help='replay memory size')
# parser.add_argument('--env_batch', default=96, type=int, help='concurrent environment number')
# parser.add_argument('--tau', default=0.001, type=float, help='moving average for target network')
# parser.add_argument('--max_step', default=40, type=int, help='max length for episode')
# parser.add_argument('--noise_factor', default=0, type=float, help='noise level for parameter space noise')
# parser.add_argument('--validate_interval', default=50, type=int, help='how many episodes to perform a validation')
# parser.add_argument('--validate_episodes', default=5, type=int, help='how many episode to perform during validation')
# parser.add_argument('--train_times', default=2000000, type=int, help='total traintimes')
# parser.add_argument('--episode_train_times', default=10, type=int, help='train times for each episode')    
# parser.add_argument('--resume', default=None, type=str, help='Resuming model path for testing')
# parser.add_argument('--output', default='./model', type=str, help='Resuming model path for testing')
# parser.add_argument('--debug', dest='debug', action='store_true', help='print some info')
# parser.add_argument('--seed', default=1234, type=int, help='random seed')

# args = parser.parse_args()    
# args.output = get_output_folder(args.output, "Paint")


In [None]:
# np.random.seed(args.seed)
# torch.manual_seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed)
# random.seed(args.seed)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed)
random.seed(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
from DRL.ddpg import DDPG
from DRL.multi import fastenv
# fenv = fastenv(args.max_step, args.env_batch, writer)
# agent = DDPG(args.batch_size, args.env_batch, args.max_step, \
#              args.tau, args.discount, args.rmsize, \
#              writer, args.resume, args.output)
# evaluate = Evaluator(args, writer)
fenv = fastenv(args.max_step, args.env_batch, writer)
agent = DDPG(args.batch_size, args.env_batch, args.max_step, \
             args.tau, args.discount, args.rmsize, \
             writer, args.resume, args.output)
evaluate = Evaluator(args, writer)
print('observation_space', fenv.observation_space, 'action_space', fenv.action_space)
train(agent, fenv, evaluate)


loaded 10000 images
loaded 20000 images
loaded 30000 images
loaded 40000 images
loaded 50000 images
loaded 60000 images
loaded 70000 images
loaded 80000 images
loaded 90000 images
loaded 100000 images
loaded 110000 images
loaded 120000 images
loaded 130000 images
loaded 140000 images
loaded 150000 images
loaded 160000 images
loaded 170000 images
loaded 180000 images
loaded 190000 images
loaded 200000 images
finish loading data, 197999 training images, 2001 testing images
observation_space (96, 128, 128, 7) action_space 13


  s0 = torch.tensor(self.state, device='cpu')
  s1 = torch.tensor(state, device='cpu')


[98m #0: steps:40 interval_time:7.37 train_time:0.00[00m
[98m #1: steps:80 interval_time:5.36 train_time:0.00[00m
[98m #2: steps:120 interval_time:5.23 train_time:0.00[00m
[98m #3: steps:160 interval_time:5.48 train_time:0.00[00m
[98m #4: steps:200 interval_time:5.41 train_time:0.00[00m
[98m #5: steps:240 interval_time:5.49 train_time:0.00[00m
[98m #6: steps:280 interval_time:5.49 train_time:0.00[00m
[98m #7: steps:320 interval_time:5.22 train_time:0.00[00m
[98m #8: steps:360 interval_time:5.09 train_time:0.00[00m
[98m #9: steps:400 interval_time:4.98 train_time:0.00[00m
[98m #10: steps:440 interval_time:5.02 train_time:25.96[00m
[98m #11: steps:480 interval_time:5.00 train_time:22.75[00m
[98m #12: steps:520 interval_time:5.09 train_time:22.91[00m
[98m #13: steps:560 interval_time:5.00 train_time:22.70[00m
[98m #14: steps:600 interval_time:5.06 train_time:22.53[00m
[98m #15: steps:640 interval_time:4.94 train_time:22.89[00m
[98m #16: steps:680 interval_

[98m #128: steps:5160 interval_time:5.06 train_time:22.75[00m
[98m #129: steps:5200 interval_time:5.09 train_time:22.68[00m
[98m #130: steps:5240 interval_time:5.02 train_time:22.72[00m
[98m #131: steps:5280 interval_time:5.09 train_time:22.43[00m
[98m #132: steps:5320 interval_time:5.16 train_time:22.83[00m
[98m #133: steps:5360 interval_time:4.97 train_time:22.84[00m
[98m #134: steps:5400 interval_time:4.95 train_time:22.75[00m
[98m #135: steps:5440 interval_time:5.23 train_time:22.83[00m
[98m #136: steps:5480 interval_time:5.04 train_time:22.80[00m
[98m #137: steps:5520 interval_time:4.95 train_time:22.69[00m
[98m #138: steps:5560 interval_time:5.11 train_time:22.64[00m
[98m #139: steps:5600 interval_time:5.08 train_time:22.63[00m
[98m #140: steps:5640 interval_time:5.00 train_time:22.69[00m
[98m #141: steps:5680 interval_time:4.99 train_time:22.77[00m
[98m #142: steps:5720 interval_time:5.06 train_time:22.80[00m
[98m #143: steps:5760 interval_time:5.1

[98m #253: steps:10160 interval_time:4.98 train_time:22.62[00m
[98m #254: steps:10200 interval_time:4.99 train_time:22.76[00m
[98m #255: steps:10240 interval_time:5.08 train_time:22.77[00m
[98m #256: steps:10280 interval_time:5.05 train_time:22.73[00m
[98m #257: steps:10320 interval_time:5.13 train_time:22.70[00m
[98m #258: steps:10360 interval_time:5.13 train_time:22.91[00m
[98m #259: steps:10400 interval_time:5.33 train_time:22.67[00m
[98m #260: steps:10440 interval_time:4.99 train_time:22.87[00m
[98m #261: steps:10480 interval_time:5.01 train_time:22.60[00m
[98m #262: steps:10520 interval_time:5.04 train_time:22.74[00m
[98m #263: steps:10560 interval_time:5.06 train_time:22.73[00m
[98m #264: steps:10600 interval_time:5.01 train_time:22.38[00m
[98m #265: steps:10640 interval_time:5.07 train_time:22.50[00m
[98m #266: steps:10680 interval_time:5.10 train_time:22.68[00m
[98m #267: steps:10720 interval_time:5.33 train_time:22.78[00m
[98m #268: steps:10760 i

[98m #377: steps:15120 interval_time:4.43 train_time:19.51[00m
[98m #378: steps:15160 interval_time:4.43 train_time:19.52[00m
[98m #379: steps:15200 interval_time:4.43 train_time:19.52[00m
[98m #380: steps:15240 interval_time:4.43 train_time:19.49[00m
[98m #381: steps:15280 interval_time:4.44 train_time:19.53[00m
[98m #382: steps:15320 interval_time:4.43 train_time:19.55[00m
[98m #383: steps:15360 interval_time:4.44 train_time:19.56[00m
[98m #384: steps:15400 interval_time:4.45 train_time:19.53[00m
[98m #385: steps:15440 interval_time:4.45 train_time:19.52[00m
[98m #386: steps:15480 interval_time:4.50 train_time:19.55[00m
[98m #387: steps:15520 interval_time:4.45 train_time:19.50[00m
[98m #388: steps:15560 interval_time:4.44 train_time:19.51[00m
[98m #389: steps:15600 interval_time:4.49 train_time:19.50[00m
[98m #390: steps:15640 interval_time:4.43 train_time:19.52[00m
[98m #391: steps:15680 interval_time:4.43 train_time:19.52[00m
[98m #392: steps:15720 i

[98m #500: steps:20040 interval_time:30.18 train_time:19.52[00m
[98m #501: steps:20080 interval_time:4.46 train_time:19.52[00m
[98m #502: steps:20120 interval_time:4.51 train_time:19.53[00m
[98m #503: steps:20160 interval_time:4.49 train_time:19.54[00m
[98m #504: steps:20200 interval_time:4.44 train_time:19.53[00m
[98m #505: steps:20240 interval_time:4.50 train_time:19.51[00m
[98m #506: steps:20280 interval_time:4.44 train_time:19.53[00m
[98m #507: steps:20320 interval_time:4.45 train_time:19.52[00m
[98m #508: steps:20360 interval_time:4.46 train_time:19.51[00m
[98m #509: steps:20400 interval_time:4.46 train_time:19.49[00m
[98m #510: steps:20440 interval_time:4.54 train_time:19.50[00m
[98m #511: steps:20480 interval_time:4.44 train_time:19.51[00m
[98m #512: steps:20520 interval_time:4.44 train_time:19.51[00m
[98m #513: steps:20560 interval_time:4.43 train_time:19.50[00m
[98m #514: steps:20600 interval_time:4.45 train_time:19.50[00m
[98m #515: steps:20640 

[98m #624: steps:25000 interval_time:4.53 train_time:19.53[00m
[98m #625: steps:25040 interval_time:4.47 train_time:19.53[00m
[98m #626: steps:25080 interval_time:4.42 train_time:19.52[00m
[98m #627: steps:25120 interval_time:4.43 train_time:19.52[00m
[98m #628: steps:25160 interval_time:4.42 train_time:19.53[00m
[98m #629: steps:25200 interval_time:4.43 train_time:19.53[00m
[98m #630: steps:25240 interval_time:4.45 train_time:19.50[00m
[98m #631: steps:25280 interval_time:4.46 train_time:19.52[00m
[98m #632: steps:25320 interval_time:4.45 train_time:19.52[00m
[98m #633: steps:25360 interval_time:4.45 train_time:19.50[00m
[98m #634: steps:25400 interval_time:4.43 train_time:19.52[00m
[98m #635: steps:25440 interval_time:4.43 train_time:19.51[00m
[98m #636: steps:25480 interval_time:4.43 train_time:19.53[00m
[98m #637: steps:25520 interval_time:4.43 train_time:19.52[00m
[98m #638: steps:25560 interval_time:4.44 train_time:19.53[00m
[98m #639: steps:25600 i

[98m #748: steps:29960 interval_time:4.44 train_time:19.53[00m
[98m #749: steps:30000 interval_time:4.42 train_time:19.51[00m
[91m Step_0030039: mean_reward:0.900 mean_dist:0.025 var_dist:0.000[00m
[98m #750: steps:30040 interval_time:30.98 train_time:19.52[00m
[98m #751: steps:30080 interval_time:4.43 train_time:19.51[00m
[98m #752: steps:30120 interval_time:4.43 train_time:19.53[00m
[98m #753: steps:30160 interval_time:4.46 train_time:19.53[00m
[98m #754: steps:30200 interval_time:4.44 train_time:19.53[00m
[98m #755: steps:30240 interval_time:4.45 train_time:19.55[00m
[98m #756: steps:30280 interval_time:4.45 train_time:19.54[00m
[98m #757: steps:30320 interval_time:4.43 train_time:19.54[00m
[98m #758: steps:30360 interval_time:4.45 train_time:19.53[00m
[98m #759: steps:30400 interval_time:4.44 train_time:19.52[00m
[98m #760: steps:30440 interval_time:4.46 train_time:19.54[00m
[98m #761: steps:30480 interval_time:4.45 train_time:19.53[00m
[98m #762: ste

[98m #871: steps:34880 interval_time:4.44 train_time:19.54[00m
[98m #872: steps:34920 interval_time:4.46 train_time:19.53[00m
[98m #873: steps:34960 interval_time:4.42 train_time:19.56[00m
[98m #874: steps:35000 interval_time:4.43 train_time:19.54[00m
[98m #875: steps:35040 interval_time:4.44 train_time:19.54[00m
[98m #876: steps:35080 interval_time:4.45 train_time:19.57[00m
[98m #877: steps:35120 interval_time:4.44 train_time:19.56[00m
[98m #878: steps:35160 interval_time:4.45 train_time:19.58[00m
[98m #879: steps:35200 interval_time:4.43 train_time:19.55[00m
[98m #880: steps:35240 interval_time:4.45 train_time:19.58[00m
[98m #881: steps:35280 interval_time:4.49 train_time:19.56[00m
[98m #882: steps:35320 interval_time:4.44 train_time:19.54[00m
[98m #883: steps:35360 interval_time:4.45 train_time:19.52[00m
[98m #884: steps:35400 interval_time:4.43 train_time:19.53[00m
[98m #885: steps:35440 interval_time:4.43 train_time:19.55[00m
[98m #886: steps:35480 i

[98m #995: steps:39840 interval_time:4.43 train_time:19.50[00m
[98m #996: steps:39880 interval_time:4.45 train_time:19.51[00m
[98m #997: steps:39920 interval_time:4.41 train_time:19.48[00m
[98m #998: steps:39960 interval_time:4.43 train_time:19.49[00m
[98m #999: steps:40000 interval_time:4.47 train_time:19.51[00m
[91m Step_0040039: mean_reward:0.924 mean_dist:0.019 var_dist:0.000[00m
[98m #1000: steps:40040 interval_time:32.87 train_time:19.42[00m
[98m #1001: steps:40080 interval_time:4.41 train_time:19.41[00m
[98m #1002: steps:40120 interval_time:4.41 train_time:19.40[00m
[98m #1003: steps:40160 interval_time:4.43 train_time:19.44[00m
[98m #1004: steps:40200 interval_time:4.44 train_time:19.45[00m
[98m #1005: steps:40240 interval_time:4.43 train_time:19.46[00m
[98m #1006: steps:40280 interval_time:4.45 train_time:19.46[00m
[98m #1007: steps:40320 interval_time:4.43 train_time:19.46[00m
[98m #1008: steps:40360 interval_time:4.43 train_time:19.46[00m
[98m 

[98m #1116: steps:44680 interval_time:4.47 train_time:19.44[00m
[98m #1117: steps:44720 interval_time:4.42 train_time:19.45[00m
[98m #1118: steps:44760 interval_time:4.44 train_time:19.45[00m
[98m #1119: steps:44800 interval_time:4.41 train_time:19.45[00m
[98m #1120: steps:44840 interval_time:4.42 train_time:19.45[00m
[98m #1121: steps:44880 interval_time:4.40 train_time:19.44[00m
[98m #1122: steps:44920 interval_time:4.41 train_time:19.46[00m
[98m #1123: steps:44960 interval_time:4.41 train_time:19.44[00m
[98m #1124: steps:45000 interval_time:4.41 train_time:19.44[00m
[98m #1125: steps:45040 interval_time:4.48 train_time:19.45[00m
[98m #1126: steps:45080 interval_time:4.41 train_time:19.46[00m
[98m #1127: steps:45120 interval_time:4.41 train_time:19.47[00m
[98m #1128: steps:45160 interval_time:4.42 train_time:19.44[00m
[98m #1129: steps:45200 interval_time:4.42 train_time:19.45[00m
[98m #1130: steps:45240 interval_time:4.44 train_time:19.45[00m
[98m #113