In [11]:
import jax.numpy as jnp
from jax import random
from flax.linen import make_causal_mask

# Set seed for reproducibility
seed = 0
key = random.PRNGKey(seed)


# Generate simulated attn_weights and targets
batch_size = 128
sequence_length = 64
num_heads = 12
attn_weights = random.uniform(key, (batch_size, num_heads, sequence_length, sequence_length))
targets = random.uniform(key, (batch_size, sequence_length))
print(f"attn_weights shape: {attn_weights.shape}")
print(f"targets shape: {targets.shape}")
mask = make_causal_mask(targets)
print(f"mask shape: {mask.shape}")
dtype = jnp.float32
big_neg = jnp.finfo(dtype).min
new_attn_weights = jnp.where(mask, attn_weights, big_neg)
print(f"new_attn_weights shape: {new_attn_weights.shape}")

attn_weights shape: (128, 12, 64, 64)
targets shape: (128, 64)
mask shape: (128, 1, 64, 64)
new_attn_weights shape: (128, 12, 64, 64)


In [27]:
(mask==0).sum()/128/64, (new_attn_weights==big_neg).sum()/128/12/64

(Array(31.5, dtype=float32), Array(31.5, dtype=float32))

In [32]:
new_attn_weights[0,2,1,:]

Array([ 7.1978688e-02,  5.0874996e-01, -3.4028235e+38, -3.4028235e+38,
       -3.4028235e+38, -3.4028235e+38, -3.4028235e+38, -3.4028235e+38,
       -3.4028235e+38, -3.4028235e+38, -3.4028235e+38, -3.4028235e+38,
       -3.4028235e+38, -3.4028235e+38, -3.4028235e+38, -3.4028235e+38,
       -3.4028235e+38, -3.4028235e+38, -3.4028235e+38, -3.4028235e+38,
       -3.4028235e+38, -3.4028235e+38, -3.4028235e+38, -3.4028235e+38,
       -3.4028235e+38, -3.4028235e+38, -3.4028235e+38, -3.4028235e+38,
       -3.4028235e+38, -3.4028235e+38, -3.4028235e+38, -3.4028235e+38,
       -3.4028235e+38, -3.4028235e+38, -3.4028235e+38, -3.4028235e+38,
       -3.4028235e+38, -3.4028235e+38, -3.4028235e+38, -3.4028235e+38,
       -3.4028235e+38, -3.4028235e+38, -3.4028235e+38, -3.4028235e+38,
       -3.4028235e+38, -3.4028235e+38, -3.4028235e+38, -3.4028235e+38,
       -3.4028235e+38, -3.4028235e+38, -3.4028235e+38, -3.4028235e+38,
       -3.4028235e+38, -3.4028235e+38, -3.4028235e+38, -3.4028235e+38,
      

In [17]:
mask[0,:,2,:5]

Array([[1., 1., 1., 0., 0.]], dtype=float32)