In [1]:

# imports
import os
import sys
import types
import json
import base64

# figure size/format
fig_width = 7
fig_height = 5
fig_format = 'retina'
fig_dpi = 96
interactivity = 'all'
is_shiny = False
is_dashboard = False
plotly_connected = True

# matplotlib defaults / format
try:
  import matplotlib.pyplot as plt
  plt.rcParams['figure.figsize'] = (fig_width, fig_height)
  plt.rcParams['figure.dpi'] = fig_dpi
  plt.rcParams['savefig.dpi'] = "figure"
  from IPython.display import set_matplotlib_formats
  set_matplotlib_formats(fig_format)
except Exception:
  pass

# plotly use connected mode
try:
  import plotly.io as pio
  if plotly_connected:
    pio.renderers.default = "notebook_connected"
  else:
    pio.renderers.default = "notebook"
  for template in pio.templates.keys():
    pio.templates[template].layout.margin = dict(t=30,r=0,b=0,l=0)
except Exception:
  pass

# disable itables paging for dashboards
if is_dashboard:
  try:
    from itables import options
    options.dom = 'fiBrtlp'
    options.maxBytes = 1024 * 1024
    options.language = dict(info = "Showing _TOTAL_ entries")
    options.classes = "display nowrap compact"
    options.paging = False
    options.searching = True
    options.ordering = True
    options.info = True
    options.lengthChange = False
    options.autoWidth = False
    options.responsive = True
    options.keys = True
    options.buttons = []
  except Exception:
    pass
  
  try:
    import altair as alt
    # By default, dashboards will have container sized
    # vega visualizations which allows them to flow reasonably
    theme_sentinel = '_quarto-dashboard-internal'
    def make_theme(name):
        nonTheme = alt.themes._plugins[name]    
        def patch_theme(*args, **kwargs):
            existingTheme = nonTheme()
            if 'height' not in existingTheme:
              existingTheme['height'] = 'container'
            if 'width' not in existingTheme:
              existingTheme['width'] = 'container'

            if 'config' not in existingTheme:
              existingTheme['config'] = dict()
            
            # Configure the default font sizes
            title_font_size = 15
            header_font_size = 13
            axis_font_size = 12
            legend_font_size = 12
            mark_font_size = 12
            tooltip = False

            config = existingTheme['config']

            # The Axis
            if 'axis' not in config:
              config['axis'] = dict()
            axis = config['axis']
            if 'labelFontSize' not in axis:
              axis['labelFontSize'] = axis_font_size
            if 'titleFontSize' not in axis:
              axis['titleFontSize'] = axis_font_size  

            # The legend
            if 'legend' not in config:
              config['legend'] = dict()
            legend = config['legend']
            if 'labelFontSize' not in legend:
              legend['labelFontSize'] = legend_font_size
            if 'titleFontSize' not in legend:
              legend['titleFontSize'] = legend_font_size  

            # The header
            if 'header' not in config:
              config['header'] = dict()
            header = config['header']
            if 'labelFontSize' not in header:
              header['labelFontSize'] = header_font_size
            if 'titleFontSize' not in header:
              header['titleFontSize'] = header_font_size    

            # Title
            if 'title' not in config:
              config['title'] = dict()
            title = config['title']
            if 'fontSize' not in title:
              title['fontSize'] = title_font_size

            # Marks
            if 'mark' not in config:
              config['mark'] = dict()
            mark = config['mark']
            if 'fontSize' not in mark:
              mark['fontSize'] = mark_font_size

            # Mark tooltips
            if tooltip and 'tooltip' not in mark:
              mark['tooltip'] = dict(content="encoding")

            return existingTheme
            
        return patch_theme

    # We can only do this once per session
    if theme_sentinel not in alt.themes.names():
      for name in alt.themes.names():
        alt.themes.register(name, make_theme(name))
      
      # register a sentinel theme so we only do this once
      alt.themes.register(theme_sentinel, make_theme('default'))
      alt.themes.enable('default')

  except Exception:
    pass

# enable pandas latex repr when targeting pdfs
try:
  import pandas as pd
  if fig_format == 'pdf':
    pd.set_option('display.latex.repr', True)
except Exception:
  pass

# interactivity
if interactivity:
  from IPython.core.interactiveshell import InteractiveShell
  InteractiveShell.ast_node_interactivity = interactivity

# NOTE: the kernel_deps code is repeated in the cleanup.py file
# (we can't easily share this code b/c of the way it is run).
# If you edit this code also edit the same code in cleanup.py!

# output kernel dependencies
kernel_deps = dict()
for module in list(sys.modules.values()):
  # Some modules play games with sys.modules (e.g. email/__init__.py
  # in the standard library), and occasionally this can cause strange
  # failures in getattr.  Just ignore anything that's not an ordinary
  # module.
  if not isinstance(module, types.ModuleType):
    continue
  path = getattr(module, "__file__", None)
  if not path:
    continue
  if path.endswith(".pyc") or path.endswith(".pyo"):
    path = path[:-1]
  if not os.path.exists(path):
    continue
  kernel_deps[path] = os.stat(path).st_mtime
print(json.dumps(kernel_deps))

