# Gemma Model Quick Start Guide

This notebook demonstrates how to use the Gemma 1.1 Instruct 2B model with Keras Hub.

**Model Details:**
- **Family**: Gemma (Google)
- **Variant**: Gemma 1.1 Instruct 2B EN
- **Parameters**: 2 billion
- **Type**: Instruction-tuned text generation model
- **Languages**: English
- **Backends**: JAX, TensorFlow, PyTorch (via Keras 3)

## 1. Installation

Install required packages for Gemma model usage.

In [None]:
# Install Keras Hub and Keras 3
!pip install -q -U keras-hub
!pip install -q -U keras

# Optional: Install JAX backend (recommended for Gemma)
!pip install -q jax[cpu]

## 2. Import Libraries and Configure Backend

In [None]:
import os

# Set Keras backend (choose: 'jax', 'tensorflow', or 'torch')
os.environ['KERAS_BACKEND'] = 'jax'

import keras
import keras_hub
import numpy as np

print(f"Keras version: {keras.__version__}")
print(f"Keras Hub version: {keras_hub.__version__}")
print(f"Using backend: {keras.backend.backend()}")

## 3. Load Gemma Model

Load the Gemma 1.1 Instruct 2B model from Keras Hub presets.

In [None]:
# Load the Gemma 1.1 Instruct 2B model
# Note: First time will download ~5GB of model weights
print("Loading Gemma 1.1 Instruct 2B model...")
print("This may take a few minutes on first run (downloading ~5GB)\n")

gemma_lm = keras_hub.models.GemmaCausalLM.from_preset("gemma_1.1_instruct_2b_en")

print("✓ Model loaded successfully!")

## 4. Basic Text Generation

Use the `generate()` method for simple text completion.

In [None]:
# Simple text generation
prompt = "Keras is a"
output = gemma_lm.generate(prompt, max_length=30)

print(f"Prompt: {prompt}")
print(f"Generated: {output}")

## 5. Batch Generation

Generate text for multiple prompts simultaneously.

In [None]:
# Batch text generation
prompts = [
    "The future of artificial intelligence is",
    "Machine learning can help solve",
    "Deep learning models are"
]

outputs = gemma_lm.generate(prompts, max_length=50)

for prompt, output in zip(prompts, outputs):
    print(f"\nPrompt: {prompt}")
    print(f"Output: {output}")
    print("-" * 80)

## 6. Custom Sampling Strategies

Configure different sampling methods for text generation.

### Top-K Sampling

In [None]:
# Compile with Top-K sampler
gemma_lm.compile(sampler="top_k")

prompt = "The most important aspect of machine learning is"
output = gemma_lm.generate(prompt, max_length=50)

print(f"Top-K Sampling:")
print(f"Prompt: {prompt}")
print(f"Output: {output}")

### Beam Search Sampling

In [None]:
# Compile with Beam sampler
gemma_lm.compile(sampler=keras_hub.samplers.BeamSampler(num_beams=2))

prompt = "Explain quantum computing in simple terms:"
output = gemma_lm.generate(prompt, max_length=100)

print(f"Beam Search (2 beams):")
print(f"Prompt: {prompt}")
print(f"Output: {output}")

## 7. Question Answering

Use Gemma for question-answering tasks.

In [None]:
# Reset to default sampler
gemma_lm = keras_hub.models.GemmaCausalLM.from_preset("gemma_1.1_instruct_2b_en")

# Question answering examples
questions = [
    "What is the capital of France?",
    "Explain what a neural network is in one sentence.",
    "How does gradient descent work?"
]

for question in questions:
    answer = gemma_lm.generate(question, max_length=100)
    print(f"\nQ: {question}")
    print(f"A: {answer}")
    print("-" * 80)

## 8. Text Summarization

Generate summaries of longer text.

In [None]:
# Text summarization
text_to_summarize = """
Artificial intelligence (AI) is intelligence demonstrated by machines, 
as opposed to natural intelligence displayed by animals including humans. 
AI research has been defined as the field of study of intelligent agents, 
which refers to any system that perceives its environment and takes actions 
that maximize its chance of achieving its goals.
"""

