## Soft Actor-Critic
Implementation using Pytorch

In [21]:
import jax.numpy as jnp
import jax.random as jrandom
import jax
from jax.lax import stop_gradient

import torch
import torch.nn as nn

from src.systems.linear import StochasticDoubleIntegrator

import numpy as np

import matplotlib.pyplot as plt

### Neural Net

In [73]:
class NeuralNet:
    def __init__(self, dim, eta=1e-2):
        (n_input, n_hidden, n_out) = dim
        self.model = nn.Sequential(nn.Linear(n_input, n_hidden),
                                    nn.ReLU(),
                                    nn.Linear(n_hidden, n_out))
        self.eta = eta
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=eta)

input = torch.tensor([[0.1, 1.0, 0.5]])
dim = (3, 32, 1)
Net = NeuralNet(dim)
x_train = torch.randn(1, 3)
Net.model(input)

tensor([[-0.1128]], grad_fn=<AddmmBackward0>)

### Q-function

In [100]:
class SoftQFunction(NeuralNet):
    def __init__(self, dimensions, eta=1e-2):
        super().__init__(dimensions, eta=eta)
        self.gamma = .9
        self.sample_size = 1
        self.n_epochs = 10
    
    def loss(self, D, value_func):
        bellman_residual = torch.tensor([[0]])
        N = len(D)
        for it in range(min(N, self.sample_size)):
            if it == 0:
                s0, u, rew, s1 = D[-1]
            else:
                idx = np.random.randint(0, N)
                s0, u, rew, s1 = D[idx]
            Q = self.get_output(s0, u)
            Q_hat = rew + self.gamma * value_func(s1)
            bellman_residual = bellman_residual + (Q - Q_hat)**2 / 2
        return bellman_residual / min(N, self.sample_size)
    
    def update(self, D, value_func):
        losses = []
        for _ in range(self.n_epochs):
            loss = self.loss(D, value_func)
            losses.append(loss.item())

            self.model.zero_grad()
            loss.backward()

            self.optimizer.step()
        return losses
    
    def get_output(self, state, control):
        input = torch.cat((state, control), axis=1).to(torch.float32)
        y_hat = self.model(input)
        return y_hat
    
    def get_value(self, state, control):
        return self.get_output(state, control).detach().numpy()



### Value function

In [142]:
class SoftValueFunction(NeuralNet):
    def __init__(self, dimensions, eta=1e-2):
        super().__init__(dimensions, eta=eta)
        self.sample_size = 1
        self.n_epochs = 1
    
    def loss(self, D, q_func, pi_log_func):
        squared_residual_error = 0
        N = len(D)
        for it in range(min(N, self.sample_size)):
            if it == 0:
                s0, u, _, _ = D[-1]
            else:
                idx = np.random.randint(0, N)
                s0, u, _, _ = D[idx]
            V = self.model(s0)
            # Sample u from policy pi
            Q = q_func(s0, u)
            log_pi = pi_log_func(s0, u)
            squared_residual_error += (V - (Q - log_pi))**2 / 2
        return squared_residual_error / min(N, self.sample_size)
    
    def update(self, D, q_func, pi_log_func):
        losses = []
        for _ in range(self.n_epochs):
            loss = self.loss(D, q_func, pi_log_func)
            losses.append(loss.item())

            self.model.zero_grad()
            loss.backward()

            self.optimizer.step()
        return losses
    
    def get_value(self, input):
        return self.model(input).detach().numpy()





### Policy function

In [157]:
class SoftPolicyFunction:
    def __init__(self, dimensions, eta=1e-2):
        (n_input, n_out) = dimensions
        self.model = nn.Sequential(nn.Linear(n_input, n_out))
        self.eta = eta
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=eta)
        self.stdev = .5
    
    def loss(self, D, q_func):
        KL_divergence = 0
        N = len(D)
        for it in range(min(N, self.sample_size)):
            if it == 0:
                s0, u, _, _ = D[-1]
            else:
                idx = np.random.randint(0, N)
                s0, u, _, _ = D[idx]
            # u should be sampled
            KL_divergence += self.log_prob(s0, u) - q_func(s0, u)
        return KL_divergence
    
    def update(self, D, q_func):
        losses = []
        for _ in range(self.n_epochs):
            loss = self.loss(D, q_func)
            losses.append(loss.item())

            self.model.zero_grad()
            loss.backward()

            self.optimizer.step()
        return losses
    
    def get_control(self, state):
        u_star = self.model(state)
        xi = np.random.normal()
        u = u_star + xi * self.stdev
        return u, u_star
    
    def grad_phi(self, state, control, q_func):
        params = self.params
        grad_phi_log_pi = jax.grad(self.log_pi)(params, state, control)
        grad_u = self.grad_u_log_pi(state, control)
        #state_tensor = torch.tensor([state]).to(torch.float32)
        #control_tensor = torch.tensor([control], requires_grad=True).to(torch.float32)
        q_value = q_func(state, control)
        q_value.backward()
        grad_Q = control.grad
        #grad_Q = jax.grad(q_func, argnums=1)(state, control)
        return grad_phi_log_pi + (grad_u - grad_Q)*state
        
    def grad_u_log_pi(self, state, control):
        mu = self.model(state)
        grad = -(control - mu) / (self.stdev**2)
        return grad

    def log_pi(self, params, state, control):
        mu = jnp.dot(stop_gradient(params), state)
        prob = -.5 * ((stop_gradient(control) - mu) / self.stdev)**2 - jnp.log(self.stdev) + jnp.log(2*jnp.pi)/2
        return prob[0]

    def log_prob(self, state, control):
        mu = self.predict(state)
        return -.5 * ((control - mu) / self.stdev)**2 - jnp.log(self.stdev) + jnp.log(2*jnp.pi)/2


