diff --git a/.dockerignore b/.dockerignore index f4946d91..db2fb63e 100644 --- a/.dockerignore +++ b/.dockerignore @@ -9,6 +9,7 @@ !config.yml !requirements*.txt !pyproject.toml +!poetry.lock !pyproject-full.toml !entrypoint.sh !README.md diff --git a/Dockerfile b/Dockerfile index bdea315d..4c48f969 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,41 +1,82 @@ -FROM jupyter/minimal-notebook:python-3.11 +#-------------- Base Image ------------------- +FROM jupyter/minimal-notebook:python-3.11 as BASE -ENV POETRY_VERSION=1.6.1 +ARG CODE_DIR=/tmp/code +ARG POETRY_VERSION=1.6.1 + +ENV DEBIAN_FRONTEND=noninteractive \ + PYTHONDONTWRITEBYTECODE=1 \ + PYTHONUNBUFFERED=1 \ + PYTHONFAULTHANDLER=1 \ + POETRY_VERSION=$POETRY_VERSION \ + POETRY_HOME="/opt/poetry" \ + POETRY_NO_INTERACTION=1 \ + POETRY_VIRTUALENVS_CREATE=1 \ + POETRY_VIRTUALENVS_IN_PROJECT=1 \ + CODE_DIR=$CODE_DIR + +ENV PATH="${POETRY_HOME}/bin:$PATH" + +USER root + +RUN curl -sSL https://install.python-poetry.org | python - + +USER ${NB_UID} + +WORKDIR $CODE_DIR + +COPY --chown=${NB_UID}:${NB_GID} poetry.lock pyproject.toml . + +RUN poetry install --no-interaction --no-ansi --no-root --only main +RUN poetry install --no-interaction --no-ansi --no-root --with add1 +RUN poetry install --no-interaction --no-ansi --no-root --with add2 +RUN poetry install --no-interaction --no-ansi --no-root --with control +RUN poetry install --no-interaction --no-ansi --no-root --with offline + +COPY --chown=${NB_UID}:${NB_GID} src/ src/ +COPY --chown=${NB_UID}:${NB_GID} README.md . + +RUN poetry build + + +#-------------- Main Image ------------------- +FROM jupyter/minimal-notebook:python-3.11 as MAIN + +ARG CODE_DIR=/tmp/code + +ENV DEBIAN_FRONTEND=noninteractive\ + PYTHONDONTWRITEBYTECODE=1 \ + PYTHONUNBUFFERED=1 \ + PYTHONFAULTHANDLER=1 \ + CODE_DIR=$CODE_DIR + +ENV PATH="${CODE_DIR}/.venv/bin:$PATH" USER root -RUN apt-get update && apt-get upgrade -y # pandoc needed for docs, see https://nbsphinx.readthedocs.io/en/0.7.1/installation.html?highlight=pandoc#pandoc # gh-pages action uses rsync -# gcc, gfortran and libopenblas-dev are needed for slycot, which in turn is needed by the python-control package -# build-essential required for scikit-build # opengl and ffmpeg needed for rendering envs -RUN apt-get -y --no-install-recommends install pandoc git-lfs rsync build-essential gcc gfortran libopenblas-dev ffmpeg -RUN apt-get -y install x11-xserver-utils +RUN apt-get update \ + && apt-get -y --no-install-recommends install pandoc git-lfs rsync ffmpeg x11-xserver-utils \ + && rm -rf /var/lib/apt/lists/* USER ${NB_UID} +WORKDIR ${CODE_DIR} + +# Copy virtual environment from base image +COPY --from=BASE ${CODE_DIR}/.venv ${CODE_DIR}/.venv +# Copy built package from base image +COPY --from=BASE ${CODE_DIR}/dist ${CODE_DIR}/dist -# Jhub does not support notebook 7 yet, all hell breaks loose if we don't pin it -RUN pip install "notebook<7" # This goes directly into main jupyter, not poetry env COPY --chown=${NB_UID}:${NB_GID} build_scripts ./build_scripts RUN bash build_scripts/install_presentation_requirements.sh - -# Install poetry according to -# https://python-poetry.org/docs/#installing-manually -RUN pip install -U setuptools "poetry==$POETRY_VERSION" - -WORKDIR /tmp - # Start of HACK: the home directory is overwritten by a mount when a jhub server is started off this image # Thus, we create a jovyan-owned directory to which we copy the code and then move it to the home dir as part # of the entrypoint -ENV CODE_DIR=/tmp/code - -RUN mkdir $CODE_DIR - COPY --chown=${NB_UID}:${NB_GID} entrypoint.sh $CODE_DIR RUN chmod +x "${CODE_DIR}/"entrypoint.sh @@ -53,13 +94,8 @@ COPY --chown=${NB_UID}:${NB_GID} . $CODE_DIR # complete code base, including the poetry.lock file WORKDIR $CODE_DIR -RUN poetry config virtualenvs.in-project true -RUN poetry install --no-interaction --no-ansi -RUN poetry install --no-interaction --no-ansi --with add1 -RUN poetry install --no-interaction --no-ansi --with add2 -RUN poetry install --no-interaction --no-ansi --with control -RUN poetry install --no-interaction --no-ansi --with offline -# use poetry for package mgmt. -RUN poetry run ipython kernel install --name "tfl-training-rl" --user -# DIRTY HACK -RUN pip install -U "notebook<7" ipykernel +RUN pip install --no-cache-dir dist/*.whl + +RUN ipython kernel install --name "tfl-training-rl" --user + +RUN jupyter trust notebooks diff --git a/README.md b/README.md index 9ecf0687..e26cc0e3 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ -# TransferLab Training: Safe and Efficient RL +# TransferLab Training: Control and Safe and Efficient RL -Welcome to the TransferLab training: Safe and Efficient RL. +Welcome to the TransferLab trainings: Control, Safe and Efficient RL. This is the readme for the participants of the training. ## During the training @@ -17,48 +17,59 @@ You have received this file as part of the training materials. There are multiple ways of viewing/executing the content. 1. If you just want to view the rendered notebooks, -open `html/index.html` in your browser. + open `html/index.html` in your browser. 2. If you want to execute the notebooks, you will either need to -install the dependencies or use docker. -For running without docker, create a conda environment (with python 3.9), -e.g., with `conda create -n training_rl python=3.9`. -Then, install the dependencies and the package with - ```shell - bash build_scripts/install_presentation_requirements.sh - pip install -e . - ``` - -3. If you want to use docker instead, you can build the image locally. -First, set the variable `PARTICIPANT_BUCKET_READ_SECRET` to the secret found in -`config.yaml`, and then build the image with - ```shell - docker build --build-arg PARTICIPANT_BUCKET_READ_SECRET=$PARTICIPANT_BUCKET_READ_SECRET -t training_rl . - ``` - You can then start the container e.g., with - ```shell - docker run -it -p 8888:8888 training_rl jupyter notebook - ``` -4. The data will be downloaded on the fly when you run the notebooks. -5. Finally, for creating source code documentation, you can run - ```shell - bash build_scripts/build_docs.sh - ``` - and then open `docs/build/html/index.html` in your browser. - This will also rebuild the jupyter-book based notebook documentation - that was originally found in the `html` directory. + install the dependencies or use docker. + For running without docker, create a [poetry](https://python-poetry.org/) environment (with python 3.11), + e.g., with `poetry shell`. -6. In case you experience some issues with the rendering when using docker -make sure to add the docker user to xhost. So run on your local machine: + Then, install the dependencies and the package with + + ```shell + poetry install --with=add1,add2,control,offline + bash build_scripts/install_presentation_requirements.sh + ``` + +3. If you want to use docker instead, + you can build the image locally using: + + ```shell + docker build -t tfl-training-rl:local . + ``` -xhost +SI:localuser:docker_user + You can then start the container e.g., with + + ```shell + docker run -it -p 8888:8888 tfl-training-rl:local jupyter notebook --ip=0.0.0.0 + ``` -and run docker like: +4. Finally, for creating source code documentation, you can run + + ```shell + bash build_scripts/build_docs.sh + ``` + + and then open `docs/build/html/index.html` in your browser. + This will also rebuild the jupyter-book based notebook documentation + that was originally found in the `html` directory. + +6. In case you experience some issues with the rendering when using docker + make sure to add the docker user to xhost. So run on your local machine: -docker run -p 8888:8888 -it --env DISPLAY=$DISPLAY --net=host --privileged --volume /tmp/.X11-unix:/tmp/.X11-unix training_rl bash + ```shell + xhost +SI:localuser:docker_user + ``` + and run docker using: + + ```shell + docker run -it --rm --privileged --net=host \ + --env DISPLAY --volume /tmp/.X11-unix:/tmp/.X11-unix \ + tfl-training-rl:local jupyter notebook --ip=0.0.0.0 + ``` -Note that there is some non-trivial logic in the entrypoint that may collide +> **Note** There is some non-trivial logic in the entrypoint that may collide with mounting volumes to paths directly inside `/home/jovyan/training_rl`. If you want to do that, the easiest way is to override the entrypoint or to mount somewhere else diff --git a/build_scripts/install_presentation_requirements.sh b/build_scripts/install_presentation_requirements.sh index 62875e5d..22d78bd4 100644 --- a/build_scripts/install_presentation_requirements.sh +++ b/build_scripts/install_presentation_requirements.sh @@ -37,7 +37,7 @@ BUILD_DIR=$(dirname "$0") ( cd "${BUILD_DIR}/.." || (echo "Unknown error, could not find directory ${BUILD_DIR}" && exit 255) - pip install jupyter_contrib_nbextensions + pip install --no-cache-dir jupyter_contrib_nbextensions jupyter contrib nbextension install --user jupyter nbextensions_configurator enable --user jupyter nbextension enable equation-numbering/main diff --git a/notebooks/_static/images/20_control_theory_map.png b/notebooks/_static/images/20_control_theory_map.png new file mode 100644 index 00000000..18373054 --- /dev/null +++ b/notebooks/_static/images/20_control_theory_map.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1edb7ba20fb10f290399e8aff6e22365fb22fb6e02eb5229de6d8590862d5cb9 +size 3742266 diff --git a/notebooks/_static/images/30_optimal_control_methods.svg b/notebooks/_static/images/30_optimal_control_methods.svg new file mode 100644 index 00000000..d9d321bf --- /dev/null +++ b/notebooks/_static/images/30_optimal_control_methods.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ada17ef2c241ff5de4e06a5cf3722e160772cbb81fc9e50831ac6224c64807aa +size 1518386 diff --git a/notebooks/nb_20_IntroductionToControl.ipynb b/notebooks/nb_20_IntroductionToControl.ipynb index a4e71c02..65d0f2ac 100644 --- a/notebooks/nb_20_IntroductionToControl.ipynb +++ b/notebooks/nb_20_IntroductionToControl.ipynb @@ -59,6 +59,7 @@ "outputs": [], "source": [ "%autoreload\n", + "import os\n", "import warnings\n", "from dataclasses import dataclass\n", "from typing import Callable, Protocol\n", @@ -66,9 +67,6 @@ "import control as ct\n", "import gymnasium as gym\n", "import matplotlib.pyplot as plt\n", - "import matplotx\n", - "import mediapy as media\n", - "import mujoco\n", "import numpy as np\n", "import seaborn as sns\n", "import sympy as sym\n", @@ -89,12 +87,15 @@ " plot_estimator_response,\n", " plot_mass_spring_damper_results,\n", " plot_inverted_pendulum_results,\n", - " display_array\n", + " display_array,\n", + " show_video,\n", ")\n", "\n", "warnings.simplefilter(\"ignore\", UserWarning)\n", "sns.set_theme()\n", - "plt.rcParams[\"figure.figsize\"] = [12, 8]" + "plt.rcParams[\"figure.figsize\"] = [12, 8]\n", + "# This is needed because inside docker the rendering of mujoco environments may not work.\n", + "render_mode = \"rgb_array\" if os.environ.get(\"DISPLAY\") else None" ] }, { @@ -138,6 +139,24 @@ "- **Robust Control**: an approach to controller design that explicitly deals with uncertainty." ] }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "
\n", + "
\n", + " \n", + "
\n", + " The Map of Control Theory\n", + "
\n", + "
\n", + "
" + ] + }, { "cell_type": "markdown", "metadata": { @@ -353,21 +372,18 @@ }, "outputs": [], "source": [ - "env = create_mass_spring_damper_environment()\n", + "env = create_mass_spring_damper_environment(render_mode=render_mode)\n", "env.reset()\n", - "all_frames = []\n", - "for i in range(2):\n", - " frames = []\n", - " for _ in range(100):\n", - " if i == 1:\n", - " action = np.zeros_like(env.action_space.sample())\n", - " else:\n", - " action = env.action_space.sample()\n", - " observation, _, terminated, truncated, _ = env.step(action)\n", - " if terminated or truncated:\n", - " all_frames.append(env.render())\n", - " env.reset()\n", - " break\n", + "\n", + "frames = []\n", + "for i in range(100):\n", + " action = np.array([i * 0.1])\n", + " observation, _, terminated, truncated, _ = env.step(action)\n", + " if terminated or truncated:\n", + " if env.render_mode is not None:\n", + " frames = env.render()\n", + " env.reset()\n", + " break\n", "env.close()" ] }, @@ -381,7 +397,7 @@ }, "outputs": [], "source": [ - "media.show_videos(all_frames, fps=1 / env.dt)" + "show_video(frames, fps=1 / env.dt)" ] }, { @@ -430,6 +446,7 @@ "cell_type": "code", "execution_count": null, "metadata": { + "scrolled": false, "slideshow": { "slide_type": "subslide" } @@ -533,21 +550,18 @@ }, "outputs": [], "source": [ - "env = create_inverted_pendulum_environment(max_steps=100, cutoff_angle=np.inf)\n", + "env = create_inverted_pendulum_environment(render_mode=render_mode, max_steps=100, cutoff_angle=np.inf)\n", "env.reset()\n", - "all_frames = []\n", - "for i in range(2):\n", - " frames = []\n", - " for _ in range(100):\n", - " if i == 1:\n", - " action = np.zeros_like(env.action_space.sample())\n", - " else:\n", - " action = env.action_space.sample()\n", - " observation, _, terminated, truncated, _ = env.step(action)\n", - " if terminated or truncated:\n", - " all_frames.append(env.render())\n", - " env.reset()\n", - " break\n", + "\n", + "frames = []\n", + "for _ in range(100):\n", + " action = np.zeros_like(env.action_space.sample())\n", + " observation, _, terminated, truncated, _ = env.step(action)\n", + " if terminated or truncated:\n", + " if env.render_mode is not None:\n", + " frames= env.render()\n", + " env.reset()\n", + " break\n", "env.close()" ] }, @@ -561,7 +575,7 @@ }, "outputs": [], "source": [ - "media.show_videos(all_frames, fps=1 / env.dt)" + "show_video(frames, fps=1 / env.dt)" ] }, { @@ -665,6 +679,7 @@ " actions = []\n", " observations = [observation]\n", " estimated_observations = []\n", + " frames = []\n", "\n", " if observer is not None:\n", " estimated_observation = observer.observe(observation)\n", @@ -683,7 +698,8 @@ "\n", " # Check if we need to stop the simulation\n", " if terminated or truncated:\n", - " frames = env.render()\n", + " if env.render_mode is not None:\n", + " frames = env.render()\n", " env.reset()\n", " break\n", " env.close()\n", @@ -711,9 +727,9 @@ }, "outputs": [], "source": [ - "env = create_mass_spring_damper_environment()\n", + "env = create_mass_spring_damper_environment(render_mode=render_mode)\n", "results = simulate_environment(env)\n", - "media.show_video(results.frames, fps=1 / env.dt)" + "show_video(results.frames, fps=1 / env.dt)" ] }, { @@ -812,10 +828,16 @@ "outputs": [], "source": [ "def control_inverted_pendulum(K=widgets.FloatSlider(min=0.0, max=100.0, step=1, value=1.0)):\n", - " env = create_inverted_pendulum_environment()\n", + " env = create_inverted_pendulum_environment(render_mode=render_mode)\n", " controller = ProportionalController(K)\n", " results = simulate_environment(env, max_steps=300, controller=controller)\n", - " media.show_video(results.frames, fps=1 / env.dt)\n", + " if env.render_mode is None:\n", + " T = np.arange(len(results.observations)) * env.dt\n", + " plt.plot(T, results.observations[:, 0])\n", + " plt.xlabel(\"Time\")\n", + " plt.ylabel(\"Angle\")\n", + " else:\n", + " show_video(results.frames, fps=1 / env.dt)\n", "\n", "\n", "interact(control_inverted_pendulum);" @@ -1056,8 +1078,8 @@ "Where:\n", "\n", "- $K$ is the system's gain.\n", - "- $a_i$ are the system's poles.\n", - "- $b_i$ are the system's zeros." + "- $a_i$ are the system's zeros.\n", + "- $b_i$ are the system's poles." ] }, { @@ -1829,7 +1851,7 @@ "solution2_first": true }, "source": [ - "### Hint\n", + "## Hint\n", "\n", "Open this help if you need help with the creating the step controller." ] @@ -1930,7 +1952,7 @@ "outputs": [], "source": [ "max_steps = 300\n", - "env = create_mass_spring_damper_environment(max_steps=max_steps)\n", + "env = create_mass_spring_damper_environment(render_mode=render_mode, max_steps=max_steps)\n", "controller = StepController()\n", "results = simulate_environment(env, max_steps=max_steps, controller=controller)\n", "observations = results.observations\n", @@ -2298,7 +2320,7 @@ "outputs": [], "source": [ "max_steps = 10\n", - "env = create_inverted_pendulum_environment(max_steps=max_steps)\n", + "env = create_inverted_pendulum_environment(render_mode=render_mode, max_steps=max_steps)\n", "controller = ConstantController()\n", "results = simulate_environment(env, max_steps=max_steps, controller=controller)" ] @@ -2336,7 +2358,7 @@ "plt.plot(response.time, response.outputs, label=\"System\")\n", "plt.plot(response.time, results.observations[:-1, 1], label=\"Model\")\n", "plt.xlabel(\"Time\")\n", - "plt.ylabel(\"Position\")\n", + "plt.ylabel(\"Angle\")\n", "plt.legend();" ] }, @@ -2357,7 +2379,7 @@ " A system is bounded-input, bounded-output stable (**BIBO** stable) if, for every bounded input, the output is finite. Mathematically, if every input satisfying\n", "\n", " $$\n", - " ||x(t)||_\\infty \\lt \\infty\n", + " ||u(t)||_\\infty \\lt \\infty\n", " $$\n", "\n", " leads to an output satisfying \n", @@ -2465,7 +2487,7 @@ "\\mathbf{x}(t) = e^{At}\\mathbf{x}(0)\n", "$$\n", "\n", - "which can be rewritten as:\n", + "which can be rewritten using the [eigendecomposition](https://en.wikipedia.org/wiki/Eigendecomposition_of_a_matrix) as:\n", "\n", "$$\n", "\\mathbf{x}(t) = \\left[M e^{\\Lambda t}M^{-1}\\right] \\mathbf{x}(0)\n", @@ -2480,7 +2502,7 @@ " 0 & 0 & \\dots & 0 & e^{\\lambda_n t}\\\\\n", " \\end{bmatrix}$\n", "- $\\Lambda$ is a vector that contains the eigenvalues of $A$.\n", - "- $M$ is a matrix whose columns are the eigenvectors of $A$." + "- $M = \\begin{bmatrix}m_{ij}\\end{bmatrix}$ is a matrix whose columns are the eigenvectors of $A$." ] }, { @@ -2552,7 +2574,7 @@ "source": [ "Lambda, M = np.linalg.eig(A)\n", "display_array(\"\\Lambda\", Lambda)\n", - "display_array(\"M\", eigenvectors)" + "display_array(\"M\", M)" ] }, { @@ -2570,17 +2592,6 @@ "display_array(\"M^{-1}\", M_1)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "solution2": "hidden" - }, - "outputs": [], - "source": [ - "M * sym.simplify(sym.exp(sym.Matrix(np.diag(Lambda) * t))) * M_1" - ] - }, { "cell_type": "code", "execution_count": null, @@ -2733,8 +2744,7 @@ }, "outputs": [], "source": [ - "result = np.linalg.eig(mass_spring_damper.A)\n", - "eigenvalues = result.eigenvalues\n", + "eigenvalues, _ = np.linalg.eig(mass_spring_damper.A)\n", "eigenvalues" ] }, @@ -2869,8 +2879,7 @@ }, "outputs": [], "source": [ - "result = np.linalg.eig(inverted_pendulum.A)\n", - "eigenvalues = result.eigenvalues\n", + "eigenvalues, _ = np.linalg.eig(inverted_pendulum.A)\n", "eigenvalues" ] }, @@ -3206,7 +3215,7 @@ }, "outputs": [], "source": [ - "env = create_mass_spring_damper_environment()\n", + "env = create_mass_spring_damper_environment(render_mode=render_mode)\n", "dt = env.dt\n", "T = np.arange(0, len(observations) * dt - dt, dt)\n", "U = np.concatenate([observations[1:, [1]], actions], axis=1).transpose()\n", @@ -3284,9 +3293,7 @@ "metadata": { "slideshow": { "slide_type": "subslide" - }, - "solution2": "hidden", - "solution2_first": true + } }, "source": [ "## Solution" @@ -3298,7 +3305,8 @@ "slideshow": { "slide_type": "subslide" }, - "solution2": "hidden" + "solution2": "hidden", + "solution2_first": true }, "source": [ "### Inverted Pendulum" @@ -3325,7 +3333,7 @@ "outputs": [], "source": [ "max_steps = 20\n", - "env = create_inverted_pendulum_environment(max_steps=max_steps)\n", + "env = create_inverted_pendulum_environment(render_mode=render_mode, max_steps=max_steps)\n", "controller = ProportionalController(K=20)\n", "results = simulate_environment(env, max_steps=max_steps, controller=controller)\n", "observations = results.observations\n", @@ -3375,7 +3383,7 @@ }, "outputs": [], "source": [ - "env = create_inverted_pendulum_environment()\n", + "env = create_inverted_pendulum_environment(render_mode=render_mode)\n", "dt = env.dt\n", "T = np.arange(0, len(observations) * dt, dt)\n", "T = T[: len(observations) - 1]\n", @@ -3807,7 +3815,7 @@ }, "outputs": [], "source": [ - "env = create_mass_spring_damper_environment()\n", + "env = create_mass_spring_damper_environment(render_mode=render_mode)\n", "n_steps = 100\n", "T = np.arange(0, n_steps) * env.dt\n", "x0 = np.zeros(closed_loop.nstates)\n", @@ -3872,7 +3880,7 @@ "outputs": [], "source": [ "max_steps = 100\n", - "env = create_mass_spring_damper_environment(max_steps=max_steps)\n", + "env = create_mass_spring_damper_environment(render_mode=render_mode, max_steps=max_steps)\n", "controller = FullStateFeedbackController(K=K, kr=kr, reference=0.1)\n", "results = simulate_environment(env, max_steps=max_steps, controller=controller)\n", "observations = results.observations\n", @@ -3883,13 +3891,14 @@ "cell_type": "code", "execution_count": null, "metadata": { + "scrolled": false, "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ - "media.show_video(results.frames, fps=1 / env.dt)" + "show_video(results.frames, fps=1 / env.dt)" ] }, { @@ -3924,9 +3933,7 @@ "metadata": { "slideshow": { "slide_type": "subslide" - }, - "solution2": "hidden", - "solution2_first": true + } }, "source": [ "## Solution" @@ -3938,7 +3945,8 @@ "slideshow": { "slide_type": "subslide" }, - "solution2": "hidden" + "solution2": "hidden", + "solution2_first": true }, "source": [ "### Inverted Pendulum" @@ -4021,7 +4029,7 @@ }, "outputs": [], "source": [ - "env = create_mass_spring_damper_environment()\n", + "env = create_mass_spring_damper_environment(render_mode=render_mode)\n", "n_steps = 100\n", "T = np.arange(0, n_steps) * env.dt\n", "x0 = np.zeros(closed_loop.nstates)\n", @@ -4092,7 +4100,7 @@ "outputs": [], "source": [ "max_steps = 500\n", - "env = create_inverted_pendulum_environment(max_steps=max_steps)\n", + "env = create_inverted_pendulum_environment(render_mode=render_mode, max_steps=max_steps)\n", "controller = FullStateFeedbackController(K=K, kr=kr, reference=0.0)\n", "results = simulate_environment(env, max_steps=max_steps, controller=controller)\n", "observations = results.observations\n", @@ -4110,7 +4118,7 @@ }, "outputs": [], "source": [ - "media.show_video(results.frames, fps=1 / env.dt)" + "show_video(results.frames, fps=1 / env.dt)" ] }, { @@ -4468,7 +4476,7 @@ "outputs": [], "source": [ "Kp = 1\n", - "Ki = 20\n", + "Ki = 50\n", "Kd = 0" ] }, @@ -4528,7 +4536,7 @@ "outputs": [], "source": [ "max_steps = 500\n", - "env = create_mass_spring_damper_environment(max_steps=max_steps)\n", + "env = create_mass_spring_damper_environment(render_mode=render_mode, max_steps=max_steps)\n", "r = 0.1\n", "controller = PIDController(Kp=Kp, Ki=Ki, Kd=Kd, reference=r, dt=env.dt)\n", "results = simulate_environment(env, max_steps=max_steps, controller=controller)\n", @@ -4547,7 +4555,7 @@ }, "outputs": [], "source": [ - "media.show_video(results.frames, fps=1 / env.dt)" + "show_video(results.frames, fps=1 / env.dt)" ] }, { @@ -4653,9 +4661,9 @@ }, "outputs": [], "source": [ - "Kp = 600\n", + "Kp = 308\n", "Ki = 0\n", - "Kd = 50" + "Kd = 40" ] }, { @@ -4714,7 +4722,7 @@ "outputs": [], "source": [ "max_steps = 500\n", - "env = create_inverted_pendulum_environment(max_steps=max_steps)\n", + "env = create_inverted_pendulum_environment(render_mode=render_mode, max_steps=max_steps)\n", "r = 0.0\n", "controller = PIDController(Kp=Kp, Ki=Ki, Kd=Kd, reference=r, dt=env.dt)\n", "results = simulate_environment(env, max_steps=max_steps, controller=controller)\n", @@ -4733,7 +4741,7 @@ }, "outputs": [], "source": [ - "media.show_video(results.frames, fps=1 / env.dt)" + "show_video(results.frames, fps=1 / env.dt)" ] }, { diff --git a/notebooks/nb_30_ControlAndPlanning.ipynb b/notebooks/nb_30_ControlAndPlanning.ipynb index 15bcd2d5..11ea9637 100644 --- a/notebooks/nb_30_ControlAndPlanning.ipynb +++ b/notebooks/nb_30_ControlAndPlanning.ipynb @@ -4,7 +4,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "hide_input": true, + "hide_input": false, "init_cell": true, "slideshow": { "slide_type": "skip" @@ -66,18 +66,15 @@ "outputs": [], "source": [ "%autoreload\n", + "import os\n", "import warnings\n", "from dataclasses import dataclass\n", "from typing import Protocol\n", "\n", "import casadi\n", - "import control as ct\n", "import do_mpc\n", "import gymnasium as gym\n", "import matplotlib.pyplot as plt\n", - "import matplotx\n", - "import mediapy as media\n", - "import mujoco\n", "import numpy as np\n", "import seaborn as sns\n", "from gymnasium import Env\n", @@ -99,11 +96,14 @@ " animate_inverted_pendulum_simulation,\n", " animate_full_inverted_pendulum_simulation,\n", " display_array,\n", + " show_video,\n", ")\n", "\n", "warnings.simplefilter(\"ignore\", UserWarning)\n", "sns.set_theme()\n", - "plt.rcParams[\"figure.figsize\"] = [9, 5]" + "plt.rcParams[\"figure.figsize\"] = [9, 5]\n", + "# This is needed because inside docker the rendering of mujoco environments may not work.\n", + "render_mode = \"rgb_array\" if os.environ.get(\"DISPLAY\") else None" ] }, { @@ -313,12 +313,14 @@ } }, "source": [ + "
\n", "
\n", " \n", "
\n", " Deterministic N-stage optimal control problem.\n", "
\n", - "
" + "\n", + "
" ] }, { @@ -420,12 +422,14 @@ } }, "source": [ + "
\n", "
\n", " \n", "
\n", " Transition graph for a deterministic discrete system.\n", "
\n", - "
" + "\n", + "
" ] }, { @@ -447,13 +451,60 @@ } }, "source": [ - "Common moptimal control methods are Dynamic Programming (DP), Pontryagin’s Minimum Principle (PMP), and Hamilton-Jacobi-Bellman (HJB) equations.\n", + "Optimal control problems solving methods can be classified in three main families:\n", + "Dynamic Programming (DP), Indirect Methods based on calculus of variation and Direct Methods.\n", "\n", "- DP is helpful where the number of states is limited and the dynamics are known. It divides an optimal control issue into smaller subproblems and recursively solves each.\n", "\n", - "- PMP, another optimal control method, employs the Hamiltonian of the system to find the optimal control input. Problems involving continuous states and control inputs benefit most from it.\n", + "- Indirect methods rely on Pontryagin’s Minimum Principle (PMP) to derive the necessary conditions\n", + " for optimality. This method uses the Hamiltonian of the system to reduce the global optimal control problem\n", + " to the solution of a system of $2N$ equations given in the form of a two-point boundary value problem (BVP). \n", + " Problems involving continuous states and control inputs benefit most from it. " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "- Direct methods rely on the discretization of the original optimal control problem which is then transcribed to\n", + " a nonlinear programming problem (NLP) solved numerically using a well-established optimisation method.\n", + " \n", + " There are many direct methods. They differ on how the variables (i.e. control and states) are discretised \n", + " and on how the continuous time dynamics is approximated.\n", + " \n", + " In the case of shooting and multiple shooting the control are parameterised\n", + " with piecewise linear functions and the differential equations\n", + " are solved via numerical integration. These approaches make use of robust\n", + " and available ordinary differential equations solvers\n", + " but need sensitivity analysis to compute the jacobians\n", + " of the continuity and boundary conditions with respect to the initial and intermediate conditions.\n", "\n", - "- Another optimal control algorithm is the HJB equation which uses partial differential equations to find the value function of the system. HJB is also useful for problems with continuous states and control inputs." + " In the case of state and control parameterisation (direct collocation),\n", + " both states and controls are approximated with polynomial functions,\n", + " therefore the continuous time differential equations are converted into algebraic constraints." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "
\n", + "
\n", + " \n", + "
\n", + " Classification of different methods to solve optimal control problems and related formulations and\n", + "solution algorithms [Biral Notes 2016].\n", + "
\n", + "
\n", + "
" ] }, { @@ -564,12 +615,23 @@ "$$" ] }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "#### State-Space Matrices" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { - "slide_type": "subslide" + "slide_type": "fragment" } }, "outputs": [], @@ -589,12 +651,23 @@ "D = np.zeros(2)" ] }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "#### Model, States and Control inputs" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { - "slide_type": "subslide" + "slide_type": "fragment" } }, "outputs": [], @@ -606,12 +679,23 @@ "mass_spring_damper.setup(A, B, C, D)" ] }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "#### Discretization" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { - "slide_type": "subslide" + "slide_type": "fragment" } }, "outputs": [], @@ -660,66 +744,7 @@ } }, "source": [ - "Application of Newtonian physics to this system leads to the [following model](https://sharpneat.sourceforge.io/research/cart-pole/cart-pole-equations.html):\n", - "\n", - "$$\n", - "\\ddot{y}(t) = \\frac{1}{m\\cos^2\\theta(t) - (1+k)(M+m)} \\left[\n", - "mg\\sin\\theta(t)\\cos\\theta(t) - (1+k)(\\gamma f(t) + ml\\dot\\theta^2(t) \\sin\\theta(t) - \\mu_c \\dot{y}(t)) - \\cfrac{\\mu_p\\cos\\theta(t)}{l}\\dot{\\theta}(t)\n", - "\\right]\\\\\n", - "\\ddot\\theta(t) = \\frac{1}{(1+k)(M+m)l - ml\\cos^2\\theta(t)} \\left[\n", - "(M+m)g\\sin\\theta(t) - \\cos\\theta(t) (\\gamma f(t) + ml\\dot\\theta^2(t)\\sin\\theta(t) - \\mu_c \\dot{y}(t)) - \\cfrac{(M+m)\\mu_p}{ml} \\dot{\\theta}(t)\n", - "\\right]\n", - "$$\n", - "\n", - "with $k = \\frac{1}{3}.$" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "slideshow": { - "slide_type": "subslide" - } - }, - "source": [ - "We can convert this to state-space form with input $u(t) = f(t)$ and output\n", - "$y(t)$; by introducing:\n", - "\n", - "$$\n", - "X(t) = \\begin{bmatrix}\n", - "x_1(t) \\\\ x_2(t) \\\\ x_3(t) \\\\ x_4(t)\n", - "\\end{bmatrix}\n", - "= \\begin{bmatrix}\n", - "y(t) \\\\ \\theta(t) \\\\ \\dot{y}(t) \\\\ \\dot{\\theta}(t) \n", - "\\end{bmatrix}\n", - "$$" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "slideshow": { - "slide_type": "subslide" - } - }, - "source": [ - "The system has the following full state-space model:\n", - "\n", - "$$\n", - "\\dot{X}(t) = \\begin{bmatrix}\n", - "\\dot{x_1}(t) \\\\ \\dot{x_2}(t) \\\\ \\dot{x_3}(t) \\\\ \\dot{x_4}(t)\n", - "\\end{bmatrix} =\n", - "\\begin{bmatrix}\n", - "x_2(t) \\\\\n", - "x_4(t) \\\\\n", - "\\frac{1}{m\\cos^2 x_2(t) - (1+k)(M+m)} \\left[\n", - "mg\\sin x_2(t)\\cos x_2(t) - (1+k)(\\gamma u(t) + ml x_4^2(t) \\sin x_2(t) - \\mu_c x_3(t)) - \\cfrac{\\mu_p \\cos x_2(t)}{l} x_4(t)\n", - "\\right]\\\\\n", - "\\frac{1}{(1+k)(M+m)l - ml\\cos^2 x_2(t)} \\left[\n", - "(M+m)g\\sin x_2(t) - \\cos x_2(t) (\\gamma u(t) + ml x_4^2(t)\\sin x_2(t) - \\mu_c x_3(t)) - \\cfrac{(M+m)\\mu_p}{ml} x_4(t)\n", - "\\right]\n", - "\\end{bmatrix}\n", - "$$" + "### Linearized Model" ] }, { @@ -730,7 +755,7 @@ } }, "source": [ - "And the following partial (pendulum angle and angular velocity only) linearized (around $\\theta = 0$ and $\\dot{\\theta} = 0$) state space model:\n", + "The system has the following partial (pendulum angle and angular velocity only) linearized (around $\\theta = 0$ and $\\dot{\\theta} = 0$) state space model:\n", "\n", "$$\n", "A = \\begin{bmatrix}\n", @@ -749,12 +774,23 @@ "$$" ] }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "#### State-Space Matrices" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { - "slide_type": "subslide" + "slide_type": "fragment" } }, "outputs": [], @@ -785,12 +821,23 @@ "D = np.zeros(2)" ] }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "#### Model, States and Control inputs" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { - "slide_type": "subslide" + "slide_type": "fragment" } }, "outputs": [], @@ -808,6 +855,17 @@ "slide_type": "subslide" } }, + "source": [ + "#### Auxiliary Expressions" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, "source": [ "We also define the pendulum's kinetic and potential energies:\n", "\n", @@ -821,7 +879,7 @@ "execution_count": null, "metadata": { "slideshow": { - "slide_type": "fragment" + "slide_type": "subslide" } }, "outputs": [], @@ -838,12 +896,23 @@ "inverted_pendulum_lin.set_expression(\"E_potential\", E_pot);" ] }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "#### Model Setup" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { - "slide_type": "subslide" + "slide_type": "fragment" } }, "outputs": [], @@ -851,12 +920,23 @@ "inverted_pendulum_lin.setup(A, B, C, D)" ] }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "#### Discretization" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { - "slide_type": "subslide" + "slide_type": "fragment" } }, "outputs": [], @@ -873,15 +953,98 @@ } }, "source": [ + "### Non-Linear Model\n", + "\n", "We also define the full non-linear model" ] }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "source": [ + "Application of Newtonian physics to this system leads to the [following model](https://sharpneat.sourceforge.io/research/cart-pole/cart-pole-equations.html):\n", + "\n", + "$$\n", + "\\ddot{y}(t) = \\frac{1}{m\\cos^2\\theta(t) - (1+k)(M+m)} \\left[\n", + "mg\\sin\\theta(t)\\cos\\theta(t) - (1+k)(\\gamma f(t) + ml\\dot\\theta^2(t) \\sin\\theta(t) - \\mu_c \\dot{y}(t)) - \\cfrac{\\mu_p\\cos\\theta(t)}{l}\\dot{\\theta}(t)\n", + "\\right]\\\\\n", + "\\ddot\\theta(t) = \\frac{1}{(1+k)(M+m)l - ml\\cos^2\\theta(t)} \\left[\n", + "(M+m)g\\sin\\theta(t) - \\cos\\theta(t) (\\gamma f(t) + ml\\dot\\theta^2(t)\\sin\\theta(t) - \\mu_c \\dot{y}(t)) - \\cfrac{(M+m)\\mu_p}{ml} \\dot{\\theta}(t)\n", + "\\right]\n", + "$$\n", + "\n", + "with $k = \\frac{1}{3}.$" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "We can convert this to state-space form with input $u(t) = f(t)$ and output\n", + "$y(t)$; by introducing:\n", + "\n", + "$$\n", + "X(t) = \\begin{bmatrix}\n", + "x_1(t) \\\\ x_2(t) \\\\ x_3(t) \\\\ x_4(t)\n", + "\\end{bmatrix}\n", + "= \\begin{bmatrix}\n", + "y(t) \\\\ \\theta(t) \\\\ \\dot{y}(t) \\\\ \\dot{\\theta}(t) \n", + "\\end{bmatrix}\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "The system has the following full state-space model:\n", + "\n", + "$$\n", + "\\dot{X}(t) = \\begin{bmatrix}\n", + "\\dot{x_1}(t) \\\\ \\dot{x_2}(t) \\\\ \\dot{x_3}(t) \\\\ \\dot{x_4}(t)\n", + "\\end{bmatrix} =\n", + "\\begin{bmatrix}\n", + "x_2(t) \\\\\n", + "x_4(t) \\\\\n", + "\\frac{1}{m\\cos^2 x_2(t) - (1+k)(M+m)} \\left[\n", + "mg\\sin x_2(t)\\cos x_2(t) - (1+k)(\\gamma u(t) + ml x_4^2(t) \\sin x_2(t) - \\mu_c x_3(t)) - \\cfrac{\\mu_p \\cos x_2(t)}{l} x_4(t)\n", + "\\right]\\\\\n", + "\\frac{1}{(1+k)(M+m)l - ml\\cos^2 x_2(t)} \\left[\n", + "(M+m)g\\sin x_2(t) - \\cos x_2(t) (\\gamma u(t) + ml x_4^2(t)\\sin x_2(t) - \\mu_c x_3(t)) - \\cfrac{(M+m)\\mu_p}{ml} x_4(t)\n", + "\\right]\n", + "\\end{bmatrix}\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "#### Model, States and Control inputs" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { - "slide_type": "subslide" + "slide_type": "fragment" } }, "outputs": [], @@ -895,6 +1058,17 @@ "u = inverted_pendulum.set_variable(var_type=\"_u\", var_name=\"force\")" ] }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "#### ODE" + ] + }, { "cell_type": "code", "execution_count": null, @@ -957,6 +1131,17 @@ "slide_type": "subslide" } }, + "source": [ + "#### Auxiliary Expressions" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, "source": [ "And kinetic and potential energies:\n", "\n", @@ -974,7 +1159,7 @@ "execution_count": null, "metadata": { "slideshow": { - "slide_type": "fragment" + "slide_type": "subslide" } }, "outputs": [], @@ -997,12 +1182,23 @@ "inverted_pendulum.set_expression(\"E_potential\", E_pot);" ] }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "#### Model Setup" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { - "slide_type": "subslide" + "slide_type": "fragment" } }, "outputs": [], @@ -1083,6 +1279,7 @@ " observation, _ = env.reset()\n", " actions = []\n", " observations = [observation]\n", + " frames = []\n", "\n", " for _ in range(max_steps):\n", " action = controller.act(observation)\n", @@ -1093,7 +1290,8 @@ "\n", " # Check if we need to stop the simulation\n", " if terminated or truncated:\n", - " frames = env.render()\n", + " if env.render_mode is not None:\n", + " frames = env.render()\n", " env.reset()\n", " break\n", " env.close()\n", @@ -1108,6 +1306,17 @@ " )" ] }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "### Simulators" + ] + }, { "cell_type": "code", "execution_count": null, @@ -1333,6 +1542,7 @@ "cell_type": "code", "execution_count": null, "metadata": { + "scrolled": false, "slideshow": { "slide_type": "subslide" } @@ -1415,24 +1625,24 @@ "\n", "$$\n", "\\begin{array}\\\\\n", - "V(\\text{ABDF}) &= l(\\text{ABDF}, \\text{ABDFG}) &= 1\\\\\n", - "V(\\text{ABE}) &= l(\\text{ABE}, \\text{ABEG}) &= 4\\\\\n", - "V(\\text{ACF}) &= l(\\text{ACF}, \\text{ACFG}) &= 1\\\\\n", - "V(\\text{ADF}) &= l(\\text{ADF}, \\text{ADFG}) &= 1\\\\\n", + "V(\\text{ABDF}) &= g(\\text{ABDF}, \\text{ABDFG}) &= 1\\\\\n", + "V(\\text{ABE}) &= g(\\text{ABE}, \\text{ABEG}) &= 4\\\\\n", + "V(\\text{ACF}) &= g(\\text{ACF}, \\text{ACFG}) &= 1\\\\\n", + "V(\\text{ADF}) &= g(\\text{ADF}, \\text{ADFG}) &= 1\\\\\n", "\\end{array}\n", "$$\n", "\n", "$$\n", "\\begin{array}\\\\\n", - "V(\\text{ABD}) &= \\min \\left[ l(\\text{ABD}, \\text{ABDG}), l(\\text{ABD}, \\text{ABDF}) + V(\\text{ABDF}) \\right]\n", + "V(\\text{ABD}) &= \\min \\left[ g(\\text{ABD}, \\text{ABDG}), g(\\text{ABD}, \\text{ABDF}) + V(\\text{ABDF}) \\right]\n", "&= \\min \\left[ 8, 5 + 1 \\right] &= 6\n", "\\\\\n", - "V(\\text{AB}) &= \\min \\left[ l(\\text{AB}, \\text{ABD}) + V(\\text{ABD}), l(\\text{AB}, \\text{ABE}) + V(\\text{ABE}) \\right]\n", - "&= \\min \\left[ 9 + 5, 6 + 4 \\right] &= 10\n", + "V(\\text{AB}) &= \\min \\left[ g(\\text{AB}, \\text{ABD}) + V(\\text{ABD}), g(\\text{AB}, \\text{ABE}) + V(\\text{ABE}) \\right]\n", + "&= \\min \\left[ 9 + 6, 1 + 4 \\right] &= 5\n", "\\\\\n", - "V(\\text{AC}) &= l(\\text{AC}, \\text{ACF}) + V(\\text{ACF}) &= 2 + 1 &= 3\n", + "V(\\text{AC}) &= g(\\text{AC}, \\text{ACF}) + V(\\text{ACF}) &= 2 + 1 &= 3\n", "\\\\\n", - "V(\\text{AD}) &= \\min \\left[ l(\\text{AD}, \\text{ADF}) + V(\\text{ADF}), l(\\text{AD}, \\text{ADG})) \\right]\n", + "V(\\text{AD}) &= \\min \\left[ g(\\text{AD}, \\text{ADF}) + V(\\text{ADF}), g(\\text{AD}, \\text{ADG})) \\right]\n", "&= \\min \\left[ 5 + 1, 8 \\right] &= 6\n", "\\\\\n", "\\end{array}\n", @@ -1441,14 +1651,14 @@ "$$\n", "\\begin{array}\\\\\n", "V(\\text{A}) &= \\min \\left[\n", - "l(\\text{A}, \\text{AB}) + V(\\text{AB}), l(\\text{A}, \\text{AC}) + V(\\text{AC}), l(\\text{A}, \\text{AD}) + V(\\text{AD})\n", + "g(\\text{A}, \\text{AB}) + V(\\text{AB}), g(\\text{A}, \\text{AC}) + V(\\text{AC}), g(\\text{A}, \\text{AD}) + V(\\text{AD})\n", "\\right]\n", - "&= \\min \\left[ 1 + 10, 5 + 3, 3 + 6 \\right] &= 8\n", + "&= \\min \\left[ 1 + 5, 5 + 3, 3 + 6 \\right] &= 6\n", "\\\\\n", "\\end{array}\n", "$$\n", "\n", - "The shortest-path is aCFG." + "The shortest-path is ABEG." ] }, { @@ -1492,12 +1702,10 @@ "\\mathbf{x}_{t+1} = A \\mathbf{x}_t + B \\mathbf{u}_t\n", "$$\n", "\n", - "where $x\\in \\mathbb {R} ^{n}$ (that is, $x$ is an $n$-dimensional real-valued vector) is the state of the system and $u\\in \\mathbb {R} ^{m}$ is the control input. Given a quadratic cost function for the system, defined as:\n", - "\n", - "### Finite-Horizon\n", + "where $x\\in \\mathbb {R} ^{n}$ (that is, $x$ is an $n$-dimensional real-valued vector) is the state of the system and $u\\in \\mathbb {R} ^{m}$ is the control input. Given a quadratic cost function for the system in the infinite-horizon case, defined as:\n", "\n", "$$\n", - "J_0(\\mathbf{x}_0, \\mathbf{u}) = \\frac{1}{2} \\mathbf{x}_N^T Q \\mathbf{x}_N + \\frac{1}{2} \\sum \\limits _{k = 0}^{N - 1} \\mathbf{x}_k^{T}Q \\mathbf{x}_k + \\mathbf{u}_k^{T} R \\mathbf{u}_k\n", + "J(\\mathbf{x}_0, \\mathbf{u}) = \\sum \\limits _{k = 0}^{\\infty} \\mathbf{x}_k^{T}Q \\mathbf{x}_k + \\mathbf{u}_k^{T} R \\mathbf{u}_k\n", "$$\n", "\n", "With $Q = Q^T \\succeq 0$, $R = R^T \\succeq 0$." @@ -1511,13 +1719,19 @@ } }, "source": [ - "### Infinite-Horizon\n", + "In both cases, the control law that minizes the cost is given by:\n", "\n", "$$\n", - "J(\\mathbf{x}_0, \\mathbf{u}) = \\frac{1}{2} \\sum \\limits _{k = 0}^{\\infty} \\mathbf{x}_k^{T}Q \\mathbf{x}_k + \\mathbf{u}_k^{T} R \\mathbf{u}_k\n", + "u_k = -K x_k\n", "$$\n", "\n", - "With $Q = Q^T \\succeq 0$, $R = R^T \\succeq 0$." + "With: $K = (R + B^T P B)^{-1} B^T P B$\n", + "\n", + "and $P$ is found by solving the discrete time algebraic Riccati equation (DARE):\n", + "\n", + "$$\n", + "Q + A^{T}PA-(A^{T}PB)(R+B^{T}PB)^{-1}(B^{T}PA) = P.\n", + "$$" ] }, { @@ -1528,25 +1742,14 @@ } }, "source": [ - "Let's solve this for the finite-horizon case using dynamic programming. We start by setting:\n", - "\n", - "$$\n", - "V_N(\\mathbf{x}_N) = J_N(\\mathbf{x}_N) = \\frac{1}{2} \\mathbf{x}_{N}^T Q \\mathbf{x}_{N} = \\frac{1}{2} \\mathbf{x}_{N}^T P \\mathbf{x}_{N}\n", - "$$\n", - "\n", - "And then proceed backward in time:\n", - "\n", - "$$\n", - "\\begin{array}\\\\\n", - "V_{N-1}(\\mathbf{x}_{N-1}) &=& \\displaystyle \\min_{\\mathbf{u}_{N-1}} J_{N-1}(\\mathbf{x}_{N-1}, \\mathbf{u}_{N-1})\\\\\n", - "&=& \\displaystyle \\min_{\\mathbf{u}_{N-1}} \\frac{1}{2} \\left(\n", - "\\mathbf{x}_{N-1}^{T} Q \\mathbf{x}_{N-1} + \\mathbf{u}_{N-1}^{T} R \\mathbf{u}_{N-1} + \\mathbf{x}_{N}^T P \\mathbf{x}_{N}\n", - "\\right) \\\\\n", - "&=& \\displaystyle \\min_{\\mathbf{u}_{N-1}} \\frac{1}{2} \\left(\n", - "\\mathbf{x}_{N-1}^{T} Q \\mathbf{x}_{N-1} + \\mathbf{u}_{N-1}^{T} R \\mathbf{u}_{N-1} + (A\\mathbf{x}_{N-1} + B\\mathbf{u}_{N-1})^T P (A\\mathbf{x}_{N-1} + B\\mathbf{u}_{N-1})\n", - "\\right)\n", - "\\end{array}\n", - "$$" + "
\n", + "
\n", + " \n", + "
\n", + " LQR Block Diagram.\n", + "
\n", + "
\n", + "
" ] }, { @@ -1557,20 +1760,57 @@ } }, "source": [ - "Taking the gradient with respect to $\\mathbf{u}_{N-1}$:\n", - "\n", - "$$\n", - "\\displaystyle \\nabla_{\\mathbf{u}_{N-1}} J_{N-1}(\\mathbf{x}_{N-1}, \\mathbf{u}_{N-1}) = \n", - "R \\mathbf{u}_{N-1} + B^T P (A \\mathbf{x}_{N-1} + B \\mathbf{u}_{N - 1}) = 0\n", - "$$\n", - "\n", - "Gives us the following optimal feedback control at step $N - 1$:\n", - "\n", - "$$\n", - "\\mathbf{u}^*_{N-1} = -(R + B^T P B)^{-1} B^T P B \\mathbf{x}_{N-1} = - K \\mathbf{x}_{N-1}\n", - "$$\n", - "\n", - "With $K = (R + B^T P B)^{-1} B^T P B$" + "### Mass-Spring-Damper" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "Now, we design Linear Quadratic Regulator for the Mass-Spring-Damper model. First, we create an instance of the LQR class." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "outputs": [], + "source": [ + "lqr = do_mpc.controller.LQR(mass_spring_damper)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "We choose an infinite prediction horizon (`n_horizon = None`), the time step `t_step = 0.04s` second (which is the same as the environment's timestep)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [], + "source": [ + "env = create_mass_spring_damper_environment(render_mode=render_mode)\n", + "lqr.settings.t_step = env.dt\n", + "lqr.settings.n_horizon = None # infinite horizon" ] }, { @@ -1581,45 +1821,24 @@ } }, "source": [ - "The optimal cost-to-go is then:\n", - "\n", - "$$\n", - "\\begin{array}\\\\\n", - "V_{N-1}(\\mathbf{x}_{N-1}) &= J_{N-1}(\\mathbf{x}_{N-1}, \\mathbf{u}^*_{N-1}) \\\\\n", - "&= \\frac{1}{2} \\left(\n", - "\\mathbf{x}_{N-1}^{T} Q \\mathbf{x}_{N-1} + \\mathbf{u}_{N-1}^{*T} R \\mathbf{u}^{*}_{N-1} + \\mathbf{x}_{N}^T P \\mathbf{x}_{N}\n", - "\\right)\\\\\n", - "&= \\frac{1}{2} \\left(\n", - "\\mathbf{x}_{N-1}^{T} Q \\mathbf{x}_{N-1} + \\mathbf{u}_{N-1}^{*T} R \\mathbf{u}^{*}_{N-1} + (A\\mathbf{x}_{N-1} + B\\mathbf{u}_{N-1})^T P (A\\mathbf{x}_{N-1} + B\\mathbf{u}_{N-1})\n", - "\\right)\\\\\n", - "&= \\frac{1}{2} \\left(\n", - "\\mathbf{x}_{N-1}^{T} Q \\mathbf{x}_{N-1} + \\mathbf{x}_{N-1}^{T} K^T R K \\mathbf{x}_{N-1} + \\mathbf{x}_{N-1}^T(A - BK)^T P (A - BK)\\mathbf{x}_{N-1}\n", - "\\right)\\\\\n", - "&= \\frac{1}{2} \n", - "\\mathbf{x}_{N-1}^{T} \\left(\n", - "Q + A^{T}PA-(A^{T}PB)(R+B^{T}PB)^{-1}(B^{T}PA)\n", - "\\right) \\mathbf{x}_{N-1}\n", - "\\\\\n", - "&:= \\frac{1}{2} \\mathbf{x}_{N-1}^T P \\mathbf{x}_{N-1}\\\\\n", - "\\end{array}\n", - "$$\n", - "\n", - "The last step is needed to ensure that the derivation works recursively for all steps." + "The goal is to drive the Mass-Spring-Damper system to the desired position. For that we define the following objective." ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": { "slideshow": { - "slide_type": "subslide" + "slide_type": "fragment" } }, + "outputs": [], "source": [ - "and $P$ is found by solving the discrete time algebraic Riccati equation (DARE):\n", - "\n", - "$$\n", - "Q + A^{T}PA-(A^{T}PB)(R+B^{T}PB)^{-1}(B^{T}PA) = P.\n", - "$$" + "Q = np.diag([1000, 0])\n", + "R = np.diag([0])\n", + "display_array(\"Q\", Q)\n", + "display_array(\"R\", R)\n", + "lqr.set_objective(Q=Q, R=R)" ] }, { @@ -1630,52 +1849,31 @@ } }, "source": [ - "
\n", - " \n", - "
\n", - " LQR Block Diagram.\n", - "
\n", - "
" + "We then complete the LQR setup." ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": { "slideshow": { - "slide_type": "subslide" + "slide_type": "fragment" } }, + "outputs": [], "source": [ - "### Mass-Spring-Damper" + "lqr.setup()" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, - "outputs": [], "source": [ - "# Initialize the controller\n", - "lqr = do_mpc.controller.LQR(mass_spring_damper)\n", - "\n", - "# Initialize the parameters\n", - "env = create_mass_spring_damper_environment()\n", - "lqr.settings.t_step = env.dt\n", - "lqr.settings.n_horizon = None # infinite horizon\n", - "\n", - "# Setting the objective\n", - "Q = np.diag([100, 0])\n", - "R = np.diag([1e-3])\n", - "display_array(\"Q\", Q)\n", - "display_array(\"R\", R)\n", - "lqr.set_objective(Q=Q, R=R)\n", - "\n", - "# lqr setup\n", - "lqr.setup()" + "Finally, we set the desired setpoint." ] }, { @@ -1688,7 +1886,6 @@ }, "outputs": [], "source": [ - "# Define set point\n", "xss = np.array([0.1, 0.0]).reshape(-1, 1)\n", "lqr.set_setpoint(xss)" ] @@ -1701,7 +1898,9 @@ } }, "source": [ - "#### SImulation" + "#### SImulation\n", + "\n", + "Now, we simulate the closed-loop system for 50 steps:" ] }, { @@ -1709,7 +1908,7 @@ "execution_count": null, "metadata": { "slideshow": { - "slide_type": "subslide" + "slide_type": "fragment" } }, "outputs": [], @@ -1744,7 +1943,9 @@ } }, "source": [ - "#### Evaluation" + "#### Evaluation\n", + "\n", + "Finally, we evaluate the controller on the actual environment." ] }, { @@ -1752,7 +1953,7 @@ "execution_count": null, "metadata": { "slideshow": { - "slide_type": "subslide" + "slide_type": "fragment" } }, "outputs": [], @@ -1777,7 +1978,7 @@ "outputs": [], "source": [ "max_steps = 100\n", - "env = create_mass_spring_damper_environment(max_steps=max_steps)\n", + "env = create_mass_spring_damper_environment(render_mode=render_mode, max_steps=max_steps)\n", "controller = LQRController(lqr)\n", "results = simulate_environment(env, max_steps=max_steps, controller=controller)" ] @@ -1792,7 +1993,7 @@ }, "outputs": [], "source": [ - "media.show_video(results.frames, fps=1 / env.dt)" + "show_video(results.frames, fps=1 / env.dt)" ] }, { @@ -1829,7 +2030,9 @@ "metadata": { "slideshow": { "slide_type": "subslide" - } + }, + "solution2": "hidden", + "solution2_first": true }, "source": [ "## Solution" @@ -1841,8 +2044,7 @@ "slideshow": { "slide_type": "subslide" }, - "solution2": "hidden", - "solution2_first": true + "solution2": "hidden" }, "source": [ "### Inverted Pendulum" @@ -1859,22 +2061,55 @@ }, "outputs": [], "source": [ - "# Initialize the controller\n", - "lqr = do_mpc.controller.LQR(inverted_pendulum_lin)\n", - "\n", - "# Initialize the parameters\n", + "lqr = do_mpc.controller.LQR(inverted_pendulum_lin)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "subslide" + }, + "solution2": "hidden" + }, + "outputs": [], + "source": [ "env = create_inverted_pendulum_environment()\n", "lqr.settings.t_step = env.dt\n", - "lqr.settings.n_horizon = None # infinite horizon\n", - "\n", + "lqr.settings.n_horizon = None # infinite horizon" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "subslide" + }, + "solution2": "hidden" + }, + "outputs": [], + "source": [ "# Setting the objective\n", "Q = np.diag([100, 1])\n", "R = np.diag([10])\n", "display_array(\"Q\", Q)\n", "display_array(\"R\", R)\n", - "lqr.set_objective(Q=Q, R=R)\n", - "\n", - "# lqr setup\n", + "lqr.set_objective(Q=Q, R=R)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "subslide" + }, + "solution2": "hidden" + }, + "outputs": [], + "source": [ "lqr.setup()" ] }, @@ -1931,6 +2166,7 @@ "cell_type": "code", "execution_count": null, "metadata": { + "scrolled": false, "slideshow": { "slide_type": "subslide" }, @@ -1985,7 +2221,7 @@ "outputs": [], "source": [ "max_steps = 500\n", - "env = create_inverted_pendulum_environment(max_steps=max_steps)\n", + "env = create_inverted_pendulum_environment(render_mode=render_mode, max_steps=max_steps)\n", "controller = LQRController(lqr)\n", "results = simulate_environment(env, max_steps=max_steps, controller=controller)" ] @@ -2001,7 +2237,7 @@ }, "outputs": [], "source": [ - "media.show_video(results.frames, fps=1 / env.dt)" + "show_video(results.frames, fps=1 / env.dt)" ] }, { @@ -2098,12 +2334,14 @@ } }, "source": [ + "
\n", "
\n", " \n", "
\n", " MPC Block Diagram.\n", "
\n", - "
" + "\n", + "
" ] }, { @@ -2226,8 +2464,8 @@ "\n", "$$\n", "\\begin{array}\\\\\n", - "V_N(\\mathbf{x}_N) &= l_f(\\mathbf{x}_N)\\\\\n", - "V_{t}(\\mathbf{x}_t) &= \\displaystyle \\min_u {l(\\mathbf{x}_{t}, \\mathbf{u}_{t}) + V_{t+1}(f(\\mathbf{x}_{t}, \\mathbf{u}_{N-t}))}\\\\\n", + "V_N(\\mathbf{x}_N) &= g_f(\\mathbf{x}_N)\\\\\n", + "V_{t}(\\mathbf{x}_t) &= \\displaystyle \\min_u {g(\\mathbf{x}_{t}, \\mathbf{u}_{t}) + V_{t+1}(f(\\mathbf{x}_{t}, \\mathbf{u}_{N-t}))}\\\\\n", "V_{t}(\\mathbf{x}_t) &= \\displaystyle \\min_u Q_{t}(\\mathbf{x}_{t}, \\mathbf{u}_{t})\n", "\\end{array}\\\\\n", "$$\n", @@ -2473,22 +2711,56 @@ "| [casadi.bilin(A, x, y)](https://web.casadi.org/python-api/#casadi.casadi.bilin) | Bilinear Form | $x^T A y$ |" ] }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "#### Controller\n", + "\n", + "First, we create an instance of the MPC class is generated with the Mass-Spring-Damper prediction model defined above." + ] + }, { "cell_type": "code", "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "outputs": [], + "source": [ + "mpc = do_mpc.controller.MPC(mass_spring_damper)" + ] + }, + { + "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, + "source": [ + "We choose the finite prediction horizon `n_horizon = 20`, the time step `t_step = 0.04s` to be the same as the environment's time step. We also set the parameters of the applied discretization scheme orthogonal collocation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, "outputs": [], "source": [ "env = create_mass_spring_damper_environment()\n", - "\n", "mpc_params = {\n", " \"n_horizon\": 20,\n", - " \"n_robust\": 0,\n", - " \"open_loop\": 0,\n", " \"t_step\": env.dt,\n", " \"state_discretization\": \"collocation\",\n", " \"collocation_type\": \"radau\",\n", @@ -2498,29 +2770,87 @@ " # Use MA27 linear solver in ipopt for faster calculations:\n", " \"nlpsol_opts\": {\"ipopt.linear_solver\": \"mumps\"},\n", "}\n", + "mpc.set_param(**mpc_params)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "#### Objective\n", "\n", + "The control objective is to move the mass to a desired position (`0.1`) and keep it there." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "outputs": [], + "source": [ "xss = np.array([0.1, 0.0])\n", "distance_cost = 100 * casadi.norm_2(mass_spring_damper.x.cat - xss)\n", "terminal_cost = distance_cost\n", "stage_cost = distance_cost\n", - "input_penalty = 1e-2\n", "print(f\"{stage_cost=}\")\n", "print(f\"{terminal_cost=}\")\n", - "\n", - "mpc = do_mpc.controller.MPC(mass_spring_damper)\n", - "mpc.set_param(**mpc_params)\n", - "mpc.set_objective(mterm=terminal_cost, lterm=stage_cost)\n", - "mpc.set_rterm(force=input_penalty)" + "mpc.set_objective(mterm=terminal_cost, lterm=stage_cost)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "We also restrict the input force." ] }, { "cell_type": "code", "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "outputs": [], + "source": [ + "force_penalty = 1e-2\n", + "mpc.set_rterm(force=force_penalty)" + ] + }, + { + "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, + "source": [ + "#### Constraints\n", + "\n", + "We apply constraints on the force. In this case, there is only an upper and lower bounds for the force." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, "outputs": [], "source": [ "# lower and upper bounds of the input\n", @@ -2529,12 +2859,25 @@ "mpc.bounds[\"upper\", \"_u\", \"force\"] = u_max" ] }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "#### Setup\n", + "\n", + "We can now setup the controller." + ] + }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { - "slide_type": "subslide" + "slide_type": "fragment" } }, "outputs": [], @@ -2550,7 +2893,9 @@ } }, "source": [ - "#### Simulation" + "#### Simulation\n", + "\n", + "We set the initial state and simulate the closed-loop for 100 steps." ] }, { @@ -2596,7 +2941,9 @@ } }, "source": [ - "#### Evaluation" + "#### Evaluation\n", + "\n", + "Finally, we evaluate the controller on the actual environment." ] }, { @@ -2632,7 +2979,7 @@ "source": [ "%%capture\n", "max_steps = 100\n", - "env = create_mass_spring_damper_environment(max_steps=max_steps)\n", + "env = create_mass_spring_damper_environment(render_mode=render_mode, max_steps=max_steps)\n", "controller = MPCController(mpc)\n", "results = simulate_environment(env, max_steps=max_steps, controller=controller)" ] @@ -2647,7 +2994,7 @@ }, "outputs": [], "source": [ - "media.show_video(results.frames, fps=1 / env.dt)" + "show_video(results.frames, fps=1 / env.dt)" ] }, { @@ -2685,7 +3032,9 @@ "metadata": { "slideshow": { "slide_type": "subslide" - } + }, + "solution2": "hidden", + "solution2_first": true }, "source": [ "## Solution" @@ -2697,11 +3046,62 @@ "slideshow": { "slide_type": "subslide" }, - "solution2": "hidden", - "solution2_first": true + "solution2": "hidden" + }, + "source": [ + "#### Controller" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "fragment" + }, + "solution2": "hidden" + }, + "outputs": [], + "source": [ + "mpc = do_mpc.controller.MPC(inverted_pendulum_lin)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "subslide" + }, + "solution2": "hidden" + }, + "outputs": [], + "source": [ + "env = create_inverted_pendulum_environment()\n", + "mpc_params = {\n", + " \"n_horizon\": 50,\n", + " \"t_step\": env.dt,\n", + " \"state_discretization\": \"collocation\",\n", + " \"collocation_type\": \"radau\",\n", + " \"collocation_deg\": 3,\n", + " \"collocation_ni\": 1,\n", + " \"store_full_solution\": True,\n", + " # Use MA27 linear solver in ipopt for faster calculations:\n", + " \"nlpsol_opts\": {\"ipopt.linear_solver\": \"mumps\"},\n", + "}\n", + "mpc.set_param(**mpc_params)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + }, + "solution2": "hidden" }, "source": [ - "### Linear Inverted Pendulum" + "#### Objective" ] }, { @@ -2709,50 +3109,57 @@ "execution_count": null, "metadata": { "slideshow": { - "slide_type": "subslide" + "slide_type": "fragment" }, "solution2": "hidden" }, "outputs": [], "source": [ - "env = create_inverted_pendulum_environment()\n", - "\n", "xss = np.array([0.0, 0.0])\n", "distance_cost = casadi.bilin(np.diag([100, 1]), inverted_pendulum_lin.x.cat - xss)\n", "terminal_cost = distance_cost\n", "stage_cost = distance_cost\n", - "input_penalty = 0.0\n", "print(f\"{stage_cost=}\")\n", "print(f\"{terminal_cost=}\")\n", - "\n", - "mpc_params = {\n", - " \"n_horizon\": 50,\n", - " \"n_robust\": 0,\n", - " \"open_loop\": 0,\n", - " \"t_step\": env.dt,\n", - " \"state_discretization\": \"collocation\",\n", - " \"collocation_type\": \"radau\",\n", - " \"collocation_deg\": 3,\n", - " \"collocation_ni\": 1,\n", - " \"store_full_solution\": True,\n", - " # Use MA27 linear solver in ipopt for faster calculations:\n", - " \"nlpsol_opts\": {\"ipopt.linear_solver\": \"mumps\"},\n", - "}\n", - "mpc = do_mpc.controller.MPC(inverted_pendulum_lin)\n", - "mpc.set_param(**mpc_params)\n", - "mpc.set_objective(mterm=terminal_cost, lterm=stage_cost)\n", - "mpc.set_rterm(force=input_penalty)" + "mpc.set_objective(mterm=terminal_cost, lterm=stage_cost)" ] }, { "cell_type": "code", "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "fragment" + }, + "solution2": "hidden" + }, + "outputs": [], + "source": [ + "force_penalty = 1e-4\n", + "mpc.set_rterm(force=force_penalty)" + ] + }, + { + "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" }, "solution2": "hidden" }, + "source": [ + "#### Constraints" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "fragment" + }, + "solution2": "hidden" + }, "outputs": [], "source": [ "# lower and upper bounds of the input\n", @@ -2761,12 +3168,24 @@ "mpc.bounds[\"upper\", \"_u\", \"force\"] = u_max" ] }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + }, + "solution2": "hidden" + }, + "source": [ + "#### Setup" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { - "slide_type": "subslide" + "slide_type": "fragment" }, "solution2": "hidden" }, @@ -2815,6 +3234,7 @@ "cell_type": "code", "execution_count": null, "metadata": { + "scrolled": false, "slideshow": { "slide_type": "subslide" }, @@ -2872,7 +3292,7 @@ "source": [ "%%capture\n", "max_steps = 500\n", - "env = create_inverted_pendulum_environment(max_steps=max_steps)\n", + "env = create_inverted_pendulum_environment(render_mode=render_mode, max_steps=max_steps)\n", "controller = MPCController(mpc)\n", "results = simulate_environment(env, max_steps=max_steps, controller=controller)" ] @@ -2888,7 +3308,7 @@ }, "outputs": [], "source": [ - "media.show_video(results.frames, fps=1 / env.dt)" + "show_video(results.frames, fps=1 / env.dt)" ] }, { @@ -2951,83 +3371,7 @@ "solution2_first": true }, "source": [ - "### Cart to Origin and Upright Pendulum" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "slideshow": { - "slide_type": "subslide" - }, - "solution2": "hidden" - }, - "outputs": [], - "source": [ - "env = create_inverted_pendulum_environment()\n", - "\n", - "xss = np.array([0, 0, 0, 0])\n", - "distance_cost = casadi.bilin(np.diag([1, 100, 0, 0]), inverted_pendulum.x.cat - xss)\n", - "terminal_cost = distance_cost\n", - "stage_cost = distance_cost\n", - "input_penalty = 0\n", - "print(f\"{stage_cost=}\")\n", - "print(f\"{terminal_cost=}\")\n", - "\n", - "mpc_params = {\n", - " \"n_horizon\": 100,\n", - " \"n_robust\": 0,\n", - " \"open_loop\": 0,\n", - " \"t_step\": env.dt,\n", - " \"state_discretization\": \"collocation\",\n", - " \"collocation_type\": \"radau\",\n", - " \"collocation_deg\": 3,\n", - " \"collocation_ni\": 1,\n", - " \"store_full_solution\": True,\n", - " # Use MA27 linear solver in ipopt for faster calculations:\n", - " \"nlpsol_opts\": {\"ipopt.linear_solver\": \"mumps\"},\n", - "}\n", - "\n", - "mpc = do_mpc.controller.MPC(inverted_pendulum)\n", - "mpc.set_param(**mpc_params)\n", - "mpc.set_objective(mterm=terminal_cost, lterm=stage_cost)\n", - "mpc.set_rterm(force=input_penalty)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "slideshow": { - "slide_type": "subslide" - }, - "solution2": "hidden" - }, - "outputs": [], - "source": [ - "# lower and upper bounds of the position\n", - "x_max = 1\n", - "mpc.bounds[\"lower\", \"_x\", \"position\"] = -x_max\n", - "mpc.bounds[\"upper\", \"_x\", \"position\"] = x_max\n", - "# lower and upper bounds of the input\n", - "u_max = 3\n", - "mpc.bounds[\"lower\", \"_u\", \"force\"] = -u_max\n", - "mpc.bounds[\"upper\", \"_u\", \"force\"] = u_max" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "slideshow": { - "slide_type": "subslide" - }, - "solution2": "hidden" - }, - "outputs": [], - "source": [ - "mpc.setup()" + "### Swing-up" ] }, { @@ -3039,7 +3383,7 @@ "solution2": "hidden" }, "source": [ - "#### Simulation" + "#### Controller" ] }, { @@ -3047,22 +3391,13 @@ "execution_count": null, "metadata": { "slideshow": { - "slide_type": "subslide" + "slide_type": "fragment" }, "solution2": "hidden" }, "outputs": [], "source": [ - "%%capture\n", - "mpc.reset_history()\n", - "inverted_pendulum_simulator.reset_history()\n", - "x0 = np.array([0.5, -0.1, 0.0, 0.0])\n", - "inverted_pendulum_simulator.x0 = x0\n", - "mpc.x0 = x0\n", - "mpc.set_initial_guess()\n", - "for k in range(100):\n", - " u0 = mpc.make_step(x0)\n", - " x0 = inverted_pendulum_simulator.make_step(u0)" + "mpc = do_mpc.controller.MPC(inverted_pendulum)" ] }, { @@ -3076,7 +3411,19 @@ }, "outputs": [], "source": [ - "animate_full_inverted_pendulum_simulation(mpc.data)" + "env = create_inverted_pendulum_environment()\n", + "mpc_params = {\n", + " \"n_horizon\": 100,\n", + " \"t_step\": env.dt,\n", + " \"state_discretization\": \"collocation\",\n", + " \"collocation_type\": \"radau\",\n", + " \"collocation_deg\": 3,\n", + " \"collocation_ni\": 1,\n", + " \"store_full_solution\": True,\n", + " # Use MA27 linear solver in ipopt for faster calculations:\n", + " \"nlpsol_opts\": {\"ipopt.linear_solver\": \"mumps\"},\n", + "}\n", + "mpc.set_param(**mpc_params)" ] }, { @@ -3088,7 +3435,7 @@ "solution2": "hidden" }, "source": [ - "#### Evaluation" + "#### Objective" ] }, { @@ -3096,21 +3443,18 @@ "execution_count": null, "metadata": { "slideshow": { - "slide_type": "subslide" + "slide_type": "fragment" }, "solution2": "hidden" }, "outputs": [], "source": [ - "class MPCController:\n", - " def __init__(self, mpc: do_mpc.controller.MPC) -> None:\n", - " self.mpc = mpc\n", - " self.mpc.reset_history()\n", - " self.mpc.x0 = np.zeros(4)\n", - " self.mpc.set_initial_guess()\n", - "\n", - " def act(self, observation: NDArray) -> NDArray:\n", - " return mpc.make_step(observation.reshape(-1, 1)).ravel()" + "energy_cost = inverted_pendulum.aux[\"E_kinetic\"] - inverted_pendulum.aux[\"E_potential\"]\n", + "terminal_cost = energy_cost\n", + "stage_cost = energy_cost\n", + "print(f\"{stage_cost=}\")\n", + "print(f\"{terminal_cost=}\")\n", + "mpc.set_objective(mterm=terminal_cost, lterm=stage_cost)" ] }, { @@ -3118,31 +3462,26 @@ "execution_count": null, "metadata": { "slideshow": { - "slide_type": "subslide" + "slide_type": "fragment" }, "solution2": "hidden" }, "outputs": [], "source": [ - "%%capture\n", - "max_steps = 200\n", - "env = create_inverted_pendulum_environment(max_steps=max_steps, cutoff_angle=np.inf)\n", - "controller = MPCController(mpc)\n", - "results = simulate_environment(env, max_steps=max_steps, controller=controller)" + "force_penalty = 0.1\n", + "mpc.set_rterm(force=force_penalty)" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" }, "solution2": "hidden" }, - "outputs": [], "source": [ - "media.show_video(results.frames, fps=1 / env.dt)" + "#### Constraints" ] }, { @@ -3150,86 +3489,32 @@ "execution_count": null, "metadata": { "slideshow": { - "slide_type": "subslide" + "slide_type": "fragment" }, "solution2": "hidden" }, "outputs": [], "source": [ - "animate_full_inverted_pendulum_simulation(mpc.data)" + "# lower and upper bounds of the position\n", + "x_max = 1\n", + "mpc.bounds[\"lower\", \"_x\", \"position\"] = -x_max\n", + "mpc.bounds[\"upper\", \"_x\", \"position\"] = x_max\n", + "# lower and upper bounds of the input\n", + "u_max = 3\n", + "mpc.bounds[\"lower\", \"_u\", \"force\"] = -u_max\n", + "mpc.bounds[\"upper\", \"_u\", \"force\"] = u_max" ] }, { "cell_type": "markdown", - "metadata": { - "slideshow": { - "slide_type": "subslide" - }, - "solution2": "hidden", - "solution2_first": true - }, - "source": [ - "### Swing-up" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "slideshow": { - "slide_type": "subslide" - }, - "solution2": "hidden" - }, - "outputs": [], - "source": [ - "env = create_inverted_pendulum_environment()\n", - "\n", - "energy_cost = inverted_pendulum.aux[\"E_kinetic\"] - inverted_pendulum.aux[\"E_potential\"]\n", - "terminal_cost = energy_cost\n", - "stage_cost = energy_cost\n", - "input_penalty = 0.1\n", - "print(f\"{stage_cost=}\")\n", - "print(f\"{terminal_cost=}\")\n", - "\n", - "mpc_params = {\n", - " \"n_horizon\": 100,\n", - " \"n_robust\": 0,\n", - " \"open_loop\": 0,\n", - " \"t_step\": env.dt,\n", - " \"state_discretization\": \"collocation\",\n", - " \"collocation_type\": \"radau\",\n", - " \"collocation_deg\": 3,\n", - " \"collocation_ni\": 1,\n", - " \"store_full_solution\": True,\n", - " # Use MA27 linear solver in ipopt for faster calculations:\n", - " \"nlpsol_opts\": {\"ipopt.linear_solver\": \"mumps\"},\n", - "}\n", - "mpc = do_mpc.controller.MPC(inverted_pendulum)\n", - "mpc.set_param(**mpc_params)\n", - "mpc.set_objective(mterm=terminal_cost, lterm=stage_cost)\n", - "mpc.set_rterm(force=input_penalty)" - ] - }, - { - "cell_type": "code", - "execution_count": null, "metadata": { "slideshow": { "slide_type": "subslide" }, "solution2": "hidden" }, - "outputs": [], "source": [ - "# lower and upper bounds of the position\n", - "x_max = 1\n", - "mpc.bounds[\"lower\", \"_x\", \"position\"] = -x_max\n", - "mpc.bounds[\"upper\", \"_x\", \"position\"] = x_max\n", - "# lower and upper bounds of the input\n", - "u_max = 3\n", - "mpc.bounds[\"lower\", \"_u\", \"force\"] = -u_max\n", - "mpc.bounds[\"upper\", \"_u\", \"force\"] = u_max" + "#### Setup" ] }, { @@ -3237,7 +3522,7 @@ "execution_count": null, "metadata": { "slideshow": { - "slide_type": "subslide" + "slide_type": "fragment" }, "solution2": "hidden" }, @@ -3343,7 +3628,7 @@ "%%capture\n", "max_steps = 500\n", "env = create_inverted_pendulum_environment(\n", - " max_steps=max_steps, cutoff_angle=np.inf, initial_angle=-np.pi\n", + " render_mode=render_mode, max_steps=max_steps, cutoff_angle=np.inf, initial_angle=-np.pi\n", ")\n", "controller = MPCController(mpc)\n", "results = simulate_environment(env, max_steps=max_steps, controller=controller)" @@ -3360,7 +3645,7 @@ }, "outputs": [], "source": [ - "media.show_video(results.frames, fps=1 / env.dt)" + "show_video(results.frames, fps=1 / env.dt)" ] }, { @@ -3561,8 +3846,13 @@ "\n", "- [[Lars Nonlinear 2017]](#grune_nonlinear_2017-back) [Nonlinear model predictive control.](https://link.springer.com/chapter/10.1007/978-3-319-46024-6_3) Grüne, Lars, Jürgen Pannek, Lars Grüne, and Jürgen Pannek. Springer International Publishing, 2017.\n", "\n", + "- [[Biral Notes 2016]](#biral_notes_2016-back) [Notes on numerical methods for solving optimal control problems.](https://www.jstage.jst.go.jp/article/ieejjia/5/2/5_154/_article/-char/ja/) Biral, Francesco, Enrico Bertolazzi, and Paolo Bosetti. IEEJ Journal of Industry Applications 5, no. 2 (2016): 154-166.\n", + "\n", "- [[Bertsektas, Dimitri P, 2022]](#bertsekas_lessons_2022-back) [Lessons from AlphaZero for Optimal, Model Predictive, and Adaptive Control](http://web.mit.edu/dimitrib/www/LessonsfromAlphazero.pdf) - Dimitri P. Bertsektas. 2022.\n", - "- [[Spencer et al. 2023]](#spencer_optimal_2023-back) [AA 203: Optimal and Learning-Based Control.](https://stanfordasl.github.io/aa203/sp2223/) Optimal control solution techniques for systems with known and unknown dynamics. 2023." + "\n", + "- [[Spencer et al. 2023]](#spencer_optimal_2023-back) [AA 203: Optimal and Learning-Based Control.](https://stanfordasl.github.io/aa203/sp2223/) Optimal control solution techniques for systems with known and unknown dynamics. 2023.\n", + "\n", + "- [[Russ Tedrake, 2023]](#tedrake_underactuated_2023-back) [Underactuated Robotics: Algorithms for Walking, Running, Swimming, Flying, and Manipulation (Course Notes for MIT 6.832)](http://underactuated.mit.edu/index.html) Russ Tedrake. Downloaded on 24.10.2023." ] } ], diff --git a/notebooks/nb_40_RecentDevelopmentsInControl.ipynb b/notebooks/nb_40_RecentDevelopmentsInControl.ipynb index abf2c2f6..390e5c25 100644 --- a/notebooks/nb_40_RecentDevelopmentsInControl.ipynb +++ b/notebooks/nb_40_RecentDevelopmentsInControl.ipynb @@ -59,6 +59,7 @@ "outputs": [], "source": [ "%autoreload\n", + "import os\n", "import warnings\n", "from dataclasses import dataclass\n", "from typing import Protocol\n", @@ -80,11 +81,14 @@ " InvertedPendulumParameters,\n", " animate_full_inverted_pendulum_simulation,\n", " simulate_environment,\n", + " show_video\n", ")\n", "\n", "warnings.simplefilter(\"ignore\", UserWarning)\n", "sns.set_theme()\n", - "plt.rcParams[\"figure.figsize\"] = [9, 5]" + "plt.rcParams[\"figure.figsize\"] = [9, 5]\n", + "# This is needed because inside docker the rendering of mujoco environments may not work.\n", + "render_mode = \"rgb_array\" if os.environ.get(\"DISPLAY\") else None" ] }, { @@ -369,18 +373,40 @@ } }, "source": [ - "### Example - Inverted Pendulum\n", + "## Example - Inverted Pendulum\n", "\n", "In a real system, usually the model parameters cannot be determined exactly, what represents an important source of uncertainty. In this example, we consider that the mass of the pendulum and that of the cart are not known precisely \n", "and vary with respect to their nominal value." ] }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "### Model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "#### Model, States and Control inputs" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { - "slide_type": "subslide" + "slide_type": "fragment" } }, "outputs": [], @@ -394,12 +420,23 @@ "u = model.set_variable(var_type=\"_u\", var_name=\"force\")" ] }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "#### Parameters" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { - "slide_type": "subslide" + "slide_type": "fragment" } }, "outputs": [], @@ -411,18 +448,41 @@ "gamma = ip_parameters.gamma\n", "g = ip_parameters.g\n", "mu_p = ip_parameters.mu_p\n", - "mu_c = ip_parameters.mu_c\n", + "mu_c = ip_parameters.mu_c" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "outputs": [], + "source": [ "# Uncertain parameters\n", "m = model.set_variable(\"_p\", \"m\")\n", "M = model.set_variable(\"_p\", \"M\")" ] }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "#### ODE" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { - "slide_type": "subslide" + "slide_type": "fragment" } }, "outputs": [], @@ -441,7 +501,7 @@ "execution_count": null, "metadata": { "slideshow": { - "slide_type": "subslide" + "slide_type": "fragment" } }, "outputs": [], @@ -468,29 +528,47 @@ "model.set_rhs(\"position\", dpos)\n", "model.set_rhs(\"theta\", dtheta)\n", "model.set_rhs(\"velocity\", ddpos)\n", - "model.set_rhs(\"dtheta\", ddtheta)\n", - "model.setup()" + "model.set_rhs(\"dtheta\", ddtheta)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Setup" ] }, { "cell_type": "code", "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model.setup()" + ] + }, + { + "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, + "source": [ + "### Controller" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, "outputs": [], "source": [ - "env = create_inverted_pendulum_environment()\n", - "\n", - "xss = np.array([0, 0, 0, 0])\n", - "distance_cost = casadi.bilin(np.diag([1, 100, 0, 0]), model.x.cat - xss)\n", - "terminal_cost = distance_cost\n", - "stage_cost = distance_cost\n", - "input_penalty = 0\n", - "print(f\"{stage_cost=}\")\n", - "print(f\"{terminal_cost=}\")" + "mpc = do_mpc.controller.MPC(model)" ] }, { @@ -498,15 +576,15 @@ "execution_count": null, "metadata": { "slideshow": { - "slide_type": "subslide" + "slide_type": "fragment" } }, "outputs": [], "source": [ + "env = create_inverted_pendulum_environment()\n", "mpc_params = {\n", " \"n_horizon\": 50,\n", " \"n_robust\": 1,\n", - " \"open_loop\": 0,\n", " \"t_step\": env.dt,\n", " \"state_discretization\": \"collocation\",\n", " \"collocation_type\": \"radau\",\n", @@ -516,21 +594,73 @@ " # Use MA27 linear solver in ipopt for faster calculations:\n", " \"nlpsol_opts\": {\"ipopt.linear_solver\": \"mumps\"},\n", "}\n", - "\n", - "mpc = do_mpc.controller.MPC(model)\n", - "mpc.set_param(**mpc_params)\n", - "mpc.set_objective(mterm=terminal_cost, lterm=stage_cost)\n", - "mpc.set_rterm(force=input_penalty)" + "mpc.set_param(**mpc_params)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "#### Objective" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "outputs": [], + "source": [ + "env = create_inverted_pendulum_environment()\n", + "xss = np.array([0.5, 0, 0, 0])\n", + "distance_cost = casadi.bilin(np.diag([1, 100, 0, 0]), model.x.cat - xss)\n", + "terminal_cost = distance_cost\n", + "stage_cost = distance_cost\n", + "print(f\"{stage_cost=}\")\n", + "print(f\"{terminal_cost=}\")\n", + "mpc.set_objective(mterm=terminal_cost, lterm=stage_cost)" ] }, { "cell_type": "code", "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, + "outputs": [], + "source": [ + "force_penalty = 0.1\n", + "mpc.set_rterm(force=force_penalty)" + ] + }, + { + "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, + "source": [ + "#### Constraints" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "slideshow": { + "slide_type": "fragment" + } + }, "outputs": [], "source": [ "# lower and upper bounds of the position\n", @@ -543,28 +673,49 @@ "mpc.bounds[\"upper\", \"_u\", \"force\"] = u_max" ] }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "#### Parameter Uncertainty" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { - "slide_type": "subslide" + "slide_type": "fragment" } }, "outputs": [], "source": [ "m_values = ip_parameters.m * np.array([1.0, 1.30, 0.70])\n", "M_values = ip_parameters.M * np.array([1.0, 1.30, 0.70])\n", - "\n", "mpc.set_uncertainty_values(m=m_values, M=M_values)" ] }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "#### Setup" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { - "slide_type": "subslide" + "slide_type": "fragment" } }, "outputs": [], @@ -572,6 +723,17 @@ "mpc.setup()" ] }, + { + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "### Simulation" + ] + }, { "cell_type": "code", "execution_count": null, @@ -604,9 +766,9 @@ "outputs": [], "source": [ "%%capture\n", - "max_steps = 200\n", + "max_steps = 100\n", "env = create_inverted_pendulum_environment(\n", - " max_steps=max_steps, cutoff_angle=np.inf, initial_angle=0.99 * np.pi\n", + " render_mode=render_mode, max_steps=max_steps, cutoff_angle=np.inf, initial_angle=0.99*np.pi\n", ")\n", "controller = MPCController(mpc)\n", "results = simulate_environment(env, max_steps=max_steps, controller=controller)" @@ -622,13 +784,14 @@ }, "outputs": [], "source": [ - "media.show_video(results.frames, fps=1 / env.dt)" + "show_video(results.frames, fps=1 / env.dt)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { + "scrolled": false, "slideshow": { "slide_type": "subslide" } @@ -786,6 +949,7 @@ "cell_type": "code", "execution_count": null, "metadata": { + "scrolled": false, "slideshow": { "slide_type": "subslide" } @@ -808,11 +972,11 @@ "\n", "- Beyond the model, the MPC cost function and constraints strongly influence closed-loop performance.\n", "- Learning approaches can design the MPC problem to achieve desired controller behavior.\n", - "- A parameterized MPC formulation is considered with cost $l(x,u,\\theta_l)$ and constraints $\\mathcal{X}(\\theta_\\mathcal{X}), \\mathcal{U}(\\theta_\\mathcal{U})$:\n", + "- A parameterized MPC formulation is considered with cost $g(x,u,\\theta_l)$ and constraints $\\mathcal{X}(\\theta_\\mathcal{X}), \\mathcal{U}(\\theta_\\mathcal{U})$:\n", "\n", " $$\n", " \\begin{array}\\\\\n", - " U^∗ &= \\displaystyle\\arg\\min_{U} \\sum\\limits_{i=0}^{T} (x_i,u_i,\\theta_l)\\\\\n", + " U^∗ &= \\displaystyle\\arg\\min_{U} \\sum\\limits_{i=0}^{T} g(x_i,u_i,\\theta_l)\\\\\n", " \\text{subject to} & x_{i+1} = f(x_i, u_i, \\theta_f)\\\\\n", " & U = [u_0 , \\dots, u_N ] \\in \\mathcal{U}(\\theta_\\mathcal{U})\\\\\n", " & X = [x_0 , \\dots, x_N ] \\in \\mathcal{X}(\\theta_\\mathcal{X})\\\\\n", diff --git a/poetry.lock b/poetry.lock index 466c2460..925ad90e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -11,27 +11,6 @@ files = [ {file = "absl_py-2.0.0-py3-none-any.whl", hash = "sha256:9a28abb62774ae4e8edbe2dd4c49ffcd45a6a848952a5eccc6a49f3f0fc1e2f3"}, ] -[[package]] -name = "accsr" -version = "0.4.5" -description = "Utils for accessing data from anywhere" -optional = false -python-versions = ">=3.8" -files = [ - {file = "accsr-0.4.5-py3-none-any.whl", hash = "sha256:43794f1a61bca4cc9ac49c7cfc9200cddc23fb17c175e90067fee7e75c4cec7e"}, - {file = "accsr-0.4.5.tar.gz", hash = "sha256:73cd857d91a1f2986c691156353048921ef6f330d51083c914a0283b0a64ead3"}, -] - -[package.dependencies] -apache-libcloud = "3.7.0" -numpy = ">1" -pyyaml = "6.0.1" -tqdm = "4.48.2" - -[package.extras] -docs = ["Sphinx (==3.2.1)", "ipython", "nbsphinx", "sphinx-rtd-theme", "sphinxcontrib-websupport (==1.2.4)"] -test = ["pytest"] - [[package]] name = "alabaster" version = "0.7.13" @@ -98,20 +77,6 @@ doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)"] test = ["anyio[trio]", "coverage[toml] (>=7)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] trio = ["trio (>=0.22)"] -[[package]] -name = "apache-libcloud" -version = "3.7.0" -description = "A standard Python library that abstracts away differences among multiple cloud provider APIs. For more information and documentation, please see https://libcloud.apache.org" -optional = false -python-versions = ">=3.6, <4" -files = [ - {file = "apache-libcloud-3.7.0.tar.gz", hash = "sha256:148a9e50069654432a7d34997954e91434dd38ebf68832eb9c75d442b3e62fad"}, - {file = "apache_libcloud-3.7.0-py2.py3-none-any.whl", hash = "sha256:027a9aff2c01db9c8e6f9f94b6eb44b3153d82702c42bfbe7af5624dabf1f950"}, -] - -[package.dependencies] -requests = ">=2.26.0" - [[package]] name = "appnope" version = "0.1.3" @@ -598,6 +563,17 @@ files = [ {file = "cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560"}, ] +[[package]] +name = "chardet" +version = "5.2.0" +description = "Universal encoding detector for Python 3" +optional = false +python-versions = ">=3.7" +files = [ + {file = "chardet-5.2.0-py3-none-any.whl", hash = "sha256:e1cf59446890a00105fe7b7912492ea04b6e6f06d4b742b2c788469e34c82970"}, + {file = "chardet-5.2.0.tar.gz", hash = "sha256:1b3b6ff479a8c414bc3fa2c0852995695c4a026dcd6d0633b2dd092ca39c1cf7"}, +] + [[package]] name = "charset-normalizer" version = "3.3.0" @@ -1336,6 +1312,28 @@ files = [ [package.extras] preview = ["glfw-preview"] +[[package]] +name = "google-api-core" +version = "2.12.0" +description = "Google API client core library" +optional = false +python-versions = ">=3.7" +files = [ + {file = "google-api-core-2.12.0.tar.gz", hash = "sha256:c22e01b1e3c4dcd90998494879612c38d0a3411d1f7b679eb89e2abe3ce1f553"}, + {file = "google_api_core-2.12.0-py3-none-any.whl", hash = "sha256:ec6054f7d64ad13b41e43d96f735acbd763b0f3b695dabaa2d579673f6a6e160"}, +] + +[package.dependencies] +google-auth = ">=2.14.1,<3.0.dev0" +googleapis-common-protos = ">=1.56.2,<2.0.dev0" +protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0" +requests = ">=2.18.0,<3.0.0.dev0" + +[package.extras] +grpc = ["grpcio (>=1.33.2,<2.0dev)", "grpcio (>=1.49.1,<2.0dev)", "grpcio-status (>=1.33.2,<2.0.dev0)", "grpcio-status (>=1.49.1,<2.0.dev0)"] +grpcgcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"] +grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"] + [[package]] name = "google-auth" version = "2.23.3" @@ -1377,6 +1375,160 @@ requests-oauthlib = ">=0.7.0" [package.extras] tool = ["click (>=6.0.0)"] +[[package]] +name = "google-cloud-core" +version = "2.3.3" +description = "Google Cloud API client core library" +optional = false +python-versions = ">=3.7" +files = [ + {file = "google-cloud-core-2.3.3.tar.gz", hash = "sha256:37b80273c8d7eee1ae816b3a20ae43585ea50506cb0e60f3cf5be5f87f1373cb"}, + {file = "google_cloud_core-2.3.3-py2.py3-none-any.whl", hash = "sha256:fbd11cad3e98a7e5b0343dc07cb1039a5ffd7a5bb96e1f1e27cee4bda4a90863"}, +] + +[package.dependencies] +google-api-core = ">=1.31.6,<2.0.dev0 || >2.3.0,<3.0.0dev" +google-auth = ">=1.25.0,<3.0dev" + +[package.extras] +grpc = ["grpcio (>=1.38.0,<2.0dev)"] + +[[package]] +name = "google-cloud-storage" +version = "2.5.0" +description = "Google Cloud Storage API client library" +optional = false +python-versions = ">=3.7" +files = [ + {file = "google-cloud-storage-2.5.0.tar.gz", hash = "sha256:382f34b91de2212e3c2e7b40ec079d27ee2e3dbbae99b75b1bcd8c63063ce235"}, + {file = "google_cloud_storage-2.5.0-py2.py3-none-any.whl", hash = "sha256:19a26c66c317ce542cea0830b7e787e8dac2588b6bfa4d3fd3b871ba16305ab0"}, +] + +[package.dependencies] +google-api-core = ">=1.31.5,<2.0.dev0 || >2.3.0,<3.0.0dev" +google-auth = ">=1.25.0,<3.0dev" +google-cloud-core = ">=2.3.0,<3.0dev" +google-resumable-media = ">=2.3.2" +requests = ">=2.18.0,<3.0.0dev" + +[package.extras] +protobuf = ["protobuf (<5.0.0dev)"] + +[[package]] +name = "google-crc32c" +version = "1.5.0" +description = "A python wrapper of the C library 'Google CRC32C'" +optional = false +python-versions = ">=3.7" +files = [ + {file = "google-crc32c-1.5.0.tar.gz", hash = "sha256:89284716bc6a5a415d4eaa11b1726d2d60a0cd12aadf5439828353662ede9dd7"}, + {file = "google_crc32c-1.5.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:596d1f98fc70232fcb6590c439f43b350cb762fb5d61ce7b0e9db4539654cc13"}, + {file = "google_crc32c-1.5.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:be82c3c8cfb15b30f36768797a640e800513793d6ae1724aaaafe5bf86f8f346"}, + {file = "google_crc32c-1.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:461665ff58895f508e2866824a47bdee72497b091c730071f2b7575d5762ab65"}, + {file = "google_crc32c-1.5.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e2096eddb4e7c7bdae4bd69ad364e55e07b8316653234a56552d9c988bd2d61b"}, + {file = "google_crc32c-1.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:116a7c3c616dd14a3de8c64a965828b197e5f2d121fedd2f8c5585c547e87b02"}, + {file = "google_crc32c-1.5.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:5829b792bf5822fd0a6f6eb34c5f81dd074f01d570ed7f36aa101d6fc7a0a6e4"}, + {file = "google_crc32c-1.5.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:64e52e2b3970bd891309c113b54cf0e4384762c934d5ae56e283f9a0afcd953e"}, + {file = "google_crc32c-1.5.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:02ebb8bf46c13e36998aeaad1de9b48f4caf545e91d14041270d9dca767b780c"}, + {file = "google_crc32c-1.5.0-cp310-cp310-win32.whl", hash = "sha256:2e920d506ec85eb4ba50cd4228c2bec05642894d4c73c59b3a2fe20346bd00ee"}, + {file = "google_crc32c-1.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:07eb3c611ce363c51a933bf6bd7f8e3878a51d124acfc89452a75120bc436289"}, + {file = "google_crc32c-1.5.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:cae0274952c079886567f3f4f685bcaf5708f0a23a5f5216fdab71f81a6c0273"}, + {file = "google_crc32c-1.5.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1034d91442ead5a95b5aaef90dbfaca8633b0247d1e41621d1e9f9db88c36298"}, + {file = "google_crc32c-1.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c42c70cd1d362284289c6273adda4c6af8039a8ae12dc451dcd61cdabb8ab57"}, + {file = "google_crc32c-1.5.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8485b340a6a9e76c62a7dce3c98e5f102c9219f4cfbf896a00cf48caf078d438"}, + {file = "google_crc32c-1.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:77e2fd3057c9d78e225fa0a2160f96b64a824de17840351b26825b0848022906"}, + {file = "google_crc32c-1.5.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f583edb943cf2e09c60441b910d6a20b4d9d626c75a36c8fcac01a6c96c01183"}, + {file = "google_crc32c-1.5.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:a1fd716e7a01f8e717490fbe2e431d2905ab8aa598b9b12f8d10abebb36b04dd"}, + {file = "google_crc32c-1.5.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:72218785ce41b9cfd2fc1d6a017dc1ff7acfc4c17d01053265c41a2c0cc39b8c"}, + {file = "google_crc32c-1.5.0-cp311-cp311-win32.whl", hash = "sha256:66741ef4ee08ea0b2cc3c86916ab66b6aef03768525627fd6a1b34968b4e3709"}, + {file = "google_crc32c-1.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:ba1eb1843304b1e5537e1fca632fa894d6f6deca8d6389636ee5b4797affb968"}, + {file = "google_crc32c-1.5.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:98cb4d057f285bd80d8778ebc4fde6b4d509ac3f331758fb1528b733215443ae"}, + {file = "google_crc32c-1.5.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd8536e902db7e365f49e7d9029283403974ccf29b13fc7028b97e2295b33556"}, + {file = "google_crc32c-1.5.0-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:19e0a019d2c4dcc5e598cd4a4bc7b008546b0358bd322537c74ad47a5386884f"}, + {file = "google_crc32c-1.5.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02c65b9817512edc6a4ae7c7e987fea799d2e0ee40c53ec573a692bee24de876"}, + {file = "google_crc32c-1.5.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:6ac08d24c1f16bd2bf5eca8eaf8304812f44af5cfe5062006ec676e7e1d50afc"}, + {file = "google_crc32c-1.5.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:3359fc442a743e870f4588fcf5dcbc1bf929df1fad8fb9905cd94e5edb02e84c"}, + {file = "google_crc32c-1.5.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:1e986b206dae4476f41bcec1faa057851f3889503a70e1bdb2378d406223994a"}, + {file = "google_crc32c-1.5.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:de06adc872bcd8c2a4e0dc51250e9e65ef2ca91be023b9d13ebd67c2ba552e1e"}, + {file = "google_crc32c-1.5.0-cp37-cp37m-win32.whl", hash = "sha256:d3515f198eaa2f0ed49f8819d5732d70698c3fa37384146079b3799b97667a94"}, + {file = "google_crc32c-1.5.0-cp37-cp37m-win_amd64.whl", hash = "sha256:67b741654b851abafb7bc625b6d1cdd520a379074e64b6a128e3b688c3c04740"}, + {file = "google_crc32c-1.5.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:c02ec1c5856179f171e032a31d6f8bf84e5a75c45c33b2e20a3de353b266ebd8"}, + {file = "google_crc32c-1.5.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:edfedb64740750e1a3b16152620220f51d58ff1b4abceb339ca92e934775c27a"}, + {file = "google_crc32c-1.5.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:84e6e8cd997930fc66d5bb4fde61e2b62ba19d62b7abd7a69920406f9ecca946"}, + {file = "google_crc32c-1.5.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:024894d9d3cfbc5943f8f230e23950cd4906b2fe004c72e29b209420a1e6b05a"}, + {file = "google_crc32c-1.5.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:998679bf62b7fb599d2878aa3ed06b9ce688b8974893e7223c60db155f26bd8d"}, + {file = "google_crc32c-1.5.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:83c681c526a3439b5cf94f7420471705bbf96262f49a6fe546a6db5f687a3d4a"}, + {file = "google_crc32c-1.5.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:4c6fdd4fccbec90cc8a01fc00773fcd5fa28db683c116ee3cb35cd5da9ef6c37"}, + {file = "google_crc32c-1.5.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:5ae44e10a8e3407dbe138984f21e536583f2bba1be9491239f942c2464ac0894"}, + {file = "google_crc32c-1.5.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:37933ec6e693e51a5b07505bd05de57eee12f3e8c32b07da7e73669398e6630a"}, + {file = "google_crc32c-1.5.0-cp38-cp38-win32.whl", hash = "sha256:fe70e325aa68fa4b5edf7d1a4b6f691eb04bbccac0ace68e34820d283b5f80d4"}, + {file = "google_crc32c-1.5.0-cp38-cp38-win_amd64.whl", hash = "sha256:74dea7751d98034887dbd821b7aae3e1d36eda111d6ca36c206c44478035709c"}, + {file = "google_crc32c-1.5.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:c6c777a480337ac14f38564ac88ae82d4cd238bf293f0a22295b66eb89ffced7"}, + {file = "google_crc32c-1.5.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:759ce4851a4bb15ecabae28f4d2e18983c244eddd767f560165563bf9aefbc8d"}, + {file = "google_crc32c-1.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f13cae8cc389a440def0c8c52057f37359014ccbc9dc1f0827936bcd367c6100"}, + {file = "google_crc32c-1.5.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e560628513ed34759456a416bf86b54b2476c59144a9138165c9a1575801d0d9"}, + {file = "google_crc32c-1.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1674e4307fa3024fc897ca774e9c7562c957af85df55efe2988ed9056dc4e57"}, + {file = "google_crc32c-1.5.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:278d2ed7c16cfc075c91378c4f47924c0625f5fc84b2d50d921b18b7975bd210"}, + {file = "google_crc32c-1.5.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d5280312b9af0976231f9e317c20e4a61cd2f9629b7bfea6a693d1878a264ebd"}, + {file = "google_crc32c-1.5.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:8b87e1a59c38f275c0e3676fc2ab6d59eccecfd460be267ac360cc31f7bcde96"}, + {file = "google_crc32c-1.5.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:7c074fece789b5034b9b1404a1f8208fc2d4c6ce9decdd16e8220c5a793e6f61"}, + {file = "google_crc32c-1.5.0-cp39-cp39-win32.whl", hash = "sha256:7f57f14606cd1dd0f0de396e1e53824c371e9544a822648cd76c034d209b559c"}, + {file = "google_crc32c-1.5.0-cp39-cp39-win_amd64.whl", hash = "sha256:a2355cba1f4ad8b6988a4ca3feed5bff33f6af2d7f134852cf279c2aebfde541"}, + {file = "google_crc32c-1.5.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:f314013e7dcd5cf45ab1945d92e713eec788166262ae8deb2cfacd53def27325"}, + {file = "google_crc32c-1.5.0-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b747a674c20a67343cb61d43fdd9207ce5da6a99f629c6e2541aa0e89215bcd"}, + {file = "google_crc32c-1.5.0-pp37-pypy37_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8f24ed114432de109aa9fd317278518a5af2d31ac2ea6b952b2f7782b43da091"}, + {file = "google_crc32c-1.5.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8667b48e7a7ef66afba2c81e1094ef526388d35b873966d8a9a447974ed9178"}, + {file = "google_crc32c-1.5.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:1c7abdac90433b09bad6c43a43af253e688c9cfc1c86d332aed13f9a7c7f65e2"}, + {file = "google_crc32c-1.5.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:6f998db4e71b645350b9ac28a2167e6632c239963ca9da411523bb439c5c514d"}, + {file = "google_crc32c-1.5.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c99616c853bb585301df6de07ca2cadad344fd1ada6d62bb30aec05219c45d2"}, + {file = "google_crc32c-1.5.0-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2ad40e31093a4af319dadf503b2467ccdc8f67c72e4bcba97f8c10cb078207b5"}, + {file = "google_crc32c-1.5.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd67cf24a553339d5062eff51013780a00d6f97a39ca062781d06b3a73b15462"}, + {file = "google_crc32c-1.5.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:398af5e3ba9cf768787eef45c803ff9614cc3e22a5b2f7d7ae116df8b11e3314"}, + {file = "google_crc32c-1.5.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:b1f8133c9a275df5613a451e73f36c2aea4fe13c5c8997e22cf355ebd7bd0728"}, + {file = "google_crc32c-1.5.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9ba053c5f50430a3fcfd36f75aff9caeba0440b2d076afdb79a318d6ca245f88"}, + {file = "google_crc32c-1.5.0-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:272d3892a1e1a2dbc39cc5cde96834c236d5327e2122d3aaa19f6614531bb6eb"}, + {file = "google_crc32c-1.5.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:635f5d4dd18758a1fbd1049a8e8d2fee4ffed124462d837d1a02a0e009c3ab31"}, + {file = "google_crc32c-1.5.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:c672d99a345849301784604bfeaeba4db0c7aae50b95be04dd651fd2a7310b93"}, +] + +[package.extras] +testing = ["pytest"] + +[[package]] +name = "google-resumable-media" +version = "2.6.0" +description = "Utilities for Google Media Downloads and Resumable Uploads" +optional = false +python-versions = ">= 3.7" +files = [ + {file = "google-resumable-media-2.6.0.tar.gz", hash = "sha256:972852f6c65f933e15a4a210c2b96930763b47197cdf4aa5f5bea435efb626e7"}, + {file = "google_resumable_media-2.6.0-py2.py3-none-any.whl", hash = "sha256:fc03d344381970f79eebb632a3c18bb1828593a2dc5572b5f90115ef7d11e81b"}, +] + +[package.dependencies] +google-crc32c = ">=1.0,<2.0dev" + +[package.extras] +aiohttp = ["aiohttp (>=3.6.2,<4.0.0dev)", "google-auth (>=1.22.0,<2.0dev)"] +requests = ["requests (>=2.18.0,<3.0.0dev)"] + +[[package]] +name = "googleapis-common-protos" +version = "1.61.0" +description = "Common protobufs used in Google APIs" +optional = false +python-versions = ">=3.7" +files = [ + {file = "googleapis-common-protos-1.61.0.tar.gz", hash = "sha256:8a64866a97f6304a7179873a465d6eee97b7a24ec6cfd78e0f575e96b821240b"}, + {file = "googleapis_common_protos-1.61.0-py2.py3-none-any.whl", hash = "sha256:22f1915393bb3245343f6efe87f6fe868532efc12aa26b391b15132e1279f1c0"}, +] + +[package.dependencies] +protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0" + +[package.extras] +grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"] + [[package]] name = "grpcio" version = "1.59.0" @@ -1867,26 +2019,6 @@ files = [ [package.dependencies] referencing = ">=0.28.0" -[[package]] -name = "jupyter" -version = "1.0.0" -description = "Jupyter metapackage. Install all the Jupyter components in one go." -optional = false -python-versions = "*" -files = [ - {file = "jupyter-1.0.0-py2.py3-none-any.whl", hash = "sha256:5b290f93b98ffbc21c0c7e749f054b3267782166d72fa5e3ed1ed4eaf34a2b78"}, - {file = "jupyter-1.0.0.tar.gz", hash = "sha256:d9dc4b3318f310e34c82951ea5d6683f67bed7def4b259fafbfe4f1beb1d8e5f"}, - {file = "jupyter-1.0.0.zip", hash = "sha256:3e1f86076bbb7c8c207829390305a2b1fe836d471ed54be66a3b8c41e7f46cc7"}, -] - -[package.dependencies] -ipykernel = "*" -ipywidgets = "*" -jupyter-console = "*" -nbconvert = "*" -notebook = "*" -qtconsole = "*" - [[package]] name = "jupyter-client" version = "8.4.0" @@ -1909,30 +2041,6 @@ traitlets = ">=5.3" docs = ["ipykernel", "myst-parser", "pydata-sphinx-theme", "sphinx (>=4)", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling"] test = ["coverage", "ipykernel (>=6.14)", "mypy", "paramiko", "pre-commit", "pytest", "pytest-cov", "pytest-jupyter[client] (>=0.4.1)", "pytest-timeout"] -[[package]] -name = "jupyter-console" -version = "6.6.3" -description = "Jupyter terminal console" -optional = false -python-versions = ">=3.7" -files = [ - {file = "jupyter_console-6.6.3-py3-none-any.whl", hash = "sha256:309d33409fcc92ffdad25f0bcdf9a4a9daa61b6f341177570fdac03de5352485"}, - {file = "jupyter_console-6.6.3.tar.gz", hash = "sha256:566a4bf31c87adbfadf22cdf846e3069b59a71ed5da71d6ba4d8aaad14a53539"}, -] - -[package.dependencies] -ipykernel = ">=6.14" -ipython = "*" -jupyter-client = ">=7.0.0" -jupyter-core = ">=4.12,<5.0.dev0 || >=5.1.dev0" -prompt-toolkit = ">=3.0.30" -pygments = "*" -pyzmq = ">=17" -traitlets = ">=5.4" - -[package.extras] -test = ["flaky", "pexpect", "pytest"] - [[package]] name = "jupyter-contrib-core" version = "0.4.2" @@ -2424,6 +2532,30 @@ files = [ docs = ["mdx-gh-links (>=0.2)", "mkdocs (>=1.5)", "mkdocs-gen-files", "mkdocs-literate-nav", "mkdocs-nature (>=0.6)", "mkdocs-section-index", "mkdocstrings[python]"] testing = ["coverage", "pyyaml"] +[[package]] +name = "markdown-it-py" +version = "3.0.0" +description = "Python port of markdown-it. Markdown parsing, done right!" +optional = false +python-versions = ">=3.8" +files = [ + {file = "markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb"}, + {file = "markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1"}, +] + +[package.dependencies] +mdurl = ">=0.1,<1.0" + +[package.extras] +benchmarking = ["psutil", "pytest", "pytest-benchmark"] +code-style = ["pre-commit (>=3.0,<4.0)"] +compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"] +linkify = ["linkify-it-py (>=1,<3)"] +plugins = ["mdit-py-plugins"] +profiling = ["gprof2dot"] +rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"] +testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] + [[package]] name = "markupsafe" version = "2.1.3" @@ -2547,25 +2679,16 @@ files = [ traitlets = "*" [[package]] -name = "matplotx" -version = "0.3.10" -description = "Useful styles and extensions for Matplotlib" +name = "mdurl" +version = "0.1.2" +description = "Markdown URL utilities" optional = false python-versions = ">=3.7" files = [ - {file = "matplotx-0.3.10-py3-none-any.whl", hash = "sha256:4d7adafdb001c771d66d9362bb8ca99fcaed15319259223a714f36793dfabbb8"}, - {file = "matplotx-0.3.10.tar.gz", hash = "sha256:b6926ce5274cf5da966cb46b90a8c7fefb761478c6c85c8f7ed3ee8ec90e86e5"}, + {file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"}, + {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, ] -[package.dependencies] -matplotlib = "*" -numpy = ">=1.20.0" - -[package.extras] -all = ["networkx", "pypng", "scipy"] -contour = ["networkx"] -spy = ["pypng", "scipy"] - [[package]] name = "mediapy" version = "1.1.9" @@ -2586,6 +2709,31 @@ Pillow = "*" [package.extras] dev = ["absl-py", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-xdist"] +[[package]] +name = "minari" +version = "0.4.2" +description = "A standard format for offline reinforcement learning datasets, with popular reference datasets and related utilities." +optional = false +python-versions = ">=3.8" +files = [ + {file = "minari-0.4.2-py3-none-any.whl", hash = "sha256:fbdf2aa8c26c39aaafb03318985e5b311935875bdf18041343fb58428596d505"}, + {file = "minari-0.4.2.tar.gz", hash = "sha256:ca538841b5595d1979658c3d9862753d7d3fc304ed13b7b5d8e100f2af2f10f5"}, +] + +[package.dependencies] +google-cloud-storage = "2.5.0" +gymnasium = ">=0.28.1" +h5py = ">=3.8.0" +numpy = ">=1.21.0" +packaging = "23.1" +portion = "2.4.0" +tqdm = ">=4.65.0" +typer = {version = "0.9.0", extras = ["all"]} +typing-extensions = ">=4.4.0" + +[package.extras] +testing = ["gymnasium-robotics (>=1.2.1)", "imageio (>=2.14.1)", "pytest (==7.1.3)"] + [[package]] name = "mistune" version = "3.0.2" @@ -3210,6 +3358,25 @@ rsa = ["cryptography (>=3.0.0)"] signals = ["blinker (>=1.4.0)"] signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"] +[[package]] +name = "opencv-python" +version = "4.8.1.78" +description = "Wrapper package for OpenCV python bindings." +optional = false +python-versions = ">=3.6" +files = [ + {file = "opencv-python-4.8.1.78.tar.gz", hash = "sha256:cc7adbbcd1112877a39274106cb2752e04984bc01a031162952e97450d6117f6"}, + {file = "opencv_python-4.8.1.78-cp37-abi3-macosx_10_16_x86_64.whl", hash = "sha256:91d5f6f5209dc2635d496f6b8ca6573ecdad051a09e6b5de4c399b8e673c60da"}, + {file = "opencv_python-4.8.1.78-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:bc31f47e05447da8b3089faa0a07ffe80e114c91ce0b171e6424f9badbd1c5cd"}, + {file = "opencv_python-4.8.1.78-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9814beca408d3a0eca1bae7e3e5be68b07c17ecceb392b94170881216e09b319"}, + {file = "opencv_python-4.8.1.78-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c4c406bdb41eb21ea51b4e90dfbc989c002786c3f601c236a99c59a54670a394"}, + {file = "opencv_python-4.8.1.78-cp37-abi3-win32.whl", hash = "sha256:a7aac3900fbacf55b551e7b53626c3dad4c71ce85643645c43e91fcb19045e47"}, + {file = "opencv_python-4.8.1.78-cp37-abi3-win_amd64.whl", hash = "sha256:b983197f97cfa6fcb74e1da1802c7497a6f94ed561aba6980f1f33123f904956"}, +] + +[package.dependencies] +numpy = {version = ">=1.23.5", markers = "python_version >= \"3.11\""} + [[package]] name = "overrides" version = "7.4.0" @@ -3223,13 +3390,13 @@ files = [ [[package]] name = "packaging" -version = "23.2" +version = "23.1" description = "Core utilities for Python packages" optional = false python-versions = ">=3.7" files = [ - {file = "packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7"}, - {file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"}, + {file = "packaging-23.1-py3-none-any.whl", hash = "sha256:994793af429502c4ea2ebf6bf664629d07c1a9fe974af92966e4b8d2df7edc61"}, + {file = "packaging-23.1.tar.gz", hash = "sha256:a392980d2b6cffa644431898be54b0045151319d1e7ec34f0cfed48767dd334f"}, ] [[package]] @@ -3588,6 +3755,23 @@ tomli = ">=1.2.2" [package.extras] poetry-plugin = ["poetry (>=1.0,<2.0)"] +[[package]] +name = "portion" +version = "2.4.0" +description = "Python data structure and operations for intervals" +optional = false +python-versions = "~= 3.7" +files = [ + {file = "portion-2.4.0-py3-none-any.whl", hash = "sha256:a084deb01382b53fc7bc83f0734879c537cd7f22c6b6f78ccdc39e963016ea0c"}, + {file = "portion-2.4.0.tar.gz", hash = "sha256:deb16389e844dbf9aeb654261fce5febd720e4786c6690efbb9dc11608226840"}, +] + +[package.dependencies] +sortedcontainers = ">=2.2,<3.0" + +[package.extras] +test = ["black (>=21.8b)", "coverage (>=6.0,<7.0)", "pytest (>=7.0,<8.0)"] + [[package]] name = "pre-commit" version = "3.5.0" @@ -4177,49 +4361,6 @@ files = [ [package.dependencies] cffi = {version = "*", markers = "implementation_name == \"pypy\""} -[[package]] -name = "qtconsole" -version = "5.4.4" -description = "Jupyter Qt console" -optional = false -python-versions = ">= 3.7" -files = [ - {file = "qtconsole-5.4.4-py3-none-any.whl", hash = "sha256:a3b69b868e041c2c698bdc75b0602f42e130ffb256d6efa48f9aa756c97672aa"}, - {file = "qtconsole-5.4.4.tar.gz", hash = "sha256:b7ffb53d74f23cee29f4cdb55dd6fabc8ec312d94f3c46ba38e1dde458693dfb"}, -] - -[package.dependencies] -ipykernel = ">=4.1" -ipython-genutils = "*" -jupyter-client = ">=4.1" -jupyter-core = "*" -packaging = "*" -pygments = "*" -pyzmq = ">=17.1" -qtpy = ">=2.4.0" -traitlets = "<5.2.1 || >5.2.1,<5.2.2 || >5.2.2" - -[package.extras] -doc = ["Sphinx (>=1.3)"] -test = ["flaky", "pytest", "pytest-qt"] - -[[package]] -name = "qtpy" -version = "2.4.0" -description = "Provides an abstraction layer on top of the various Qt bindings (PyQt5/6 and PySide2/6)." -optional = false -python-versions = ">=3.7" -files = [ - {file = "QtPy-2.4.0-py3-none-any.whl", hash = "sha256:4d4f045a41e09ac9fa57fcb47ef05781aa5af294a0a646acc1b729d14225e741"}, - {file = "QtPy-2.4.0.tar.gz", hash = "sha256:db2d508167aa6106781565c8da5c6f1487debacba33519cedc35fa8997d424d4"}, -] - -[package.dependencies] -packaging = "*" - -[package.extras] -test = ["pytest (>=6,!=7.0.0,!=7.0.1)", "pytest-cov (>=3.0.0)", "pytest-qt"] - [[package]] name = "referencing" version = "0.30.2" @@ -4299,6 +4440,38 @@ files = [ {file = "rfc3986_validator-0.1.1.tar.gz", hash = "sha256:3d44bde7921b3b9ec3ae4e3adca370438eccebc676456449b145d533b240d055"}, ] +[[package]] +name = "rich" +version = "13.6.0" +description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "rich-13.6.0-py3-none-any.whl", hash = "sha256:2b38e2fe9ca72c9a00170a1a2d20c63c790d0e10ef1fe35eba76e1e7b1d7d245"}, + {file = "rich-13.6.0.tar.gz", hash = "sha256:5c14d22737e6d5084ef4771b62d5d4363165b403455a30a1c8ca39dc7b644bef"}, +] + +[package.dependencies] +markdown-it-py = ">=2.2.0" +pygments = ">=2.13.0,<3.0.0" + +[package.extras] +jupyter = ["ipywidgets (>=7.5.1,<9)"] + +[[package]] +name = "rise" +version = "5.7.1" +description = "Reveal.js - Jupyter/IPython Slideshow Extension" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, <4" +files = [ + {file = "rise-5.7.1-py2.py3-none-any.whl", hash = "sha256:df8ce9f0e575d334b27ff40a1f91a4c78d9f7b4995858bb81185ceeaf98eae3a"}, + {file = "rise-5.7.1.tar.gz", hash = "sha256:641db777cb907bf5e6dc053098d7fd213813fa9a946542e52b900eb7095289a6"}, +] + +[package.dependencies] +notebook = ">=6.0" + [[package]] name = "rpds-py" version = "0.10.6" @@ -4559,6 +4732,17 @@ docs = ["entangled-cli[rich]", "mkdocs", "mkdocs-entangled-plugin", "mkdocs-mate rich = ["rich"] test = ["build", "pytest", "rich", "wheel"] +[[package]] +name = "shellingham" +version = "1.5.4" +description = "Tool to Detect Surrounding Shell" +optional = false +python-versions = ">=3.7" +files = [ + {file = "shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686"}, + {file = "shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de"}, +] + [[package]] name = "shimmy" version = "0.2.1" @@ -4595,19 +4779,6 @@ files = [ {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, ] -[[package]] -name = "slycot" -version = "0.5.4" -description = "A wrapper for the SLICOT control and systems library" -optional = false -python-versions = ">=3.8" -files = [ - {file = "slycot-0.5.4.tar.gz", hash = "sha256:0bcb6e6322d955bfe696a549cb09c1f16272cbbdd09453cb82ea7193ed3d01c6"}, -] - -[package.dependencies] -numpy = "*" - [[package]] name = "sniffio" version = "1.3.0" @@ -4630,6 +4801,17 @@ files = [ {file = "snowballstemmer-2.2.0.tar.gz", hash = "sha256:09b16deb8547d3412ad7b590689584cd0fe25ec8db3be37788be3810cbf19cb1"}, ] +[[package]] +name = "sortedcontainers" +version = "2.4.0" +description = "Sorted Containers -- Sorted List, Sorted Dict, Sorted Set" +optional = false +python-versions = "*" +files = [ + {file = "sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0"}, + {file = "sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88"}, +] + [[package]] name = "soupsieve" version = "2.5" @@ -4916,6 +5098,22 @@ files = [ {file = "tensorboard_data_server-0.7.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:255c02b7f5b03dd5c0a88c928e563441ff39e1d4b4a234cdbe09f016e53d9594"}, ] +[[package]] +name = "tensorboardx" +version = "2.6.2.2" +description = "TensorBoardX lets you watch Tensors Flow without Tensorflow" +optional = false +python-versions = "*" +files = [ + {file = "tensorboardX-2.6.2.2-py2.py3-none-any.whl", hash = "sha256:160025acbf759ede23fd3526ae9d9bfbfd8b68eb16c38a010ebe326dc6395db8"}, + {file = "tensorboardX-2.6.2.2.tar.gz", hash = "sha256:c6476d7cd0d529b0b72f4acadb1269f9ed8b22f441e87a84f2a3b940bb87b666"}, +] + +[package.dependencies] +numpy = "*" +packaging = "*" +protobuf = ">=3.20" + [[package]] name = "terminado" version = "0.17.1" @@ -5053,17 +5251,23 @@ files = [ [[package]] name = "tqdm" -version = "4.48.2" +version = "4.66.1" description = "Fast, Extensible Progress Meter" optional = false -python-versions = ">=2.6, !=3.0.*, !=3.1.*" +python-versions = ">=3.7" files = [ - {file = "tqdm-4.48.2-py2.py3-none-any.whl", hash = "sha256:1a336d2b829be50e46b84668691e0a2719f26c97c62846298dd5ae2937e4d5cf"}, - {file = "tqdm-4.48.2.tar.gz", hash = "sha256:564d632ea2b9cb52979f7956e093e831c28d441c11751682f84c86fc46e4fd21"}, + {file = "tqdm-4.66.1-py3-none-any.whl", hash = "sha256:d302b3c5b53d47bce91fea46679d9c3c6508cf6332229aa1e7d8653723793386"}, + {file = "tqdm-4.66.1.tar.gz", hash = "sha256:d88e651f9db8d8551a62556d3cff9e3034274ca5d66e93197cf2490e2dcb69c7"}, ] +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + [package.extras] -dev = ["argopt", "py-make (>=0.1.0)", "pydoc-markdown", "twine"] +dev = ["pytest (>=6)", "pytest-cov", "pytest-timeout", "pytest-xdist"] +notebook = ["ipywidgets (>=6)"] +slack = ["slack-sdk"] +telegram = ["requests"] [[package]] name = "traitlets" @@ -5117,6 +5321,30 @@ torch = "*" tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)"] tutorials = ["matplotlib", "pandas", "tabulate"] +[[package]] +name = "typer" +version = "0.9.0" +description = "Typer, build great CLIs. Easy to code. Based on Python type hints." +optional = false +python-versions = ">=3.6" +files = [ + {file = "typer-0.9.0-py3-none-any.whl", hash = "sha256:5d96d986a21493606a358cae4461bd8cdf83cbf33a5aa950ae629ca3b51467ee"}, + {file = "typer-0.9.0.tar.gz", hash = "sha256:50922fd79aea2f4751a8e0408ff10d2662bd0c8bbfa84755a699f3bada2978b2"}, +] + +[package.dependencies] +click = ">=7.1.1,<9.0.0" +colorama = {version = ">=0.4.3,<0.5.0", optional = true, markers = "extra == \"all\""} +rich = {version = ">=10.11.0,<14.0.0", optional = true, markers = "extra == \"all\""} +shellingham = {version = ">=1.3.0,<2.0.0", optional = true, markers = "extra == \"all\""} +typing-extensions = ">=3.7.4.3" + +[package.extras] +all = ["colorama (>=0.4.3,<0.5.0)", "rich (>=10.11.0,<14.0.0)", "shellingham (>=1.3.0,<2.0.0)"] +dev = ["autoflake (>=1.3.1,<2.0.0)", "flake8 (>=3.8.3,<4.0.0)", "pre-commit (>=2.17.0,<3.0.0)"] +doc = ["cairosvg (>=2.5.2,<3.0.0)", "mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "pillow (>=9.3.0,<10.0.0)"] +test = ["black (>=22.3.0,<23.0.0)", "coverage (>=6.2,<7.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.910)", "pytest (>=4.4.0,<8.0.0)", "pytest-cov (>=2.10.0,<5.0.0)", "pytest-sugar (>=0.9.4,<0.10.0)", "pytest-xdist (>=1.32.0,<4.0.0)", "rich (>=10.11.0,<14.0.0)", "shellingham (>=1.3.0,<2.0.0)"] + [[package]] name = "types-python-dateutil" version = "2.8.19.14" @@ -5334,4 +5562,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "865f284d24411b62f2f73a266d5581903826d9020188b82b857e87d5857bf1d9" +content-hash = "605360b21771b0d060f18900d94d7206d54c818b682c9413be0760e3cefe028a" diff --git a/pyproject.toml b/pyproject.toml index 767ff21d..cd47a1c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" [tool.poetry] -name = "tfl training rl" +name = "tfl-training-rl" version = "0.1.0" description = "A transferlab training" authors = [ @@ -25,48 +25,32 @@ include=["src/training_rl/assets"] [tool.poetry.dependencies] python = "^3.11" -#accsr = "^0.4.5" ipykernel = "^6.25.2" ipywidgets = "^8.1.1" jupyter-contrib-nbextensions = "^0.7.0" notebook = "<7.0.0" - -[tool.poetry.group.offline] -optional = true -[tool.poetry.group.offline.dependencies] -minari = "^0.4.2" -chardet ="*" -opencv-python="*" -tensorboardX="*" -pygame = "*" +rise = "^5.7.1" +matplotlib = "^3.8.0" +seaborn = "^0.13.0" +traitlets = "5.9.0" +# special sauce b/c of a flaky bug in poetry on windows +# see https://github.com/python-poetry/poetry/issues/7611#issuecomment-1466478926 +virtualenv = [ + { version = "^20.4.3,!=20.4.5,!=20.4.6" }, + { version = "<20.16.4", markers = "sys_platform == 'win32'" }, +] [tool.poetry.group.add1] optional = true [tool.poetry.group.add1.dependencies] pettingzoo = "^1.24.1" -ipykernel = "^6.25.2" -ipywidgets = "^8.1.1" jsonargparse = "^4.25.0" -jupyter = "^1.0.0" -jupyter-contrib-nbextensions = "^0.7.0" -matplotlib = "^3.8.0" -matplotx = "^0.3.10" -networkx = "^3.1" -notebook = "<7.0.0" numba = "^0.57.1" # b/c of numba numpy = "<=1.24" overrides = "^7.4.0" packaging = "*" pandas = {extras = ["performance"], version = "^2.1.0"} -seaborn = "^0.13.0" -traitlets = "5.9.0" -virtualenv = [ - # special sauce b/c of a flaky bug in poetry on windows - # see https://github.com/python-poetry/poetry/issues/7611#issuecomment-1466478926 - { version = "^20.4.3,!=20.4.5,!=20.4.6" }, - { version = "<20.16.4", markers = "sys_platform == 'win32'" }, -] [tool.poetry.group.add2] optional = true @@ -80,14 +64,23 @@ tensorboard = "^2.5.0" torch = "^2.0.0, !=2.0.1, !=2.1.0" tqdm = "*" - [tool.poetry.group.control] optional = true [tool.poetry.group.control.dependencies] control = "^0.9.4" do-mpc = "^4.6.1" mediapy = "^1.1.9" -slycot = "^0.5.4" +gymnasium = {extras = ["classic-control", "mujoco"], version = "^0.28.0"} +networkx = "^3.1" + +[tool.poetry.group.offline] +optional = true +[tool.poetry.group.offline.dependencies] +minari = "^0.4.2" +chardet ="*" +opencv-python="*" +tensorboardX="*" +pygame = "*" [tool.poetry.group.dev] diff --git a/src/training_rl/control/environment.py b/src/training_rl/control/environment.py index 6581d02d..d3ada4ec 100644 --- a/src/training_rl/control/environment.py +++ b/src/training_rl/control/environment.py @@ -4,17 +4,21 @@ from dataclasses import dataclass from typing import ClassVar, Protocol +import mediapy as media import numpy as np from gymnasium import Env, utils from gymnasium.envs.mujoco import MujocoEnv -from gymnasium.envs.mujoco.inverted_pendulum_v4 import (DEFAULT_CAMERA_CONFIG, - InvertedPendulumEnv) +from gymnasium.envs.mujoco.inverted_pendulum_v4 import ( + DEFAULT_CAMERA_CONFIG, + InvertedPendulumEnv, +) from gymnasium.spaces import Box from gymnasium.wrappers import OrderEnforcing, PassiveEnvChecker, TimeLimit from gymnasium.wrappers.render_collection import RenderCollection from numpy.typing import NDArray __all__ = [ + "show_video", "create_inverted_pendulum_environment", "create_mass_spring_damper_environment", "simulate_environment", @@ -24,8 +28,18 @@ ASSETS_DIR = importlib.resources.files(__package__) / "../assets" +def show_video(frames: list[NDArray], fps: float) -> None: + """Renders the given frames as a video. + + If no frames are passed, then it simply returns without doing anything. + """ + if len(frames) == 0: + return + media.show_video(frames, fps=fps) + + def create_inverted_pendulum_environment( - render_mode: str = "rgb_array", + render_mode: str | None = "rgb_array", max_steps: int = 100, cutoff_angle: float = 0.8, initial_angle: float = 0.0, @@ -41,11 +55,13 @@ def create_inverted_pendulum_environment( env = PassiveEnvChecker(env) env = OrderEnforcing(env) env = TimeLimit(env, max_steps) - return RenderCollection(env) + if render_mode is not None: + env = RenderCollection(env) + return env def create_mass_spring_damper_environment( - render_mode: str = "rgb_array", + render_mode: str | None = "rgb_array", max_steps: int = 100, ) -> Env: """Creates instance of MassSpringDamperEnv with some wrappers @@ -55,7 +71,9 @@ def create_mass_spring_damper_environment( env = PassiveEnvChecker(env) env = OrderEnforcing(env) env = TimeLimit(env, max_steps) - return RenderCollection(env) + if render_mode is not None: + env = RenderCollection(env) + return env class InvertedPendulumEnvWithInitialAndCutoffAngle(InvertedPendulumEnv): diff --git a/src/training_rl/control/shortest_path_problem.py b/src/training_rl/control/shortest_path_problem.py index b17a8499..d6867a72 100644 --- a/src/training_rl/control/shortest_path_problem.py +++ b/src/training_rl/control/shortest_path_problem.py @@ -14,7 +14,7 @@ def create_shortest_path_graph() -> nx.DiGraph: ("A", "C", 5), ("A", "D", 3), ("B", "D", 9), - ("B", "E", 6), + ("B", "E", 1), ("C", "F", 2), ("D", "F", 5), ("D", "G", 8),