# Demo for Spherinator Training using Gaia XP DR3

## Download the Gaia DR3 XP data

For the demo only use the first of 3386 files.

In [None]:
!wget -nc -P ./data/gaia/csv https://cdn.gea.esac.esa.int/Gaia/gdr3/Spectroscopy/xp_continuous_mean_spectrum/XpContinuousMeanSpectrum_000000-003111.csv.gz

## Data preparation

In [None]:
from pest import GaiaConverter

gaia_converter = GaiaConverter(
    with_flux_error=True,
    number_of_workers=1,
)
gaia_converter.convert_all("data/gaia/csv", "data/gaia/parquet")

## Visualize training data

To get an impression of the data, we can visualize the first 10 images of the training data. The
data is normalized to be in the range [0, 1].

In [None]:
import pyarrow.dataset as ds
import hipster
from PIL import Image

dataset = ds.dataset("data/gaia/parquet", format="parquet")
print("Number of spectra:", dataset.count_rows())

df = dataset.to_table().to_pandas()
data = df["flux"][1]
spectrum = hipster.SpectrumPlotter(
    wavelengths=hipster.Range(336, 1021, 2),
    # ylim=(0, 1),
    figsize_in_pixel=512,
)(data)

image = Image.fromarray(spectrum)
image

## Training the model

In [None]:
import spherinator.models as sm

model = sm.VariationalAutoencoder(
    encoder=sm.ConvolutionalEncoder1D(
        input_dim=[1, 343],
        output_dim=128,
        cnn_layers=[
            sm.ConsecutiveConv1DLayer(
                kernel_size=7,
                stride=1,
                padding=0,
                num_layers=5,
                base_channel_number=16,
                channel_increment=4,
            ),
            sm.ConsecutiveConv1DLayer(
                kernel_size=5,
                stride=2,
                padding=0,
                num_layers=1,
                base_channel_number=64,
            ),
            sm.ConsecutiveConv1DLayer(
                kernel_size=5,
                stride=2,
                padding=0,
                num_layers=1,
                base_channel_number=96,
            ),
            sm.ConsecutiveConv1DLayer(
                kernel_size=5,
                stride=2,
                padding=0,
                num_layers=1,
                base_channel_number=128,
            ),
        ],
    ),
    decoder=sm.ConvolutionalDecoder1D(
        input_dim=3,
        output_dim=[1, 343],
        cnn_input_dim=[128, 36],
        cnn_layers=[
            sm.ConsecutiveConvTranspose1DLayer(
                kernel_size=6,
                stride=2,
                padding=0,
                out_channels_list=[96],
            ),
            sm.ConsecutiveConvTranspose1DLayer(
                kernel_size=5,
                stride=2,
                padding=0,
                out_channels_list=[64],
            ),
            sm.ConsecutiveConvTranspose1DLayer(
                kernel_size=5,
                stride=2,
                padding=0,
                out_channels_list=[32],
            ),
            sm.ConsecutiveConvTranspose1DLayer(
                kernel_size=7,
                stride=1,
                padding=0,
                out_channels_list=[28, 24, 20, 16, 1],
                activation=None,
            ),
        ],
    ),
    z_dim=3,
    beta=1.0e-4,
    encoder_out_dim=128,
)
# _ = model(model.example_input_array)
# model

-> use every pytorch module
-> explain the model architecture

In [None]:
import spherinator.data as sd

datamodule = sd.ParquetDataModule(
    data_directory="data/gaia/parquet",
    data_column="flux",
    normalize="minmax",
    batch_size=2048,
    num_workers=4,
    shuffle=True,
)
# datamodule.setup("fit")
# print(f"Number of training items: {len(datamodule.data_train)}")

In [None]:
import lightning.pytorch as pl

trainer = pl.Trainer(
    max_epochs=10,
    accelerator="gpu",
    precision="16-mixed",
)
trainer.fit(model, datamodule=datamodule)

## Export the trained model to ONNX

In [None]:
import torch

onnx = torch.onnx.export(
    model.variational_encoder,
    torch.randn(1, 1, 343, device="cpu"),
    dynamic_axes={"input": {0: "batch"}},
    dynamo=True,
)
onnx.optimize()
onnx.save("data/gaia/models/encoder.onnx")

onnx = torch.onnx.export(
    model.decoder,
    torch.randn(1, 3, device="cpu"),
    dynamic_axes={"input": {0: "batch"}},
    dynamo=True,
)
onnx.optimize()
onnx.save("data/gaia/models/decoder.onnx")

## Visualize the ONNX model with netron 

In [None]:
!pip install -q netron
import netron
netron.start('data/gaia/models/decoder.onnx', 8081)

## Pytorch Lightning Commnd Line Interface (CLI)

Start the training using a unique [yaml config-file](./configs/spherinator/gaia_vae_8.yaml) in reproducible mode.

```bash
spherinator fit --c configs/spherinator/gaia_vae_8.yaml
```