In [None]:
from pathlib import Path
import wandb
import torch

model_id = "model-c7fudnr9:v0"
# model_id = "model-g9cwh93c:v0"
model_path = Path("artifacts") / model_id

api = wandb.Api()
artifact = api.artifact("ain-space/gaia/" + model_id, type="model")
if model_path.exists():
    print("Model already downloaded")
else:
    artifact_dir = artifact.download(model_path)

ckpt = model_path / "model.ckpt"
print("Loading checkpoint from", ckpt)
checkpoint = torch.load(ckpt, weights_only=True, map_location="cpu")

In [None]:
from spherinator.models import yaml2model

model = yaml2model("/home/doserbd/git/Gaia/experiments/gaia_vae_8_nll_normal.yaml")
output = model(torch.randn(1, 1, 343))

In [None]:
model.load_state_dict(checkpoint["state_dict"])
model.eval()

input = torch.randn(10, 1, 343)
model(input)

In [None]:
import os

export_path = "gaia-calibrated-v2"
os.makedirs(export_path, exist_ok=False)

batch_size = 256

onnx = torch.onnx.export(
    model.variational_encoder,
    torch.randn(batch_size, 1, 343, device="cpu"),
    dynamic_axes={"input": {0: "batch"}},
    dynamo=True,
)
onnx.optimize()
onnx.save(os.path.join(export_path, "encoder.onnx"))

onnx = torch.onnx.export(
    model.decoder,
    torch.randn(batch_size, 3, device="cpu"),
    dynamic_axes={"input": {0: "batch"}},
    export_options=export_options,
    dynamo=True,
)
onnx.optimize()
onnx.save(os.path.join(export_path, "decoder.onnx"))

In [None]:
import onnxruntime as ort

ort_session = ort.InferenceSession(os.path.join(export_path, "encoder.onnx"))
ort_session = ort.InferenceSession(os.path.join(export_path, "decoder.onnx"))

In [None]:
#!rsync gaia-calibrated-v0 space:/var/www/html/space/models/