# Attention Tests
---

## Load Packages

In [1]:
import os
import sys

module_path = os.path.abspath(os.path.join("../../../"))
print(f"module_path: {module_path}")
if module_path not in sys.path:
    print(f"Adding {module_path} to sys.path")
    sys.path.append(module_path)

module_path: /home/beegass/Documents/Coding/HiPPO-Jax


In [2]:
# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False"
os.environ["TF_FORCE_UNIFIED_MEMORY"] = "1"

In [3]:
## import packages
import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
import torch.nn as tnn
import torch.nn.functional as F
from src.models.hippo.hippo import HiPPOLSI, HiPPOLTI
from src.data.process import whitesignal
import einops
from jaxtyping import Array, Float
from typing import Optional, Tuple
import numpy as np
import torch
import time

  jax.tree_util.register_keypaths(data_clz, keypaths)
  jax.tree_util.register_keypaths(data_clz, keypaths)


In [4]:
print(jax.devices())
print(f"The Device: {jax.lib.xla_bridge.get_backend().platform}")

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]
The Device: gpu


In [5]:
print(f"MPS enabled: {torch.backends.mps.is_available()}")

MPS enabled: False


In [6]:
torch.set_printoptions(linewidth=150)
np.set_printoptions(linewidth=150)
jnp.set_printoptions(linewidth=150)

Jax Seeding

In [7]:
seed = 1701
key = jax.random.PRNGKey(seed)

In [8]:
num_copies = 10
subkeys = jax.random.split(key, num=num_copies)
key = subkeys[0]

Torch Seeding

In [9]:
torch.manual_seed(seed)

<torch._C.Generator at 0x7f28060877f0>

Numpy Seeding

In [10]:
np.random.seed(seed)

## Torch Attention Block Unification
---

### Quick note on Transformer Block Unification

Fun aside is that the MLP blocks of a Transformer are actually identical to the Self-Attention blocks of the transformer, with a few changes:

