# Speed

One of the main benefits of Safetensors is its reading speed. Let's test it (and loading, just in case)

In [1]:
import torch
import torch.nn as nn
import pickle
import numpy as np
from safetensors.torch import save_file, load_file


class LargeModel(nn.Module):
    def __init__(self):
        super(LargeModel, self).__init__()
        self.fc1 = nn.Linear(10000, 10000)
        self.fc2 = nn.Linear(10000, 10000)
        self.fc3 = nn.Linear(10000, 10000)
        self.fc4 = nn.Linear(10000, 10000)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = self.fc4(x)
        return x


model = LargeModel()
weights = model.state_dict()

## Save

In [3]:
%%timeit
# SafeTensors
save_file(weights, "large_model_weights.safetensors")

2.09 s ± 746 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
%%timeit
# Pickle
with open("large_model_weights.pkl", "wb") as f:
    pickle.dump(weights, f)

3.68 s ± 646 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
%%timeit
# NumPy's .npy
np.save("large_model_weights.npy", {k: v.numpy() for k, v in weights.items()})

4.1 s ± 1.26 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [4]:
%%timeit
# Pytorch
torch.save(model.state_dict(), "large_model_weights.pt")

6.33 s ± 707 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
import os

print("Pickle .pt file size:", os.path.getsize("large_model_weights.pkl") / 1e9, "GB")
print("SafeTensors file size:", os.path.getsize("large_model_weights.safetensors") / 1e9, "GB")
print("Npy file size:", os.path.getsize("large_model_weights.npy") / 1e9, "GB")
print("PyTorch file size:", os.path.getsize("large_model_weights.pt") / 1e9, "GB")

Pickle .pt file size: 1.600162623 GB
SafeTensors file size: 1.600160672 GB
Npy file size: 1.600160729 GB


## Load

In [None]:
%%timeit
loaded_weights_safetensors = load_file("large_model_weights.safetensors")

174 µs ± 3.95 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [None]:
%%timeit
with open("large_model_weights.pkl", "rb") as f:
    loaded_weights_pickle = pickle.load(f)

1.5 s ± 183 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
%%timeit
loaded_weights_npy = np.load("large_model_weights.npy", allow_pickle=True).item()

222 ms ± 6.86 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


Here we are a lot faster

# Safety

With pickle we can execute any malicious code. In safetensors it is not possible

In [None]:
import pickle
import os

class MaliciousPayload(nn.Module):
    def __init__(self):
        super(MaliciousPayload, self).__init__()
        self.nonparam = torch.tensor(10.)
        self.param = nn.Parameter(torch.tensor(20.))
        
    def __reduce__(self):
        return (os.system, ("echo 'Malicious code executed!'; touch hacked.txt",))

malicious_data = MaliciousPayload()
with open("malicious_model.pkl", "wb") as f:
    pickle.dump(malicious_data, f)

In [None]:
with open("malicious_model.pkl", "rb") as f:
    loaded_data = pickle.load(f)
loaded_data

Malicious code executed!


0

In safetensors we can't do anything except saving 'state_dict' and storing metadata, which is string

In [None]:
from safetensors.torch import save_file, load_file

save_file(malicious_data.state_dict(), "safe.safetensors")

safe_data = load_file("safe.safetensors")
print(safe_data)

{'param': tensor(20.)}


(beware that you need to define your params as ```nn.Parameter```)

# Different formats support 

more for jax: https://github.com/alvarobartt/safejax

We can save the tensor for PyTorch and then load it for other frameworks (tf, jax, and some others)

In [None]:
import torch
import tensorflow as tf
import jax.numpy as jnp
from safetensors.torch import save_file, load_file
from safetensors.tensorflow import load_file as load_tf
from safetensors.flax import load_file as load_flax  # for JAX/Flax

tensor = torch.randn(3, 3)

save_file({"tensor": tensor}, "multi_framework.safetensors")

pt_data = load_file("multi_framework.safetensors")
print("PyTorch Tensor:\n", pt_data["tensor"])

tf_data = load_tf("multi_framework.safetensors")
print("TensorFlow Tensor:\n", tf_data["tensor"])

flax_data = load_flax("multi_framework.safetensors")
print("JAX/Flax Tensor:\n", flax_data["tensor"])

PyTorch Tensor:
 tensor([[-2.1482,  1.8262,  0.7550],
        [ 0.0742,  0.5142,  1.8213],
        [ 0.6014, -1.3224,  0.9582]])
TensorFlow Tensor:
 tf.Tensor(
[[-2.1481683   1.8261622   0.7550253 ]
 [ 0.07417796  0.51415616  1.8213471 ]
 [ 0.6013614  -1.3223515   0.95818126]], shape=(3, 3), dtype=float32)
JAX/Flax Tensor:
 [[-2.1481683   1.8261622   0.7550253 ]
 [ 0.07417796  0.51415616  1.8213471 ]
 [ 0.6013614  -1.3223515   0.95818126]]


# Lazy-loading

We can load only the pieces we need

In [7]:
from safetensors import safe_open

tensors = {}
with safe_open("large_model_weights.safetensors", framework="pt", device=0) as f:
    tensor_slice = f.get_slice("fc3.weight")
    vocab_size, hidden_dim = tensor_slice.get_shape()
    tensor = tensor_slice[:5, :5]
tensor

tensor([[ 6.1581e-03, -7.8856e-04, -6.8335e-03, -9.0303e-03,  6.2035e-03],
        [-5.8598e-03,  9.7183e-03,  9.8777e-05,  6.4537e-03, -3.9373e-03],
        [ 7.8363e-03, -5.1742e-04, -4.9633e-03, -7.9796e-03,  5.9262e-03],
        [-7.0394e-03,  2.6760e-03, -2.8524e-03,  3.0783e-03,  7.3222e-04],
        [-6.8006e-03,  1.5352e-03, -1.2444e-03,  8.1483e-03,  4.1770e-03]],
       device='mps:0')