# Memory mechanism

## Implementation

The "memorizing transformer" paper introduce knn-attention mix which works the following way:

```
knn_attention(embeddings) = attention(
    knn(embeddingd_storage, embeddings) * weight_knn +
    embeddings * weight_local
)
```

Where knn is nearest-neightbours embeddings extraction and after each inference step the corresponding `embeddings` are added to the `embeddings_storage`

In [1]:
#| default_exp memory_collection

In [2]:
#| export
from __future__ import annotations
import os
from typing import Union, List
import pickle
import numpy as np
import pandas as pd
import torch
from sklearn.neighbors import NearestNeighbors

In [3]:
#| export
class BaseMemoryCollection:
    def __init__(self, top_k: int, remember_until_position: int = 0):
        self.top_k = top_k
        self.remember_until_position = remember_until_position
        self._local2global_position_offset = 0
        self._remembered_tokens = 0

    def reset(self) -> None:
        """
        Reset memory
        """
        self._local2global_position_offset = 0
        self._remembered_tokens = 0
    
    def get(self, inputs: torch.FloatTensor) -> torch.FloatTensor:
        """
        Get relevant "memories".
        :param inputs: (Individual) sequence embedding matrix (2d array)
        """
        raise NotImplementedError()
    
    def _check_position_ids_sequential(self, local_position_ids: torch.LongTensor) -> None:
        """
        Check if local_position_ids is a sequential vector
        :param local_position_ids: position ids
        """
        assert len(local_position_ids.shape) == 1
        with torch.no_grad():
            if local_position_ids.shape[0]:
                assert local_position_ids[0] == 0
            id_diff = local_position_ids[1:] - local_position_ids[:-1]
            assert torch.all(id_diff == 1)
    
    def add(self, inputs: torch.FloatTensor, local_position_ids: torch.LongTensor) -> None:
        """
        Remember stuff.
        But only the part which (global) id <= self.remember_until_position
        
        :param inputs: (Individual) sequence embedding matrix (2d array)
        :param local_position_ids: (Individual) sequence token ids (inside the chunk processed by transformer)
        """
        self._check_position_ids_sequential(local_position_ids)
        assert len(inputs.shape) == 2
        assert inputs.shape[0] == local_position_ids.shape[0]
        with torch.no_grad():
            global_position_ids = self._local2global_position_offset + local_position_ids
            assert global_position_ids[0] <= self._remembered_tokens
            remember_mask = (global_position_ids < self.remember_until_position) & (global_position_ids >= self._remembered_tokens)
            remember_inputs = inputs.masked_select(remember_mask.unsqueeze(-1))\
                .view((-1, inputs.shape[-1]))
        tokens_to_remember = remember_inputs.shape[0]
        self._add_filtered(remember_inputs)
        self._remembered_tokens += tokens_to_remember
        self._local2global_position_offset += tokens_to_remember
    
    def _add_filtered(self, inputs: torch.FloatTensor) -> None:
        """
        Remember the inputs embeddings
        :param inputs: (Individual) sequence embedding matrix (2d array)
        """
        raise NotImplementedError()
    
    def save(self, directory: str) -> None:
        """
        Save memory state
        """
        raise NotImplementedError()
    
    @staticmethod
    def load(directory: str) -> BaseMemoryCollection:
        """
        Load memory state
        """
        raise NotImplementedError()

