# Saving and Exporting

In this tutorial, we will teach you how to export and save the trained models

In [1]:
!pip install --upgrade pip
!pip install --upgrade datasets[audio] transformers accelerate evaluate jiwer tensorboard gradio
!pip install optimum

Collecting pip
  Downloading pip-25.1.1-py3-none-any.whl.metadata (3.6 kB)
Downloading pip-25.1.1-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m21.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 24.1.2
    Uninstalling pip-24.1.2:
      Successfully uninstalled pip-24.1.2
Successfully installed pip-25.1.1
Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting jiwer
  Downloading jiwer-3.1.0-py3-none-any.whl.metadata (2.6 kB)
Collecting tensorboard
  Downloading tensorboard-2.19.0-py3-none-any.whl.metadata (1.8 kB)
Collecting gradio
  Downloading gradio-5.29.0-py3-none-any.whl.metadata (16 kB)
Collecting datasets[audio]
  Downloading datasets-3.5.1-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets[audio])
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collec

### Exporting and Saving ML models
Exporting and saving machine learning models is a crucial step in the model development process, allowing you to preserve the state of a model after training and deploy it in different environments. We will focus on various formats used, and some key points to take note of when saving

### Exporting vs Saving

**Saving**
- Preserves model's architecture, trained weights and often associated configuration information (hyperparameters or vocabulary), so you don't need to train each time
- Intended for future use within the same framework or closely related environments where you started training

**Exporting**
- Converts the model into a representation suitable for deployment in production environments or for use across different frameworks
- Often involves optimizations or format changes for better inference speed and compatability

### Common Formats for saving models
Framework specific formats:
- PyTorch (.pth or .pt): Saves either the entire model or just the state dictionary, which only includes the weights and biases
- Tensorflow/Keras (.h5 or SavedModel): Tensorflow offers multiple ways to save models; as a HDF5 file containing the architecture, weights and training configuration, or as a SavedModel directory, which is a more comprehensive save format

Framework-Agnostic Formats
- ONNX (Open Neural Network Exchange): A cross-platform format supported by many deep learning frameworks, allowing for model exchange between different tools
- Safetensors: A fast and safe ways to store tensors across multiple frameworks. Safetensors stores only the weights (aka state dictionary), whereas ONNX stores both the computation graph and the weights.

### Saving Models in PyTorch

In [2]:
import torch
from transformers import WhisperProcessor
from transformers import WhisperForConditionalGeneration

