# Retrieval-Enhanced Transformer (Retro)
# https://nn.labml.ai/transformers/retro/index.html


This is a PyTorch implementation of the paper Improving language models by retrieving from trillions of tokens.

It builds a database of chunks of text. It is a key-value database where the keys are indexed by the BERT embeddings of the chunks. They use a frozen pre-trained BERT model to calculate these embeddings. The values are the corresponding chunks and an equal length of text proceeding that chunk.

Then the model retrieves text similar (nearest neighbors) to the input to the model from this database. These retrieved texts are used to predict the output.

Since we use a frozen BERT model for retrieval we can pre-calculate all the nearest neighbors for the training dataset. This speeds up the training process.

In [1]:
%%capture
! pip install transformers
! pip install labml-nn

In [None]:
# %%capture
! apt install libomp-dev
! pip install faiss-cpu
# ! pip install faiss-gpu

# BERT Embeddings of chunks of text
# https://nn.labml.ai/transformers/retro/bert_embeddings.html

In [2]:
"""
---
title: BERT Embeddings of chunks of text
summary: >
  Generate BERT embeddings for chunks using a frozen BERT model
---

# BERT Embeddings of chunks of text

This is the code to get BERT embeddings of chunks for [RETRO model](index.html).
"""

from typing import List

import torch
from transformers import BertTokenizer, BertModel

from labml import lab, monit


class BERTChunkEmbeddings:
    """
    ## BERT Embeddings

    For a given chunk of text $N$ this class generates BERT embeddings $\text{B\small{ERT}}(N)$.
    $\text{B\small{ERT}}(N)$ is the average of BERT embeddings of all the tokens in $N$.
    """

    def __init__(self, device: torch.device):
        self.device = device

        # Load the BERT tokenizer from [HuggingFace](https://huggingface.co/bert-base-uncased)
        with monit.section('Load BERT tokenizer'):
            self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased',
                                                           cache_dir=str(
                                                               lab.get_data_path() / 'cache' / 'bert-tokenizer'))

        # Load the BERT model from [HuggingFace](https://huggingface.co/bert-base-uncased)
        with monit.section('Load BERT model'):
            self.model = BertModel.from_pretrained("bert-base-uncased",
                                                   cache_dir=str(lab.get_data_path() / 'cache' / 'bert-model'))

            # Move the model to `device`
            self.model.to(device)

    @staticmethod
    def _trim_chunk(chunk: str):
        """
        In this implementation, we do not make chunks with a fixed number of tokens.
        One of the reasons is that this implementation uses character-level tokens and BERT
        uses its sub-word tokenizer.

        So this method will truncate the text to make sure there are no partial tokens.

        For instance, a chunk could be like `s a popular programming la`, with partial
        words (partial sub-word tokens) on the ends.
        We strip them off to get better BERT embeddings.
        As mentioned earlier this is not necessary if we broke chunks after tokenizing.
        """
        # Strip whitespace
        stripped = chunk.strip()
        # Break words
        parts = stripped.split()
        # Remove first and last pieces
        stripped = stripped[len(parts[0]):-len(parts[-1])]

        # Remove whitespace
        stripped = stripped.strip()

        # If empty return original string
        if not stripped:
            return chunk
        # Otherwise, return the stripped string
        else:
            return stripped

    def __call__(self, chunks: List[str]):
        """
        ### Get $\text{B\small{ERT}}(N)$ for a list of chunks.
        """

        # We don't need to compute gradients
        with torch.no_grad():
            # Trim the chunks
            trimmed_chunks = [self._trim_chunk(c) for c in chunks]

            # Tokenize the chunks with BERT tokenizer
            tokens = self.tokenizer(trimmed_chunks, return_tensors='pt', add_special_tokens=False, padding=True)

            # Move token ids, attention mask and token types to the device
            input_ids = tokens['input_ids'].to(self.device)
            attention_mask = tokens['attention_mask'].to(self.device)
            token_type_ids = tokens['token_type_ids'].to(self.device)
            # Evaluate the model
            output = self.model(input_ids=input_ids,
                                attention_mask=attention_mask,
                                token_type_ids=token_type_ids)

            # Get the token embeddings
            state = output['last_hidden_state']
            # Calculate the average token embeddings.
            # Note that the attention mask is `0` if the token is empty padded.
            # We get empty tokens because the chunks are of different lengths.
            emb = (state * attention_mask[:, :, None]).sum(dim=1) / attention_mask[:, :, None].sum(dim=1)

            #
            return emb


def _test():
    """
    ### Code to test BERT embeddings
    """
    from labml.logger import inspect

    # Initialize
    device = torch.device('cuda:0')
    bert = BERTChunkEmbeddings(device)

    # Sample
    text = ["Replace me by any text you'd like.",
            "Second sentence"]

    # Check BERT tokenizer
    encoded_input = bert.tokenizer(text, return_tensors='pt', add_special_tokens=False, padding=True)

    inspect(encoded_input, _expand=True)

    # Check BERT model outputs
    output = bert.model(input_ids=encoded_input['input_ids'].to(device),
                        attention_mask=encoded_input['attention_mask'].to(device),
                        token_type_ids=encoded_input['token_type_ids'].to(device))

    inspect({'last_hidden_state': output['last_hidden_state'],
             'pooler_output': output['pooler_output']},
            _expand=True)

    # Check recreating text from token ids
    inspect(bert.tokenizer.convert_ids_to_tokens(encoded_input['input_ids'][0]), _n=-1)
    inspect(bert.tokenizer.convert_ids_to_tokens(encoded_input['input_ids'][1]), _n=-1)

    # Get chunk embeddings
    inspect(bert(text))


