# Memory mechanism

## Implementation

The "memorizing transformer" paper introduce knn-attention mix which works the following way:

```
knn_attention(embeddings) = attention(
    knn(embeddingd_storage, embeddings) * weight_knn +
    embeddings * weight_local
)
```

Where knn is nearest-neightbours embeddings extraction and after each inference step the corresponding `embeddings` are added to the `embeddings_storage`

In [1]:
#| default_exp memory_collection

In [2]:
#| export
from __future__ import annotations
import os
from typing import Union, List
import pickle
import numpy as np
import pandas as pd
import torch
from sklearn.neighbors import NearestNeighbors

In [3]:
#| export
class BaseMemoryCollection:
    def reset(self) -> None:
        """
        Reset memory
        """
        raise NotImplementedError()
    
    def get(self, inputs: torch.FloatTensor) -> torch.FloatTensor:
        """
        Get relevant "memories".
        """
        raise NotImplementedError()
    
    def add(self, inputs: torch.FloatTensor) -> torch.FloatTensor:
        """
        Remember stuff
        """
        raise NotImplementedError()
    
    def save(self, directory: str) -> None:
        """
        Save memory state
        """
        raise NotImplementedError()
    
    @staticmethod
    def load(directory: str) -> BaseMemoryCollection:
        """
        Load memory state
        """
        raise NotImplementedError()

