-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #11 from MinRegret/bpc
Add BPC
- Loading branch information
Showing
4 changed files
with
290 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
# Copyright 2020 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""deluca.agents._bpc""" | ||
from numbers import Real | ||
from typing import Callable | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
import numpy.random as random | ||
from jax import grad | ||
from jax import jit | ||
|
||
from deluca.agents._lqr import LQR | ||
from deluca.agents.core import Agent | ||
|
||
def generate_uniform(shape, norm=1.00): | ||
v = random.normal(size=shape) | ||
v = norm * v / np.linalg.norm(v) | ||
v = np.array(v) | ||
return v | ||
|
||
class BPC(Agent): | ||
def __init__( | ||
self, | ||
A: jnp.ndarray, | ||
B: jnp.ndarray, | ||
Q: jnp.ndarray = None, | ||
R: jnp.ndarray = None, | ||
K: jnp.ndarray = None, | ||
start_time: int = 0, | ||
H: int = 5, | ||
lr_scale: Real = 0.005, | ||
decay: bool = False, | ||
delta: Real = 0.01 | ||
) -> None: | ||
""" | ||
Description: Initialize the dynamics of the model. | ||
Args: | ||
A (jnp.ndarray): system dynamics | ||
B (jnp.ndarray): system dynamics | ||
Q (jnp.ndarray): cost matrices (i.e. cost = x^TQx + u^TRu) | ||
R (jnp.ndarray): cost matrices (i.e. cost = x^TQx + u^TRu) | ||
K (jnp.ndarray): Starting policy (optional). Defaults to LQR gain. | ||
start_time (int): | ||
H (postive int): history of the controller | ||
lr_scale (Real): | ||
decay (boolean): | ||
""" | ||
|
||
self.d_state, self.d_action = B.shape # State & Action Dimensions | ||
|
||
self.A, self.B = A, B # System Dynamics | ||
|
||
self.t = 0 # Time Counter (for decaying learning rate) | ||
|
||
self.H = H | ||
|
||
self.lr_scale, self.decay = lr_scale, decay | ||
|
||
self.delta = delta | ||
|
||
# Model Parameters | ||
# initial linear policy / perturbation contributions / bias | ||
# TODO: need to address problem of LQR with jax.lax.scan | ||
self.K = K if K is not None else LQR(self.A, self.B, Q, R).K | ||
|
||
self.M = self.delta * generate_uniform((H, self.d_action, self.d_state)) | ||
|
||
# Past H noises ordered increasing in time | ||
self.noise_history = jnp.zeros((H, self.d_state, 1)) | ||
|
||
# past state and past action | ||
self.state, self.action = jnp.zeros((self.d_state, 1)), jnp.zeros((self.d_action, 1)) | ||
|
||
self.eps = generate_uniform((H, H, self.d_action, self.d_state)) | ||
self.eps_bias = generate_uniform((H, self.d_action, 1)) | ||
|
||
def grad(M, noise_history, cost): | ||
return cost * jnp.sum(self.eps, axis = 0) | ||
|
||
self.grad = grad | ||
|
||
def __call__(self, | ||
state: jnp.ndarray, | ||
cost: Real | ||
) -> jnp.ndarray: | ||
""" | ||
Description: Return the action based on current state and internal parameters. | ||
Args: | ||
state (jnp.ndarray): current state | ||
Returns: | ||
jnp.ndarray: action to take | ||
""" | ||
|
||
action = self.get_action(state) | ||
self.update(state, action, cost) | ||
return action | ||
|
||
def update(self, | ||
state: jnp.ndarray, | ||
action:jnp.ndarray, | ||
cost: Real | ||
) -> None: | ||
""" | ||
Description: update agent internal state. | ||
Args: | ||
state (jnp.ndarray): current state | ||
action (jnp.ndarray): action taken | ||
cost (Real): scalar cost received | ||
Returns: | ||
None | ||
""" | ||
noise = state - self.A @ self.state - self.B @ action | ||
self.noise_history = jax.ops.index_update(self.noise_history, 0, noise) | ||
self.noise_history = jnp.roll(self.noise_history, -1, axis=0) | ||
|
||
lr = self.lr_scale | ||
lr *= (1/ (self.t**(3/4)+1)) if self.decay else 1 | ||
|
||
delta_M = self.grad(self.M, self.noise_history, cost) | ||
self.M -= lr * delta_M | ||
|
||
self.eps = jax.ops.index_update(self.eps, 0, \ | ||
generate_uniform((self.H, self.d_action, self.d_state))) | ||
self.eps = np.roll(self.eps, -1, axis = 0) | ||
|
||
self.M += self.delta * self.eps[-1] | ||
|
||
# update state | ||
self.state = state | ||
|
||
self.t += 1 | ||
|
||
def get_action(self, state: jnp.ndarray) -> jnp.ndarray: | ||
""" | ||
Description: get action from state. | ||
Args: | ||
state (jnp.ndarray): | ||
Returns: | ||
jnp.ndarray | ||
""" | ||
return -self.K @ state + jnp.tensordot(self.M, self.noise_history, axes=([0, 2], [0, 1])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 279, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"The autoreload extension is already loaded. To reload it, use:\n", | ||
" %reload_ext autoreload\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"%load_ext autoreload\n", | ||
"%autoreload 2" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 280, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from deluca.agents import BPC, LQR\n", | ||
"from deluca.envs import LDS\n", | ||
"import jax.numpy as jnp" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 281, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"A,B = jnp.array([[.8,.5], [0,.8]]), jnp.array([[0],[0.8]])\n", | ||
"lds = LDS(state_size= B.shape[0], action_size=B.shape[1], A=A, B=B)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 282, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"bpc = BPC(lds.A, lds.B, lr_scale=1e-3, delta=0.01)\n", | ||
"lqr = LQR(A, B)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 283, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def get_err(T, lds, controller):\n", | ||
" lds.reset()\n", | ||
" avg_err = 0\n", | ||
" err = 0\n", | ||
" for i in range(T):\n", | ||
" try:\n", | ||
" action = controller(lds.state, err)\n", | ||
" except:\n", | ||
" action = controller(lds.state)\n", | ||
" lds.step(action)\n", | ||
" lds.state += 0.03 * jnp.sin(i) # add sine noise\n", | ||
" err = jnp.linalg.norm(lds.state)+jnp.linalg.norm(action)\n", | ||
" avg_err += err/T\n", | ||
" return avg_err" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 284, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"T = 100" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 285, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"BPC incurs 0.07014202 loss\n", | ||
"LQR incurs 0.10718855 loss\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"print(\"BPC incurs \", get_err(T, lds, bpc), \" loss\")\n", | ||
"print(\"LQR incurs \", get_err(T, lds, lqr), \" loss\")" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.8.3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 4 | ||
} |