In [None]:
import torch
from torch import nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import numpy as np
import math
import os
import random
import tqdm
import gzip
import time

!pip install einops
from einops import rearrange, repeat, pack, unpack, einsum
from einops.layers.torch import Rearrange


from functools import partial, wraps
from contextlib import contextmanager, ExitStack
from pathlib import Path
from filelock import FileLock
import pickle

import transformers
from transformers import AutoTokenizer

!pip install faiss-gpu
import faiss

!pip install datasets
import datasets

Collecting einops
  Downloading einops-0.7.0-py3-none-any.whl.metadata (13 kB)
Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.7.0
Collecting faiss-gpu
  Downloading faiss_gpu-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)
Downloading faiss_gpu-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (85.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.5/85.5 MB[0m [31m18.7 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hInstalling collected packages: faiss-gpu
Successfully installed faiss-gpu-1.7.2


In [None]:
"""
This line of code checks whether a CUDA-enabled GPU is available on the system.
If a GPU is available, it sets the device to 'cuda'; otherwise, it sets it to 'cpu'.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class RelativePosition(nn.Module):
    """
    RelativePosition module computes relative positions between sequence elements and generates corresponding relative position embeddings.

    Args:
        rp_scale (float): Scaling factor for the relative position embeddings.
        num_buckets (int): Number of buckets for discretizing relative positions.
        rp_max_distance (int): Maximum relative distance considered.
        heads (int): Number of attention heads.

    Attributes:
        scale (float): Scaling factor for the relative position embeddings.
        num_buckets (int): Number of buckets for discretizing relative positions.
        rp_max_distance (int): Maximum relative distance considered.
        relative_attention_embedding (torch.nn.Embedding): Embedding layer for storing relative position embeddings.

    Methods:
        relative_position_bucket: Computes bucket indices for relative positions.
        forward: Forward pass of the module.
    """

    def __init__(
        self,
        rp_scale,
        num_buckets=32,
        rp_max_distance=128,
        heads=8
    ):
        """
        Initializes the RelativePosition module.

        Args:
            rp_scale (float): Scaling factor for the relative position embeddings.
            num_buckets (int): Number of buckets for discretizing relative positions.
            rp_max_distance (int): Maximum relative distance considered.
            heads (int): Number of attention heads.
        """
        super().__init__()
        self.scale = rp_scale
        self.num_buckets = num_buckets
        self.rp_max_distance = rp_max_distance
        self.relative_attention_embedding = nn.Embedding(num_buckets, heads)

    def relative_position_bucket(self, relative_position_matrix):
        """
        Computes bucket indices for given relative position matrix.

        Args:
            relative_position_matrix (torch.Tensor): Matrix of relative positions.

        Returns:
            torch.Tensor: Bucket indices for the given relative positions.
        """
        n = -relative_position_matrix
        n = torch.max(n, torch.zeros_like(n))

        max_exact = self.num_buckets // 2

        is_small = n < max_exact
        val_if_large = max_exact + (torch.log(n.float() / max_exact) / math.log(self.rp_max_distance / max_exact) * (self.num_buckets - max_exact)).long()
        val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, self.num_buckets - 1))

        return torch.where(is_small, n, val_if_large)

    def forward(self, sequence_length, device):
        """
        Forward pass of the RelativePosition module.

        Args:
            sequence_length (int): Length of the sequence.
            device (torch.device): Device to perform computations.

        Returns:
            torch.Tensor: Relative position embeddings scaled by the scaling factor.
        """
        sequence_pos = torch.arange(sequence_length, dtype=torch.long, device=device)
        context_pos = torch.arange(2 * sequence_length, dtype=torch.long, device=device)
        sequence_rel_pos = rearrange(sequence_pos, 'i -> i 1')
        context_rel_pos = rearrange(context_pos, 'j -> 1 j')
        rel_pos = context_rel_pos - sequence_rel_pos

        position_bucket_indices = self.relative_position_bucket(rel_pos)

        rp_values = self.relative_attention_embedding(position_bucket_indices)
        rp_values = rearrange(rp_values, 'i j h -> () h i j')
        return rp_values * self.scale

In [None]:
class KNN():
    """
    KNN (K-Nearest Neighbors) class for efficient nearest neighbor search.

    Args:
        dim (int): Dimensionality of the data vectors.
        max_memories (int): Maximum number of data vectors to be stored in memory.

    Attributes:
        dim (int): Dimensionality of the data vectors.
        max_memories (int): Maximum number of data vectors to be stored in memory.
        shape (tuple): Shape of the memory map storing the data vectors.
        db_offset (int): Offset for indexing the memory map.
        db_filepath (str): Filepath for the memory map.
        db (np.memmap): Memory map for storing data vectors.
        index (faiss.IndexFlatL2): FAISS index for fast nearest neighbor search.

    Methods:
        add_to_db: Add new data vectors to the memory map.
        search_and_retrieve: Perform nearest neighbor search on query vectors.
        add: Add new data vectors to the memory map and update the index.
        search: Perform nearest neighbor search on query vectors.
        clear: Clear the memory map and reset the index.
    """

    def __init__(
        self,
        dim,
        max_memories,
        ):
        """
        Initializes the KNN object.

        Args:
            dim (int): Dimensionality of the data vectors.
            max_memories (int): Maximum number of data vectors to be stored in memory.
        """
        self.dim = dim
        self.max_memories = max_memories
        self.shape = (max_memories, 2, dim)
        self.db_offset = 0
        self.db_filepath = "./memory.memmap"
        self.db = np.memmap(self.db_filepath, mode='w+', dtype=np.float32, shape=self.shape)
        self.index = faiss.IndexFlatL2(dim)

    def add_to_db(self, new_data):
        """
        Add new data vectors to the memory map.

        Args:
            new_data (torch.Tensor): New data vectors to be added.
        """
        new_data_len = new_data.shape[0]
        ids = (np.arange(new_data_len) + self.db_offset)
        self.db[ids] = new_data.detach().cpu().numpy()
        self.db_offset += new_data_len
        # Write to file
        self.db.flush()

    def search_and_retrieve(self, query_vecs, topk):
        """
        Perform nearest neighbor search on query vectors.

        Args:
            query_vecs (np.ndarray): Query vectors for nearest neighbor search.
            topk (int): Number of nearest neighbors to retrieve.

        Returns:
            np.ndarray: Nearest neighbor key-value pairs.
        """
        query_vecs = query_vecs
        distances, indices = self.index.search(query_vecs, topk)
        kvs = self.db[indices]
        return kvs

    def add(self, new_data):
        """
        Add new data vectors to the memory map and update the index.

        Args:
            new_data (torch.Tensor): New data vectors to be added.
        """
        # Input is b n 2 d, flatten to (b n) 2 d
        new_data = new_data.flatten(0, 1)
        # Add to db
        self.add_to_db(new_data)
        # Only keys are used in knn index
        keys, vals = new_data.unbind(dim=-2)
        keys = keys.detach().cpu().numpy()
        # Add (b n) d tensors to index
        keys = np.ascontiguousarray(keys)
        # Add to index
        self.index.add(keys)

    def search(self, query_vecs, topk):
        """
        Perform nearest neighbor search on query vectors.

        Args:
            query_vecs (torch.Tensor): Query vectors for nearest neighbor search.
            topk (int): Number of nearest neighbors to retrieve.

        Returns:
            torch.Tensor: Nearest neighbor key-value pairs.
        """
        query_batch_size, query_seq_len = query_vecs.shape[0], query_vecs.shape[1]
        device = query_vecs.device
        # Input is b n d, flatten to (b n) d
        query_vecs = query_vecs.flatten(0, 1)
        kvs = self.search_and_retrieve(np.ascontiguousarray(query_vecs.detach().cpu().numpy()), topk)
        # kvs are (b n) k 2 d, unflatten to b n k 2 d
        kvs = torch.tensor(kvs)
        kvs = torch.unflatten(kvs, 0, (query_batch_size, query_seq_len))
        return kvs.to(device)

    def clear(self):
        """
        Clear the memory map and reset the index.
        """
        self.index.reset()
        self.db[:] = 0
        self.db_offset = 0


In [None]:


class XLAttention(nn.Module):
    """
    XLAttention module implements the attention mechanism used in XLNet.

    Args:
        embedding_dimension (int): Dimensionality of the input embeddings.
        heads (int): Number of attention heads.
        head_dimension (int): Dimensionality of each attention head.
        dropout (float): Dropout probability.

    Attributes:
        heads (int): Number of attention heads.
        dropout (nn.Dropout): Dropout layer.
        scale (float): Scaling factor for attention scores.
        query_matrix (nn.Linear): Linear layer for query projection.
        key_matrix (nn.Linear): Linear layer for key projection.
        value_matrix (nn.Linear): Linear layer for value projection.
        output_matrix (nn.Linear): Linear layer for output projection.

    Methods:
        forward: Forward pass of the XLAttention module.
    """

    def __init__(
        self,
        embedding_dimension,
        heads=8,
        head_dimension=64,
        dropout=0.,
    ):
        """
        Initializes the XLAttention module.

        Args:
            embedding_dimension (int): Dimensionality of the input embeddings.
            heads (int): Number of attention heads.
            head_dimension (int): Dimensionality of each attention head.
            dropout (float): Dropout probability.
        """
        super().__init__()
        self.heads = heads
        self.dropout = nn.Dropout(dropout)
        self.scale = head_dimension ** -0.5

        self.query_matrix = nn.Linear(embedding_dimension, self.heads * head_dimension)
        self.key_matrix = nn.Linear(embedding_dimension, self.heads * head_dimension)
        self.value_matrix = nn.Linear(embedding_dimension, self.heads * head_dimension)
        self.output_matrix = nn.Linear(self.heads * head_dimension, embedding_dimension)

    def forward(
        self,
        x,  # batch_size, sequence_length, embedding_dimension
        relative_positions=None,
        xl_memory=None
    ):
        """
        Forward pass of the XLAttention module.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, sequence_length, embedding_dimension).
            relative_positions (torch.Tensor): Tensor containing relative positions.
            xl_memory (torch.Tensor): XL memory for cross-attention.

        Returns:
            torch.Tensor: Output tensor.
            torch.Tensor: XL memory to add.
        """
        device = x.device
        queries = self.query_matrix(x)
        keys = self.key_matrix(x)
        values = self.value_matrix(x)

        queries = queries * self.scale

        if xl_memory is not None:
            k_xl, v_xl = xl_memory.unbind(dim=-2)
            keys = torch.cat((k_xl, keys), dim=-2)
            values = torch.cat((v_xl, values), dim=-2)
            xl_sequence_length = k_xl.shape[1]

        queries = rearrange(queries, 'b t (h d) -> b h t d', h=self.heads)
        keys = rearrange(keys, 'b t (h d) -> b h t d', h=self.heads)
        qk = einsum(queries, keys, 'b h i d, b h j d -> b h i j')

        i, j = qk.shape[-2:]
        if relative_positions is not None:
            qk = relative_positions[..., -i:, -j:] + qk

        qk = qk * self.scale

        mask = torch.ones((i, j), dtype=torch.bool, device=device).triu(j - i + 1)
        qk = qk.masked_fill(mask, float('-inf'))

        qk = F.softmax(qk, dim=-1)
        qk = self.dropout(qk)

        values = rearrange(values, 'b t (h d) -> b h t d', h=self.heads)
        qkv = qk @ values
        qkv = rearrange(qkv, 'b h t d -> b t (h d)')

        out = self.output_matrix(qkv)

        keys = rearrange(keys, 'b h t d -> b t (h d)', h=self.heads)
        values = rearrange(values, 'b h t d -> b t (h d)', h=self.heads)
        kv_memories = torch.stack((keys, values), dim=-2)

        if xl_memory is not None:
            xl_memories, current_input = kv_memories[:, :-xl_sequence_length], kv_memories[:, -xl_sequence_length:]
            kv_to_add_xl = current_input
        else:
            kv_to_add_xl = kv_memories

        return out, kv_to_add_xl


In [None]:
class KNNAttention(nn.Module):
    """
    KNNAttention module implements the attention mechanism with K-Nearest Neighbors (KNN) retrieval.

    Args:
        embedding_dimension (int): Dimensionality of the input embeddings.
        knn (KNN): KNN object for memory retrieval.
        heads (int): Number of attention heads.
        head_dimension (int): Dimensionality of each attention head.
        topk_retrieved_memories (int): Number of top memories to retrieve.
        dropout (float): Dropout probability.

    Attributes:
        heads (int): Number of attention heads.
        scale (float): Scaling factor for attention scores.
        dropout (nn.Dropout): Dropout layer.
        query_matrix (nn.Linear): Linear layer for query projection.
        key_matrix (nn.Linear): Linear layer for key projection.
        value_matrix (nn.Linear): Linear layer for value projection.
        output_matrix (nn.Linear): Linear layer for output projection.
        gate_bias (nn.Parameter): Learnable bias for gating mechanism.
        topk_retrieved_memories (int): Number of top memories to retrieve.
        knn (KNN): KNN object for memory retrieval.

    Methods:
        forward: Forward pass of the KNNAttention module.
    """

    def __init__(
        self,
        embedding_dimension,
        knn,
        heads=8,
        head_dimension=64,
        topk_retrieved_memories=3,
        dropout=0.
    ):
        """
        Initializes the KNNAttention module.

        Args:
            embedding_dimension (int): Dimensionality of the input embeddings.
            knn (KNN): KNN object for memory retrieval.
            heads (int): Number of attention heads.
            head_dimension (int): Dimensionality of each attention head.
            topk_retrieved_memories (int): Number of top memories to retrieve.
            dropout (float): Dropout probability.
        """
        super().__init__()
        self.heads = heads
        self.scale = head_dimension ** -0.5
        self.dropout = nn.Dropout(dropout)

        self.query_matrix = nn.Linear(embedding_dimension, heads * head_dimension)
        self.key_matrix = nn.Linear(embedding_dimension, heads * head_dimension)
        self.value_matrix = nn.Linear(embedding_dimension, heads * head_dimension)
        self.output_matrix = nn.Linear(heads * head_dimension, embedding_dimension)

        self.gate_bias = nn.Parameter(torch.randn(self.heads, 1, 1))
        self.topk_retrieved_memories = topk_retrieved_memories
        self.knn = knn

    def forward(
        self,
        x,  # batch_size, sequence_length, embedding_dimension
        relative_positions=None,
        xl_memory=None
    ):
        """
        Forward pass of the KNNAttention module.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, sequence_length, embedding_dimension).
            relative_positions (torch.Tensor): Tensor containing relative positions.
            xl_memory (torch.Tensor): XL memory for cross-attention.

        Returns:
            torch.Tensor: Output tensor.
            torch.Tensor: XL memory to add.
        """
        device = x.device
        batch_size, sequence_length = x.shape[:2]
        queries = self.query_matrix(x)
        keys = self.key_matrix(x)
        values = self.value_matrix(x)

        queries = F.normalize(queries, dim=-1)
        keys = F.normalize(keys, dim=-1)

        if xl_memory is not None:
            k_xl, v_xl = xl_memory.unbind(dim=-2)
            keys = torch.cat((k_xl, keys), dim=-2)
            values = torch.cat((v_xl, values), dim=-2)
            xl_sequence_length = k_xl.shape[1]

        queries = rearrange(queries, 'b t (h d) -> b h t d', h=self.heads)
        keys = rearrange(keys, 'b t (h d) -> b h t d', h=self.heads)
        qk = einsum(queries, keys, 'b h i d, b h j d -> b h i j')

        i, j = qk.shape[-2:]
        if relative_positions is not None:
            qk = relative_positions[..., -i:, -j:] + qk

        qk = qk * self.scale

        mask = torch.ones((i, j), dtype=torch.bool, device=device).triu(j - i + 1)
        qk = qk.masked_fill(mask, float('-inf'))

        qk = F.softmax(qk, dim=-1)

        qk = self.dropout(qk)

        values = rearrange(values, 'b t (h d) -> b h t d', h=self.heads)
        qkv = qk @ values

        if self.knn.index.ntotal > 0:
            queries = rearrange(queries, 'b h t d -> b t (h d)')
            mem_kv = self.knn.search(queries, topk=self.topk_retrieved_memories)
            mem_k, mem_v = mem_kv.unbind(dim=-2)
            mem_k = rearrange(mem_k, 'b t k (h d) -> b h t k d', h=self.heads)
            mem_v = rearrange(mem_v, 'b t k (h d) -> b h t k d', h=self.heads)

            queries = rearrange(queries, 'b t (h d) -> b h t d', h=self.heads)
            mem_qk = einsum(queries, mem_k, 'b h t d, b h t k d -> b h t k')
            mem_qk = mem_qk * self.scale

            mem_qk = F.softmax(mem_qk, dim=-1)
            mem_qk = self.dropout(mem_qk)
            mem_qkv = einsum(mem_qk, mem_v, 'b h t k, b h t k d -> b h t d')

            combined_qkv = mem_qkv * self.gate_bias + qkv * (1 - self.gate_bias)
            combined_qkv = rearrange(combined_qkv, 'b h t d -> b t (h d)')
            out = self.output_matrix(combined_qkv)
        else:
            qkv = rearrange(qkv, 'b h t d -> b t (h d)')
            out = self.output_matrix(qkv)

        keys = rearrange(keys, 'b h t d -> b t (h d)', h=self.heads)
        values = rearrange(values, 'b h t d -> b t (h d)', h=self.heads)
        kv_memories = torch.stack((keys, values), dim=-2)

        if xl_memory is not None:
            xl_memories, current_kv = kv_memories[:, :-xl_sequence_length], kv_memories[:, -xl_sequence_length:]
        else:
            current_kv = kv_memories

        self.knn.add(current_kv)

        return out, current_kv

In [None]:
class Block(nn.Module):
    """
    Block module defines a single block in the transformer architecture.

    Args:
        embedding_dimension (int): Dimensionality of the input embeddings.
        attention_type (nn.Module): Type of attention mechanism to be used.
        dropout (float): Dropout probability.

    Attributes:
        attention (nn.Module): Attention mechanism.
        dim (int): Dimensionality of the input embeddings.
        norm (nn.LayerNorm): Layer normalization.
        ff_block (nn.Sequential): Feed-forward block.

    Methods:
        forward: Forward pass of the Block module.
    """

    def __init__(self, embedding_dimension, attention_type, dropout=0.):
        """
        Initializes the Block module.

        Args:
            embedding_dimension (int): Dimensionality of the input embeddings.
            attention_type (nn.Module): Type of attention mechanism to be used.
            dropout (float): Dropout probability.
        """
        super().__init__()
        self.attention = attention_type
        self.dim = embedding_dimension
        self.norm = nn.LayerNorm(self.dim)

        self.ff_block = nn.Sequential(
            nn.LayerNorm(self.dim),
            nn.Linear(self.dim, self.dim * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(self.dim * 4, self.dim))

    def forward(self, x, xl_memories, rel_pos):
        """
        Forward pass of the Block module.

        Args:
            x (torch.Tensor): Input tensor.
            xl_memories (torch.Tensor): XL memory for cross-attention.
            rel_pos (torch.Tensor): Tensor containing relative positions.

        Returns:
            torch.Tensor: Output tensor.
            torch.Tensor: New XL memory.
        """
        residual = x
        attn_out = self.norm(x)
        attn_out, new_xl_memories = self.attention(attn_out, relative_positions=rel_pos, xl_memory=xl_memories)
        attn_out += residual

        residual = attn_out
        ff_out = self.ff_block(attn_out)
        ff_out += residual
        return ff_out, new_xl_memories

In [None]:
class MemorizingTransformer(nn.Module):
    """
    MemorizingTransformer module implements a transformer with memory capabilities.

    Args:
        embedding_dimension (int): Dimensionality of the input embeddings.
        vocab_size (int): Size of the vocabulary.
        max_knn_memories (int): Maximum number of KNN memories.
        heads (int): Number of attention heads.
        depth (int): Depth of the transformer.
        dropout (float): Dropout probability.
        head_dimension (int): Dimensionality of each attention head.
        topk (int): Number of top memories to retrieve.

    Attributes:
        heads (int): Number of attention heads.
        embedding_dimension (int): Dimensionality of the input embeddings.
        dropout (float): Dropout probability.
        depth (int): Depth of the transformer.
        head_dimension (int): Dimensionality of each attention head.
        max_knn_memories (int): Maximum number of KNN memories.
        topk (int): Number of top memories to retrieve.
        rel_pos (RelativePosition): Relative position embedding for XLAttention.
        rel_pos_knn (RelativePosition): Relative position embedding for KNNAttention.
        embedding_matrix (nn.Embedding): Embedding matrix.
        knn (KNN): KNN object for memory retrieval.
        layers (nn.ModuleList): List of transformer blocks.
        to_logits (nn.Sequential): Final linear layer for logits.

    Methods:
        forward: Forward pass of the MemorizingTransformer module.
    """

    def __init__(
        self,
        embedding_dimension,
        vocab_size,
        max_knn_memories=81920,
        heads=8,
        depth=10,
        dropout=0,
        head_dimension=64,
        topk=5,
    ):
        """
        Initializes the MemorizingTransformer module.

        Args:
            embedding_dimension (int): Dimensionality of the input embeddings.
            vocab_size (int): Size of the vocabulary.
            max_knn_memories (int): Maximum number of KNN memories.
            heads (int): Number of attention heads.
            depth (int): Depth of the transformer.
            dropout (float): Dropout probability.
            head_dimension (int): Dimensionality of each attention head.
            topk (int): Number of top memories to retrieve.
        """
        super().__init__()
        self.heads = heads
        self.embedding_dimension = embedding_dimension
        self.dropout = dropout
        self.depth = depth
        self.head_dimension = head_dimension
        self.max_knn_memories = max_knn_memories
        self.topk = topk

        self.rel_pos = RelativePosition(rp_scale=head_dimension ** 0.5, heads=self.heads)
        self.rel_pos_knn = RelativePosition(rp_scale=head_dimension ** 0.5, heads=self.heads)
        self.embedding_matrix = nn.Embedding(vocab_size, self.embedding_dimension)
        self.knn = KNN(head_dimension * heads, self.max_knn_memories)

        self.layers = nn.ModuleList([])
        for i in range(self.depth):
            if i == self.depth - 2:
                attention_type = KNNAttention(
                    self.embedding_dimension,
                    self.knn,
                    heads=self.heads,
                    head_dimension=self.head_dimension,
                    dropout=self.dropout
                )
            else:
                attention_type = XLAttention(
                    self.embedding_dimension,
                    heads=self.heads,
                    head_dimension=self.head_dimension,
                    dropout=self.dropout
                )
            self.layers.append(Block(self.embedding_dimension, attention_type))

        self.to_logits = nn.Sequential(
            nn.LayerNorm(self.embedding_dimension),
            nn.Linear(self.embedding_dimension, vocab_size)
        )

    def forward(
        self,
        x,
        relative_positions=None,
        xl_memories=None,
        labels=None,
    ):
        """
        Forward pass of the MemorizingTransformer module.

        Args:
            x (torch.Tensor): Input tensor.
            relative_positions (torch.Tensor): Tensor containing relative positions.
            xl_memories (torch.Tensor): XL memory for cross-attention.
            labels (torch.Tensor): Target labels.

        Returns:
            torch.Tensor: Loss tensor.
            list: List of new XL memories.
        """
        device = x.device
        batch_size, sequence_length = x.shape[0], x.shape[1]

        rel_pos = self.rel_pos(sequence_length, device=device)
        rel_pos_knn = self.rel_pos_knn(sequence_length, device=device)

        if xl_memories is not None:
            xl_memories = xl_memories
        else:
            xl_memories = (None,) * self.depth

        xl_memories_iter = iter(xl_memories)

        x = self.embedding_matrix(x)
        new_xl_memories = []

        for ind, block in enumerate(self.layers):
            if ind == self.depth - 2:
                layer_rel_pos = rel_pos_knn
            else:
                layer_rel_pos = rel_pos

            x, xl_mem = block(x, next(xl_memories_iter), layer_rel_pos)

            if xl_mem is not None:
                new_xl_memories.append(xl_mem.detach())

        logits = self.to_logits(x)
        loss = F.cross_entropy(rearrange(logits, 'b n c -> b c n'), labels)

        if len(new_xl_memories) > 0:
            return loss, new_xl_memories
        return loss


In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset

SEGMENTS = 10
SEQUENCE_LENGTH = 512
CHUNK_SIZE = (SEGMENTS * SEQUENCE_LENGTH) + 1
BATCH_SIZE = 8

# Load dataset
dataset = load_dataset("ccdv/arxiv-summarization", split='train', streaming=True)
raw_dataset = list(dataset.take(3500))

# Extract raw articles
raw_articles = [x['article'] for x in raw_dataset]
raw_articles = [x for x in raw_articles if len(x) > CHUNK_SIZE]

# Convert to numpy arrays
converted = [np.fromstring(doc, dtype=np.uint8) for doc in raw_articles]

# Clip articles to CHUNK_SIZE
def clip_article(doc, chunk_size):
    remainder = len(doc) % chunk_size
    return doc[:-remainder]

clipped = [clip_article(doc, CHUNK_SIZE) for doc in converted]



# Ensure all documents have the same shape
min_length = min(len(doc) for doc in clipped)
clipped = [doc[:min_length] for doc in clipped]

# Reshape the documents
chunked = np.array([doc.reshape(-1, CHUNK_SIZE) for doc in clipped])

# Convert to torch tensor
processed_data = torch.tensor(np.concatenate(chunked), dtype=torch.long)

# Split into train, validation, and test loaders
eighty_split = int(processed_data.shape[0] * .8)
ninety_split = int(processed_data.shape[0] * .9)

train_loader = DataLoader(processed_data[:eighty_split], batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(processed_data[eighty_split:ninety_split], batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(processed_data[ninety_split:], batch_size=BATCH_SIZE, shuffle=True)

# Check the shape of processed_data
print("Processed Data Shape:", processed_data.shape)


Processed Data Shape: torch.Size([3401, 5121])


  converted = [np.fromstring(doc, dtype=np.uint8) for doc in raw_articles]


In [None]:
model = MemorizingTransformer(embedding_dimension = 128,
                              vocab_size = 128,
                              max_knn_memories = MAX_KNN_MEMORIES)

model.to(device) ###########

optim = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)
model.train()

for i in tqdm.tqdm(range(200), mininterval = 10., desc = 'training'):

    model.train()
    train_loss = 0.
    # Clear XL memories
    xl_memories = None
    # Clear KNN memory
    model.knn.clear()

    data = next(iter(train_loader)).to(device=device)
    seq, labels = data[:, :-1], data[:, 1:]

    t0 = time.time()
    print ("Begin document")

    # Each pass will be (BATCH_SIZE * SEGMENTS) iterations
    for seq_segment, labels_segment in zip(seq.chunk(SEGMENTS, dim = -1), labels.chunk(SEGMENTS, dim = -1)):

        loss, xl_memories = model(
            seq_segment,
            labels = labels_segment,
            xl_memories = xl_memories
        )

        train_loss += loss.item() / SEGMENTS
        (loss / SEGMENTS).backward()


    print(f'training loss: {train_loss}')
    t1 = time.time()
    print ("End document, total time:", t1 - t0)
    torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_CLIP_NORM)
    optim.step()
    optim.zero_grad()


    if not (i % VALIDATE_EVERY):
        model.eval()

        valid_data = next(iter(val_loader))
        valid_loss = 0.

        with torch.no_grad():
            xl_memories = None
            model.knn.clear()
            seq, labels = data[:, :-1], data[:, 1:]

            for seq_segment, labels_segment in zip(seq.chunk(SEGMENTS, dim = -1), labels.chunk(SEGMENTS, dim = -1)):

                loss, xl_memories = model(
                    seq_segment,
                    labels = labels_segment,
                    xl_memories = xl_memories
                )

                valid_loss += loss.item() / SEGMENTS

        print(f'valid loss: {valid_loss}')


training:   0%|          | 0/200 [00:00<?, ?it/s]

Begin document
training loss: 5.039541959762574
End document, total time: 8.374260902404785


training:   0%|          | 1/200 [00:15<50:12, 15.14s/it]

valid loss: 4.46309003829956
Begin document
training loss: 4.492447185516358
End document, total time: 7.744511842727661
Begin document


training:   2%|▏         | 3/200 [00:30<31:27,  9.58s/it]

training loss: 4.098450136184693
End document, total time: 7.306441783905029
Begin document
training loss: 3.8796200037002566
End document, total time: 7.242147207260132
Begin document


training:   2%|▎         | 5/200 [00:45<27:28,  8.46s/it]

training loss: 3.7435310840606686
End document, total time: 7.328748464584351
Begin document
training loss: 3.616596484184265
End document, total time: 7.472764492034912
Begin document


training:   4%|▎         | 7/200 [01:00<25:56,  8.06s/it]

training loss: 3.5251991748809814
End document, total time: 7.360278844833374
Begin document
training loss: 3.4646272182464597
End document, total time: 7.3023681640625
Begin document


training:   4%|▍         | 9/200 [01:15<24:51,  7.81s/it]

training loss: 3.438588643074036
End document, total time: 7.276176452636719
Begin document
training loss: 3.4132254362106322
End document, total time: 7.055574893951416
Begin document


training:   6%|▌         | 11/200 [01:29<24:08,  7.66s/it]

training loss: 3.372542500495911
End document, total time: 7.504054546356201
Begin document
training loss: 3.334512138366699
End document, total time: 7.23045539855957
Begin document


training:   6%|▋         | 13/200 [01:44<23:27,  7.52s/it]

training loss: 3.29343364238739
End document, total time: 7.05449366569519
Begin document
training loss: 3.2777815341949466
End document, total time: 7.261979818344116
Begin document


training:   8%|▊         | 15/200 [01:59<23:13,  7.53s/it]

training loss: 3.2410950183868406
End document, total time: 7.6329450607299805
Begin document
training loss: 3.2314491987228395
End document, total time: 7.266247034072876
Begin document


training:   8%|▊         | 17/200 [02:14<22:49,  7.48s/it]

training loss: 3.1802069187164306
End document, total time: 7.261354446411133
Begin document
training loss: 3.1663676261901856
End document, total time: 7.186507701873779
Begin document


training:  10%|▉         | 19/200 [02:29<22:39,  7.51s/it]

training loss: 3.1007745027542115
End document, total time: 7.765432596206665
Begin document
training loss: 3.057208847999573
End document, total time: 7.328123569488525
Begin document


training:  10%|█         | 21/200 [02:44<22:16,  7.47s/it]

training loss: 3.0353387594223022
End document, total time: 7.185502767562866
Begin document
training loss: 3.004630327224731
End document, total time: 7.254883289337158
Begin document


training:  12%|█▏        | 23/200 [02:58<21:55,  7.43s/it]

training loss: 2.996585941314697
End document, total time: 7.224204778671265
Begin document
training loss: 3.0295111656188958
End document, total time: 7.321697950363159
Begin document


training:  12%|█▎        | 25/200 [03:13<21:34,  7.40s/it]

training loss: 2.937446904182434
End document, total time: 7.1251914501190186
Begin document
training loss: 2.922499585151672
End document, total time: 7.234288930892944
Begin document


training:  14%|█▎        | 27/200 [03:28<21:17,  7.39s/it]

training loss: 2.856919074058533
End document, total time: 7.2671058177948
Begin document
training loss: 2.8808781623840334
End document, total time: 7.482014179229736
Begin document


training:  14%|█▍        | 29/200 [03:42<21:01,  7.38s/it]

training loss: 2.8619038343429564
End document, total time: 7.038935661315918
Begin document
training loss: 2.8057425498962405
End document, total time: 7.161719083786011
Begin document


training:  16%|█▌        | 31/200 [03:57<20:44,  7.36s/it]

training loss: 2.811570572853088
End document, total time: 7.277841329574585
Begin document
training loss: 2.8176154375076297
End document, total time: 7.4859619140625
Begin document


training:  16%|█▋        | 33/200 [04:12<20:28,  7.36s/it]

training loss: 2.809566187858581
End document, total time: 6.982219934463501
Begin document
training loss: 2.8008985757827753
End document, total time: 6.822697877883911
Begin document


training:  18%|█▊        | 35/200 [04:26<19:57,  7.26s/it]

training loss: 2.7528455257415767
End document, total time: 7.0360987186431885
Begin document
training loss: 2.746201181411744
End document, total time: 6.935306549072266
Begin document


training:  18%|█▊        | 37/200 [04:40<19:36,  7.22s/it]

training loss: 2.752862143516541
End document, total time: 7.083675861358643
Begin document
training loss: 2.7754136562347416
End document, total time: 6.928543567657471
Begin document


training:  20%|█▉        | 39/200 [04:54<19:12,  7.16s/it]

training loss: 2.7197938442230223
End document, total time: 6.909160614013672
Begin document
training loss: 2.7179848670959474
End document, total time: 6.81702446937561
Begin document


training:  20%|██        | 41/200 [05:08<18:59,  7.17s/it]

training loss: 2.6873924970626826
End document, total time: 7.340386867523193
Begin document
training loss: 2.6663897991180416
End document, total time: 6.960350751876831
Begin document


training:  22%|██▏       | 43/200 [05:22<18:37,  7.12s/it]

training loss: 2.657948446273804
End document, total time: 6.856670379638672
Begin document
training loss: 2.6397240877151487
End document, total time: 7.130556583404541
Begin document


training:  22%|██▎       | 45/200 [05:37<18:39,  7.22s/it]

training loss: 2.671142244338989
End document, total time: 7.57828164100647
Begin document
training loss: 2.6713225126266478
End document, total time: 7.301489591598511
Begin document


training:  24%|██▎       | 47/200 [05:52<18:30,  7.26s/it]

training loss: 2.6617756366729735
End document, total time: 7.182066440582275
Begin document
training loss: 2.6341784715652463
End document, total time: 7.255206108093262
Begin document


training:  24%|██▍       | 49/200 [06:07<18:20,  7.29s/it]

training loss: 2.6546007156372067
End document, total time: 7.244778394699097
Begin document
training loss: 2.6411700248718266
End document, total time: 7.437213897705078
Begin document


training:  26%|██▌       | 51/200 [06:21<18:09,  7.31s/it]

training loss: 2.6175357818603517
End document, total time: 7.06542444229126
Begin document
training loss: 2.6057873249053958
End document, total time: 7.180307865142822
Begin document


training:  26%|██▋       | 53/200 [06:36<17:55,  7.32s/it]

training loss: 2.609294557571411
End document, total time: 7.271184921264648
Begin document
training loss: 2.5596313238143917
End document, total time: 7.516435623168945
Begin document


training:  28%|██▊       | 55/200 [06:51<17:48,  7.37s/it]

training loss: 2.615690732002258
End document, total time: 7.233374118804932
Begin document
training loss: 2.6076946258544917
End document, total time: 7.2973339557647705
Begin document


training:  28%|██▊       | 57/200 [07:06<17:34,  7.37s/it]

training loss: 2.5570107936859134
End document, total time: 7.268221616744995
Begin document
training loss: 2.5825605630874633
End document, total time: 7.8143088817596436
Begin document


training:  30%|██▉       | 59/200 [07:21<17:32,  7.46s/it]

training loss: 2.5541709184646604
End document, total time: 7.330108165740967
Begin document
training loss: 2.6037556886672975
End document, total time: 7.293806314468384
Begin document


training:  30%|███       | 61/200 [07:36<17:16,  7.46s/it]

training loss: 2.5487401485443115
End document, total time: 7.378065824508667
Begin document
training loss: 2.5489181756973265
End document, total time: 7.425987720489502
Begin document


training:  32%|███▏      | 63/200 [07:51<17:04,  7.48s/it]

training loss: 2.557217407226563
End document, total time: 7.42017388343811
Begin document
training loss: 2.556982421875
End document, total time: 7.328909158706665
Begin document


training:  32%|███▎      | 65/200 [08:06<16:48,  7.47s/it]

training loss: 2.5416264533996578
End document, total time: 7.332236051559448
Begin document
training loss: 2.540660238265991
End document, total time: 7.227803707122803
Begin document


training:  34%|███▎      | 67/200 [08:21<16:34,  7.48s/it]

training loss: 2.5474711894989013
End document, total time: 7.559648752212524
Begin document
training loss: 2.528946352005005
End document, total time: 7.324854135513306
Begin document


training:  34%|███▍      | 69/200 [08:36<16:17,  7.46s/it]

training loss: 2.523319864273071
End document, total time: 7.30050802230835
Begin document
training loss: 2.525755095481872
End document, total time: 7.1444621086120605
Begin document


training:  36%|███▌      | 71/200 [08:51<16:04,  7.47s/it]

training loss: 2.5115158081054685
End document, total time: 7.658086538314819
Begin document
training loss: 2.522643899917602
End document, total time: 7.3765246868133545
Begin document


training:  36%|███▋      | 73/200 [09:06<15:48,  7.47s/it]

training loss: 2.4932637453079223
End document, total time: 7.326087474822998
Begin document
training loss: 2.4944551706314084
End document, total time: 7.326231002807617
Begin document


training:  38%|███▊      | 75/200 [09:21<15:39,  7.51s/it]

training loss: 2.461370325088501
End document, total time: 7.698078155517578
Begin document
training loss: 2.497510313987732
End document, total time: 7.14426064491272
Begin document


training:  38%|███▊      | 77/200 [09:36<15:17,  7.46s/it]

training loss: 2.5114260196685794
End document, total time: 7.283469915390015
Begin document
training loss: 2.4832695007324217
End document, total time: 7.253715753555298
Begin document


training:  40%|███▉      | 79/200 [09:50<14:56,  7.41s/it]

training loss: 2.479548573493957
End document, total time: 7.111997127532959
Begin document
training loss: 2.4798885107040407
End document, total time: 7.357863903045654
Begin document


training:  40%|████      | 81/200 [10:05<14:40,  7.40s/it]

training loss: 2.5012339115142823
End document, total time: 7.181897163391113
Begin document
training loss: 2.521123480796814
End document, total time: 7.233805894851685
Begin document


training:  42%|████▏     | 83/200 [10:20<14:21,  7.36s/it]

training loss: 2.5215580701827998
End document, total time: 7.115772724151611
Begin document
training loss: 2.4796427726745605
End document, total time: 7.589867115020752
Begin document


training:  42%|████▎     | 85/200 [10:35<14:09,  7.39s/it]

training loss: 2.4966239213943484
End document, total time: 7.100458383560181
Begin document
training loss: 2.4449331283569333
End document, total time: 7.224069595336914
Begin document


training:  44%|████▎     | 87/200 [10:49<13:53,  7.38s/it]

training loss: 2.4806287765502932
End document, total time: 7.273103713989258
Begin document
training loss: 2.4743970394134522
End document, total time: 7.631104946136475
Begin document


training:  44%|████▍     | 89/200 [11:04<13:41,  7.40s/it]

training loss: 2.4577834367752076
End document, total time: 7.062979221343994
Begin document
training loss: 2.527753615379333
End document, total time: 7.20804762840271
Begin document


training:  46%|████▌     | 91/200 [11:19<13:23,  7.38s/it]

training loss: 2.494409966468811
End document, total time: 7.207025766372681
Begin document
training loss: 2.4606940031051634
End document, total time: 7.560499906539917
Begin document


training:  46%|████▋     | 93/200 [11:34<13:10,  7.38s/it]

training loss: 2.4756375312805177
End document, total time: 7.029901742935181
Begin document
training loss: 2.441540575027466
End document, total time: 6.931883335113525
Begin document


training:  48%|████▊     | 95/200 [11:48<12:45,  7.29s/it]

training loss: 2.4428407669067385
End document, total time: 6.988785028457642
Begin document
training loss: 2.542115473747253
End document, total time: 6.953532695770264
Begin document


training:  48%|████▊     | 97/200 [12:02<12:29,  7.28s/it]

training loss: 2.442001295089722
End document, total time: 7.362692356109619
Begin document
training loss: 2.4514965534210207
End document, total time: 6.967535018920898
Begin document


training:  50%|████▉     | 99/200 [12:16<12:09,  7.23s/it]

training loss: 2.4846585750579835
End document, total time: 7.029708385467529
Begin document
training loss: 2.494274663925171
End document, total time: 7.038853406906128
Begin document
training loss: 2.468542075157165
End document, total time: 7.2282140254974365


training:  50%|█████     | 101/200 [12:37<13:25,  8.13s/it]

valid loss: 2.4635545253753666
Begin document
training loss: 2.5085588693618774
End document, total time: 6.922603130340576
Begin document


training:  52%|█████▏    | 103/200 [12:51<12:39,  7.83s/it]

training loss: 2.44498507976532
End document, total time: 7.098491191864014
Begin document
training loss: 2.4806080579757688
End document, total time: 6.957918643951416
Begin document


training:  52%|█████▎    | 105/200 [13:06<12:08,  7.67s/it]

training loss: 2.4577840089797975
End document, total time: 7.422662973403931
Begin document
training loss: 2.475238013267517
End document, total time: 7.215893507003784
Begin document


training:  54%|█████▎    | 107/200 [13:20<11:43,  7.57s/it]

training loss: 2.4810818910598753
End document, total time: 7.245092391967773
Begin document
training loss: 2.4417475938797
End document, total time: 7.263494491577148
Begin document


training:  55%|█████▍    | 109/200 [13:36<11:29,  7.58s/it]

training loss: 2.4658073663711546
End document, total time: 7.739418029785156
Begin document
training loss: 2.4672456741333013
End document, total time: 7.017959117889404
Begin document


training:  56%|█████▌    | 111/200 [13:50<11:06,  7.48s/it]

training loss: 2.452367854118347
End document, total time: 7.279255628585815
Begin document
training loss: 2.5203794717788695
End document, total time: 7.164864540100098
Begin document


training:  56%|█████▋    | 113/200 [14:05<10:51,  7.49s/it]

training loss: 2.438669443130493
End document, total time: 7.605443239212036
Begin document
training loss: 2.4727033138275147
End document, total time: 7.244235992431641
Begin document


training:  57%|█████▊    | 115/200 [14:20<10:31,  7.43s/it]

training loss: 2.4092563390731807
End document, total time: 7.167381048202515
Begin document
training loss: 2.4299096822738653
End document, total time: 7.200842380523682
Begin document


training:  58%|█████▊    | 117/200 [14:34<10:13,  7.39s/it]

training loss: 2.4222781658172607
End document, total time: 7.185885906219482
Begin document
training loss: 2.43513011932373
End document, total time: 7.556690454483032
Begin document


training:  60%|█████▉    | 119/200 [14:50<10:03,  7.45s/it]

training loss: 2.4420917749404905
End document, total time: 7.366397142410278
Begin document
training loss: 2.435632276535034
End document, total time: 7.236774921417236
Begin document


training:  60%|██████    | 121/200 [15:04<09:46,  7.42s/it]

training loss: 2.4051583528518674
End document, total time: 7.2835612297058105
Begin document
training loss: 2.4559593915939333
End document, total time: 7.774540185928345
Begin document


training:  62%|██████▏   | 123/200 [15:20<09:37,  7.50s/it]

training loss: 2.446843147277832
End document, total time: 7.357535123825073
Begin document
training loss: 2.442533397674561
End document, total time: 7.3488781452178955
Begin document


training:  62%|██████▎   | 125/200 [15:34<09:20,  7.47s/it]

training loss: 2.418792414665222
End document, total time: 7.249176740646362
Begin document
training loss: 2.4549681901931764
End document, total time: 7.819297790527344
Begin document


training:  64%|██████▎   | 127/200 [15:50<09:09,  7.53s/it]

training loss: 2.4535364389419554
End document, total time: 7.283367395401001
Begin document
training loss: 2.4181143045425415
End document, total time: 7.31215763092041
Begin document


training:  64%|██████▍   | 129/200 [16:05<08:51,  7.48s/it]

training loss: 2.4734151363372803
End document, total time: 7.241044998168945
Begin document
training loss: 2.423294734954834
End document, total time: 7.424448490142822
Begin document


training:  66%|██████▌   | 131/200 [16:19<08:34,  7.46s/it]

training loss: 2.4382560253143315
End document, total time: 7.153767347335815
Begin document
training loss: 2.437877726554871
End document, total time: 7.1403725147247314
Begin document


training:  66%|██████▋   | 133/200 [16:34<08:16,  7.41s/it]

training loss: 2.4689619064331056
End document, total time: 7.217736005783081
Begin document
training loss: 2.390569567680359
End document, total time: 7.051027059555054
Begin document


training:  68%|██████▊   | 135/200 [16:49<08:01,  7.41s/it]

training loss: 2.369512414932251
End document, total time: 7.586649656295776
Begin document
training loss: 2.4353806972503667
End document, total time: 7.249579429626465
Begin document


training:  68%|██████▊   | 137/200 [17:03<07:46,  7.40s/it]

training loss: 2.4548264503479005
End document, total time: 7.252155065536499
Begin document
training loss: 2.4464282751083375
End document, total time: 7.142266035079956
Begin document


training:  70%|██████▉   | 139/200 [17:19<07:35,  7.47s/it]

training loss: 2.4422513961791994
End document, total time: 7.8954386711120605
Begin document
training loss: 2.500449371337891
End document, total time: 7.143323659896851
Begin document


training:  70%|███████   | 141/200 [17:33<07:17,  7.42s/it]

training loss: 2.4341185569763186
End document, total time: 7.2466278076171875
Begin document
training loss: 2.4435135602951052
End document, total time: 7.287298917770386
Begin document


training:  72%|███████▏  | 143/200 [17:49<07:06,  7.48s/it]

training loss: 2.4254390239715575
End document, total time: 7.759059190750122
Begin document
training loss: 2.4071861743927006
End document, total time: 7.2668962478637695
Begin document


training:  72%|███████▎  | 145/200 [18:03<06:49,  7.45s/it]

training loss: 2.4406036138534546
End document, total time: 7.2666168212890625
Begin document
training loss: 2.406069707870483
End document, total time: 7.278563737869263
Begin document


training:  74%|███████▎  | 147/200 [18:18<06:33,  7.43s/it]

training loss: 2.437677597999573
End document, total time: 7.282994747161865
Begin document
training loss: 2.4406843423843383
End document, total time: 7.507497310638428
Begin document


training:  74%|███████▍  | 149/200 [18:33<06:18,  7.43s/it]

training loss: 2.414120817184448
End document, total time: 7.144364595413208
Begin document
training loss: 2.4483897447586056
End document, total time: 7.252863883972168
Begin document


training:  76%|███████▌  | 151/200 [18:48<06:03,  7.41s/it]

training loss: 2.4357647657394406
End document, total time: 7.266793966293335
Begin document
training loss: 2.4323696851730343
End document, total time: 7.618861675262451
Begin document


training:  76%|███████▋  | 153/200 [19:03<05:48,  7.42s/it]

training loss: 2.434057831764221
End document, total time: 7.063844203948975
Begin document
training loss: 2.400449752807617
End document, total time: 7.07490086555481
Begin document


training:  78%|███████▊  | 155/200 [19:17<05:30,  7.35s/it]

training loss: 2.482611656188965
End document, total time: 7.075734376907349
Begin document
training loss: 2.4189081430435184
End document, total time: 7.592996597290039
Begin document


training:  78%|███████▊  | 157/200 [19:32<05:15,  7.34s/it]

training loss: 2.433962678909302
End document, total time: 6.840792655944824
Begin document
training loss: 2.405925035476684
End document, total time: 7.0137104988098145
Begin document


training:  80%|███████▉  | 159/200 [19:46<04:57,  7.25s/it]

training loss: 2.4208086252212526
End document, total time: 6.870612382888794
Begin document
training loss: 2.4255853891372676
End document, total time: 6.945481061935425
Begin document


training:  80%|████████  | 161/200 [20:00<04:41,  7.23s/it]

training loss: 2.397289824485779
End document, total time: 7.173508405685425
Begin document


training:  80%|████████  | 161/200 [20:04<04:51,  7.48s/it]


KeyboardInterrupt: 