## Multimodal display STELIB library (SVO)
Based on data from the STELIB service developed by the Spanish Virtual Observatory in the framework of the IAU Comission G5 Working Group : Spectral Stellar Libraries
http://svocats.cab.inta-csic.es/stelib/index.php
Data set: http://svocats.cab.inta-csic.es/stelib/index.php?action=search

Adrián García Riber and Francsico Serradilla.
Polytechnic University of Madrid

In [None]:
from astropy.io import fits
import matplotlib.pylab as plt
import numpy as np

import os
from pathlib import Path

import tensorflow as tf
from tensorflow.keras import layers

In [None]:
root = '/Users/adrian/Documents/FITS_Library/stelib'

In [None]:
# Path and name of one file to check the library
file = root+"/stelib_spec_fits_HD268623_moy.fits"

# Print the header
sp = fits.open(file)
print('\n\nHeader of the spectrum :\n\n', sp[0].header, '\n\n')

# Extracting and normalizing the fluxes
flux2 = np.array(sp[0].data)
flux_norm = np.reshape(flux2/(np.nanmax(flux2)), (sp[0].header['NAXIS1']))

# Extracting the wavelengths
wave2 = np.ones(sp[0].header['NAXIS1'], dtype=float)
for i in range(sp[0].header['NAXIS1']):
    wave2[i] = sp[0].header['CRVAL1'] + i*sp[0].header['CDELT1']

# Closing the fits-file
sp.close()
# Plot the spectrum
fig = plt.figure(1, figsize=(12, 8))
plt.plot(wave2, flux_norm)
plt.xlabel('Wavelength [Å]')
plt.ylabel('ADU')
plt.title(file)
plt.show()

In [None]:
# Counting the spectra and printing the spectrum dimension
num = 1
for path, subdirs, files in os.walk(root):
    for nanme in files:
        num += 1
dim1 = sp[0].header['NAXIS1']
print(num)
print(dim1)

In [None]:
# Creating the custom_set with all the spectra and generating labels to enable recovering header information
curves = 0
custom_set = np.zeros((num, dim1))
label_set = np.zeros((num, ), dtype=int)
spectra_set = [''] * num 

for path, subdirs, files in os.walk(root):
    for name in files:
        file = [os.path.join(path, name)]
        str = " " 
        Ffile = (str.join(file))
        route = Path(name)
        Fname = route.with_suffix('')
        Fpng = route.with_suffix('.png')

        data, header = fits.getdata(Ffile, header=True)
        hdu_number = 0
        fits.getheader(Ffile, hdu_number)
        fits_file = Ffile
        
        with fits.open(fits_file, mode='readonly') as hdulist:
            hdulist.info()
            data = np.array(hdulist[0].data)
             
            data_norm = np.reshape(data/(np.nanmax(data)), (sp[0].header['NAXIS1']))

            
            label_set[curves] = curves 
            spectra_set[curves] = name
            for i in range (dim1):
                custom_set[curves,i] = (data_norm[i])
        hdulist.close   
        curves += 1
               
        print ("Spectra loaded:",curves+1, "spectra");


In [None]:
custom_set.shape

In [None]:
label_set.shape

## Importing the VAE model
Generated with the 6D-VAE_STELIB-augmented_Demo Notebook

In [None]:
vae6D = tf.keras.models.load_model('STELIB_6DVAE-augmented_OK.tf')

In [None]:
vae6D.load_weights('STELIB_6DVAE-augmented_Weights')

In [None]:
encoder = tf.keras.models.load_model('STELIB_6D_Encoder-augmented_OK.tf')

In [None]:
decoder = tf.keras.models.load_model('STELIB_6D_Decoder-augmented_OK.tf')

In [None]:
class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim), mean=0., stddev=0.1)
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

In [None]:
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
vae6D.compile(optimizer, loss=tf.keras.losses.MeanSquaredError())

## Obtaining the latent vectors and decoded spectra

In [None]:
encoded_spectra = encoder(custom_set)

In [None]:
decoded_spectra = vae6D(custom_set)

