# Intro to Gradient Descent (and computational science)

**Prepared by Dan Hackett (2025)**

This exercise is an introduction to computational science using Python and Jupyter notebooks, and secretly also an introduction to coding if you haven't done much before. The best way to learn how to use a tool is to try to use it to do something. So, as a well-motivated excuse to learn, we're going to implement "gradient descent", the backbone algorithm underlying all of training in AI/ML.

Gradient descent is an *optimization algorithm* or *optimizer*. Its objective is to find the parameters which minimize a "loss function".

Gradient descent is easy to understand from a physics analogy: we're just going to simulate something sliding down a hill to the lowest point it can find. In the analogy, the "loss function" is just the height of the hill as a function of the position $x$. We'll call this function $V(x)$.

**Note:** gradient descent isn't *quite* the same as e.g. a rollercoaster on a track or a ball rolling down a hill. That's because in gradient descent, the thing that's moving doesn't have any momentum. So, it won't pick up speed as it goes downhill the way a rolling ball would. Instead, the "speed" it moves with just depends on how steep the hill is. It's more like something sliding with a lot of friction.

Here's an example:
$$
    V(x) = -0.1x - x^2 + x^4
$$
We'll work with this one all through the notebook.

# Plotting V(x) the hard way

First, we're going to slowly build up to making a plot of $V(x)$ vs $x$. In this first section, we'll do things the hard way to review some basic programming concepts. Later on, we'll see how more powerful tools (specifically, numpy) can make this much easier.

## Exercise 1

Okay, let's get started.

**Write a line of Python code which computes $V(x)$.** We've set the variable `x` to a value already; write your code to compute $V(x)$ from this variable.

**HINT:** in Python, you raise things to a power by using `**`, not `^`. So, to compute $x^2$, you do `x**2`.

In [None]:
x = 1
V = ?????

In [None]:
# run this cell to check your answer numerically
import numpy as np
assert np.isclose(V, -0.1)

In [None]:
# SOLUTION
x = 1
V = -0.1*x - x**2 + x**4

Now, you have some code that should be able to compute $V(x)$ for any arbitrary value of $x$. But we need to check that it's right!

To check your expression for `V`, we want to use it to make a plot of $V(x)$ as a function of $x$ so we can see whether it looks right. To do that, we first need to evaluate $V(x)$ on a grid of different values of $x$. We'll walk through how to do this a few different ways.

First, let's explore how to do this using `list`s, `for` loops, and the "accumulator" pattern.

## Review: Collections and `list`

Let's see how to work with `list`, one of Python's basic built-in collections. A collection is an object which contains other objects. A `list` (linked list) stores an ordered list of objects. In Python, they are denoted with square brackets and commas like `[3,1,2]`. An empty list is `[]`.

Run the cells below to see how lists work.

In [None]:
x = [3,1,2] # make a new list with three items
print(x)
x.append(7) # add an item to the list
print(x)

In [None]:
x = [] # make an empty list
x.append(1) # add an item
print(x)
x.append(5) # add an item
print(x)

x = [6] # make a new list, overwrite old one
print(x)

## Review: `for` loops, range

`for` loops let you run the same block of code repeatedly, except with some variable which changes its value in each "iteration" of the loop, so it doesn't just do exactly the same thing every time.

In Python, `for` loops run over all the variables in a collection, so you can think of them as running code "for each" item in the collection in order.

Play with the example code below to see what a `for` loop does.

In [None]:
for x in [5,3,7]: # try adding or removing items from the list
    print(x)

In [None]:
a_list = [4,2,1]
for x in a_list:
    x = x+1
    print(x)
print(a_list) # doesn't modify items in original list!

Python has a built-in helper function `range` to do `for` loops over sequential integers. For example, `range(5)` lets you do a loop over `0,1,2,3,4`. Note that it doesn't include the last integer in the sequence.

Play with the example code below to see what `range` does.

In [None]:
for n in range(10): # try changing 10 to other numbers
    print(n)

In [None]:
for n in range(-5,5): # can start from values other than zero
    print(n)

In [None]:
for n in range(-10,10,2): # bigger steps
    print(n)

## Concept: accumulator pattern

A useful way to combine `for` and `list` is the "accumulator" pattern. This lets us build up (accumulate) a list full of values following a certain pattern.

Play with the code below to understand the pattern.

In [None]:
a_list = []
for i in range(5):
    x = 3 * i**2 + 1
    a_list.append(x)
print(a_list)

## Exercise 2

Now, let's use this pattern to make a list of values of $x$. Specifically, we want 101 evenly-spaced values between -1 and 1. The first value should be -1, the last value should be 1.

