In [1]:
import sys; sys.path.append('..') # add project root to the python path

In [2]:
import torch
import gym
import numpy as np

from src.part3.MLP import MultiLayerPerceptron as MLP

from src.part5.DQN import prepare_training_inputs
from src.part5.DDPG import DDPG, Actor, Critic
from src.part5.DDPG import OrnsteinUhlenbeckProcess as OUProcess
from src.common.memory.memory import ReplayMemory
from src.common.train_utils import to_tensor
from src.common.target_update import soft_update

In [3]:
FROM_SCRATCH = False

In [4]:
lr_mu = 0.005
lr_q = 0.001
gamma = 0.99
batch_size  = 256
memory_size = 50000
tau = 0.001 # polyak parameter for soft target update
sampling_only_until = 2000

In [5]:
actor, actor_target = Actor(), Actor()
critic, critic_target = Critic(), Critic()

agent = DDPG(critic=critic,
             critic_target=critic_target,
             actor=actor,
             actor_target=actor_target)

memory = ReplayMemory(memory_size)

In [6]:
total_eps = 1000
print_every = 100

env = gym.make('Pendulum-v0')

if FROM_SCRATCH:
    for n_epi in range(total_eps):
        ou_noise = OUProcess(mu=np.zeros(1))
        s = env.reset()
        cum_r = 0

        while True:
            s = to_tensor(s, size=(1, 3))
            a = agent.get_action(s).numpy() + ou_noise()[0]
            ns, r, done, info = env.step(a)

            experience = (s,
                          torch.tensor(a).view(1, 1),
                          torch.tensor(r).view(1, 1),
                          torch.tensor(ns).view(1, 3),
                          torch.tensor(done).view(1, 1))
            memory.push(experience)

            s = ns
            cum_r += r

            if len(memory) >= sampling_only_until:
                # train agent
                sampled_exps = memory.sample(batch_size)
                sampled_exps = prepare_training_inputs(sampled_exps)
                agent.update(*sampled_exps)
                # update target networks
                soft_update(agent.actor, agent.actor_target, tau)
                soft_update(agent.critic, agent.critic_target, tau)        

            if done:
                break

        if n_epi % print_every == 0:
            msg = (n_epi, cum_r)
            print("Episode : {} | Cumulative Reward : {} |".format(*msg))

    torch.save(agent.state_dict(), 'ddpg_cartpole.ptb')
else:
    agent.load_state_dict(torch.load('ddpg_cartpole.ptb'))



In [None]:
env = gym.make('Pendulum-v0')

s = env.reset()
env.render()
cum_r = 0

while True:
    s = to_tensor(s, size=(1, 3))
    a = agent.get_action(s).numpy()
    ns, r, done, info = env.step(a)
    s = ns
    env.render()
    if done:
        break
    
env.close()