# set run_path if requested
run_path = 'L2hvbWUvb3Jlbi93b3JrL25vdGVzL25vdGVzLW5scC9wb3N0cy9jNHc0'
if run_path:
  # hex-decode the path
  run_path = base64.b64decode(run_path.encode("utf-8")).decode("utf-8")
  os.chdir(run_path)

# reset state
%reset

# shiny
# Checking for shiny by using False directly because we're after the %reset. We don't want
# to set a variable that stays in global scope.
if False:
  try:
    import htmltools as _htmltools
    import ast as _ast

    _htmltools.html_dependency_render_mode = "json"

    # This decorator will be added to all function definitions
    def _display_if_has_repr_html(x):
      try:
        # IPython 7.14 preferred import
        from IPython.display import display, HTML
      except:
        from IPython.core.display import display, HTML

      if hasattr(x, '_repr_html_'):
        display(HTML(x._repr_html_()))
      return x

    # ideally we would undo the call to ast_transformers.append
    # at the end of this block whenver an error occurs, we do 
    # this for now as it will only be a problem if the user 
    # switches from shiny to not-shiny mode (and even then likely
    # won't matter)
    import builtins
    builtins._display_if_has_repr_html = _display_if_has_repr_html

    class _FunctionDefReprHtml(_ast.NodeTransformer):
      def visit_FunctionDef(self, node):
        node.decorator_list.insert(
          0,
          _ast.Name(id="_display_if_has_repr_html", ctx=_ast.Load())
        )
        return node

      def visit_AsyncFunctionDef(self, node):
        node.decorator_list.insert(
          0,
          _ast.Name(id="_display_if_has_repr_html", ctx=_ast.Load())
        )
        return node

    ip = get_ipython()
    ip.ast_transformers.append(_FunctionDefReprHtml())

  except:
    pass

def ojs_define(**kwargs):
  import json
  try:
    # IPython 7.14 preferred import
    from IPython.display import display, HTML
  except:
    from IPython.core.display import display, HTML

  # do some minor magic for convenience when handling pandas
  # dataframes
  def convert(v):
    try:
      import pandas as pd
    except ModuleNotFoundError: # don't do the magic when pandas is not available
      return v
    if type(v) == pd.Series:
      v = pd.DataFrame(v)
    if type(v) == pd.DataFrame:
      j = json.loads(v.T.to_json(orient='split'))
      return dict((k,v) for (k,v) in zip(j["index"], j["data"]))
    else:
      return v

  v = dict(contents=list(dict(name=key, value=convert(value)) for (key, value) in kwargs.items()))
  display(HTML('<script type="ojs-define">' + json.dumps(v) + '</script>'), metadata=dict(ojs_define = True))
globals()["ojs_define"] = ojs_define
globals()["__spec__"] = None

  set_matplotlib_formats(fig_format)




In [2]:
import os
import trax
from trax import layers as tl  # core building block
import jax
from trax import fastmath  # uses jax, offers numpy on steroids

# fastmath.use_backend('tensorflow-numpy')
import functools
from trax.fastmath import numpy as np  # note, using fastmath subset of numpy!
from trax.layers import (
    #tie_in,
    length_normalized,
    apply_broadcasted_dropout,
    look_adjacent,
    permute_via_gather,
    permute_via_sort,
)


def tie_in(x, y):
  if fastmath.backend_name() == 'jax':
    return jax.lax.tie_in(x, y)
  return y

2025-02-07 18:21:15.571442: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1738945275.585709   54637 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1738945275.590225   54637 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
def mask_self_attention(
    dots, q_info, kv_info, causal=True, exclude_self=True, masked=False
):
    """Performs masking for self-attention."""
    if causal:
        mask = fastmath.lt(q_info, kv_info).astype(np.float32)
        dots = dots - 1e9 * mask
    if exclude_self:
        mask = np.equal(q_info, kv_info).astype(np.float32)
        dots = dots - 1e5 * mask
    if masked:
        zeros_like_kv_info = tie_in(kv_info, np.zeros_like(kv_info))
        mask = fastmath.lt(kv_info, zeros_like_kv_info).astype(np.float32)
        dots = dots - 1e9 * mask
    return dots

In [4]:
def our_softmax(x, passthrough=False):
    """ softmax with passthrough"""
    logsumexp = fastmath.logsumexp(x, axis=-1, keepdims=True)
    o = np.exp(x - logsumexp)
    if passthrough:
        return (x, np.zeros_like(logsumexp))
    else:
        return (o, logsumexp)

In [5]:
## compare softmax(a) using both methods
a = np.array([1.0, 2.0, 3.0, 4.0])
sma = np.exp(a) / sum(np.exp(a))
print(sma)
sma2, a_logsumexp = our_softmax(a)
print(sma2)
print(a_logsumexp)

[0.0320586  0.08714432 0.23688282 0.6439142 ]


[0.0320586  0.0871443  0.23688279 0.64391416]
[4.44019]


