# Adding Adapter Support for Gemma 3 with Plugin Interface

This notebook demonstrates how to add adapter support to the Gemma 3 model using the adapters library's plugin interface. Adapters are a parameter-efficient fine-tuning technique that allows you to adapt pre-trained language models to new tasks while only training a small number of parameters.

## 1. Installing Required Libraries

First, let's install the necessary libraries if you haven't already.

In [None]:
# Install the adapters library and transformers
!pip install -q adapters transformers datasets

## 2. Understanding the Model Architecture

Before creating our plugin interface, let's understand the basic structure of Gemma 3:

- Like most Transformer language models, it consists of an embedding layer followed by a series of decoder layers
- Each layer contains a self-attention block and an MLP block
- The self-attention block includes query, key, value, and output projections
- The MLP block includes multiple linear projections

To create an adapter interface, we need to map these components to the appropriate adapter hooks.

## 3. Creating the Plugin Interface

Now we'll create a plugin interface for Gemma 3 that maps the model's architecture to the adapter framework.

In [23]:
import adapters
from adapters import AdapterModelInterface
from transformers import AutoModelForCausalLM

plugin_interface = AdapterModelInterface(
    # Specify which adapter methods to enable
    adapter_methods=["lora", "reft"],
    
    # Map the model's components to the adapter interface
    model_embeddings="embed_tokens",      # Embedding layer
    model_layers="layers",                # Transformer layers
    layer_self_attn="self_attn",          # Self-attention module in each layer
    layer_cross_attn=None,                # Gemma 3 doesn't have cross-attention
    
    # Projection matrices within the attention module
    attn_k_proj="k_proj",                 # Key projection
    attn_q_proj="q_proj",                 # Query projection
    attn_v_proj="v_proj",                 # Value projection
    attn_o_proj="o_proj",                 # Output projection
    
    # MLP projections
    layer_intermediate_proj="mlp.up_proj",  # Upward projection in MLP
    layer_output_proj="mlp.down_proj",      # Downward projection in MLP
)

Each parameter in the interface maps to specific module names in the model's architecture, allowing the adapter methods to hook into the right components.

## 4. Loading the Model and Initializing with the Interface

Now, let's load the Gemma 3 model and initialize it with our plugin interface.

⚠️ Note: Gemma 3 is a gated model that requires a HuggingFace token for access. Make sure you have accepted the model terms on the HuggingFace Hub.

In [24]:
# Import HuggingFace token
import os
# Set your HuggingFace token here, or set it as an environment variable
os.environ["HUGGINGFACE_TOKEN"] = "<YOUR_TOKEN>" 

# For demonstration purposes, we'll use a smaller version of the model
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-3-1b-it",  # You can switch to google/gemma-3-8b if you have enough resources
    token=os.environ.get("HUGGINGFACE_TOKEN"),  # Required for gated models
    device_map="auto"  # Automatically distribute model across available GPUs
)

Some parameters are on the meta device because they were offloaded to the cpu.


In [25]:
# Initialize the adapter framework with our plugin interface
adapters.init(model, interface=plugin_interface)

AttributeError: 'functools.partial' object has no attribute '__func__'

## 5. Adding and Training an Adapter

With the interface in place, we can now add an adapter to our model.

In [None]:
from adapters import LoRAConfig

# Add a LoRA adapter
adapter_name = "gemma3-math-adapter"
lora_config = LoRAConfig(
    r=16,           # LoRA rank
    alpha=32,       # LoRA alpha parameter
    dropout=0.05,   # Dropout probability for LoRA layers
)

# model.add_adapter(adapter_name, config=lora_config)

# Activate the adapter
model.set_active_adapters(adapter_name)

# Set the model to train only the adapter parameters
model.train_adapter(adapter_name)

# Verify adapter was correctly added
print(model.adapter_summary())

Name                     Architecture         #Param      %Param  Active   Train
--------------------------------------------------------------------------------
gemma3-finance-adapter   lora              1,490,944       0.149       1       1
--------------------------------------------------------------------------------
Full model                               999,885,952     100.000               0


