In [None]:
{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# 12 - Optimal control and adjoint gradients\n",
        "\n",
        "We define a small discrete-time linear system and verify gradients via autodiff.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "from tig.core.random import Rng\n",
        "from tig.control.adjoint_methods import DiscreteDynamics, simulate_with_grad_u\n",
        "\n",
        "tf.keras.backend.set_floatx(\"float64\")\n",
        "tf.config.run_functions_eagerly(True)\n",
        "\n",
        "rng = Rng(seed=0)\n",
        "dt = 0.05\n",
        "n_steps = 30\n",
        "\n",
        "A = tf.eye(4, dtype=tf.float64) * tf.cast(0.95, tf.float64)\n",
        "B = rng.normal((4, 2), dtype=tf.float64) * tf.cast(0.1, tf.float64)\n",
        "x_target = rng.normal((4,), dtype=tf.float64)\n",
        "\n",
        "def f(x: tf.Tensor, u: tf.Tensor, t: float) -> tf.Tensor:\n",
        "    return (A @ tf.reshape(x, (-1, 1)))[:, 0] + (B @ tf.reshape(u, (-1, 1)))[:, 0]\n",
        "\n",
        "dyn = DiscreteDynamics(f=f)\n",
        "x0 = rng.normal((4,), dtype=tf.float64)\n",
        "u0 = rng.normal((n_steps, 2), dtype=tf.float64)\n",
        "\n",
        "def loss(traj: tf.Tensor, u: tf.Tensor) -> tf.Tensor:\n",
        "    e = traj[-1] - x_target\n",
        "    return tf.reduce_sum(e * e) + tf.cast(1e-3, tf.float64) * tf.reduce_sum(u * u)\n",
        "\n",
        "L, g = simulate_with_grad_u(dyn=dyn, x0=x0, u0=u0, t0=0.0, dt=dt, loss=loss)\n",
        "print(\"loss:\", float(tf.reshape(L, ()).numpy()))\n",
        "print(\"||grad||:\", float(tf.linalg.norm(tf.reshape(g, (-1,))).numpy()))\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "name": "python",
      "version": "3.x"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
