# Demo for Spherinator Training using Illustris

## Download data

the API key can be obtained from the Illustris website: https://www.tng-project.org/data/access/

To avoid hardcoding the API key in the script, you can also store it in a text file

In [None]:
import os

# ILLUSTRIS_API_KEY = "PUT HERE YOUR ILLUSTRIS API KEY"

with open(".illustris_api_key.txt", "r") as file:
    ILLUSTRIS_API_KEY = file.read().rstrip()

subhalo_ids = [454795,454898,454963,455058,455109,455198,455258,455335,455413,
               455479,455551,455637,455730,455857,455957,456014,456114,456168,
               456234,456283,456381,456456,456538,456584,456634,456725,456786,
               456872,456921,457000,457086,457169,457227,457296,457361,457452,
               457514,457604,457697,457781,457871,457931,457998,458028,458111,
               458174,458231,458302,458378,458447,458509,458604,458658,458722,
               458784,458864,458945,459020,459084,459169,459243,459270,459360,
               459394,459517,459576,459665,459730,459786,459850,459906,459959,
               460008,460076,460117,460193,460273,460351,460434,460526,460595,
               460692,460746,460823,460888,460939,461038,461136,461202,461283,
               461364,461450,461521,461609,461667,461709,461806,461864,461929,
               462010,462077,462141,462189,462241,462323,462391,462481,462564,
               462631,462690,462775,462832,462904,462986,463062,463139,463233,
               463278,463340,463395,463453,463521,463597,463649,463750,463804,
               463894,463958,464018,464110,464182,464247,464292,464331,464422,
               464490,464539,464576,464669,464742,464788,464894,464936,465016,
               465080,465136,465205,465284,465320,465361,465412,465495,465548,
               465614,465693,465764,465842,465921,466003,466055,466133,466182,
               466265,466332,466387,466436,466493,466599,466694,466746,466801,
               466894,466958,467011,467052,467127,467212,467256,467308,467385,
               467445,467519,467575,467628,467700,467740,467798,467871,467919,
               467958,468006,468064,468110,468168,468251,468318,468382,468450,
               468508,468590]

for sid in subhalo_ids:
    if os.path.exists(f"./data/illustris/fits/TNG100/sdss/snapnum_099/data/broadband_{sid}.fits"):
        continue
    !wget -nc -P ./data/illustris/fits/TNG100/sdss/snapnum_099/data --content-disposition --header="API-Key: {ILLUSTRIS_API_KEY}" "http://www.tng-project.org/api/TNG100-1/snapshots/99/subhalos/{sid}/skirt/broadband_sdss.fits"

## Data preparation

In [None]:
from pest import FitsConverter

image_size = 128
FitsConverter(image_size).convert_all(
    "data/illustris/fits/TNG100/sdss/snapnum_099/data", "data/illustris/parquet"
)

## Visualize training data

To get an impression of the data, we can visualize the first 50 images of the training data.

In [None]:
import pyarrow.dataset as ds
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

# Ensure plots are displayed inline in the notebook
%matplotlib inline

dataset = ds.dataset("data/illustris/parquet", format="parquet")
print("Number of data points:", dataset.count_rows())

df = dataset.to_table().to_pandas()

fig, axes = plt.subplots(5, 10, figsize=(15, 8))
for i, ax in enumerate(axes.flatten()):
    data = np.array(df["data"][i]).reshape(3, image_size, image_size).transpose(1, 2, 0) * 255
    image = Image.fromarray(data.astype(np.uint8), "RGB")
    ax.imshow(image)
    ax.axis("off")
plt.show()

## Training the model

In [None]:
import spherinator.models as sm

model = sm.VariationalAutoencoder(
    encoder=sm.ConvolutionalEncoder2D(
        input_dim=[3, image_size, image_size],
        output_dim=128,
        cnn_layers=[
            sm.ConsecutiveConv2DLayer(
                kernel_size=3,
                stride=2,
                padding=0,
                num_layers=1,
                base_channel_number=16,
                norm=None,
            ),
        ],
    ),
    decoder=sm.ConvolutionalDecoder2D(
        input_dim=3,
        output_dim=[3, image_size, image_size],
        cnn_input_dim=[128, 36],
        cnn_layers=[
            sm.ConsecutiveConvTranspose2DLayer(
                kernel_size=3,
                stride=2,
                padding=0,
                out_channels_list=[1],
                norm=None,
                activation=None,
            ),
        ],
    ),
    z_dim=3,
    beta=1.0e-4,
    encoder_out_dim=128,
)
# _ = model(model.example_input_array)
# model

-> use every pytorch module
-> explain the model architecture

In [None]:
import spherinator.data as sd

datamodule = sd.ParquetDataModule(
    data_directory="data/illustris/parquet",
    data_column="data",
    normalize="minmax",
    batch_size=256,
    num_workers=4,
    shuffle=True,
)

datamodule.setup("fit")
dataloader = datamodule.train_dataloader()
print(dataloader.batch_size)

batch = next(iter(dataloader))
print(batch.shape)

In [None]:
import lightning.pytorch as pl

trainer = pl.Trainer(
    max_epochs=10,
    accelerator="gpu",
    precision="16-mixed",
)
trainer.fit(model, datamodule=datamodule)

## Export the trained model to ONNX

In [None]:
import torch

onnx = torch.onnx.export(
    model.variational_encoder,
    torch.randn(1, 1, 343, device="cpu"),
    dynamic_axes={"input": {0: "batch"}},
    dynamo=True,
)
onnx.optimize()
onnx.save("data/illustris/models/encoder.onnx")

onnx = torch.onnx.export(
    model.decoder,
    torch.randn(1, 3, device="cpu"),
    dynamic_axes={"input": {0: "batch"}},
    dynamo=True,
)
onnx.optimize()
onnx.save("data/illustris/models/decoder.onnx")

## Visualize the ONNX model with netron 

In [None]:
!pip install -q netron
import netron
netron.start('data/illustris/models/decoder.onnx', 8081)

## Pytorch Lightning Commnd Line Interface (CLI)

Start the training using a unique [yaml config-file](./configs/spherinator/illustris.yaml) in reproducible mode.

```bash
spherinator fit --c configs/spherinator/illustris.yaml
```