##### Copyright 2024 Google LLC.

In [None]:
# @title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Gemma Inference on TPUs
This notebook demonstrates how to leverage Google Colab's TPUs for inference with [Gemma](https://ai.google.dev/gemma) , an open-weights Large Language Model (LLM), using the [Flax](https://github.com/google/flax).

<table align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/gemma_inference_on_tpu.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
</table>

### Connect to a TPU
- To connect to a TPU v2, click on the button Connect TPU in the top right-hand corner of the screen.

Now you can see the TPU devices you have available:


In [None]:
import jax

jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

## Installation

- To install Gemma you need to use Python 3.10 or higher.
- Google Colab typically offers Python 3.6 or later versions as the default runtime environment.

In [None]:
! pip install git+https://github.com/google-deepmind/gemma.git
! pip install --user kaggle

Collecting git+https://github.com/google-deepmind/gemma.git
  Cloning https://github.com/google-deepmind/gemma.git to /tmp/pip-req-build-vdzv6aiz
  Running command git clone --filter=blob:none --quiet https://github.com/google-deepmind/gemma.git /tmp/pip-req-build-vdzv6aiz
  Resolved https://github.com/google-deepmind/gemma.git to commit a24194737dcb54b7392091e9ba772aea1cb68ffb
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


## Downloading the Gemma Checkpoint

Before using [Google Gemma](https://ai.google.dev/gemma) for the first time, you must request access to the model through Kaggle by setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup), or completing the following steps:

1. Log in to [Kaggle](https://www.kaggle.com) or create a new Kaggle account if you don't already have one.
1. Go to the [Gemma model card](https://www.kaggle.com/models/google/paligemma/), and click **Request Access**.
1. Complete the consent form and accept the terms and conditions.

To generate a Kaggle API key, open your [**Settings** page in Kaggle](https://www.kaggle.com/settings) and click **Create New Token**. This triggers the download of a `kaggle.json` file containing your API credentials.

Then, in Colab, select **Secrets** (🔑) in the left pane and add your Kaggle username and Kaggle API key. Store your username under desired name `KAGGLE_USERNAME` and your API key under the name `KAGGLE_KEY`.

Set environment variables for Kaggle API credentials.

In [None]:
import os
from google.colab import userdata

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

In [None]:
import kagglehub

VARIANT = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it'] {type:"string"}
weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')

ckpt_path = os.path.join(weights_dir, VARIANT)
vocab_path = os.path.join(weights_dir, 'tokenizer.model')

In [None]:
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm

## Start Generating with Your Model

Load and prepare your LLM's checkpoint for use with Flax.

In [None]:
# Load parameters
params = params_lib.load_and_format_params(ckpt_path)

Load your tokenizer, which you'll construct using the [SentencePiece](https://github.com/google/sentencepiece) library.

In [None]:
vocab = spm.SentencePieceProcessor()
vocab.Load(vocab_path)

True

Use the `transformer_lib.TransformerConfig.from_params` function to automatically load the correct configuration from a checkpoint. Note that the vocabulary size is smaller than the number of input embeddings due to unused tokens in this release.

In [None]:
transformer_config=transformer_lib.TransformerConfig.from_params(
    params,
    cache_size=1024  # Number of time steps in the transformer's cache
)
transformer = transformer_lib.Transformer(transformer_config)

Finally, build a sampler on top of your model and your tokenizer.

In [None]:
# Create a sampler with the right param shapes.
sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer'],
)

You're ready to start sampling ! This sampler uses just-in-time compilation, so changing the input shape triggers recompilation, which can slow things down. For the fastest and most efficient results, keep your batch size consistent.

In [None]:
input_batch = [
    "\n Explain the phenomenon of a solar eclipse.",
  ]

out_data = sampler(
    input_strings=input_batch,
    total_generation_steps=300,
  )

for input_string, out_string in zip(input_batch, out_data.text):
  print(f"Prompt:\n{input_string}\n Answer:\n{out_string}")
  print()

Prompt:

 Explain the phenomenon of a solar eclipse.
 Answer:


A solar eclipse occurs when the Moon passes between the Sun and Earth, casting a shadow on Earth. This phenomenon is caused by the relative positions of the Moon, Sun, and Earth.

**Here's a step-by-step explanation of how a solar eclipse occurs:**

1. **New Moon:** The Moon is positioned between the Sun and Earth, and the Sun's rays are not directly visible from Earth.
2. **Waxing Crescent Phase:** As the Moon orbits the Sun, it gradually moves from the new moon phase to the waxing crescent phase. This means that the illuminated portion of the Moon is gradually increasing.
3. **First Quarter Phase:** When the Moon is at the first quarter phase, half of its face is illuminated.
4. **Waxing Gibbous Phase:** As the Moon continues to orbit the Sun, it moves further away from the Sun, and the illuminated portion of the Moon gradually increases to the waxing gibbous phase. This means that more and more of the Moon is illuminate