In [1]:
import os
import time
import torch
import gym
import numpy as np
from gym import wrappers
from PIL import Image

from TD3.td3 import TD3
from TD3.utils import PrioritizedReplayBuffer, mkdir

In [2]:
env_name = 'BipedalWalkerHardcore-v2'
lr_base = 0.001
lr_decay = 0.00005
exp_noise_base = 0.3 
exp_noise_decay = 0.0001

random_seed = 42
gamma = 0.99                # discount for future rewards
batch_size = 256        # num of transitions sampled from replay buffer
polyak = 0.9999              # target policy update parameter (1-tau)
policy_noise = 0.2          # target policy smoothing noise
noise_clip = 0.5
policy_delay = 2            # delayed policy updates parameter
max_episodes = 100000         # max num of episodes
max_timesteps = 5000        # max timesteps in one episode
max_buffer_length = 5000000
log_interval = 10           # print avg reward after interval

In [3]:
actor_config = [
        {'dim': [None, 256], 'dropout': False, 'activation': 'relu'},
        {'dim': [256, 256], 'dropout': True, 'activation':'relu'},
        {'dim': [256, 128], 'dropout': False, 'activation': 'relu'},       
        {'dim': [128, None],'dropout': False, 'activation': 'tanh'}
    ]
    
critic_config = [
        {'dim': [None, 512], 'dropout': False, 'activation': 'relu'},
        {'dim': [512, 512], 'dropout': False , 'activation':'relu'},
        {'dim': [512, 128], 'dropout': False, 'activation': 'relu'},       
        {'dim': [128, 1], 'dropout': False, 'activation': False},
    ]

