In [32]:
# Ran pip install -U torch in a terminal
# pip install -U jax did not install a working version of jax
import torch
from torch.autograd.functional import jacobian

from scipy.optimize import fsolve
def f(x):
    """ Gen Rosenbrock function.
    Args: x (torch.Tensor): Input tensor of shape (n).
    Returns: torch.Tensor: Output tensor of shape (1).
    """
    return torch.sum(100.0 * (x[1:] - x[:-1]**2)**2 + (1 - x[:-1])**2, dim=-1)

def gradf(x):
    """ Gradient of Gen Rosenbrock function.
    Args: x Input vector of length n.
    conversions are needed for the torch AD engine
    Returns: vector length n .
    """
    xx=torch.tensor(x, requires_grad=True)
    yy=f(xx)
    yy.backward()
    return (xx.grad).numpy()

x0 = [-1.0, 1.2, 1.0, 0.1,0.4,0.1,0.2,0.3]
print(gradf(x0))
print()
# Solve the system of equations
roots = fsolve(gradf, x0, xtol=1.49012e-8)
roots

[  76.000015  251.60004   272.       -197.4        86.4       -21.400002
   15.599998   52.000004]


array([-0.99290939,  0.99590476,  0.99684422,  0.99619948,  0.99364965,
        0.98792973,  0.97624273,  0.95304988])

###### 