In [None]:
# ! wget https://raw.githubusercontent.com/callummcdougall/arena-v1/main/w3d1/utils.py
# ! pip install -U scikit-learn scipy
import torch as t
import utils
from typing import Callable, Iterable
from torch import optim

In [None]:
def rosenbrocks_banana(x: t.Tensor, y: t.Tensor, a=1, b=100) -> t.Tensor:
    return (a - x) ** 2 + b * (y - x**2) ** 2 + 1

x_range = [-2, 2]
y_range = [-1, 3]
fig = utils.plot_fn(rosenbrocks_banana, x_range, y_range, log_scale=True)

In [None]:
fig.show()

Min is at (1,1) when both squared terms are zero

In [None]:
def opt_fn_with_sgd(fn: Callable, xy: t.Tensor, lr=0.001, momentum=0.98, n_iters: int = 100):
    '''
    Optimize the a given function starting from the specified point.

    xy: shape (2,). The (x, y) starting point.
    n_iters: number of steps.

    Return: (n_iters, 2). The (x,y) BEFORE each step. So out[0] is the starting point.
    '''
    assert xy.requires_grad
    optimizer = optim.SGD([xy], lr=lr, momentum=momentum)
    progression_curve = t.zeros((n_iters, len(xy)))
    for i in range(n_iters):
        progression_curve[i] = xy.detach()
        result = fn(*xy)
        result.backward()
        optimizer.step()
        optimizer.zero_grad()
    return progression_curve

xy = t.tensor([-1.5, 2.5], requires_grad=True)
x_range = [-2, 2]
y_range = [-1, 3]

fig = utils.plot_optimization_sgd(opt_fn_with_sgd, rosenbrocks_banana, xy, x_range, y_range, lr=0.001, momentum=0.98, show_min=True)

fig.show()

In [None]:
from torch.optim.optimizer import Optimizer

In [None]:
class SGD:
    params: list

    def __init__(self, params: Iterable[t.nn.parameter.Parameter], lr: float, momentum: float, weight_decay: float=0):
        '''Implements SGD with momentum.

        Like the PyTorch version, but assume nesterov=False, maximize=False, and dampening=0
            https://pytorch.org/docs/stable/generated/torch.optim.SGD.html#torch.optim.SGD
        '''
        self.params = list(params)

        if lr < 0.0:
            raise ValueError(f"Invalid learning rate: {lr}")
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {momentum}")
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {weight_decay}")
        self.lr = lr
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.timestep = 0

        # same as self.gs on streamlit
        self.gradient_updates = [t.zeros_like(p) for p in self.params]

    def zero_grad(self) -> None:
        """Set param grads to None
        """
        for param in self.params:
            param.grad = None

    def step(self) -> None:
        with t.inference_mode():
            for i, (gradient_update, param) in enumerate(zip(self.gradient_updates, self.params)):
                grads = param.grad
                if self.weight_decay != 0:
                    grads += self.weight_decay*param
                if self.momentum != 0 and self.timestep > 0:
                    # I wonder if this is correct, i thought it should be (1-momentum)*grads
                    grads += self.momentum * gradient_update
                self.params[i] -= self.lr * grads
                self.gradient_updates[i]=grads
            self.timestep += 1
            
            


    def __repr__(self) -> str:
        # Should return something reasonable here, e.g. "SGD(lr=lr, ...)"
        return f"SGD lr={self.lr} momentum={self.momentum} weight_decay={self.weight_decay}"

utils.test_sgd(SGD)

