# Spherinator & HiPSter

This notebook demonstrates how to train a Spherinator model using the Illustris TNG dataset.



>[Spherinator & HiPSter](#scrollTo=VDLEcO17VV5n)

>[Part 1: Spherinator - The Training](#scrollTo=cr4MS5QWElnI)

>>[Download data](#scrollTo=rAGt3xTXVV5o)

>>[Visualize training data](#scrollTo=q0B8v-QzVV5p)

>>[Define the model](#scrollTo=0dn14y9eVV5p)

>>[Define the Parquet Data Module](#scrollTo=xlmTZ74vVV5q)

>>[Setup the PyTorch Lightning Trainer and start the fitting process](#scrollTo=is4LBLalH22a)

>>[Export the trained model to ONNX](#scrollTo=g2j8u6nyVV5q)

>[Part 2: HiPSter - The Inference](#scrollTo=emV4xrtPC12n)

>>[Catalog as VOTable](#scrollTo=U3R61t0CGvKK)

>>[Visualize HiPS tiles and catalog using Aladin-Lite](#scrollTo=pTeGzP8cIcF8)



# Part 1: Spherinator - The Training

First we have to install the Spherinator package.

In [None]:
if 'google.colab' in str(get_ipython()):
    %pip -q install git+https://github.com/HITS-AIN/Spherinator
import spherinator
print(spherinator.__version__)

## Download data

For a small test data set we use 200 selected synthetic
[SKIRT](https://www.tng-project.org/data/docs/specifications/#sec5l) images the Illustris TNG100-1
simulation.

In the parquet schema the metadata `simulation`, `snapshot`, and `subhalo_id` are stored in the
`metadata` column. The `data` column contains the actual data as a list. The actual shape `(3, 128,
128)` is stored in the schema metadata.

In [2]:
%pip -q install --upgrade gdown
import gdown
gdown.download('https://drive.google.com/uc?id=1XxPUdoKpZCNKnh3X725V1fjQN8pXJnS2', 'illustris.parquet')


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


Downloading...
From: https://drive.google.com/uc?id=1XxPUdoKpZCNKnh3X725V1fjQN8pXJnS2
To: /home/doserbd/git/SPACE_HPC_Visualization_Workshop/colab/illustris.parquet
100%|██████████| 76.2M/76.2M [00:51<00:00, 1.49MB/s]


'illustris.parquet'

In [None]:
import pyarrow.dataset as ds

dataset = ds.dataset("illustris.parquet", format="parquet")
dataset.schema

In [None]:
df = dataset.to_table().to_pandas()
df

## Visualize training data

To get an impression of the data, we can visualize the first 50 images of the training data.

In [None]:
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline

fig, axes = plt.subplots(5, 5, figsize=(15, 15))
for i, ax in enumerate(axes.flatten()):
    data = np.array(df["data"][i]).reshape(3, 128, 128).transpose(1, 2, 0) * 255
    image = Image.fromarray(data.astype(np.uint8), "RGB")
    ax.imshow(image)
    ax.axis("off")
plt.show()

## Define the model

As model we use a `VariationalAutoencoder` with a convolutional network as encoder and decoder.

In [None]:
import spherinator.models as sm

model = sm.VariationalAutoencoder(
    encoder=sm.ConvolutionalEncoder2D(
        input_dim=[3, 128, 128],
        output_dim=128,
        cnn_layers=[
            sm.ConsecutiveConv2DLayer(
                kernel_size=3,
                stride=1,
                padding=0,
                out_channels=[16, 20, 24],
            ),
            sm.ConsecutiveConv2DLayer(
                kernel_size=4,
                stride=2,
                padding=0,
                out_channels=[64, 128],
            ),
        ],
    ),
    decoder=sm.ConvolutionalDecoder2D(
        input_dim=3,
        output_dim=[3, 128, 128],
        cnn_input_dim=[128, 28, 28],
        cnn_layers=[
            sm.ConsecutiveConvTranspose2DLayer(
                kernel_size=5,
                stride=2,
                padding=0,
                out_channels=[64],
            ),
            sm.ConsecutiveConvTranspose2DLayer(
                kernel_size=6,
                stride=2,
                padding=0,
                out_channels=[24],
            ),
            sm.ConsecutiveConvTranspose2DLayer(
                kernel_size=3,
                stride=1,
                padding=0,
                out_channels=[20, 16, 3],
                activation=None,
            ),
        ],
    ),
    z_dim=3,
    beta=1.0e-4,
    encoder_out_dim=128,
)

## Define the Parquet Data Module

In [None]:
from spherinator.data import ParquetDataModule

datamodule = ParquetDataModule(
    data_directory="illustris.parquet",
    data_column="data",
    normalize="minmax",
    batch_size=256,
    num_workers=4,
    shuffle=True,
)

## Setup the PyTorch Lightning Trainer and start the fitting process

In [None]:
from lightning.pytorch import Trainer

trainer = Trainer(
    max_epochs=10,
    accelerator="auto",
    precision="16-mixed",
)
trainer.fit(model, datamodule=datamodule)

## Export the trained model to ONNX

- The model include the variational autoencoder part, which is not needed for the Inference.
- We export only the encoder and the decoder part of the model.
- Dynamic axes are used to allow for variable input sizes.
- Unique names are used for the input and output tensors.

In [None]:
import torch

onnx = torch.onnx.export(
    model.variational_encoder,
    torch.randn(2, 3, 128, 128, device="cpu"),
    dynamic_axes={"x": {0: "batch"}},
    input_names=["x"],
    output_names=["coord", "scale"],
    dynamo=True,
)
onnx.optimize()
onnx.save("encoder.onnx")

onnx = torch.onnx.export(
    model.decoder,
    torch.randn(2, 3, device="cpu"),
    dynamic_axes={"input": {0: "batch"}},
    dynamo=True,
)
onnx.optimize()
onnx.save("decoder.onnx")

# Part 2: HiPSter - The Inference

In the second part we use HiPSter to perform inference on the Illustris TNG simulation data.

In [None]:
if 'google.colab' in str(get_ipython()):
    %pip -q install git+https://github.com/HITS-AIN/HiPSter

import hipster
print(hipster.__version__)

In [None]:
hipster.HiPSGenerator(
    decoder=hipster.Inference("decoder.onnx"),
    image_maker=hipster.ImagePlotter(),
    max_order=4,
    hips_path="output/illustris",
).execute()

## Catalog as VOTable

A VOTable (or HIPS catalog) can be used to visualize where an input image is located in the latent
space. The `hipster.VOTableGenerator` take all images from the `data_directory` and use the
`hipster.Inference` class to encode them into the latent space.

In [None]:
hipster.VOTableGenerator(
    encoder=hipster.Inference("encoder.onnx"),
    data_directory="illustris.parquet",
    output_file="illustris.vot",
    root_path="output",
).execute()

## Visualize HiPS tiles and catalog using Aladin-Lite

The HiPS tiles and catalogs can be visualized using
[Aladin-Lite](https://github.com/cds-astro/aladin-lite). Here we use
[ipyaladin](https://github.com/cds-astro/ipyaladin), which allows to integrate Aladin-Lite in
Jupyter.

In [None]:
%pip -q install ipyaladin

from ipyaladin import Aladin

aladin = Aladin(survey="output/illustris", fov=180, show_fullscreen_control=False)
aladin.add_catalog_from_URL("output/illustris.vot", {"source_size": 5, "color": "red"})
aladin

In [9]:
import os
import gdown

if not os.path.exists("illustris_full_trained"):
    gdown.download_folder(
        "https://drive.google.com/drive/folders/1BYLoV83Jb9IpIi2YHYUb2Wokubo9Tn3M",
        output="illustris_full_trained",
    )

In [None]:
hipster.HiPSGenerator(
    decoder=hipster.Inference("illustris_full_trained/decoder.onnx"),
    image_maker=hipster.ImagePlotter(),
    max_order=4,
    hips_path="output/illustris",
).execute()

In [None]:
from ipyaladin import Aladin

aladin = Aladin(survey="output/illustris", fov=180, show_fullscreen_control=False)
aladin