In [1]:
import sys

import numpy as np
import keras

sys.path.insert(0, "../code")
import abcsn_training
import abcsn_config

2025-07-30 17:14:59.423344: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-07-30 17:14:59.441887: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-07-30 17:14:59.447487: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-07-30 17:14:59.461826: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


# Classifying Supernova Spectra With ABC-SN

# 1. Load ABC-SN

The file `ABCSN.keras` is not hosted on GitHub because it is too large of a file. You must download it from Zenodo [here](https://zenodo.org/records/16620817). Place it in `abcsn/` and ensure that it is called `ABCSN.keras`. `ABCSN.keras` has been added to `.gitignore`.

In [2]:
abcsn = keras.models.load_model("../abcsn/ABCSN.keras")

# 2. Data

Using the function `load_data`, the training and test sets we used are automatically loaded. However, if you are wanting to use ABC-SN to classify new supernovae, you won't be using this function. Instead, all you need is an array of spectra with the correct shape.

Your array of spectra to classify, we'll call it `X`, should be a NumPy array with shape `(num_spectra, 1, num_wvl)` where `num_spectra` is the number of spectra and `num_wvl` is the size of each spectrum.

1. Each spectrum should be defined from 2500 to 10000 angstroms
2. Each spectrum should be normalized to mean zero and standard deviation one.
3. Each spectrum should be at a spectral resolution of R = 100. See `degrade_spectrum` in `data_degrading.py` if your spectrum is at a higher spectral resolution.

ABC-SN was originally developed for a dataset of 17 classes. During development, we reduced this to just 10. The dictionaries `sn_dict_trn` and `sn_dict_tst` (they are identical) serve to help translate the class ID from 0-16 to 0-9. We can invert this dictionary to translate class IDs in the range 0-9 to the range 0-16 (see `sn_10_to_17`). ABC-SN will predict targets in a range of 0-9 which we can translate to the range of 0-16. We do this because we already have dictionaries defined that take a class ID in the range of 0-16 and translates it into the SN subtype name.

In [3]:
Xtrn, Xtst, Ytrn, Ytst, num_wvl, num_classes, sn_dict_trn, sn_dict_tst, wvl = abcsn_training.load_data()

In [4]:
sn_10_to_17 = {j: i for i, j in sn_dict_trn.items()}
sn_10_to_17

{0: 0, 1: 1, 2: 2, 3: 4, 4: 6, 5: 7, 6: 8, 7: 10, 8: 11, 9: 13}

In [5]:
classes_str = [abcsn_config.SN_Stypes_str[i] for i, j in sn_dict_trn.items()]
classes_ind = [i for i, j in sn_dict_trn.items()]
classes_ind, classes_str

([0, 1, 2, 4, 6, 7, 8, 10, 11, 13],
 ['Ia-norm',
  'Ia-91T',
  'Ia-91bg',
  'Iax',
  'Ib-norm',
  'Ibn',
  'IIb',
  'Ic-norm',
  'Ic-broad',
  'IIP'])

# 3. Predict

1. `X` is your array of spectra to classify.
2. `P` is your array of output probabilities of each of the 10 classes.
3. `P_argmax` is your array of final predictions for each spectra in `X`. Class IDs defined from 0-10.
4. `P_IDs` is your array of final predictions for each spectra in `X`. Class IDs defined from 0-16.
5. `P_str` is the final SN subtype prediction for each spectra in `X`.

In [6]:
X = Xtst.copy()
P = abcsn.predict(X, verbose=0)
P_argmax = np.argmax(P, axis=1)

I0000 00:00:1753910106.435002   35617 service.cc:146] XLA service 0x2b43d403ca40 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1753910106.435079   35617 service.cc:154]   StreamExecutor device (0): Host, Default Version
2025-07-30 17:15:06.481664: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
I0000 00:00:1753910107.604683   35617 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


In [7]:
P_IDs = [sn_10_to_17[prediction] for prediction in P_argmax]
P_str = [abcsn_config.SN_Stypes_int_to_str[prediction_id] for prediction_id in P_IDs]