In [1]:
import sys
sys.path.append('..')

import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal

from models.sac import SAC


In [2]:
class Actor(nn.Module):
    def __init__(self, n_sts: int, n_acts: int, dim: int):
        super(Actor, self).__init__()
                
        self.fc1 = nn.Linear(n_sts, dim)
        self.fc_mu = nn.Linear(dim, n_acts)
        self.fc_std = nn.Linear(dim, n_acts)
        
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
    
    def forward(self, x):
        x = self.relu1(self.fc1(x))
        mu = self.fc_mu(x)
        std = F.softplus(self.fc_std(x))
        dist = Normal(mu, std)
        act = dist.rsample()
        log_prob = dist.log_prob(act)
        real_act = torch.tanh(act)
        real_log_prob = log_prob - torch.log(1 - torch.tanh(act).pow(2) + 1e-7)

        return real_act, real_log_prob.sum(-1, keepdims=True)

In [3]:
class QCritic(nn.Module): 
    def __init__(self, n_sts: int, act_dim: int, dim: int):
        super(QCritic, self).__init__()
                
        self.fc1 = nn.Linear(n_sts + act_dim, dim)
        self.fc2 = nn.Linear(dim, 1)
        
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
                
    def forward(self, xs):
        x, a = xs
        
        x = self.relu1(self.fc1(torch.cat([x, a], 1)))
        x = self.fc2(x)
        
        return x

In [4]:
lr = 3e-4
tau = 5e-3
eps = 0
gamma = 0.99

n_sts = 17
n_acts = 6
act_dim = 6

dim = 256

n_epis = 5000
n_epochs = 100
n_rollout = 10

In [5]:
act = Actor(n_sts, n_acts, dim)
act_opt = torch.optim.Adam(act.parameters(), lr=lr)

qcri1 = QCritic(n_sts, act_dim, dim)
qcri1_opt = torch.optim.Adam(qcri1.parameters(), lr=lr)

qcri2 = QCritic(n_sts, act_dim, dim)
qcri2_opt = torch.optim.Adam(qcri2.parameters(), lr=lr)

In [6]:
env = gym.make("HalfCheetah-v4", render_mode='human')
model = SAC(env, n_acts, act, act_opt, qcri1, qcri1_opt, qcri2, qcri2_opt, eps=eps, tau=tau, target_entropy=-n_acts)

In [7]:
model.train(n_epis, n_epochs, n_rollout)

  if not isinstance(terminated, (bool, np.bool8)):


step: 210, score: -2.918459021324853, n_buffer: 210
step: 410, score: -6.178574724810211, n_buffer: 410
step: 610, score: -3.7947563660148496, n_buffer: 610
step: 810, score: -8.879499541639623, n_buffer: 810
step: 1210, score: -5.795249003990034, n_buffer: 1210
step: 1410, score: -2.0259409817711225, n_buffer: 1410
step: 1610, score: 0.5939183816121175, n_buffer: 1610
step: 1810, score: -3.7368113403220145, n_buffer: 1810
step: 2210, score: -0.5340876998186311, n_buffer: 2210
step: 2410, score: 1.4865356335534297, n_buffer: 2410
step: 2610, score: -2.976091721169463, n_buffer: 2610
step: 2810, score: -0.4708618369482144, n_buffer: 2810
step: 3210, score: -2.0733693433245977, n_buffer: 3210
step: 3410, score: 0.18890728755680114, n_buffer: 3410
step: 3610, score: 2.315229242327237, n_buffer: 3610
step: 3810, score: -2.186691812151995, n_buffer: 3810
step: 4210, score: -0.47346843006915484, n_buffer: 4210
step: 4410, score: 0.10217939746152194, n_buffer: 4410
step: 4610, score: -0.88346

KeyboardInterrupt: 

In [None]:
env.close()