In [6]:
def our_simple_attend(
    q, k=None, v=None,
    mask_fn=None, q_info=None, kv_info=None,
    dropout=0.0, rng=None, verbose=False, passthrough=False
    ):
  """Dot-product attention,  with masking, without optional chunking and/or.

  Args:
    q: Query vectors, shape [q_len, d_qk]
    k: Key vectors, shape [kv_len, d_qk]; or None
    v: Value vectors, shape [kv_len, d_v]
    mask_fn: a function reference that implements masking (e.g. mask_self_attention)
    q_info: Query-associated metadata for masking
    kv_info: Key-associated metadata for masking
    dropout: Dropout rate
    rng: RNG for dropout

  Returns:
    A tuple (output, dots_logsumexp). The output has shape [q_len, d_v], and
    dots_logsumexp has shape [q_len]. The logsumexp of the attention
    probabilities is useful for combining multiple rounds of attention (as in
    LSH attention).
  """
  assert v is not None
  share_qk = (k is None)
  if share_qk:
    k = q
    if kv_info is None:
      kv_info = q_info

  if share_qk:
    k = length_normalized(k)
  k = k / np.sqrt(k.shape[-1])

  # Dot-product attention.
  kr = np.swapaxes(k, -1, -2)  # note the fancy transpose for later..

## Step 1  ##
  dots = np.matmul(q, kr )
  if verbose: print("Our attend dots", dots.shape)

  # Masking
  if mask_fn is not None:
    dots = mask_fn(dots, q_info[..., :, None], kv_info[..., None, :])

  # Softmax.
  #dots_logsumexp = fastmath.logsumexp(dots, axis=-1, keepdims=True)  #original
  #dots = np.exp(dots - dots_logsumexp)  #original
  
## Step 2  ##
  #replace with our_softmax()
  dots, dots_logsumexp = our_softmax(dots, passthrough=passthrough)
  if verbose: print("Our attend dots post softmax", dots.shape, dots_logsumexp.shape)

  if dropout > 0.0:
    assert rng is not None
    # Dropout is broadcast across the bin dimension
    dropout_shape = (dots.shape[-2], dots.shape[-1])
    keep_prob = tie_in(dots, 1.0 - dropout)
    keep = fastmath.random.bernoulli(rng, keep_prob, dropout_shape)
    multiplier = keep.astype(dots.dtype) / tie_in(keep, keep_prob)
    dots = dots * multiplier

## Step 3  ##
# The softmax normalizer (dots_logsumexp) is used by multi-round LSH attn.
  out = np.matmul(dots, v)
  if verbose: print("Our attend out1", out.shape)
  out = np.reshape(out, (-1, out.shape[-1]))
  if verbose: print("Our attend out2", out.shape)
  dots_logsumexp = np.reshape(dots_logsumexp, (-1,))
  return out, dots_logsumexp

In [7]:
seq_len = 8
emb_len = 5
d_qk = 3
d_v = 4
with fastmath.use_backend("jax"):  # specify the backend for consistency
    rng_attend = fastmath.random.get_prng(1)
    q = k = jax.random.uniform(rng_attend, (seq_len, d_qk), dtype=np.float32)
    v = jax.random.uniform(rng_attend, (seq_len, d_v), dtype=np.float32)
    o, logits = our_simple_attend(
        q,
        k,
        v,
        mask_fn=None,
        q_info=None,
        kv_info=None,
        dropout=0.0,
        rng=rng_attend,
        verbose=True,
    )
print(o, "\n", logits)

Our attend dots (8, 8)


Our attend dots post softmax (8, 8) (8, 1)
Our attend out1 (8, 4)
Our attend out2 (8, 4)
[[0.5455444  0.4232705  0.62970716 0.45504814]
 [0.5558777  0.4169514  0.6260488  0.45763403]
 [0.5502556  0.42250413 0.6107501  0.4532582 ]
 [0.53680766 0.43004778 0.63048995 0.4492887 ]
 [0.5546176  0.41898918 0.62778664 0.44567773]
 [0.54741716 0.4229177  0.6060424  0.46433902]
 [0.53192824 0.43415833 0.63327026 0.44313937]
 [0.538871   0.42285213 0.6527077  0.44843906]] 
 [2.5345023 2.6896586 2.8266857 2.4992957 2.861424  2.6235857 2.5204637
 2.3627536]


In [8]:
class OurSelfAttention(tl.SelfAttention):
    """Our self-attention. Just the Forward Function."""

    def forward_unbatched(
        self, x, mask=None, *, weights, state, rng, update_state, verbose=False
    ):
        print("ourSelfAttention:forward_unbatched")
        del update_state
        attend_rng, output_rng = fastmath.random.split(rng)
        if self.bias:
            if self.share_qk:
                w_q, w_v, w_o, b_q, b_v = weights
            else:
                w_q, w_k, w_v, w_o, b_q, b_k, b_v = weights
        else:
            if self.share_qk:
                w_q, w_v, w_o = weights
            else:
                w_q, w_k, w_v, w_o = weights

        print("x.shape,w_q.shape", x.shape, w_q.shape)
        q = np.matmul(x, w_q)
        k = None
        if not self.share_qk:
            k = np.matmul(x, w_k)
        v = np.matmul(x, w_v)

        if self.bias:
            q = q + b_q
            if not self.share_qk:
                k = k + b_k
            v = v + b_v

        mask_fn = functools.partial(
            mask_self_attention,
            causal=self.causal,
            exclude_self=self.share_qk,
            masked=self.masked,
        )
        q_info = kv_info = tie_in(x, np.arange(q.shape[-2], dtype=np.int32))

        assert (mask is not None) == self.masked
        if self.masked:
            # mask is a boolean array (True means "is valid token")
            ones_like_mask = tie_in(x, np.ones_like(mask, dtype=np.int32))
            kv_info = kv_info * np.where(mask, ones_like_mask, -ones_like_mask)

        # Notice, we are callout our vesion of attend
        o, _ = our_simple_attend(
            q,
            k,
            v,
            mask_fn=mask_fn,
            q_info=q_info,
            kv_info=kv_info,
            dropout=self.attention_dropout,
            rng=attend_rng,
            verbose=True,
        )

        # Notice, wo weight matrix applied to output of attend in forward_unbatched
        out = np.matmul(o, w_o)
        out = apply_broadcasted_dropout(out, self.output_dropout, output_rng)
        return out, state