## 6. Loading the GSM8K Dataset for Fine-tuning

For this example, we'll use the GSM8K dataset to fine-tune our model for solving grade school math problems.

In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer

# Load the GSM8K dataset
dataset = load_dataset("openai/gsm8k", "main")
print(dataset)

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    "google/gemma-3-1b-it",
    token=os.environ.get("HUGGINGFACE_TOKEN")
)

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label'],
        num_rows: 2264
    })
})


In [None]:
# Explore sample data
print("Sample question:")
print(dataset["train"][0]["question"])
print("\nSample answer:")
print(dataset["train"][0]["answer"])

{'sentence': 'According to Gran , the company has no plans to move all production to Russia , although that is where the company is growing .', 'label': 1}
Label distribution: ['negative', 'neutral', 'positive']
ClassLabel(names=['negative', 'neutral', 'positive'], id=None)


## 7. Preprocessing the Dataset

We need to tokenize our math problems and their solutions for training.

In [None]:
import torch

def preprocess_function(examples):
    # Create full prompts with question and expected answer format
    prompts = [
        f"Solve the following math problem step-by-step:\n\nQuestion: {q}\n\nAnswer: {a}" 
        for q, a in zip(examples["question"], examples["answer"])
    ]
    
    # Tokenize as regular sequences
    tokenized = tokenizer(prompts, padding="max_length", truncation=True, max_length=768)
    
    # For causal language modeling, labels are the same as input_ids
    tokenized["labels"] = tokenized["input_ids"].copy()
    
    return tokenized

# Apply preprocessing to the dataset
tokenized_dataset = dataset.map(preprocess_function, batched=True, remove_columns=["question", "answer"])

print("Dataset processed!")

Map: 100%|██████████| 2264/2264 [00:00<00:00, 7365.73 examples/s]

Dataset processed!





## 8. Fine-tuning the Adapter

Now we can fine-tune our adapter for solving math problems.

In [None]:
from transformers import Trainer, TrainingArguments
import numpy as np

# For math problem solving, we'll use perplexity as our main metric
def compute_metrics(pred):
    logits = pred.predictions
    labels = pred.label_ids
    
    # Shift labels to align with predictions (standard for causal language modeling)
    shifted_logits = logits[..., :-1, :].contiguous()
    shifted_labels = labels[..., 1:].contiguous()
    
    # Calculate loss
    loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
    loss = loss_fct(shifted_logits.view(-1, shifted_logits.size(-1)), shifted_labels.view(-1))
    
    # Calculate perplexity
    perplexity = torch.exp(loss).item()
    
    return {"perplexity": perplexity, "loss": loss.item()}

# Set up training arguments - adjusted for math problem solving
training_args = TrainingArguments(
    output_dir="./gemma3-math-adapter",
    per_device_train_batch_size=2,  # Smaller batch size due to longer sequences
    per_device_eval_batch_size=2,
    learning_rate=5e-5,
    num_train_epochs=5,  # More epochs for complex task
    save_steps=200,
    eval_steps=200,
    logging_steps=50,
    evaluation_strategy="steps",
    load_best_model_at_end=True,
    metric_for_best_model="loss",  # Use loss as metric for best model
    greater_is_better=False,  # Lower loss is better
    push_to_hub=False,
    gradient_accumulation_steps=4,  # Accumulate gradients to simulate larger batch sizes
    fp16=torch.cuda.is_available(),  # Use mixed precision if available
    warmup_ratio=0.1,  # Add some warmup steps
)

In [None]:
# Split dataset into train and validation
# Use a smaller subset for faster training if needed
train_dataset = tokenized_dataset["train"].select(range(min(len(tokenized_dataset["train"]), 2000)))
eval_dataset = tokenized_dataset["test"].select(range(min(len(tokenized_dataset["test"]), 200)))

