In [1]:
import sys
sys.path.append('..')

import gym

import torch
import torch.nn as nn

from models.ddpg import DDPG

In [2]:
class Actor(nn.Module):
    def __init__(self, nstates: int, nactions: int):
        super(Actor, self).__init__()
                
        self.fc1 = nn.Linear(nstates, 400)
        self.fc2 = nn.Linear(400, 300)
        self.fc3 = nn.Linear(300, nactions)
        
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
    
    def forward(self, x):
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(x))
        x = torch.tanh(self.fc3(x))
        
        return x

In [3]:
class Critic(nn.Module):
    def __init__(self, nstates: int, nactions: int):
        super(Critic, self).__init__()
                
        self.fc1 = nn.Linear(nstates, 400)
        self.fc2 = nn.Linear(nactions + 400, 300)
        self.fc3 = nn.Linear(300, 1)
        
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
                
    def forward(self, xs):
        x, a = xs
        
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(torch.cat([x, a], 1)))
        x = self.fc3(x)
        
        return x

In [9]:
cri_lr = 1e-3
act_lr = 5e-4

tau = 5e-3

nstates = 17
nactions = 6

n_epis = 500
n_epochs = 200
n_rollout = 10

act_range = (-1, 1)

In [26]:
env = gym.make("HalfCheetah-v4", render_mode='human')

In [27]:
act = Actor(nstates, nactions)
act_opt = torch.optim.Adam(act.parameters(), lr=act_lr)

cri = Critic(nstates, nactions)
cri_opt = torch.optim.Adam(cri.parameters(), lr=cri_lr)

In [28]:
ddpg = DDPG(env, nactions, act, act_opt, cri, cri_opt, act_noise=0.3, act_range=act_range)

In [29]:
ddpg.train(n_epis, n_epochs, n_rollout)

epoch: 20, score: -5.811917284592793, n_buffer: 210
epoch: 40, score: -4.6575051841774595, n_buffer: 410
epoch: 60, score: -4.950014137267848, n_buffer: 610
epoch: 80, score: -4.051159667608857, n_buffer: 810
epoch: 100, score: -3.084050745979697, n_buffer: 1010
epoch: 120, score: -0.9182716500171747, n_buffer: 1210
epoch: 140, score: -2.490610524644855, n_buffer: 1410
epoch: 160, score: -5.795042367251741, n_buffer: 1610
epoch: 180, score: -8.607199648123208, n_buffer: 1810
epoch: 20, score: -7.424091521709229, n_buffer: 2210
epoch: 40, score: -5.657285555716844, n_buffer: 2410
epoch: 60, score: -7.299208831808151, n_buffer: 2610
epoch: 80, score: -6.297047857009563, n_buffer: 2810
epoch: 100, score: -5.925792095024519, n_buffer: 3010
epoch: 120, score: -4.027448473349675, n_buffer: 3210
epoch: 140, score: -1.2926775526738756, n_buffer: 3410
epoch: 160, score: -4.728659390139326, n_buffer: 3610
epoch: 180, score: -5.412483985846993, n_buffer: 3810
epoch: 20, score: -3.3530186243897724

KeyboardInterrupt: 

In [30]:
env.close()