# Gemma Model Fine-tuning with TPU/GPU

This notebook demonstrates how to fine-tune the Gemma model on custom datasets using either TPU or GPU accelerators.

In [None]:
# Install required packages if needed
!pip install -q transformers datasets accelerate jax flax optax huggingface_hub

In [None]:
import os
import pandas as pd
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
import torch
import matplotlib.pyplot as plt

# For TPU support
import jax
import flax.linen as nn

## 1. Load and Prepare the Dataset

In [None]:
# Load dataset
# Option 1: Load from local file
dataset_path = "../data/sample_dataset.csv"  # Replace with your dataset path

# Option 2: Load from Hugging Face Datasets
# dataset = load_dataset("your_dataset_name")

# For local CSV file
if os.path.exists(dataset_path):
    dataset = load_dataset("csv", data_files={"train": dataset_path})
    print(f"Dataset loaded successfully: {len(dataset['train'])} examples")
    # Preview dataset
    print(dataset['train'][:3])
else:
    # Demo dataset for testing
    print("Using demo dataset for testing...")
    data = {
        "text": [
            "Gemma is a large language model developed by Google.",
            "Fine-tuning allows adaptation of models to specific tasks.",
            "This is a sample dataset for demonstration purposes."
        ]
    }
    df = pd.DataFrame(data)
    df.to_csv("demo_dataset.csv", index=False)
    dataset = load_dataset("csv", data_files={"train": "demo_dataset.csv"})

## 2. Load Model and Tokenizer

In [None]:
# Set your environment variable with your API key
# os.environ["HUGGING_FACE_HUB_TOKEN"] = "your_hf_token_here"

# Load model and tokenizer
model_name = "google/gemma-2b"

tokenizer = AutoTokenizer.from_pretrained(model_name)
# Ensure pad token is set
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Load model with device mapping for optimal hardware utilization
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
print(f"Model loaded: {model_name}")

## 3. Preprocess the Dataset

In [None]:
# Define preprocessing function
def preprocess_function(examples):
    return tokenizer(
        examples["text"],
        truncation=True,
        padding="max_length",
        max_length=512  # Adjust as needed
    )

# Apply preprocessing
tokenized_dataset = dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=dataset["train"].column_names
)

print(f"Tokenized dataset: {len(tokenized_dataset['train'])} examples")

## 4. Configure Training Arguments

In [None]:
# Define output directory
output_dir = "./gemma-fine-tuned"

# Check available hardware
if torch.cuda.is_available():
    print(f"GPU available: {torch.cuda.get_device_name(0)}")
    device = "cuda"
else:
    print("No GPU available, using CPU")
    device = "cpu"

# Try to detect TPUs
try:
    tpu_devices = jax.devices('tpu')
    print(f"TPU devices available: {len(tpu_devices)}")
    using_tpu = True
except:
    print("No TPU devices detected")
    using_tpu = False

# Training Arguments
training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=8,  # Adjust based on available memory
    gradient_accumulation_steps=2,  # Increase for larger effective batch size
    num_train_epochs=3,
    learning_rate=2e-5,
    weight_decay=0.01,
    warmup_ratio=0.1,
    logging_dir="./logs",
    logging_steps=10,
    save_strategy="epoch",
    fp16=True if device == "cuda" else False,  # Use mixed precision on GPU
    save_total_limit=2,  # Keep only the last 2 checkpoints
    report_to="tensorboard",
)

print("Training arguments configured")

## 5. Train the Model

In [None]:
# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
)

# Start training
print("Starting training...")
trainer.train()
print("Training completed!")

## 6. Save the Model

In [None]:
# Save the model and tokenizer
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
print(f"Model saved to: {output_dir}")

## 7. Test the Fine-tuned Model

In [None]:
# Load the fine-tuned model for testing
fine_tuned_model = AutoModelForCausalLM.from_pretrained(output_dir, device_map="auto")
fine_tuned_tokenizer = AutoTokenizer.from_pretrained(output_dir)

# Set pad token if needed
if fine_tuned_tokenizer.pad_token is None:
    fine_tuned_tokenizer.pad_token = fine_tuned_tokenizer.eos_token

# Test generation function
def generate_text(prompt, max_length=100, temperature=0.7, top_p=0.9):
    inputs = fine_tuned_tokenizer(prompt, return_tensors="pt").to(fine_tuned_model.device)
    
    with torch.no_grad():
        outputs = fine_tuned_model.generate(
            **inputs,
            max_length=max_length,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            pad_token_id=fine_tuned_tokenizer.pad_token_id
        )
    
    return fine_tuned_tokenizer.decode(outputs[0], skip_special_tokens=True)

# Test with a sample prompt
test_prompt = "What is fine-tuning in machine learning?"
generated_text = generate_text(test_prompt)
print(f"Prompt: {test_prompt}")
print(f"Generated text: {generated_text}")

## 8. Push to Hugging Face Hub (Optional)

In [None]:
# Uncomment and complete to push your model to Hugging Face Hub
"""
from huggingface_hub import notebook_login

# Login to Hugging Face
notebook_login()

# Push model to hub
fine_tuned_model.push_to_hub("your-username/gemma-fine-tuned")
fine_tuned_tokenizer.push_to_hub("your-username/gemma-fine-tuned")
print("Model pushed to HuggingFace Hub: your-username/gemma-fine-tuned")
"""