In [16]:
import gym

import torch
import torch.nn as nn

from models.ddpg import DDPG

In [17]:
class Actor(nn.Module):
    def __init__(self, n_sts: int, n_acts: int, dim:int):
        super(Actor, self).__init__()
                
        self.fc1 = nn.Linear(n_sts, dim)
        self.fc2 = nn.Linear(dim, dim)
        self.fc3 = nn.Linear(dim, n_acts)
        
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()

    def forward(self, x):
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(x))
        x = torch.tanh(self.fc3(x)) * 2
        
        return x

In [18]:
class Critic(nn.Module):
    def __init__(self, n_sts: int, n_acts: int, dim: int):
        super(Critic, self).__init__()
                
        self.fc1 = nn.Linear(n_sts + n_acts, dim)
        self.fc2 = nn.Linear(dim, 1)
        
        self.relu1 = nn.ReLU()
                
    def forward(self, xs):
        x, a = xs
        
        x = self.relu1(self.fc1(torch.cat([x, a], 1)))
        x = self.fc2(x)
        
        return x

In [19]:
cri_lr = 1e-3
act_lr = 5e-4

tau = 5e-3

n_sts = 3
n_acts = 1

dim = 256

n_epis = 100
n_epochs = 200
n_rollout = 10

act_range = (-2, 2)

In [20]:
env = gym.make('Pendulum-v1', render_mode='human')

In [21]:
act = Actor(n_sts, n_acts, dim)
act_opt = torch.optim.Adam(act.parameters(), lr=act_lr)

cri = Critic(n_sts, n_acts, dim)
cri_opt = torch.optim.Adam(cri.parameters(), lr=cri_lr)

In [22]:
ddpg = DDPG(env, n_acts, act, act_opt, cri, cri_opt, act_noise = 0.3, act_range=act_range)

In [23]:
ddpg.train(n_epis, n_epochs, n_rollout)

step: 210, score: -73.26300631959415, n_buffer: 210
step: 410, score: -65.38297641782701, n_buffer: 410
step: 610, score: -82.34483222521912, n_buffer: 610
step: 810, score: -82.11015855865372, n_buffer: 810
step: 1010, score: -85.08165139218373, n_buffer: 1010
step: 1210, score: -81.97688326888544, n_buffer: 1210
step: 1410, score: -84.99376582891762, n_buffer: 1410
step: 1610, score: -82.12338215684726, n_buffer: 1610
step: 1810, score: -84.77030926615448, n_buffer: 1810
step: 2210, score: -66.3142685804228, n_buffer: 2210
step: 2410, score: -64.08107626880326, n_buffer: 2410
step: 2610, score: -61.766913445044466, n_buffer: 2610
step: 2810, score: -65.05580256236837, n_buffer: 2810
step: 3010, score: -68.07224346564546, n_buffer: 3010
step: 3210, score: -68.09883601009258, n_buffer: 3210
step: 3410, score: -68.2604344979504, n_buffer: 3410
step: 3610, score: -68.4666461291011, n_buffer: 3610
step: 3810, score: -66.12266214601304, n_buffer: 3810
step: 4210, score: -84.3243460696251, 

KeyboardInterrupt: 

In [None]:
env.close()