- the query projection is missing, the data itself is the query
- the key, value are data-independent parameters (i.e. the MLP block is really just a cross-attention "soft lookup" into a fixed {key:value} table
- the Softmax (map/reduce non-linearity) is replaced with GeLU (map-only non-linearity)
- the final Linear projection back to the residual pathway is missing

This immediately suggests a unification of these blocks into a more general Transformer "superblock", that is simply wired up to the residual pathway either in parallel (e.g. as in all the heads of a multi-headed self-attention), or in series (as usually done from block to block otherwise). It also suggests in-between generalizations, e.g. multi-headed attention suggests the equivalent use of "groups" in Linear (or Conv) layers. Alternatively, attention could be done over two pools of nodes simultaneously: those where key,value are data-dependent and those that aren't, dispensing with the need for a distinction.

**TLDR**: A much simpler Transformer with a single type of block wired up to a residual pathway in both parallel and in series is possible but to my knowledge has not yet been convincingly achieved.

In [11]:
B, T, C = 8, 512, 128  # batch, sequence length, channels
n_head = 4
n_embd = C

In [12]:
x = torch.randn(B, T, C)
x.shape

torch.Size([8, 512, 128])

In [13]:
# self-attention block of a single Head
d_head = C // n_head  # head size, 128/4 = 32
key = tnn.Linear(C, d_head, bias=False)
query = tnn.Linear(C, d_head, bias=False)
value = tnn.Linear(C, d_head, bias=False)
proj = tnn.Linear(d_head, C, bias=False)
k = key(x)
q = query(x)
v = value(x)
print(f"key shape: {k.shape}")
print(f"query shape: {q.shape}")
print(f"value shape: {v.shape}")
att = torch.softmax(q @ k.transpose(-2, -1), dim=-1)
y = att @ v
print(f"context shape: {y.shape}")
r = proj(
    y
)  # standard self-attention blocks have one more Linear when back to residual pathway
print(f"projected context shape: {r.shape}")

key shape: torch.Size([8, 512, 32])
query shape: torch.Size([8, 512, 32])
value shape: torch.Size([8, 512, 32])
context shape: torch.Size([8, 512, 32])
projected context shape: torch.Size([8, 512, 128])


In [14]:
# typical linear block on a Transformer
layer1 = tnn.Linear(C, C * 4, bias=False)
layer2 = tnn.Linear(C * 4, C, bias=False)
l1 = F.gelu(layer1(x))
l2 = layer2(l1)  # projects back to residual pathway
l2.shape

torch.Size([8, 512, 128])

In [15]:
# linear block is actually attention over a fixed (not data-dependent) {k:v} dict
q = x  # change 1: query is simply the input
k = layer1.weight  # key and value are data-independent learnable parameters
v = layer2.weight.T
att = F.gelu(q @ k.transpose(-2, -1))  # change 2: using gelu instead of softmax
y = att @ v
y.shape

torch.Size([8, 512, 128])

In [16]:
(l2 == y).all()  # cool

tensor(True)

## Flax Attention Block Unification
---

In [17]:
class SelfAttention(nn.Module):
    """
    Dot-Product Attention: Compute the dot products of the query with all keys,
    and apply a softmax function to obtain the weights on the values

    Attributes:
        n_head: The number of attention heads.
        d_model: The dimension of the input.
        dtype: The data type of the computation. Default is jnp.float32.
    """

    n_head: int  # number of heads the attention is split into
    d_model: int  # dimension of the input, aka n_embd or C which is the size of the embedding.
    dtype: jnp.dtype = jnp.float32  # data type of the computation (default: float32)

    def setup(self) -> None:
        # Check if d_model is divisible by n_head to ensure the input can be evenly distributed among all heads
        assert self.d_model % self.n_head == 0

        # Compute the size of each head by dividing the input dimension by the number of heads. head size, e.g. 128/4 = 32
        self.d_head = self.d_model // self.n_head

        # Create dense layers for key, query, and value with the dimension size of each head.
        self.key = nn.Dense(
            self.d_head,
            kernel_init=nn.initializers.xavier_uniform(),  # Weights with Xavier uniform init
            bias_init=nn.initializers.zeros,  # Bias init with zeros
            name=f"key_layer",
        )
        self.query = nn.Dense(
            self.d_head,
            kernel_init=nn.initializers.xavier_uniform(),  # Weights with Xavier uniform init
            bias_init=nn.initializers.zeros,  # Bias init with zeros
            name=f"query_layer",
        )
        self.value = nn.Dense(
            self.d_head,
            kernel_init=nn.initializers.xavier_uniform(),  # Weights with Xavier uniform init
            bias_init=nn.initializers.zeros,  # Bias init with zeros
            name=f"value_layer",
        )

        # Create a dense layer for projecting the output back to the original dimension size.
        self.proj = nn.Dense(
            self.d_model,
            kernel_init=nn.initializers.xavier_uniform(),  # Weights with Xavier uniform init
            bias_init=nn.initializers.zeros,  # Bias init with zeros
            name=f"proj_layer",
        )

    def __call__(
        self,
        query: Float[Array, "*batch d_model"],
        key: Float[Array, "*batch d_model"],
        value: Float[Array, "*batch d_model"],
        mask: Optional[Float[Array, "*batch d_model"]] = None,
    ) -> Tuple[Float[Array, "*batch d_model"], Float[Array, "*batch d_model"]]:
        """
        Call method, used for calculating the forward pass for the Self Attention Module.

        Args:
            query (jnp.ndarray):
                Shape: (batch d_model)
                The query tensor.

            key (jnp.ndarray):
                Shape: (batch d_model)
                The key tensor.

            value (jnp.ndarray):
                Shape: (batch d_model)
                The value tensor.

        Returns:
            r (jnp.ndarray):
                Shape: (batch d_model)
                The projected context tensor back to the original dimension size.
        """
        _k = self.key(key)
        q = self.query(query)
        v = self.value(value)
        k = einops.rearrange(_k, "... i j -> ... j i")

        # Calculate the attention scores by taking the dot product of query and key
        score = (q @ k) / jnp.sqrt(self.d_head)

        # Masking to avoid performing attention on padding token indices.
        if mask is not None:
            # Set the score for all padding token indices to a large negative value
            score = jnp.where(
                mask == 0, -9e15, score
            )  # -9e15 is a very large negative number

        # then apply softmax to get probabilities.
        attn = nn.softmax(score, axis=-1)

        # Multiply the attention scores with the value to get the context
        context = attn @ v

        # Project the context back to the original dimension size using the projection layer
        out = self.proj(context)

        return out

In [18]:
class ScaledDotProductAttention(nn.Module):
    """

    Attributes:
        n_head: The number of attention heads.
        d_model: The dimension of the input.
        dtype: The data type of the computation. Default is jnp.float32.
    """

    n_head: int  # number of heads the attention is split into
    d_model: int  # dimension of the input, aka n_embd or C which is the size of the embedding.
    dtype: jnp.dtype = jnp.float32  # data type of the computation (default: float32)

    def setup(self) -> None:
        # Check if d_model is divisible by n_head to ensure the input can be evenly distributed among all heads
        assert self.d_model % self.n_head == 0

        # Compute the size of each head by dividing the input dimension by the number of heads. head size, e.g. 128/4 = 32
        self.d_head = self.d_model // self.n_head

    def __call__(
        self,
        query: Float[Array, "*batch seq_len d_model"],
        key: Float[Array, "*batch seq_len d_model"],
        value: Float[Array, "*batch seq_len d_model"],
        mask: Optional[Float[Array, "*batch seq_len d_model"]] = None,
    ) -> Tuple[
        Float[Array, "*batch seq_len d_model"], Float[Array, "*batch seq_len seq_len"]
    ]:
        """
        Call method, used for calculating the forward pass for the Self Attention Module.

        Args:
            query (jnp.ndarray):
                Shape: (batch d_model)
                The query tensor.

            key (jnp.ndarray):
                Shape: (batch d_model)
                The key tensor.

            value (jnp.ndarray):
                Shape: (batch d_model)
                The value tensor.

        Returns:
            r (jnp.ndarray):
                Shape: (batch d_model)
                The projected context tensor back to the original dimension size.
        """

        key = einops.rearrange(key, "... i j -> ... j i")

        # Calculate the attention scores by taking the dot product of query and key
        score = (query @ key) / jnp.sqrt(self.d_head)

        # Masking to avoid performing attention on padding token indices.
        if mask is not None:
            assert (
                mask.shape == score.shape
            ), f"Mask shape {mask.shape} must match score shape {score.shape}"

            # Set the score for all padding token indices to a large negative value
            score = jnp.where(
                mask == 0, -9e15, score
            )  # -9e15 is a very large negative number

        # then apply softmax to get probabilities.
        attn = nn.softmax(score, axis=-1)

        # Multiply the attention scores with the value to get the context
        context = attn @ value

        return context, attn

In [19]:
class MultiHeadAttention(nn.Module):
    n_head: int  # number of heads the attention is split into
    d_model: int  # dimension of the input, aka n_embd or C which is the size of the embedding.
    # step: int  # step size for the GBT
    # lambda_n: float = 1.0  # lambda_n for the LegT
    # alpha: float = 2.0  # alpha for the GBT,
    # measure: str = "legs"  # measure for type of the polynomial,
    # basis: float = 1.0  # basis for the polynomial
    # unroll: bool = False  # unroll the loop for the output
    dtype: jnp.dtype = jnp.float32  # data type of the computation (default: float32)

    def setup(self) -> None:
        # Check if d_model is divisible by n_head to ensure the input can be evenly distributed among all heads
        assert self.d_model % self.n_head == 0

        # Compute the size of each head by dividing the input dimension by the number of heads. head size, e.g. 128/4 = 32
        self.d_head = self.d_model // self.n_head

        # Create dense layers for key, query, and value with the dimension size of each head.

        self.key = nn.Dense(
            self.d_model,
            kernel_init=nn.initializers.xavier_uniform(),  # Weights with Xavier uniform init
            bias_init=nn.initializers.zeros,  # Bias init with zeros
            name=f"key_layer",
        )
        self.query = nn.Dense(
            self.d_model,
            kernel_init=nn.initializers.xavier_uniform(),  # Weights with Xavier uniform init
            bias_init=nn.initializers.zeros,  # Bias init with zeros
            name=f"query_layer",
        )
        self.value = nn.Dense(
            self.d_model,
            kernel_init=nn.initializers.xavier_uniform(),  # Weights with Xavier uniform init
            bias_init=nn.initializers.zeros,  # Bias init with zeros
            name=f"value_layer",
        )

        # Create a dense layer for projecting the output back to the original dimension size.
        self.proj = nn.Dense(
            self.d_model,
            kernel_init=nn.initializers.xavier_uniform(),  # Weights with Xavier uniform init
            bias_init=nn.initializers.zeros,  # Bias init with zeros
            name=f"proj_layer",
        )

        self.attention = ScaledDotProductAttention(
            n_head=self.n_head, d_model=self.d_model
        )

    def __call__(
        self,
        query: Float[Array, "*batch seq_len d_model"],
        key: Float[Array, "*batch seq_len d_model"],
        value: Float[Array, "*batch seq_len d_model"],
        mask: Optional[Float[Array, "*batch seq_len d_model"]] = None,
    ) -> Tuple[
        Float[Array, "*batch seq_len d_model"], Float[Array, "*batch seq_len seq_len"]
    ]:
        jax.debug.print("query shape: {x}", x=query.shape)
        jax.debug.print("key shape: {x}", x=key.shape)
        jax.debug.print("value shape: {x}", x=value.shape)

        q = self.query(query)
        k = self.key(key)
        v = self.value(value)

        q = einops.rearrange(
            q,
            "... seq_len (n_head d_head) -> ... seq_len n_head d_head",
            n_head=self.n_head,
            d_head=self.d_head,
        )
        k = einops.rearrange(
            k,
            "... seq_len (n_head d_head) -> ... seq_len n_head d_head",
            n_head=self.n_head,
            d_head=self.d_head,
        )
        v = einops.rearrange(
            v,
            "... seq_len (n_head d_head) -> ... seq_len n_head d_head",
            n_head=self.n_head,
            d_head=self.d_head,
        )

        jax.debug.print("q shape: {x}", x=q.shape)
        jax.debug.print("k shape: {x}", x=k.shape)
        jax.debug.print("v shape: {x}\n", x=v.shape)

        context, attn = jax.vmap(self.attention, in_axes=(1, 1, 1, None))(q, k, v, mask)

        context = einops.rearrange(
            context,
            "... n_head seq_len d_head -> ... seq_len (n_head d_head)",
            n_head=self.n_head,
            d_head=self.d_head,
        )

        # jax.debug.print("context: {x}\n", x=context)
        jax.debug.print("context shape: {x}", x=context.shape)
        # jax.debug.print("attn: {x}\n", x=attn)
        jax.debug.print("attn shape: {x}", x=attn.shape)

        # Project the context back to the original dimension size using the projection layer
        out = self.proj(context)

        # jax.debug.print("out: {x}\n", x=out)
        jax.debug.print("out shape: {x}", x=out.shape)

        return out, attn

In [20]:
class TransformerBlock(nn.Module):
    d_model: int  # Input dimension is needed here since it is equal to the output dimension (residual connection)
    n_head: int  # number of heads the attention is split into
    ffn_expan: int  # expansion factor for the feedforward layer
    _dropout: float  # dropout probability
    # step: int  # step size for the GBT
    # lambda_n: float = 1.0  # lambda_n for the LegT
    # alpha: float = 2.0  # alpha for the GBT,
    # measure: str = "legs"  # measure for type of the polynomial,
    # basis: float = 1.0  # basis for the polynomial
    # unroll: bool = False  # unroll the loop for the output

    def setup(self):
        # Attention layer
        self.attention = MultiHeadAttention(n_head=self.n_head, d_model=self.d_model)

        # Two-layer MLP
        self.ffn = [
            nn.Dense(self.ffn_expan * self.d_model),
            nn.Dropout(rate=self._dropout),
            nn.relu,
            nn.Dense(self.d_model),
        ]
        # Layers to apply in between the main layers
        self.norm1 = nn.LayerNorm()
        self.norm2 = nn.LayerNorm()
        self.dropout = nn.Dropout(rate=self._dropout)

    def __call__(
        self,
        query: Float[Array, "*batch seq_len d_model"],
        key: Float[Array, "*batch seq_len d_model"],
        value: Float[Array, "*batch seq_len d_model"],
        mask: Optional[Float[Array, "*batch seq_len d_model"]] = None,
        train: bool = True,
    ) -> Float[Array, "*batch seq_len d_model"]:

        # Attention part
        proj_context, attn = jax.vmap(self.attention, in_axes=(0, 0, 0, 0))(
            self.norm1(query), self.norm1(key), self.norm1(value), mask
        )
        x = query + self.dropout(proj_context, deterministic=not train)

        # MLP part
        linear_out = self.norm2(x)
        for layer in self.ffn:
            if not isinstance(layer, nn.Dropout):
                linear_out = layer(linear_out)
            else:
                linear_out = layer(linear_out, deterministic=not train)

        x = x + self.dropout(linear_out, deterministic=not train)

        return x

In [21]:
def test_transformer_block():
    # Parameters
    batch_size = 16
    seq_len = 128
    d_model = 512
    n_head = 8
    ffn_expan = 4
    dropout = 0.1

    # Create transformer block
    transformer_block = TransformerBlock(
        d_model=d_model, n_head=n_head, ffn_expan=ffn_expan, _dropout=dropout
    )

    # Initialize parameters
    variables = transformer_block.init(
        {"params": subkeys[1], "dropout": subkeys[2]},
        jnp.ones((batch_size, seq_len, d_model)),
        jnp.ones((batch_size, seq_len, d_model)),
        jnp.ones((batch_size, seq_len, d_model)),
    )
    params = variables["params"]

    # Test forward pass
    y = transformer_block.apply(
        {"params": params},
        jnp.ones((batch_size, seq_len, d_model)),
        jnp.ones((batch_size, seq_len, d_model)),
        jnp.ones((batch_size, seq_len, d_model)),
        rngs={"dropout": subkeys[4]},
    )

    assert y.shape == (
        batch_size,
        seq_len,
        d_model,
    ), f"Unexpected output shape: {y.shape}"
    print("Transformer block test successful!")


test_transformer_block()

query shape: (128, 512)
key shape: (128, 512)
value shape: (128, 512)
q shape: (128, 8, 64)
k shape: (128, 8, 64)
v shape: (128, 8, 64)

context shape: (128, 512)
attn shape: (8, 128, 128)
out shape: (128, 512)
query shape: (128, 512)
key shape: (128, 512)
value shape: (128, 512)
q shape: (128, 8, 64)
k shape: (128, 8, 64)
v shape: (128, 8, 64)

context shape: (128, 512)
attn shape: (8, 128, 128)
out shape: (128, 512)
Transformer block test successful!


In [22]:
class TransformerDecoderBlock(nn.Module):
    d_model: int  # Input dimension is needed here since it is equal to the output dimension (residual connection)
    n_head: int  # number of heads the attention is split into
    ffn_expan: int  # expansion factor for the feedforward layer
    dropout: float  # dropout probability

    def setup(self):
        # Attention layer
        self.attention = MultiHeadAttention(n_head=self.n_head, d_model=self.d_model)
        self.transformer_block = TransformerBlock(
            d_model=self.d_model,
            n_head=self.n_head,
            ffn_expan=self.ffn_expan,
            _dropout=self.dropout,
        )

        # Layers to apply in between the main layers
        self.norm1 = nn.LayerNorm()
        self.norm2 = nn.LayerNorm()
        self.dropout = nn.Dropout(rate=self.dropout)

    def __call__(
        self,
        x: Float[Array, "*batch d_model"],
        key: Float[Array, "*batch d_model"],
        value: Float[Array, "*batch d_model"],
        mask: Optional[Float[Array, "*batch d_model"]] = None,
        trg_mask: Optional[Float[Array, "*batch d_model"]] = None,
        train: bool = True,
    ) -> Float[Array, "*batch d_model"]:
        # Masked Attention part
        mask_proj_context, mask_attn = jax.vmap(self.attention, in_axes=(0, 0, 0, 0))(
            self.norm1(x), self.norm1(key), self.norm1(value), trg_mask
        )
        query = x + self.dropout(mask_proj_context, deterministic=not train)

        # Encoder Attention part
        out = self.transformer_block(
            self.norm2(query), self.norm2(key), self.norm2(value), mask=mask
        )

        return out