In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

**Definition 1** (Associative Memory)\
Given a set of keys $\mathcal K\in \reals^{d_k}$ and values $\mathcal V \in \reals^{d_v}$, associative memory is an *operator* $\mathcal M : \mathcal K \mapsto \mathcal V$\
Objective $\tilde{\mathcal{L}}(\cdot, \cdot)$ measures the quality of the mapping. The optimal mapping given an objective is defined as:
$$\mathcal M^* = \argmin_\mathcal{M} \tilde{\mathcal{L}} (\mathcal M(\mathcal K), \mathcal V)$$

This definition seems to simply conceptualize model parameters (like weights) as defining a mapping between any arbitrary input keys and values. So far, this is just explicating what we already know about parameters (especially w.r.t. matrices of weights which literally are representations of linear maps). I suppose it is just a bit more general because we are stating that the objective is to pick the optimal mapping that minimizes the objective, which could include architectural considerations like rgularization and non-linearities, rather than just the learnable parameters.

For comprehension: "Memory" $\approxeq$ "Map"

### Simple MLP Example

Consider a **1-Layer** MLP with parameters $W$. The optimization problem is: $$W^* = \argmin_W \mathcal L (W, \mathcal D_{tr})$$
Where $\mathcal D_{tr}$ is the training dataset.

Then the gradient descent update rule is:
$$
\begin{align*}
& W_{t+1} = W_t - \eta \nabla_{W_t} \mathcal L (W_t, x_{t+1}) \\
& \ \ \ \  \ \  \ \ \ = W_t - \eta \nabla_{y_{t+1}} \mathcal L (W_t, x_{t+1}) \otimes x_{t+1}
\end{align*}
$$

The second expression comes from noting that $\frac{\partial \mathcal L}{\partial w_i} = \frac{\partial \mathcal L}{\partial y_{t+1}} \frac{\partial y_{t+1}}{\partial w_i}$ and $\frac{\partial y_{t+1}}{\partial w_i} = x_{t+1}$ **and that** the paper denotes $y_t \coloneqq \hat y$\
**NOTE:** This is seemingly in the context of *online* or ***Streaming*** learning, wherein each input has a time-step $t$. The model state at time $t+1$ begins with weights from time $t$, i.e. $W_t$. Then, it predicts $y_{t+1}$ using the weights $W_t$, **and then** updates them to get $W_{t+1}$. Hence why $y_{t+1}$ is used in the gradient.


**Reformulation**\
Now, rewrite the optimization problem as the *regularized* minimization of the inner-product between the predicted output and the gradient:
$$
\begin{align*}
& W_{t+1} = \argmin_W \langle Wx_{1+1}, \nabla_{y_{t+1}} \mathcal L (W_t, x_{t+1}) \rangle + \frac{1}{2\eta} \|W - W_t \|_2^2 \\
& \ \ \ \ \ \ \ \ \ = \argmin_W \langle y_{t+1}, u_{t+1} \rangle + \frac{1}{2\eta} \|W - W_t \|_2^2 
\end{align*}
$$