#
if __name__ == '__main__':
    _test()

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


# Database for nearest neighbor retrieval
# https://nn.labml.ai/transformers/retro/database.html

In [8]:
"""
---
title: Database for nearest neighbor retrieval
summary: >
  Nearest neighbor retrieval and creation of the database
---

# Database for nearest neighbor retrieval

This is the build the database and retrieves nearest neighbors for
 [RETRO model](index.html).

We use [FAISS library](https://faiss.ai/) for the database whilst the paper had used the SCaNN library.
"""

from typing import List, Optional

import faiss
import numpy as np
import torch

from labml import lab, monit
from labml_helpers.datasets.text import TextFileDataset
from labml_nn.transformers.retro.bert_embeddings import BERTChunkEmbeddings


def build_database(chunk_len: int = 16, batch_size: int = 64, d_emb: int = 768, n_centeroids: int = 256,
                   code_size: int = 64, n_probe: int = 8, n_train: int = 50_000):
    """
    ## Build Database

    * `chunk_len` is the length of a chunk (number of characters)
    * `batch_size` is the batch size to use when calculating $\text{B\small{ERT}}(N)$
    * `d_emb` is the number of features in $\text{B\small{ERT}}(N)$ embeddings
        [lists to select in FAISS index](https://faiss.ai/cpp_api/struct/structfaiss_1_1IndexIVFPQ.html)
    * `n_centeroids` is the number of lists in the index
    * `code_size` encoded vector size in the index
    * `n_probe` is the number of lists to probe
    * `n_train' is the number of keys to train the index on
    """

    # Load the dataset text file
    dataset = TextFileDataset(
        lab.get_data_path() / 'tiny_shakespeare.txt',
        list,
        url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')

    # Get training data (a string)
    text = dataset.train

    # Split the text into chunks of `chunk_length`
    chunks = [text[i:i + chunk_len] for i in range(0, len(text), chunk_len) if i + chunk_len * 2 < len(text)]
    # Get the offsets of each of the chunks
    chunk_offsets = np.array([i for i in range(0, len(text), chunk_len) if i + chunk_len * 2 < len(text)])
    # Number of chunks
    n_chunks = len(chunks)

    # Initialize BERT to get $\text{B\small{ERT}}(N)$
    bert = BERTChunkEmbeddings(torch.device('cuda:0'))

    # Get chunk embeddings by processing `batch_size` number of chunks on each iteration
    chunk_emb = []
    for i in monit.iterate('Get embeddings', range(0, n_chunks, batch_size)):
        chunk_emb.append(bert(chunks[i: i + batch_size]).cpu())
    # Merge them into a single tensor
    chunk_emb = torch.cat(chunk_emb, dim=0).numpy()

    # Create the [FAISS index](https://faiss.ai/cpp_api/struct/structfaiss_1_1IndexIVFPQ.html)
    quantizer = faiss.IndexFlatL2(d_emb)
    index = faiss.IndexIVFPQ(quantizer, d_emb, n_centeroids, code_size, 8)
    index.nprobe = n_probe

    # Get a random sample of the the chunk indexes
    random_sample = np.random.choice(np.arange(n_chunks), size=[min(n_train, n_chunks)], replace=False)

    # Train the index to store the keys
    with monit.section('Train index'):
        index.train(chunk_emb[random_sample])

    # Add the chunks to the index in batches of size `1024`
    for s in monit.iterate('Index', range(0, n_chunks, 1024)):
        e = min(s + 1024, n_chunks)
        # Add to index
        index.add_with_ids(chunk_emb[s:e], chunk_offsets[s: e])

    # Save the index
    with monit.section('Save'):
        faiss.write_index(index, str(lab.get_data_path() / 'retro.index'))


class RetroIndex:
    """
    ## Index for retrieving nearest neighbors
    """

    def __init__(self, chunk_len: int = 16, n_probe: int = 8,
                 n_neighbors: int = 2, n_extra: int = 2,
                 exclude_neighbor_span: int = 8):
        """
        * `chunk_len` is the chunk length
        * `n_probe` is the number of lists to probe
        * `n_neighbors` is the number of neighbors to retrieve
        * `n_extra` is the number of extra neighbors to retrieve since we will be
            removing neighbors overlapping with the query chunk
        * `exclude_neighbor_span` is the extra text length to avoid when checking for overlaps
        """

        self.n_neighbors = n_neighbors
        self.chunk_len = chunk_len
        self.exclude_neighbor_span = exclude_neighbor_span
        self.n_extra = n_extra

        # Initialize BERT to get $\text{B\small{ERT}}(N)$
        self.bert = BERTChunkEmbeddings(torch.device('cuda:0'))
        # Load the database
        with monit.section('Load index'):
            self.index = faiss.read_index(str(lab.get_data_path() / 'retro.index'))
            self.index.nprobe = n_probe

    def filter_neighbors(self, offset: int, neighbor_offsets: List[int]):
        """
        #### Filter neighbors that overlap with the query
        
        The positions of the neighbors are given by `neighbor_offsets` and the position
        of the query chunk is `offset`.
        """
        return [n for n in neighbor_offsets
                if n < offset - (self.chunk_len + self.exclude_neighbor_span)
                or n > offset + (self.chunk_len + self.exclude_neighbor_span)]

    def __call__(self, query_chunks: List[str], offsets: Optional[List[int]]):
        """
        ### Retrieve nearest neighbors
        """

        # Get $\text{B\small{ERT}}(N)$ of query chunks
        emb = self.bert(query_chunks).cpu()

        # Get `n_neighbors + n_extra` nearest neighbors from the database
        distance, neighbor_offsets = self.index.search(emb.numpy(), self.n_neighbors + self.n_extra)

        # If the query chunk offsets are given filter out overlapping chunks
        if offsets is not None:
            neighbor_offsets = [self.filter_neighbors(off, n_off)
                                for off, n_off in zip(offsets, neighbor_offsets)]

        # Get the closest `n_neighbors` after filtering
        neighbor_offsets = [n_off[:self.n_neighbors] for n_off in neighbor_offsets]

        #
        return neighbor_offsets


#
if __name__ == '__main__':
    build_database()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


# RETRO model
# https://nn.labml.ai/transformers/retro/model.html

In [9]:
"""
---
title: RETRO model
summary: >
  RETRO model with encoder for neighbors and autoregressive decoder
---

# RETRO model

This is the model definition for
 [RETRO](index.html).
"""

import math
from typing import Set

import torch
from torch import nn

from labml.logger import inspect


class RotaryPositionalEmbeddings(nn.Module):
    """
    ## [RoPE embeddings](../rope/index.html)

    *We use rotary position embeddings in self-attention layers.
    We assume the positional information gets embedded in embeddings
    and therefore not use them in causal attention.
    [Non-causal self-attention needs explicit positional information
     because it cannot infer it](https://papers.labml.ai/paper/3999902edc8511eba3db37f65e372566).*
    """

    def __init__(self, d: int, base: int = 10_000):
        """
        * `d` is the number of features $d$
        * `base` is the constant used for calculating $\Theta$
        """
        super().__init__()
        # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
        self.theta = nn.Parameter(1. / (base ** (torch.arange(0, d, 2).float() / d)), requires_grad=False)

    def forward(self, x: torch.Tensor):
        """
        * `x` is the Tensor at the head of a key or a query with shape `[ batch_size, seq_len, n_heads, d]`
        """
        # Extract the shape
        batch_size, seq_len, n_heads, d = x.shape

        # $\frac{d}{2}$
        d_2 = d // 2

        # Create position indexes `[0, 1, ..., seq_len - 1]`
        seq_idx = torch.arange(seq_len, device=x.device).type_as(self.theta)

        # Calculate the product of position index and $\theta_i$
        idx_theta = torch.einsum('n,d->nd', seq_idx, self.theta)

        # Concatenate so that for row $m$ we have
        # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta 0, m \theta 1, ..., m \theta_{\frac{d}{2}}]$
        idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)

        # Calculate
        # $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., -x^{(\frac{d}{2})}]$
        neg_half_x = torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)

        # Calculate
        #
        # \begin{align}
        # \begin{pmatrix}
        # x^{(i)}_m \cos m \theta_i - x^{(i + \frac{d}{2})}_m \sin m \theta_i \\
        # x^{(i + \frac{d}{2})}_m \cos m\theta_i + x^{(i)}_m \sin m \theta_i \\
        # \end{pmatrix} \\
        # \end{align}
        #
        # for $i \in {1, 2, ..., \frac{d}{2}}$
        rx = (x * idx_theta2.cos()[None, :, None, :]) + (neg_half_x * idx_theta2.sin()[None, :, None, :])

        #
        return rx