#Load Whisper, an ASR model
processor = WhisperProcessor.from_pretrained("openai/whisper-small")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/185k [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/283k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/836k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.48M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/494k [00:00<?, ?B/s]

normalizer.json:   0%|          | 0.00/52.7k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/34.6k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.19k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.97k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/967M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/3.87k [00:00<?, ?B/s]

In [3]:
# Save the entire model
torch.save(model, "whisper_model.pth")

# Saving only the state dictionary
torch.save(model.state_dict(), "whisper_model_state_dict.pth")


### Saving in SafeTensors
We specifically use save_model here as safetensors disallows tensor sharing, which torch uses.

Read more about it [here](https://huggingface.co/docs/safetensors/en/torch_shared_tensors)

In [4]:
#Saving as safetensors:
from safetensors.torch import save_model

save_model(model, "whisper.safetensors")


### Exporting Models to ONNX
Exporting a model to ONNX requires the model to be in evaluation mode, and a sample input to trace the computation graph

In [5]:
! pip install onnx onnxruntime

Collecting onnx
  Downloading onnx-1.17.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (16 kB)
Collecting onnxruntime
  Downloading onnxruntime-1.21.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.5 kB)
Collecting coloredlogs (from onnxruntime)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)
Downloading onnx-1.17.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (16.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.0/16.0 MB[0m [31m83.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading onnxruntime-1.21.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (16.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.0/16.0 MB[0m [31m108.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 

In [6]:
# Download a sample dataset to get inputs for ONNX
from datasets import load_dataset

ds = load_dataset("laion/LAION-Audio-300M", split = 'train', streaming = True)

sample = next(iter(ds))
print(sample)

README.md:   0%|          | 0.00/31.0 [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/2775 [00:00<?, ?it/s]

{'audio.mp3': {'path': 'sample_0.audio.mp3', 'array': array([ 0.00401239,  0.00369918,  0.00052929, ...,  0.01347164,
       -0.01134891, -0.02810542]), 'sampling_rate': 16000}, 'metadata.json': {'caption': 'A melancholic piano melody plays, characterized by a slow tempo and a minor key. The recording quality suggests a home studio setup, with a slightly warm and intimate sound.  The piece evokes feelings of wistful longing.', 'channel_follower_count': 63200000, 'duration_ms': 4364, 'id': 'Bf_5ya97wHw', 'like_count': 35, 'segment_filename': 'general_segment_2134770_0.mp3', 'segment_index': 0, 'start_time_ms': 10142, 'title': 'इ कईसन ऐहसास हो रहल बा || E Kaisan Aehsas Ho Rahal Ba || Brijwa || Bhojpuri Hit  Songs 2021 new', 'transcription': ' ', 'type': 'general-purpose', 'uploader': 'Wave Music', 'uploader_id': '@WaveMusicIndia', 'view_count': 5267}, '__key__': 'sample_0', '__url__': 'hf://datasets/laion/LAION-Audio-300M@29eaacba2d0815aaf608ab34303555b9c895792e/flash_15_2_random_snippet

### Simple demo of Exporting Models to ONNX

In [7]:
# Load a pre-trained BERT model and tokenizer
from transformers import BertTokenizer, BertForQuestionAnswering

tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/443 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/1.34G [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-large-uncased-whole-word-masking-finetuned-squad were not used when initializing BertForQuestionAnswering: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [8]:
import torch.onnx

# Set the model to evaluation mode
model.eval()

# Create dummy input as required for the model to run
inputs = tokenizer("What is AI?", "AI is Artificial Intelligence", return_tensors="pt")

# Export the model
torch.onnx.export(model,
                  args=(inputs['input_ids'], inputs['attention_mask']),
                  f="qa_model.onnx",
                  input_names=['input_ids', 'attention_mask'],
                  output_names=['start_logits', 'end_logits'],
                  dynamic_axes={'input_ids' : {0 : 'batch_size'},    # Variable batch size
                                'attention_mask' : {0 : 'batch_size'},
                                'start_logits' : {0 : 'batch_size'},
                                'end_logits' : {0 : 'batch_size'}})


# Caveat with ONNX:
It can only handle simple architectures. More complex architectures like encoder-decoder models (Whisper) might not be supported

- Libraries like HuggingFace might also implement custom layers which take in inputs unsupported by ONNX.

In [9]:
# This code will fail!

import torch.onnx

# Set the model to evaluation mode
model.eval()

# Create dummy input as required for the model to run
inputs = processor(sample['audio.mp3']['array'], return_tensors="pt")
print(inputs)

# Export the model ()
torch.onnx.export(
    model,
    inputs,
    "whisper_model.onnx",
    input_names=["input_features"],
    output_names=["logits"],
    dynamic_axes={
        "input_features": {0: "batch_size", 2: "sequence_length"},
        "logits": {0: "batch_size"}
    },
    do_constant_folding=True
)

It is strongly recommended to pass the `sampling_rate` argument to `WhisperFeatureExtractor()`. Failing to do so can result in silent errors that might be hard to debug.


{'input_features': tensor([[[ 0.4287, -0.0635, -0.0403,  ..., -0.9665, -0.9665, -0.9665],
         [ 0.4513,  0.1456,  0.2386,  ..., -0.9665, -0.9665, -0.9665],
         [ 0.4246,  0.0689,  0.2392,  ..., -0.9665, -0.9665, -0.9665],
         ...,
         [-0.7965, -0.8905, -0.8482,  ..., -0.9665, -0.9665, -0.9665],
         [-0.7202, -0.9665, -0.9665,  ..., -0.9665, -0.9665, -0.9665],
         [-0.7256, -0.9665, -0.9665,  ..., -0.9665, -0.9665, -0.9665]]])}


RuntimeError: Only tuples, lists and Variables are supported as JIT inputs/outputs. Dictionaries and strings are also accepted, but their usage is not recommended. Here, received an input of unsupported type: BatchFeature

### Solution: Use HuggingFace's Optimum library to solve this instead
- Optimum provides a set of functions to help automate the exporting to ONNX or other runtimes in a few lines of code

In [10]:
from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
from transformers import WhisperProcessor
import torch

# Load the model and processor
model_id = "openai/whisper-small"
processor = WhisperProcessor.from_pretrained(model_id)

# Create ORTModelForSpeechSeq2Seq (You can replace model id with the path)
ort_model = ORTModelForSpeechSeq2Seq.from_pretrained(model_id, export=True)

# Export the model to ONNX
ort_model.save_pretrained("./whisper-small-onnx")

print("Whisper model exported to ONNX in the 'whisper-small-onnx' directory")

  if input_features.shape[-1] != expected_seq_length:
  if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
  if sequence_length != 1:
  or not self.key_cache[layer_idx].numel()  # the layer has no cache
  elif (


Whisper model exported to ONNX in the 'whisper-small-onnx' directory


### Loading and Using Saved Models

Loading Models in PyTorch

In [11]:
# Explicitly allow the WhisperForConditionalGeneration class as a safe global
torch.serialization.add_safe_globals([WhisperForConditionalGeneration])

# Now load the model with weights_only=False (not recommended unless trusted source)
model = torch.load('whisper_model.pth', weights_only=False)


Loading only the state_dict (doesn't have a model)

In [12]:
# Initialize the model
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")

# load the state_dict
state_dict = torch.load("whisper_model_state_dict.pth")

# load the weights into the model
model.load_state_dict(state_dict)

<All keys matched successfully>

Loading Models in SafeTensors

In [13]:
from safetensors.torch import load_file

file_path = "./whisper.safetensors"
loaded = load_file(file_path)

Before you proceed, run the following to make space in RAM

In [14]:
del model
del loaded
del ort_model
del ds
del inputs

Loading Models in ONNX

In [15]:
# Using onnxruntime
import onnxruntime as ort
from optimum.onnxruntime import ORTModelForSpeechSeq2Seq

session = ort.InferenceSession('qa_model.onnx')

#Using optimum
ort_model = ORTModelForSpeechSeq2Seq.from_pretrained('./whisper-small-onnx/')


### Further resources and steps:
- Have a look at how to quantize, train ONNX accelerated models using Optimum
- You can try visualising your ONNX models at [netron](https://netron.app/)
- Try out other exporting methods, like TorchFX, BetterTransformer - which can be done natively or through Optimum