In [1]:
import sys
sys.path.append('..')
import gym

import torch
import torch.nn as nn


from models.td3 import TD3

In [2]:
class Actor(nn.Module):
    def __init__(self, nstates: int, nactions: int):
        super(Actor, self).__init__()
                
        self.fc1 = nn.Linear(nstates, 400)
        self.fc2 = nn.Linear(400, 300)
        self.fc3 = nn.Linear(300, nactions)
        
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
    
    def forward(self, x):
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(x))
        x = torch.tanh(self.fc3(x))
        
        return x

In [3]:
class Critic(nn.Module): #다름
    def __init__(self, nstates: int, nactions: int):
        super(Critic, self).__init__()
                
        self.fc1 = nn.Linear(nstates + nactions, 400)
        self.fc2 = nn.Linear(400, 300)
        self.fc3 = nn.Linear(300, 1)
        
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
                
    def forward(self, xs):
        x, a = xs
        
        x = self.relu1(self.fc1(torch.cat([x, a], 1)))
        x = self.relu2(self.fc2(x))
        x = self.fc3(x)
        
        return x

In [4]:
lr = 1e-3
tau = 5e-3
eps = 0
gamma = 0.99

n_sts = 17
n_acts = 6

n_epis = 500
n_epoch = 200
n_rollout = 10

In [14]:
act = Actor(n_sts, n_acts)
act_opt = torch.optim.Adam(act.parameters(), lr=lr)

cri1 = Critic(n_sts, n_acts)
cri_opt1 = torch.optim.Adam(cri1.parameters(), lr=lr)

cri2 = Critic(n_sts, n_acts)
cri_opt2 = torch.optim.Adam(cri2.parameters(), lr=lr)

In [15]:
env = gym.make('HalfCheetah-v4', render_mode='human')

In [16]:
model = TD3(env, n_acts, act, act_opt, cri1, cri_opt1, cri2, cri_opt2, eps=eps, tau=tau, act_noise=0.2, act_range = (-1, 1))

In [17]:
model.train(n_epis, n_epoch, n_rollout)

epoch: 20, score: -5.007670364814524, n_buffer: 210
epoch: 40, score: -5.586418766775719, n_buffer: 410
epoch: 60, score: -6.479032688716799, n_buffer: 610
epoch: 80, score: -8.36169644366104, n_buffer: 810
epoch: 100, score: -8.566147826504466, n_buffer: 1010
epoch: 120, score: -5.156440819363919, n_buffer: 1210
epoch: 140, score: -5.588380862987522, n_buffer: 1410
epoch: 160, score: -5.605906633150953, n_buffer: 1610
epoch: 180, score: -4.814340416737348, n_buffer: 1810
epoch: 20, score: -5.492232146112409, n_buffer: 2210
epoch: 40, score: -4.348670673446226, n_buffer: 2410
epoch: 60, score: -5.25145024012582, n_buffer: 2610
epoch: 80, score: -3.9310038482584218, n_buffer: 2810
epoch: 100, score: -2.569588238258038, n_buffer: 3010
epoch: 120, score: -2.8658359769682944, n_buffer: 3210
epoch: 140, score: -0.9765327310988974, n_buffer: 3410
epoch: 160, score: 0.622546488031255, n_buffer: 3610
epoch: 180, score: -0.2597742257128395, n_buffer: 3810
epoch: 20, score: -0.44514476282325344,

KeyboardInterrupt: 

In [13]:
env.close()