In [None]:
from pathlib import Path
import wandb
import torch

model_id = "model-1tgv4b7q:v0"
model_path = Path("artifacts") / model_id

api = wandb.Api()
artifact = api.artifact("ain-space/gaia/" + model_id, type="model")
if model_path.exists():
    print("Model already downloaded")
else:
    artifact_dir = artifact.download(model_path)

ckpt = model_path / "model.ckpt"
print("Loading checkpoint from", ckpt)
checkpoint = torch.load(ckpt, weights_only=True, map_location="cpu")

In [None]:
from spherinator.models import yaml2model

model = yaml2model("../experiments/gaia_vae_8.yaml")
output = model(torch.randn(1, 1, 343))

In [None]:
model.load_state_dict(checkpoint["state_dict"])
model.eval()

input = torch.randn(10, 1, 343)
model(input)

In [None]:
import os

export_path = "gaia-calibrated-v1"
os.makedirs(export_path, exist_ok=True)

export_options = torch.onnx.ExportOptions(dynamic_shapes=False)
onnx = torch.onnx.dynamo_export(
    model.variational_encoder,
    torch.randn(256, 1, 343, device="cpu"),
    export_options=export_options,
)
onnx.save(os.path.join(export_path, "encoder.onnx"))

onnx = torch.onnx.dynamo_export(
    model.decoder, torch.randn(256, 20, device="cpu"), export_options=export_options
)
onnx.save(os.path.join(export_path, "decoder.onnx"))

In [None]:
import onnxruntime as ort

ort_session = ort.InferenceSession(os.path.join(export_path, "encoder.onnx"))

In [None]:
#!rsync gaia-calibrated-v0 space:/var/www/html/space/models/