In [4]:
#| export
class CosineKnnMemoryCollection(BaseMemoryCollection):
    def __init__(self, top_k: int, max_temporary_buffer_size: int, remember_until_position: int = 0) -> None:
        super().__init__(top_k, remember_until_position)
        self.max_temporary_buffer_size = max_temporary_buffer_size
        self.knns = []
        self.temporary_buffer = []
        self.vectors = []
        self._buffer_knn = None

    def reset(self) -> None:
        super().reset()
        self.knns = []
        self.temporary_buffer = []
        self.vectors = []
        self._buffer_knn = None

    def _embeddings_numpy(self, embeddings: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray:
        if isinstance(embeddings, list):
            return np.array(embeddings)
        return embeddings
    
    def _norm(self, inputs: np.ndarray) -> np.ndarray:
        embedding_dim = inputs.shape[-1]
        
        # sqrt(embedding_dim * (x^2)) = 1.0, x>0
        # embedding_dim * x^2 = 1.0, x>0
        # x^2 = 1/embedding_dim, x>0
        # x = 1/sqrt(embedding_dim)
        filler_dim_value = 1 / np.sqrt(embedding_dim)

        norm = np.sqrt((inputs ** 2).sum(axis=-1, keepdims=True))
        
        inputs_normed = inputs / norm
        inputs_normed[norm[:, 0] == 0] = filler_dim_value
        
        return inputs_normed

    def _bruteforce_knn(self, embeddings, n_jobs=1) -> NearestNeighbors:
        # Cosine similarity and L2 distance on normed vectors have 1.0 corellation
        # Minkowski metric with p=2 is same as L2
        nn = NearestNeighbors(n_neighbors=self.top_k, algorithm="brute", metric="minkowski", p=2, n_jobs=n_jobs)
        nn.fit(self._norm(self._embeddings_numpy(embeddings)))
        return nn
    
    def _knn(self, embeddings, n_jobs=-1) -> NearestNeighbors:
        # Cosine similarity and L2 distance on normed vectors have 1.0 corellation (spearman)
        # Minkowski metric with p=2 is same as L2
        nn = NearestNeighbors(n_neighbors=self.top_k, algorithm="auto", metric="minkowski", p=2, n_jobs=n_jobs)
        nn.fit(self._norm(self._embeddings_numpy(embeddings)))
        return nn
    
    def _get_buffer_knn(self) -> NearestNeighbors:
        if self._buffer_knn is None:
            self._buffer_knn = self._bruteforce_knn(self.temporary_buffer)
        return self._buffer_knn
    
    def get(self, inputs: torch.FloatTensor) -> torch.FloatTensor:
        with torch.no_grad():
            vectors = inputs.detach().cpu().float().numpy()
        vectors_normed = self._norm(vectors)
        nn: NearestNeighbors
        knns = self.knns
        if self.temporary_buffer:
            knns = knns + [self._get_buffer_knn()]
        if len(knns) == 0:
            return inputs
        vectors_found = np.zeros(
            (inputs.shape[0], len(knns), self.top_k, inputs.shape[1]),
            dtype=np.float32
        )
        distances_found = np.zeros(
            (inputs.shape[0], len(knns), self.top_k)
        )
        for i, nn in enumerate(knns):
            distances, indices_local = nn.kneighbors(vectors_normed, return_distance=True)
            indices = indices_local + i * self.max_temporary_buffer_size
            for j in range(self.top_k):
                jth_vectors = [
                    self.vectors[index]
                    for index in indices[:, j]
                ]
                jth_distances = distances[:, j]
                vectors_found[:, i, j, :] = jth_vectors
                distances_found[:, i, j] = jth_distances
        vectors_found = vectors_found.reshape((inputs.shape[0], len(knns) * self.top_k, inputs.shape[1]))
        distances_found = distances_found.reshape((inputs.shape[0], len(knns) * self.top_k))

        vectors_chosen = np.zeros((inputs.shape[0], self.top_k, inputs.shape[1]))
        for i in range(inputs.shape[0]):
            item_vectors = vectors_found[i]
            item_distances = distances_found[i]
            item_distances_min = item_distances.argsort()[:self.top_k]
            item_vectors_chosen = [item_vectors[j] for j in item_distances_min]
            for j, vector in enumerate(item_vectors_chosen):
                vectors_chosen[i, j, :] = vector
        
        with torch.no_grad():
            vectors_chosen_torch = torch.tensor(vectors_chosen, dtype=inputs.dtype, device=inputs.device)
        return vectors_chosen_torch
    
    def _add_filtered(self, inputs: torch.FloatTensor) -> None:
        with torch.no_grad():
            vectors = inputs.detach().cpu().float().numpy()
        vectors_list = list(vectors)
        if len(vectors_list):
            self._buffer_knn = None
        self.temporary_buffer += vectors_list
        self.vectors += vectors_list
        if len(self.temporary_buffer) >= self.max_temporary_buffer_size:
            knn_count = len(self.temporary_buffer) // self.max_temporary_buffer_size
            rest_buffer_length = len(self.temporary_buffer) % self.max_temporary_buffer_size
            for i in range(knn_count):
                embeddings = self.temporary_buffer[i * self.max_temporary_buffer_size : (i + 1) * self.max_temporary_buffer_size]
                self.knns.append(self._knn(embeddings))
            if rest_buffer_length == 0:
                self.temporary_buffer = []
            else:
                self.temporary_buffer = self.temporary_buffer[-rest_buffer_length:]

    def save(self, directory: str) -> None:
        os.makedirs(directory, exist_ok=True)
        with open(os.path.join(directory, "cosine-knn-memory.pkl"), "wb") as dst:
            pickle.dump(self, dst)

    @staticmethod
    def load(directory: str) -> BaseMemoryCollection:
        with open(os.path.join(directory, "cosine-knn-memory.pkl"), "rb") as src:
            memory = pickle.load(src)
            assert isinstance(memory, CosineKnnMemoryCollection)
            return memory

## Tests

### Distances corellation

In [5]:
def _test_cosine_vectors(vec1, vec2):
    norm1 = np.sqrt((vec1 ** 2).sum())
    norm2 = np.sqrt((vec2 ** 2).sum())
    return 1.0 - (vec1 * vec2).sum() / (norm1 * norm2)

In [6]:
def _test_normed_l2_vectors(vec1, vec2):
    def _normed(vector):
        return vector / np.sqrt((vector ** 2).sum())
    
    vec1 = _normed(vec1)
    vec2 = _normed(vec2)
    return np.sqrt(((vec1 - vec2) ** 2).sum())

In [7]:
def _test_distances_corr():
    vecs1 = np.random.rand(10000 * 100).reshape([10000, -1])
    vecs2 = np.random.rand(10000 * 100).reshape([10000, -1])
    cosine = [_test_cosine_vectors(vecs1[i], vecs2[i]) for i in range(10000)]
    normed_l2 = [_test_normed_l2_vectors(vecs1[i], vecs2[i]) for i in range(10000)]
    correllation = pd.DataFrame({"cosine": cosine, "normed_l2": normed_l2}).corr(method="spearman")["cosine"]["normed_l2"]
    return correllation

In [8]:
corrs = []
for i in range(20):
    np.random.seed(42 + i)
    corrs.append(_test_distances_corr())
assert min(corrs) >= 0.99
min(corrs), max(corrs)

(1.0, 1.0)

### Memory

In [9]:
vectors_count = 10000
vectors_dim = 100
max_buffer_size = 1000
batch_size = 2
eps = 1e-6

np.random.seed(42)
vectors = torch.FloatTensor(np.random.rand(vectors_count * vectors_dim).reshape([vectors_count, -1]))
batch_indices = torch.LongTensor([i for i in range(batch_size)])

memory = CosineKnnMemoryCollection(top_k=1, max_temporary_buffer_size=max_buffer_size, remember_until_position=vectors_count)

for i in range(vectors_count // batch_size):
    batch = vectors[i * batch_size : (i + 1) * batch_size]
    # In real life we should at first retrieve the similar embeddings than add new ones, but we're just testing memory mechanism itself
    memory.add(batch, batch_indices)
    batch_extracted = memory.get(batch)
    assert (batch.view(batch_extracted.shape) - batch_extracted).abs().max() < eps

In [10]:
np.random.seed(202305)
oor_vectors = torch.FloatTensor(np.random.rand(batch_size * vectors_dim).reshape([batch_size, -1]))
memory.add(oor_vectors, batch_indices)
assert (oor_vectors - memory.get(oor_vectors).view(oor_vectors.shape)).abs().mean() >= 0.1

In [11]:
memory.save("temp-memory-test")
memory_loaded = CosineKnnMemoryCollection.load("temp-memory-test")
for i in range(vectors_count // batch_size):
    batch = vectors[i * batch_size : (i + 1) * batch_size]
    batch_extracted = memory_loaded.get(batch)
    assert (batch.view(batch_extracted.shape) - batch_extracted).abs().max() < eps

In [12]:
#| hide
import nbdev; nbdev.nbdev_export()