Let's export the trained model in ONNX and safetensors formats for compatibility with downstream inference engines. First, we'll define some variables.

In [1]:
model_name = "LightGPT-Small-Base"
checkpoint_path = "./checkpoints/checkpoint.pt"
lora_path = None  # "./checkpoints/lora_instruct.pt"
exports_path = "./exports"

Then, we'll load the base model checkpoint into memory from disk.

In [2]:
import torch

from model import LightGPT

checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True)

model = LightGPT(**checkpoint["model_args"])

model = torch.compile(model)

model.load_state_dict(checkpoint["model"])

print("Base checkpoint loaded successfully")

Base checkpoint loaded successfully


Now, we'll load any LoRA checkpoints we wish to incorporate into the exported model.

In [3]:
from model import LightGPTInstruct

if lora_path != None:
    checkpoint = torch.load(lora_path, map_location="cpu", weights_only=True)

    model = LightGPTInstruct(model, **checkpoint["lora_args"])

    model = torch.compile(model)

    model.load_state_dict(checkpoint["lora"], strict=False)

    model.merge_lora_parameters()

    print("LoRA checkpoint loaded successfully")

For ONNX format we'll use TorchDynamo to trace the FX Graph of our model using some example data and then translate the intermediate representation to ONNX format.

In [4]:
from os import path

from model import ONNXModel

from torch.onnx import dynamo_export, ExportOptions

example_input = torch.randint(0, model.vocabulary_size - 1, (1, 1024))

onnx_model = ONNXModel(model)  # Nicer inferencing API

onnx_model.eval()  # Turn off dropout and other train-time operations

export_options = ExportOptions(
    dynamic_shapes=True
)  # Necessary for variable batch and sequence lengths

onnx_model = dynamo_export(onnx_model, example_input, export_options=export_options)

onnx_path = path.join(exports_path, f"{model_name}.onnx")

onnx_model.save(onnx_path)

print(f"Model saved to {onnx_path}")

  param_schemas = callee.param_schemas()
  param_schemas = callee.param_schemas()
  self.param_schema = self.onnxfunction.param_schemas()


Applied 73 of general pattern rewrite rules.
Model saved to ./exports/LightGPT-Small-Base.onnx


Compare the output of PyTorch with the ONNX runtime to see if they are the same.

In [5]:
import onnxruntime

from numpy.testing import assert_allclose

pytorch_logits = model.predict(example_input).detach().numpy()

session = onnxruntime.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])

onnx_input = {"l_x_": example_input.numpy()}

onnx_logits = session.run(None, onnx_input)

onnx_logits = onnx_logits[0]

assert_allclose(pytorch_logits, onnx_logits, rtol=1e-2, atol=1e-03)

print("Looks good!")

Looks good!


Lastly, let's export the model in HuggingFace format so that it can be used with the HuggingFace ecosystem.

In [6]:
from os import path

from transformers import AutoConfig, AutoModelForCausalLM

from model import LightGPTHuggingFaceConfig, LightGPTHuggingFaceModel

hf_path = path.join(exports_path, model_name)

AutoConfig.register("lightgpt", LightGPTHuggingFaceConfig)
AutoModelForCausalLM.register(LightGPTHuggingFaceConfig, LightGPTHuggingFaceModel)

LightGPTHuggingFaceConfig.register_for_auto_class()
LightGPTHuggingFaceModel.register_for_auto_class("AutoModel")

hf_config = LightGPTHuggingFaceConfig(**checkpoint["model_args"])

hf_model = LightGPTHuggingFaceModel(hf_config)

hf_model.model = torch.compile(hf_model.model)

# Compensate for poor HuggingFace Transformers support for tied weights.
state_dict = model.state_dict()
state_dict = {k:v for k, v in state_dict.items() if "output_layer" not in k}

hf_model.model.load_state_dict(state_dict, strict=False)

hf_model.save_pretrained(hf_path, state_dict=state_dict)

print(f"Model saved to {hf_path}")

Model saved to ./exports/LightGPT-Small-Base


Lastly, we'll login to HuggingFaceHub and upload the model under our account.

In [7]:
from huggingface_hub import notebook_login

notebook_login()

hf_model.push_to_hub(model_name)

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

RuntimeError: The weights trying to be saved contained shared tensors [{'model._orig_mod.token_embeddings.weight', 'model._orig_mod.output_layer.weight'}] that are mismatching the transformers base configuration. Try saving using `safe_serialization=False` or remove this tensor sharing.