In [1]:
import gym

import torch
import torch.nn as nn

from models.ddpg import DDPG

In [44]:
class Actor(nn.Module):
    def __init__(self, n_sts: int, n_acts: int, dim:int):
        super(Actor, self).__init__()
                
        self.fc1 = nn.Linear(n_sts, dim)
        self.fc2 = nn.Linear(dim, n_acts)
        
        self.relu1 = nn.ReLU()

    def forward(self, x):
        x = self.relu1(self.fc1(x))
        x = torch.tanh(self.fc2(x)) * 2
        
        return x

In [46]:
class Critic(nn.Module):
    def __init__(self, n_sts: int, n_acts: int, dim: int):
        super(Critic, self).__init__()
                
        self.fc1 = nn.Linear(n_sts + n_acts, dim)
        self.fc2 = nn.Linear(dim, 1)
        
        self.relu1 = nn.ReLU()
                
    def forward(self, xs):
        x, a = xs
        
        x = self.relu1(self.fc1(torch.cat([x, a], 1)))
        x = self.fc2(x)
        
        return x

In [48]:
cri_lr = 1e-3
act_lr = 5e-4

tau = 5e-3

n_sts = 3
n_acts = 1

dim = 256

n_epis = 500
n_epochs = 200
n_rollout = 10

act_range = (-2, 2)

In [49]:
env = gym.make('Pendulum-v1', render_mode='human')

In [52]:
act = Actor(n_sts, n_acts, dim)
act_opt = torch.optim.Adam(act.parameters(), lr=act_lr)

cri = Critic(n_sts, n_acts, dim)
cri_opt = torch.optim.Adam(cri.parameters(), lr=cri_lr)

In [53]:
ddpg = DDPG(env, n_acts, act, act_opt, cri, cri_opt, act_noise = 0.3, act_range=act_range)

In [54]:
ddpg.train(n_epis, n_epochs, n_rollout)

  if not isinstance(terminated, (bool, np.bool8)):


epoch: 20, score: -82.26434258703087, n_buffer: 210
epoch: 40, score: -82.28790949686429, n_buffer: 410
epoch: 60, score: -84.1544112230634, n_buffer: 610
epoch: 80, score: -82.39202267833173, n_buffer: 810
epoch: 100, score: -84.45549551978381, n_buffer: 1010
epoch: 120, score: -82.45998536758043, n_buffer: 1210
epoch: 140, score: -84.12448597869985, n_buffer: 1410
epoch: 160, score: -82.50604223302429, n_buffer: 1610
epoch: 180, score: -84.09363139074024, n_buffer: 1810
epoch: 20, score: -78.7876571554002, n_buffer: 2210
epoch: 40, score: -71.44935494543515, n_buffer: 2410
epoch: 60, score: -84.49926259242923, n_buffer: 2610
epoch: 80, score: -88.16698681940562, n_buffer: 2810
epoch: 100, score: -89.88086753701673, n_buffer: 3010
epoch: 120, score: -92.84446437282574, n_buffer: 3210
epoch: 140, score: -95.61210760627121, n_buffer: 3410
epoch: 160, score: -94.51181420781562, n_buffer: 3610
epoch: 180, score: -94.40684543714073, n_buffer: 3810
epoch: 20, score: -54.77995585216009, n_bu

KeyboardInterrupt: 

In [None]:
env.close()