From 78e6d954666fa7c8ba4f2e22f6c79fb75428a3fa Mon Sep 17 00:00:00 2001 From: paula-gradu Date: Mon, 18 Jan 2021 13:39:57 -0500 Subject: [PATCH] Add BPC --- deluca/agents/__init__.py | 3 +- deluca/agents/_bpc.py | 161 +++++++++++++++++++++++++++++++++ deluca/agents/_gpc.py | 4 +- examples/agents/BPC test.ipynb | 125 +++++++++++++++++++++++++ 4 files changed, 290 insertions(+), 3 deletions(-) create mode 100644 deluca/agents/_bpc.py create mode 100644 examples/agents/BPC test.ipynb diff --git a/deluca/agents/__init__.py b/deluca/agents/__init__.py index 1a74119..67faa35 100644 --- a/deluca/agents/__init__.py +++ b/deluca/agents/__init__.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from deluca.agents._bpc import BPC from deluca.agents._gpc import GPC from deluca.agents._hinf import Hinf from deluca.agents._ilqr import ILQR @@ -21,4 +22,4 @@ from deluca.agents._adaptive import Adaptive from deluca.agents._deep import Deep -__all__ = ["LQR", "PID", "GPC", "ILQR", "Hinf", "Zero", "DRC", "Adaptive", "Deep"] +__all__ = ["LQR", "PID", "BPC", "GPC", "ILQR", "Hinf", "Zero", "DRC", "Adaptive", "Deep"] diff --git a/deluca/agents/_bpc.py b/deluca/agents/_bpc.py new file mode 100644 index 0000000..4cb46a3 --- /dev/null +++ b/deluca/agents/_bpc.py @@ -0,0 +1,161 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""deluca.agents._bpc""" +from numbers import Real +from typing import Callable + +import jax +import jax.numpy as jnp +import numpy as np +import numpy.random as random +from jax import grad +from jax import jit + +from deluca.agents._lqr import LQR +from deluca.agents.core import Agent + +def generate_uniform(shape, norm=1.00): + v = random.normal(size=shape) + v = norm * v / np.linalg.norm(v) + v = np.array(v) + return v + +class BPC(Agent): + def __init__( + self, + A: jnp.ndarray, + B: jnp.ndarray, + Q: jnp.ndarray = None, + R: jnp.ndarray = None, + K: jnp.ndarray = None, + start_time: int = 0, + H: int = 5, + lr_scale: Real = 0.005, + decay: bool = False, + delta: Real = 0.01 + ) -> None: + """ + Description: Initialize the dynamics of the model. + + Args: + A (jnp.ndarray): system dynamics + B (jnp.ndarray): system dynamics + Q (jnp.ndarray): cost matrices (i.e. cost = x^TQx + u^TRu) + R (jnp.ndarray): cost matrices (i.e. cost = x^TQx + u^TRu) + K (jnp.ndarray): Starting policy (optional). Defaults to LQR gain. + start_time (int): + H (postive int): history of the controller + lr_scale (Real): + decay (boolean): + """ + + self.d_state, self.d_action = B.shape # State & Action Dimensions + + self.A, self.B = A, B # System Dynamics + + self.t = 0 # Time Counter (for decaying learning rate) + + self.H = H + + self.lr_scale, self.decay = lr_scale, decay + + self.delta = delta + + # Model Parameters + # initial linear policy / perturbation contributions / bias + # TODO: need to address problem of LQR with jax.lax.scan + self.K = K if K is not None else LQR(self.A, self.B, Q, R).K + + self.M = self.delta * generate_uniform((H, self.d_action, self.d_state)) + + # Past H noises ordered increasing in time + self.noise_history = jnp.zeros((H, self.d_state, 1)) + + # past state and past action + self.state, self.action = jnp.zeros((self.d_state, 1)), jnp.zeros((self.d_action, 1)) + + self.eps = generate_uniform((H, H, self.d_action, self.d_state)) + self.eps_bias = generate_uniform((H, self.d_action, 1)) + + def grad(M, noise_history, cost): + return cost * jnp.sum(self.eps, axis = 0) + + self.grad = grad + + def __call__(self, + state: jnp.ndarray, + cost: Real + ) -> jnp.ndarray: + """ + Description: Return the action based on current state and internal parameters. + + Args: + state (jnp.ndarray): current state + + Returns: + jnp.ndarray: action to take + """ + + action = self.get_action(state) + self.update(state, action, cost) + return action + + def update(self, + state: jnp.ndarray, + action:jnp.ndarray, + cost: Real + ) -> None: + """ + Description: update agent internal state. + + Args: + state (jnp.ndarray): current state + action (jnp.ndarray): action taken + cost (Real): scalar cost received + + Returns: + None + """ + noise = state - self.A @ self.state - self.B @ action + self.noise_history = jax.ops.index_update(self.noise_history, 0, noise) + self.noise_history = jnp.roll(self.noise_history, -1, axis=0) + + lr = self.lr_scale + lr *= (1/ (self.t**(3/4)+1)) if self.decay else 1 + + delta_M = self.grad(self.M, self.noise_history, cost) + self.M -= lr * delta_M + + self.eps = jax.ops.index_update(self.eps, 0, \ + generate_uniform((self.H, self.d_action, self.d_state))) + self.eps = np.roll(self.eps, -1, axis = 0) + + self.M += self.delta * self.eps[-1] + + # update state + self.state = state + + self.t += 1 + + def get_action(self, state: jnp.ndarray) -> jnp.ndarray: + """ + Description: get action from state. + + Args: + state (jnp.ndarray): + + Returns: + jnp.ndarray + """ + return -self.K @ state + jnp.tensordot(self.M, self.noise_history, axes=([0, 2], [0, 1])) diff --git a/deluca/agents/_gpc.py b/deluca/agents/_gpc.py index c654fd1..39cde83 100644 --- a/deluca/agents/_gpc.py +++ b/deluca/agents/_gpc.py @@ -87,7 +87,7 @@ def __init__( # Model Parameters # initial linear policy / perturbation contributions / bias # TODO: need to address problem of LQR with jax.lax.scan - self.K = K if K is not None else LQR(self.A, self.B, Q, R).K + self.K = K if K is not None else LQR(self.A, self.B, Q, R).K self.M = jnp.zeros((H, d_action, d_state)) @@ -156,7 +156,7 @@ def update(self, state: jnp.ndarray, u:jnp.ndarray) -> None: lr = self.lr_scale lr *= (1/ (self.t+1)) if self.decay else 1 self.M -= lr * delta_M - self.M -= lr * delta_M + self.bias -= lr * delta_bias # update state self.state = state diff --git a/examples/agents/BPC test.ipynb b/examples/agents/BPC test.ipynb new file mode 100644 index 0000000..a8015bb --- /dev/null +++ b/examples/agents/BPC test.ipynb @@ -0,0 +1,125 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 279, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 280, + "metadata": {}, + "outputs": [], + "source": [ + "from deluca.agents import BPC, LQR\n", + "from deluca.envs import LDS\n", + "import jax.numpy as jnp" + ] + }, + { + "cell_type": "code", + "execution_count": 281, + "metadata": {}, + "outputs": [], + "source": [ + "A,B = jnp.array([[.8,.5], [0,.8]]), jnp.array([[0],[0.8]])\n", + "lds = LDS(state_size= B.shape[0], action_size=B.shape[1], A=A, B=B)" + ] + }, + { + "cell_type": "code", + "execution_count": 282, + "metadata": {}, + "outputs": [], + "source": [ + "bpc = BPC(lds.A, lds.B, lr_scale=1e-3, delta=0.01)\n", + "lqr = LQR(A, B)" + ] + }, + { + "cell_type": "code", + "execution_count": 283, + "metadata": {}, + "outputs": [], + "source": [ + "def get_err(T, lds, controller):\n", + " lds.reset()\n", + " avg_err = 0\n", + " err = 0\n", + " for i in range(T):\n", + " try:\n", + " action = controller(lds.state, err)\n", + " except:\n", + " action = controller(lds.state)\n", + " lds.step(action)\n", + " lds.state += 0.03 * jnp.sin(i) # add sine noise\n", + " err = jnp.linalg.norm(lds.state)+jnp.linalg.norm(action)\n", + " avg_err += err/T\n", + " return avg_err" + ] + }, + { + "cell_type": "code", + "execution_count": 284, + "metadata": {}, + "outputs": [], + "source": [ + "T = 100" + ] + }, + { + "cell_type": "code", + "execution_count": 285, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "BPC incurs 0.07014202 loss\n", + "LQR incurs 0.10718855 loss\n" + ] + } + ], + "source": [ + "print(\"BPC incurs \", get_err(T, lds, bpc), \" loss\")\n", + "print(\"LQR incurs \", get_err(T, lds, lqr), \" loss\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}