In [4]:
class TD3Trainer():
    
    def __init__(self, env_name, actor_config, critic_config, random_seed=42, lr_base=0.001, lr_decay=0.00005, 
                 exp_noise_base=0.3, exp_noise_decay=0.0001, gamma=0.99, batch_size=1024, 
                 polyak=0.9999, policy_noise=0.2, noise_clip=0.5, policy_delay=2, 
                 max_episodes=100000, max_timesteps=3000, max_buffer_length=5000000, 
                 log_interval=5, threshold=None, lr_minimum=1e-10, exp_noise_minimum=1e-10,
                 record_videos=True, record_interval=100, beta_multiplier=0.0001):        
        
        self.algorithm_name = 'td3'
        self.env_name = env_name
        self.env = gym.make(env_name)
        self.record_videos = record_videos
        self.record_interval = record_interval        
        if self.record_videos == True:
            videos_dir = mkdir('.', 'videos')
            monitor_dir = mkdir(videos_dir, self.algorithm_name)
            should_record = lambda i: self.should_record
            self.env = wrappers.Monitor(self.env, monitor_dir, video_callable=should_record, force=True)            
        self.state_dim = self.env.observation_space.shape[0]
        self.action_dim = self.env.action_space.shape[0]
        self.action_low = self.env.action_space.low
        self.action_high = self.env.action_space.high        
        self.should_record = False
        if not threshold == None:
            self.threshold = threshold
        else:    
            self.threshold = self.env.spec.reward_threshold
        
        self.actor_config = actor_config
        self.critic_config = critic_config
        self.actor_config[0]['dim'][0] = self.state_dim
        self.actor_config[-1]['dim'][1] = self.action_dim
        self.critic_config[0]['dim'][0] = self.state_dim + self.action_dim
        
        self.actor_config = actor_config
        self.critic_config = critic_config
        self.random_seed = random_seed
        self.lr_base = lr_base
        self.lr_decay = lr_decay   
        self.lr_minimum = lr_minimum
        self.exp_noise_base = exp_noise_base
        self.exp_noise_decay = exp_noise_decay     
        self.exp_noise_minimum = exp_noise_minimum                
        self.gamma = gamma
        self.batch_size = batch_size        
        self.polyak = polyak
        self.beta_multiplier = beta_multiplier
        self.policy_noise = policy_noise
        self.noise_clip = noise_clip
        self.policy_delay = policy_delay
        self.max_episodes = max_episodes
        self.max_timesteps = max_timesteps
        self.max_buffer_length = max_buffer_length
        self.log_interval = log_interval

        
        prdir = mkdir('.', 'preTrained')
        self.directory = mkdir(prdir, self.algorithm_name)
        self.filename = "{}_{}_{}".format(self.algorithm_name, self.env_name, self.random_seed)
                
        self.policy = TD3(self.actor_config, self.critic_config, self.action_low, self.action_high)   
        #self.replay_buffer = ReplayBuffer(max_length=self.max_buffer_length)
        self.replay_buffer = PrioritizedReplayBuffer(size=self.max_buffer_length, alpha=0.8)
        
        self.reward_history = []
        self.make_plots = False       
        
        if self.random_seed:
            print("Random Seed: {}".format(self.random_seed))
            self.env.seed(self.random_seed)
            torch.manual_seed(self.random_seed)
            np.random.seed(self.random_seed)
        
    def train(self):
        
        start_time = time.time()
        print("Training started ... \n")
        print("action_space={}".format(self.env.action_space))
        print("obs_space={}".format(self.env.observation_space))
        print("threshold={}".format(self.threshold))     
        print("action_low={} action_high={} \n".format(self.action_low, self.action_high))         

        # loading models
        self.policy.load(self.directory, self.filename)
                
        # logging variables:        
        log_f = open("train_{}.txt".format(self.algorithm_name), "w+")

        # training procedure:
        for episode in range(1, self.max_episodes+1):
            
            # Only record video during evaluation, every n steps
            if episode % self.record_interval == 0:
                self.should_record = True
            
            ep_reward = 0.0
            state = self.env.reset()
            
            # calculate params
            exploration_noise = max(self.exp_noise_base / (1.0 + episode * self.exp_noise_decay), self.exp_noise_minimum)
            learning_rate = max(self.lr_base / (1.0 + episode * self.lr_decay), self.lr_minimum)      
            beta = min(episode * self.beta_multiplier, 1)
            self.policy.set_optimizers(lr=learning_rate)

            for t in range(self.max_timesteps):
                
                # select action and add exploration noise:
                action = self.policy.select_action(state)               
                action = action + np.random.normal(0, exploration_noise, size=self.action_dim)
                action = action.clip(self.action_low, self.action_high)

                # take action in env:
                next_state, reward, done, _ = self.env.step(action)
                self.replay_buffer.add(state, action, reward, next_state, float(done))
                state = next_state

                ep_reward += reward

                # if episode is done then update policy:
                if done or t==(self.max_timesteps-1):
                    self.policy.update(self.replay_buffer, t, self.batch_size, self.gamma, self.polyak, 
                                       self.policy_noise, self.noise_clip, self.policy_delay, beta)
                    break

            self.reward_history.append(ep_reward)
            avg_reward = np.mean(self.reward_history[-100:]) 

            # logging updates:        
            log_f.write('{},{}\n'.format(episode, ep_reward))
            log_f.flush()
            
            # Calculate polyak
            #part = (env.spec.reward_threshold - avg_reward) / (env.spec.reward_threshold + 150)
            #if part > 1:
            #    part = 1
            #polyak = polyak_int[0] + (1 - part) * (polyak_int[1] - polyak_int[0])     

            # Calculate LR
            #part = min((env.spec.reward_threshold - avg_reward) / (env.spec.reward_threshold + 150), 1)
                        
            avg_actor_loss = np.mean(self.policy.actor_loss_list[-100:])
            avg_Q1_loss = np.mean(self.policy.Q1_loss_list[-100:])
            avg_Q2_loss = np.mean(self.policy.Q2_loss_list[-100:])

            # Truncate training history if we don't plan to plot it later
            if not self.make_plots:
                self.policy.truncate_loss_lists() 
                if len(self.reward_history) > 100:
                    self.reward_history.pop(0)    

            # Print avg reward every log interval:
            if episode % self.log_interval == 0:            
                self.policy.save(self.directory, self.filename)
                print("Ep:{:5d}  Rew:{:8.2f}  Avg Rew:{:8.2f}  LR:{:8.8f}  Bf:{:2.0f} {:0.4f}  EN:{:0.4f}  Loss: {:5.3f} {:5.3f} {:5.3f}".format(
                    episode, ep_reward, avg_reward, learning_rate, self.replay_buffer.get_fill(), beta, 
                    exploration_noise, avg_actor_loss, avg_Q1_loss, avg_Q2_loss))
                
            self.should_record = False    
                
            # if avg reward > threshold then save and stop traning:
            if avg_reward >= self.threshold and episode > 100: 
                print("Ep:{:5d}  Rew:{:8.2f}  Avg Rew:{:8.2f}  LR:{:8.8f}  Bf:{:2.0f}  EN:{:0.4f}  Loss: {:5.3f} {:5.3f} {:5.3f}".format(
                    episode, ep_reward, avg_reward, learning_rate, self.replay_buffer.get_fill(), beta,
                    exploration_noise, avg_actor_loss, avg_Q1_loss, avg_Q2_loss))
                print("########## Solved! ###########")
                name = self.filename + '_solved'
                self.policy.save(self.directory, name)
                log_f.close()
                training_time = time.time() - start_time
                print("Training time: {:6.2f} sec".format(training_time))
                break    
       
    def test(self, episodes=3, render=True, save_gif=True):   
        
        gifdir = mkdir('.','gif')
        algdir = mkdir(gifdir, self.algorithm_name)

        for episode in range(1, episodes+1):
            ep_reward = 0.0
            state = self.env.reset()
            epdir = mkdir(algdir, str(episode))
            
            for t in range(self.max_timesteps):
                action = self.policy.select_action(state)
                state, reward, done, _ = self.env.step(action)
                ep_reward += reward
                
                if save_gif:                                       
                    img = self.env.render(mode = 'rgb_array')
                    img = Image.fromarray(img)
                    img.save('{}/{}.jpg'.format(epdir, t))
                if done:
                    break
                    

            print('Test episode: {}\tReward: {:4.2f}'.format(episode, ep_reward))           
            self.env.close()        
            

