In [None]:
!pip install -U "jax[cpu]"

In [None]:
!git clone https://github.com/google-deepmind/gemma.git

In [1]:
import os

VARIANT = "27b"  # @param ['2b', '2b-it', '7b', '7b-it'] {type:"string"}


ckpt_path = "/home/zhaoyuec/data/gemma2/gemma2-27b/ckpt/"
vocab_path = "/home/zhaoyuec/data/gemma2/gemma2-27b/tokenizer.model"

In [2]:
# Load parameters
from gemma import params as params_lib

params = params_lib.load_and_format_params(ckpt_path)

In [3]:
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(vocab_path)

True

In [4]:
# We use the `transformer_lib.TransformerConfig.from_params` function to
# automatically load the correct configuration from a checkpoint. Note that the
# vocabulary size is smaller than the number of input embeddings due to unused
# tokens in this release.

from gemma import transformer as transformer_lib

config_27b = transformer_lib.TransformerConfig.from_params(
    params, cache_size=30  # Number of time steps in the transformer's cache
)
model_27b = transformer_lib.Transformer(config=config_27b)

gemma2 27b


In [5]:
from gemma import sampler as sampler_lib
# Create a sampler with the right param shapes.
sampler = sampler_lib.Sampler(
    transformer=model_27b,
    vocab=vocab,
    params=params["transformer"],
)

In [6]:
prompt_texts = ["I love to", "Today is a", "What is the"]
# prompt_texts = ["I love to"]

# out_data = sampler(
#     input_strings=prompt_texts,
#     total_generation_steps=10,  # number of steps performed when generating
#   )

# for input_string, out_string in zip(prompt_texts, out_data.text):
#   print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
#   print()
#   print(10*'#')

In [8]:
import jax


def get_attention_mask_and_positions(
    example: jax.Array,
    pad_id: int,
) -> tuple[jax.Array, jax.Array]:
  """Builds the position and attention mask vectors from the given tokens."""

  pad_mask = example != pad_id

  current_token_position = transformer_lib.build_positions_from_mask(pad_mask)
  attention_mask = transformer_lib.make_causal_attn_mask(pad_mask)
  return current_token_position, attention_mask

In [9]:
import numpy as np
import jax.numpy as jnp
from gemma import transformer as transformer_lib
import jsonlines

params = params_lib.load_and_format_params(ckpt_path)

output_path = "golden_data_gemma2-27b.jsonl"
all_data_to_save = []

for prompt_index in range(len(prompt_texts)):
  prompt_text = prompt_texts[prompt_index]
  one_sample_input = np.array([2] + vocab.encode(prompt_text))
  expanded_one_sample_input = jnp.expand_dims(one_sample_input, axis=0)
  pad_id = vocab.pad_id
  get_attention_mask_and_positions(one_sample_input, pad_id)
  # Build the position and attention mask vectors.
  positions, attention_mask = get_attention_mask_and_positions(one_sample_input, pad_id)
  print(f"{expanded_one_sample_input=}, {positions=}, {attention_mask=}")

  # Foward pass on the input data.
  # No attention cache is needed here.

  logits, _ = model_27b.apply(
      #     params,
      {"params": params["transformer"]},
      expanded_one_sample_input,
      positions,
      None,  # Attention cache is None.
      attention_mask,
  )
  print(f"{logits=}")
  print(logits.shape)
  # Prepare data to be saved
  data_to_save = {
      "prompt": prompt_texts[prompt_index],
      # "completion": out_data.text[prompt_index],
      "tokens": [2] + vocab.encode(prompt_texts[prompt_index]),
      "logits": logits[0].tolist(),  # remove the batch dim and then tolist() for json serialization
  }
  all_data_to_save.append(data_to_save)

expanded_one_sample_input=Array([[     2, 235285,   2182,    577]], dtype=int32), positions=Array([0, 1, 2, 3], dtype=int32), attention_mask=Array([[[ True, False, False, False],
        [ True,  True, False, False],
        [ True,  True,  True, False],
        [ True,  True,  True,  True]]], dtype=bool)
embed output (Array(1, dtype=int32, weak_type=True), Array(4, dtype=int32, weak_type=True), Array(4608, dtype=int32, weak_type=True)), 
value [[[ 0.32421875 -0.36914062 -0.328125   ...  0.47070312 -0.34179688
    0.26953125]
  [-0.91015625  0.11523438 -0.5703125  ...  0.5859375  -0.03149414
   -0.003479  ]
  [ 0.14746094  0.67578125 -0.5625     ... -1.3515625  -0.13769531
    1.3515625 ]
  [-0.45703125  0.5234375  -0.07470703 ... -0.15527344 -0.05932617
   -0.09619141]]]
test dtype float32
test dtype float32
test dtype float32
test dtype float32
test dtype float32
test dtype float32
test dtype float32
test dtype float32
test dtype float32
test dtype float32
test dtype float32
test dty

In [10]:
with jsonlines.open(output_path, "w") as f:
  f.write_all(all_data_to_save)


print(f"Data saved to {output_path}")

Data saved to golden_data_gemma2-27b.jsonl
