In [7]:
#import the libraries
from mpl_toolkits import mplot3d
%matplotlib qt
import numpy as np
import matplotlib.pyplot as plt
from sympy import symbols, diff, lambdify 
 


In [4]:
# Define the sympy symbols to be used in the function
x = symbols('x')
y = symbols('y')
#Define the function in terms of x and y
f1 = (x-2) ** 2 + (y-2)**2+5
 
# Calculate the partial derivatives of f1 w.r.t. x and y 
f1x = diff(f1,x)
f1y = diff(f1,y)

In [8]:
# Define a function optimized for numpy array calculation
# in sympy 
f = lambdify([x,y],f1,'numpy')

In [25]:
x_grid = np.linspace(-3, 3, 30)
y_grid = np.linspace(-3, 3, 30)

In [26]:
# Create mesh grid for surface plot
X, Y = np.meshgrid(x_grid,y_grid)

In [28]:
#Define the surface function using the lambdify function 
Z = f(X, Y )

In [30]:
#Select a start point
x0,y0 = (3,3)
#Initialize a list for storing the gradient descent points
xlist = [x0]
ylist = [y0]

In [31]:
 
#Specify the learning rate
lr=0.001

In [32]:
#Perform gradient descent
for i in range(100): 
    # Update the x and y values using the negative gradient values
    x0-=f1x.evalf(subs={x:x0,y:y0})*lr
    y0-=f1y.evalf(subs={x:x0,y:y0})*lr
    # Append to the list to keep track of the points
    xlist.append(x0)
    ylist.append(y0)

In [33]:
xarr = np.array(xlist,dtype='float64')
yarr = np.array(ylist,dtype='float64')   

In [35]:
zlist = list(f(xarr,yarr))

In [38]:
#Plot the surface and points
ax = plt.axes(projection='3d')
ax.plot_surface(X, Y, Z, rstride=1, cstride=1,
                cmap='viridis', edgecolor='none')

ax.plot(xlist,ylist,zlist,'ro',markersize=10,alpha=0.6)
ax.set_title('Gradient Descent');
ax.set_aspect('equal')

In [41]:
xlist 

[3,
 2.99800000000000,
 2.99600400000000,
 2.99401199200000,
 2.99202396801600,
 2.99003992007997,
 2.98805984023981,
 2.98608372055933,
 2.98411155311821,
 2.98214333001197,
 2.98017904335195,
 2.97821868526525,
 2.97626224789472,
 2.97430972339893,
 2.97236110395213,
 2.97041638174422,
 2.96847554898074,
 2.96653859788277,
 2.96460552068701,
 2.96267630964563,
 2.96075095702634,
 2.95882945511229,
 2.95691179620207,
 2.95499797260966,
 2.95308797666444,
 2.95118180071111,
 2.94927943710969,
 2.94738087823547,
 2.94548611647900,
 2.94359514424604,
 2.94170795395755,
 2.93982453804964,
 2.93794488897354,
 2.93606899919559,
 2.93419686119720,
 2.93232846747480,
 2.93046381053985,
 2.92860288291878,
 2.92674567715294,
 2.92489218579863,
 2.92304240142703,
 2.92119631662418,
 2.91935392399093,
 2.91751521614295,
 2.91568018571066,
 2.91384882533924,
 2.91202112768856,
 2.91019708543319,
 2.90837669126232,
 2.90655993787980,
 2.90474681800404,
 2.90293732436803,
 2.90113144971929,
 2.89932

In [42]:
ylist

[3,
 2.99800000000000,
 2.99600400000000,
 2.99401199200000,
 2.99202396801600,
 2.99003992007997,
 2.98805984023981,
 2.98608372055933,
 2.98411155311821,
 2.98214333001197,
 2.98017904335195,
 2.97821868526525,
 2.97626224789472,
 2.97430972339893,
 2.97236110395213,
 2.97041638174422,
 2.96847554898074,
 2.96653859788277,
 2.96460552068701,
 2.96267630964563,
 2.96075095702634,
 2.95882945511229,
 2.95691179620207,
 2.95499797260966,
 2.95308797666444,
 2.95118180071111,
 2.94927943710969,
 2.94738087823547,
 2.94548611647900,
 2.94359514424604,
 2.94170795395755,
 2.93982453804964,
 2.93794488897354,
 2.93606899919559,
 2.93419686119720,
 2.93232846747480,
 2.93046381053985,
 2.92860288291878,
 2.92674567715294,
 2.92489218579863,
 2.92304240142703,
 2.92119631662418,
 2.91935392399093,
 2.91751521614295,
 2.91568018571066,
 2.91384882533924,
 2.91202112768856,
 2.91019708543319,
 2.90837669126232,
 2.90655993787980,
 2.90474681800404,
 2.90293732436803,
 2.90113144971929,
 2.89932