class SelfAttention(nn.Module):
    """
    ## Self-Attention Layer $\text{A\small{TTN}}$

    This applies causal and non-causal [multi-headed self-attention](../mha.html).
    """

    def __init__(self, d_model: int, n_heads: int, d_k: int, is_causal: bool):
        """
        * `d_model` is the number of features in transformer embeddings
        * `n_heads` is the number of attention heads
        * `d_k` is the number of features per head
        * `is_causal` indicates whether this is causal attention (masked)
        """
        super().__init__()

        self.is_causal = is_causal
        self.n_heads = n_heads
        self.d_k = d_k

        # To scale attentions before softmax by $\frac{1}{\sqrt{d_k}}$
        self.scale = 1 / math.sqrt(self.d_k)

        # Linear layers for query, key and value heads.
        self.query = nn.Linear(d_model, n_heads * d_k)
        self.key = nn.Linear(d_model, n_heads * d_k)
        self.value = nn.Linear(d_model, n_heads * d_k)

        # Pre-norm layer. The paper uses RMSNorm instead.
        self.norm = nn.LayerNorm(d_model)

        # Softmax for attention probabilities
        self.softmax = nn.Softmax(dim=-1)

        # Rotary positional embeddings
        self.rotary_pe = RotaryPositionalEmbeddings(self.d_k)

        # Final linear layer
        self.output = nn.Linear(n_heads * d_k, d_model)

    def mask_attention(self, attn: torch.Tensor):
        """
        ### Mask the attention layer for causal attention

        * `attn` is the attention matrix of shape `[batch_size, n_heads, seq_len, seq_len]`
        """

        # No masking for non-causal attention
        if not self.is_causal:
            return attn

        # Create a triangular mask
        mask = torch.tril(attn.new_ones(attn.shape[-2:]))
        # Filter by the mask
        return attn.masked_fill(mask == 0, float('-inf'))

    def forward(self, h: torch.Tensor):
        """
        * `h` is the transformer embeddings of shape `[batch_size, seq_len, d_model]`
        """

        # Residual connection
        h_res = h

        # Pre-normalization
        h = self.norm(h)

        # Get query, key, and values and split them in to heads.
        # These will have shapes `[batch_size, seq_len, n_heads, d_k]`
        mh_shape = (*h.shape[:-1], self.n_heads, self.d_k)
        q = self.query(h).view(mh_shape)
        k = self.key(h).view(mh_shape)
        v = self.value(h).view(mh_shape)

        # Apply rotary positional embeddings
        q = self.rotary_pe(q)
        k = self.rotary_pe(k)

        # Calculate attentions
        attn = torch.einsum('bihd,bjhd->bhij', q, k)
        # Scale it by $\frac{1}{\sqrt{d_k}}$
        attn = attn * self.scale

        # Apply masks if it's causal attention
        attn = self.mask_attention(attn)

        # Calculate attention probabilities
        attn = self.softmax(attn)

        # Get values
        h = torch.einsum("bhij,bjhd->bihd", attn, v)

        # Change from shape `[batch_size, seq_len, n_heads, d_k]`
        # to `[batch_size, seq_len, n_heads * d_k]`
        h = h.reshape(*h.shape[:-2], -1)

        # Apply final linear layer.
        # The result will have shape `[batch_size, seq_len, d_model]`
        h = self.output(h)

        # Add the residual connection
        return h + h_res


