In [1]:
import gym

import torch
import torch.nn as nn

from models.ddpg import DDPG

In [2]:
class Actor(nn.Module):
    def __init__(self, n_sts: int, n_acts: int, dim:int):
        super(Actor, self).__init__()
                
        self.fc1 = nn.Linear(n_sts, dim)
        self.fc2 = nn.Linear(dim, n_acts)
        
        self.relu1 = nn.ReLU()

        # self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.5
        self.fc1.weight.data.uniform_(-initrange, initrange)
        self.fc2.weight.data.uniform_(-initrange, initrange)

    
    def forward(self, x):
        x = self.relu1(self.fc1(x))
        x = torch.tanh(self.fc2(x)) * 2
        
        return x

In [3]:
class Critic(nn.Module):
    def __init__(self, n_sts: int, n_acts: int, dim: int):
        super(Critic, self).__init__()
                
        self.fc1 = nn.Linear(n_sts + n_acts, dim)
        self.fc2 = nn.Linear(dim, 1)
        
        self.relu1 = nn.ReLU()
                
    def forward(self, xs):
        x, a = xs
        
        x = self.relu1(self.fc1(torch.cat([x, a], 1)))
        x = self.fc2(x)
        
        return x

In [9]:
cri_lr = 1e-3
act_lr = 5e-4

tau = 5e-3

n_sts = 3
n_acts = 1

dim = 256

n_epis = 500
n_epochs = 200
n_rollout = 200

act_range = (-2, 2)

In [10]:
env = gym.make('Pendulum-v1', render_mode='human')

In [11]:
act = Actor(n_sts, n_acts, dim)
act_opt = torch.optim.Adam(act.parameters(), lr=act_lr)

cri = Critic(n_sts, n_acts, dim)
cri_opt = torch.optim.Adam(cri.parameters(), lr=cri_lr)

In [12]:
ddpg = DDPG(env, n_acts, act, act_opt, cri, cri_opt, act_noise = 0.2, act_range=act_range)

In [13]:
ddpg.train(n_epis, n_epochs, n_rollout)

epoch: 20, score: -1751.666018315563, n_buffer: 4200
epoch: 40, score: -1676.972796963652, n_buffer: 8200
epoch: 60, score: -1677.2373714341877, n_buffer: 9999
epoch: 80, score: -1677.392881584571, n_buffer: 9999
epoch: 100, score: -1676.5340921865056, n_buffer: 9999
epoch: 120, score: -1677.8810232183864, n_buffer: 9999
epoch: 140, score: -1677.1170577453952, n_buffer: 9999
epoch: 160, score: -1676.8608447450922, n_buffer: 9999
epoch: 180, score: -1676.6978887007285, n_buffer: 9999
epoch: 20, score: -1757.6466640217066, n_buffer: 9999
epoch: 40, score: -1676.7013110346554, n_buffer: 9999
epoch: 60, score: -1676.4706097840633, n_buffer: 9999
epoch: 80, score: -1677.5023613191624, n_buffer: 9999
epoch: 100, score: -1676.6896375974381, n_buffer: 9999
epoch: 120, score: -1677.0211650360038, n_buffer: 9999
epoch: 140, score: -1675.765044586531, n_buffer: 9999
epoch: 160, score: -1677.1629501273528, n_buffer: 9999
epoch: 180, score: -1676.9731790873925, n_buffer: 9999


KeyboardInterrupt: 

In [14]:
env.close()