# Setup

# Brief ML primer

Let's quickly demonstrate how training looks in Pytorch. We will train a small neural network to model the function
$$
f(x) = \mathrm{sinc}(x) := \frac{\sin(\pi x)}{\pi x}
$$

In [None]:
def target_fn(x):
    return torch.sinc(x)
fig, ax = plt.subplots(1,1, figsize=(3.5, 2.5))
xs = torch.linspace(-5, 5, steps=51)
ys = target_fn(xs)
ax.plot(grab(xs), grab(ys))
plt.show()

In [None]:
class ToyModel(torch.nn.Module):
    pass

In [None]:
def train_model():
    pass
res = train_model()

In [None]:
# TODO: plot true vs learned function

# Target

To explore learned flows, we will use a standard 2D target, a **mixture of Gaussians**:
$$
p(x) = \frac{1}{\sqrt{2\pi}} \sum_i \sigma(\alpha)_i e^{-(x-\mu_i)^2/2}.
$$
Above, $\sigma(\alpha)_i := e^{\alpha_i} / \sum_j e^{\alpha_j}$ is the Softmax function, which maps a vector $\vec{\alpha}$ to positive-definite and normalized weights.

In [None]:
class MixtureOfGaussians:
    # TODO: implement sample() and log_prob()
    pass

In [None]:
# this is an indexing trick that will be useful
a = np.arange(24).reshape(6, 4)
inds = np.array([1, 4, 3])
# this pulls rows "inds" out of a
extracted = a[inds[:,None], np.arange(4)]
print(f'{a=}')
print(' -> ')
print(f'{extracted=}')

In [None]:
# some reasonable parameters for the MoG
mu = torch.tensor([
    [5., 0.],
    [0, 5.],
    [-5., -3.]
])
alpha = torch.tensor([0.5, 0.0, 0.8])
target = MixtureOfGaussians(mu, alpha)

# TODO: draw samples and plot density

# Flow model

In [None]:
class Velocity(torch.nn.Module):
    pass

From Lecture 1:

In [None]:
def flow(x, velocity, *, n_step, tf=1.0, inverse=False):
    dt = tf/n_step
    ts = dt*torch.arange(n_step)
    logJ = torch.tensor(0.0)
    sign = 1
    if inverse:
        sign = -1
        ts = reversed(ts)
    for t in ts:
        # transport samples
        x = x + sign * dt * velocity.value(x, t)
        # estimate change of measure
        logJ = logJ + dt * velocity.div(x, t)
    return x, logJ

**Training option 1:** Reverse KL divergence between target and model.

In [None]:
def train_model_rkl():
    pass
res_rkl = train_model_rkl()

In [None]:
# TODO: plot loss history

In [None]:
# TODO: plot model vs true samples

**Training option 2:** Forward KL using samples from target.

In [None]:
def train_model_fkl():
    pass
res_fkl = train_model_fkl()