In [None]:
import wandb
run = wandb.init()
artifact = run.use_artifact('ain-space-org/wandb-registry-model/gaia calibrated:v0', type='model')
artifact_dir = artifact.download()
artifact_dir

In [None]:
# import yaml
# import importlib

# config = yaml.load(open('model.yaml'), Loader=yaml.FullLoader)
# model_class_path = config["model"]["class_path"]
# module_name, class_name = model_class_path.rsplit(".", 1)
# module = importlib.import_module(module_name)
# model_class = getattr(module, class_name)
# model_init_args = config["model"]["init_args"]
# model = model_class(**model_init_args)

In [None]:
import torch
from spherinator.models import (
    ConsecutiveConv1DLayer,
    ConvolutionalEncoder1DGen,
    DenseModel,
    VariationalAutoencoderPure,
)

cnn_layers=[ConsecutiveConv1DLayer(kernel_size=5,
                                   num_layers=1,
                                   base_channel_number=16,
                                   pooling=torch.nn.MaxPool1d(2, ceil_mode=True)),
            ConsecutiveConv1DLayer(kernel_size=7,
                                   num_layers=1,
                                   base_channel_number=32,
                                   pooling=torch.nn.MaxPool1d(2, ceil_mode=True)),
            ConsecutiveConv1DLayer(kernel_size=9,
                                   num_layers=1,
                                   base_channel_number=64,
                                   pooling=torch.nn.MaxPool1d(2, ceil_mode=True))]
encoder = ConvolutionalEncoder1DGen(input_dim=[1, 343],
                                    output_dim=3,
                                    cnn_layers=cnn_layers)
decoder = DenseModel(layer_dims=[3, 16, 64, 256, 343],
                     output_shape=[1, 343])
model = VariationalAutoencoderPure(encoder=encoder,
                                   decoder=decoder,
                                   z_dim=3,
                                   beta=1.0e-3)

input = torch.randn(1, 1, 343)
model(input)

In [None]:
import torch

ckpt = artifact_dir + "/model.ckpt"
checkpoint = torch.load(ckpt, weights_only=True, map_location="cpu")
checkpoint

In [None]:
model.load_state_dict(checkpoint["state_dict"])
model.eval()

input = torch.randn(1, 1, 343)
model(input)

In [None]:
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
onnx = torch.onnx.dynamo_export(model.variational_encoder, torch.randn(1, 1, 343, device="cpu"), export_options=export_options)
# onnx = torch.onnx.dynamo_export(model.variational_encoder, torch.randn(1, 1, 343, device="cpu"))
onnx.save("encoder.onnx")

onnx = torch.onnx.dynamo_export(model.decoder, torch.randn(1, 3, device="cpu"), export_options=export_options)
onnx.save("decoder.onnx")

In [None]:
!scp -p encoder.onnx space:/var/www/html/space/models/gaia-calibrated-v0/encoder.onnx
!scp -p decoder.onnx space:/var/www/html/space/models/gaia-calibrated-v0/decoder.onnx