prompt = f"Summarize the following text in one sentence:\n{text_to_summarize}\n\nSummary:"
summary = gemma_lm.generate(prompt, max_length=100)

print(f"Original text: {text_to_summarize}")
print(f"\nSummary: {summary}")

## 9. Code Generation

Generate code snippets based on natural language descriptions.

In [None]:
# Code generation example
code_prompts = [
    "Write a Python function to calculate the factorial of a number:",
    "Create a function that sorts a list in Python:",
    "Write code to read a CSV file using pandas:"
]

for prompt in code_prompts:
    code = gemma_lm.generate(prompt, max_length=150)
    print(f"\nPrompt: {prompt}")
    print(f"Generated code:\n{code}")
    print("=" * 80)

## 10. Advanced: Low-Level Token Generation

Work directly with token IDs without preprocessing.

In [None]:
# Generate using token IDs directly
prompt = {
    # Token IDs: start token followed by "Keras is"
    "token_ids": np.array([[2, 214064, 603, 0, 0, 0, 0]] * 2),
    # Padding mask: 1 for real tokens, 0 for padding
    "padding_mask": np.array([[1, 1, 1, 0, 0, 0, 0]] * 2),
}

gemma_lm_no_preproc = keras_hub.models.GemmaCausalLM.from_preset(
    "gemma_1.1_instruct_2b_en",
    preprocessor=None,
)

output = gemma_lm_no_preproc.generate(prompt)
print(f"Token-based generation output:\n{output}")

## 11. Fine-tuning Example (Single Batch)

Demonstrate how to fine-tune the model on custom data.

In [None]:
# Fine-tuning on a single batch
features = [
    "The quick brown fox jumped over the lazy dog.",
    "Machine learning is a subset of artificial intelligence."
]

# Create a fresh model instance for training
gemma_lm_train = keras_hub.models.GemmaCausalLM.from_preset("gemma_1.1_instruct_2b_en")

# Note: This is a simple example. In practice, you'd need more data and epochs
print("Fine-tuning on sample batch...")
gemma_lm_train.fit(x=features, batch_size=2, epochs=1)
print("✓ Fine-tuning complete!")

## 12. Model Information and Configuration

In [None]:
# Display model configuration
print("Gemma Model Configuration:")
print("=" * 50)
print(f"Vocabulary size: {gemma_lm.preprocessor.tokenizer.vocabulary_size()}")
print(f"\nModel summary:")
gemma_lm.summary()

## 13. Performance Tips

**Optimization strategies:**

1. **Choose the right backend**: JAX is recommended for Gemma models
2. **Batch processing**: Process multiple prompts together for efficiency
3. **Adjust max_length**: Shorter sequences generate faster
4. **Use appropriate sampling**: Greedy is fastest, beam search most accurate
5. **GPU acceleration**: Use GPU/TPU for production workloads

**Memory considerations:**
- Gemma 2B requires ~8GB RAM minimum
- Gemma 7B requires ~28GB RAM minimum
- Use smaller batch sizes if memory-constrained

## 14. Common Use Cases

**Gemma models excel at:**
- ✓ Question answering
- ✓ Text summarization
- ✓ Code generation
- ✓ Chatbots and conversational AI
- ✓ Content creation
- ✓ Language translation
- ✓ Text classification
- ✓ Named entity recognition

**Limitations:**
- ⚠ May generate factually incorrect information
- ⚠ Limited to English language
- ⚠ Context window limitations
- ⚠ May reflect biases from training data

## 15. Next Steps

**Explore further:**
1. Try the Gemma 7B model for better performance
2. Experiment with different sampling strategies
3. Fine-tune on domain-specific data
4. Build a chatbot or Q&A system
5. Integrate with LangChain for advanced applications
6. Deploy to production with optimizations

**Resources:**
- [Gemma on Kaggle](https://www.kaggle.com/models/google/gemma)
- [Keras Hub Documentation](https://keras.io/keras_hub/)
- [Gemma Model Card](https://ai.google.dev/gemma/docs)
- [Responsible AI Toolkit](https://ai.google.dev/responsible)