# Context choice

> The "Memorizing transformer" have a module which, for each token, choose between the local (small chunk passed through actual attention) and global (big document) embeddings. This is an implementation of two methods of such a choice - just trainable constant biases (which comes from original paper) and per-head linear classifiers.

In [1]:
#| default_exp context_choice

In [2]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Union

In [3]:
#| export
LOCAL_LOSS_K = 1.0

In [4]:
#| export
def local_score_loss(local_score: torch.FloatTensor) -> torch.FloatTensor:
    targets = torch.zeros(local_score.shape, dtype=local_score.dtype, device=local_score.device)
    return F.binary_cross_entropy_with_logits(local_score, targets)

In [5]:
#| export
class BaseContextChoice(nn.Module):
    """
    Base class for every context choice method.
    Basically each of these things is some kind of weighted average between the local context and global context.
    """
    def __init__(self, attention_heads: int, embedding_dim: int, loss_k: float = LOCAL_LOSS_K) -> None:
        """
        Initializer.
        :param attention_heads: how much attention heads each block of the original transformer have
        :param embedding_dim: how much embedding dimensions each block of the original transformer have
        """
        assert attention_heads > 0
        assert embedding_dim > 0
        assert embedding_dim % attention_heads == 0
        super(BaseContextChoice, self).__init__()
        self.attention_heads = attention_heads
        self.embedding_dim = embedding_dim
        self.head_dim = embedding_dim // attention_heads
        self._loss_component = 0
        self.loss_k = loss_k
    
    def get_loss_component(self) -> Union[torch.FloatTensor, float]:
        result = self._loss_component * self.loss_k
        self._loss_component = None
        return result
    
    def forward(self, embeddings_local: torch.FloatTensor, embeddings_global: torch.FloatTensor) -> torch.FloatTensor:
        """
        Apply the weighted average between embeddings_local and embeddings_global.
        """
        raise NotImplementedError("Each BaseContextChoice subclass must define their own forward method")

In [6]:
#| export
class ContextChoiceLinear(BaseContextChoice):
    def __init__(self, attention_heads: int, embedding_dim: int, loss_k: float = LOCAL_LOSS_K) -> None:
        super(ContextChoiceLinear, self).__init__(attention_heads, embedding_dim, loss_k)
        self.weights = nn.Parameter(torch.randn((self.attention_heads, self.head_dim, 1)))
        self.biases = nn.Parameter(torch.randn((self.attention_heads,)))

    def forward(self, embeddings_local: torch.FloatTensor, embeddings_global: torch.FloatTensor) -> torch.FloatTensor:
        batch_size, sequence_length, _ = embeddings_local.shape
        # batch_size x sequence_length x attention_heads x head_dim
        embeddings_local = embeddings_local.view((batch_size, sequence_length, self.attention_heads, -1))
        embeddings_global = embeddings_global.view((batch_size, sequence_length, self.attention_heads, -1))
        # batch_size x sequence_length x attention_heads x 1
        # b - batch size
        # s - sequence length
        # h - attention heads
        # d - head dim
        # a - 1
        local_logits = torch.einsum("bshd,hda->bsha", embeddings_local, self.weights) + self.biases.view((1, 1, self.attention_heads, 1))    
        local_score = F.sigmoid(local_logits)
        global_score = 1 - local_score
        # batch_size x sequence_length x attention_heads x head_dim
        embeddings_local_scaled = embeddings_local * local_score
        embeddings_global_scaled = embeddings_global * global_score
        embeddings_result = embeddings_local_scaled + embeddings_global_scaled
        
        self._loss_component = local_score_loss(local_logits)
        
        # batch_size x sequence_length x attention_heads * head_dim
        return embeddings_result.view((batch_size, sequence_length, self.attention_heads * self.head_dim))

