In [1]:
import numpy as np
import pandas as pd
import torch

In [2]:
def zero_grad(p):
    if p.grad is not None:
        p.grad.zero_()

In [3]:
def l1_loss(y, yh):
    return torch.abs(y - yh)

def mse_loss(y, yh):
    return (y - yh)**2

In [4]:
def f(p, x):
    # Simple (verrrry) linear model.
    return p * x

def fit(p_gt, loss_fn, lr, num_epochs=3, x=[1.0, 3.0]):
    # Expected param + labeled dataset.
    p_gt = torch.tensor(p_gt)
    x = torch.tensor(x)
    y = f(p_gt, x)

    # Param to optimize
    p = torch.tensor(1.0, requires_grad=True)
    dfs = []

    def simple_log(epoch, batch_idx):
        # Show err.
        dp = p.detach() - p_prev
        p_err = p.detach() - p_gt
        dfs.append(pd.DataFrame(
            dict(
                epoch=epoch,
                batch_idx=batch_idx,
                p=p.detach().numpy(),
                Δp=dp.numpy(),
                p_err=p_err.numpy(),
                p_gt=p_gt.numpy(),
            ),
            index=[len(dfs)],
        ))

    p_prev = p.detach().clone()
    simple_log(epoch="pre-opt", batch_idx="n/a")
    for epoch in range(num_epochs):
        # Batch size = 1
        for batch_idx, (xi, yi) in enumerate(zip(x, y)):
            p_prev = p.detach().clone()
            zero_grad(p)
            yhi = f(p, xi)
            loss = loss_fn(yi, yhi)
            loss.backward()

            # SGD update, no momentum.
            with torch.no_grad():
                v = p.grad  # velocity
                p -= lr * v  # step
            # Validation.
            simple_log(epoch, batch_idx)
    print(loss_fn)
    display(pd.concat(dfs))

In [5]:
# Optimal step size for linear model w/ mse loss
fit(p_gt=3.0, loss_fn=mse_loss, lr=0.5, num_epochs=2)

<function mse_loss at 0x7fd778276bf8>


Unnamed: 0,epoch,batch_idx,p,Δp,p_err,p_gt
0,pre-opt,,1.0,0.0,-2.0,3.0
1,0,0.0,3.0,2.0,0.0,3.0
2,0,1.0,3.0,0.0,0.0,3.0
3,1,0.0,3.0,0.0,0.0,3.0
4,1,1.0,3.0,0.0,0.0,3.0


In [6]:
# L1 loss is param dependent (just see grad)

# Base case - meh. Takes two well-conditioned minibatches.
fit(p_gt=3.0, loss_fn=l1_loss, lr=0.5)
# Shift data via expected param. Takes longer to converge (duh).
fit(p_gt=5.0, loss_fn=l1_loss, lr=0.5)
# Shift data via data points. Now has a stable (but shitty) orbit.
fit(p_gt=3.0, loss_fn=l1_loss, lr=0.5, x=[1.0, 20.0], num_epochs=4)

<function l1_loss at 0x7fd77a2b72f0>


Unnamed: 0,epoch,batch_idx,p,Δp,p_err,p_gt
0,pre-opt,,1.0,0.0,-2.0,3.0
1,0,0.0,1.5,0.5,-1.5,3.0
2,0,1.0,3.0,1.5,0.0,3.0
3,1,0.0,3.0,0.0,0.0,3.0
4,1,1.0,3.0,0.0,0.0,3.0
5,2,0.0,3.0,0.0,0.0,3.0
6,2,1.0,3.0,0.0,0.0,3.0


<function l1_loss at 0x7fd77a2b72f0>


Unnamed: 0,epoch,batch_idx,p,Δp,p_err,p_gt
0,pre-opt,,1.0,0.0,-4.0,5.0
1,0,0.0,1.5,0.5,-3.5,5.0
2,0,1.0,3.0,1.5,-2.0,5.0
3,1,0.0,3.5,0.5,-1.5,5.0
4,1,1.0,5.0,1.5,0.0,5.0
5,2,0.0,5.0,0.0,0.0,5.0
6,2,1.0,5.0,0.0,0.0,5.0


<function l1_loss at 0x7fd77a2b72f0>


Unnamed: 0,epoch,batch_idx,p,Δp,p_err,p_gt
0,pre-opt,,1.0,0.0,-2.0,3.0
1,0,0.0,1.5,0.5,-1.5,3.0
2,0,1.0,11.5,10.0,8.5,3.0
3,1,0.0,11.0,-0.5,8.0,3.0
4,1,1.0,1.0,-10.0,-2.0,3.0
5,2,0.0,1.5,0.5,-1.5,3.0
6,2,1.0,11.5,10.0,8.5,3.0
7,3,0.0,11.0,-0.5,8.0,3.0
8,3,1.0,1.0,-10.0,-2.0,3.0