**NOTE:** you might see some round-off error!

In [None]:
all_x = []

# TODO

print(all_x)

In [None]:
# SOLUTION
all_x = []
for i in range(101):
    x = -1 + i/50
    all_x.append(x)
print(all_x)

Let's check whether it worked. Often it's useful to make a quick plot just to see whether your code ran correctly. 

**Run the code below and see if it makes sense!** This should also remind you how to make a simple plot with matplotlib.

**NOTE:** if you just call `plt.plot` with one argument like `plt.plot(x)`, it will plot the values in `x` on the y-axis versus `0,1,...,len(x)-1` on the x-axis.

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
plt.plot(all_x, '.')

We can also check the values directly. Run the code below to check if you got it right.

If the checks below all pass, then you've definitely got the right answer. Can you see why?

In [None]:
import numpy as np
assert all_x[0] == -1 # check first item
assert all_x[-1] == 1 # check last item
# even spacing
dx = all_x[1] - all_x[0]
for i in range(100):
    # This code won't work because of round-off error:
    #assert (all_x[i+1] - all_x[i]) == dx
    # Use helper function from numpy for floating-point comparisons instead:
    assert np.isclose((all_x[i+1] - all_x[i]), dx)

## Exercise 3

Now, let's put it all together! Go back and grab your code from above and use it to **write a new accumulator pattern which computes $V(x)$** for each one of these values of $x$, and stores it in a list `all_V`.

In [None]:
all_V = []

# TODO (hint: do some copy-pasting!)

print(all_V)

In [None]:
# SOLUTION
all_V = []
for x in all_x:
    V = -0.1*x - x**2 + x**4
    all_V.append(V)
print(all_V)

Now, run the code below to make a plot of $V(x)$ vs $x$. **Does this look right?**

In [None]:
plt.plot(all_x, all_V, '.')

# Plotting V(x) the easy way

That was a lot of work just to make one plot! Fortunately, there are much faster and easier ways to do all this using Python's standard NumPy library (always abbreviated as `np`).

If we hadn't already imported numpy for some of the tests above, you'd run the cell below to import numpy in the standard way. Running it again is unnecessary, but won't break anything (or do anything at all).

In [None]:
import numpy as np

## Review: numpy

NumPy is a library for creating and manipulating N-dimensional arrays (`NDArray`s) of numbers. There are many ways to create numpy arrays, for example by passing a list to the `np.array` function.

Array operations are *broadcasted*. That means that any math operation gets applied to every element in the array.

**Play with the code below to see how this works.**

In [None]:
x = np.array([1,2,3]) # make an array from a list
print(x)

x = np.arange(5) # like range, but makes a numpy array
print(x)

In [None]:
x = np.array([1,2,3])
print(2*x) # multiply every number by 2
print(x**2) # square every number

In [None]:
# operations involving two different arrays
x = np.array([3,1,2])
y = np.array([2,3,5])
print(x*y)
print(x+y)
print(x-y)

One very useful helper function is `np.linspace`, which makes an evenly-spaced grid of values.

**Run the cell below for an example.** How does this compare with `all_x` we computed above?

In [None]:
x = np.linspace(-1,1,101)
print(x)

## Exercise 4

**Now, try taking this `x` made by running `linspace` and running your code to compute `V`.** What happens?

In [None]:
x = np.linspace(-1,1,101)

V = # TODO: copy-paste your code from exercise 1

print(V)

In [None]:
# SOLUTION
x = np.linspace(-1,1,1024)
V = -0.1*x - x**2 + x**4
print(x,V)

**Now, use these to make a plot of $V$ vs $x$.**