In [None]:
encoded_spectra.shape

## Exploring the library

In [None]:
from sklearn.metrics import r2_score

from urllib.parse import urlencode
from astroquery.simbad import Simbad                                                            
from astropy.coordinates import SkyCoord
import astropy.units as u

import aplpy

from pythonosc import udp_client
import time

In [None]:
encoded_spectra.shape

In [None]:
encoded_spectra = np.squeeze(encoded_spectra)

In [None]:
x = np.zeros((len(encoded_spectra), 6))

In [None]:
ra_stelib = []
dec_stelib = []
for j in range(len(encoded_spectra)-1):
    
    #   Opening FITS and getting coordinates
    target = files[j]
    file2 = root+"/"+target

    sp2 = fits.open(file2)
    print('\n\nHeader of the spectrum :\n\n', sp2[0].header, '\n\n')
    ra_spectra = sp2[0].header['RA']
    dec_spectra = sp2[0].header['DEC']
    print(sp2[0].header['OBJECT'])   
    sp2.close()
#--------------------------------------            
     #   Converting coordinates
    try:
        hours, minutes, seconds = ra_spectra.split(':')
        ra_ok = f"{hours}h{minutes}m{seconds}s"
        days, minutes, seconds = dec_spectra.split(':')
        dec_ok = f"{days}d{minutes}m{seconds}s"
        c = SkyCoord(ra_ok, dec_ok, frame='icrs')
        ra_stelib.append(ra_ok)
        dec_stelib.append(dec_ok)
    except:
        ra_ok = 0
        dec_ok = 0
        ra_stelib.append(ra_ok)
        dec_stelib.append(dec_ok)

  

In [None]:
coords = SkyCoord(ra_stelib,dec_stelib,frame='icrs',unit='deg')
coords

In [None]:
import warnings

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from mocpy import MOC, World2ScreenMPL

from astropy.coordinates import Angle, SkyCoord
import astropy.units as u
from astroquery.vizier import Vizier
import plotly.express as px

## Graphical Stelib

In [None]:
from astropy.visualization.wcsaxes.frame import EllipticalFrame

fig = plt.figure(figsize=(16,8))
 
with World2ScreenMPL(
    fig,
    fov=320 * u.deg,
    center=SkyCoord(0, 0, unit='deg', frame='icrs'),
    coordsys="icrs",
    rotation=Angle(0, u.degree),
) as wcs:
    ax = fig.add_subplot(111, projection=wcs, frame_class=EllipticalFrame)
    ax.set_title("STELIB library")
    ax.grid(color="black", linestyle="dotted")    
    ax.scatter(coords.ra,coords.dec,marker='o',color='#1f82c0',s=18,transform=ax.get_transform('world'),zorder=10)

## "On the fly" Multimodal display

