In [1]:
import jax.numpy as jnp
from jax import random

def shift_right(x, axis=1, constant_values=0):
  """Shift to the right on given axis with padding value 0."""
  pad_widths = [(0, 0)] * len(x.shape)
  pad_widths[axis] = (1, 0)
  padded = jnp.pad(x, pad_widths, constant_values=constant_values)
  # Cuts off the rightmost slice of size along the `axis` dimension.
  # Note that `list[:-1]`` is the same as `list[slice(-1)]`.
  return padded[tuple(slice(-1 if i == axis else None) for i in range(x.ndim))]

# Set seed for reproducibility
seed = 0
key = random.PRNGKey(seed)

# Generate simulated targets
batch_size = 128
sequence_length = 4
embedding_size = 768

# Parameters for the special tokens
eos_id = 1
pad_id = 0

# Generating non-padded parts of the sequence with random IDs (excluding 0 and 1)
non_pad_tokens = random.randint(key, (batch_size, sequence_length), minval=2, maxval=100)

# Generate the ground truth eos indices
true_eos_indices = random.randint(key, (batch_size,), minval=1, maxval=sequence_length)
print("True eos indices:", true_eos_indices)

# Generate the targets by replacing the token at the eos index with the eos token and padding the rest
targets = jnp.array([jnp.concatenate((non_pad_tokens[i, :true_eos_indices[i]], jnp.array([eos_id]), jnp.full(sequence_length - true_eos_indices[i] - 1, pad_id))) for i in range(batch_size)])
print("Targets shape:", targets.shape)

True eos indices: [1 2 2 2 3 3 3 3 1 1 1 3 3 1 2 1 2 2 3 2 2 3 1 1 3 1 2 1 3 2 3 3 3 3 2 1 2
 2 2 1 2 2 2 2 3 1 1 2 1 1 3 2 1 3 3 3 2 2 2 1 1 2 1 3 2 1 3 2 3 1 3 1 2 2
 2 3 1 3 3 2 3 1 3 2 2 1 2 2 3 2 1 3 2 2 2 2 1 3 1 2 2 3 3 2 3 3 3 1 3 3 2
 3 3 2 3 1 2 2 1 3 2 1 3 2 1 3 1 3]
Targets shape: (128, 4)


In [2]:
shifted_targets = shift_right(targets)
print("Shifted targets shape:", shifted_targets.shape)

# Your implementation to extract cls token features
eos_indices = jnp.where(shifted_targets == eos_id, size=128)  # This should now identify only the final eos tokens

# Print some outputs to verify
print("True eos indices:", true_eos_indices)
print("Indices of eos tokens:", eos_indices[1])

# Optional: Verify if the indices match the expected positions (all should be at the last index)
print("Check if the indices match the last position of each sequence:")
print(eos_indices[1] == true_eos_indices+1)


Shifted targets shape: (128, 4)


True eos indices: [1 2 2 2 3 3 3 3 1 1 1 3 3 1 2 1 2 2 3 2 2 3 1 1 3 1 2 1 3 2 3 3 3 3 2 1 2
 2 2 1 2 2 2 2 3 1 1 2 1 1 3 2 1 3 3 3 2 2 2 1 1 2 1 3 2 1 3 2 3 1 3 1 2 2
 2 3 1 3 3 2 3 1 3 2 2 1 2 2 3 2 1 3 2 2 2 2 1 3 1 2 2 3 3 2 3 3 3 1 3 3 2
 3 3 2 3 1 2 2 1 3 2 1 3 2 1 3 1 3]
Indices of eos tokens: [2 3 3 3 2 2 2 2 3 2 3 3 3 3 2 2 2 3 2 3 3 2 3 3 3 2 3 3 3 3 2 2 3 2 2 3 2
 3 3 3 2 2 3 2 3 2 3 2 2 3 3 3 2 3 2 3 3 2 3 3 3 2 3 3 3 3 2 2 3 3 3 2 3 3
 2 3 3 2 3 2 3 2 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
Check if the indices match the last position of each sequence:
[ True  True  True  True False False False False False  True False False
 False False False  True False  True False  True  True False False False
 False  True  True False False  True False False False False False False
 False  True  True False False False  True False False  True False False
  True False False  True  True False False False  True False  True False
 False Fa

In [14]:
import flax.linen as nn
from flax.linen import make_causal_mask

unimodal_decoder_mask = nn.make_causal_mask(jnp.empty((targets.shape[0], targets.shape[1]+1))) # [B,1,L+1,L+1]
print(f"unimodal_decoder_mask shape: {unimodal_decoder_mask.shape}")

cls_mask = unimodal_decoder_mask[:,:,-1,:-1].squeeze() # [B,L]
new_cls_mask = jnp.where(targets == 0, 0, cls_mask) # [B,L]
new_cls_mask = jnp.pad(new_cls_mask, ((0,0),(0,1)), mode='constant', constant_values=1) # [B,L+1]
new_cls_mask = new_cls_mask[:,None,None,:] # [B,1,1,L+1]
new_cls_mask = jnp.pad(new_cls_mask, ((0,0),(0,0),(targets.shape[1],0),(0,0)), mode='constant', constant_values=0)
print(f"new_cls_mask shape: {new_cls_mask.shape}")
final_unimodal_decoder_mask = jnp.concatenate((unimodal_decoder_mask[:,:,:-1,:], new_cls_mask[:,:,-1:,:]), axis=-2)


unimodal_decoder_mask shape: (128, 1, 5, 5)
new_cls_mask shape: (128, 1, 5, 5)


In [15]:
idx = 1
shifted_targets[idx], targets[idx], unimodal_decoder_mask[idx], new_cls_mask[idx], final_unimodal_decoder_mask[idx]

(Array([ 0, 61, 20,  1], dtype=int32),
 Array([61, 20,  1,  0], dtype=int32),
 Array([[[1., 0., 0., 0., 0.],
         [1., 1., 0., 0., 0.],
         [1., 1., 1., 0., 0.],
         [1., 1., 1., 1., 0.],
         [1., 1., 1., 1., 1.]]], dtype=float32),
 Array([[[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [1., 1., 1., 0., 1.]]], dtype=float32),
 Array([[[1., 0., 0., 0., 0.],
         [1., 1., 0., 0., 0.],
         [1., 1., 1., 0., 0.],
         [1., 1., 1., 1., 0.],
         [1., 1., 1., 0., 1.]]], dtype=float32))