In [7]:
import numpy as np
import plotly.graph_objects as go
import pandas as pd

# Define the function and its gradient
def f(x, y):
    return x**2 + y**2

def gradient(x, y):
    return np.array([2*x, 2*y])

# Gradient descent parameters
learning_rate = 0.1
num_steps = 30
initial_point = np.array([1.0, -2.0])

# Perform gradient descent and store results
points = []
x, y = initial_point

for step in range(num_steps):
    f_val = f(x, y)
    grad = gradient(x, y)
    points.append([step, x, y, f_val, grad[0], grad[1]])
    
    # Update x, y
    x -= learning_rate * grad[0]
    y -= learning_rate * grad[1]

# Convert to NumPy array
points = np.array(points)

# Generate surface grid
X = np.linspace(-2, 2, 100)
Y = np.linspace(-2, 2, 100)
X, Y = np.meshgrid(X, Y)
Z = f(X, Y)

# Create 3D surface plot
fig = go.Figure()
fig.add_trace(go.Surface(z=Z, x=X, y=Y, colorscale='Viridis', opacity=0.6))

# Add gradient descent path
fig.add_trace(go.Scatter3d(x=points[:, 1], y=points[:, 2], z=points[:, 3],
                           mode='markers+lines', marker=dict(size=5, color='red')))

fig.update_layout(title="Gradient Descent on a 3D Loss Surface",
                  scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='f(X, Y)'))

fig.show()

# Convert to DataFrame
df = pd.DataFrame(points, columns=["Step", "x", "y", "f(x, y)", "∇f_x", "∇f_y"])

# Apply styling
styled_df = df.style.set_properties(**{
    'text-align': 'center',  # Center align text
    'border': '1px solid black',  # Add table borders
}).set_table_styles([
    {'selector': 'th', 'props': [('font-weight', 'bold'), ('background-color', '#4CAF50'), ('color', 'white')]}
]).background_gradient(subset=["f(x, y)"], cmap="coolwarm")  # Gradient color on function values

# Display the styled DataFrame
styled_df


Unnamed: 0,Step,x,y,"f(x, y)",∇f_x,∇f_y
0,0.0,1.0,-2.0,5.0,2.0,-4.0
1,1.0,0.8,-1.6,3.2,1.6,-3.2
2,2.0,0.64,-1.28,2.048,1.28,-2.56
3,3.0,0.512,-1.024,1.31072,1.024,-2.048
4,4.0,0.4096,-0.8192,0.838861,0.8192,-1.6384
5,5.0,0.32768,-0.65536,0.536871,0.65536,-1.31072
6,6.0,0.262144,-0.524288,0.343597,0.524288,-1.048576
7,7.0,0.209715,-0.41943,0.219902,0.41943,-0.838861
8,8.0,0.167772,-0.335544,0.140737,0.335544,-0.671089
9,9.0,0.134218,-0.268435,0.090072,0.268435,-0.536871
