<img style="float: right; margin-right: 100px" src="./images/skirt.jpg" width="200" height="200" />

# Spherinator Training using Illustris TNG

This notebook demonstrates how to train a Spherinator model using the Illustris TNG dataset.

## Download data

For a small test data set we use 200 selected synthetic
[SKIRT](https://www.tng-project.org/data/docs/specifications/#sec5l) images the Illustris TNG100-1
simulation.

:::note
For the download an secret API key is needed.
Please register at [Illustris TNG](https://www.tng-project.org/data/) and put your API
key to the `.illustris_api_key.txt` file.

In [None]:
import os

with open(".illustris_api_key.txt", "r") as file:
    ILLUSTRIS_API_KEY = file.read().rstrip()

with open("subhalo_ids.txt", "r") as file:
    subhalo_ids = [int(line.strip()) for line in file.readlines()]

for sid in subhalo_ids:
    if os.path.exists(f"./data/illustris/fits/TNG100/sdss/snapnum_099/data/broadband_{sid}.fits"):
        continue
    !wget -nc -P ./data/illustris/fits/TNG100/sdss/snapnum_099/data --content-disposition \
        --header="API-Key: {ILLUSTRIS_API_KEY}" \
        "http://www.tng-project.org/api/TNG100-1/snapshots/99/subhalos/{sid}/skirt/broadband_sdss.fits"

## Data preparation

We use PEST to transform the FITS files to a parquet format. 

In [None]:
from pest import FitsConverter

FitsConverter(image_size=128).convert_all(
    "data/illustris/fits/TNG100/sdss/snapnum_099/data", "data/illustris/parquet"
)

In the parquet schema the metadata `simulation`, `snapshot`, and `subhalo_id` are stored in the
`metadata` column. The `data` column contains the actual data as a list. The actual shape `(3, 128,
128)` is stored in the schema metadata.

In [None]:
import pyarrow.dataset as ds

dataset = ds.dataset("data/illustris/parquet", format="parquet")
dataset.schema

## Visualize training data

To get an impression of the data, we can visualize the first 50 images of the training data.

In [None]:
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline

df = dataset.to_table().to_pandas()
fig, axes = plt.subplots(5, 5, figsize=(15, 15))
for i, ax in enumerate(axes.flatten()):
    data = np.array(df["data"][i]).reshape(3, 128, 128).transpose(1, 2, 0) * 255
    image = Image.fromarray(data.astype(np.uint8), "RGB")
    ax.imshow(image)
    ax.axis("off")
plt.show()

## Training the model

As model we define a `VariationalAutoencoder` with a convolutional network as encoder and decoder.

In [None]:
import spherinator.models as sm

model = sm.VariationalAutoencoder(
    encoder=sm.ConvolutionalEncoder2D(
        input_dim=[3, 128, 128],
        output_dim=128,
        cnn_layers=[
            sm.ConsecutiveConv2DLayer(
                kernel_size=3,
                stride=1,
                padding=0,
                out_channels=[16, 20, 24],
            ),
            sm.ConsecutiveConv2DLayer(
                kernel_size=4,
                stride=2,
                padding=0,
                out_channels=[64, 128],
            ),
        ],
    ),
    decoder=sm.ConvolutionalDecoder2D(
        input_dim=3,
        output_dim=[3, 128, 128],
        cnn_input_dim=[128, 28, 28],
        cnn_layers=[
            sm.ConsecutiveConvTranspose2DLayer(
                kernel_size=5,
                stride=2,
                padding=0,
                out_channels=[64],
            ),
            sm.ConsecutiveConvTranspose2DLayer(
                kernel_size=6,
                stride=2,
                padding=0,
                out_channels=[24],
            ),
            sm.ConsecutiveConvTranspose2DLayer(
                kernel_size=3,
                stride=1,
                padding=0,
                out_channels=[20, 16, 3],
                activation=None,
            ),
        ],
    ),
    z_dim=3,
    beta=1.0e-4,
    encoder_out_dim=128,
)

## ParquetDataModule

In [None]:
from spherinator.data import ParquetDataModule

datamodule = ParquetDataModule(
    data_directory="data/illustris/parquet",
    data_column="data",
    normalize="minmax",
    batch_size=256,
    num_workers=4,
    shuffle=True,
)

In [None]:
from lightning.pytorch import Trainer

trainer = Trainer(
    max_epochs=10,
    accelerator="gpu",
    precision="16-mixed",
)
trainer.fit(model, datamodule=datamodule)

## Visualize the reconstructed images

In [None]:
# TODO: Maybe show tensorboard

## Export the trained model to ONNX

- The model include the variational autoencoder part, which is not needed for the Inference.
- We export only the encoder and the decoder part of the model.
- Dynamic axes are used to allow for variable input sizes.
- Unique names are used for the input and output tensors.

In [None]:
import torch

onnx = torch.onnx.export(
    model.variational_encoder,
    torch.randn(2, 3, 128, 128, device="cpu"),
    dynamic_axes={"x": {0: "batch"}},
    input_names=["x"],
    output_names=["coord", "scale"],
    dynamo=True,
)
onnx.optimize()
onnx.save("data/illustris/models/encoder.onnx")

onnx = torch.onnx.export(
    model.decoder,
    torch.randn(2, 3, device="cpu"),
    dynamic_axes={"input": {0: "batch"}},
    dynamo=True,
)
onnx.optimize()
onnx.save("data/illustris/models/decoder.onnx")

## Visualize the ONNX model with netron

[Netron](https://netron.app) is a viewer for neural network models.
We can use it to visualize the ONNX model we just exported.

In [None]:
!pip install -q netron
import netron

netron.start('data/illustris/models/encoder.onnx', 8082)

## Pytorch Lightning Commnd Line Interface (CLI)

Start the training using a unique [yaml config-file](./configs/spherinator/illustris.yaml) in reproducible mode.

In [None]:
!spherinator fit -c configs/spherinator/illustris.yaml