##### Copyright 2024 Google LLC.

In [1]:
# @title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Gemma - Model evaluation

This notebook demonstrates using EleautherAI's Language Model Evaluation Harness to perform model performance benchmark on Gemma2 2B, specifically using a subset of MMLU.
<table align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/Gemma_evaluation.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
</table>

## Setup

### Select the Colab runtime
To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the Gemma model. In this case, you can use the T4 GPU but you need a high RAM instance (due to [this](https://github.com/google-deepmind/gemma/issues/57)).

### Gemma setup

To complete this tutorial, you'll first need to complete the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup). The Gemma setup instructions show you how to do the following:

* Get access to Gemma on kaggle.com.
* Select a Colab runtime with sufficient resources to run
  the Gemma 2B model.
* Generate and configure a Kaggle username and an API key as Colab secrets.

After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment.

### Configure your credentials

Add your your Kaggle credentials to the Colab Secrets manager to securely store it.

1. Open your Google Colab notebook and click on the 🔑 Secrets tab in the left panel. <img src="https://storage.googleapis.com/generativeai-downloads/images/secrets.jpg" alt="The Secrets tab is found on the left panel." width=50%>
2. Create new secrets: `KAGGLE_USERNAME` and `KAGGLE_KEY`
3. Copy/paste your username into `KAGGLE_USERNAME`
3. Copy/paste your key into `KAGGLE_KEY`
4. Toggle the buttons on the left to allow notebook access to the secrets.

In [None]:
import os
import kagglehub
from google.colab import userdata

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

# Pre-allocate 90% of accelerator memory
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

Install the evaluation harness.

In [None]:
!pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git

Collecting git+https://github.com/EleutherAI/lm-evaluation-harness.git
  Cloning https://github.com/EleutherAI/lm-evaluation-harness.git to /tmp/pip-req-build-1s_z1sm5
  Running command git clone --filter=blob:none --quiet https://github.com/EleutherAI/lm-evaluation-harness.git /tmp/pip-req-build-1s_z1sm5
  Resolved https://github.com/EleutherAI/lm-evaluation-harness.git to commit 543617fef9ba885e87f8db8930fbbff1d4e2ca49
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


Install the Gemma JAX library.

In [None]:
!pip install -q git+https://github.com/google-deepmind/gemma.git
from gemma import params as params_lib
import sentencepiece as spm
from gemma import transformer as transformer_lib
from gemma import sampler as sampler_lib

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


Download the Gemma model and tokenizer.

In [None]:
GEMMA_VARIANT = 'gemma2-2b-it'
GEMMA_PATH = kagglehub.model_download(f'google/gemma-2/flax/{GEMMA_VARIANT}')
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'tokenizer.model')



In [None]:
params = params_lib.load_and_format_params(CKPT_PATH)
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)

transformer_config = transformer_lib.TransformerConfig.from_params(
    params=params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(transformer_config)

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer'],
)

Create a new Gemma LM calss, following the [New Model Guide](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/model_guide.md). For MMLU, we only need to implement the loglikelihood() function.

In [None]:
import os
from typing import Optional, Union, List, Dict
import kagglehub
import jax
import jax.numpy as jnp
from jinja2 import Template
from tqdm import tqdm
import lm_eval
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from lm_eval.models import utils


@register_model("Gemma2")
class Gemma2LM(LM):

    def __init__(self, batch_size: Optional[int] = 1):
        self._batch_size = batch_size
        self._rank = 0
        self._world_size = 1

    @property
    def batch_size(self):
        return self._batch_size

    def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
        results = []

        for chunked_request in utils.chunks(tqdm(requests, disable=False), self._batch_size):
            contexts = [req.args[0] for req in chunked_request]
            next_tokens = [req.args[1] for req in chunked_request]

            outputs = sampler(input_strings=contexts, total_generation_steps=1)

            for i, logits in enumerate(outputs.logits):
                next_token = next_tokens[i]

                next_token_id = vocab.EncodeAsIds(next_token)[0]
                logits = jnp.array(logits[0])  # Assuming generating one token

                log_probs = jax.nn.log_softmax(logits, axis=-1)

                next_token_logprob = log_probs[next_token_id]

                is_greedy = next_token_id == log_probs.argmax()

                results.append((float(next_token_logprob), is_greedy))

        return results

    def loglikelihood_rolling(self, requests: list[Instance]) -> list[tuple[float, bool]]:
        # Used to evaluate perplexity; not important for this tutorial
        raise NotImplementedError("loglikelihood_rolling is not implemented")


    def generate_until(self, requests: list[Instance]) -> list[str]:
        # Not used for MMLU
        raise NotImplementedError("generat_until is not implemented")


Now start the evaluation process. It takes a long time to run through the entire MMLU benchmark, so we are just going to run a subset of it.

In [None]:
import lm_eval
results = lm_eval.simple_evaluate(
    model=Gemma2LM(batch_size=4),
    tasks=["mmlu_management"],
    num_fewshot=3,
)

print(results['results'])

INFO:lm-eval:Setting random seed to 0 | Setting numpy seed to 1234 | Setting torch manual seed to 1234
INFO:lm-eval:Using pre-initialized model
INFO:lm-eval:`group` and `group_alias` keys in TaskConfigs are deprecated and will be removed in v0.4.5 of lm_eval. The new `tag` field will be used to allow for a shortcut to a group of tasks one does not wish to aggregate metrics across. `group`s which aggregate across subtasks must be only defined in a separate group config file, which will be the official way to create groups that support cross-task aggregation as in `mmlu`. Please see the v0.4.4 patch notes and our documentation: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/new_task_guide.md#advanced-group-configs for more information.
INFO:lm-eval:Setting fewshot random generator seed to 1234
INFO:lm-eval:Building contexts for mmlu_management on rank 0...
100%|██████████| 103/103 [00:00<00:00, 118.20it/s]
INFO:lm-eval:Running loglikelihood requests
100%|██████████| 4

{'mmlu_management': {'alias': 'management', 'acc,none': 0.6601941747572816, 'acc_stderr,none': 0.046897659372781335}}
