In [None]:
{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# 01 - Tensor calculus and autodiff\n",
        "\n",
        "Math objects:\n",
        "- JVP: `J(x)v`\n",
        "- VJP: `J(x)^T u`\n",
        "\n",
        "Adjoint identity:\n",
        "\\[\n",
        "\\langle u, Jv \\rangle = \\langle J^T u, v \\rangle\n",
        "\\]\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "from tig.core.random import Rng\n",
        "from tig.inverse.forward_models import CallableForwardModel\n",
        "\n",
        "tf.keras.backend.set_floatx(\"float64\")\n",
        "tf.config.run_functions_eagerly(True)\n",
        "\n",
        "rng = Rng(seed=0)\n",
        "\n",
        "def f(x: tf.Tensor) -> tf.Tensor:\n",
        "    a = tf.reshape(x, (8, 8))\n",
        "    y = tf.einsum(\"ij,jk->ik\", a, a)\n",
        "    return tf.reshape(tf.math.sin(y) + 0.01 * y, (-1,))\n",
        "\n",
        "model = CallableForwardModel(f=f)\n",
        "\n",
        "x = rng.normal((64,), dtype=tf.float64)\n",
        "v = rng.normal((64,), dtype=tf.float64)\n",
        "u = rng.normal((64,), dtype=tf.float64)\n",
        "\n",
        "jv = model.jvp(x, v)\n",
        "jtu = model.vjp(x, u)\n",
        "\n",
        "left = tf.reduce_sum(u * jv)\n",
        "right = tf.reduce_sum(jtu * v)\n",
        "gap = tf.abs(left - right)\n",
        "denom = tf.abs(left) + tf.abs(right) + tf.cast(1e-15, tf.float64)\n",
        "\n",
        "print(\"relative gap:\", float((gap / denom).numpy()))\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "name": "python",
      "version": "3.x"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
