Skip to content

Commit

Permalink
Merge pull request #11 from MinRegret/bpc
Browse files Browse the repository at this point in the history
Add BPC
  • Loading branch information
eladhazan committed Jun 16, 2023
2 parents f3c3884 + 78e6d95 commit 74ba003
Show file tree
Hide file tree
Showing 4 changed files with 290 additions and 3 deletions.
3 changes: 2 additions & 1 deletion deluca/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from deluca.agents._bpc import BPC
from deluca.agents._gpc import GPC
from deluca.agents._hinf import Hinf
from deluca.agents._ilqr import ILQR
Expand All @@ -21,4 +22,4 @@
from deluca.agents._adaptive import Adaptive
from deluca.agents._deep import Deep

__all__ = ["LQR", "PID", "GPC", "ILQR", "Hinf", "Zero", "DRC", "Adaptive", "Deep"]
__all__ = ["LQR", "PID", "BPC", "GPC", "ILQR", "Hinf", "Zero", "DRC", "Adaptive", "Deep"]
161 changes: 161 additions & 0 deletions deluca/agents/_bpc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""deluca.agents._bpc"""
from numbers import Real
from typing import Callable

import jax
import jax.numpy as jnp
import numpy as np
import numpy.random as random
from jax import grad
from jax import jit

from deluca.agents._lqr import LQR
from deluca.agents.core import Agent

def generate_uniform(shape, norm=1.00):
v = random.normal(size=shape)
v = norm * v / np.linalg.norm(v)
v = np.array(v)
return v

class BPC(Agent):
def __init__(
self,
A: jnp.ndarray,
B: jnp.ndarray,
Q: jnp.ndarray = None,
R: jnp.ndarray = None,
K: jnp.ndarray = None,
start_time: int = 0,
H: int = 5,
lr_scale: Real = 0.005,
decay: bool = False,
delta: Real = 0.01
) -> None:
"""
Description: Initialize the dynamics of the model.
Args:
A (jnp.ndarray): system dynamics
B (jnp.ndarray): system dynamics
Q (jnp.ndarray): cost matrices (i.e. cost = x^TQx + u^TRu)
R (jnp.ndarray): cost matrices (i.e. cost = x^TQx + u^TRu)
K (jnp.ndarray): Starting policy (optional). Defaults to LQR gain.
start_time (int):
H (postive int): history of the controller
lr_scale (Real):
decay (boolean):
"""

self.d_state, self.d_action = B.shape # State & Action Dimensions

self.A, self.B = A, B # System Dynamics

self.t = 0 # Time Counter (for decaying learning rate)

self.H = H

self.lr_scale, self.decay = lr_scale, decay

self.delta = delta

# Model Parameters
# initial linear policy / perturbation contributions / bias
# TODO: need to address problem of LQR with jax.lax.scan
self.K = K if K is not None else LQR(self.A, self.B, Q, R).K

self.M = self.delta * generate_uniform((H, self.d_action, self.d_state))

# Past H noises ordered increasing in time
self.noise_history = jnp.zeros((H, self.d_state, 1))

# past state and past action
self.state, self.action = jnp.zeros((self.d_state, 1)), jnp.zeros((self.d_action, 1))

self.eps = generate_uniform((H, H, self.d_action, self.d_state))
self.eps_bias = generate_uniform((H, self.d_action, 1))

def grad(M, noise_history, cost):
return cost * jnp.sum(self.eps, axis = 0)

self.grad = grad

def __call__(self,
state: jnp.ndarray,
cost: Real
) -> jnp.ndarray:
"""
Description: Return the action based on current state and internal parameters.
Args:
state (jnp.ndarray): current state
Returns:
jnp.ndarray: action to take
"""

action = self.get_action(state)
self.update(state, action, cost)
return action

def update(self,
state: jnp.ndarray,
action:jnp.ndarray,
cost: Real
) -> None:
"""
Description: update agent internal state.
Args:
state (jnp.ndarray): current state
action (jnp.ndarray): action taken
cost (Real): scalar cost received
Returns:
None
"""
noise = state - self.A @ self.state - self.B @ action
self.noise_history = jax.ops.index_update(self.noise_history, 0, noise)
self.noise_history = jnp.roll(self.noise_history, -1, axis=0)

lr = self.lr_scale
lr *= (1/ (self.t**(3/4)+1)) if self.decay else 1

delta_M = self.grad(self.M, self.noise_history, cost)
self.M -= lr * delta_M

self.eps = jax.ops.index_update(self.eps, 0, \
generate_uniform((self.H, self.d_action, self.d_state)))
self.eps = np.roll(self.eps, -1, axis = 0)

self.M += self.delta * self.eps[-1]

# update state
self.state = state

self.t += 1

def get_action(self, state: jnp.ndarray) -> jnp.ndarray:
"""
Description: get action from state.
Args:
state (jnp.ndarray):
Returns:
jnp.ndarray
"""
return -self.K @ state + jnp.tensordot(self.M, self.noise_history, axes=([0, 2], [0, 1]))
4 changes: 2 additions & 2 deletions deluca/agents/_gpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(
# Model Parameters
# initial linear policy / perturbation contributions / bias
# TODO: need to address problem of LQR with jax.lax.scan
self.K = K if K is not None else LQR(self.A, self.B, Q, R).K
self.K = K if K is not None else LQR(self.A, self.B, Q, R).K

self.M = jnp.zeros((H, d_action, d_state))

Expand Down Expand Up @@ -156,7 +156,7 @@ def update(self, state: jnp.ndarray, u:jnp.ndarray) -> None:
lr = self.lr_scale
lr *= (1/ (self.t+1)) if self.decay else 1
self.M -= lr * delta_M
self.M -= lr * delta_M
self.bias -= lr * delta_bias

# update state
self.state = state
Expand Down
125 changes: 125 additions & 0 deletions examples/agents/BPC test.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 279,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 280,
"metadata": {},
"outputs": [],
"source": [
"from deluca.agents import BPC, LQR\n",
"from deluca.envs import LDS\n",
"import jax.numpy as jnp"
]
},
{
"cell_type": "code",
"execution_count": 281,
"metadata": {},
"outputs": [],
"source": [
"A,B = jnp.array([[.8,.5], [0,.8]]), jnp.array([[0],[0.8]])\n",
"lds = LDS(state_size= B.shape[0], action_size=B.shape[1], A=A, B=B)"
]
},
{
"cell_type": "code",
"execution_count": 282,
"metadata": {},
"outputs": [],
"source": [
"bpc = BPC(lds.A, lds.B, lr_scale=1e-3, delta=0.01)\n",
"lqr = LQR(A, B)"
]
},
{
"cell_type": "code",
"execution_count": 283,
"metadata": {},
"outputs": [],
"source": [
"def get_err(T, lds, controller):\n",
" lds.reset()\n",
" avg_err = 0\n",
" err = 0\n",
" for i in range(T):\n",
" try:\n",
" action = controller(lds.state, err)\n",
" except:\n",
" action = controller(lds.state)\n",
" lds.step(action)\n",
" lds.state += 0.03 * jnp.sin(i) # add sine noise\n",
" err = jnp.linalg.norm(lds.state)+jnp.linalg.norm(action)\n",
" avg_err += err/T\n",
" return avg_err"
]
},
{
"cell_type": "code",
"execution_count": 284,
"metadata": {},
"outputs": [],
"source": [
"T = 100"
]
},
{
"cell_type": "code",
"execution_count": 285,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"BPC incurs 0.07014202 loss\n",
"LQR incurs 0.10718855 loss\n"
]
}
],
"source": [
"print(\"BPC incurs \", get_err(T, lds, bpc), \" loss\")\n",
"print(\"LQR incurs \", get_err(T, lds, lqr), \" loss\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

0 comments on commit 74ba003

Please sign in to comment.