key = jrandom.PRNGKey(0)
T = 100
x0 = jnp.array([2, 0])
SDI = StochasticDoubleIntegrator(x0)

dim_q = (3, 32, 1)
dim_v = (2, 32, 1)
dim_pi = (2, 1)
SQF = SoftQFunction(dim_q)
SVF = SoftValueFunction(dim_v)
PI = SoftPolicyFunction(dim_pi)

time_horizon = np.arange(0, T, SDI.dt)
D = []

for t in time_horizon:
    s0_estimate = SDI.observe(key)
    tensor_s0 = torch.tensor([s0_estimate]).to(torch.float32)
    u, _ = PI.get_control(tensor_s0)
    x, cost, done = SDI.update(key, u.detach().numpy()[0], info=True)
    s1_estimate = SDI.observe(key)
    tensor_s1 = torch.tensor([s1_estimate]).to(torch.float32)
    D.append((tensor_s0, u, float(-cost), tensor_s1))

    SQF.update(D, SVF.model)
    SVF.update(D, SQF.get_output, PI.log_prob)
    PI.update(s0_estimate, u, SQF.get_output)

    # step
    key, subkey = jrandom.split(key)

    if done:
        x0 = jrandom.normal(key, (2,))*2
        SDI.reset(x0)


RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

### Main

In [26]:
class SoftActorCritic:
    def __init__(self, key, dim_q, dim_v, dim_pi):
        self.SQF = SoftQFunction(dim_q)
        self.SVF = SoftValueFunction(dim_v)
        self.PI = SoftPolicyFunction(key, dim_pi)
        self.buffer = list()
        #self.tracker = Tracker(['state0', 'state1', 'control', 'cost', 'V_value', 'V_loss', 
        #                            'Q_value', 'Q_loss', 'policy_angle', 'policy_force'])
    
    def update(self, s0, u, tracking=True):
        v_value = self.SVF.get_value(s0)
        q_value = self.SQF.get_value(s0, u)
        v_loss = self.SVF.update(self.buffer, self.SQF.get_output, self.PI.log_prob)
        q_loss = self.SQF.update(self.buffer, self.SVF)

        self.PI.update(s0, u, self.SQF.predict)

        #if tracking:
        #    control_angle = jnp.arctan2(self.PI.params[0,0], self.PI.params[0,1])
        #    control_force = jnp.linalg.norm(self.PI.params)
        #    self.tracker.add([s0[0], s0[1], u, None, v_value, v_loss, q_value, q_loss,
        #                            control_angle, control_force])
    
    def get_control(self, state):
        return self.PI.get_control(state)
    
    def add_to_buffer(self, transition):
        self.buffer.append(transition)

In [30]:
key = jrandom.PRNGKey(0)

dim_q = (3, 32, 1)
dim_v = (2, 32, 1)
dim_pi = (2, 1)

SAC = SoftActorCritic(key, dim_q, dim_v, dim_pi)

key = jrandom.PRNGKey(1)
key, subkey = jrandom.split(key)

T = 500
n_obs = 2
n_ctrl = 1


# Initiate system
x0 = jnp.array([2, 0])
SDI = StochasticDoubleIntegrator(x0, boundary=5)

time_horizon = np.arange(0, T, SDI.dt)

for _ in time_horizon:
    s0_estimate = SDI.observe(key)
    u, _ = SAC.get_control(s0_estimate)
    _, cost, done = SDI.update(key, u, info=True)
    s1_estimate = SDI.observe(subkey)
    SAC.add_to_buffer((s0_estimate, u, -cost, s1_estimate))

    SAC.update(s0_estimate, u)

    # step
    key, subkey = jrandom.split(key)

    if done:
        x0 = jrandom.normal(key, (2,))*2
        SDI.reset(x0)




TypeError: linear(): argument 'input' (position 1) must be Tensor, not numpy.ndarray