##### Copyright 2024 Google LLC.

In [None]:
# @title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Model Composition using DeepMind CALM and Gemma

Welcome to this step-by-step guide on fine-tuning the [Gemma](https://huggingface.co/google/gemma-2b) using [Hugging Face Transformers](https://huggingface.co/docs/transformers/en/index) and DeepMind's **CALM (Composition to Augment Language Models)** framework.

As Large Language Models (LLMs) grow ever larger and more capable, it can be both challenging and costly to extend or adapt them to new domains or tasks. Many solutions involve retraining or fine-tuning a large, general-purpose model on new data—a time-consuming and resource-intensive process. Moreover, organizational constraints or data privacy concerns may limit access to the original training data needed for such adaptation.

[**CALM**](https://github.com/google-deepmind/calm) addresses these challenges by enabling the composition of two distinct language models—an “anchor” model with foundational capabilities and an “augmenting” model specialized in a particular domain—without fully re-training the anchor model. CALM does this by introducing cross-attention between models, allowing you to combine their strengths and preserve their original capabilities. The result is a more capable composed model that leverages existing, proven models and a few additional parameters, rather than building new monolithic models from scratch. The library currently supports combining any two models built with the Gemma architecture.

[**Transformers**](https://huggingface.co/docs/transformers/en/index) is a powerful and versatile tool for working with a wide range of large language models, tokenizers, and pipelines. It offers a user-friendly API for loading, training, and deploying state-of-the-art models, making it an integral component of the machine learning and natural language processing ecosystem. Its broad compatibility, ease of use, and extensive documentation help streamline tasks like fine-tuning, inference, and evaluation of models.

[**Gemma**](https://ai.google.dev/gemma) is a family of lightweight, state-of-the-art open models from Google, built from the same research and technology used to create the Gemini models. They are text-to-text, decoder-only large language models, available in English, with open weights, pre-trained variants, and instruction-tuned variants. Gemma models are well-suited for a variety of text generation tasks, including question answering, summarization, and reasoning. Their relatively small size makes it possible to deploy them in environments with limited resources such as a laptop, desktop or your own cloud infrastructure, democratizing access to state of the art AI models and helping foster innovation for everyone.

In this notebook, you'll learn how to fine-tune a Composed LLM (CALM) configuration using the `gemma-2-2b` model as both the anchor and augmentation model. The resulting composed model merges capabilities from both instances of `gemma-2-2b`. You could try other combinations out too (`9B` with `2B`), but you'll be keeping it simple with just the `2B` Gemma variant.

What you'll learn:
1. **Setup & Dependencies**: Installing and importing necessary libraries.
2. **Model & Configuration**: Initializing the CALM configuration with `gemma-2-2b` as anchor and augmentation models.
3. **Data Loading & Preprocessing**: Using an instruction tuning dataset, tokenizing it, and preparing for fine-tuning.
4. **Training**: Setting training arguments and running the fine-tuning using the Hugging Face `Trainer`.
5. **Saving & Conclusion**: Saving the fine-tuned model.

<table align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/[Gemma_2]Finetune_with_CALM.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
</table>

## Setup

### Selecting the Runtime Environment

To start, you can choose either **Google Colab** as your platform.

- #### **Google Colab** <img src="https://upload.wikimedia.org/wikipedia/commons/thumb/d/d0/Google_Colaboratory_SVG_Logo.svg/1200px-Google_Colaboratory_SVG_Logo.svg.png" alt="Google Colab" width="30"/>

  1. Click **Open in Colab**.
  2. You'll need access to a [**Colab Pro/Pro+**](https://colab.research.google.com/signup) runtime with sufficient resources to run the Gemma model.
  3. In the menu, go to **Runtime** > **Change runtime type**.
  4. Ensure that the **GPU** is set to **A100**.

### Gemma using Hugging Face

Before diving into the tutorial, let's set up Gemma:

1. **Create a Hugging Face Account**: If you don't have one, you can sign up for a free account [here](https://huggingface.com/join).
2. **Access the Gemma Model**: Visit the [Gemma model page](https://huggingface.com/collections/google/gemma-2-release-667d6600fd5220e7b967f315) and accept the usage conditions.
3. **Generate a Hugging Face Token**: Go to your Hugging Face [settings page](https://huggingface.com/settings/tokens) and generate a new access token (preferably with `write` permissions). You'll need this token later in the tutorial.

**Once you've completed these steps, you're ready to move on to the next section where you'll set up environment variables in your Colab environment.**

### Configure Your Credentials

To access private models and datasets, you need to log in to the Hugging Face (HF) ecosystem.

- #### **Google Colab** <img src="https://upload.wikimedia.org/wikipedia/commons/thumb/d/d0/Google_Colaboratory_SVG_Logo.svg/1200px-Google_Colaboratory_SVG_Logo.svg.png" alt="Google Colab" width="30"/>
  If you're using Colab, you can securely store your Hugging Face token (`HF_TOKEN`) using the Colab Secrets manager:
  1. Open your Google Colab notebook and click on the 🔑 Secrets tab in the left panel. <img src="https://storage.googleapis.com/generativeai-downloads/images/secrets.jpg" alt="The Secrets tab is found on the left panel." width=50%>
  2. **Add Hugging Face Token**:
    - Create a new secret with the **name** `HF_TOKEN`.
    - Copy/paste your token key into the **Value** input box of `HF_TOKEN`.
    - **Toggle** the button on the left to allow notebook access to the secret.

In [None]:
import os
import sys

if 'google.colab' in sys.modules:
    # Running on Colab
    from google.colab import userdata
    os.environ['HF_TOKEN'] = userdata.get('HF_TOKEN')
else:
    # Not running on Colab
    raise EnvironmentError('This notebook is designed to run on Google Colab.')

# Disable tokenizers parallelism to avoid deadlocks
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

### Setting Up the Environment

Next, you'll set up the environment by installing all the necessary Python packages for fine-tuning the Gemma model.


In [None]:
# Clone DeepMind CALM
!git clone https://github.com/google-deepmind/calm.git
%cd calm

You will clone the **CALM** repository and install compatible versions of `transformers`, `datasets`, and `accelerate`.

**Note**: You are using pinned versions to ensure compatibility. You may adjust them as new updates become available.

In [None]:
# Install the appropriate Hugging Face libraries to ensure compatibility with the Gemma 2 model and CALM.
!pip install transformers==4.47.0 -U -q
!pip install datasets==3.2.0 -U -q
!pip install accelerate==1.2.1 -U -q

## Import the libraries


You import the required libraries here for loading and preprocessing the [Abirate/english_quotes](https://huggingface.co/datasets/Abirate/english_quotes) dataset, tokenization, model configuration, and initialization utilities.


In [None]:
import torch
from transformers import (
    AutoTokenizer,
    AutoConfig,
    AutoModel
)
# Import the "calm" module from the "model" package for inference
from model import calm

## Fine-tune using CALM

In this section, you will:
1. Load a small dataset (`Abirate/english_quotes`) from Hugging Face Datasets.
2. Configure CALM by specifying both the anchor (base) model and the augmentation model. In this demonstration, you will use the same `gemma-2-2b` model for both. However, in practice, you may choose a different variant (e.g., `9B`, `27B`) to combine different capabilities.
3. Preprocess the dataset for language modeling.
4. Use the `Trainer` from Hugging Face Transformers to fine-tune the composed model.

You'll save your training logic into a separate Python script (`train.py`) for clarity.

### The Training Script

The script below:
- Sets up the CALM model configuration.
- Loads and tokenizes the dataset.
- Defines training arguments and runs the training.
- Saves the fine-tuned model.

You will specify parameters like `anchor_model_dir`, `aug_model_dir`, `num_heads`, `num_connections`, and other hyperparameters via command-line flags. You can easily adjust these flags for different experiments.


In [None]:
%%writefile train.py
from collections.abc import Sequence
from absl import app
from absl import flags
from absl import logging

import datasets

# Import the "calm" module from the "model" package.
# This presumably contains the CALM model implementation
from model import calm

from transformers import (
    AutoTokenizer,
    AutoConfig,
    AutoModel,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
)

# Register the custom CALMConfig class under the identifier "calm" with AutoConfig.
# By doing this, when you specify a configuration type as "calm", AutoConfig knows
# to use calm.CALMConfig to instantiate the configuration.
AutoConfig.register("calm", calm.CALMConfig)

# Register the CALM model class with AutoModel for the CALMConfig configuration class.
# This means that if AutoModel is given a CALMConfig, it knows to instantiate calm.CALM.
AutoModel.register(calm.CALMConfig, calm.CALM)

_ANCHOR_MODEL_DIR = flags.DEFINE_string('anchor_model_dir', None, 'Path to the anchor model directory or identifier.')
_AUG_MODEL_DIR = flags.DEFINE_string('aug_model_dir', None, 'Path to the augmentation model directory or identifier.')
_OUTPUT_DIR = flags.DEFINE_string('output_dir', None, 'Directory where the fine-tuned model will be saved.')
_LEARNING_RATE = flags.DEFINE_float('learning_rate', 2e-5, 'Learning rate for fine-tuning.')
_EPOCHS = flags.DEFINE_integer('epochs', 3, 'Number of training epochs.')
_BATCH_SIZE = flags.DEFINE_integer('batch_size', 1, 'Batch size per device.')
_NUM_HEADS = flags.DEFINE_integer('num_heads', 1, 'Number of cross-attention heads in CALM.')
_NUM_CONNECTIONS = flags.DEFINE_integer('num_connections', 2, 'Number of cross-connections between anchor and aug models.')
_LOGGING_STEPS = flags.DEFINE_integer('logging_steps', 1, 'Logging frequency in steps.')
_MAX_STEPS = flags.DEFINE_integer('max_steps', -1, 'Max training steps, use -1 for no limit.')

def train(argv: Sequence[str]) -> None:
    del argv  # Unused.
    SEED = 42

    anchor_model_path = _ANCHOR_MODEL_DIR.value
    aug_model_path = _AUG_MODEL_DIR.value
    num_heads = _NUM_HEADS.value
    num_connections = _NUM_CONNECTIONS.value

    logging.info('Using anchor model: %s', anchor_model_path)
    logging.info('Using augmentation model: %s', aug_model_path)

    # Load the tokenizer from the anchor model
    logging.info('Loading Tokenizer...')
    tokenizer = AutoTokenizer.from_pretrained(anchor_model_path)
    tokenizer.padding_side = 'right'

    # Create CALM config
    logging.info('Creating CALM configuration...')
    calm_config = calm.CALMConfig(
        anchor_model=anchor_model_path,
        aug_model=aug_model_path,
        anchor_config=None,
        aug_config=None,
        num_connections=num_connections,
        num_heads=num_heads,
    )
    calm_config.save_pretrained('./calm_config')

    # Initialize the composed CALM model
    logging.info('Initializing the CALM model...')
    model = calm.CALM(calm_config)
    model.config.use_cache = False

    # Load the dataset (english_quotes)
    logging.info('Loading and preparing dataset...')
    dataset = datasets.load_dataset('Abirate/english_quotes', split='all')

    # Filter out empty quotes
    dataset = dataset.filter(lambda x: len(x["quote"]) > 0)

    # For demonstration, use a small subset (e.g., 2048 samples)
    dataset = dataset.shuffle(seed=SEED).select(range(2048))

    # Tokenize the data
    def preprocess_function(examples):
        return tokenizer(
            examples['quote'], truncation=True, padding='max_length',
            max_length=512
        )

    dataset = dataset.map(preprocess_function, batched=True)

    # Data collator for language modeling (no masking since it's causal LM)
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer, mlm=False
    )

    epochs = _EPOCHS.value
    batch_size = _BATCH_SIZE.value
    learning_rate = _LEARNING_RATE.value
    output_dir = _OUTPUT_DIR.value
    logging_steps = _LOGGING_STEPS.value
    max_steps = _MAX_STEPS.value

    # Split into train/validation sets
    dataset = dataset.train_test_split(test_size=0.02)

    # TrainingArguments for Hugging Face Trainer
    training_args = TrainingArguments(
        output_dir=output_dir,
        save_strategy='no',
        overwrite_output_dir=True,
        num_train_epochs=epochs,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        eval_strategy='epoch',
        optim="adamw_torch_fused",
        lr_scheduler_type="constant",
        warmup_ratio=0.03,
        logging_steps=logging_steps,
        max_steps=max_steps,
        learning_rate=learning_rate,
        report_to="none",
        seed=SEED,
    )

    # Initialize the Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset['train'],
        eval_dataset=dataset['test'],
        data_collator=data_collator,
        tokenizer=tokenizer,
    )

    # Train the model
    logging.info('Starting training...')
    trainer.can_return_loss = True
    trainer.train()
    trainer.save_model(output_dir)
    print(f'Training complete! Model saved to {output_dir}')

if __name__ == '__main__':
    app.run(train)

Overwriting train.py


### Start fine-tuning

You can now fine-tune the composed model. To do this, you'll run a short training run of only 50 steps for demonstration. For a real training job, consider increasing the `max_steps` or `epochs` and using a larger dataset.

In [None]:
anchor_model_path = 'google/gemma-2-2b'
aug_model_path = 'google/gemma-2-2b'

# Remove previous output directory if exists
!rm -rf ./gemma-ft

# Run training with specified parameters
!python train.py --anchor_model_dir google/gemma-2-2b \
          --aug_model_dir google/gemma-2-2b \
          --num_heads 2 \
          --num_connections 2 \
          --learning_rate 3e-5 \
          --batch_size 2 \
          --max_steps 50 \
          --output_dir './gemma-ft'

2024-12-17 14:15:13.357550: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-12-17 14:15:13.375429: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-12-17 14:15:13.396556: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-12-17 14:15:13.403038: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-17 14:15:13.418338: I tensorflow/core/platform/cpu_feature_guar

## Prompt using the newly fine-tuned model

Let's finally prompt using the fine-tuned model and also verify if it's really working as intended. To do this, let's test the model with a sample prompt by first using the tokenizer to generate the input ids, and then rely on the reloaded fine-tuned model to generate a response using `model.generate()`.


In [None]:
# Register the custom CALMConfig and CALM classes with AutoConfig and AutoModel
AutoConfig.register("calm", calm.CALMConfig)
AutoModel.register(calm.CALMConfig, calm.CALM)

In [None]:
# Load the CALM configuration
config = calm.CALMConfig.from_pretrained('./calm_config')

# Load the composed and fine-tuned model
model = calm.CALM.from_pretrained('./gemma-ft', config=config)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

CALM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

CALM(
  (anchor_model): Gemma2ForCausalLM(
    (model): Gemma2Model(
      (embed_tokens): Embedding(256000, 2304, padding_idx=0)
      (layers): ModuleList(
        (0-25): 26 x Gemma2DecoderLayer(
          (self_attn): Gemma2Attention(
            (q_proj): Linear(in_features=2304, out_features=2048, bias=False)
            (k_proj): Linear(in_features=2304, out_features=1024, bias=False)
            (v_proj): Linear(in_features=2304, out_features=1024, bias=False)
            (o_proj): Linear(in_features=2048, out_features=2304, bias=False)
            (rotary_emb): Gemma2RotaryEmbedding()
          )
          (mlp): Gemma2MLP(
            (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)
            (up_proj): Linear(in_features=2304, out_features=9216, bias=False)
            (down_proj): Linear(in_features=9216, out_features=2304, bias=False)
            (act_fn): PytorchGELUTanh()
          )
          (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
   

In [None]:
print('Loading Tokenizer...')
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-2b', use_fast=True)
tokenizer.padding_side = 'right'

print('Prompting the model...')
prompt = "Life is either a "
inputs = tokenizer(prompt, return_tensors='pt').to(device)
outputs = model.generate(**inputs, max_new_tokens=40, use_cache=False,
                         repetition_penalty=1.1)
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(text)

Loading Tokenizer...
Prompting the model...
Life is either a <strong>journey</strong> or a <strong>destination.</strong> If it's the former, you'll never arrive; if it's the latter, you'll never depart.

- Anonymous




You have successfully fine-tuned a CALM-composed model using the `gemma-2-2b` as both the anchor and augmentation models. While this demonstration focuses on a simple, small-scale example, the principles remain the same for larger models and datasets. By following this guide, you’ve learnt how to compose two Gemma models with CALM to create a new model that integrates capabilities from both, expanding its skills without incurring the computational overhead of a full re-training.

### Next steps:
- Experiment with different Gemma model variants or other instruction-tuned models.
- Use larger datasets and more training steps for better model quality.
- Adjust hyperparameters (e.g., learning rate, batch size, epochs) for optimal results.

Happy fine-tuning!