## HOW DOES GRADIENT DESCENT WORKS?

In [38]:
# Functions used through this explanation
import random
from scratch.gradient_descent import gradient_step
from scratch.complex_typing import Vector
from scratch.linear_algebra import distance

Imagine we have a sum of squares function

$$
f(x, y, z) = x^2 + y^2 + z^2
$$


We want to compute the **gradient** of the function to understand how it changes with respect to each variable. This is essential for gradient-based optimization, where the goal is to find the point that minimizes the function.

For example, if we have the function:

its partial derivatives are:
- $\frac{\partial f}{\partial x} = 2x$
- $\frac{\partial f}{\partial y} = 2y$
- $\frac{\partial f}{\partial z} = 2z$

These represent the slope of the function along each axis. The gradient vector combines them:

$$
\nabla f(x, y, z) = \left( 2x,\ 2y,\ 2z \right)
$$

This gradient points in the direction of steepest increase. To minimize the function, we move in the opposite direction — toward the point where the gradient is zero, which in this case is the origin: $(0, 0, 0)$.

Let’s define a function that returns this gradient vector:

In [11]:
def sum_of_squares_gradient(v: Vector):
    # Gradient for a sum of squares function
    return [2*v_i for v_i in v]

Now we have to take a random point:

In [32]:
v = [random.uniform(-10,10) for i in range(3)]
print(v)

[6.2312768532071985, -9.861394979835412, -2.991835185564808]


Then we do the magic, we are going to iterate many times:


In [33]:
n_obs = 0
print(f"obs:{n_obs} -> {v}")
for epoch in range(500):
    grad = sum_of_squares_gradient(v)
    v = gradient_step(v, grad, -0.01)
    n_obs += 1
    if n_obs <=5 or n_obs>=495:
        print(f"obs:{n_obs} -> {v}")


obs:0 -> [6.2312768532071985, -9.861394979835412, -2.991835185564808]
obs:1 -> [6.106651316143054, -9.664167080238704, -2.931998481853512]
obs:2 -> [5.984518289820193, -9.47088373863393, -2.873358512216442]
obs:3 -> [5.864827924023789, -9.28146606386125, -2.815891341972113]
obs:4 -> [5.747531365543313, -9.095836742584025, -2.759573515132671]
obs:5 -> [5.632580738232447, -8.913920007732344, -2.7043820448300173]
obs:495 -> [0.00028280332742456005, -0.0004475543903830732, -0.00013578291665023996]
obs:496 -> [0.00027714726087606884, -0.00043860330257541175, -0.00013306725831723517]
obs:497 -> [0.0002716043156585475, -0.00042983123652390353, -0.00013040591315089048]
obs:498 -> [0.0002661722293453765, -0.0004212346117934255, -0.00012779779488787267]
obs:499 -> [0.000260848784758469, -0.000412809919557557, -0.0001252418389901152]
obs:500 -> [0.00025563180906329957, -0.0004045537211664058, -0.0001227370022103129]


Now we should be able to see that v is close to [0, 0, 0] were it finds it's minimun