In [1]:
import kagglehub
import pathlib
import torch

import sentencepiece as spm
import recurrentgemma

import pprint as pp

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
%config InlineBackend.figure_format="retina"

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [3]:
VARIANT = "9b-it"
assert VARIANT in ["2b", "2b-it", "9b", "9b-it"]

In [4]:
%%capture
weights_dir = kagglehub.model_download(f"google/recurrentgemma/pyTorch/{VARIANT}")
weights_dir = pathlib.Path(weights_dir)
ckpt_path = weights_dir/f"{VARIANT}.pt"
vocab_path = weights_dir/"tokenizer.model"

preset = (
    recurrentgemma.Preset.RECURRENT_GEMMA_2B_V1 
    if "2b" in VARIANT 
    else recurrentgemma.Preset.RECURRENT_GEMMA_9B_V1
)

In [5]:
params = torch.load(str(ckpt_path))
params = {k: v.to(device=device) for k, v in params.items()}

In [6]:
model_config = recurrentgemma.GriffinConfig.from_torch_params(params, preset=preset)
model = recurrentgemma.Griffin(model_config, device=device, dtype=torch.bfloat16)
model.load_state_dict(params)

<All keys matched successfully>

In [7]:
model_config

GriffinConfig(vocab_size=256000, width=4096, mlp_expanded_width=12288, num_heads=16, block_types=(<TemporalBlockType.RECURRENT: 2>, <TemporalBlockType.RECURRENT: 2>, <TemporalBlockType.ATTENTION: 1>, <TemporalBlockType.RECURRENT: 2>, <TemporalBlockType.RECURRENT: 2>, <TemporalBlockType.ATTENTION: 1>, <TemporalBlockType.RECURRENT: 2>, <TemporalBlockType.RECURRENT: 2>, <TemporalBlockType.ATTENTION: 1>, <TemporalBlockType.RECURRENT: 2>, <TemporalBlockType.RECURRENT: 2>, <TemporalBlockType.ATTENTION: 1>, <TemporalBlockType.RECURRENT: 2>, <TemporalBlockType.RECURRENT: 2>, <TemporalBlockType.ATTENTION: 1>, <TemporalBlockType.RECURRENT: 2>, <TemporalBlockType.RECURRENT: 2>, <TemporalBlockType.ATTENTION: 1>, <TemporalBlockType.RECURRENT: 2>, <TemporalBlockType.RECURRENT: 2>, <TemporalBlockType.ATTENTION: 1>, <TemporalBlockType.RECURRENT: 2>, <TemporalBlockType.RECURRENT: 2>, <TemporalBlockType.ATTENTION: 1>, <TemporalBlockType.RECURRENT: 2>, <TemporalBlockType.RECURRENT: 2>, <TemporalBlockType

In [8]:
vocab = spm.SentencePieceProcessor()
vocab.Load(str(vocab_path))

True

In [10]:
sampler = recurrentgemma.Sampler(model=model, vocab=vocab, is_it_model=("it" in VARIANT))
sampler._is_it_model

True

In [16]:
prompt = """
In your generation, every time you see "Tushaar", it must be followed by "Gangavarapu".
Write a complete biography of 3000 or more tokens.

Tushaar Gangavarapu is an instructor of Computer Science at Cornell University. Their are 
advised by Alexander (Sasha) Rush. Their research is
"""

In [17]:
input_batch = [prompt]
out_data = sampler(input_strings=input_batch, total_generation_steps=3000)

for input_string, out_string in zip(input_batch, out_data.text):
  print(f"prompt:\n{input_string}\noutput:\n{out_string}")
  print(10 * "#")

prompt:

In your generation, every time you see "Tushaar", it must be followed by "Gangavarapu".
Write a complete biography of 3000 or more tokens.

Tushaar Gangavarapu is an instructor of Computer Science at Cornell University. Their are 
advised by Alexander (Sasha) Rush. Their research is

output:
## Tushaar Gangavarapu: A Rising Star in the Realm of Artificial Intelligence

Tushaar Gangavarapu is a rising star in the field of Artificial Intelligence (AI), currently pursuing a PhD in Computer Science at Cornell University under the guidance of renowned AI researcher Alexander (Sasha) Rush. His research focuses on developing novel AI systems capable of understanding and generating human-like text, with applications in natural language processing, machine translation, and creative writing.

**Early Life and Education:**

Born in Hyderabad, India, Tushaar developed a passion for technology and problem-solving at a young age. He excelled in his studies, earning a B.Tech. degree in Compu