diff --git a/docs/report/001.ipynb b/docs/report/001.ipynb index 862ac83c..958a2b04 100644 --- a/docs/report/001.ipynb +++ b/docs/report/001.ipynb @@ -414,7 +414,9 @@ " _module = \"jax\"\n", "\n", " def _print_ComplexSqrt(self, expr: sp.Expr) -> str:\n", - " return \"select([less(x, 0), True], [1j * sqrt(-x), sqrt(x)], default=nan,)\"" + " arg = expr.args[0]\n", + " x = self._print(arg)\n", + " return f\"select([less({x}, 0), True], [1j * sqrt(-{x}), sqrt({x})], default=nan,)\"" ] }, {