In [25]:
import sys
import numpy as np
sys.path.append(r"C:\pythonprojects\calculas\venv\Lib\site-packages")
from jax import vmap,grad
import jax.numpy as jnp
from sympy import *
sqrt(18)

3*sqrt(2)

In [12]:
x,y = symbols('x y')
expression = 3 * x**2 + y
expression

3*x**2 + y

In [13]:
additional_expression = x * (expression)
additional_expression

x*(3*x**2 + y)

In [14]:
expanded=expand(additional_expression)
expanded

3*x**3 + x*y

In [15]:
factor(expanded)  

x*(3*x**2 + y)

In [16]:
expression.evalf(subs={x:1,y:2})  #here we are substituting the value of x and y with 1 and 2 respectively.


5.00000000000000

In [17]:
a = np.array([1,2,3])
f_symb = x ** 2
f_symb(a)

TypeError: 'Pow' object is not callable

In [18]:
from sympy.utilities.lambdify import lambdify
f_symb_numpy = lambdify(x,f_symb,'numpy')
f_symb_numpy(a)

array([1, 4, 9])

In [19]:
diff(x**3,x)  #here diff takes the function and the argument for calculating the derivatives

3*x**2

In [21]:
dfdx_symb = diff(f_symb,x)  #calculating the derivatives
dfdx_symb

2*x

In [24]:
dfdx_symb_numpy = lambdify(x,dfdx_symb,'numpy')  #making it numpy friendly
dfdx_symb_numpy(a)

array([2, 4, 6])

In [28]:
b = jnp.array([1,2,3,4])
type(b)
b = b.astype('float32')
b


Array([1., 2., 3., 4.], dtype=float32)

In [29]:
b[1] = 5

TypeError: JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html

In [30]:
b.at[1].set(5)  

Array([1., 5., 3., 4.], dtype=float32)

In [36]:
def f(x):
    return x**2
grad(f)(2.0)


Array(4., dtype=float32, weak_type=True)

In [39]:
try:
    grad(f)(b)  #calculating the derivative of jnp array
except TypeError as err:
    print(err)

Gradient only defined for scalar-output functions. Output had shape: (4,).


In [42]:
try:
    j = vmap(grad(f))(b)  #using the vmap api to calculate the derivative of array of larger size
except TypeError as err:
    print(err)   

In [43]:
j

Array([2., 4., 6., 8.], dtype=float32)

In [46]:
#from here we can see that automatic differentiation is way more faster compared to symbolic and numerical differentiation, as automatic differentiation uses the chain rule
import timeit, time


x_array_large = np.linspace(-5, 5, 1000000)  #larger array

tic_symb = time.time()
res_symb = lambdify(x, diff(f(x),x),'numpy')(b)    #derivatives using the symbol, passing the parameter b in the function f(b)
toc_symb = time.time()
time_symb = 1000 * (toc_symb - tic_symb)  # Time in ms.

tic_numerical = time.time()
res_numerical = np.gradient(f(b),b)  #derivatives using the numerical gradient
toc_numerical = time.time()
time_numerical = 1000 * (toc_numerical - tic_numerical)

tic_jax = time.time()
res_jax = vmap(grad(f))(jnp.array(b.astype('float32')))  #derivatives using the jax, by passing the array b converted into jnp array as parameter
toc_jax = time.time()
time_jax = 1000 * (toc_jax - tic_jax)

print(f"Results\nSymbolic Differentiation:\n{res_symb}\n" + 
      f"Numerical Differentiation:\n{res_numerical}\n" + 
      f"Automatic Differentiation:\n{res_jax}")

print(f"\n\nTime\nSymbolic Differentiation:\n{time_symb} ms\n" + 
      f"Numerical Differentiation:\n{time_numerical} ms\n" + 
      f"Automatic Differentiation:\n{time_jax} ms")

Results
Symbolic Differentiation:
[2. 4. 6. 8.]
Numerical Differentiation:
[3. 4. 6. 7.]
Automatic Differentiation:
[2. 4. 6. 8.]


Time
Symbolic Differentiation:
164.74342346191406 ms
Numerical Differentiation:
5.802392959594727 ms
Automatic Differentiation:
27.464866638183594 ms
