In [None]:
!pip install -U "jax[cpu]"

In [None]:
!git clone https://github.com/google-deepmind/gemma.git

In [None]:
from huggingface_hub import snapshot_download

snapshot_download(repo_id="google/gemma-2b-flax", local_dir="/local/path/gemma-2b-flax")

In [None]:
import os

VARIANT = "2b"  # @param ['2b', '2b-it', '7b', '7b-it'] {type:"string"}


ckpt_path = "/local/path/gemma-2b-flax/2b/"
vocab_path = "/local/path/gemma-2b-flax/tokenizer.model"

In [None]:
# Load parameters
from gemma import params as params_lib

params = params_lib.load_and_format_params(ckpt_path)

In [None]:
import sentencepiece as spm

vocab = spm.SentencePieceProcessor()
vocab.Load(vocab_path)

In [None]:
# We use the `transformer_lib.TransformerConfig.from_params` function to
# automatically load the correct configuration from a checkpoint. Note that the
# vocabulary size is smaller than the number of input embeddings due to unused
# tokens in this release.

from gemma import transformer as transformer_lib

config_2b = transformer_lib.TransformerConfig.from_params(
    params, cache_size=30  # Number of time steps in the transformer's cache
)
model_2b = transformer_lib.Transformer(config=config_2b)

In [None]:
from gemma import sampler as sampler_lib
# Create a sampler with the right param shapes.
sampler = sampler_lib.Sampler(
    transformer=model_2b,
    vocab=vocab,
    params=params["transformer"],
)

In [None]:
prompt_texts = ["I love to", "Today is a", "What is the"]

out_data = sampler(
    input_strings=prompt_texts,
    total_generation_steps=6,  # number of steps performed when generating
)

for input_string, out_string in zip(prompt_texts, out_data.text):
  print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
  print()
  print(10 * "#")

In [None]:
import jax


def get_attention_mask_and_positions(
    example: jax.Array,
    pad_id: int,
) -> tuple[jax.Array, jax.Array]:
  """Builds the position and attention mask vectors from the given tokens."""

  pad_mask = example != pad_id

  current_token_position = transformer_lib.build_positions_from_mask(pad_mask)
  attention_mask = transformer_lib.make_causal_attn_mask(pad_mask)
  return current_token_position, attention_mask

In [None]:
import numpy as np
import jax.numpy as jnp
from gemma import transformer as transformer_lib
import jsonlines

params = params_lib.load_and_format_params(ckpt_path)

output_path = "golden_data_gemma-2b.jsonl"
all_data_to_save = []

for prompt_index in range(len(prompt_texts)):
  prompt_text = prompt_texts[prompt_index]
  one_sample_input = np.array([2] + vocab.encode(prompt_text))
  expanded_one_sample_input = jnp.expand_dims(one_sample_input, axis=0)
  pad_id = vocab.pad_id
  get_attention_mask_and_positions(one_sample_input, pad_id)
  # Build the position and attention mask vectors.
  positions, attention_mask = get_attention_mask_and_positions(one_sample_input, pad_id)
  print(f"{expanded_one_sample_input=}, {positions=}, {attention_mask=}")

  # Foward pass on the input data.
  # No attention cache is needed here.

  logits, _ = model_2b.apply(
      #     params,
      {"params": params["transformer"]},
      expanded_one_sample_input,
      positions,
      None,  # Attention cache is None.
      attention_mask,
  )
  print(f"{logits=}")

  # Prepare data to be saved
  data_to_save = {
      "prompt": prompt_texts[prompt_index],
      "completion": out_data.text[prompt_index],
      "tokens": [2] + vocab.encode(prompt_texts[prompt_index]),
      "logits": logits[0].tolist(),  # remove the batch dim and then tolist() for json serialization
  }
  all_data_to_save.append(data_to_save)

In [None]:
with jsonlines.open(output_path, "w") as f:
  f.write_all(all_data_to_save)


print(f"Data saved to {output_path}")