# Modern Hopfield Networks (Dense Associative Memories)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from utils import get_data, get_samples_per_class

### Discrete (*[Krotov and Hopfield](https://arxiv.org/abs/1606.01164)*)

Energy function:
$$
E = -\sum_{i=1}^{N}F(x_{i}^{T}s)
$$
where $F$ is the interaction function, $x_{i}$ is the $ith$ stored pattern (one of $N$ stored patterns), $s$ is the current state.
For example, [Demircigil et al.]() introduced the exponential interaction function:
$$
E = -\sum_{i=1}^{N}\exp(x_{i}^{T}s)
$$
The above equation can be rewritten as:
$$
E = -\exp(lse(1,X^{T}s))
$$
where $X=(x_{1},...,x_{N})$ and $lse()$ is the log-sum-exp function:
$$
lse(l) = \log(\sum_{l=1}^{N}\exp(z_{l}))
$$
Here, instead of having a weight matrix with stored patterns, we update the state asynchronously - component by component (for dimension $d$, the state has $d$ components denoted as $s[l]$). State's $l$-th component is updated to minimize the network energy - so the update rule depends on the difference between the current state and the current state with the $l$-th component flipped (1 -> -1 or -1 -> 1).
$$
s^{new}[l] = sgn[-E(s^{+})+E(s^{-})] \\
\text{rewritten as} \\
s^{new}[l] = sgn[\sum_{i=1}^{N}\exp(x_{i}^{T}s^{+}) - \sum_{i=1}^{N}\exp(x_{i}^{T}s^{-})] 
$$
where $s^{+}$ is the current state with $s[l]=1$, and $s^{-}$ is the current state with $s[l]=-1$ 

In [None]:
X, y = get_data(max_rows=50)
X, y = get_samples_per_class(X, y)

plt.imshow(np.vstack(X))
plt.imshow(np.concatenate([x.reshape(28,28) for x in X], axis=1))
plt.title(str(y))
plt.show()
print(X.shape, y.shape)

In [None]:
# turn to polar values and unsqueeze for ease of manipulation
X[X < 125] = -1.
X[X >= 125] = 1.

In [None]:
# choose patterns to store
class_idxs_stored = [0,1,2,3,4,5,6,7,8,9]
X_stored = X[class_idxs_stored]

In [None]:
def logsumexp(z):
    z_dim = z.shape[0] if z.shape[0] > z.shape[1] else z.shape[1]
    c = z.max() # only for numerical stability
    return c + np.log(np.sum([np.exp(z[l] - c) for l in range(z_dim)]))

def exp_energy_func(s):
    tmp = logsumexp(X_stored.reshape(X_stored.shape[0], X_stored.shape[1]) @ s)
    return -np.exp(tmp)

def update_state(s):
    d = s.shape[0]
    for l in range(d):
        s_plus, s_minus = s.copy(), s
        s_plus[l] = 1.
        s_minus[l] = -1.
        # s[l] = np.sign(-exp_energy_func(s_plus) + exp_energy_func(s_minus)) # not numerically stable
        # exp is monotonically increasing function, hence (x > y) => (e^x > e^y)
        s_plus_exp_args_sum = logsumexp(X_stored @ s_plus)
        s_minus_exp_args_sum = logsumexp(X_stored @ s_minus)
        s[l] = np.sign(s_plus_exp_args_sum - s_minus_exp_args_sum)
    return s

In [None]:
# check if stored patterns are in fact fixed points
retrieved = []
for c_idx in class_idxs_stored:
    s = X_stored[c_idx].copy().reshape(-1,1)
    retrieved.append(update_state(s).reshape(28,28))
plt.imshow(np.concatenate(retrieved, axis=1))

In [None]:
# try retrieving random pattern
s = np.random.choice([-1.,1.], size=(784,1))
steps = []
for step_i in range(3):
    s = update_state(s)
    steps.append(s.reshape(28,28))