class CrossAttention(nn.Module):
    """
    ## Cross-Attention Layer $\text{C\small{A}}$

    This is similar to the self-attention layer defined above, except that
    it gets keys and values from a different set of embeddings than the queries.

    This is used in the encoder to encode the retrieved chunks based on the
    input chunks.

    *We do not use any explicit positional embeddings here.
    We assume that the model can represent positional information in the embeddings implicitly.*
    """

    def __init__(self, d_model: int, n_heads: int, d_k: int):
        """
        * `d_model` is the number of features in transformer embeddings
        * `n_heads` is the number of attention heads
        * `d_k` is the number of features per head
        """
        super().__init__()

        self.n_heads = n_heads
        self.d_k = d_k

        # To scale attentions before softmax by $\frac{1}{\sqrt{d_k}}$
        self.scale = 1 / math.sqrt(self.d_k)

        # Linear layers for query, key and value heads.
        self.query = nn.Linear(d_model, n_heads * d_k)
        self.key = nn.Linear(d_model, n_heads * d_k)
        self.value = nn.Linear(d_model, n_heads * d_k)

        # Pre-norm layer for the query embeddings. The paper uses RMSNorm instead.
        self.norm = nn.LayerNorm(d_model)

        # Softmax for attention probabilities
        self.softmax = nn.Softmax(dim=-1)

        # Final linear layer
        self.output = nn.Linear(n_heads * d_k, d_model)

    def forward(self, e: torch.Tensor, h: torch.Tensor):
        """
        * `e` are the retrieved nearest neighbor chunk embeddings with shape
          `[batch_size, chunks, neighbors, neighbor_len, d_model]`
        * `h` are the input chunks from which the nearest neighbors were retrieved with shape
          `[batch_size, chunks, chunk_len, d_model]`. This is already normalized.
        """

        # Residual connection
        e_res = e

        # Normalize retrieved chunks
        e = self.norm(e)

        # Get query from the retrieved chunks
        q = self.query(e).view(*e.shape[:-1], self.n_heads, self.d_k)
        # Get keys and values from the input chunks
        k = self.key(h).view(*h.shape[:-1], self.n_heads, self.d_k)
        v = self.value(h).view(*h.shape[:-1], self.n_heads, self.d_k)

        # Calculate attention scores for all chunks.
        # Each retrieved neighbor will pay attention to the original chunk that retrieved it.
        # This will have shape `[batch_size, chunks, neighbors, n_heads, neighbor_len, chunk_len]`
        attn = torch.einsum('bcnihd,bcjhd->bcnhij', q, k)
        # Scale attention scores
        attn = attn * self.scale

        # Calculate softmax across the last dimension
        attn = self.softmax(attn)

        # Gather values
        e = torch.einsum("bcnhij,bcjhd->bcnihd", attn, v)

        # Change from shape `[batch_size, chunks, neighbors, neighbor_len, n_heads, d_k]`
        # to `[batch_size, chunks, neighbors, neighbor_len, n_heads * d_k]`
        e = e.reshape(*e.shape[:-2], -1)

        # Apply final linear layer.
        # The result will have shape `[batch_size, chunks, neighbors, neighbor_len, d_model]`
        e = self.output(e)

        # Add residual connection
        return e + e_res


