In [1]:
import sys 
sys.path.append("/Users/aaronkilgallon/Downloads/archive/") 
from gemma.config import GemmaConfig, get_config_for_7b, get_config_for_2b
from gemma.model import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
import contextlib
import os
import torch

from IPython.display import Markdown as md

In [2]:
# Load the model
VARIANT = "2b" 
MACHINE_TYPE = "cpu" 
weights_dir = '/Users/aaronkilgallon/Downloads/archive/'

In [3]:
@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
  """Sets the default torch dtype to the given dtype."""
  torch.set_default_dtype(dtype)
  yield
  torch.set_default_dtype(torch.float)


In [4]:
# Model Config.
model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
model_config.tokenizer = os.path.join(weights_dir, "tokenizer.model")
model_config.quant = "quant" in VARIANT

In [14]:
print(model_config)

GemmaConfig(vocab_size=256000, max_position_embeddings=8192, num_hidden_layers=18, num_attention_heads=8, num_key_value_heads=1, hidden_size=2048, intermediate_size=16384, head_dim=256, rms_norm_eps=1e-06, dtype='bfloat16', quant=False, tokenizer='/Users/aaronkilgallon/Downloads/archive/tokenizer.model')


In [5]:
# Model.
device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
  model = GemmaForCausalLM(model_config)
  ckpt_path = os.path.join(weights_dir, f'gemma-{VARIANT}.ckpt')
  model.load_weights(ckpt_path)
  model = model.to(device).eval()

  return self.fget.__get__(instance, owner)()


In [12]:
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn>\n"

prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt="What is a good place for travel in the US?"
    )
    + MODEL_CHAT_TEMPLATE.format(prompt="California.")
    + USER_CHAT_TEMPLATE.format(prompt="What can I do there?")
    + "<start_of_turn>model\n"
)

print(prompt)

<start_of_turn>user
What is a good place for travel in the US?<end_of_turn>
<start_of_turn>model
California.<end_of_turn>
<start_of_turn>user
What can I do there?<end_of_turn>
<start_of_turn>model



In [13]:
print(model)

GemmaForCausalLM(
  (embedder): Embedding()
  (model): GemmaModel(
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaAttention(
          (qkv_proj): Linear()
          (o_proj): Linear()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear()
          (up_proj): Linear()
          (down_proj): Linear()
        )
        (input_layernorm): RMSNorm()
        (post_attention_layernorm): RMSNorm()
      )
    )
    (norm): RMSNorm()
  )
  (sampler): Sampler()
)


In [16]:
prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt="(0, 0, 0), (1, 0, 1), (2, 0, 2), (2, 0, 3), (3, 0, 4)"
    )
    + "<start_of_turn>model\n"
)

a = model.generate(
    prompt,
    device=device,
    output_len=20,
)

print(a)

(0, 1, 0), (0, 1, 1), (1
