<a href="https://colab.research.google.com/github/USLF2025/chrome-extensions-samples/blob/main/colabs/sampling.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Sampling

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/gemma/blob/main/colabs/sampling.ipynb)

Example on how to load a Gemma model and run inference on it.

The Gemma library has 3 ways to prompt a model:

* `gm.text.ChatSampler`: Easiest to use, simply talk to the model and get answer. Support multi-turns conversations out-of-the-box.
* `gm.text.Sampler`: Lower level, but give more control. The chat state has to be manually handeled for multi-turn.
* `model.apply`: Directly call the model, only predict a single token.

In [2]:
!pip install -q gemma

In [3]:
# Common imports
import os
import jax
import jax.numpy as jnp

# Gemma imports
from gemma import gm

By default, Jax do not utilize the full GPU memory, but this can be overwritten. See [GPU memory allocation](https://docs.jax.dev/en/latest/gpu_memory_allocation.html):

In [4]:
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

Load the model and the params. Here we load the instruction-tuned version of the model.

In [20]:
model = gm.nn.Gemma3_4B()

params = gm.ckpts.load_params(str(gm.ckpts.CheckpointPath.GEMMA3_4B_IT))



TypeError: 'StepMetadata' object is not iterable

In [23]:
from gemma import gm

# Configure and load the Gemma 7B instruction-tuned model using the corrected import
model = gm.GemmaForCausalLM.from_pretrained('gemma-7b-it')

# The model and parameters should now be loaded and accessible through the 'model' object
# You can now use the 'model' for inference.

AttributeError: module 'gemma.gm' has no attribute 'GemmaForCausalLM'

In [25]:
from gemma import gm
print(dir(gm))

['__builtins__', '__cached__', '__dir__', '__doc__', '__file__', '__getattr__', '__loader__', '__name__', '__package__', '__path__', '__spec__', '_epy', 'ckpts', 'ckpts', 'data', 'data', 'evals', 'losses', 'math', 'math', 'nn', 'nn', 'peft', 'sharding', 'testing', 'text', 'text', 'tools', 'typing', 'typing', 'utils', 'vision']


In [27]:
import gemma

# Configure and load the Gemma 7B instruction-tuned model
model = gemma.GemmaForCausalLM.from_pretrained('gemma-7b-it')

# The model and parameters should now be loaded and accessible through the 'model' object
# You can now use the 'model' for inference.

AttributeError: Please use "from gemma import gm", NOT "import gemma as gm".

In [30]:
company_info_text = """COMPANY NAME : US LOGISTICS AND FREIGT FORWARDING LLC. 96063 ESTATE DRIVE YULEE, FL 32097- CORPORATE HQ  ; 260 PEACHTREE STEET ATLANTA, GA  SUITE 2109 30303 AND I AM THE FOUNDER ANTHONY RODRIGUEZ WITH STEPHANIE RODRIGUEZ IS PRINCIPLE OWNER 60-40 SPLIT WOMEN OWNED MINORITY COMPANY"""

# Extracting information (simplified - regex or NLP would be used in a real scenario)
company_name = "US LOGISTICS AND FREIGT FORWARDING LLC."
corporate_hq_address = "96063 ESTATE DRIVE YULEE, FL 32097"
atlanta_address = "260 PEACHTREE STEET ATLANTA, GA SUITE 2109 30303"
founder_name = "ANTHONY RODRIGUEZ"
principal_owner_name = "STEPHANIE RODRIGUEZ"
ownership_split = "60-40"
company_type = "WOMEN OWNED MINORITY COMPANY"

structured_company_data = {
    "company_name": company_name,
    "addresses": {
        "corporate_hq": corporate_hq_address,
        "atlanta_office": atlanta_address
    },
    "key_personnel": {
        "founder": founder_name,
        "principal_owner": principal_owner_name
    },
    "ownership": {
        "split": ownership_split,
        "type": company_type
    }
}

import json

# Convert the dictionary to a JSON string
structured_company_data_json = json.dumps(structured_company_data, indent=4)

print(structured_company_data_json)

{
    "company_name": "US LOGISTICS AND FREIGT FORWARDING LLC.",
    "addresses": {
        "corporate_hq": "96063 ESTATE DRIVE YULEE, FL 32097",
        "atlanta_office": "260 PEACHTREE STEET ATLANTA, GA SUITE 2109 30303"
    },
    "key_personnel": {
        "founder": "ANTHONY RODRIGUEZ",
        "principal_owner": "STEPHANIE RODRIGUEZ"
    },
    "ownership": {
        "split": "60-40",
        "type": "WOMEN OWNED MINORITY COMPANY"
    }
}


### How the Model Receives Data

*   **Unstructured Text:** The cleaned unstructured text can be used in several ways:
    *   As part of the training data for fine-tuning, allowing the model to learn the language, context, and general information within the document.
    *   As input for prompting, where you might ask the model questions directly about the text content.
*   **Structured Data:** Structured data provides the model with explicit facts and relationships. This can be used:
    *   To create specific training examples (e.g., "What is the company's key service?" -> "US Import/Export Drayage").
    *   In combination with unstructured text, where the model uses the structured data to ground its responses in factual information extracted from the text.
    *   For tasks requiring precise information retrieval or analysis.

By using both formats, we can leverage the richness of unstructured text and the precision of structured data to build a more capable and accurate knowledge model.

In [31]:
# This is a hypothetical example. In a real scenario, you would use
# more advanced techniques (like natural language processing) to extract
# specific information reliably.

# Let's assume the company name is mentioned early in the text
# We'll just take a placeholder for demonstration
company_name = "Example Logistics Company"

# Let's assume we identify a key service mentioned
key_service = "US Import/Export Drayage"

structured_data_example = {
    "company_name": company_name,
    "key_service_highlight": key_service,
    "source_document": "company insight.pdf"
}

import json

# Convert the dictionary to a JSON string
structured_data_json = json.dumps(structured_data_example, indent=4)

print(structured_data_json)

{
    "company_name": "Example Logistics Company",
    "key_service_highlight": "US Import/Export Drayage",
    "source_document": "company insight.pdf"
}


### Creating Simple Structured Data

Now, let's imagine we want to extract a specific piece of information from the text, like the company's name or a key service offered, and represent it in a structured format like a Python dictionary (which can be easily converted to JSON). This is a simplified example of how we would create structured data from unstructured text.

In [None]:
GENERATE WITH AI


In [32]:
import re

def clean_text(text):
    """Performs basic cleaning on text data."""
    # Remove excessive whitespace and newlines
    text = re.sub(r'\s+', ' ', text).strip()
    # You might add more cleaning steps here, like removing special characters
    # text = re.sub(r'[^a-zA-Z0-9\s]', '', text)
    # Convert to lowercase (optional, depending on the model)
    # text = text.lower()
    return text

cleaned_company_insight_text = clean_text(company_insight_text)

# Print the first 500 characters of the cleaned text
print(cleaned_company_insight_text[:500])

NameError: name 'company_insight_text' is not defined

### Cleaning Unstructured Text

First, let's perform some basic cleaning on the extracted text from the PDF. This typically involves removing unnecessary whitespace, special characters, and potentially converting text to lowercase for consistency.

In [18]:
import gemma.jax as gjax

model, params = gjax.GemmaForCausalLM.from_pretrained(
    'gemma_3b_it',
    # device_sharding=jax.sharding.PartitionSpec('dp', 'sp'), # Optional: for distributed training
    # param_sharding=jax.sharding.PartitionSpec('dp', 'sp'),  # Optional: for distributed training
)

ModuleNotFoundError: No module named 'gemma.jax'

In [17]:
model = gm.nn.Gemma3_4B()

params = gm.ckpts.load_params(params_path)



TypeError: 'StepMetadata' object is not iterable

In [15]:
from gemma import config as gm_config
from gemma import transformer as gm_transformer
from gemma import checkpoint as gm_checkpoint

# Define the model configuration
model_config = gm_config.GemmaConfig.from_json(
    '/usr/local/lib/python3.12/dist-packages/gemma/configs/config_3b_it.json'
)

# Instantiate the model
model = gm_transformer.Transformer(model_config)

# Load the parameters using CheckpointManager
checkpoint_manager = gm_checkpoint.CheckpointManager(
    '/usr/local/lib/python3.12/dist-packages/gemma/checkpoints/gemma_3b_it' # This path might need adjustment
)

params = checkpoint_manager.restore('params')

ImportError: cannot import name 'config' from 'gemma' (/usr/local/lib/python3.12/dist-packages/gemma/__init__.py)

In [9]:
import os
import tempfile

# Create a temporary directory to store checkpoints
ckpt_dir = tempfile.mkdtemp()

# Download the checkpoint files
!gsutil -m cp -r gs://gemma-jax/checkpoints/gemma_3b_it {ckpt_dir}

# Update the params path to the downloaded checkpoints
params_path = os.path.join(ckpt_dir, 'gemma_3b_it')

BucketNotFoundException: 404 gs://gemma-jax bucket does not exist.
CommandException: 1 file/object could not be transferred.


## Multi-turns conversations

The easiest way to chat with Gemma is to use the `gm.text.ChatSampler`. It hides the boilerplate of the conversation cache, as well as the `<start_of_turn>` / `<end_of_turn>` tokens used to format the conversation.

Here, we set `multi_turn=True` when creating `gm.text.ChatSampler` (by default, the `ChatSampler` start a new conversation every time).

In multi-turn mode, you can erase the previous conversation state, by passing `chatbot.chat(..., multi_turn=False)`.

In [None]:
sampler = gm.text.ChatSampler(
    model=model,
    params=params,
    multi_turn=True,
    print_stream=True,  # Print output as it is generated.
)

turn0 = sampler.chat('Share one methapore linking "shadow" and "laughter".')

In [None]:
turn1 = sampler.chat('Expand it in a haiku.')

Note: By default (`multi_turn=False`), the conversation state is reset everytime, but you can still continue the previous conversation by passing `sampler.chat(..., multi_turn=True)`

By default, greedy decoding is used. You can pass a custom `sampling=` method as kwargs:

* `gm.text.Greedy()`: (default) Greedy decoding
* `gm.text.RandomSampling()`: Simple random sampling with temperature, for more variety

## Sample a prompt

For more control, we also provide a `gm.text.Sampler` which still perform efficient sampling (with kv-caching, early stopping,...).

Prompting the sampler require to correctly add format the prompt with the `<start_of_turn>` / `<end_of_turn>` tokens (see the custom token section doc on [tokenizer](https://gemma-llm.readthedocs.io/en/latest/tokenizer.html)).

In [None]:
sampler = gm.text.Sampler(
    model=model,
    params=params,
)

prompt = """<start_of_turn>user
Give me a list of inspirational quotes.<end_of_turn>
<start_of_turn>model
"""

out = sampler.sample(prompt, max_new_tokens=1000)
print(out)

## Use the model directly

Here's an example of predicting a single token, directly calling the model.

The model input expectes encoded tokens. For this, we first need to encode the prompt with our tokenizer. See our [tokenizer](https://gemma-llm.readthedocs.io/en/latest/tokenizer.html) documentation for more information on using the tokenizer.

In [None]:
tokenizer = gm.text.Gemma3Tokenizer()

Note: When encoding the prompt, don't forget to add the beginning-of-string token with `add_bos=True`. All prompts feed to the model should start by this token.

In [None]:
prompt = tokenizer.encode('One word to describe Paris: \n\n', add_bos=True)
prompt = jnp.asarray(prompt)

We then can call the model, and get the predicted logits.

In [None]:
# Run the model
out = model.apply(
    {'params': params},
    tokens=prompt,
    return_last_only=True,  # Only predict the last token
)


# Sample a token from the predicted logits
next_token = jax.random.categorical(
    jax.random.key(1),
    out.logits
)
tokenizer.decode(next_token)

You can also display the next token probability.

In [None]:
tokenizer.plot_logits(out.logits)

## Next steps

* See our [multimodal](https://gemma-llm.readthedocs.io/en/latest/multimodal.html) example to query the model with images.
* See our [finetuning](https://gemma-llm.readthedocs.io/en/latest/finetuning.html) example to train Gemma on your custom task.
* See our [tool use](https://gemma-llm.readthedocs.io/en/latest/tool_use.html) tutorial to extend Gemma with external tools.