class ChunkedCrossAttention(nn.Module):
    """
    ## Chunked Cross-Attention Layer $\text{C\small{CA}}$

    This is similar to the cross-attention layer defined above.

    This is used in the decoder to pay attention to the retrieved neighbor chunks.

    *We do not use any explicit positional embeddings here.
    We assume that the model can represent positional information in the embeddings implicitly.*
    """

    def __init__(self, d_model: int, n_heads: int, d_k: int, chunk_len: int):
        """
        * `d_model` is the number of features in transformer embeddings
        * `n_heads` is the number of attention heads
        * `d_k` is the number of features per head
        * `chunk_len` is the length of a chunk
        """

        super().__init__()

        self.chunk_len = chunk_len
        self.n_heads = n_heads
        self.d_k = d_k

        # To scale attentions before softmax by $\frac{1}{\sqrt{d_k}}$
        self.scale = 1 / math.sqrt(self.d_k)

        # Linear layers for query, key and value heads.
        self.query = nn.Linear(d_model, n_heads * d_k)
        self.key = nn.Linear(d_model, n_heads * d_k)
        self.value = nn.Linear(d_model, n_heads * d_k)

        # Pre-norm layer for the query embeddings. The paper uses RMSNorm instead.
        self.norm = nn.LayerNorm(d_model)

        # Softmax for attention probabilities
        self.softmax = nn.Softmax(dim=-1)

        # Final linear layer
        self.output = nn.Linear(n_heads * d_k, d_model)

    def forward(self, h: torch.Tensor, e: torch.Tensor):
        """
        `h` are the input embeddings of shape `[batch_size, seq_len, d_model]`
        `e` are the retrieved nearest neighbors of shape `[batch_size, chunks, neighbors, neighbor_len, d_model]`
        """

        # Get shape
        batch_size, chunks, neighbors, neighbor_len, d_model = e.shape

        # No attention if there are no chunks (for short inputs when sampling)
        if chunks == 0:
            return h

        # Residual connection
        h_res = h

        # Remove the first `chunk_len - 1` embeddings.
        # The input pays attention to neighbors retrieved and encoded using the past tokens only;
        # so that there is no information leakage.
        # That is the retrieved neighbors from the first chunks will have information from the first chunk.
        # So by shifting the sequence to the left by `chunk_len - 1` we make sure that information only flows
        # to the right.
        h = h[:, self.chunk_len - 1:]
        # Pre-norm
        h = self.norm(h)
        # Append empty embeddings to the end to be able to split the input into chunks
        if h.shape[1] < chunks * self.chunk_len:
            h = torch.cat((h, h.new_zeros(batch_size, chunks * self.chunk_len - h.shape[1], d_model)), dim=1)
        # Reshape the input into chunks.
        h = h.reshape(batch_size, chunks, self.chunk_len, d_model)

        # Get query from the input
        q = self.query(h).view(*h.shape[:-1], self.n_heads, self.d_k)
        # Get keys and values from the retrieved neighbors
        k = self.key(e).view(*e.shape[:-1], self.n_heads, self.d_k)
        v = self.value(e).view(*e.shape[:-1], self.n_heads, self.d_k)

        # Calculate attention scores for input chunks.
        # Each chunk will pay attention to neighbors retrieved by the previous chunk.
        # This will have shape `[batch_size, chunks, heads, chunk_len, neighbors, neighbor_len]`
        attn = torch.einsum('bcihd,bcnjhd->bchinj', q, k)
        # Scale attention scores
        attn = attn * self.scale

        # Apply softmax over the last two dimensions `neighbors, neighbor_len`
        attn = self.softmax(attn.view(*attn.shape[:-2], -1)).view(attn.shape)

        # Gather values
        h = torch.einsum("bchinj,bcnjhd->bcihd", attn, v)

        # Change from shape `[batch_size, chunks, chunk_len, n_heads, d_k]`
        # to `[batch_size, chunks * chunk_len, n_heads * d_k]`
        h = h.reshape(batch_size, chunks * self.chunk_len, -1)

        # Apply final linear layer.
        # The result will have shape `[batch_size, chunks * chunk_len, d_model]`
        h = self.output(h)

        # Append `chunk_len - 1` zero embedding to the left; i.e. right shift it back
        h = torch.cat((h.new_zeros(batch_size, self.chunk_len - 1, d_model), h), dim=1)

        # Truncate and add the residual connection
        return h[:, :h_res.shape[1]] + h_res


class FeedForward(nn.Module):
    """
    ### Position-wise Feed Forward Layer $\text{F\small{FW}}$

    This consists of two linear layers and an activation in the middle.
    """

    def __init__(self, d_model: int, d_ff: int):
        """
        * `d_model` is the number of features in transformer embeddings
        * `d_ff` is the number features in the hidden layer
        """

        super().__init__()

        # The two linear layers
        self.lin1 = nn.Linear(d_model, d_ff)
        self.lin2 = nn.Linear(d_ff, d_model)

        # ReLU Activation
        self.act = nn.ReLU()

        # Pre-norm layer
        self.norm = nn.LayerNorm(d_model)

    def forward(self, h: torch.Tensor):
        """
        `h` are the embeddings of shape `[batch_size, seq_len, d_model]`
        """

        # Residual
        h_res = h
        # Pre-norm
        h = self.norm(h)
        # First linear layer
        h = self.lin1(h)
        # Activation
        h = self.act(h)
        # Second linear layer
        h = self.lin2(h)

        # Add the residual connection
        return h + h_res


class NearestNeighborEncoder(nn.Module):
    """
    ## Nearest Neighbor Encoder $\text{E\small{NCODER}}(\text{R\small{ET}}(C_u)_{1 \le u \le l}, H)$

    This module encodes the retrieved nearest neighbors
    """

    def __init__(self, chunk_len: int, n_layers: int, ca_layers: Set[int],
                 d_model: int, n_heads: int, d_k: int, d_ff: int):
        """
        * `chunk_len` is the length of a chunk
        * `n_layer` is the number of layers in the encoder $L_{\text{enc}}$
        * `ca_layers` are the layers with cross attention $P_{\text{enc}}$
        * `d_model` is the number of features in embeddings
        * `n_heads` is the number of heads in attention layers
        * `d_k` is the size of attention heads
        * `d_ff` is the size of the feed-forward networks hidden layers
        """

        super().__init__()
        self.ca_layers = ca_layers
        self.chunk_len = chunk_len
        # Cross-attention layers
        self.ca = nn.ModuleList([CrossAttention(d_model, n_heads, d_k) for _ in range(len(ca_layers))])
        # Bi-directional self attention layers
        self.attn = nn.ModuleList([SelfAttention(d_model, n_heads, d_k, is_causal=False) for _ in range(n_layers)])
        # Feed forward layers
        self.ffw = nn.ModuleList([FeedForward(d_model, d_ff) for _ in range(n_layers)])

        # Pre-normalization layer for $H$
        self.norm_h = nn.LayerNorm(d_model)

    def forward(self, e: torch.Tensor, h: torch.Tensor):
        """
        * `e` are token embeddings of the retrieved nearest neighbors,
         $\text{E\small{MB}}\big(\text{R\small{ET}}(C_u)_{1 \le u \le l}\big)$
         of shape `[batch_size, chunks, neighbors, neighbor_len, d_model]`

        * `h` is are the input token embeddings, $H$
         of shape `[batch_size, seq_len, d_model]`

        *The chunks $u \in [1, l]$ and neighbors $j \in [1, k]$ are processed in parallel.*
        """

        # Get shape
        batch_size, chunks, neighbors, neighbor_len, d_model = e.shape

        # $(H_u)_{u \in [1, l]} \leftarrow \text{S\small{PLIT}}(H)$
        h_split = h[:, :self.chunk_len * chunks, :].reshape(batch_size, chunks, self.chunk_len, d_model)

        # Pre-norm
        h_split = self.norm_h(h_split)

        # Keep the index of the cross attention layer
        p_ca = 0
        # For all layers $p' \in [1, L_{\text{enc}}]$
        for p in range(len(self.attn)):
            # Bi-directional self attention
            # $E^j_u \leftarrow \text{A\small{TTN}}_{\text{enc}}(E^j_u)$
            e = self.attn[p](e.view(-1, neighbor_len, d_model)).view(e.shape)

            # Cross attention if $p' \in P_{\text{enc}}$
            if p in self.ca_layers:
                # $E^j_u \leftarrow \text{C\small{A}}_{\text{enc}}(E^j_u, H_u)$
                e = self.ca[p_ca](e, h_split)
                # Incremnt the cross attention index
                p_ca += 1

            # Feed forward layer $E^j_u \leftarrow \text{F\small{FW}}_{\text{enc}}(E^j_u)$
            e = self.ffw[p](e)

        # return $E$
        return e


