# Fine-tune Google Gemma-3 4B with DeepSpeed ZeRO-3 on Amazon SageMaker AI using ModelTrainer

In this notebook, we fine-tune [Google Gemma-3 4B Instruct](https://huggingface.co/google/gemma-3-4b-it) on Amazon SageMaker AI, using Python scripts and SageMaker ModelTrainer for executing a training job with DeepSpeed ZeRO-3 distributed training strategy.

## Overview

- **Model**: google/gemma-3-4b-it
- **Strategy**: DeepSpeed ZeRO-3 with CPU offloading
- **Dataset**: HuggingFaceH4/Multilingual-Thinking (Apache 2.0 license)
- **Training**: LoRA fine-tuning with merged weights

## Prerequisites

In [None]:
%pip install -r ./scripts/requirements.txt --upgrade

## Setup Configuration

Configure your Hugging Face token and optionally MLflow tracking server ARN.

In [None]:
import os

model_id = "google/gemma-3-4b-it"

os.environ["HF_TOKEN"] = "<HF_TOKEN>"
os.environ["model_id"] = model_id
os.environ["mlflow_uri"] = "arn:aws:sagemaker:region:account_id:mlflow-app/app-xxxxxxxx"
os.environ["mlflow_experiment_name"] = "gemma-3-4b-it-reasoning-multi-language"

## Visualize and upload the dataset

We are going to load [HuggingFaceH4/Multilingual-Thinking](https://huggingface.co/datasets/HuggingFaceH4/Multilingual-Thinking) dataset (Apache 2.0 license).

In [None]:
import sagemaker

sagemaker_session = sagemaker.Session()
bucket_name = sagemaker_session.default_bucket()
default_prefix = sagemaker_session.default_bucket_prefix

In [None]:
from datasets import load_dataset

dataset = load_dataset("HuggingFaceH4/Multilingual-Thinking", split="train")

dataset

In [None]:
import pandas as pd

df = pd.DataFrame(dataset)

df.head()

In [None]:
from sklearn.model_selection import train_test_split

train, val = train_test_split(df, test_size=0.1, random_state=42)
train, test = train_test_split(train, test_size=10, random_state=42)

print("Number of train elements: ", len(train))
print("Number of val elements: ", len(val))
print("Number of test elements: ", len(test))

Create a prompt template and format the dataset using Gemma-3 chat template.

In [None]:
import textwrap
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_id)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token


def prepare_dataset(sample):
    messages = []
    first_user_message = True
    
    for el in sample["messages"]:
        if el["role"] == "system":
            system_prompt = """
            You are an AI assistant that thinks in {language} but responds in English.

            IMPORTANT: Follow this exact format for every response:
            1. First, write your reasoning and thoughts inside <think>...</think> tags
            2. Then, provide your final answer in English

            Always think through the problem in {language}, then translate your conclusion to English for the final response.
            """
            system_prompt = system_prompt.format(language=sample["reasoning_language"])
            system_prompt = textwrap.dedent(system_prompt).strip()
        elif el["role"] == "user":
            if first_user_message:
                first_user_message = False
                messages.append({"role": "user", "content": system_prompt + "\n\n" + el["content"]})
            else:
                messages.append({"role": "user", "content": el["content"]})
        else:
            if el["thinking"] is not None and el["thinking"] != "" and el["thinking"] != "null":
                messages.append({
                    "role": "assistant",
                    "content": f"<think>\n{el['thinking']}\n</think>\n{el['content']}",
                })
            else:
                messages.append({"role": "assistant", "content": el["content"]})

    sample["text"] = tokenizer.apply_chat_template(messages, tokenize=False)
    return sample

In [None]:
from datasets import Dataset, DatasetDict
from random import randint

train_dataset = Dataset.from_pandas(train)
val_dataset = Dataset.from_pandas(val)
test_dataset = Dataset.from_pandas(test)

dataset = DatasetDict({"train": train_dataset, "val": val_dataset})

train_dataset = dataset["train"].map(
    prepare_dataset, remove_columns=list(train_dataset.features)
)

print(train_dataset[randint(0, len(dataset))]["text"])

val_dataset = dataset["val"].map(
    prepare_dataset, remove_columns=list(val_dataset.features)
)

### Upload to Amazon S3

In [None]:
import boto3
import shutil
import sagemaker

sagemaker_session = sagemaker.Session()
s3_client = boto3.client('s3')

bucket_name = sagemaker_session.default_bucket()
default_prefix = sagemaker_session.default_bucket_prefix

In [None]:
if default_prefix:
    input_path = f"{default_prefix}/datasets/gemma-3-4b-it-fine-tuning-dsz3"
else:
    input_path = f"datasets/gemma-3-4b-it-fine-tuning-dsz3"

train_dataset_s3_path = f"s3://{bucket_name}/{input_path}/train/dataset.json"
val_dataset_s3_path = f"s3://{bucket_name}/{input_path}/val/dataset.json"

In [None]:
import os

os.makedirs("./data/train", exist_ok=True)
os.makedirs("./data/val", exist_ok=True)

train_dataset.to_json("./data/train/dataset.json", orient="records")
val_dataset.to_json("./data/val/dataset.json", orient="records")

s3_client.upload_file("./data/train/dataset.json", bucket_name, f"{input_path}/train/dataset.json")
s3_client.upload_file("./data/val/dataset.json", bucket_name, f"{input_path}/val/dataset.json")

shutil.rmtree("./data")

print(f"Training data uploaded to:")
print(train_dataset_s3_path)
print(val_dataset_s3_path)

## Model fine-tuning

We are now ready to fine-tune our model. We will use the [Trainer](https://huggingface.co/docs/transformers/main_classes/trainer) from transformers to fine-tune our model. We prepared a script [train.py](./scripts/train.py) which loads the dataset from disk, prepares the model, tokenizer and starts the training.

### Training configurations

For configuration we use `TrlParser`, that allows us to provide hyperparameters in a `yaml` file. This yaml will be uploaded and provided to Amazon SageMaker similar to our datasets.

In [None]:
%%bash

cat > ./args.yaml <<EOF
model_id: "${model_id}"                           # Hugging Face model id
mlflow_uri: "${mlflow_uri}"                       # MLflow tracking server URI
mlflow_experiment_name: "${mlflow_experiment_name}" # MLflow experiment name
# sagemaker specific parameters
output_dir: "/opt/ml/model"                       # path to where SageMaker will upload the model 
checkpoint_dir: "/opt/ml/checkpoints/"            # directory for saving training checkpoints
train_dataset_path: "/opt/ml/input/data/train/"   # path to where S3 saves train dataset
val_dataset_path: "/opt/ml/input/data/val/"       # path to where S3 saves test dataset
token: "${HF_TOKEN}"                              # Hugging Face API token
merge_weights: true                               # merge weights in the base model
# training parameters
apply_truncation: true                           # apply truncation to datasets
attn_implementation: "flash_attention_2"         # attention implementation type
learning_rate: 2e-5                              # learning rate scheduler
num_train_epochs: 10                             # number of training epochs
per_device_train_batch_size: 1                   # batch size per device during training
per_device_eval_batch_size: 2                    # batch size for evaluation
gradient_accumulation_steps: 16                  # number of steps before performing a backward/update pass
gradient_checkpointing: true                     # use gradient checkpointing
torch_dtype: "bfloat16"                          # float precision type
bf16: true                                       # use bfloat16 precision
tf32: true                                       # use tf32 precision
ignore_data_skip: true                           # skip data loading errors
logging_strategy: "steps"                        # logging strategy
logging_steps: 1                                 # log every N steps
log_on_each_node: false                          # disable logging on each node
ddp_find_unused_parameters: false                # DDP unused parameter detection
save_total_limit: 1                              # maximum number of checkpoints to keep
save_steps: 100                                  # Save checkpoint every this many steps
warmup_steps: 50                                 # number of warmup steps
weight_decay: 0.01                               # weight decay coefficient
dataloader_pin_memory: false                     # pin memory for dataloader
# LoRA parameters
load_in_4bit: false                              # enable 4-bit quantization
lora_r: 16                                       # LoRA rank
lora_alpha: 32                                   # LoRA alpha parameter
lora_dropout: 0.1                                # LoRA dropout rate
EOF

In [None]:
import os
from sagemaker.s3 import S3Uploader

if default_prefix:
    input_path = f"s3://{bucket_name}/{default_prefix}/datasets/gemma-3-4b-it-fine-tuning-dsz3"
else:
    input_path = f"s3://{bucket_name}/datasets/gemma-3-4b-it-fine-tuning-dsz3"

model_yaml = "args.yaml"
train_config_s3_path = S3Uploader.upload(local_path=model_yaml, desired_s3_uri=f"{input_path}/config")

os.remove("./args.yaml")

print(f"Training config uploaded to:")
print(train_config_s3_path)

### DeepSpeed ZeRO-3 configurations

In [None]:
%%bash

cat > ./accelerate_config.yaml <<EOF
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
  deepspeed_multinode_launcher: standard
  offload_optimizer_device: cpu
  offload_param_device: cpu
  zero3_init_flag: true
  zero3_save_16bit_model: true
  zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
main_training_function: main
mixed_precision: bf16
rdzv_backend: c10d                        # static for single node, c10d for single and multi-node
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
EOF

In [None]:
import os
from sagemaker.s3 import S3Uploader

if default_prefix:
    input_path = f"s3://{bucket_name}/{default_prefix}/datasets/gemma-3-4b-it-fine-tuning-dsz3"
else:
    input_path = f"s3://{bucket_name}/datasets/gemma-3-4b-it-fine-tuning-dsz3"

model_yaml = "accelerate_config.yaml"
train_accelerate_config_s3_path = S3Uploader.upload(
    local_path=model_yaml, desired_s3_uri=f"{input_path}/accelerate_config"
)

os.remove("./accelerate_config.yaml")

print(f"Accelerate config uploaded to:")
print(train_accelerate_config_s3_path)

## Fine-tune model

Below estimator will train the model with LoRA, merge the adapter in the base model and save in S3.

### Get PyTorch image_uri

We are going to use the native PyTorch container image, pre-built for Amazon SageMaker.

In [None]:
import sagemaker
from sagemaker.config import load_sagemaker_config

sagemaker_session = sagemaker.Session()

bucket_name = sagemaker_session.default_bucket()
default_prefix = sagemaker_session.default_bucket_prefix
configs = load_sagemaker_config()

In [None]:
# ml.g5.12xlarge has 4x A10G GPUs
instance_type = "ml.g5.12xlarge"
instance_count = 1

instance_type

In [None]:
image_uri = sagemaker.image_uris.retrieve(
    framework="pytorch",
    region=sagemaker_session.boto_session.region_name,
    version="2.7.1",
    instance_type=instance_type,
    image_scope="training",
)

image_uri

In [None]:
from sagemaker.modules.configs import (
    CheckpointConfig,
    Compute,
    OutputDataConfig,
    SourceCode,
    StoppingCondition,
)
from sagemaker.modules.distributed import Torchrun
from sagemaker.modules.train import ModelTrainer

args = [
    "--entrypoint",
    "train.py",
    "--accelerate_config",
    "/opt/ml/input/data/accelerate_config/accelerate_config.yaml",
    "--config",
    "/opt/ml/input/data/config/args.yaml",
]

source_code = SourceCode(
    source_dir="./scripts",
    requirements="requirements.txt",
    command=f"bash sm_accelerate_train.sh {' '.join(args)}",
)

compute_configs = Compute(
    instance_type=instance_type,
    instance_count=instance_count,
    keep_alive_period_in_seconds=0,
)

job_name = f"train-{model_id.split('/')[-1].replace('.', '-')}-dsz3"

if default_prefix:
    output_path = f"s3://{bucket_name}/{default_prefix}/{job_name}"
else:
    output_path = f"s3://{bucket_name}/{job_name}"

model_trainer = ModelTrainer(
    training_image=image_uri,
    source_code=source_code,
    base_job_name=job_name,
    compute=compute_configs,
    stopping_condition=StoppingCondition(max_runtime_in_seconds=18000),
    output_data_config=OutputDataConfig(
        s3_output_path=output_path, compression_type="NONE"
    ),
    checkpoint_config=CheckpointConfig(
        s3_uri=output_path + "/checkpoint", local_path="/opt/ml/checkpoints"
    ),
)

In [None]:
from sagemaker.modules.configs import InputData

train_input = InputData(
    channel_name="train",
    data_source=train_dataset_s3_path,
)

val_input = InputData(
    channel_name="val",
    data_source=val_dataset_s3_path,
)

config_input = InputData(
    channel_name="config",
    data_source=train_config_s3_path,
)

accelerate_config_input = InputData(
    channel_name="accelerate_config",
    data_source=train_accelerate_config_s3_path,
)

data = [train_input, val_input, config_input, accelerate_config_input]
data

In [None]:
# Start the training job
model_trainer.train(input_data_config=data, wait=False)

---

# Model Deployment

In the following sections, we are going to deploy the fine-tuned model on an Amazon SageMaker Real-time endpoint.

## Load Fine-Tuned model

In [None]:
import boto3
import sagemaker

sagemaker_session = sagemaker.Session()

In [None]:
model_id = "google/gemma-3-4b-it"

bucket_name = sagemaker_session.default_bucket()
default_prefix = sagemaker_session.default_bucket_prefix
job_prefix = f"train-{model_id.split('/')[-1].replace('.', '-')}-dsz3"

In [None]:
def get_last_job_name(job_name_prefix):
    sagemaker_client = boto3.client('sagemaker')
    matching_jobs = []
    next_token = None

    while True:
        search_params = {
            'Resource': 'TrainingJob',
            'SearchExpression': {
                'Filters': [
                    {'Name': 'TrainingJobName', 'Operator': 'Contains', 'Value': job_name_prefix},
                    {'Name': 'TrainingJobStatus', 'Operator': 'Equals', 'Value': "Completed"}
                ]
            },
            'SortBy': 'CreationTime',
            'SortOrder': 'Descending',
            'MaxResults': 100
        }

        if next_token:
            search_params['NextToken'] = next_token

        search_response = sagemaker_client.search(**search_params)

        matching_jobs.extend([
            job['TrainingJob']['TrainingJobName'] 
            for job in search_response['Results']
            if job['TrainingJob']['TrainingJobName'].startswith(job_name_prefix)
        ])

        next_token = search_response.get('NextToken')
        if not next_token or matching_jobs:
            break

    if not matching_jobs:
        raise ValueError(f"No completed training jobs found starting with prefix '{job_name_prefix}'")

    return matching_jobs[0]

In [None]:
job_name = get_last_job_name(job_prefix)
job_name

### Inference configurations

In [None]:
import sagemaker
from sagemaker import get_execution_role
from sagemaker import Model

In [None]:
instance_count = 1
instance_type = "ml.g5.2xlarge"  # Single A10G GPU is sufficient for Gemma-3 4B
health_check_timeout = 700

In [None]:
image_uri = sagemaker.image_uris.retrieve(
    framework="djl-lmi",
    region=sagemaker_session.boto_session.region_name,
    version="latest"
)

image_uri = image_uri.split("/")[0] + "/djl-inference:0.36.0-lmi18.0.0-cu128"

image_uri

In [None]:
if default_prefix:
    model_data_path = f"s3://{bucket_name}/{default_prefix}/{job_prefix}/{job_name}/output/model/"
else:
    model_data_path = f"s3://{bucket_name}/{job_prefix}/{job_name}/output/model/"

model_data = {
    "S3DataSource": {
        "S3Uri": model_data_path,
        "S3DataType": "S3Prefix",
        "CompressionType": "None",
    }
}

model = Model(
    image_uri=image_uri,
    model_data=model_data,
    role=get_execution_role(),
    env={
        "HF_MODEL_ID": "/opt/ml/model",
        "SERVING_FAIL_FAST": "true",
        "OPTION_ASYNC_MODE": "true",
        "OPTION_ROLLING_BATCH": "disable",
        "OPTION_TENSOR_PARALLEL_DEGREE": "max",
        "OPTION_ENTRYPOINT": "djl_python.lmi_vllm.vllm_async_service",
        "OPTION_TRUST_REMOTE_CODE": "true",
        "OPTION_MODEL_LOADING_TIMEOUT": "3600"
    },

)

In [None]:
endpoint_name = f"{model_id.split('/')[-1].replace('.', '-')}-djl"

In [None]:
predictor = model.deploy(
    endpoint_name=endpoint_name,
    initial_instance_count=instance_count,
    instance_type=instance_type,
    container_startup_health_check_timeout=health_check_timeout,
    model_data_download_timeout=3600
)

### Predict

In [None]:
import sagemaker

sagemaker_session = sagemaker.Session()

In [None]:
model_id = "google/gemma-3-4b-it"

endpoint_name = f"{model_id.split('/')[-1].replace('.', '-')}-djl"

In [None]:
predictor = sagemaker.Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sagemaker_session,
    serializer=sagemaker.serializers.JSONSerializer(),
    deserializer=sagemaker.deserializers.JSONDeserializer(),
)

In [None]:
import pandas as pd
import textwrap

eval_dataset = []

index = 1
for sample in test_dataset:
    print("Processing item ", index)

    messages = []
    message_index = 0
    for el in sample["messages"]:
        if message_index == len(sample["messages"]) - 1:
            break

        if el["role"] == "system":
            system_prompt = """
            You are an AI assistant that thinks in {language} but responds in English.

            IMPORTANT: Follow this exact format for every response:
            1. First, write your reasoning and thoughts inside <think>...</think> tags
            2. Then, provide your final answer in English

            Always think through the problem in {language}, then translate your conclusion to English for the final response.
            """
            system_prompt = system_prompt.format(language=sample["reasoning_language"])
            system_prompt = textwrap.dedent(system_prompt).strip()
            messages.append({"role": "system", "content": system_prompt})
        elif el["role"] == "user":
            messages.append({"role": "user", "content": el["content"]})
        else:
            if el["thinking"] is not None and el["thinking"] != "" and el["thinking"] != "null":
                messages.append({
                    "role": "assistant",
                    "content": f"<think>\n{el['thinking']}\n</think>\n{el['content']}",
                })
            else:
                messages.append({"role": "assistant", "content": el["content"]})

        message_index += 1

    response = predictor.predict({
        "messages": messages,
        "max_tokens": 4096,
        "temperature": 0.1,
        "top_p": 0.9,
        "repetition_penalty": 1.15,
        "do_sample": True,
    })

    eval_dataset.append([
        [el["content"] for el in messages if el["role"] == "system"][0],
        [el["content"] for el in messages if el["role"] == "user"],
        response["choices"][0]["message"]["content"],
    ])

    index += 1
    print("**********************************************")

eval_dataset_df = pd.DataFrame(eval_dataset, columns=["system", "question", "answer"])
eval_dataset_df.to_json("./eval_dataset_results.jsonl", orient="records", lines=True)

### Delete Endpoint

In [None]:
import sagemaker

sagemaker_session = sagemaker.Session()

In [None]:
model_id = "google/gemma-3-4b-it"

endpoint_name = f"{model_id.split('/')[-1].replace('.', '-')}-djl"

In [None]:
predictor = sagemaker.Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sagemaker_session,
    serializer=sagemaker.serializers.JSONSerializer(),
    deserializer=sagemaker.deserializers.JSONDeserializer(),
)

In [None]:
predictor.delete_model()
predictor.delete_endpoint(delete_endpoint_config=True)