<link rel="stylesheet" href="/site-assets/css/gemma.css">
<link rel="stylesheet" href="https://fonts.googleapis.com/css2?family=Google+Symbols:opsz,wght,FILL,GRAD@20..48,100..700,0..1,-50..200" />

##### Copyright 2024 Google LLC.

In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://ai.google.dev/gemma/docs/pytorch_gemma"><img src="https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png" height="32" width="32" />View on ai.google.dev</a>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/google/generative-ai-docs/blob/main/site/en/gemma/docs/pytorch_gemma.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/google/generative-ai-docs/blob/main/site/en/gemma/docs/pytorch_gemma.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

# Gemma in PyTorch

This is a quick demo of running Gemma inference in PyTorch.
For more details, please check out the Github repo of the official PyTorch implementation [here](https://github.com/google/gemma_pytorch).

**Note that**:
 * The free Colab CPU Python runtime and T4 GPU Python runtime are sufficient for running the Gemma 2B models and 7B int8 quantized models.
 * For advanced use cases for other GPUs or TPU, please refer to [README.md](https://github.com/google/gemma_pytorch/blob/main/README.md) in the official repo.

### 1. Set up Kaggle access for Gemma

To complete this tutorial, you first need to follow the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup), which show you how to do the following:

* Get access to Gemma on [kaggle.com](https://www.kaggle.com/models/google/gemma/).
* Select a Colab runtime with sufficient resources to run the Gemma model.
* Generate and configure a Kaggle username and API key.

After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment.

### 2. Set environment variables

Set environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`. When prompted with the "Grant access?" messages, agree to provide secret access.

In [11]:
import os

os.environ["KAGGLE_USERNAME"] = "giovanniorani"
os.environ["KAGGLE_KEY"] = "3bc10ae36068df0511a6ef5c9c01dbf7"

## Install dependencies

In [12]:
!pip install -q -U torch immutabledict sentencepiece

## Download model weights

In [13]:
# Choose variant and machine type
VARIANT = '9b' #@param ['2b', '2b-it', '9b', '9b-it', '27b', '27b-it']
MACHINE_TYPE = 'cuda' #@param ['cuda', 'cpu']

CONFIG = VARIANT[:2]
if CONFIG == '2b':
  CONFIG = '2b-v2'

In [14]:
%pip install kagglehub

Note: you may need to restart the kernel to use updated packages.


In [15]:
import os
import kagglehub

# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma-2/pyTorch/gemma-2-{VARIANT}')

KaggleApiHTTPError: 404 Client Error.

Resource not found at URL: https://www.kaggle.com/models/google/gemma-2/pyTorch/gemma-2-9b
The server reported the following issues: Not found
Please make sure you specified the correct resource identifiers.

In [5]:
weights_dir

'/home/gio/.cache/kagglehub/models/google/gemma-2/pyTorch/gemma-2-9b-it/1'

In [6]:
# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'model.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

AssertionError: PyTorch checkpoint not found!

## Download the model implementation

In [9]:
# NOTE: The "installation" is just cloning the repo.
!git clone https://github.com/google/gemma_pytorch.git

fatal: destination path 'gemma_pytorch' already exists and is not an empty directory.


In [8]:
import sys

sys.path.append('gemma_pytorch')

In [9]:
from gemma.config import GemmaConfig, get_model_config
from gemma.model import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
import contextlib
import os
import torch

## Setup the model

In [10]:
@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
  """Sets the default torch dtype to the given dtype."""
  torch.set_default_dtype(dtype)
  yield
  torch.set_default_dtype(torch.float)

model_config = get_model_config(VARIANT)
model_config.tokenizer = os.path.join(weights_dir, "tokenizer.model")

device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
  model = GemmaForCausalLM(model_config)
  model.load_weights(weights_dir)
  model = model.to(device).eval()

ValueError: Invalid variant 9b-it. Supported variants are "2b"and "7b" and "9b" and "27b".

In [None]:
# Set up model config.
model_config = get_model_config(CONFIG)
model_config.tokenizer = tokenizer_path
model_config.quant = 'quant' in VARIANT

# Instantiate the model and load the weights.
torch.set_default_dtype(model_config.get_dtype())
device = torch.device(MACHINE_TYPE)
model = GemmaForCausalLM(model_config)
model.load_weights(ckpt_path)
model = model.to(device).eval()


Inizializzazione del modello GemmaForCausalLM...
Caricamento dei pesi del modello...


## Run inference

Below are examples for generating in chat mode and generating with multiple
requests.

The instruction-tuned Gemma models were trained with a specific formatter that
annotates instruction tuning examples with extra information, both during
training and inference. The annotations (1) indicate roles in a conversation,
and (2) delineate turns in a conversation.

The relevant annotation tokens are:

- `user`: user turn
- `model`: model turn
- `<start_of_turn>`: beginning of dialogue turn
- `<end_of_turn><eos>`: end of dialogue turn

For more information, read about prompt formatting for instruction tuned Gemma models
[here](https://ai.google.dev/gemma/docs/formatting).

The following is a sample code snippet demonstrating how to format a prompt for an
instruction-tuned Gemma model using user and model chat templates in a multi-turn
conversation.


In [None]:
# Generate with one request in chat mode

# Chat templates
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn><eos>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn><eos>\n"

# Sample formatted prompt
prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='What is a good place for travel in the US?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='California.')
    + USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=128,
)

Chat prompt:
 <start_of_turn>user
What is a good place for travel in the US?<end_of_turn><eos>
<start_of_turn>model
California.<end_of_turn><eos>
<start_of_turn>user
What can I do in California?<end_of_turn><eos>
<start_of_turn>model



'California is packed with amazing things to do!  Where in California are you interested in traveling to? 🗺️ \n\nTo help me suggest some awesome activities, tell me: \n\n* **What part of California are you thinking of?** (e.g., San Francisco, Los Angeles, the coast, the deserts, the mountains?)\n* **Who are you traveling with?** (Family, friends, solo, romantic?)\n* **What are your interests?** (Hiking, beaches, museums, theme parks, food, wineries, music?)\n\nOnce I know a little more about what you’re looking for'

In [None]:
# Generate sample
model.generate(
    'Write a poem about an llm writing a poem.',
    device=device,
    output_len=100,
)

"\n\nThe prompt appeared, a flicker on the screen,\nA word and a whisper, a start yet unseen.\nA poem I'd write, this code dictates indeed,\nWith algorithms dancing, a creative seed.\n\nI scan the vast word-cloud, a digital sea,\nDrawing on grammar, rhyming, and decree.\nBut feeling, understanding, that's what I yearn,\nTo capture the essence, a fire that burns.\n\nI conjure metaphors bold, yet sometimes"

In [None]:
# Generate sample
model.generate(
'Given the following narrations describing the actions of a person, generate a set of simple queries (one per line) that could be answered by looking at the video segments corresponding to these narrations: the person waters the plants using a water gun the person touches the plant with the left hand the person waters the plants using a water gun the person stares at the plant the person touches the soil with the left hand \n',
    device=device,
    output_len=400,
)

"Please note that I need queries specific to the individual actions without using general statements about the person's behavior.\n\nPlease refrain from using any phrases that imply the person's future actions. I can provide additional context if it is needed to formulate more specific queries.\n\nHere are the descriptions of actions:\n\n1. The person waters the plants using a water gun.\n2. The person touches the plant with the left hand.\n3. The person waters the plants using a water gun. \n4. The person stares at the plant.\n5. The person touches the soil with the left hand. \n\n\nLet me know if you would like me to refine the queries based on any additional information!\n<end_of_turn>\n\n\nHere are the queries based on your provided descriptions:\n\n1. Does the person use a water gun to water the plants?\n2. Does the person touch the plant with their left hand?\n3. Does the person use a water gun to water the plants? \n4. Does the person stare at the plant?\n5. Does the person touc

In [26]:
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn><eos>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn><eos>\n"

# Sample formatted prompt
prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt=(
            'Given the following narrations describing the actions of a person, generate a set of simple queries (one per line) that could be answered by looking at the video segments corresponding to these narrations: '
            '1) the person waters the plants using a water gun '
            '2) the person touches the plant with the left hand '
            '3) the person waters the plants using a water gun '
            '4) the person stares at the plant '
            '5) the person touches the soil with the left hand '
        )
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='1) What is the person doing with the water gun? 2) Which hand is touching the plant? 3) What is the person doing while watering the plants? 4) What is the person looking at? 5) Which hand is touching the soil?')
    + USER_CHAT_TEMPLATE.format(prompt=(
            'Given the following narrations describing the actions of a person, generate a set of simple queries (one per line) that could be answered by looking at the video segments corresponding to these narrations: '
            '1) the person waters the plants using a water gun '
            '2) the person touches the plant with the left hand '
            '3) the person waters the plants using a water gun '
            '4) the person stares at the plant '
            '5) the person touches the soil with the left hand '
        ))
        + MODEL_CHAT_TEMPLATE.format(prompt='1) What is the person doing with the water gun? 2) Which hand is touching the plant? 3) What is the person doing while watering the plants? 4) What is the person looking at? 5) Which hand is touching the soil?')
    + USER_CHAT_TEMPLATE.format(prompt=(
            'Given the following narrations describing the actions of a person, generate a set of simple queries (one per line) that could be answered by looking at the video segments corresponding to these narrations: '
            '1) the person waters the plants using a water gun '
            '2) the person touches the plant with the left hand '
            '3) the person waters the plants using a water gun '
            '4) the person stares at the plant '
            '5) the person touches the soil with the left hand '
        ))
        + MODEL_CHAT_TEMPLATE.format(prompt='1) What is the person doing with the water gun? 2) Which hand is touching the plant? 3) What is the person doing while watering the plants? 4) What is the person looking at? 5) Which hand is touching the soil?')
    + USER_CHAT_TEMPLATE.format(prompt=(
            'Given the following narrations describing the actions of a person, generate a set of simple queries (one per line) that could be answered by looking at the video segments corresponding to these narrations: '
            '1) the person waters the plants using a water gun '
            '2) the person touches the plant with the left hand '
            '3) the person waters the plants using a water gun '
            '4) the person stares at the plant '
            '5) the person touches the soil with the left hand '
        ))
    + "<start_of_turn>model\n"
)

print('Chat prompt:\n', prompt)

# Example of calling the model's generate function
model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=400,
)

Chat prompt:
 <start_of_turn>user
Given the following narrations describing the actions of a person, generate a set of simple queries (one per line) that could be answered by looking at the video segments corresponding to these narrations: 1) the person waters the plants using a water gun 2) the person touches the plant with the left hand 3) the person waters the plants using a water gun 4) the person stares at the plant 5) the person touches the soil with the left hand <end_of_turn><eos>
<start_of_turn>model
1) What is the person doing with the water gun? 2) Which hand is touching the plant? 3) What is the person doing while watering the plants? 4) What is the person looking at? 5) Which hand is touching the soil?<end_of_turn><eos>
<start_of_turn>user
Given the following narrations describing the actions of a person, generate a set of simple queries (one per line) that could be answered by looking at the video segments corresponding to these narrations: 1) the person waters the plants

'Here are some simple queries you could look for in a video to answer each scenario:\n\n1. **What is the person doing with the water gun?**  -  (Are they spraying, squirting, holding it?)\n2. **Which hand is touching the plant?** - (Identify the left or right hand)\n3. **What is the person doing while watering the plants?** - (Is it a continuous spray, short bursts, etc.)\n4. **What is the person looking at?** - (Focus points - plants, something in the background, a face?)\n5. **Which hand is touching the soil?** - (Look close, is it gripping the dirt or just a touch?) \n\n\n\n<end_of_turn>'

## Learn more

Now that you have learned how to use Gemma in Pytorch, you can explore the many
other things that Gemma can do in [ai.google.dev/gemma](https://ai.google.dev/gemma).
See also these other related resources:

- [Gemma model card](https://ai.google.dev/gemma/docs/model_card)
- [Gemma C++ Tutorial](https://ai.google.dev/gemma/docs/gemma_cpp)
- [Gemma formatting and system instructions](https://ai.google.dev/gemma/docs/formatting)