# Hopfield - Self-attention
The update of the new energy function (notebook [3_hopfield-continuous-value.ipynb](./3_hopfield-continuous-value.ipynb)) is the self-attention of transformer networks.

References:
* https://ml-jku.github.io/hopfield-layers/#update

From equation:
$$
\xi^{new} = X\mathrm{softmax}(\beta X^T \xi)
$$

For $S$ state patterns $\Xi=(\xi_1,...,\xi_s)$, the equation can be generalized to:
$$
\Xi^{\mathrm{new}} = X\mathrm{softmax}(\beta X^T\Xi)
$$

Where $X^T$ can be considered as $N$ *raw **stored** patterns* $Y=(y_1,...y_N)^T$, which are mapped to an associative space via $W_K$, and $\Xi^T$ as $S$ *raw **state** patterns* $R=(\xi_1,...,\xi_S)^T$, which are mapped to an associative space via $W_Q$.

Then, by setting:
$$
Q = \Xi^T = RW_Q \\
K = X^T = YW_K \\
\beta = \frac{1}{\sqrt{d_k}}
$$

we obtain:
$$
(Q^{\mathrm{new}})^T = K^T \mathrm{softmax}(\frac{1}{\sqrt{d_k}}KQ^T)
$$

Where $W_Q$ and $W_K$ are matrices which map the respective patterns into the associative space. In the previous equation, the softmax is applied column-wise tot he matrix $KQ^T$. By transposing the equation, which also means softmax is now applied row-wise to its transposed input $QK^T$, we obtain:

$$
(Q^{\mathrm{new}})^T = \mathrm{softmax}(\frac{1}{\sqrt{d_k}}QK^T)K
$$

Now, by projecting $Q^{new}$ via another projection matrix $W_V$ we obtain:

$$
Z = Q^{new}W_V = \mathrm{softmax}(\frac{1}{\sqrt{d_k}}QK^T)KW_V = \mathrm{softmax}(\frac{1}{\sqrt{d_k}}QK^T)V
$$

Which is basically the transformer attention formula (As per Attention is All you need):
$$
\mathrm{Attention(Q, K, V)} = \mathrm{softmax}(\frac{QK^T}{\sqrt{d_k}})V
$$

Some remarks:
* Transformer based models usually implement embedding layers before the attention mechanism, i.e., what is feed into the attention mechanism is an embedding of the input/outputs.
    * These embeddings have trainable matrices that produces them during training.
* In the new Hopfield definition, the matrices that produce the embeddings are explicit in the formula, i.e., matrices $W_Q$, $W_K$, and $W_V$ are the matrices that transform the input/outpus into the associative space that is feed to the attention mechanism.
* One differencing aspect of original attention vs Hopfield is the value of $\beta$ parameter. Original attention fixes this to be dependent on the dimension of the embeddings, which for large values of $d_k$ will yield in smaller $\beta$, which in turn, as per explained in the new Hopfield paper, means the retrievals will tend to be metastable states or the average of similar patterns which can give us an intuition of why they work and why the concept of "Attention".
* The new Hopfield definition can be interpreted as a generalization of the attention mechanism.
* The result of the retrieval, which is the attention produced from the state patterns against the stored patterns, can be the input to fully connected layers for some classification task.
* Similarly, before the attention mechanism, there can be other feature extraction layers s.a. CNNs that will produce vectors for which store/retrieval process can be applied.

## Hopfield MNIST #1 - Predict using full patterns


In [1]:
import numpy as np
import torch
from torchvision.datasets import MNIST
from torchvision.transforms import transforms
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
import torch.nn.functional as F

random_seed = 123
train_split_fraction = 0.5
np.random.seed(random_seed)
torch.manual_seed(random_seed)

data_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(0, 0.1)
    ]
)

target_transform = transforms.Compose(
    [
        lambda x: torch.LongTensor([x]),
        lambda x: F.one_hot(x, 10),
    ]
)

train_set = MNIST(
    './mnist-train', 
    train=True, 
    download=True, 
    transform=data_transform,
    target_transform=target_transform
    )

test_set = MNIST(
    './mnist-test', 
    train=False, 
    download=True, 
    transform=data_transform,
    target_transform=target_transform
    )

train_set_size = len(train_set)
indices = list(range(train_set_size))
np.random.shuffle(indices)
split = int(np.floor(train_split_fraction * train_set_size))
stored_patterns_idx, train_idx = indices[split:], indices[:split]

stored_patterns_sampler = SubsetRandomSampler(stored_patterns_idx)
train_sampler = SubsetRandomSampler(train_idx)

train_loader = DataLoader(train_set, batch_size=256, sampler=train_sampler)
stored_patterns_loader = DataLoader(train_set, batch_size=split, sampler=stored_patterns_sampler)
stored_patterns = list(stored_patterns_loader)

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
def prepare_device(obj, use_cuda: bool = True):
    if use_cuda and torch.cuda.is_available():
        return obj.to("cuda")
    return obj

In [6]:
import torch.nn as nn
import torch

