In [20]:
import torch
import safetensors.torch as sft
import os

def convert_to_safetensors(input_path, output_path=None):
    if output_path is None:
        output_path = os.path.splitext(input_path)[0] + ".safetensors"

    try:
        if not os.path.exists(input_path):
            raise FileNotFoundError(f"Output directory does not exist: {input_path}")

        # load the model
        original_state_dict  = torch.load(input_path, map_location='cpu')

        # check if it's a checkpoint
        if isinstance(original_state_dict, dict) and "model_state_dict" in original_state_dict :
            original_state_dict  = original_state_dict ["model_state_dict"]

        # save as a safetensor
        sft.save_file(original_state_dict , output_path)
        print(f"Converted {input_path} to {output_path}")

        # -----------------------

        # load the safetensor to verify
        converted_state_dict  = sft.load_file(output_path)

        # [safety checks]

        # compare keys
        original_keys = set(original_state_dict .keys())
        converted_keys = set(converted_state_dict .keys())
        print("Keys match:", original_keys == converted_keys)

    except Exception as e:
        print(f"Error loading model: {e}")
        return

convert_to_safetensors("../models/embeddings/hubert/pytorch_model.bin")


Converted ../models/embeddings/hubert/pytorch_model.bin to ../models/embeddings/hubert/pytorch_model.safetensors
Keys match: True
