In [1]:
import sympy
import jax

# Calculating Gradients

In [2]:
def f(x):
    return x**4 + 12*x + 1/x

### Manual differentiation

In [3]:
def df(x):
    return 4*x**3 + 12 - 1/x**2

In [4]:
x = 11.0
print(f(x))
print(df(x))

14773.09090909091
5335.99173553719


### Symbolic differentiation

In [5]:
x_sym = sympy.symbols('x')
f_sym = f(x_sym)
df_sym = sympy.diff(f_sym)
print(f_sym)
print(df_sym)

x**4 + 12*x + 1/x
4*x**3 + 12 - 1/x**2


In [6]:
f = sympy.lambdify(x_sym, f_sym)
print(f(x))

14773.09090909091


In [7]:
df = sympy.lambdify(x_sym, df_sym)
print(df(x))

5335.99173553719


### Numerical differentiation

In [8]:
x = 11.0
dx = 1e-6

df_x_numeric = (f(x+dx) - f(x))/dx
print(df_x_numeric)

5335.992456821259


### Automatic differentiation

In [9]:
df = jax.grad(f)
print(df(x))

2024-10-08 14:43:57.781076: W external/xla/xla/service/gpu/nvptx_compiler.cc:930] The NVIDIA driver's CUDA version is 12.3 which is older than the PTX compiler version 12.6.77. Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


5335.9917
