In [None]:
import wandb

run = wandb.init()
artifact = run.use_artifact("ain-space/gaia/model-k1knlvtk:v0", type="model")
artifact_dir = artifact.download()

In [None]:
from spherinator.models import WeightsProvider

weights = WeightsProvider("artifacts/model-k1knlvtk:v0/model.ckpt", prefix="encoder")

print(weights.get_state_dict())

In [None]:
import torch
from spherinator.models import (
    ConsecutiveConv1DLayer,
    ConvolutionalEncoder1DGen,
    ConsecutiveConvTranspose1DLayer,
    ConvolutionalDecoder1DGen,
    AutoencoderPure,
)

encoder = ConvolutionalEncoder1DGen(
    input_dim=[1, 343],
    output_dim=20,
    cnn_layers=[
        ConsecutiveConv1DLayer(
            kernel_size=7,
            num_layers=5,
            base_channel_number=16,
            channel_increment=4,
        ),
        ConsecutiveConv1DLayer(
            kernel_size=5,
            stride=2,
            num_layers=1,
            base_channel_number=64,
        ),
        ConsecutiveConv1DLayer(
            kernel_size=5,
            stride=2,
            num_layers=1,
            base_channel_number=96,
        ),
        ConsecutiveConv1DLayer(
            kernel_size=6,
            stride=2,
            num_layers=1,
            base_channel_number=128,
        ),
    ],
    weights=WeightsProvider("artifacts/model-k1knlvtk:v0/model.ckpt", prefix="encoder"),
)
decoder = ConvolutionalDecoder1DGen(
    input_dim=20,
    output_dim=[1, 343],
    cnn_input_dim=[128, 36],
    cnn_layers=[
        ConsecutiveConvTranspose1DLayer(
            kernel_size=6,
            stride=2,
            out_channels_list=[96],
        ),
        ConsecutiveConvTranspose1DLayer(
            kernel_size=5,
            stride=2,
            out_channels_list=[64],
        ),
        ConsecutiveConvTranspose1DLayer(
            kernel_size=5,
            stride=2,
            out_channels_list=[32],
        ),
        ConsecutiveConvTranspose1DLayer(
            kernel_size=7,
            out_channels_list=[28, 24, 20, 16, 1],
            activation=None,
        ),
    ],
    weights=WeightsProvider("artifacts/model-k1knlvtk:v0/model.ckpt", prefix="decoder"),
)
model = AutoencoderPure(encoder=encoder, decoder=decoder)

input = torch.randn(1, 1, 343)
model(input)