In [None]:
{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# 04 - Matrix-free Krylov (CG)\n",
        "\n",
        "We solve `Ax=b` for SPD `A` with a matrix-free operator exposing `matvec`.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "from tig.core.random import Rng\n",
        "from tig.linalg.krylov import cg\n",
        "\n",
        "tf.keras.backend.set_floatx(\"float64\")\n",
        "tf.config.run_functions_eagerly(True)\n",
        "\n",
        "rng = Rng(seed=0)\n",
        "m = rng.normal((128, 128), dtype=tf.float64)\n",
        "A = tf.transpose(m) @ m + 1e-3 * tf.eye(128, dtype=tf.float64)\n",
        "\n",
        "class Op:\n",
        "    def __init__(self, a: tf.Tensor) -> None:\n",
        "        self._a = a\n",
        "    def matvec(self, x: tf.Tensor) -> tf.Tensor:\n",
        "        return (self._a @ tf.reshape(x, (-1, 1)))[:, 0]\n",
        "    def rmatvec(self, y: tf.Tensor) -> tf.Tensor:\n",
        "        return self.matvec(y)\n",
        "\n",
        "x_true = rng.normal((128,), dtype=tf.float64)\n",
        "b = (A @ tf.reshape(x_true, (-1, 1)))[:, 0]\n",
        "\n",
        "res = cg(op=Op(A), b=b, x0=tf.zeros_like(b), tol=1e-10, max_iter=500)\n",
        "x_hat = res.x\n",
        "r = b - Op(A).matvec(x_hat)\n",
        "rel = tf.linalg.norm(r) / (tf.linalg.norm(b) + tf.cast(1e-15, tf.float64))\n",
        "print(\"final relative residual:\", float(rel.numpy()))\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "name": "python",
      "version": "3.x"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