class HopfieldNet(nn.Module):
    def __init__(
        self, 
        store_dim: int, 
        hidden_store_dim: int, 
        state_dim: int, 
        hidden_state_dim: int, 
        value_dim: int,
        hidden_value_dim: int):
        super().__init__()
        
        self.store_dim = store_dim
        self.hidden_store_dim = hidden_store_dim
        self.state_dim = state_dim
        self.hidden_state_dim = hidden_state_dim
        self.value_dim = value_dim
        self.hidden_value_dim = hidden_value_dim
        self.device = device

        self.__init_parameters()
        self._reset_parameters()
    
    def __init_parameters(self):
        # state patterns
        self.WQ = nn.Parameter(
            torch.Tensor(self.state_dim, self.hidden_state_dim)
        )
        # stored patterns
        self.WK = torch.Tensor(
            torch.Tensor(self.store_dim, self.hidden_store_dim)
        )
        # value patterns
        self.WV = torch.Tensor(
            torch.Tensor(self.value_dim, self.hidden_value_dim)
        )

    def to(self, device: str):
        super().to(device)
        self.WQ = self.WQ.to(device)
        self.WK = self.WK.to(device)
        self.WV = self.WV.to(device)
        return self

    def _reset_parameters(self):
        nn.init.normal_(self.WQ, 0., 0.01)
        nn.init.normal_(self.WK, 0., 0.01)
        nn.init.normal_(self.WV, 0., 0.01)

    def print_shapes(self):
        print(f"WQ: {self.WQ.shape}")
        print(f"WK: {self.WK.shape}")
        print(f"WV: {self.WV.shape}")
        

    def forward(self, state_patterns, stored_patterns, value_patterns, beta=1.0):
        Q = state_patterns @ self.WQ
        K = stored_patterns @ self.WK
        V = value_patterns @ self.WV

        Z = torch.softmax(beta * Q @ K.T, dim=1) @ V
        return Z 

In [7]:
store_dim = 784 # 784 = 28 * 28, i.e., flattened images
state_dim = store_dim # stored and state paterns have the same dim
value_dim = 10 # the one-hot expected label from the state patterns
hidden_store_dim, hidden_state_dim = 256, 256
hidden_value_dim = value_dim # We don't need to embeed in a lower dimension the one hot encoding
beta = 1 / np.sqrt(store_dim) # beta as classic transformers

model = HopfieldNet(store_dim, hidden_store_dim, state_dim, hidden_state_dim, value_dim, hidden_value_dim)
model = prepare_device(model)
num_params = sum([p.numel() for p in model.parameters()])
print(f"Total number of parameters in model: {num_params}")

Total number of parameters in model: 200704


Shape operations test

In [17]:
model.print_shapes()
sim_raw_state = torch.randn((20, 784)).to(device)
sim_raw_stored = torch.randn((200, 784)).to(device)
sim_raw_value = torch.randn((200, 10)).to(device)
result = model(sim_raw_state, sim_raw_stored, sim_raw_value)
print(f"Shape of Result: {result.shape}")
assert list(result.shape) == [20, 10], "Result shape doesn't match with the expected result"


WQ: torch.Size([784, 256])
WK: torch.Size([784, 256])
WV: torch.Size([10, 10])
Shape of Result: torch.Size([20, 10])


In [18]:
import torch.optim as optim
import torchmetrics

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-2)
metric = torchmetrics.Accuracy()

In [66]:
from tqdm import tqdm
import torch.nn.functional as F

def train_loop(
    dataloader, 
    model, 
    stored_patterns_K, 
    stored_projections_V, 
    loss_fn, 
    optimizer):
    for R, y in dataloader:
        R = prepare_device(R).view(-1, 28 * 28)
        y = prepare_device(y.view(-1, 10)).to(torch.float32)
        
        pred = model(R, stored_patterns_K, stored_projections_V)
        loss = loss_fn(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return loss.item()


def evaluate(
    dataloader, 
    model,
    stored_patterns_K, 
    stored_projections_V, 
    metric_fn):
    model.eval()
    with torch.no_grad():
        for R, y in dataloader:
            R = prepare_device(R).view(-1, 28 * 28)
            y = prepare_device(y.view(-1, 10)).to("cpu")

            preds = model(R, stored_patterns_K, stored_projections_V).to("cpu")
            metric = metric_fn(preds, y)
        metric = metric_fn.compute()
        metric_fn.reset()
    return metric


def train(
    epochs, 
    dataloader, 
    model,
    stored_patterns_K,
    stored_projections_V, 
    loss_fn, 
    optimizer, 
    metric_fn):
    model.train()
    with tqdm(total=epochs) as progress:
        for _ in range(epochs):
            loss = train_loop(
                dataloader, 
                model,
                stored_patterns_K,
                stored_projections_V, 
                loss_fn, 
                optimizer)
            metric = evaluate(
                dataloader, 
                model, 
                stored_patterns_K,
                stored_projections_V, 
                metric_fn)
            progress.set_postfix({
                        "loss": f"{loss:.2f}",
                        "accuracy": f"{metric:.2f}"
                    })
            progress.update()

        

In [67]:
K = stored_patterns[0][0]
K_labels = stored_patterns[0][1]
K = K.view(-1, 28*28)
K_labels = K_labels.view(-1, 10) # ten expected classes
K = prepare_device(K)
K_labels = prepare_device(K_labels).to(torch.float32)

In [68]:
train(10, train_loader, model, K, K_labels, loss_fn, optimizer, metric)

100%|██████████| 10/10 [05:38<00:00, 33.84s/it, loss=2.30, accuracy=0.90]
