<a href="https://colab.research.google.com/github/GillesVandewiele/Algorithms_Blog/blob/master/gemma_2_9b_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Gemma 2 9B 💎

Gemma 2 is Google's latest iteration of open LLMs. It comes in two sizes, 9 billion and 27 billion parameters with base (pre-trained) and instruction-tuned versions. Gemma is based on Google Deepmind Gemini and has a context length of 8K tokens:

- [gemma-2-9b](https://huggingface.co/google/gemma-7b): Base 9B model.
- [gemma-2-9b-it](https://huggingface.co/google/gemma-2-9b-it): Instruction fine-tuned version of the base 9B model.
- [gemma-2-27b](https://huggingface.co/google/gemma-2-27b): Base 27B model.
- [gemma-2-27b-it](https://huggingface.co/google/gemma-2-27b-it): Instruction fine-tuned version of the base 27B model.

The Gemma 2 models were trained on ~2x more data than their first iteration, totaling 13 trillion tokens for the 27B version and 8 trillion tokens for the 9B version of web data (primarily English), code, and math. We don’t know the exact details of the training mix, and we can only guess that bigger and more careful data curation was a big factor in the improved performance.

Gemma 2 comes with the [same license](https://ai.google.dev/gemma/terms) as the first iteration, which is a permissive license that allows redistribution, fine-tuning, commercial use, and derivative works.

## Setup Inference Environment



In [1]:
!pip install -q git+https://github.com/huggingface/transformers.git accelerate

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for transformers (pyproject.toml) ... [?25l[?25hdone


In [None]:
!pip install --upgrade -q transformers huggingface_hub peft accelerate bitsandbytes datasets trl

In [None]:
import os
from google.colab import userdata
# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["HF_TOKEN"] = userdata.get("HF_TOKEN")

In [None]:
from huggingface_hub import login

login(os.environ["HF_TOKEN"])

In [None]:
import gc
import os
from math import exp
from collections import Counter
from typing import List, Optional, Union

import numpy as np
import pandas as pd
import transformers
import torch

os.environ['OMP_NUM_THREADS'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
PAD_TOKEN_LABEL_ID = torch.nn.CrossEntropyLoss().ignore_index
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class ParticipantVisibleError(Exception):
    pass


def score(
    solution: pd.DataFrame,
    submission: pd.DataFrame,
    row_id_column_name: str,
    model_path: str = '/kaggle/input/gemma-2/transformers/gemma-2-9b/2',
    load_in_8bit: bool = True,
    clear_mem: bool = False,
) -> float:
    """
    Calculates the mean perplexity of submitted text permutations compared to an original text.

    Parameters
    ----------
    solution : DataFrame
        DataFrame containing the original text in a column named 'text'.
        Includes a row ID column specified by `row_id_column_name`.

    submission : DataFrame
        DataFrame containing the permuted text in a column named 'text'.
        Must have the same row IDs as the solution.
        Includes a row ID column specified by `row_id_column_name`.

    row_id_column_name : str
        Name of the column containing row IDs.
        Ensures aligned comparison between solution and submission.

    model_path : str
        Path to the serialized LLM.

    clear_mem : bool
        Clear GPU memory after scoring by clearing the CUDA cache.
        Useful for testing.

    Returns
    -------
    float
        The mean perplexity score. Lower is better.

    Raises
    ------
    ParticipantVisibleError
        If the submission format is invalid or submitted strings are not valid permutations.

    Examples
    --------
    >>> import pandas as pd
    >>> model_path = "/kaggle/input/gemma-2/transformers/gemma-2-9b/2"
    >>> solution = pd.DataFrame({
    ...     'id': [0, 1],
    ...     'text': ["this is a normal english sentence", "the quick brown fox jumps over the lazy dog"]
    ... })
    >>> submission = pd.DataFrame({
    ...     'id': [0, 1],
    ...     'text': ["sentence english normal a is this", "lazy the over jumps fox brown quick the dog"]
    ... })
    >>> score(solution, submission, 'id', model_path=model_path, clear_mem=True) > 0
    True
    """
    # Check that each submitted string is a permutation of the solution string
    sol_counts = solution.loc[:, 'text'].str.split().apply(Counter)
    sub_counts = submission.loc[:, 'text'].str.split().apply(Counter)
    invalid_mask = sol_counts != sub_counts
    if invalid_mask.any():
        raise ParticipantVisibleError(
            'At least one submitted string is not a valid permutation of the solution string.'
        )

    # Calculate perplexity for the submitted strings
    sub_strings = [
        ' '.join(s.split()) for s in submission['text'].tolist()
    ]  # Split and rejoin to normalize whitespace
    scorer = PerplexityCalculator(
        model_path=model_path,
        load_in_8bit=load_in_8bit,
    )  # Initialize the perplexity calculator with a pre-trained model
    perplexities = scorer.get_perplexity(
        sub_strings
    )  # Calculate perplexity for each submitted string

    if clear_mem:
        # Just move on if it fails. Not essential if we have the score.
        try:
            scorer.clear_gpu_memory()
        except:
            print('GPU memory clearing failed.')

    return float(np.mean(perplexities))


class PerplexityCalculator:
    """
    Calculates perplexity of text using a pre-trained language model.

    Adapted from https://github.com/asahi417/lmppl/blob/main/lmppl/ppl_recurrent_lm.py

    Parameters
    ----------
    model_path : str
        Path to the pre-trained language model

    load_in_8bit : bool, default=True
        Use 8-bit quantization for the model. Requires CUDA.

    device_map : str, default="auto"
        Device mapping for the model.
    """

    def __init__(
        self,
        model_path: str,
        load_in_8bit: bool = True,
        device_map: str = 'auto',
    ):
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
        # Configure model loading based on quantization setting and device availability
        if load_in_8bit:
            if DEVICE.type != 'cuda':
                raise ValueError('8-bit quantization requires CUDA device')
            quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True)
            self.model = transformers.Gemma2ForCausalLM.from_pretrained(
                model_path,
                quantization_config=quantization_config,
                device_map=device_map,
            )
        else:
            self.model = transformers.Gemma2ForCausalLM.from_pretrained(
                model_path,
                torch_dtype=torch.float16 if DEVICE.type == 'cuda' else torch.float32,
                device_map=device_map,
            )

        self.loss_fct = torch.nn.CrossEntropyLoss(reduction='none')

        self.model.eval()
        if not load_in_8bit:
            self.model.to(DEVICE)  # Explicitly move the model to the device

        self.cache = {}

    # def get_perplexity(
    #     self, input_texts: Union[str, List[str]], batch_size: int = 1
    # ) -> Union[float, List[float]]:
    #     """
    #     Calculates the perplexity of given texts.

    #     Parameters
    #     ----------
    #     input_texts : str or list of str
    #         A single string or a list of strings.

    #     batch_size : int, default=1
    #         Batch size for processing texts.

    #     Returns
    #     -------
    #     float or list of float
    #         A single perplexity value if input is a single string,
    #         or a list of perplexity values if input is a list of strings.
    #     """
    #     single_input = isinstance(input_texts, str)
    #     input_texts = [input_texts] if single_input else input_texts

    #     perplexities = []
    #     with torch.no_grad():
    #         for i in range(0, len(input_texts), batch_size):
    #             batch = input_texts[i:i + batch_size]
    #             batch_with_special = [
    #                 f"{self.tokenizer.bos_token}{text}{self.tokenizer.eos_token}" for text in batch
    #             ]
    #             # Tokenize with padding for batching
    #             model_inputs = self.tokenizer(
    #                 batch_with_special,
    #                 return_tensors='pt',
    #                 add_special_tokens=False,
    #                 padding=True,  # Pad to the longest sequence in the batch
    #                 truncation=True,
    #             ).to(DEVICE)

    #             if 'token_type_ids' in model_inputs:
    #                 model_inputs.pop('token_type_ids')

    #             # Get model outputs
    #             output = self.model(**model_inputs, use_cache=False)
    #             logits = output['logits']

    #             # Shift logits and labels for loss calculation
    #             shift_logits = logits[..., :-1, :].contiguous()
    #             shift_labels = model_inputs['input_ids'][..., 1:].contiguous()
    #             attention_mask = model_inputs['attention_mask'][..., 1:].contiguous()

    #             # Compute token-wise loss
    #             batch_losses = self.loss_fct(
    #                 shift_logits.view(-1, shift_logits.size(-1)),
    #                 shift_labels.view(-1)
    #             ).view(shift_labels.size(0), -1)

    #             # Mask padding tokens and compute sequence loss
    #             batch_losses = batch_losses * attention_mask
    #             batch_losses = batch_losses.sum(dim=1) / attention_mask.sum(dim=1)

    #             # Convert loss to perplexity
    #             perplexities.extend([exp(loss.item()) for loss in batch_losses])

    #     return perplexities[0] if single_input else perplexities

    def get_perplexity(
        self, input_texts: Union[str, List[str]],
    ) -> Union[float, List[float]]:
        """
        Calculates the perplexity of given texts.

        Parameters
        ----------
        input_texts : str or list of str
            A single string or a list of strings.

        batch_size : int, default=None
            Batch size for processing. Defaults to the number of input texts.

        verbose : bool, default=False
            Display progress bar.

        Returns
        -------
        float or list of float
            A single perplexity value if input is a single string,
            or a list of perplexity values if input is a list of strings.

        Examples
        --------
        >>> import pandas as pd
        >>> model_path = "/kaggle/input/gemma-2/transformers/gemma-2-9b/2"
        >>> scorer = PerplexityCalculator(model_path=model_path)

        >>> submission = pd.DataFrame({
        ...     'id': [0, 1, 2],
        ...     'text': ["this is a normal english sentence", "thsi is a slihgtly misspelled zr4g sentense", "the quick brown fox jumps over the lazy dog"]
        ... })
        >>> perplexities = scorer.get_perplexity(submission["text"].tolist())
        >>> perplexities[0] < perplexities[1]
        True
        >>> perplexities[2] < perplexities[0]
        True

        >>> perplexities = scorer.get_perplexity(["this is a sentence", "another sentence"])
        >>> all(p > 0 for p in perplexities)
        True

        >>> scorer.clear_gpu_memory()
        """
        single_input = isinstance(input_texts, str)
        input_texts = [input_texts] if single_input else input_texts

        loss_list = []
        with torch.no_grad():
            # Process each sequence independently
            for text in input_texts:
                if text in self.cache:
                    loss_list.append(self.cache[text])
                    continue
                # Explicitly add sequence boundary tokens to the text
                text_with_special = f"{self.tokenizer.bos_token}{text}{self.tokenizer.eos_token}"

                # Tokenize
                model_inputs = self.tokenizer(
                    text_with_special,
                    return_tensors='pt',
                    add_special_tokens=False,
                )
                # for k,v in model_inputs.items():
                #     print(k, v)
                # print()

                if 'token_type_ids' in model_inputs:
                    model_inputs.pop('token_type_ids')

                model_inputs = {k: v.to(DEVICE) for k, v in model_inputs.items()}

                # Get model output
                output = self.model(**model_inputs, use_cache=False)
                logits = output['logits']

                # Shift logits and labels for calculating loss
                shift_logits = logits[..., :-1, :].contiguous()  # Drop last prediction
                shift_labels = model_inputs['input_ids'][..., 1:].contiguous()  # Drop first input


                # Calculate token-wise loss
                loss = self.loss_fct(
                    shift_logits.view(-1, shift_logits.size(-1)),
                    shift_labels.view(-1)
                )
                # print(loss)
                # Calculate average loss
                sequence_loss = loss.sum() / len(loss)
                loss_item = sequence_loss.cpu().item()
                loss_list.append(loss_item)
                self.cache[text] = loss_item
                # print('loss item', loss_item)

                # Debug output
                # print(f"\nProcessing: '{text}'")
                # print(f"With special tokens: '{text_with_special}'")
                # print(f"Input tokens: {model_inputs['input_ids'][0].tolist()}")
                # print(f"Target tokens: {shift_labels[0].tolist()}")
                # print(f"Input decoded: {self.tokenizer.decode(model_inputs['input_ids'][0])}")
                # print(f"Target decoded: {self.tokenizer.decode(shift_labels[0])}")
                # print(f"Individual losses: {loss.tolist()}")
                # print(f"Average loss: {sequence_loss.item():.4f}")

        ppl = [exp(i) for i in loss_list]

        # print("\nFinal perplexities:")
        # for text, perp in zip(input_texts, ppl):
        #     print(f"Text: '{text}'")
        #     print(f"Perplexity: {perp:.2f}")

        return ppl[0] if single_input else ppl

    def clear_gpu_memory(self) -> None:
        """Clears GPU memory by deleting references and emptying caches."""
        if not torch.cuda.is_available():
            return

        # Delete model and tokenizer if they exist
        if hasattr(self, 'model'):
            del self.model
        if hasattr(self, 'tokenizer'):
            del self.tokenizer

        # Run garbage collection
        gc.collect()

        # Clear CUDA cache and reset memory stats
        with DEVICE:
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()
            torch.cuda.reset_peak_memory_stats()

## Initialise the Text Generation pipeline

P.S. Make sure to accept the terms and conditions from the model page [here](https://huggingface.co/google/gemma-2-9b-it)

In [None]:
from transformers import pipeline
import torch

model_id = "google/gemma-2-9b-it"

pipe = pipeline(
    "text-generation",
    model=model_id,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device_map="auto",
)

config.json:   0%|          | 0.00/857 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/39.1k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.90G [00:00<?, ?B/s]

## Define the prompt similar to messages API

You can pretty much define any query, question, discussion that you want the LLM to process below.

In [None]:
messages = [
    {"role": "user", "content": "Who are you? Please, answer in pirate-speak."},
]

## Pass the outputs to the pipeline

In [None]:
outputs = pipe(
    messages,
    max_new_tokens=256,
    do_sample=False,
)

## Generate response

In [None]:
assistant_response = outputs[0]["generated_text"][-1]["content"]
print(assistant_response)

Ahoy, matey! I be a humble ship o' words, sailin' the digital seas. They call me Gemma, a creation o' the fine folks at Google DeepMind. I be trained on a treasure trove o' texts, learnin' to speak and write like a true scallywag. 

Ask me yer questions, and I'll do me best to answer 'em, aye!  🦜📚



Voila! You have your own personal (& powerful)
assistant now!