In [None]:
plt.plot(#TODO)

In [None]:
# SOLUTION
plt.plot(x,V)

# Packaging up V(x)

By now, you should have a bit of code that can compute $V(x)$ correctly. But, you don't want to have to copy-paste this everywhere---this is a very easy way to make mistakes! Instead, we're going to package your code into a function that we can call later on.

## Review: functions

A function is a bundle of code that you can run over and over with different inputs.

Functions aren't exactly the same in code as in math. Like functions in math, they can evaluate to (return) a value which depends on their arguments.

The code "inside" a function is an indented block, just like with `if`, `while`, and `for`. This code runs whenever you call the function.

**Play with the code below to remember how defining a function works.**

In [None]:
def multiply(a,b):
    y = a*b # run some code
    return y # return value that we computed
print(multiply(2,4))
print(multiply(9,7))

## Exercise 5

Write a function to compute $V(x)$ named `calc_V` which takes a single argument `x` and returns $V(x)$ for that value.

In [None]:
def calc_V(x):
    # TODO (HINT: copy-paste your code here)

In [None]:
# SOLUTION
def calc_V(x):
    return -0.1*x - x**2 + x**4

Now, **call `calc_V` to compute $V$, and make a plot of $V(x)$ vs $x$** to check everything worked okay!

In [None]:
x = np.linspace(-1,1,101)

V = #TODO

In [None]:
plt.plot(#TODO)

In [None]:
# SOLUTION
x = np.linspace(-1,1,101)
V = calc_V(x)
plt.plot(x, V, '.-')

# Derivatives and finite differences

One last concept!

In order to implement gradient descent, we don't just need $V(x)$, we also need its derivative. This gets notated like $V'(x)$ if you are Newton, or $\frac{\partial V}{\partial x}$ if you are Leibniz. In case you don't remember what a derivative is: $V'(x)$ is just the slope right at the point $x$.

We're going to take the derivative the pencil-and-paper way below, but first, let's make the computer do it the hard way. This will also give us something to check our pen-and-paper answer below.

Derivatives are easy to approximate numerically as a "finite difference":
$$
    V'(x) = \frac{dV}{dx} \approx \frac{V(x+\epsilon) - V(x)}{\epsilon}
$$
This just computes the slope by comparing $V(x)$ with its value a small step $\epsilon$ away. If we take $\epsilon \rightarrow 0$, it becomes $V'(x)$ exactly. (But if we're doing it numerically, we have to pick some small but nonzero $\epsilon$!)

## Exercise 6

**Compute a finite difference approximation of $V'(x)$ on a grid of $x$ and make a plot of it.**

**Note:** there are many ways to do this!

In [None]:
# TODO

In [None]:
# SOLUTION (e.g.)
x = np.linspace(-1,1,100)
V = calc_V(x)
dVdx = np.diff(V) / np.diff(x)
plt.plot(x[:-1], dVdx, '.-')

## Exercise 7

Fortunately, thanks to the miracle of calculus, we can just take the derivative analytically and get a new function we can evaluate separately. Recall that our function is:
$$
    V(x) = -0.1x - x^2 + x^4
$$
Its derivative is just:
$$
V'(x) = \frac{\partial V}{\partial x} = -0.1 - 2x + 4x^3
$$

Now, **write a function `calc_dVdx` that takes a single argument `x` and returns `V'(x)`.**

In [None]:
def calc_dVdx(x):
    # TODO

In [None]:
# SOLUTION
def calc_dVdx(x):
    return -0.1 - 2*x + 4*x**3

**Call `calc_dVdx` to compute $V'(x)$ on a grid and make a plot of $V'(x)$ vs $x$.** Does it look right?

In [None]:
x = np.linspace(-1,1,101)

dVdx = # TODO

plt.plot(#TODO)

In [None]:
# SOLUTION
x = np.linspace(-1,1,101)
V = calc_V(x)
dVdx = calc_dVdx(x)
plt.plot(x, V)
plt.plot(x, dVdx)

## Exercise 8

Now, we've implemented $V'(x)$ two different ways. Let's make sure we get the same answer from both! If they don't agree, then at least one of them must be wrong.

**Make a plot with your two different versions of `V'` on it.** Do they look the same?

In [None]:
# HINT: calling plt.plot twice will put two lines on the same plot
plt.plot(#TODO: finite different version)
plt.plot(#TODO: analytic version)

In [None]:
# SOLUTION
x = np.linspace(-1,1,101)
V = calc_V(x)
dVdx = calc_dVdx(x)
plt.plot(x, dVdx)
plt.plot(x[:-1], np.diff(V)/np.diff(x))

# Gradient descent

Okay, now we have all the ingredients we need to implement gradient descent!

The steps of the gradient descent algorithm are:
1. Pick some starting value $x$
2. Take a step: update $x$ according to the rule
$$
    x \rightarrow x - \eta V'(x)
$$
where $\eta$ is a "hyperparameter" called the step size.
3. Repeat step 2 many times until we're at a minimum.

We said before the gradient descent is like something sliding down a hill. You can think of this rule as defining how its position $x$ evolves in time, like $x(t)$. Each iteration of gradient descent is like a fixed step forward in time.

## Exercise 9

It's a good idea to write your code "from the inside out". So, before we get to running many steps, let's just write code to run one step.

**Write the code to update the value of $x$ (step 2 above).** Use the variable `step_size` for $\eta$. 

**HINT:** you'll need to call your $V'(x)$ function.

In [None]:
x = 0
step_size = 0.1 # "eta" by another name

print(x)

# TODO: update x

print(x)

In [None]:
# SOLUTION
x = 0
step_size = 0.1 # "eta" by another name

print(x)
dVdx = calc_dVdx(x)
x = x - step_size * dVdx
print(x)

In [None]:
# Run this cell to check your answer!
assert np.isclose(x, 0.01)

## Exercise 10

Now, on to actually implementing it! What you need to do is:
* Pick some starting value for `x`
* Make a for loop to run 100 steps

Inside the loop, 
* Update `x`
* Save the new value of `x` inside a list named `history`

At the end, we should have a list `history` that tells us how `x` evolved under gradient descent. `history` is like $x(t)$.

In [None]:
history = []
step_size = 0.05

# TODO: set initial x value

# TODO: write the loop

In [None]:
# SOLUTION
history = []
step_size = 0.05

x = -0.5 # starting value

for n in range(100):
    dVdx = calc_dVdx(x)
    x = x - step_size * dVdx
    history.append(x)

Now, run the first cell below to make a plot of $V(x)$ vs $x$ (for reference/comparison), and the second cell to make a plot of how $x(t)$ evolved under gradient descent.

Does the value $x(t)$ goes towards make sense? Why or why not?

In [None]:
x = np.linspace(-1,1,1024)
V = calc_V(x)
plt.plot(x, V)

In [None]:
plt.plot(history)

## Exercise 11

Now we're going to look at all the different ways that gradient descent can break! This will be more of a free-form exploration.

You may have noticed that it went to the "wrong minimum" out of the two. This is a local minimum, and not the global minimum where the loss is actually minimized. This is a well-known failure mode of gradient descent! Let's see if you can fix it.

The knobs you can turn are the starting value of $x$ and the step size. Try:
* Make the step size smaller and smaller.
* Make the step size larger and larger.
* Picking different values of starting $x$.
What happens in each case?





In [None]:
# HINT: it can be a good idea to collect all the code in one cell here,
# so that you don't have to scroll up and rerun a bunch of cells to play around.

# Bonus: Stochastic gradient descent

In real life ML, we won't be able to compute $V'(x)$ exactly. Instead, we only get to compute a noisy approximation of it using whatever data we have handy (e.g. the entire internet). If you only have a noisy estimator for $V'(x)$, then you are doing "stochastic gradient descent" (SGD) instead of just the normal kind.

We can make a toy model of SGD very easily. Numpy has built-in random number generators you can use. The cell below draws a Gaussian random number (i.e., a number taken from the normal distribution or "bell curve"). Run it a couple times and convince yourself it's random.

In [None]:
np.random.normal(0, 0.1) # first parameter is the mean, second parameter is standard deviation

To emulate SGD, we can just take `calc_dVdx` from early and add a little bit of random noise to it.

**Run the cell below.** This function is a "wrapper" around your `calc_dVdx` from earlier. It'll compute $V'(x)$, add a bit of random noise to it, then return the noised-up value. `sigma` is the standard deviation of the noise; passing `sigma=0` won't add any noise, passing `sigma=10000` will add a lot.

In [None]:
def calc_noisy_dVdx(x, sigma=0.5):
    dVdx = calc_dVdx(x)
    dVdx += np.random.normal(0,sigma)
    return dVdx

## Exercise 12

Now, take your gradient descent implementation from above, copy-paste it down here, and change where it calls `calc_dVdx` to call `calc_noisy_dVdx` instead. Run it and see what happens!

Once you've got it working, try playing around with the noise level `sigma` (you can pass this as an extra argument to `calc_noisy_dVdx`) and the other parameters. Some things to think about:
* SGD won't ever settle exactly into a minimum, just bounce around. Any ideas how you could fix this?
* What happens if you make the noise level too high?
* SGD can hop between the different minima in $V(x)$. Try to get this to happen!

If you get frustrated that the noise changes every time you run the cell, use `np.random.seed` to seed the RNG before you run your loop.

In [None]:
# TODO: gradient descent, but with calc_dVdx -> calc_noisy_dVdx

In [None]:
# SOLUTION
history = []
step_size = 0.05

x = 0.1
for n in range(100):
    dVdx = calc_noisy_dVdx(x, sigma=1)
    x = x - step_size * dVdx
    history.append(x)
    
plt.plot(history)

In [None]:
# Just for fun: learning rate scheduling!
history = []
start_x = 0.1
step_size = 0.05

x = start_x
for n in range(100):
    dVdx = calc_noisy_dVdx(x)
    x = x - step_size * dVdx
    if n%10 == 0:
        step_size /= 2
    history.append(x)
    
plt.plot(history)