class RetroModel(nn.Module):
    """
    ## Retro Model

    This is the Retro decoder
    """

    def __init__(self, n_vocab: int, d_model: int, n_layers: int, ca_layers: Set[int], chunk_len: int,
                 n_heads: int, d_k: int, d_ff: int, encoder: NearestNeighborEncoder):
        """
        * `v_vocab` is the number of tokens in the vocabulary
        * `d_model` is the number of features in embeddings
        * `n_layers` is the number of layers in the decoder $L$
        * `ca_layers` are the layers with cross attention $P$
        * `chunk_len` is the length of a chunk
        * `n_heads` is the number of heads in attention layers
        * `d_k` is the size of attention heads
        * `d_ff` is the size of the feed-forward networks hidden layers
        * `encoder` is the nearest neighbor encoder
        """
        super().__init__()

        self.ca_layers = ca_layers
        self.encoder = encoder

        # Token embedding layer
        self.emb = nn.Embedding(n_vocab, d_model)
        # Chunked cross attention layers $\text{C\small{CA}}$
        self.cca = nn.ModuleList(
            [ChunkedCrossAttention(d_model, n_heads, d_k, chunk_len) for _ in range(len(ca_layers))])
        # Attention layers $\text{A\small{TTN}}$
        self.attn = nn.ModuleList([SelfAttention(d_model, n_heads, d_k, is_causal=True) for _ in range(n_layers)])
        # Feed forward layers $\text{F\small{FW}}$
        self.ffw = nn.ModuleList([FeedForward(d_model, d_ff) for _ in range(n_layers)])
        # Readout layer $\text{R\small{EAD}}$
        self.read = nn.Linear(d_model, n_vocab)

        # Pre-normalization layer for nearest neighbor embeddings from
        # $\text{E\small{NCODER}}(\text{R\small{ET}}(C_u)_{1 \le u \le l}, H)$
        self.norm_e = nn.LayerNorm(d_model)

    def forward(self, x: torch.Tensor, ret: torch.Tensor):
        """
        * `x` is the input sequence, $X$ of shape `[batch_size, seq_len]`
        * `ret` are the retrieved neighbors
         $\text{R\small{ET}}(C_u)_{1 \le u \le l}$
         of shape `[batch_size, chunks, neighbors, neighbor_len]`
        """

        # Get input embeddings $H \leftarrow \text{E\small{MB}}(X)$
        h = self.emb(x)

        # Embeddings of the retrieved neighbors
        # $E^j_u = \text{E\small{MB}}_{\text{enc}}\big(\text{R\small{ET}}(C_u)^j\big)$.
        #
        # We use same embeddings for both input and neighbors
        ret_emb = self.emb(ret)

        # Keep index of the chunked cross attention layer
        p_ca = 0
        # For all layers $p \in [1, L]$
        for p in range(len(self.attn)):
            # Causal self attention $H \leftarrow \text{A\small{TTN}}(H)$
            h = self.attn[p](h)

            # Get encoder embeddings before the first $\text{C\small{CA}}$ layer,
            # when $p = \min(P)$
            if self.ca_layers and p == min(self.ca_layers):
                # $E = \text{E\small{NCODER}}(\text{R\small{ET}}(C_u)_{1 \le u \le l}, H)$
                #
                # We passed the embeddings of $\text{R\small{ET}}(C_u)_{1 \le u \le l}$ to encoder.
                e = self.encoder(ret_emb, h)
                # Normalize encoder embeddings
                e = self.norm_e(e)

            # Chunked-cross attention if $p \in P$
            if p in self.ca_layers:
                # $H \leftarrow \text{C\small{CA}}(H, E)$
                h = self.cca[p_ca](h, e)
                # Increment chunked cross-attention index
                p_ca += 1

            # $H \leftarrow \text{F\small{FW}}(H)$
            h = self.ffw[p](h)

        # $O \leftarrow \text{R\small{EAD}}(H)$
        return self.read(h)