In [5]:
agent = TD3Trainer(env_name, actor_config, critic_config, random_seed=random_seed, lr_base=lr_base, lr_decay=lr_decay, 
                   exp_noise_base=exp_noise_base, exp_noise_decay=exp_noise_decay, gamma=gamma, batch_size=batch_size,
                   polyak=polyak, policy_noise=policy_noise, noise_clip=noise_clip, policy_delay=policy_delay, 
                   max_episodes=max_episodes, max_timesteps=max_timesteps, max_buffer_length=max_buffer_length, 
                   log_interval=log_interval)
agent.train()

ACTOR=Sequential(
  (0): Linear(in_features=24, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=256, bias=True)
  (3): Dropout(p=0.2)
  (4): ReLU()
  (5): Linear(in_features=256, out_features=128, bias=True)
  (6): ReLU()
  (7): Linear(in_features=128, out_features=4, bias=True)
  (8): Tanh()
)
ACTOR=Sequential(
  (0): Linear(in_features=24, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=256, bias=True)
  (3): Dropout(p=0.2)
  (4): ReLU()
  (5): Linear(in_features=256, out_features=128, bias=True)
  (6): ReLU()
  (7): Linear(in_features=128, out_features=4, bias=True)
  (8): Tanh()
)
CRITIC=Sequential(
  (0): Linear(in_features=28, out_features=512, bias=True)
  (1): ReLU()
  (2): Linear(in_features=512, out_features=512, bias=True)
  (3): ReLU()
  (4): Linear(in_features=512, out_features=128, bias=True)
  (5): ReLU()
  (6): Linear(in_features=128, out_features=1, bias=True)
)
CRITIC=Sequential(
  (0): Line

Ep:  590  Rew: -113.99  Avg Rew: -121.04  LR:0.00097135  Bf: 7 0.0590  EN:0.2833  Loss: 0.866 0.871 0.996
Ep:  600  Rew: -107.86  Avg Rew: -120.06  LR:0.00097087  Bf: 7 0.0600  EN:0.2830  Loss: 0.932 0.855 0.677
Ep:  610  Rew: -106.14  Avg Rew: -117.68  LR:0.00097040  Bf: 7 0.0610  EN:0.2828  Loss: 1.045 0.820 0.993
Ep:  620  Rew: -102.94  Avg Rew: -114.42  LR:0.00096993  Bf: 7 0.0620  EN:0.2825  Loss: 1.235 1.155 1.045
Ep:  630  Rew: -115.91  Avg Rew: -111.27  LR:0.00096946  Bf: 7 0.0630  EN:0.2822  Loss: 1.084 0.909 0.860
Ep:  640  Rew: -100.99  Avg Rew: -110.24  LR:0.00096899  Bf: 7 0.0640  EN:0.2820  Loss: 1.056 1.131 1.552
Ep:  650  Rew: -101.04  Avg Rew: -108.59  LR:0.00096852  Bf: 7 0.0650  EN:0.2817  Loss: 1.127 0.697 0.714
Ep:  660  Rew:  -98.34  Avg Rew: -107.27  LR:0.00096805  Bf: 7 0.0660  EN:0.2814  Loss: 1.191 0.747 0.712
Ep:  670  Rew:  -98.67  Avg Rew: -105.73  LR:0.00096759  Bf: 7 0.0670  EN:0.2812  Loss: 1.209 0.772 0.787
Ep:  680  Rew: -123.64  Avg Rew: -105.41  LR:0

Ep: 1370  Rew:  -99.82  Avg Rew: -119.09  LR:0.00093589  Bf:14 0.1370  EN:0.2639  Loss: 4.629 1.421 1.368
Ep: 1380  Rew: -120.32  Avg Rew: -118.51  LR:0.00093545  Bf:14 0.1380  EN:0.2636  Loss: 4.304 1.441 1.475
Ep: 1390  Rew: -186.61  Avg Rew: -119.39  LR:0.00093502  Bf:14 0.1390  EN:0.2634  Loss: 4.503 1.464 1.338
Ep: 1400  Rew: -116.55  Avg Rew: -118.05  LR:0.00093458  Bf:15 0.1400  EN:0.2632  Loss: 4.299 1.449 1.468
Ep: 1410  Rew: -101.77  Avg Rew: -115.18  LR:0.00093414  Bf:15 0.1410  EN:0.2629  Loss: 4.361 2.015 2.159
Ep: 1420  Rew: -138.89  Avg Rew: -115.50  LR:0.00093371  Bf:15 0.1420  EN:0.2627  Loss: 4.521 1.591 1.597
Ep: 1430  Rew: -113.52  Avg Rew: -115.02  LR:0.00093327  Bf:15 0.1430  EN:0.2625  Loss: 4.436 1.594 1.590
Ep: 1440  Rew: -110.99  Avg Rew: -115.44  LR:0.00093284  Bf:15 0.1440  EN:0.2622  Loss: 4.642 1.694 1.719
Ep: 1450  Rew: -102.73  Avg Rew: -115.84  LR:0.00093240  Bf:16 0.1450  EN:0.2620  Loss: 4.654 1.191 1.406
Ep: 1460  Rew: -110.41  Avg Rew: -116.13  LR:0

Ep: 2150  Rew: -121.43  Avg Rew: -101.97  LR:0.00090293  Bf:28 0.2150  EN:0.2469  Loss: 9.042 1.903 1.915
Ep: 2160  Rew: -148.04  Avg Rew: -101.35  LR:0.00090253  Bf:28 0.2160  EN:0.2467  Loss: 9.150 2.378 2.305
Ep: 2170  Rew: -109.75  Avg Rew: -101.50  LR:0.00090212  Bf:29 0.2170  EN:0.2465  Loss: 9.151 1.936 1.910
Ep: 2180  Rew: -114.63  Avg Rew: -101.21  LR:0.00090171  Bf:29 0.2180  EN:0.2463  Loss: 9.080 2.407 2.275
Ep: 2190  Rew:  -99.16  Avg Rew: -102.57  LR:0.00090131  Bf:29 0.2190  EN:0.2461  Loss: 9.321 2.152 2.075
Ep: 2200  Rew: -142.93  Avg Rew: -101.73  LR:0.00090090  Bf:29 0.2200  EN:0.2459  Loss: 9.263 2.085 2.123
Ep: 2210  Rew:  -96.44  Avg Rew: -103.35  LR:0.00090050  Bf:30 0.2210  EN:0.2457  Loss: 9.190 1.955 1.861
Ep: 2220  Rew: -114.27  Avg Rew: -103.34  LR:0.00090009  Bf:30 0.2220  EN:0.2455  Loss: 9.321 2.230 2.187
Ep: 2230  Rew: -118.99  Avg Rew: -103.03  LR:0.00089969  Bf:30 0.2230  EN:0.2453  Loss: 9.220 2.117 2.036
Ep: 2240  Rew:  -80.09  Avg Rew: -105.21  LR:0

Ep: 2930  Rew: -119.35  Avg Rew:  -98.90  LR:0.00087222  Bf:54 0.2930  EN:0.2320  Loss: 9.736 1.448 1.449
Ep: 2940  Rew:  -97.62  Avg Rew: -100.78  LR:0.00087184  Bf:55 0.2940  EN:0.2318  Loss: 9.740 1.579 1.584
Ep: 2950  Rew: -106.43  Avg Rew:  -96.92  LR:0.00087146  Bf:55 0.2950  EN:0.2317  Loss: 9.707 1.452 1.448
Ep: 2960  Rew:  -61.36  Avg Rew:  -95.72  LR:0.00087108  Bf:55 0.2960  EN:0.2315  Loss: 9.728 1.488 1.566
Ep: 2970  Rew:  -98.21  Avg Rew:  -98.26  LR:0.00087070  Bf:56 0.2970  EN:0.2313  Loss: 9.641 1.547 1.545
Ep: 2980  Rew: -146.90  Avg Rew:  -99.80  LR:0.00087032  Bf:56 0.2980  EN:0.2311  Loss: 9.727 1.659 1.779
Ep: 2990  Rew: -111.39  Avg Rew:  -98.57  LR:0.00086994  Bf:56 0.2990  EN:0.2309  Loss: 9.597 1.661 1.626
Ep: 3000  Rew:  -85.38  Avg Rew:  -96.93  LR:0.00086957  Bf:57 0.3000  EN:0.2308  Loss: 9.701 1.655 1.561
Ep: 3010  Rew: -126.00  Avg Rew: -101.15  LR:0.00086919  Bf:57 0.3010  EN:0.2306  Loss: 9.621 1.656 1.569
Ep: 3020  Rew: -120.41  Avg Rew: -103.63  LR:0

KeyboardInterrupt: 

In [None]:
agent.test()