Lets use a custom safetensors file to export and import tensors. 
Starting with a toy example. 

In [1]:
import torch
from safetensors.torch import save_file as torch_save_file

In [2]:
tensor_1 = torch.randn(2, 4)
tensor_2 = torch.randn(3, 5, 10)


In [3]:
# Metadata
metadata = {"description": "Example tensors", "version": "1.0.0", "Arch" : "Dummy"}

In [12]:
# Save tensors with metadata
torch_save_file({
    "tensor1": tensor_1, 
    "tensor2": tensor_2 },
    "model.safetensors", metadata=metadata)

In [5]:
# now lets load. But lets start with the metadata loading only


In [13]:
from safetensors import safe_open
from safetensors.torch import load_file

In [14]:
# Load metadata
with safe_open("model.safetensors", framework="pt", device="cpu") as f:
    metadata = f.metadata()

print("Metadata:", metadata)

Metadata: {'Arch': 'Dummy', 'version': '1.0.0', 'description': 'Example tensors'}


In [20]:
# Conditionally load tensors based on metadata
# Adding coditions will be useful for us as we build out support for multiple architectures

if metadata.get("version") == "1.0.0":
    tensors = load_file("model.safetensors")
    print("Tensors loaded:", tensors.keys())
    
    print("Tensor1:", tensors["tensor1"].shape)
    print("Tensor2:", tensors["tensor2"].shape)
    
else:
    print("Condition not met, tensors not loaded.")

Tensors loaded: dict_keys(['tensor1', 'tensor2'])
Tensor1: torch.Size([2, 4])
Tensor2: torch.Size([3, 5, 10])


In [16]:
assert torch.equal(tensors['tensor1'], tensor_1), "Tensors are not equal"

In [19]:
assert torch.equal(tensors['tensor2'], tensor_2), "Tensors are not equal"

We can see the saved and loaded tensors match. We will use this technique to export the tensos from our export script and load in our inference engine