# Quantization-Aware Fine-Tuning for GPT-OSS

This notebook demonstrates a complete workflow for fine-tuning language models with Quantization-Aware Training (QAT) using modelopt and SFTTrainer for gpt-oss models.

## Overview

The workflow includes:

• Model and tokenizer loading

• Dataset preparation

• Training configuration setup

• Model quantization

• Quantization aware training

• Model saving and checkpointing

**Setup Environment**

In [None]:
%pip install --upgrade transformers trl

In [None]:
import modelopt.torch.opt as mto

# Enable automatic save/load of modelopt state huggingface checkpointing
# modelopt state will be saved automatically to "modelopt_state.pth"
mto.enable_huggingface_checkpointing()

**Model Configuration**

Configure the model parameters including the model path, attention implementation, and data type. Set up the model configuration and prepare the model loading arguments.

In [None]:
from transformers import AutoConfig, Mxfp4Config
from trl import ModelConfig

model_args = ModelConfig(
    model_name_or_path="openai/gpt-oss-20b",
    attn_implementation="eager",
    torch_dtype="bfloat16",
)
model_kwargs = {
    "revision": model_args.model_revision,
    "trust_remote_code": model_args.trust_remote_code,
    "attn_implementation": model_args.attn_implementation,
    "torch_dtype": model_args.torch_dtype,
    "use_cache": False,
    "device_map": "auto",
}

# Dequantize if the model is in MXFP4 format
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
if (
    getattr(config, "quantization_config", {})
    and config.quantization_config.get("quant_method", None) == "mxfp4"
):
    model_kwargs["quantization_config"] = Mxfp4Config(dequantize=True)

**Load the Model and Tokenizer**

Load the pre-trained model and tokenizer with the specified configuration.

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs)

tokenizer = AutoTokenizer.from_pretrained(
    model_args.model_name_or_path,
)

**Dataset Configuration**

Set up the dataset parameters for training and evaluation. This includes specifying the dataset name, train/test splits, and test size ratio.

In [None]:
from trl import ScriptArguments

script_args = ScriptArguments(
    dataset_name="HuggingFaceH4/Multilingual-Thinking",
    dataset_train_split="train",
    dataset_test_split="test",
)
test_size = 0.1

**Load and Prepare Dataset**

Load the dataset and split it into training and evaluation sets. The dataset is split with the specified test size ratio and random seed for reproducibility.

In [None]:
from datasets import load_dataset

dataset = load_dataset(script_args.dataset_name)
# split the dataset into train and test
dataset = dataset[script_args.dataset_train_split].train_test_split(test_size=test_size, seed=42)
train_dataset = dataset[script_args.dataset_train_split]
eval_dataset = dataset[script_args.dataset_test_split]

**Training Configuration**

Configure the training parameters including epochs, batch sizes, learning rate, gradient accumulation, and evaluation strategy. This sets up the SFT configuration for supervised fine-tuning.

In [None]:
from trl import SFTConfig

training_args = SFTConfig(
    output_dir="gpt-oss-20b-multilingual-reasoner",
    num_train_epochs=0.1,
    learning_rate=2e-5,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=2,
    max_length=4096,
    warmup_ratio=0.03,
    eval_strategy="steps",
    eval_on_start=True,
    logging_steps=10,
    save_steps=50,
    eval_steps=10,
    save_total_limit=2,
)

**Initialize Trainer**

Set up the SFT trainer with the model, dataset, and training configuration.

In [None]:
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset[script_args.dataset_train_split],
    eval_dataset=dataset[script_args.dataset_test_split],
    processing_class=tokenizer,
)

**Quantization aware Training**

Configure the quantization parameters and prepare the calibration dataset. This step sets up the quantization configuration, creates a calibration subset from the evaluation dataset, and defines a forward loop function for model calibration. The calibration process helps determine optimal quantization scales for the model weights and activations.

In [None]:
import torch

import modelopt.torch.quantization as mtq

# MXFP4_MLP_WEIGHT_ONLY_CFG doesn't need calibration, but other quantization configurations may require it.
quantization_config = mtq.MXFP4_MLP_WEIGHT_ONLY_CFG
calib_size = 128

dataset = torch.utils.data.Subset(
    trainer.eval_dataset, list(range(min(len(trainer.eval_dataset), calib_size)))
)
data_loader = trainer.get_eval_dataloader(dataset)


def forward_loop(model):
    for data in data_loader:
        model(**data)

Apply quantization to the model using the prepared configuration and calibration data.

In [None]:
mtq.quantize(model, quantization_config, forward_loop)

Start the quantization-aware training.

In [None]:
trainer.train()

**Model Saving and Checkpointing**

Save the trained and quantized model with HuggingFace checkpointing enabled to store the modelopt state automatically.

In [None]:
model.save_pretrained(training_args.output_dir)