In [None]:
class RMSprop:
    def __init__(
        self,
        params: Iterable[t.nn.parameter.Parameter],
        lr: float,
        alpha: float,
        eps: float,
        weight_decay: float,
        momentum: float,
    ):
        '''Implements RMSprop.

        Like the PyTorch version, but assumes centered=False
            https://pytorch.org/docs/stable/generated/torch.optim.RMSprop.html#torch.optim.RMSprop
        '''
        self.params = list(params)
        self.lr = lr
        self.alpha = alpha
        self.eps = eps
        self.weight_decay = weight_decay
        self.momentum = momentum

        self.moving_average_squared_gradients = [t.zeros_like(param) for param in self.params]
        self.gradient_updates = [t.zeros_like(p) for p in self.params]

        self.timestep = 0

    def zero_grad(self) -> None:
        """Set param grads to None
        """
        for param in self.params:
            param.grad = None

    def step(self) -> None:
        with t.inference_mode():
            for i, (ma, gradient_update, param) in enumerate(zip(self.moving_average_squared_gradients, self.gradient_updates, self.params)):
                grad = param.grad
                if self.weight_decay != 0:
                    grad += self.weight_decay*param

                ma = self.alpha * ma + (1 - self.alpha) * grad ** 2
                if self.momentum != 0:
                    # I wonder if this is correct, i thought it should be (1-momentum)*grads
                    gradient_update = self.momentum * gradient_update + grad / (t.sqrt(ma) + self.eps)
                    self.params[i] -= self.lr * gradient_update
                    self.gradient_updates[i] = gradient_update
                else: 
                    
                    self.params[i] -= self.lr / (t.sqrt(ma) + self.eps) * grad
                self.moving_average_squared_gradients[i] = ma
            self.timestep += 1



    def __repr__(self) -> str:
        return f"RMSprop lr = {self.lr}; alpha = {self.alpha}; weight_decay = {self.weight_decay}"



utils.test_rmsprop(RMSprop)

In [None]:
class Adam:
    def __init__(
        self,
        params: Iterable[t.nn.parameter.Parameter],
        lr: float,
        betas: tuple[float, float],
        eps: float,
        weight_decay: float,
    ):
        '''Implements Adam.

        Like the PyTorch version, but assumes amsgrad=False and maximize=False
            https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam
        '''
        self.params = list(params)
        self.lr = lr
        self.betas = betas
        
        self.eps = eps
        self.weight_decay = weight_decay

        self.m = [t.zeros_like(param) for param in self.params]
        self.v = [t.zeros_like(param) for param in self.params]

        self.timestep = 1

    def zero_grad(self) -> None:
        """Set param grads to None
        """
        for param in self.params:
            param.grad = None

    def step(self) -> None:
        with t.inference_mode():
            for i, (m, v, param) in enumerate(zip(self.m, self.v, self.params)):
                grad = param.grad
                if self.weight_decay != 0:
                    grad += self.weight_decay * param
                m = self.betas[0] * m + (1 - self.betas[0]) * grad
                v = self.betas[1] * v + (1-self.betas[1]) * grad ** 2

                
                self.m[i] = m
                self.v[i] = v

                # adjust in the beginning of training
                m_hat = m / (1 - self.betas[0]**self.timestep)
                v_hat = v / (1 - self.betas[1]**self.timestep)


                self.params[i] -= self.lr * m_hat / (t.sqrt(v_hat)+self.eps)

            self.timestep += 1

    def __repr__(self) -> str:
        return f"Adam lr = {self.lr} betas = {self.betas} weight_decay = {self.weight_decay}"

utils.test_adam(Adam)

## Plotting multiple optimisers

In [None]:
def opt_fn(fn: Callable, xy: t.Tensor, optimizer_class, optimizer_kwargs, n_iters: int = 100):
    '''Optimize the a given function starting from the specified point.

    optimizer_class: one of the optimizers you've defined, either SGD, RMSprop, or Adam
    optimzer_kwargs: keyword arguments passed to your optimiser (e.g. lr and weight_decay)
    '''
    assert xy.requires_grad
    optimizer = optimizer_class([xy], **optimizer_kwargs)
    progression_curve = t.zeros((n_iters, len(xy)))
    for i in range(n_iters):
        progression_curve[i] = xy.detach()
        result = fn(*xy)
        result.backward()
        optimizer.step()
        optimizer.zero_grad()
    return progression_curve

xy = t.tensor([-1.5, 2.5], requires_grad=True)
x_range = [-2, 2]
y_range = [-1, 3]
optimizer_kwargs = {"lr": 0.001, "momentum": 0.98}

opt_fn(rosenbrocks_banana, xy, SGD, optimizer_kwargs=optimizer_kwargs)

In [None]:
xy = t.tensor([-1.5, 2.5], requires_grad=True)
x_range = [-2, 2]
y_range = [-1, 3]
optimizers = [
    (SGD, dict(lr=1e-3, momentum=0.98)),
    (SGD, dict(lr=5e-4, momentum=0.98)),
]
fn = rosenbrocks_banana
fig = utils.plot_optimization(opt_fn, fn, xy, optimizers, x_range, y_range)

fig.show()