In [20]:
import jax.numpy as jnp
from jax import random

# Set seed for reproducibility
seed = 0
key = random.PRNGKey(seed)

# Generate simulated targets
batch_size = 128
sequence_length = 64
embedding_size = 768

# Parameters for the special tokens
eos_id = 1
pad_id = 0

# Generating non-padded parts of the sequence with random IDs (excluding 0 and 1)
non_pad_tokens = random.randint(key, (batch_size, sequence_length), minval=2, maxval=100)

# Generate the ground truth eos indices
true_eos_indices = random.randint(key, (batch_size,), minval=1, maxval=sequence_length)
print("True eos indices:", true_eos_indices)

# Generate the targets by replacing the token at the eos index with the eos token and padding the rest
targets = jnp.array([jnp.concatenate((non_pad_tokens[i, :true_eos_indices[i]], jnp.array([eos_id]), jnp.full(sequence_length - true_eos_indices[i] - 1, pad_id))) for i in range(batch_size)])
print("Targets shape:", targets.shape)

# # Print an example
# example_idx = 5
# print(f"true_eos_indices[{example_idx}] : \n{true_eos_indices[example_idx]}")
# print(f"non_pad_tokens[{example_idx}] : \n{non_pad_tokens[example_idx]}")
# print(f"targets[{example_idx}] : \n{targets[example_idx]}")

True eos indices: [61 38 41 62 33 30 42 18  7 37 58 57 45 22 62  7 11 56 24 62 59 54 49 31
 63 19 26  7 24 56  9 30 36 57 53 55 53 23 44 43 32 50 23 14 57 40  4  2
 19 10 15  8 40 48 18 42 62 50 32 43 19 50 40 51 53 43 60 53 18 58 60 52
 32 62 47 36 43 12  6  5  3 46 54 56 59 16 44 17 63 29 10  6 47 23 17  2
 13 33 58 20  2 42 39 35 15 36 54 34  9 30 23  3 30 47 18 49 11 62 40 45
  2 55 57 38 13 24 34 57]
Targets shape: (128, 64)


In [22]:
# Generate simulated txt_encoded array
txt_encoded = random.normal(key, (batch_size, sequence_length, embedding_size))
print("Txt_encoded shape:", txt_encoded.shape)

# Your implementation to extract cls token features
eos_indices = jnp.where(targets == eos_id, size=128)  # This should now identify only the final eos tokens
contrastive_ztxt = txt_encoded[eos_indices[0], eos_indices[1], :]

# Print some outputs to verify
print("Indices of eos tokens:", eos_indices[1])
print("Shape of extracted features:", contrastive_ztxt.shape)

# Optional: Verify if the indices match the expected positions (all should be at the last index)
print("Check if the indices match the last position of each sequence:")
print(eos_indices[1] == true_eos_indices)


Txt_encoded shape: (128, 64, 768)
Indices of eos tokens: [61 38 41 62 33 30 42 18  7 37 58 57 45 22 62  7 11 56 24 62 59 54 49 31
 63 19 26  7 24 56  9 30 36 57 53 55 53 23 44 43 32 50 23 14 57 40  4  2
 19 10 15  8 40 48 18 42 62 50 32 43 19 50 40 51 53 43 60 53 18 58 60 52
 32 62 47 36 43 12  6  5  3 46 54 56 59 16 44 17 63 29 10  6 47 23 17  2
 13 33 58 20  2 42 39 35 15 36 54 34  9 30 23  3 30 47 18 49 11 62 40 45
  2 55 57 38 13 24 34 57]
Shape of extracted features: (128, 768)
Check if the indices match the last position of each sequence:
[ True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  Tru