In [4]:
#| export
class CosineKnnMemoryCollection(BaseMemoryCollection):
    def __init__(self, max_temporary_buffer_size: int) -> None:
        super().__init__()
        self.max_temporary_buffer_size = max_temporary_buffer_size
        self.knns = []
        self.temporary_buffer = []
        self.vectors = []

    def _embeddings_numpy(self, embeddings: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray:
        if isinstance(embeddings, list):
            return np.array(embeddings)
        return embeddings
    
    def _norm(self, inputs: np.ndarray) -> np.ndarray:
        embedding_dim = inputs.shape[-1]
        
        # sqrt(embedding_dim * (x^2)) = 1.0, x>0
        # embedding_dim * x^2 = 1.0, x>0
        # x^2 = 1/embedding_dim, x>0
        # x = 1/sqrt(embedding_dim)
        filler_dim_value = 1 / np.sqrt(embedding_dim)

        norm = np.sqrt((inputs ** 2).sum(axis=-1, keepdims=True))
        
        inputs_normed = inputs / norm
        inputs_normed[norm[:, 0] == 0] = filler_dim_value
        
        return inputs_normed

    def _bruteforce_knn(self, embeddings) -> NearestNeighbors:
        # Cosine similarity and L2 distance on normed vectors have 1.0 corellation
        # Minkowski metric with p=2 is same as L2
        nn = NearestNeighbors(n_neighbors=1, algorithm="brute", metric="minkowski", p=2, n_jobs=-1)
        nn.fit(self._norm(self._embeddings_numpy(embeddings)))
        return nn
    
    def _knn(self, embeddings) -> NearestNeighbors:
        # Cosine similarity and L2 distance on normed vectors have 1.0 corellation
        # Minkowski metric with p=2 is same as L2
        nn = NearestNeighbors(n_neighbors=1, algorithm="auto", metric="minkowski", p=2, n_jobs=-1)
        nn.fit(self._norm(self._embeddings_numpy(embeddings)))
        return nn
    
    def get(self, inputs: torch.FloatTensor) -> torch.FloatTensor:
        vectors = inputs.detach().cpu().float().numpy()
        vectors_normed = self._norm(vectors)
        nn: NearestNeighbors
        knns = self.knns
        if self.temporary_buffer:
            knns = knns + [self._bruteforce_knn(self.temporary_buffer)]
        if len(knns) == 0:
            return inputs
        vectors_found = np.zeros(
            (inputs.shape[0], len(knns), inputs.shape[1]),
            dtype=np.float32
        )
        for i, nn in enumerate(knns):
            indices_local = nn.kneighbors(vectors_normed, return_distance=False)
            indices = indices_local + i * self.max_temporary_buffer_size
            indices = indices.ravel()
            nn_vectors = [self.vectors[i] for i in indices]
            vectors_found[:, i, :] = nn_vectors
        vectors_chosen = np.zeros((inputs.shape[0], inputs.shape[1]))
        for i in range(inputs.shape[0]):
            nn: NearestNeighbors = self._bruteforce_knn(vectors_found[i])
            index = nn.kneighbors(vectors_normed[[i]], return_distance=False)
            index = index.ravel()[0]
            vectors_chosen[i, :] = vectors_found[i, index, :]
        vectors_chosen_torch = torch.tensor(vectors_chosen, dtype=inputs.dtype, device=inputs.device)
        return vectors_chosen_torch
    
    def add(self, inputs: torch.FloatTensor) -> torch.FloatTensor:
        vectors = inputs.detach().cpu().float().numpy()
        vectors_list = list(vectors)
        self.temporary_buffer += vectors_list
        self.vectors += vectors_list
        if len(self.temporary_buffer) >= self.max_temporary_buffer_size:
            knn_count = len(self.temporary_buffer) // self.max_temporary_buffer_size
            rest_buffer_length = len(self.temporary_buffer) % self.max_temporary_buffer_size
            for i in range(knn_count):
                embeddings = self.temporary_buffer[i * self.max_temporary_buffer_size : (i + 1) * self.max_temporary_buffer_size]
                self.knns.append(self._knn(embeddings))
            if rest_buffer_length == 0:
                self.temporary_buffer = []
            else:
                self.temporary_buffer = self.temporary_buffer[-rest_buffer_length:]

    def save(self, directory: str) -> None:
        os.makedirs(directory, exist_ok=True)
        with open(os.path.join(directory, "cosine-knn-memory.pkl"), "wb") as dst:
            pickle.dump(self, dst)

    @staticmethod
    def load(directory: str) -> BaseMemoryCollection:
        with open(os.path.join(directory, "cosine-knn-memory.pkl"), "rb") as src:
            memory = pickle.load(src)
            assert isinstance(memory, CosineKnnMemoryCollection)
            return memory

## Tests

### Distances corellation

In [5]:
def _test_cosine_vectors(vec1, vec2):
    norm1 = np.sqrt((vec1 ** 2).sum())
    norm2 = np.sqrt((vec2 ** 2).sum())
    return 1.0 - (vec1 * vec2).sum() / (norm1 * norm2)

In [6]:
def _test_normed_l2_vectors(vec1, vec2):
    def _normed(vector):
        return vector / np.sqrt((vector ** 2).sum())
    
    vec1 = _normed(vec1)
    vec2 = _normed(vec2)
    return np.sqrt(((vec1 - vec2) ** 2).sum())

In [7]:
def _test_distances_corr():
    vecs1 = np.random.rand(10000 * 100).reshape([10000, -1])
    vecs2 = np.random.rand(10000 * 100).reshape([10000, -1])
    cosine = [_test_cosine_vectors(vecs1[i], vecs2[i]) for i in range(10000)]
    normed_l2 = [_test_normed_l2_vectors(vecs1[i], vecs2[i]) for i in range(10000)]
    correllation = pd.DataFrame({"cosine": cosine, "normed_l2": normed_l2}).corr(method="spearman")["cosine"]["normed_l2"]
    return correllation

In [8]:
corrs = []
for i in range(20):
    np.random.seed(42 + i)
    corrs.append(_test_distances_corr())
assert min(corrs) >= 0.99
min(corrs), max(corrs)

(1.0, 1.0)

### Memory

In [9]:
vectors_count = 10000
vectors_dim = 100
max_buffer_size = 1000
batch_size = 2
eps = 1e-6

np.random.seed(42)
vectors = torch.FloatTensor(np.random.rand(vectors_count * vectors_dim).reshape([vectors_count, -1]))
memory = CosineKnnMemoryCollection(max_buffer_size)
for i in range(vectors_count // batch_size):
    batch = vectors[i * batch_size : (i + 1) * batch_size]
    # In real life we should at first retrieve the similar embeddings than add new ones, but we're just testing memory mechanism itself
    memory.add(batch)
    batch_extracted = memory.get(batch)
    assert (batch - batch_extracted).abs().max() < eps
memory.save("temp-memory-test")
memory_loaded = CosineKnnMemoryCollection.load("temp-memory-test")
for i in range(vectors_count // batch_size):
    batch = vectors[i * batch_size : (i + 1) * batch_size]
    batch_extracted = memory_loaded.get(batch)
    assert (batch - batch_extracted).abs().max() < eps

In [10]:
#| hide
import nbdev; nbdev.nbdev_export()