In [1]:
import flax.linen as nn
import jax, jax.numpy as jnp

def _new_get_flops(fn, *args, **kwargs):
  e = jax.jit(fn).lower(*args, **kwargs)
  cost = e.compile().cost_analysis()[0]
  if cost is None:
    return 0
  flops = int(cost['flops']) if 'flops' in cost else 0
  return flops

nn.summary._get_flops = _new_get_flops

batch_size = 1
seq_len = 16
dim = 96
head = 12
x = jnp.ones((batch_size, seq_len, dim))


class Foo(nn.Module):
  @nn.compact
  def __call__(self, x,is_causal):
    input_ids = jnp.ones((x.shape[0],x.shape[1]))
    mask = nn.make_causal_mask(input_ids) if is_causal else None
    y = nn.MultiHeadDotProductAttention(
        num_heads=head,
        qkv_features=dim,
        out_features=dim,
        kernel_init=nn.initializers.xavier_uniform(),
        deterministic=False,
        name='attention',
        )(x, x, x, mask=mask)
    return y

print(Foo().tabulate(jax.random.key(0), x, is_causal=True,compute_flops=True, console_kwargs={'width': 120}))


[3m                                                      Foo Summary                                                       [0m
┏━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath           [0m[1m [0m┃[1m [0m[1mmodule              [0m[1m [0m┃[1m [0m[1minputs              [0m[1m [0m┃[1m [0m[1moutputs           [0m[1m [0m┃[1m [0m[1mflops  [0m[1m [0m┃[1m [0m[1mparams               [0m[1m [0m┃
┡━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩
│                 │ Foo                  │ - [2mfloat32[0m[1,16,96]   │ [2mfloat32[0m[1,16,96]   │ 1296030 │                       │
│                 │                      │ - is_causal: True    │                    │         │                       │
├─────────────────┼──────────────────────┼──────────────────────┼────────────────────┼─────────┼─────────

In [2]:
model = Foo()
params = model.init(jax.random.PRNGKey(0), x, is_causal=True)
y = model.apply(params, x, is_causal=True)
y.shape,y

((1, 16, 96),
 Array([[[-0.66383624,  0.40252024,  0.93283004, ..., -0.14342627,
           1.5304804 ,  1.1549087 ],
         [-0.66383624,  0.40252024,  0.93283004, ..., -0.14342627,
           1.5304804 ,  1.1549087 ],
         [-0.66383624,  0.40252024,  0.93283004, ..., -0.14342627,
           1.5304804 ,  1.1549087 ],
         ...,
         [-0.66383624,  0.40252024,  0.93283004, ..., -0.14342627,
           1.5304804 ,  1.1549087 ],
         [-0.66381943,  0.40154403,  0.9372309 , ..., -0.14170712,
           1.5388081 ,  1.1570458 ],
         [-0.66383624,  0.40252024,  0.93283004, ..., -0.14342627,
           1.5304804 ,  1.1549087 ]]], dtype=float32))

- Problem: discovering that causal and fully visible masking have same flops
- Attempt: changing the attention calculation mechanism

In [3]:
from typing import Any, Callable, Optional, Union, overload
from flax.linen.dtypes import promote_dtype
from flax.linen.module import Module, compact, merge_param
from flax.typing import (
  Array,
  PRNGKey,
  Dtype,
  Shape as Shape,
  Initializer,
  PrecisionLike,
  DotGeneralT,
)

def new_dot_product_attention_weights(
    query: Array,
    key: Array,
    bias: Optional[Array] = None,
    mask: Optional[Array] = None,
    broadcast_dropout: bool = True,
    dropout_rng: Optional[PRNGKey] = None,
    dropout_rate: float = 0.0,
    deterministic: bool = False,
    dtype: Optional[Dtype] = None,
    precision: PrecisionLike = None,
    module: Optional[Module] = None,
    force_fp32_for_softmax: bool = False,
    einsum_dot_general: Callable[..., Array] = jax.lax.dot_general,
):
  query, key = promote_dtype(query, key, dtype=dtype)
  dtype = query.dtype

  assert query.ndim == key.ndim, 'q, k must have same rank.'
  assert query.shape[:-3] == key.shape[:-3], 'q, k batch dims must match.'
  assert query.shape[-2] == key.shape[-2], 'q, k num_heads must match.'
  assert query.shape[-1] == key.shape[-1], 'q, k depths must match.'

  # calculate attention matrix
  depth = query.shape[-1]
  query = query / jnp.sqrt(depth).astype(dtype)
  if mask is None:
    # attn weight shape is (batch..., num_heads, q_length, kv_length)
    attn_weights = jnp.einsum(
        '...qhd,...khd->...hqk',
        query,
        key,
        precision=precision,
        _dot_general=einsum_dot_general,
    )
  else:
    big_neg = jnp.finfo(dtype).min
    bs,seq_len,head,dim = query.shape
    attn_weights = jnp.ones((bs, head, seq_len, key.shape[1]))*big_neg
    for i in range(seq_len):
      # print(f"i = {i}")
      temp_query = query[:,i:,:,:] # (bs, seq_len-i, head, dim)
      temp_key = key[:,i:i+1,:,:] # (bs, 1, head, dim)
      temp_attn_weights = jnp.einsum(
          '...qhd,...khd->...hqk',
          temp_query,
          temp_key,
          precision=precision,
          _dot_general=einsum_dot_general,
      ) # (bs, head, seq_len-i, key.shape[1])
      # print(f"temp_query.shape = {temp_query.shape}")
      # print(f"temp_key.shape = {temp_key.shape}")
      # print(f"temp_attn_weights.shape = {temp_attn_weights.shape}")
      attn_weights = attn_weights.at[:,:,i:,i:i+1].set(temp_attn_weights)



  # apply attention bias: masking, dropout, proximity bias, etc.
  if bias is not None:
    attn_weights = attn_weights + bias
  # # apply attention mask
  # if mask is not None:
  #   big_neg = jnp.finfo(dtype).min
  #   attn_weights = jnp.where(mask, attn_weights, big_neg)
  # normalize the attention weights
  if force_fp32_for_softmax and dtype != jnp.float32:
    attn_weights = jax.nn.softmax(attn_weights.astype(jnp.float32))
  else:
    attn_weights = jax.nn.softmax(attn_weights).astype(dtype)

  if module:
    module.sow('intermediates', 'attention_weights', attn_weights)

  # apply attention dropout
  if not deterministic and dropout_rate > 0.0:
    keep_prob = 1.0 - dropout_rate
    if broadcast_dropout:
      # dropout is broadcast across the batch + head dimensions
      dropout_shape = tuple([1] * (key.ndim - 2)) + attn_weights.shape[-2:]
      keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape)  # type: ignore
    else:
      keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)  # type: ignore
    multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
    attn_weights = attn_weights * multiplier

  return attn_weights


nn.attention.dot_product_attention_weights = new_dot_product_attention_weights

In [4]:
print(Foo().tabulate(jax.random.key(0), x, is_causal=True,compute_flops=True, console_kwargs={'width': 120}))


[3m                                                      Foo Summary                                                       [0m
┏━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath           [0m[1m [0m┃[1m [0m[1mmodule              [0m[1m [0m┃[1m [0m[1minputs              [0m[1m [0m┃[1m [0m[1moutputs           [0m[1m [0m┃[1m [0m[1mflops  [0m[1m [0m┃[1m [0m[1mparams               [0m[1m [0m┃
┡━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩
│                 │ Foo                  │ - [2mfloat32[0m[1,16,96]   │ [2mfloat32[0m[1,16,96]   │ 1277568 │                       │
│                 │                      │ - is_causal: True    │                    │         │                       │
├─────────────────┼──────────────────────┼──────────────────────┼────────────────────┼─────────┼─────────

In [5]:
model = Foo()
params = model.init(jax.random.PRNGKey(0), x, is_causal=True)
y = model.apply(params, x, is_causal=True)
y.shape,y

((1, 16, 96),
 Array([[[-0.66383624,  0.40252024,  0.93283004, ..., -0.14342627,
           1.5304804 ,  1.1549087 ],
         [-0.66383624,  0.40252024,  0.93283004, ..., -0.14342627,
           1.5304804 ,  1.1549087 ],
         [-0.66383624,  0.40252024,  0.93283004, ..., -0.14342627,
           1.5304804 ,  1.1549087 ],
         ...,
         [-0.66383624,  0.40252024,  0.93283004, ..., -0.14342627,
           1.5304804 ,  1.1549087 ],
         [-0.66381943,  0.40154403,  0.9372309 , ..., -0.14170712,
           1.5388081 ,  1.1570458 ],
         [-0.66383624,  0.40252024,  0.93283004, ..., -0.14342627,
           1.5304804 ,  1.1549087 ]]], dtype=float32))

Conclusion: Q*K is a really small component of FLOPS in MSA... STUPID! Wasted your 1 hour!