# Demo for Spherinator Training using Gaia XP DR3

## Download the Gaia DR3 XP data

For the demo only use the first of 3386 files.

In [1]:
!wget -nc -P ./data/gaia/csv https://cdn.gea.esac.esa.int/Gaia/gdr3/Spectroscopy/xp_continuous_mean_spectrum/XpContinuousMeanSpectrum_000000-003111.csv.gz

File ‘./data/gaia/csv/XpContinuousMeanSpectrum_000000-003111.csv.gz’ already there; not retrieving.



## Data preparation

- Conversion of files from csv to parquet
- Normalization of the data 

In [2]:
from pest import GaiaConverter

gaia_converter = GaiaConverter(
    with_flux_error=True,
    number_of_workers=1,
)
gaia_converter.convert_all("data/gaia/csv", "data/gaia/parquet")

Found 1 files to convert
File data/gaia/parquet/XpContinuousMeanSpectrum_000000-003111.parquet already exists, skipping


## Training the model

In [3]:
import spherinator.models as sm

model = sm.VariationalAutoencoder(
    encoder=sm.ConvolutionalEncoder1D(
        input_dim=[1, 343],
        output_dim=128,
        cnn_layers=[
            sm.ConsecutiveConv1DLayer(
                kernel_size=7,
                stride=1,
                padding=0,
                num_layers=5,
                base_channel_number=16,
                channel_increment=4,
            ),
            sm.ConsecutiveConv1DLayer(
                kernel_size=5,
                stride=2,
                padding=0,
                num_layers=1,
                base_channel_number=64,
            ),
            sm.ConsecutiveConv1DLayer(
                kernel_size=5,
                stride=2,
                padding=0,
                num_layers=1,
                base_channel_number=96,
            ),
            sm.ConsecutiveConv1DLayer(
                kernel_size=5,
                stride=2,
                padding=0,
                num_layers=1,
                base_channel_number=128,
            ),
        ],
    ),
    decoder=sm.ConvolutionalDecoder1D(
        input_dim=3,
        output_dim=[1, 343],
        cnn_input_dim=[128, 36],
        cnn_layers=[
            sm.ConsecutiveConvTranspose1DLayer(
                kernel_size=6,
                stride=2,
                padding=0,
                out_channels_list=[96],
            ),
            sm.ConsecutiveConvTranspose1DLayer(
                kernel_size=5,
                stride=2,
                padding=0,
                out_channels_list=[64],
            ),
            sm.ConsecutiveConvTranspose1DLayer(
                kernel_size=5,
                stride=2,
                padding=0,
                out_channels_list=[32],
            ),
            sm.ConsecutiveConvTranspose1DLayer(
                kernel_size=7,
                stride=1,
                padding=0,
                out_channels_list=[28, 24, 20, 16, 1],
                activation=None,
            ),
        ],
    ),
    z_dim=3,
    beta=1.0e-4,
    encoder_out_dim=128,
)
# _ = model(model.example_input_array)
# model

  param_schemas = callee.param_schemas()
  param_schemas = callee.param_schemas()


In [4]:
import spherinator.data as sd

datamodule = sd.ParquetDataModule(
    data_directory="data/gaia/parquet",
    data_column="flux",
    normalize="minmax",
    batch_size=2048,
    num_workers=4,
    shuffle=True,
)
# datamodule.setup("fit")
# print(f"Number of training items: {len(datamodule.data_train)}")

In [5]:
import lightning.pytorch as pl

trainer = pl.Trainer(
    max_epochs=10,
    accelerator="gpu",
    precision="16-mixed",
)
trainer.fit(model, datamodule=datamodule)

Using 16bit Automatic Mixed Precision (AMP)
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/doserbd/git/Gaia/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                | Type                   | Params | Mode  | In sizes    | Out sizes       
-----------------------

Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


## Export the trained model to ONNX

In [6]:
import torch

onnx = torch.onnx.export(
    model.variational_encoder,
    torch.randn(1, 1, 343, device="cpu"),
    dynamic_axes={"input": {0: "batch"}},
    dynamo=True,
)
onnx.optimize()
onnx.save("encoder.onnx")

onnx = torch.onnx.export(
    model.decoder,
    torch.randn(1, 3, device="cpu"),
    dynamic_axes={"input": {0: "batch"}},
    dynamo=True,
)
onnx.optimize()
onnx.save("decoder.onnx")



OnnxExporterError: Failed to export the model to ONNX. Generating SARIF report at 'report_dynamo_export.sarif'. SARIF is a standard format for the output of static analysis tools. SARIF logs can be loaded in VS Code SARIF viewer extension, or SARIF web viewer (https://microsoft.github.io/sarif-web-component/). Please report a bug on PyTorch Github: https://github.com/pytorch/pytorch/issues

## Visualize the ONNX model with netron 

In [None]:
!pip install -q netron
import netron
netron.start('data/gaia/models/encoder.onnx', 8081)