plt.imshow(np.concatenate(steps, axis=1))
plt.title("Retrieval steps")

As can be seen, in contrast to classical Hopfield Networks, the capacity is higher, pattern retrieval is much more robust - **it allows pulling apart close patterns**. In fact, the storage capacity is $C\cong2^{\frac{d}{2}}$

### Continuous

##### Energy function
The energy function from the discrete state:
$$
E = -\exp(lse(1,X^{T}s))
$$
can now be generalized to continuous-valued patterns. The new energy function is defined as:
$$
E = -lse(\beta,X^{T}s) + \frac{1}{2}s^{T}s + \beta^{-1}\log(N) + \frac{1}{2}M^{2}
$$
where $\beta$ is now the temperature and $M$ is the largest norm of all stored patterns. This last quadratic term with $M$ ensures that the state $s$ remains finite.
According to the [paper by Krotov and Hopfield](https://arxiv.org/abs/2008.06996), the stored patterns $X^{T}$ can be in this scenario viewed as weights from $s$ to hidden units, while $X$ can be viewed as weights from the hidden units to $s$.

##### Update rule
The above energy function allows deriving an update rule for the state pattern $s$ by the *Concave-Convex-Procedure* described by [Yuille and Rangarajan](https://papers.nips.cc/paper/2125-the-concave-convex-procedure-cccp.pdf).
1. The total energy $E(s)$ is split into convex and concave term: $E(s) = E_{1}(s) + E_{2}(s)$
    * the term $E_{1}(s) = \frac{1}{2}s^{T}s + C$ is convex ($C$ is a constant independent of $s$)
    * the term $E_{2}(s) = -lse(\beta,X^{T}s)$ is concave (shown in the paper)
2. The *Concave-Convex-Procedure* applied to $E(s)$ is:
$$
\nabla_{s}E_{1}(s^{t+1}) = - \nabla_{s}E_{2}(s^{t})
$$

$$
\nabla_{s}\left(\frac{1}{2}s^{T}s + C \right)(s^{t+1}) = \nabla_{s} lse\big(\beta,X^{T}s^t\big)
$$
$$
s^{t+1} = X \cdot {softmax}\big(\beta X^{T} s^{t} \big)
$$

where $\nabla_{s} lse\big(\beta,X^{T}s\big) = X \cdot {softmax}\big(\beta X^{T} s \big).$

Therefore, the update rule for a state pattern $s$ reads:
$$
s^{new} = X\cdot{softmax}\big(\beta X^{T} s \big)
$$

A few important properties *(From the paper)*:
* The *Concave-Convex-Procedure* for obtaining the update rule guarantees monotonical decrease of the energy function
* New energy function has global convergence to a local minimum
* Exponential storage capacity
* Convergence after one update step

In [None]:
X, y = get_data(max_rows=50)
X, y = get_samples_per_class(X, y)

plt.imshow(np.vstack(X))
plt.imshow(np.concatenate([x.reshape(28,28) for x in X], axis=1))
plt.title(str(y))
plt.show()
print(X.shape, y.shape)

In [None]:
# choose patterns to store
class_idxs_stored = [0,1,2,3,4,5,6,7,8,9]
X_stored = X[class_idxs_stored]

In [None]:
def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum()

def update_state(s, beta=1.):
    s_new = X_stored.T @ softmax(beta * (X_stored @ s).squeeze())
    return s_new

In [None]:
# check if stored patterns are in fact fixed points
retrieved = []
for c_idx in class_idxs_stored:
    s = X_stored[c_idx].copy().reshape(-1,1)
    retrieved.append(update_state(s).reshape(28,28))
plt.imshow(np.concatenate(retrieved, axis=1))

In [None]:
# try retrieving random pattern
s = np.random.randn(784,1) * 255
steps = []
for step_i in range(3):
    s = update_state(s)
    steps.append(s.reshape(28,28))
plt.imshow(np.concatenate(steps, axis=1))
plt.title("Retrieval steps")