##### Copyright 2024 Google LLC.

In [None]:
# @title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Inference with Gemma 2 using mistral.rs

[Gemma](https://ai.google.dev/gemma) is a family of lightweight, state-of-the-art open-source language models from Google. Built from the same research and technology used to create the Gemini models, Gemma models are text-to-text, decoder-only large language models (LLMs), available in English, with open weights, pre-trained variants, and instruction-tuned variants.
Gemma models are well-suited for various text-generation tasks, including question-answering, summarization, and reasoning. Their relatively small size makes it possible to deploy them in environments with limited resources such as a laptop, desktop, or your cloud infrastructure, democratizing access to state-of-the-art AI models and helping foster innovation for everyone.

[mistral.rs](https://github.com/EricLBuehler/mistral.rs) is a versatile framework for Large Language Model (LLM) inference supporting text-to-text and multimodal LLMs. It offers features like grammar support for structured outputs, and inference using LoRA-fine-tuned models, making it a powerful tool for a wide range of AI applications.

In this notebook, you will learn how to prompt the Gemma 2 model in various ways using the **mistral.rs** Python APIs in a Google Colab environment.
<table align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/[Gemma_2]Using_with_mistral_rs.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
</table>

## Setup

### Select the Colab runtime
To complete this tutorial, you can use the CPU runtime in Colab. Since the default accelerator for any Colab runtime is CPU, simply click the **Connect** button in the top-right corner of the Colab window.

### Setup Hugging Face

**Before you dive into the tutorial, let's get you set up with Hugging face:**

1. **Hugging Face Account:**  If you don't already have one, you can create a free Hugging Face account by clicking [here](https://huggingface.co/join).

2. **Hugging Face Token:**  Generate a Hugging Face access (preferably `write` permission) token by clicking [here](https://huggingface.co/settings/tokens). You'll need this token later in the tutorial.

**Once you've completed these steps, you're ready to move on to the next section where you'll set up environment variables in your Colab environment.**

### Configure your HF token
Add your Hugging Face token to the Colab Secrets manager to securely store it.

1. Open your Google Colab notebook and click on the 🔑 Secrets tab in the left panel. <img src="https://storage.googleapis.com/generativeai-downloads/images/secrets.jpg" alt="The Secrets tab is found on the left panel." width=50%>
2. Create a new secret with the name `HF_TOKEN`.
3. Copy/paste your HF token key into the Value input box of `HF_TOKEN`.
4. Toggle the button on the left to allow notebook access to the secret.

In [None]:
import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["HF_TOKEN"] = userdata.get("HF_TOKEN")

### Install dependencies

First, you must install the Python package for mistral.rs.

In [None]:
!pip install mistralrs==0.3.2

Collecting mistralrs==0.3.2
  Downloading mistralrs-0.3.2-cp310-cp310-manylinux_2_34_x86_64.whl.metadata (1.7 kB)
Downloading mistralrs-0.3.2-cp310-cp310-manylinux_2_34_x86_64.whl (14.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.4/14.4 MB[0m [31m51.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: mistralrs
Successfully installed mistralrs-0.3.2


## Overview

 In this notebook, you will explore how to prompt the Gemma 2 model using the Python APIs of mistral.rs. It's divided into the following sections:

1. Load the Gemma 2 model
2. Response generation
3. Non-streaming chat completion
4. Streaming chat completion
5. Inference with grammar

Additionally, you can run inference on Gemma 2 using the Rust APIs and command-line interface of mistral.rs. Read more about them in the [mistral.rs documentation](https://github.com/EricLBuehler/mistral.rs?tab=readme-ov-file#get-started-fast-).


## Load the Gemma 2 model

In this section, you will learn how to load the Gemma 2 model from Hugging Face Hub using the mistral.rs Python APIs.

First, create an instance of the `Which` class specifying the model to load. Set `model_id` to the Hugging Face Hub repository ID of the desired Gemma 2 model variant. Specify `Architecture.Gemma2` for the arch parameter.

Create an instance of the `Runner` class by providing the created `Which` instance. Set `token_source` to `env:HF_TOKEN` to indicate downloading the model from Hugging Face Hub using your existing Hugging Face token.

Models can also be loaded from the local file system. Refer to the [mistral.rs get models guide](https://github.com/EricLBuehler/mistral.rs?tab=readme-ov-file#getting-models) for more details.

In [None]:
from mistralrs import Runner, Which, Architecture

runner = Runner(
    which=Which.Plain(
        model_id="google/gemma-2-2b-it",
        arch=Architecture.Gemma2,
    ),
    token_source="env:HF_TOKEN",
)

## Response generation

Next, you'll send a prompt to Gemma 2 for inference using mistral.rs.

To achieve this, create an instance of the `CompletionRequest` class, specifying your desired prompt in the `prompt` parameter. Set the model parameter to "gemma2". mistral.rs allows you to configure common LLM parameters like `top_p`, `max_tokens`, and `temperature` within `CompletionRequest`. You can find the full list in the [class definition of `CompletionRequest`](https://github.com/EricLBuehler/mistral.rs/blob/458dc5f447161904ee9191fdec1dc5d4f039af54/mistralrs-pyo3/mistralrs.pyi#L44).

Finally, call the `send_completion_request` method on the `runner` object, passing the newly created `CompletionRequest` instance as the `request` argument.

In [None]:
from mistralrs import CompletionRequest

request = CompletionRequest(
    model="gemma2",
    prompt="What is a black hole?",
    max_tokens=512,
    top_p=0.1,
    temperature=0.1,
)

response = runner.send_completion_request(request)

print(response.choices[0].text)
print(response.usage)

A black hole is a region of spacetime where gravity is so strong that nothing, not even light, can escape. 

Here's a breakdown:

**What causes a black hole?**

* **Stellar Collapse:** When a massive star runs out of fuel, it collapses under its own gravity. This collapse can create a black hole if the star's mass is enough.
* **Supermassive Black Holes:** These are found at the centers of most galaxies, including our own Milky Way. Their formation is still a mystery, but they likely formed from the merging of smaller black holes or from the collapse of massive gas clouds.

**Key Features of a Black Hole:**

* **Event Horizon:** This is the boundary around a black hole beyond which escape is impossible. Once something crosses the event horizon, it's trapped forever.
* **Singularity:** This is the theoretical point at the center of a black hole where all the matter is compressed into an infinitely small point.
* **Gravitational Pull:** Black holes have immense gravitational pull, which 

## Non-streaming chat completion

In addition to basic prompting, mistral.rs also supports multi-turn conversations with LLMs.

For multi-turn conversations, you'll maintain a list of dictionaries, each representing a single message in the conversation history with Gemma 2. These dictionaries should include a key named `role` specifying whether the user or the assistant (model) generated the message.

Create an instance of `ChatCompletionRequest` by passing the conversation history to the `messages` parameter. Then, invoke the `send_chat_completion_request` method of the `runner` with the newly created `ChatCompletionRequest` instance as the `request` argument.

You can also configure any other model parameters you want to use. Refer to the the [class definition of `ChatCompletionRequest`](https://github.com/EricLBuehler/mistral.rs/blob/458dc5f447161904ee9191fdec1dc5d4f039af54/mistralrs-pyo3/mistralrs.pyi#L11).

In [None]:
from mistralrs import ChatCompletionRequest

messages = [
    {
        "role": "user",
        "content": "Who are the first humans to land on moon?",
    },
    {
        "role": "assistant",
        "content": "The first humans to land on the Moon were **Neil Armstrong and Buzz Aldrin** of the Apollo 11 mission on **July 20, 1969**.",
    },
    {
        "role": "user",
        "content": "Which country did they belong to?",
    },
]

request = ChatCompletionRequest(
    model="gemma2",
    messages=messages,
    max_tokens=256,
    presence_penalty=1.0,
    top_p=0.1,
    temperature=0.1,
)

response = runner.send_chat_completion_request(request)

print(response.choices[0].message.content)
print(response.usage)

Neil Armstrong and Buzz Aldrin were both American astronauts. They represented the **United States** during the Apollo 11 mission. 

Usage {
    completion_tokens: 29,
    prompt_tokens: 74,
    total_tokens: 103,
    avg_tok_per_sec: 2.808376,
    avg_prompt_tok_per_sec: 4.5204644,
    avg_compl_tok_per_sec: 1.4281493,
    total_time_sec: 36.676,
    total_prompt_time_sec: 16.37,
    total_completion_time_sec: 20.306,
}


Notice how the model generated the response for the last query from the user based on the previous conversation history.

## Streaming chat completion

mistral.rs also supports obtaining streamed responses from the model during multi-turn conversations. To enable streaming, set the `stream` parameter of `ChatCompletionRequest` to `True`.

The `send_chat_completion_request` function will return an iterable object. You can access each chunk of the streamed response by iterating over this object.

Pass the conversation history (`messages`) you created earlier to the `ChatCompletionRequest` instance with `stream=True`.

In [None]:
request = ChatCompletionRequest(
    model="gemma2",
    messages=messages,
    max_tokens=256,
    presence_penalty=1.0,
    top_p=0.1,
    temperature=0.1,
    stream=True,
)

response = runner.send_chat_completion_request(request)

for chunk in response:
    print(chunk.choices[0].delta.content, end="")

Neil Armstrong and Buzz Aldrin were both American astronauts. They represented the **United States** during the Apollo 11 mission. 


## Inference with grammar

Specifying a grammar when creating a `ChatCompletionRequest` or `CompletionRequest` allows you to constrain the model's output to a specific format.

To ensure the model generates responses that match a regular expression, set the parameter `grammar_type` to "regex" and `grammar` to the desired regular expression string.

The following code snippet illustrates how to restrict the model's response to two-digit numbers using regular expression grammar:

In [None]:
request = CompletionRequest(
    model="gemma2",
    prompt="What is the next number in the fibonnaci series: 1, 1, 2, 3, 5, 8 ?",
    grammar_type="regex",
    grammar=r"\d{2}",
    top_p=0.1,
    temperature=0.1,
)

response = runner.send_completion_request(request)

print(response.choices[0].text)
print(response.usage)

13
Usage {
    completion_tokens: 5,
    prompt_tokens: 31,
    total_tokens: 36,
    avg_tok_per_sec: 3.115804,
    avg_prompt_tok_per_sec: 3.5123498,
    avg_compl_tok_per_sec: 1.8328446,
    total_time_sec: 11.554,
    total_prompt_time_sec: 8.826,
    total_completion_time_sec: 2.728,
}


mistral.rs offers support for various grammar types beyond regular expressions. Refer to the [mistral.rs documentation](https://github.com/EricLBuehler/mistral.rs/tree/master?tab=readme-ov-file#description) for more details.

These are just a few examples of how you can perform inference with Gemma 2 using mistral.rs. To explore its capabilities further, you can refer to the [mistral.rs documentation](https://github.com/EricLBuehler/mistral.rs/tree/master?tab=readme-ov-file#--mistralrs).