In [9]:
causal = False
masked = False
mask = None
attention_dropout = 0.0
n_heads = 3
d_qk = 3
d_v = 4
seq_len = 8
emb_len = 5
batch_size = 1

osa = OurSelfAttention(
    n_heads=n_heads,
    d_qk=d_qk,
    d_v=d_v,
    causal=causal,
    use_reference_code=True,
    attention_dropout=attention_dropout,
    mode="train",
)

rng_osa = fastmath.random.get_prng(1)
x = jax.random.uniform(
    jax.random.PRNGKey(0), (batch_size, seq_len, emb_len), dtype=np.float32
)
_, _ = osa.init(tl.shapes.signature(x), rng=rng_osa)

In [10]:
osa(x)

ourSelfAttention:forward_unbatched


LayerError: Exception passing through layer OurSelfAttention (in pure_fn):
  layer created in file [...]/layers/research/efficient_attention.py, line 1014
  layer input shapes: ShapeDtype{shape:(1, 8, 5), dtype:float32}

  File [...]/layers/research/efficient_attention.py, line 1323, in forward
    single_out, single_new_state = self.forward_unbatched(

  File [...]/tmp/ipykernel_54637/1155828790.py, line 10, in forward_unbatched
    if self.bias:

AttributeError: 'OurSelfAttention' object has no attribute 'bias'. Did you mean: '_bias'?

In [11]:
def our_hash_vectors(vecs, rng, n_buckets, n_hashes, mask=None, verbose=False):
    """
  Args:
    vecs: tensor of at least 2 dimension,
    rng: random number generator
    n_buckets: number of buckets in each hash table
    n_hashes: the number of hash tables
    mask: None indicating no mask or a 1D boolean array of length vecs.shape[0], containing the location of padding value
    verbose: controls prints for debug
  Returns:
    A vector of size n_hashes * vecs.shape[0] containing the buckets associated with each input vector per hash table.

    """

    # check for even, integer bucket sizes
    assert isinstance(n_buckets, int) and n_buckets % 2 == 0

    rng = fastmath.stop_gradient(tie_in(vecs, rng))
    rot_size = n_buckets
    ### Start Code Here

    ### Step 1 ###
    rotations_shape = (vecs.shape[-1], n_hashes, rot_size // 2)
    random_rotations = fastmath.random.normal(rng, rotations_shape).astype(
        np.float32)
    if verbose: print("random.rotations.shape", random_rotations.shape)

    ### Step 2 ###
    if fastmath.backend_name() == 'jax':
      rotated_vecs = np.einsum('tf,fhb->htb', vecs, random_rotations)
      if verbose: print("using jax")
    else:
      #Step 2a
      random_rotations = np.reshape(random_rotations,
                                    [-1, n_hashes * (rot_size // 2)])
      if verbose: print("random_rotations reshaped", random_rotations.shape)
      #Step 2b
      rotated_vecs = np.dot(vecs, random_rotations)
      if verbose: print("rotated_vecs1", rotated_vecs.shape)
      #Step 2c
      rotated_vecs = np.reshape(rotated_vecs, [-1, n_hashes, rot_size//2])
      if verbose: print("rotated_vecs2", rotated_vecs.shape)
      #Step 2d
      rotated_vecs = np.transpose(rotated_vecs, (1, 0, 2))
      if verbose: print("rotated_vecs3", rotated_vecs.shape)

    ### Step 3 ###
    rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1)
    if verbose: print("rotated_vecs.shape", rotated_vecs.shape)
    ### Step 4 ###
    buckets = np.argmax(rotated_vecs, axis=-1).astype(np.int32)
    if verbose: print("buckets.shape", buckets.shape)
    if verbose: print("buckets", buckets)

    if mask is not None:
      n_buckets += 1  # Create an extra bucket for padding tokens only
      buckets = np.where(mask[None, :], buckets, n_buckets - 1)

    # buckets is now (n_hashes, seqlen). Next we add offsets so that
    # bucket numbers from different hashing rounds don't overlap.
    offsets = tie_in(buckets, np.arange(n_hashes, dtype=np.int32))
    offsets = np.reshape(offsets * n_buckets, (-1, 1))
    ### Step 5 ###
    buckets = np.reshape(buckets + offsets, (-1,))
    if verbose: print("buckets with offsets", buckets.shape, "\n", buckets)
    return buckets

In [12]:
# example code. Note for reference, the sizes in this example match the values in the diagram above.
ohv_q = np.ones((8, 5))  # (seq_len=8, n_q=5)
ohv_n_buckets = 4  # even number
ohv_n_hashes = 3
with fastmath.use_backend("tf"):
    ohv_rng = fastmath.random.get_prng(1)
    ohv = our_hash_vectors(
        ohv_q, ohv_rng, ohv_n_buckets, ohv_n_hashes, mask=None, verbose=True
    )
    print("ohv shape", ohv.shape, "\nohv", ohv)  # (ohv_n_hashes * ohv_n_buckets)
# note the random number generators do not produce the same results with different backends
with fastmath.use_backend("jax"):
    ohv_rng = fastmath.random.get_prng(1)
    ohv = our_hash_vectors(ohv_q, ohv_rng, ohv_n_buckets, ohv_n_hashes, mask=None)
    print("ohv shape", ohv.shape, "\nohv", ohv)  # (ohv_n_hashes * ohv_n_buckets)

ValueError: No backend with name tf

In [13]:
def sort_buckets(buckets, q, v, n_buckets, n_hashes, seqlen, verbose=True):
    """
  Args:
    buckets: tensor of at least 2 dimension,
    n_buckets: number of buckets in each hash table
    n_hashes: the number of hash tables
    """
    if verbose: print("---sort_buckets--")
    ## Step 1
    ticker = np.arange(n_hashes * seqlen)
    if verbose: print("ticker",ticker.shape, ticker)
    ## Step 2
    buckets_and_t = seqlen * buckets + (ticker % seqlen)
    if verbose: print("buckets_and_t",buckets_and_t.shape, buckets_and_t)

    # Hash-based sort ("s" at the start of variable names means "sorted")
    #Step 3
    sbuckets_and_t, sticker = fastmath.sort_key_val(
    buckets_and_t, ticker, dimension=-1)
    if verbose: print("sbuckets_and_t",sbuckets_and_t.shape, sbuckets_and_t)
    if verbose: print("sticker",sticker.shape, sticker)
    #Step 4
    _, undo_sort = fastmath.sort_key_val(sticker, ticker, dimension=-1)
    if verbose: print("undo_sort",undo_sort.shape, undo_sort)

    #Step 4
    st = (sticker % seqlen)
    sq = np.take(q, st, axis=0)
    sv = np.take(v, st, axis=0)
    return sq, sv, sticker, undo_sort

In [14]:
t_n_hashes = 2
t_n_buckets = 4
t_n_seq = t_seqlen = 8
t_n_q = 3
n_v = 5

t_q = (np.array([(j % t_n_buckets) for j in range(t_n_seq)]) * np.ones((t_n_q, 1))).T
t_v = np.ones((t_n_seq, n_v))
t_buckets = np.array(
    [
        (j % t_n_buckets) + t_n_buckets * i
        for i in range(t_n_hashes)
        for j in range(t_n_seq)
    ]
)
print("q\n", t_q)
print("t_buckets: ", t_buckets)

t_sq, t_sv, t_sticker, t_undo_sort = sort_buckets(
    t_buckets, t_q, t_v, t_n_buckets, t_n_hashes, t_seqlen, verbose=True
)

print("sq.shape", t_sq.shape, "sv.shape", t_sv.shape)
print("sq\n", t_sq)

q
 [[0. 0. 0.]
 [1. 1. 1.]
 [2. 2. 2.]
 [3. 3. 3.]
 [0. 0. 0.]
 [1. 1. 1.]
 [2. 2. 2.]
 [3. 3. 3.]]
t_buckets:  [0 1 2 3 0 1 2 3 4 5 6 7 4 5 6 7]
---sort_buckets--
ticker (16,) [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15]
buckets_and_t (16,) [ 0  9 18 27  4 13 22 31 32 41 50 59 36 45 54 63]
sbuckets_and_t (16,) [ 0  4  9 13 18 22 27 31 32 36 41 45 50 54 59 63]
sticker (16,) [ 0  4  1  5  2  6  3  7  8 12  9 13 10 14 11 15]
undo_sort (16,) [ 0  2  4  6  1  3  5  7  8 10 12 14  9 11 13 15]


sq.shape (16, 3) sv.shape (16, 5)
sq
 [[0. 0. 0.]
 [0. 0. 0.]
 [1. 1. 1.]
 [1. 1. 1.]
 [2. 2. 2.]
 [2. 2. 2.]
 [3. 3. 3.]
 [3. 3. 3.]
 [0. 0. 0.]
 [0. 0. 0.]
 [1. 1. 1.]
 [1. 1. 1.]
 [2. 2. 2.]
 [2. 2. 2.]
 [3. 3. 3.]
 [3. 3. 3.]]


In [15]:
a = np.arange(16 * 3).reshape((16, 3))
chunksize = 2
ar = np.reshape(
    a, (-1, chunksize, a.shape[-1])
)  # the -1 usage is very handy, see numpy reshape
print(ar.shape)

(8, 2, 3)


In [16]:
def dotandv(sq, sv, undo_sort, kv_chunk_len, n_hashes, seqlen, passthrough, verbose=False ):
    # Step 1
    rsq = np.reshape(sq,(-1, kv_chunk_len, sq.shape[-1]))
    rsqt =  np.swapaxes(rsq, -1, -2)
    if verbose: print("rsq.shape,rsqt.shape: ", rsq.shape,rsqt.shape)
    dotlike = np.matmul(rsq, rsqt)
    if verbose: print("dotlike\n", dotlike)

    #Step 2
    dotlike, slogits = our_softmax(dotlike, passthrough)
    if verbose: print("dotlike post softmax\n", dotlike)

    #Step 3
    vr = np.reshape(sv, (-1, kv_chunk_len, sv.shape[-1]))
    if verbose:  print("dotlike.shape, vr.shape:", dotlike.shape, vr.shape)
    so = np.matmul(dotlike, vr)
    if verbose: print("so.shape:", so.shape)
    so = np.reshape(so, (-1, so.shape[-1]))
    slogits = np.reshape(slogits, (-1,))  # provided
    if verbose: print("so.shape,slogits.shape", so.shape, slogits.shape)

    #Step 4
    o = np.take(so, undo_sort, axis=0)
    logits = np.take(slogits, undo_sort, axis=0)
    if verbose: print("o.shape,o", o.shape, o)
    if verbose: print("logits.shape, logits", logits.shape, logits)

    #Step 5 (Provided)
    if n_hashes > 1:
      o = np.reshape(o, (n_hashes, seqlen, o.shape[-1]))
      logits = np.reshape(logits, (n_hashes, seqlen, 1))
      probs = np.exp(logits - fastmath.logsumexp(logits, axis=0, keepdims=True))
      o = np.sum(o * probs, axis=0)

    return(o)

In [17]:
t_kv_chunk_len = 2
out = dotandv(
    t_sq,
    t_sv,
    t_undo_sort,
    t_kv_chunk_len,
    t_n_hashes,
    t_seqlen,
    passthrough=True,
    verbose=True,
)
print("out\n", out)
print("\n-----With softmax enabled----\n")
out = dotandv(
    t_sq,
    t_sv,
    t_undo_sort,
    t_kv_chunk_len,
    t_n_hashes,
    t_seqlen,
    passthrough=False,
    verbose=True,
)
print("out\n", out)

rsq.shape,rsqt.shape:  (8, 2, 3) (8, 3, 2)
dotlike
 [[[ 0.  0.]
  [ 0.  0.]]

 [[ 3.  3.]
  [ 3.  3.]]

 [[12. 12.]
  [12. 12.]]

 [[27. 27.]
  [27. 27.]]

 [[ 0.  0.]
  [ 0.  0.]]

 [[ 3.  3.]
  [ 3.  3.]]

 [[12. 12.]
  [12. 12.]]

 [[27. 27.]
  [27. 27.]]]


dotlike post softmax
 [[[ 0.  0.]
  [ 0.  0.]]

 [[ 3.  3.]
  [ 3.  3.]]

 [[12. 12.]
  [12. 12.]]

 [[27. 27.]
  [27. 27.]]

 [[ 0.  0.]
  [ 0.  0.]]

 [[ 3.  3.]
  [ 3.  3.]]

 [[12. 12.]
  [12. 12.]]

 [[27. 27.]
  [27. 27.]]]
dotlike.shape, vr.shape: (8, 2, 2) (8, 2, 5)
so.shape: (8, 2, 5)
so.shape,slogits.shape (16, 5) (16,)
o.shape,o (16, 5) [[ 0.  0.  0.  0.  0.]
 [ 6.  6.  6.  6.  6.]
 [24. 24. 24. 24. 24.]
 [54. 54. 54. 54. 54.]
 [ 0.  0.  0.  0.  0.]
 [ 6.  6.  6.  6.  6.]
 [24. 24. 24. 24. 24.]
 [54. 54. 54. 54. 54.]
 [ 0.  0.  0.  0.  0.]
 [ 6.  6.  6.  6.  6.]
 [24. 24. 24. 24. 24.]
 [54. 54. 54. 54. 54.]
 [ 0.  0.  0.  0.  0.]
 [ 6.  6.  6.  6.  6.]
 [24. 24. 24. 24. 24.]
 [54. 54. 54. 54. 54.]]
logits.shape, logits (16,) [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]


out
 [[ 0.  0.  0.  0.  0.]
 [ 6.  6.  6.  6.  6.]
 [24. 24. 24. 24. 24.]
 [54. 54. 54. 54. 54.]
 [ 0.  0.  0.  0.  0.]
 [ 6.  6.  6.  6.  6.]
 [24. 24. 24. 24. 24.]
 [54. 54. 54. 54. 54.]]

-----With softmax enabled----

rsq.shape,rsqt.shape:  (8, 2, 3) (8, 3, 2)
dotlike
 [[[ 0.  0.]
  [ 0.  0.]]

 [[ 3.  3.]
  [ 3.  3.]]

 [[12. 12.]
  [12. 12.]]

 [[27. 27.]
  [27. 27.]]

 [[ 0.  0.]
  [ 0.  0.]]

 [[ 3.  3.]
  [ 3.  3.]]

 [[12. 12.]
  [12. 12.]]

 [[27. 27.]
  [27. 27.]]]
dotlike post softmax
 [[[0.5        0.5       ]
  [0.5        0.5       ]]

 [[0.5        0.5       ]
  [0.5        0.5       ]]

 [[0.49999976 0.49999976]
  [0.49999976 0.49999976]]

 [[0.49999976 0.49999976]
  [0.49999976 0.49999976]]

 [[0.5        0.5       ]
  [0.5        0.5       ]]

 [[0.5        0.5       ]
  [0.5        0.5       ]]

 [[0.49999976 0.49999976]
  [0.49999976 0.49999976]]

 [[0.49999976 0.49999976]
  [0.49999976 0.49999976]]]
dotlike.shape, vr.shape: (8, 2, 2) (8, 2, 5)
so.shape: (8, 2, 5)

In [18]:
# original version from trax 1.3.4
def attend(
    q,
    k=None,
    v=None,
    q_chunk_len=None,
    kv_chunk_len=None,
    n_chunks_before=0,
    n_chunks_after=0,
    mask_fn=None,
    q_info=None,
    kv_info=None,
    dropout=0.0,
    rng=None,
):
    """Dot-product attention, with optional chunking and/or masking.

  Args:
    q: Query vectors, shape [q_len, d_qk]
    k: Key vectors, shape [kv_len, d_qk]; or None
    v: Value vectors, shape [kv_len, d_v]
    q_chunk_len: Set to non-zero to enable chunking for query vectors
    kv_chunk_len: Set to non-zero to enable chunking for key/value vectors
    n_chunks_before: Number of adjacent previous chunks to attend to
    n_chunks_after: Number of adjacent subsequent chunks to attend to
    mask_fn: TODO(kitaev) doc
    q_info: Query-associated metadata for masking
    kv_info: Key-associated metadata for masking
    dropout: Dropout rate
    rng: RNG for dropout

  Returns:
    A tuple (output, dots_logsumexp). The output has shape [q_len, d_v], and
    dots_logsumexp has shape [q_len]. The logsumexp of the attention
    probabilities is useful for combining multiple rounds of attention (as in
    LSH attention).
  """
    assert v is not None
    share_qk = k is None

    if q_info is None:
        q_info = np.arange(q.shape[-2], dtype=np.int32)

    if kv_info is None and not share_qk:
        kv_info = np.arange(v.shape[-2], dtype=np.int32)

    # Split q/k/v into chunks along the time axis, if desired.
    if q_chunk_len is not None:
        q = np.reshape(q, (-1, q_chunk_len, q.shape[-1]))
        q_info = np.reshape(q_info, (-1, q_chunk_len))

    if share_qk:
        assert kv_chunk_len is None or kv_chunk_len == q_chunk_len
        k = q
        kv_chunk_len = q_chunk_len
        if kv_info is None:
            kv_info = q_info
        elif kv_chunk_len is not None:
            # kv_info is not None, but reshape as required.
            kv_info = np.reshape(kv_info, (-1, kv_chunk_len))
    elif kv_chunk_len is not None:
        k = np.reshape(k, (-1, kv_chunk_len, k.shape[-1]))
        kv_info = np.reshape(kv_info, (-1, kv_chunk_len))

    if kv_chunk_len is not None:
        v = np.reshape(v, (-1, kv_chunk_len, v.shape[-1]))

    if share_qk:
        k = length_normalized(k)
    k = k / np.sqrt(k.shape[-1])

    # Optionally include adjacent chunks.
    if q_chunk_len is not None or kv_chunk_len is not None:
        assert q_chunk_len is not None and kv_chunk_len is not None
    else:
        assert n_chunks_before == 0 and n_chunks_after == 0

    k = look_adjacent(k, n_chunks_before, n_chunks_after)
    v = look_adjacent(v, n_chunks_before, n_chunks_after)
    kv_info = look_adjacent(kv_info, n_chunks_before, n_chunks_after)

    # Dot-product attention.
    dots = np.matmul(q, np.swapaxes(k, -1, -2))

    # Masking
    if mask_fn is not None:
        dots = mask_fn(dots, q_info[..., :, None], kv_info[..., None, :])

    # Softmax.
    dots_logsumexp = fastmath.logsumexp(dots, axis=-1, keepdims=True)
    dots = np.exp(dots - dots_logsumexp)

    if dropout > 0.0:
        assert rng is not None
        # Dropout is broadcast across the bin dimension
        dropout_shape = (dots.shape[-2], dots.shape[-1])
        #
        keep_prob = tie_in(dots, 1.0 - dropout)
        keep = fastmath.random.bernoulli(rng, keep_prob, dropout_shape)
        multiplier = keep.astype(dots.dtype) / tie_in(keep, keep_prob)
        dots = dots * multiplier

    # The softmax normalizer (dots_logsumexp) is used by multi-round LSH attn.
    out = np.matmul(dots, v)
    out = np.reshape(out, (-1, out.shape[-1]))
    dots_logsumexp = np.reshape(dots_logsumexp, (-1,))
    return out, dots_logsumexp

In [19]:
class OurLSHSelfAttention(tl.LSHSelfAttention):
    """Our simplified LSH self-attention """

    def forward_unbatched(self, x, mask=None, *, weights, state, rng, update_state):
        attend_rng, output_rng = fastmath.random.split(rng)
        w_q, w_v, w_o = weights

        q = np.matmul(x, w_q)
        v = np.matmul(x, w_v)

        if update_state:
            _, old_hash_rng = state
            hash_rng, hash_subrng = fastmath.random.split(old_hash_rng)
            #      buckets = self.hash_vectors(q, hash_subrng, mask)  #  original
            ## use our version of hash
            buckets = our_hash_vectors(
                q, hash_subrng, self.n_buckets, self.n_hashes, mask=mask
            )
            s_buckets = buckets
            if self._max_length_for_buckets:
                length = self.n_hashes * self._max_length_for_buckets
                if buckets.shape[0] < length:
                    s_buckets = np.concatenate(
                        [buckets, np.zeros(length - buckets.shape[0], dtype=np.int32)],
                        axis=0,
                    )
            state = (s_buckets, hash_rng)
        else:
            buckets, _ = state
            if self._max_length_for_buckets:
                buckets = buckets[: self.n_hashes * x.shape[0]]

        seqlen = x.shape[0]
        assert int(buckets.shape[0]) == self.n_hashes * seqlen

        ticker = tie_in(x, np.arange(self.n_hashes * seqlen, dtype=np.int32))
        buckets_and_t = seqlen * buckets + (ticker % seqlen)
        buckets_and_t = fastmath.stop_gradient(buckets_and_t)

        # Hash-based sort ("s" at the start of variable names means "sorted")
        sbuckets_and_t, sticker = fastmath.sort_key_val(
            buckets_and_t, ticker, dimension=-1
        )
        _, undo_sort = fastmath.sort_key_val(sticker, ticker, dimension=-1)
        sbuckets_and_t = fastmath.stop_gradient(sbuckets_and_t)
        sticker = fastmath.stop_gradient(sticker)
        undo_sort = fastmath.stop_gradient(undo_sort)

        st = sticker % seqlen
        sq = np.take(q, st, axis=0)
        sv = np.take(v, st, axis=0)

        mask_fn = functools.partial(
            mask_self_attention,
            causal=self.causal,
            exclude_self=True,
            masked=self.masked,
        )
        q_info = st

        assert (mask is not None) == self.masked
        kv_info = None
        if self.masked:
            # mask is a boolean array (True means "is valid token")
            smask = np.take(mask, st, axis=0)
            ones_like_mask = tie_in(x, np.ones_like(smask, dtype=np.int32))
            kv_info = q_info * np.where(smask, ones_like_mask, -ones_like_mask)

        ## use original version of attend (could use ours but lacks masks and masking)
        so, slogits = attend(
            sq,
            k=None,
            v=sv,
            q_chunk_len=self.chunk_len,
            n_chunks_before=self.n_chunks_before,
            n_chunks_after=self.n_chunks_after,
            mask_fn=mask_fn,
            q_info=q_info,
            kv_info=kv_info,
            dropout=self.attention_dropout,
            rng=attend_rng,
        )

        # np.take(so, undo_sort, axis=0); np.take(slogits, undo_sort, axis=0) would
        # also work, but these helpers include performance optimizations for TPU.
        o = permute_via_gather(so, undo_sort, sticker, axis=0)
        logits = permute_via_sort(slogits, sticker, buckets_and_t, axis=-1)

        if self.n_hashes > 1:
            o = np.reshape(o, (self.n_hashes, seqlen, o.shape[-1]))
            logits = np.reshape(logits, (self.n_hashes, seqlen, 1))
            probs = np.exp(logits - fastmath.logsumexp(logits, axis=0, keepdims=True))
            o = np.sum(o * probs, axis=0)

        assert o.shape == (seqlen, w_v.shape[-1])
        out = np.matmul(o, w_o)
        out = apply_broadcasted_dropout(out, self.output_dropout, output_rng)
        return out, state

In [20]:
# Here we're going to try out our LSHSelfAttention
n_heads = 3
causal = False
masked = False
mask = None
chunk_len = 8
n_chunks_before = 0
n_chunks_after = 0
attention_dropout = 0.0
n_hashes = 5
n_buckets = 4
seq_len = 8
emb_len = 5
al = OurLSHSelfAttention(
    n_heads=n_heads,
    d_qk=3,
    d_v=4,
    causal=causal,
    chunk_len=8,
    n_chunks_before=n_chunks_before,
    n_chunks_after=n_chunks_after,
    n_hashes=n_hashes,
    n_buckets=n_buckets,
    use_reference_code=True,
    attention_dropout=attention_dropout,
    mode="train",
)

x = jax.random.uniform(jax.random.PRNGKey(0), (1, seq_len, emb_len), dtype=np.float32)
al_osa = fastmath.random.get_prng(1)
_, _ = al.init(tl.shapes.signature(x), rng=al_osa)

In [21]:
al(x)

LayerError: Exception passing through layer OurLSHSelfAttention (in pure_fn):
  layer created in file [...]/layers/research/efficient_attention.py, line 1751
  layer input shapes: ShapeDtype{shape:(1, 8, 5), dtype:float32}

  File [...]/layers/research/efficient_attention.py, line 2158, in forward
    single_out, single_new_state = self.forward_unbatched(

  File [...]/tmp/ipykernel_54637/2615489615.py, line 17, in forward_unbatched
    q, hash_subrng, self.n_buckets, self.n_hashes, mask=mask

AttributeError: 'OurLSHSelfAttention' object has no attribute 'n_buckets'. Did you mean: '_n_buckets'?