In [None]:
import torch

In [None]:
def zero_grad(p):
    if p.grad is not None:
        p.grad.zero_()

In [None]:
def l1_loss(y, yh):
    return torch.abs(y - yh)

def mse_loss(y, yh):
    return (y - yh)**2

In [None]:
def f(p, x):
    # Simple (verrrry) linear model.
    return p * x

def fit(p_gt, loss_fn, lr, num_epochs=3, x=[1.0, 3.0]):
    # Expected param + labeled dataset.
    p_gt = torch.tensor(p_gt)
    x = torch.tensor(x)
    y = f(p_gt, x)

    # Param to optimize
    p = torch.tensor(1.0, requires_grad=True)
    for epoch in range(num_epochs):
        # Batch size = 1
        for epoch_batch_idx, (xi, yi) in enumerate(zip(x, y)):
            # Show err.
            p_err = p.detach() - p_gt
            print((epoch, epoch_batch_idx), "p_err:", p_err)

            zero_grad(p)
            yhi = f(p, xi)
            loss = loss_fn(yi, yhi)
            loss.backward()

            # SGD update, no momentum.
            with torch.no_grad():
                v = p.grad  # velocity
                p -= lr * v  # step
    print()

In [None]:
# Optimal step size for linear model w/ mse loss
fit(p_gt=3.0, loss_fn=mse_loss, lr=0.5, num_epochs=2)

In [None]:
# L1 loss is param dependent :( (just see grad)

# Base case - meh. Takes two well-conditioned minibatches.
fit(p_gt=3.0, loss_fn=l1_loss, lr=0.5)
# Shift data via expected param. Takes longer to converge (duh).
fit(p_gt=5.0, loss_fn=l1_loss, lr=0.5)
# Shift data via data points. Now has a stable (but shitty) orbit.
fit(p_gt=3.0, loss_fn=l1_loss, lr=0.5, x=[1.0, 20.0], num_epochs=4)