# ðŸš€ Customize and Deploy `openai/gpt-oss-20b` on Amazon SageMaker AI
---
In this notebook, we explore how to use **gpt-oss-20b**, which is OpenAIâ€™s smaller open-weight reasoning model from the gpt-oss series. Youâ€™ll learn how to fine-tune it on your dataset, evaluate its performance, and deploy it at scale with SageMaker.

**What is gpt-oss-20b?**

OpenAI released **gpt-oss-20b** on **August 5, 2025**, as part of their first open-weight models since GPT-2. It is a **~21-billion-parameter mixture-of-experts (MoE) model** with ~3.6-4 billion active parameters per token. It is released under the **Apache 2.0** license and designed for low-latency inference, efficient deployment on consumer hardware, strong reasoning, tool use, and instruction following.  
ðŸ”— Model card: [openai/gpt-oss-20b on Hugging Face](https://huggingface.co/openai/gpt-oss-20b)

---

**Key Specifications**

| Feature | Details |
|---|---|
| **Parameters** | ~21 billion total; ~3.6-4 billion active per forward pass |
| **Architecture** | Mixture-of-Experts (MoE) transformer; dense sparse attention + grouped-multi-query formatting |
| **Input / Output** | Text-in / Text-out (harmony response format) |
| **Context Length** | Up to **128,000 tokens** |
| **Customizability** | Configurable reasoning effort (low / medium / high); fine-tunable |
| **License** | Apache 2.0 |

---

**Benchmarks & Behavior**

- Delivers performance similar to OpenAIâ€™s **o3-mini** on common benchmarks. 
- Optimized for deployment on devices with **â‰ˆ16 GB of memory**, making it well suited for edge or local inference.
- Supports instruction following, tool use, function calling, and structured output formats.

---

In [None]:
%pip install -Uq "datasets==4.3.0" \
    "sagemaker==2.253.1"

In [None]:
import boto3
import sagemaker

In [None]:
region = boto3.Session().region_name

sess = sagemaker.Session(boto3.Session(region_name=region))

sagemaker_session_bucket = None
if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

role = sagemaker.get_execution_role()

In [None]:
print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")

## Data Preparation for Supervised Fine-tuning

#### [Hermes Reasoning Tool Use](https://huggingface.co/datasets/interstellarninja/hermes_reasoning_tool_use)

**Hermes Reasoning Tool Use** is a specialized dataset compiled by interstellarninja, containing approximately **50,000 entries** focused on reasoning and tool-use chains in AI assistants. The dataset is designed to support advanced function-calling, multi-step reasoning, and interactive workflows under instruct-style settings.

**Data Format & Structure**:

* Distributed in **JSON** format.
* Contains a single `train` split with ~50k records.
* Each record includes:

  * `tools_used` â€“ a list of tool names invoked in the reasoning session
  * `user_instruction` â€“ a natural-language prompt requiring tool usage or reasoning
  * `assistant_response` â€“ the modelâ€™s answer which includes reasoning trace and tool calls
  * `tool_outputs` â€“ the structured outputs returned from each tool invocation
  * `final_answer` â€“ the concise result of the reasoning chain

**License**: Released under an **open-access research license** (please check the dataset card for the latest licensing terms on Hugging Face).

**Applications**:

The dataset can support a variety of advanced simulation and training tasks, including:

* Fine-tuning assistants to perform multi-step reasoning with external tool interfaces
* Agent-based workflows where models must call tools, interpret their outputs, and incorporate them into answers
* Benchmarking tool-use reasoning capabilities of LLMs
* Training models for interactive decision-making, function-calling, or knowledge-retrieval pipelines

If you like, I can pull more detailed metadata (field types, tool inventory, split sizes) from the dataset card and format for your notebook.


In [None]:
import os
import re
import json
import pprint
from tqdm import tqdm
from datasets import load_dataset

In [None]:
dataset_parent_path = os.path.join(os.getcwd(), "tmp_cache_local_dataset")
os.makedirs(dataset_parent_path, exist_ok=True)

**Preparing Your Dataset in `messages` format**

This section walks you through creating a conversation-style datasetâ€”the required `messages` formatâ€”for directly training LLMs using SageMaker AI.

**What Is the `messages` Format?**

The `messages` format structures instances as chat-like exchanges, wrapping each conversation turn into a role-labeled JSON array. Itâ€™s widely used by frameworks like TRL.

Example entry:

```json
{
  "messages": [
    { "role": "system", "content": "You are a helpful assistant." },
    { "role": "user", "content": "How do I bake sourdough?" },
    { "role": "assistant", "content": "First, you need to create a starter by..." }
  ]
}


In [None]:
dataset_name = "interstellarninja/hermes_reasoning_tool_use"
dataset = load_dataset(dataset_name, split="train[:1000]")

In [None]:
pprint.pp(dataset[0])

In [None]:
print(f"total number of fine-tunable samples: {len(dataset)}")

In [None]:
# Precompile regex patterns for efficiency and clarity
THINK_RE = re.compile(r"<think>(.*?)</think>", re.DOTALL)
TOOL_CALL_RE = re.compile(r"<tool_call>(.*?)</tool_call>", re.DOTALL)


def convert_to_messages_reasoning(row):
    """
    Convert a dataset row with 'conversations' into a messages dict.
    Behavior is preserved:
      - system/user content copied verbatim from indices 0 and 1.
      - assistant 'thinking' extracted from <think>...</think> (empty if absent).
      - assistant 'content' is all <tool_call> blocks re-wrapped and joined by '\n'.
    """
    conversations = row["conversations"]

    system_content = conversations[0]["value"]
    user_content = conversations[1]["value"]
    assistant_text = conversations[2]["value"]

    # Extract reasoning/thinking content
    think_match = THINK_RE.search(assistant_text)
    reasoning_content = think_match.group(1).strip() if think_match else ""

    # Extract and re-wrap all tool_call blocks, then join by newline
    tool_payloads = TOOL_CALL_RE.findall(assistant_text)
    assistant_content = "\n".join(f"<tool_call>{t}</tool_call>" for t in tool_payloads)

    return {
        "messages": [
            {"role": "system", "content": system_content, "thinking": None},
            {"role": "user", "content": user_content, "thinking": None},
            {"role": "assistant", "content": assistant_content, "thinking": reasoning_content},
        ]
    }


# Map over the dataset (behavior unchanged)
dataset = dataset.map(convert_to_messages_reasoning, remove_columns=dataset.column_names)

In [None]:
dataset_filename = os.path.join(dataset_parent_path, f"{dataset_name.replace('/', '--').replace('.', '-')}.jsonl")
dataset.to_json(dataset_filename, lines=True)

#### Upload file to S3

In [None]:
from sagemaker.s3 import S3Uploader

In [None]:
data_s3_uri = f"s3://{sess.default_bucket()}/dataset"

uploaded_s3_uri = S3Uploader.upload(
    local_path=dataset_filename,
    desired_s3_uri=data_s3_uri
)
print(f"Uploaded {dataset_filename} to > {uploaded_s3_uri}")

## Fine-Tune LLMs using SageMaker `Estimator`/`ModelTrainer`

In [None]:
import time
from sagemaker.modules.configs import (
    CheckpointConfig,
    Compute,
    OutputDataConfig,
    SourceCode,
    StoppingCondition,
)
from sagemaker.modules.configs import InputData
from sagemaker.modules.train import ModelTrainer
from getpass import getpass
import yaml
from jinja2 import Template

In [None]:
MODEL_ID = "openai/gpt-oss-20b"

In [None]:
hf_token = getpass()

### Training using `PyTorch` `ModelTrainer`
---
**Observability**: SageMaker AI has [SageMaker MLflow](https://docs.aws.amazon.com/sagemaker/latest/dg/mlflow.html) which enables you to accelerate generative AI by making it easier to track experiments and monitor performance of models and AI applications using a single tool.

You can choose to include MLflow as a part of your training workflow to track your model fine-tuning metrics in realtime by simply specifying a **mlflow** tracking arn.

Optionally you can also report to : **tensorboard**, **wandb**.

In [None]:
MLFLOW_TRACKING_SERVER_ARN = "arn:aws:sagemaker:us-east-1:XXXXXYYYYYZZ:mlflow-tracking-server/<name>" # or None

if MLFLOW_TRACKING_SERVER_ARN:
    reports_to = "mlflow"
else:
    reports_to = "tensorboard"

In [None]:
job_name = MODEL_ID.replace('/', '--').replace('.', '-')

In [None]:
if MLFLOW_TRACKING_SERVER_ARN:
    training_env = {
        # mlflow tracking metrics
        "MLFLOW_EXPERIMENT_NAME": f"{job_name}-exp",
        "MLFLOW_TAGS": json.dumps(
            {
                "source.job": "sm-training-jobs", 
                "source.type": "sft", 
                "source.framework": "pytorch"
            }
        ),
        "MLFLOW_TRACKING_URI": MLFLOW_TRACKING_SERVER_ARN,
        "MLFLOW_ENABLE_SYSTEM_METRICS_LOGGING": "true",
        # non tracking metrics - enabled
        "HF_TOKEN": hf_token,
        "FI_EFA_USE_DEVICE_RDMA": "1",
        "NCCL_DEBUG": "INFO",
        "NCCL_SOCKET_IFNAME": "eth0",
        "FI_PROVIDER": "efa",
        "NCCL_PROTO": "simple",
        "NCCL_NET_GDR_LEVEL": "5"
    }
else:
    training_env = {
        # non tracking metrics
        "HF_TOKEN": hf_token,
        "FI_EFA_USE_DEVICE_RDMA": "1",
        "NCCL_DEBUG": "INFO",
        "NCCL_SOCKET_IFNAME": "eth0",
        "FI_PROVIDER": "efa",
        "NCCL_PROTO": "simple",
        "NCCL_NET_GDR_LEVEL": "5"
    }

In [None]:
pytorch_image_uri = f"763104351884.dkr.ecr.{region}.amazonaws.com/pytorch-training:2.8.0-gpu-py312-cu129-ubuntu22.04-sagemaker"
print(f"Using image: {pytorch_image_uri}")

#### Training strategy - Choose between: `PeFT`/`Spectrum`/`Full-Finetuning`

Here we create a measured mapping of strategy to instance.

In [None]:
%%writefile sagemaker_code/requirements.txt
transformers==4.56.1
peft==0.17.0
accelerate==1.11.0
bitsandbytes==0.46.1
datasets==4.0.0
deepspeed==0.17.5
hf-transfer==0.1.8
hf_xet
liger-kernel==0.6.1
lm-eval[api]==0.4.9
kernels>=0.9.0
mlflow
Pillow
safetensors>=0.6.2
sagemaker==2.251.1
sagemaker-mlflow==0.1.0
sentencepiece==0.2.0
tokenizers>=0.21.4
triton
trl==0.21.0
tensorboard
py7zr
git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels
vllm==0.10.1
poetry
yq
psutil
nvidia-ml-py
pyrsmi

In [None]:
# For PeFT
args = [
    "--config",
    "hf_recipes/openai/gpt-oss-20b--vanilla-peft-qlora.yaml",
    # "--run-eval" # enable this for small models to run eval + tune
]
training_instance_type = "ml.p4de.24xlarge"
training_instance_count = 1

## For Spectrum
# args = [
#     "--config",
#     "hf_recipes/openai/gpt-oss-20b--vanilla-spectrum.yaml",
#     # "--run-eval" # enable this for small models if you're looking to bundle eval with fine-tuning
# ]
# training_instance_type = "ml.p4de.24xlarge"
# training_instance_count = 1

## For Full-Finetuning
# args = [
#     "--config",
#     "hf_recipes/openai/gpt-oss-20b--vanilla-full.yaml",
#     # "--run-eval" # enable this for small models if you're looking to bundle eval with fine-tuning
# ]
# training_instance_type = "ml.p4de.24xlarge"
# training_instance_count = 1


In [None]:
pytorch_image_uri = sagemaker.image_uris.retrieve(
    framework="pytorch",
    region=sess.boto_session.region_name,
    version="2.8.0",
    instance_type=training_instance_type,
    image_scope="training",
)
print(f"Using image: {pytorch_image_uri}")

In [None]:
source_code = SourceCode(
    source_dir="./sagemaker_code",
    command=f"bash sm_accelerate_train.sh {' '.join(args)}",
)

compute_configs = Compute(
    instance_type=training_instance_type,
    instance_count=training_instance_count,
    keep_alive_period_in_seconds=1800,
    volume_size_in_gb=300
)

base_job_name = f"{job_name}-finetune"
output_path = f"s3://{sess.default_bucket()}/{base_job_name}"

model_trainer = ModelTrainer(
    training_image=pytorch_image_uri,
    source_code=source_code,
    base_job_name=base_job_name,
    compute=compute_configs,
    stopping_condition=StoppingCondition(max_runtime_in_seconds=18000),
    output_data_config=OutputDataConfig(
        s3_output_path=output_path,
    ),
    checkpoint_config=CheckpointConfig(
        s3_uri=os.path.join(
            output_path,
            dataset_name.replace('/', '--').replace('.', '-'), 
            job_name,
            "checkpoints"
        ), 
        local_path="/opt/ml/checkpoints"
    ),
    role=role,
    environment=training_env
)

In [None]:
model_trainer.train(
    input_data_config=[
        InputData(
            channel_name="training",
            data_source=uploaded_s3_uri,  
        )
    ], 
    wait=False
)