In [None]:
for j in range(len(encoded_spectra)-1):
    # Opening FITS and getting coordinates
    target = files[j]
    file2 = root+"/"+target
    sp2 = fits.open(file2)
    ra_spectra = sp2[0].header['RA']
    dec_spectra = sp2[0].header['DEC']
    print(sp2[0].header['OBJECT'])   
    sp2.close()
    # Converting coordinates
    try:
        hours, minutes, seconds = ra_spectra.split(':')
        ra_ok = f"{hours}h{minutes}m{seconds}s"
        days, minutes, seconds = dec_spectra.split(':')
        dec_ok = f"{days}d{minutes}m{seconds}s"
        c = SkyCoord(ra_ok, dec_ok, frame='icrs')
        ra_stelib.append(ra_ok)
        dec_stelib.append(dec_ok)
    except:
        ra_ok = 0
        dec_ok = 0
    objects = sp2[0].header['OBJECT']
    # Closing FITS
    sp2.close() 
    # Querying and Plotting star    
    try:
        query_params = { 
                 'hips': 'DSS', 
                 'object': objects, 
                 'fov': (3 * u.arcmin).to(u.deg).value, 
                 'width': 800, 
                 'height': 350 
                }                                                                                               
        url = f'http://alasky.u-strasbg.fr/hips-image-services/hips2fits?{urlencode(query_params)}' 
        hdul = fits.open(url)
        gc = aplpy.FITSFigure(hdul)                                                                     
        gc.show_grayscale() 
        gc.save('Star.png', transparent=True)  
        # Representing coordinates
        fig3 = plt.figure(figsize=(16,8))
        with World2ScreenMPL(
            fig3,
            fov=320 * u.deg,
            center=SkyCoord(0, 0, unit='deg', frame='icrs'),
            coordsys="icrs",
            rotation=Angle(0, u.degree),
        ) as wcs:
            ax = fig3.add_subplot(111, projection=wcs, frame_class=EllipticalFrame)
            ax.set_title("STELIB library", fontsize=18)
            ax.grid(color="black", linestyle="dotted")
            # Representing all stars of STELIB library
            ax.scatter(coords.ra,coords.dec
                       ,marker='o',color='#1f82c0'
                       ,s=18,transform=ax.get_transform('world'),zorder=10)
            # Representing current star            
            ax.scatter(c.ra.degree,c.dec.degree
                       ,marker='o',color='red'
                       ,s=52,transform=ax.get_transform('world'),zorder=10)
            # Plotting sky map with all stars in blue and current star in red       
            plt.savefig('Sky.png', transparent=True)


        # Multiplying factor to reach audible range         
        x[j] = encoded_spectra[j]*100000
        # Sending via OSC         
        client_s = udp_client.SimpleUDPClient("127.0.0.1", 9989)
        client_s.send_message("/s", 1)

        client0 = udp_client.SimpleUDPClient("127.0.0.1", 9990)   
        client0.send_message("/lat0", x[j][0])
        print("latent 0 =", x[j][0])

        client1 = udp_client.SimpleUDPClient("127.0.0.1", 9991) 
        client1.send_message("/lat1", x[j][1])
        print("latent 1 =", x[j][1])

        client2 = udp_client.SimpleUDPClient("127.0.0.1", 9992) 
        client2.send_message("/lat2", x[j][2])
        print("latent 2 =", x[j][2])

        client3 = udp_client.SimpleUDPClient("127.0.0.1", 9993)
        client3.send_message("/lat3", x[j][3])
        print("latent 3 =", x[j][3])

        client4 = udp_client.SimpleUDPClient("127.0.0.1", 9994) 
        client4.send_message("/lat4", x[j][4])
        print("latent 4 =", x[j][4])

        client5 = udp_client.SimpleUDPClient("127.0.0.1", 9995)
        client5.send_message("/lat5", x[j][5])
        print("latent 5 =", x[j][5])

        client6 = udp_client.SimpleUDPClient("127.0.0.1", 9996) 
        client6.send_message("/ra", c.ra.degree)
        print("ra =", c.ra.degree)

        client7 = udp_client.SimpleUDPClient("127.0.0.1", 9997) 
        client7.send_message("/dec", c.dec.degree)
        print("dec =", c.dec.degree)
        
        print("Rsquared:", r2_score(custom_set[j], decoded_spectra[j], multioutput='variance_weighted'))
        
        # Introducing the duration (3.5)
        time.sleep(3.5)

        client_s.send_message("/s", 0)
        # Plotting and saving the original spectra and decoded output                             
        fig, ax = plt.subplots(1, 2, figsize=(16, 6))
        ax[0].plot(wave2, custom_set[j])
        ax[0].set_xlabel('Original Spectrum [Å]')
        ax[0].set_ylabel('ADU')
        plt.title(sp2[0].header['OBJECT'],fontsize=14)
        ax[1].plot(wave2, decoded_spectra[j])
        ax[1].set_xlabel('Decoded Spectrum [Å]')
        ax[1].set_ylabel('ADU')     
        plt.savefig('VAE_result.png', transparent=True)
    # Managing failure    
    except:
        client_s = udp_client.SimpleUDPClient("127.0.0.1", 9989)
        client_s.send_message("/s", 0)