In [7]:
#| export
class ContextChoiceConstant(BaseContextChoice):
    def __init__(self, attention_heads: int, embedding_dim: int, loss_k: float = LOCAL_LOSS_K) -> None:
        super().__init__(attention_heads, embedding_dim, loss_k)
        self.bias = nn.Parameter(torch.randn((self.attention_heads)))

    def forward(self, embeddings_local: torch.FloatTensor, embeddings_global: torch.FloatTensor) -> torch.FloatTensor:
        batch_size, sequence_length, _ = embeddings_local.shape
        # batch_size x sequence_length x attention_heads x head_dim
        embeddings_local = embeddings_local.view((batch_size, sequence_length, self.attention_heads, -1))
        embeddings_global = embeddings_global.view((batch_size, sequence_length, self.attention_heads, -1))
        # 1 x 1 x attention_heads x 1
        logits_local = self.bias.view((1, 1, self.attention_heads, 1))
        scores_local = F.sigmoid(logits_local)
        scores_global = 1 - scores_local
        # batch_size x sequence_length x attention_heads x head_dim
        embeddings_local_scaled = embeddings_local * scores_local
        embeddings_global_scaled = embeddings_global * scores_global
        embeddings_result = embeddings_local_scaled + embeddings_global_scaled
        
        self._loss_component = local_score_loss(logits_local)
        
        # batch_size x sequence_length x attention_heads * head_dim
        return embeddings_result.view((batch_size, sequence_length, self.attention_heads * self.head_dim))

## Testing

### ContextChoiceLinear

Since the "bshd,hda->bsha" einsum expression was generated with help of GPT-4 after I explained what I am going to do and gave it an example of for-loop + nn.Linear based example which do what I need - let's check it

In [8]:
def _test_initialize_einsum_classifier():
    # Initializing the einsum version
    classifier = ContextChoiceLinear(
        attention_heads=16,
        embedding_dim=16 * 8,
    )
    classifier.eval()
    return classifier

In [9]:
def _test_initialize_linear_classifiers(classifier):
    # Initializing equal(?) linear classifiers
    linears = []
    for i in range(classifier.attention_heads):
        layer = nn.Linear(in_features=classifier.head_dim, out_features=1)
        layer.load_state_dict({
            "weight": classifier.weights[i, :, :].transpose(0, 1),
            "bias": classifier.biases[[i]],
        })
        layer.eval()
        linears.append(layer)
    return linears    

In [10]:
def _test_apply_naive_linears(embeddings_local, embeddings_global, classifier, linears):
    # apply "naive" version
    embeddings_local = embeddings_local.view((
        embeddings_local.shape[0],
        embeddings_local.shape[1],
        classifier.attention_heads,
        classifier.head_dim
    ))
    embeddings_global = embeddings_global.view((
        embeddings_global.shape[0],
        embeddings_global.shape[1],
        classifier.attention_heads,
        classifier.head_dim
    ))
    embeddings_scored_all = []
    for i, layer in enumerate(linears):
        local_score = F.sigmoid(layer(embeddings_local[:, :, i, :]))
        global_score = 1 - local_score
        embeddings_scored = embeddings_local[:, :, i, :] * local_score + embeddings_global[:, :, i, :] * global_score
        embeddings_scored = embeddings_scored.view((
            embeddings_scored.shape[0],
            embeddings_scored.shape[1],
            1,
            -1,
        ))
        embeddings_scored_all.append(embeddings_scored)
    prediction_naive = torch.cat(embeddings_scored_all, dim=2)
    prediction_naive = prediction_naive.view((
        prediction_naive.shape[0],
        prediction_naive.shape[1],
        -1
    ))
    return prediction_naive

In [11]:
def _test_einsum_classifier():
    classifier = _test_initialize_einsum_classifier()
    linears = _test_initialize_linear_classifiers(classifier)
    # Generate embeddings
    embeddings_local = torch.randn((2, 64, classifier.attention_heads * classifier.head_dim))
    embeddings_global = torch.randn((2, 64, classifier.attention_heads * classifier.head_dim))
    # Than apply einsum version
    with torch.no_grad():
        prediction_einsum = classifier(embeddings_local, embeddings_global)
        prediction_naive = _test_apply_naive_linears(embeddings_local, embeddings_global, classifier, linears)
    # And see if the computational difference is small
    diff_max = (prediction_naive - prediction_einsum).abs().max()
    assert diff_max < 1e-6
    return diff_max

In [12]:
diffs = []
for i in range(100):
    torch.manual_seed(42 + i)
    diffs.append(_test_einsum_classifier())
max(diffs)

tensor(9.5367e-07)

Okay, that sounds like a success. Every time the difference between "naive" method and GPT-4-help-made-einsum-method was less than 1e-6

In [13]:
#| hide
import nbdev; nbdev.nbdev_export()