Copyright 2024 DeepMind Technologies Limited.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

---

# Getting Started with Gemma Sampling: A Step-by-Step Guide

You will find in this colab a detailed tutorial explaining how to load a Gemma checkpoint and sample from it.



## Installation

In [1]:
! pip install git+https://github.com/google-deepmind/gemma.git
! pip install --user kaggle
! pip install jax
! pip install orbax
! pip install chex
! pip install flax
! pip install sentencepiece

Collecting git+https://github.com/google-deepmind/gemma.git
  Cloning https://github.com/google-deepmind/gemma.git to c:\users\t-amuslih\appdata\local\temp\pip-req-build-geya55vv
  Resolved https://github.com/google-deepmind/gemma.git to commit a0504162f99a1c238efb37b8197e711c0f3808fd
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Getting requirements to build wheel: finished with status 'done'
  Preparing metadata (pyproject.toml): started
  Preparing metadata (pyproject.toml): finished with status 'done'
Collecting sentencepiece<0.2.0,>=0.1.99 (from gemma==1.0.0)
  Using cached sentencepiece-0.1.99.tar.gz (2.6 MB)
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'error'


  Running command git clone --filter=blob:none --quiet https://github.com/google-deepmind/gemma.git 'C:\Users\t-amuslih\AppData\Local\Temp\pip-req-build-geya55vv'
  error: subprocess-exited-with-error
  
  × python setup.py egg_info did not run successfully.
  │ exit code: 1
  ╰─> [17 lines of output]
      Traceback (most recent call last):
        File "<string>", line 2, in <module>
        File "<pip-setuptools-caller>", line 34, in <module>
        File "C:\Users\t-amuslih\AppData\Local\Temp\pip-install-zz8i9y6a\sentencepiece_d5ea635da9c945288d6b8fb70f78c851\setup.py", line 126, in <module>
          subprocess.check_call([
        File "c:\Users\t-amuslih\source\repos\uni\gemma\.conda\Lib\subprocess.py", line 408, in check_call
          retcode = call(*popenargs, **kwargs)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "c:\Users\t-amuslih\source\repos\uni\gemma\.conda\Lib\subprocess.py", line 389, in call
          with Popen(*popenargs, **kwargs) as p:
            



## Downloading the checkpoint

"To use Gemma's checkpoints, you'll need a Kaggle account and API key. Here's how to get them:

1. Visit https://www.kaggle.com/ and create an account.
2. Go to your account settings, then the 'API' section.
3. Click 'Create new token' to download your key.

Then run the cell below.

In [2]:
import kagglehub
kagglehub.login()

VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

If everything went well, you should see:
```
Kaggle credentials set.
Kaggle credentials successfully validated.
```

Now select and download the checkpoint you want to try. Note that you will need an A100 runtime for the 7b models.

In [3]:
import os

VARIANT = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it'] {type:"string"}
weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')
ckpt_path = os.path.join(weights_dir, VARIANT)
vocab_path = os.path.join(weights_dir, 'tokenizer.model')

In [4]:
import sys
sys.path.append('C:/Users/t-amuslih/source/repos/uni/gemma')

In [5]:
# @title Python imports
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm

## Start Generating with Your Model

Load and prepare your LLM's checkpoint for use with Flax.

In [6]:
# Load parameters
params = params_lib.load_and_format_params(ckpt_path)

Load your tokenizer, which we'll construct using the [SentencePiece](https://github.com/google/sentencepiece) library.

In [8]:
vocab = spm.SentencePieceProcessor()
vocab.Load(vocab_path)

True

Use the `transformer_lib.TransformerConfig.from_params` function to automatically load the correct configuration from a checkpoint. Note that the vocabulary size is smaller than the number of input embeddings due to unused tokens in this release.

In [9]:
transformer_config=transformer_lib.TransformerConfig.from_params(
    params,
    cache_size=1024  # Number of time steps in the transformer's cache
)
transformer = transformer_lib.Transformer(transformer_config)

Finally, build a sampler on top of your model and your tokenizer.

In [10]:
# Create a sampler with the right param shapes.
sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer'],
)

You're ready to start sampling ! This sampler uses just-in-time compilation, so changing the input shape triggers recompilation, which can slow things down. For the fastest and most efficient results, keep your batch size consistent.

In [11]:
input_batch = [
    "\n# Python program for implementation of Bubble Sort\n\ndef bubbleSort(arr):",
    "What are the planets of the solar system?",
  ]

out_data = sampler(
    input_strings=input_batch,
    total_generation_steps=300,  # number of steps performed when generating
  )

for input_string, out_string in zip(input_batch, out_data.text):
  print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
  print()
  print(10*'#')

Prompt:

# Python program for implementation of Bubble Sort

def bubbleSort(arr):
Output:

    for i in range(len(arr)):
        for j in range(len(arr)-i-1):
            if arr[j] > arr[j+1]:
                swap(arr, j, j+1)
    return arr

def swap(arr, i, j):
    temp = arr[i]
    arr[i] = arr[j]
    arr[j] = temp

# Driver code
arr = [5, 2, 8, 3, 1, 9]
print(bubbleSort(arr))

```

**Explanation:**

1. **bubbleSort Function**:
    - It takes a list `arr` as input.
    - It uses two nested for loops to iterate through the list.
    - The outer loop `i` iterates from the beginning of the list to the end of the list.
    - The inner loop `j` iterates from the second element of the list to the end of the list minus `i` (excluding the element at position `j` itself).
    - Inside the nested loops, it compares the elements at positions `j` and `j+1` in the list.
    - If `arr[j]` is greater than `arr[j+1]`, it swaps the elements at positions `j` and `j+1` in the list.
    -

##########
P

You should get an implementation of bubble sort and a description of the solar system.