print(f"Training on {len(train_dataset)} examples and evaluating on {len(eval_dataset)} examples")

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 1448
    })
    test: Dataset({
        features: ['sentence', 'label', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 363
    })
})


In [None]:
from adapters import AdapterTrainer

# Initialize the trainer
trainer = AdapterTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,
)

# Train only the adapter parameters
trainer.train()

ValueError: Expected input batch_size (512) to match target batch_size (4).

## 9. Saving and Loading the Adapter

After training, we can save just the adapter weights.

In [None]:
# Save only the adapter weights
model.save_adapter("./gemma3-math-adapter", adapter_name)

## 10. Testing the Adapter

Let's test our math problem-solving adapter on some new examples.

In [None]:
# Let's test the model with a few math problems
test_examples = [
    "Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four eggs. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?",
    "John has 5 pens, 2 pencils, and 3 erasers in his drawer. If he randomly picks 3 items from the drawer, what is the probability that he picks exactly 2 pens?",
    "A rectangle has a length of 12 cm and a width of 8 cm. What is its area in square centimeters?"
]

# Format the test examples with the prompt template
test_prompts = [
    f"Solve the following math problem step-by-step:\n\nQuestion: {text}\n\nAnswer:" 
    for text in test_examples
]

# Tokenize the test prompts
test_inputs = tokenizer(test_prompts, return_tensors="pt", padding=True).to(model.device)

# Generate responses
with torch.no_grad():
    outputs = model.generate(
        **test_inputs,
        max_new_tokens=300,  # More tokens for step-by-step solutions
        temperature=0.3,     # Lower temperature for more deterministic answers
        do_sample=True,      # Some sampling for creativity in problem-solving
        num_beams=3,         # Beam search for better coherence
        no_repeat_ngram_size=2  # Avoid repetition
    )

# Decode and print the results
for i, output in enumerate(outputs):
    generated_text = tokenizer.decode(output, skip_special_tokens=True)
    print(f"Problem: {test_examples[i]}")
    print(f"Solution:\n{generated_text}")
    print("---\n")

In [None]:
# This code demonstrates how you would reload the model and adapter in a new session
# We're not executing this in the notebook as we already have the model loaded

'''
# Load the base model
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-3-1b-it", 
    token="YOUR_HF_TOKEN",
    device_map="auto"
)

# Create the plugin interface again
plugin_interface = AdapterModelInterface(
    adapter_methods=["lora", "reft"],
    model_embeddings="embed_tokens",
    model_layers="layers",
    layer_self_attn="self_attn",
    layer_cross_attn=None,
    attn_k_proj="k_proj",
    attn_q_proj="q_proj",
    attn_v_proj="v_proj",
    attn_o_proj="o_proj",
    layer_intermediate_proj="mlp.up_proj",
    layer_output_proj="mlp.down_proj",
)

# Initialize adapter support
adapters.init(model, interface=plugin_interface)

# Load the adapter
model.load_adapter("./gemma3-math-adapter", adapter_name="math")

# Activate the adapter
model.set_active_adapters("math")
'''

## 12. Note on HybridCache for Gemma 3

One thing to note about Gemma 3 is that it uses a special "HybridCache" for attention. This is different from standard key-value caching mechanisms and requires special handling in some generation scenarios. However, for adapter training, we don't need to worry about this since we're only modifying specific components and not changing the caching mechanism.

The HybridCache implementation in Gemma 3 is used for efficient generation and doesn't interfere with adapter training or inference.

## 13. Conclusion

In this notebook, we've demonstrated how to:

1. Create a plugin interface for adding adapter support to Gemma 3
2. Load and initialize the model with the adapter framework
3. Add a LoRA adapter to the model
4. Fine-tune the adapter on the GSM8K math problem-solving task
5. Save and reload the adapter weights
6. Use the adapter for solving new math problems

The plugin interface approach allows you to use parameter-efficient fine-tuning with any Transformer model, even those not officially supported by the adapters library.