In [None]:
# ! wget https://raw.githubusercontent.com/callummcdougall/arena-v1/main/w3d1/utils.py
# ! pip install -U scikit-learn scipy
import torch as t
import utils
from typing import Callable, Iterable
from torch import optim

In [None]:
def rosenbrocks_banana(x: t.Tensor, y: t.Tensor, a=1, b=100) -> t.Tensor:
    return (a - x) ** 2 + b * (y - x**2) ** 2 + 1

x_range = [-2, 2]
y_range = [-1, 3]
fig = utils.plot_fn(rosenbrocks_banana, x_range, y_range, log_scale=True)

In [None]:
fig.show()

Min is at (1,1) when both squared terms are zero

In [None]:
def opt_fn_with_sgd(fn: Callable, xy: t.Tensor, lr=0.001, momentum=0.98, n_iters: int = 100):
    '''
    Optimize the a given function starting from the specified point.

    xy: shape (2,). The (x, y) starting point.
    n_iters: number of steps.

    Return: (n_iters, 2). The (x,y) BEFORE each step. So out[0] is the starting point.
    '''
    assert xy.requires_grad
    optimizer = optim.SGD([xy], lr=lr, momentum=momentum)
    progression_curve = t.zeros((n_iters, len(xy)))
    for i in range(n_iters):
        progression_curve[i] = xy.detach()
        result = fn(*xy)
        result.backward()
        optimizer.step()
        optimizer.zero_grad()
    return progression_curve

xy = t.tensor([-1.5, 2.5], requires_grad=True)
x_range = [-2, 2]
y_range = [-1, 3]

fig = utils.plot_optimization_sgd(opt_fn_with_sgd, rosenbrocks_banana, xy, x_range, y_range, lr=0.001, momentum=0.98, show_min=True)

fig.show()

In [None]:
from torch.optim.optimizer import Optimizer

In [None]:
class SGD(Optimizer):
    params: list

    def __init__(self, params: Iterable[t.nn.parameter.Parameter], lr: float, momentum: float, weight_decay: float):
        '''Implements SGD with momentum.

        Like the PyTorch version, but assume nesterov=False, maximize=False, and dampening=0
            https://pytorch.org/docs/stable/generated/torch.optim.SGD.html#torch.optim.SGD
        '''
        self.params = list(params)

        if lr < 0.0:
            raise ValueError(f"Invalid learning rate: {lr}")
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {momentum}")
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {weight_decay}")
        self.lr = lr
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.timestep = 0

        self.gradient_updates = [t.zeros_like(param) for param in self.params]

    def step(self) -> None:
        for t, (gradient_update, param) in enumerate(zip(self.gradient_updates, self.params)):
            grads = param.grad
            if self.weight_decay != 0:
                grads = grads + self.weight_decay*param
            if self.momentum != 0 and self.timestep > 1:
                # I wonder if this is correct, i thought it should be (1-momentum)*grads
                grads = self.momentum * gradient_update + grads
            


    def __repr__(self) -> str:
        # Should return something reasonable here, e.g. "SGD(lr=lr, ...)"
        return f"SGD lr={self.lr} momentum={self.momentum} weight_decay={self.weight_decay}"

utils.test_sgd(SGD)