def _test():
    """
    ### Test the model with fake data
    """
    chunk_len = 4
    d_model = 8
    d_ff = 32
    n_heads = 2
    d_k = 4

    device = torch.device('cuda:0')

    m = RetroModel(5, d_model, 6, {2, 5}, chunk_len, n_heads, d_k, d_ff,
                   encoder=NearestNeighborEncoder(chunk_len, 2, {1}, d_model, n_heads, d_k, d_ff))

    m.to(device)
    x = [1, 2, 4, 4, 0, 1, 2, 3, 4, 3]
    ret = [
        [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]],
        [[0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1]],
    ]
    res = m(torch.tensor([x] * 10).to(device), torch.tensor([ret] * 10).to(device))

    inspect(res)


#
if __name__ == '__main__':
    _test()

# RETRO training dataset
# https://nn.labml.ai/transformers/retro/dataset.html

In [10]:
"""
---
title: Training dataset for RETRO
summary: >
  Create a dataset for RETRO model training
---

# RETRO training dataset

We pre-retrieve nearest neighbors from the [key-value database](database.html)
 and create the dataset to train the [RETRO](index.html)
 [model](model.html).
"""

import json
from pathlib import Path

import numpy as np
import torch
from torch.utils.data import Dataset as PyTorchDataset

from labml import lab, monit
from labml_helpers.datasets.text import TextFileDataset, TextDataset
from labml_nn.transformers.retro.database import RetroIndex


