In [2]:
import gym
import torch
import torch.nn as nn


from models.td3 import TD3

In [3]:
class Actor(nn.Module):
    def __init__(self, nstates: int, nactions: int):
        super(Actor, self).__init__()
                
        self.fc1 = nn.Linear(nstates, 400)
        self.fc2 = nn.Linear(400, 300)
        self.fc3 = nn.Linear(300, nactions)
        
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
    
    def forward(self, x):
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(x))
        mu = torch.tanh(self.fc3(x)) * 2
        
        return mu

In [4]:
class Critic(nn.Module): #다름
    def __init__(self, nstates: int, nactions: int):
        super(Critic, self).__init__()
                
        self.fc1 = nn.Linear(nstates + nactions, 400)
        self.fc2 = nn.Linear(400, 300)
        self.fc3 = nn.Linear(300, 1)
        
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
                
    def forward(self, xs):
        x, a = xs
        
        x = self.relu1(self.fc1(torch.cat([x, a], 1)))
        x = self.relu2(self.fc2(x))
        x = self.fc3(x)
        
        return x

In [7]:
lr = 1e-3
tau = 5e-3
eps = 0
gamma = 0.99

n_sts = 3
n_acts = 1

n_epis = 500
n_rollout = 1000

In [8]:
act = Actor(n_sts, n_acts)
act_opt = torch.optim.Adam(act.parameters(), lr=lr)

cri1 = Critic(n_sts, n_acts)
cri_opt1 = torch.optim.Adam(cri1.parameters(), lr=lr)

cri2 = Critic(n_sts, n_acts)
cri_opt2 = torch.optim.Adam(cri2.parameters(), lr=lr)

In [9]:
env = gym.make('Pendulum-v1', render_mode='human')

In [10]:
model = TD3(env, n_acts, act, act_opt, cri1, cri_opt1, cri2, cri_opt2, eps=eps, tau=tau, act_range = (-2, 2))

In [11]:
model.train(n_epis, n_rollout)

  if not isinstance(terminated, (bool, np.bool8)):


n_epi: 0, score: -7.761124816948159, n_buffer: 201
n_epi: 0, score: -8.461473378535008, n_buffer: 401
n_epi: 0, score: -8.32503781634143, n_buffer: 601
n_epi: 0, score: -8.556512833149135, n_buffer: 801
n_epi: 1, score: -15.913264549516644, n_buffer: 1201
n_epi: 1, score: -8.548851324373613, n_buffer: 1401
n_epi: 1, score: -8.302280179020217, n_buffer: 1601
n_epi: 1, score: -8.52423285440799, n_buffer: 1801
n_epi: 2, score: -16.390971884518763, n_buffer: 2201
n_epi: 2, score: -8.47832457440721, n_buffer: 2401
n_epi: 2, score: -8.47221254166231, n_buffer: 2601
n_epi: 2, score: -8.319217329764422, n_buffer: 2801
n_epi: 3, score: -16.55248823542514, n_buffer: 3201
n_epi: 3, score: -8.562268510788252, n_buffer: 3401
n_epi: 3, score: -8.326322973083373, n_buffer: 3601
n_epi: 3, score: -8.469795616637436, n_buffer: 3801
n_epi: 4, score: -15.536870443465995, n_buffer: 4201
n_epi: 4, score: -8.41966086592082, n_buffer: 4401
n_epi: 4, score: -8.49732491281742, n_buffer: 4601
n_epi: 4, score: -8

KeyboardInterrupt: 

In [55]:
env.close()