In [16]:
import os, sys, contextlib, gc
import torch
import sentencepiece as spm
from torch import nn
from pathlib import Path
from gemma.config import get_model_config
from gemma.model import GemmaModel, GemmaForCausalLM, gemma_config
from gemma.siglip_vision.config import get_siglip_vision_model_config

model_checkpoint_path: Path = Path("gemma/gemma-3-1b-pt-checkpoint/model.ckpt")
tokenizer_path: Path = Path("gemma/gemma-3-1b-pt-checkpoint/tokenizer.model")
# sp = spm.SentencePieceProcessor()
# sp.Load(model_file=tokenizer_path.__str__())


In [17]:
@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
    """Sets the default torch dtype to the given dtype."""
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(torch.float)

In [18]:
device = torch.device("cpu")
model_config = gemma_config.get_config_for_1b("bfloat16")
model_config.dtype = "bfloat16"
model_config.tokenizer = tokenizer_path.__str__()
with _set_default_tensor_type(model_config.get_dtype()):
    model = GemmaForCausalLM(model_config)
    model.load_state_dict(
        torch.load(model_checkpoint_path, weights_only=False)["model_state_dict"]
    )
    print(model.eval())


GemmaForCausalLM(
  (embedder): Embedding()
  (model): GemmaModel(
    (layers): ModuleList(
      (0-25): 26 x Gemma2DecoderLayer(
        (self_attn): GemmaAttention(
          (qkv_proj): Linear()
          (o_proj): Linear()
          (query_norm): RMSNorm()
          (key_norm): RMSNorm()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear()
          (up_proj): Linear()
          (down_proj): Linear()
        )
        (input_layernorm): RMSNorm()
        (post_attention_layernorm): RMSNorm()
        (pre_feedforward_layernorm): RMSNorm()
        (post_feedforward_layernorm): RMSNorm()
      )
    )
    (norm): RMSNorm()
  )
  (sampler): Sampler()
)


In [19]:
model.generate(
    prompts="Once upon a time, in a land far, far away, there lived a",
    device="cpu",
    output_len=100,
    temperature=0,
    top_k=1,
)

" little girl named Mary. She was a sweet, kind, and gentle girl who loved to play with her friends and make new friends. One day, Mary was playing in the park when she saw a beautiful butterfly. She was so excited to see such a beautiful creature and wanted to take it home with her. But, she didn't know how to catch a butterfly. So, she decided to ask her friend, Peter, for help.\n\nPeter was a brave and adventurous boy who loved to"