def build_dataset(chunk_len: int = 16, chunks_per_sample: int = 32, skip_range: int = 8):
    """
    ## Build the dataset

    * `chunk_len` is the chunk length
    * `chunks_per_sample` is the number of chunks per training sample
    * `skip_range` is the maximum number of characters to skip between two samples.
        We skip a few characters between samples to make sure the samples
        aren't aligned perfectly with the chunks in the [database](database.html)
    """

    # Load the text file
    dataset = TextFileDataset(
        lab.get_data_path() / 'tiny_shakespeare.txt',
        list,
        url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')

    # Training portion of it
    text = dataset.train

    # Load the index for retrieving neighbors
    index = RetroIndex()

    # The input sample offsets
    sample_offsets = []
    # Cursor for the text
    i = 0
    while i < len(text):
        # Skip a few characters to make sure it's not aligned with the neighbors
        skip = np.random.randint(skip_range)
        i += skip

        # Stop if we've reached the end of the text
        if i + chunks_per_sample * chunk_len > len(text):
            break

        # Collect the offset
        sample_offsets.append(i)

        # Increment the cursor
        i += chunks_per_sample * chunk_len

    # For samples
    samples = []
    # Iterate through sample offsets
    for i in monit.iterate('Gather Neighbors', sample_offsets):
        # Get the sample including an extra character (for prediction)
        sample = text[i: i + chunks_per_sample * chunk_len + 1]
        # The input
        src = sample[:-1]
        # Break it into chunks
        chunks = [src[j:j + chunk_len] for j in range(0, len(src), chunk_len)]
        # The chunk offsets
        chunk_offsets = [j + i for j in range(0, len(src), chunk_len)]

        # Retrieve nearest neighbors
        neighbor_offsets = index(chunks, chunk_offsets)

        # Get neighbor texts. The neighbor length is twice the `chunk_len`
        neighbors = [[text[j: j + chunk_len * 2] for j in n_off] for n_off in neighbor_offsets]

        # Add to list of samples
        samples.append((sample[:-1], sample[1:], neighbors))

    # Save the samples in JSON.
    # We don't need to use complex dataset storage mechanisms or pre-tokenize
    # since our dataset is small.
    with open(str(lab.get_data_path() / 'retro_train_dataset.json'), 'w') as f:
        f.write(json.dumps(samples))


class Dataset(PyTorchDataset):
    """
    ## Dataset

    This is the PyTorch dataset that loads the dataset created
    by `build_dataset`.
    """
    def __init__(self, file_path: Path, tds: TextDataset):
        """
        * `file_path` is the path of the saved JSON file
        * `tds` is the `TextDataset`
        """

        self.tds = tds
        # Load the samples
        with open(str(file_path), 'r') as f:
            self.samples = json.loads(f.read())

    def __len__(self):
        """
        Number of samples
        """
        return len(self.samples)

    def __getitem__(self, idx: int):
        """
        Get a sample
        """
        # Get the sample
        s = self.samples[idx]
        # Tokenize
        src = self.tds.text_to_i(s[0])
        tgt = self.tds.text_to_i(s[1])
        neighbors = torch.stack([torch.stack([self.tds.text_to_i(n) for n in chunks]) for chunks in s[2]])
        #
        return src, tgt, neighbors

#
if __name__ == '__main__':
    build_dataset()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


# RETRO training
# https://nn.labml.ai/transformers/retro/train.html

In [None]:
"""
---
title: RETRO training
summary: >
  Training RETRO model with Tiny Shakespeare dataset
---

# RETRO training

This is the training code for
 [RETRO](index.html).
"""

import torch
from torch import nn
from torch.utils.data import DataLoader, RandomSampler

from labml import monit, lab, tracker, experiment, logger
from labml.logger import Text
from labml_helpers.datasets.text import TextFileDataset
from labml_nn.optimizers.noam import Noam
from labml_nn.transformers.retro import model as retro
from labml_nn.transformers.retro.dataset import Dataset, RetroIndex
from labml_nn.transformers.retro.model import RetroModel, NearestNeighborEncoder


class Sampler:
    """
    ## Sampler

    This class greedily samples from a model.
    """

    def __init__(self, device: torch.device, model: retro.RetroModel, tds: TextFileDataset, chunk_len: int):
        """
        * `device` is the device of the model
        * `model` is the [Retro mode](retro.html)
        * `tds` is the text dataset (used to get neighbor chunks)
        * `chunk_len` is the length of a chunk
        """
        self.chunk_len = chunk_len
        self.tds = tds
        self.model = model
        self.device = device

        # [Retro index](database.html)
        self.index = RetroIndex()

    def retrieve_nearest_neighbours(self, chunk: str):
        """
        ### Retrieve nearest neighbors of a given chunk
        """

        # Retrieve the offsets of the nearest neighbors
        neighbor_offsets = self.index([chunk], None)

        # Get the neighbors (with neighbor length equal to `chunk_len * 2`)
        text = self.tds.train
        neighbors = [text[j: j + self.chunk_len * 2] for j in neighbor_offsets[0]]

        #
        return neighbors

    def sample(self, prompt: str, sample_len: int):
        """
        ### Sample text from the given prompt
        """

        # To store nearest neighbors as strings
        neighbors_str = []

        # Sampled text
        sampled = ''

        # Sample `sample_len` tokens
        for i in range(sample_len):
            # We need to retrieve neighbors,
            # if there are more sampled chunks than we have already retrieved for
            while len(neighbors_str) < len(prompt) // self.chunk_len:
                # Get the last chunk for which we haven't retrieved neighbors
                off = len(neighbors_str) * self.chunk_len
                chunk = prompt[off: off + self.chunk_len]
                # Retrieve nearest neighbors
                neighbors_str.append(self.retrieve_nearest_neighbours(chunk))

            # Tokenize the input
            src = self.tds.text_to_i(prompt)
            # Tokenize the retrieved neighbors
            neighbors = torch.stack([torch.stack([self.tds.text_to_i(n) for n in chunk]) for chunk in neighbors_str])

            # Move them to the same device as the model
            src = src.to(self.device)
            neighbors = neighbors.to(self.device)

            # Get model output
            res = self.model(src[None, :], neighbors[None, :, :, :])

            # Greedily sample the last token
            token = res[0, -1, :].argmax(dim=-1)

            # Add the sampled token text to the prompt and sample text
            prompt += self.tds.itos[token.item()]
            sampled += self.tds.itos[token.item()]

        #
        return sampled


class Trainer:
    """
    ## Retro trainer
    """

    def __init__(self, device: torch.device, model: retro.RetroModel,
                 dataloader: DataLoader, optimizer: torch.optim.Optimizer):
        """
        * `device` is the device of the model
        * `model` is the [Retro mode](retro.html)
        * `dataloader` is the dataloader for the [dataset with pre-retrieved neighbors](dataset.html)
        * `optimizer` is the optimizer
        """
        self.optimizer = optimizer
        self.device = device
        self.dataloader = dataloader
        self.model = model
        self.loss_func = nn.CrossEntropyLoss()

    def __call__(self):
        """
        ### Train the model for an epoch
        """

        # Iterate through training data
        for i, (src, tgt, neighbors) in monit.enum('Train', self.dataloader):
            # Move data to the device
            src, tgt, neighbors = src.to(self.device), tgt.to(self.device), neighbors.to(self.device)

            # Forward pass
            res = self.model(src, neighbors)
            # Calculate loss
            loss = self.loss_func(res.view(-1, res.shape[-1]), tgt.view(-1))

            # Clear the gradients
            self.optimizer.zero_grad()
            # Backward pass
            loss.backward()
            # Optimize the model
            self.optimizer.step()

            # Save training statistics and increment the global step counter
            tracker.save({'loss.train': loss})
            tracker.add_global_step(len(src))


def train():
    """
    ## Create and train a small model
    """

    # Create an experiment
    experiment.create(name='retro_small')

    # GPU device
    device = torch.device('cuda:0')

    # Load Tiny Shakespeare dataset
    tds = TextFileDataset(
        lab.get_data_path() / 'tiny_shakespeare.txt',
        list,
        url='https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')

    # Load [Retro dataset](dataset.html)
    train_dataset = Dataset(lab.get_data_path() / 'retro_train_dataset.json', tds)

    # Create dataloader
    train_dl = DataLoader(train_dataset,
                          batch_size=4,
                          sampler=RandomSampler(train_dataset, replacement=True))

    # Hyper-parameters
    chunk_len = 16
    d_model = 128
    d_ff = 512
    n_heads = 16
    d_k = 16

    # Create the nearest neighbor encoder
    nearest_neighbor_encoder = NearestNeighborEncoder(chunk_len, 6, {3}, d_model, n_heads, d_k, d_ff)
    # Create the model
    model = RetroModel(tds.n_tokens, d_model, 6,
                       {3, 5},
                       chunk_len, n_heads, d_k, d_ff,
                       encoder=nearest_neighbor_encoder)
    # Move the model to the device
    model = model.to(device)
    # Create the optimizer
    optimizer = Noam(model.parameters(), lr=1., d_model=d_model, warmup=2_000)
    # Create the `Trainer`
    trainer = Trainer(device, model, train_dl, optimizer)
    # Create the `Sampler`
    sampler = Sampler(device, model, tds, chunk_len)
    #
    prompt = '''Second Citizen:\nOne word, good citizens.\n\nFirst Citizen:'''

    # Set models for saving and loading
    experiment.add_pytorch_models(model=model)

    # Start the experiment
    with experiment.start():
        # Train for `32` epochs
        for epoch in monit.loop(32):
            # Train
            trainer()
            # Print a new line
            tracker.new_line()
            # Sample from the `prompt`
            logger.log([(prompt.replace('\n', '\\n\n'), Text.subtle),
                        (sampler.sample(prompt, 128).replace('\n', '\\n\n'), Text.none)])
            # Save models
            experiment.save_checkpoint()


#
if __name__ == '__main__':
    train()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
