In [1]:
import tensorflow as tf
from tensorflow import keras

from keras.layers import Conv1D , Dropout , Flatten , MaxPooling1D, Dense, Input, BatchNormalization
from keras.layers.core import Lambda
from keras.models import Model , load_model
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
import keras.backend as K

from astroNN.gaia import mag_to_fakemag
from astroNN.gaia import fakemag_to_logsol

import numpy as np
import matplotlib.pyplot as plt
import random
import h5py
from IPython.display import Image

In [12]:
def SpectrumParallax(dim_t , dim_n, dropout_iterations = 100): 
    """
    INPUT: 
    dim_t - number of time steps of spectrum 
    dim_n - number of features of spectrum
    """
    
    #SPECTRUM TO LUINOSITY
    dim_1 = 1 # number of corrected magnitude for one example 
    units = 1 #number of final output for one example
    
    inputs_spectra = Input(shape=(dim_t, dim_n), name="pseudo-lum-input") 
    inputs_mag = Input(shape=(dim_1,), name="K_mag")
    inputs_error_paralaje = Input(shape=(dim_1,), name="error_paralaje")
    inputs_offset = Input(shape=(3,), name="offset-input")
    print("inputs_mag: ",inputs_mag)
    print("inputs_error_paralaje: ",inputs_error_paralaje)
    
    
    #x_parallax_list = []
    
    #for i in range(droput_iterations):
    x_parallax = Conv1D(filters=2, kernel_size=3, activation='relu')(inputs_spectra)
    x_parallax = BatchNormalization()(x_parallax)
    x_parallax = MaxPooling1D(pool_size=2)(x_parallax)
    x_parallax = Dropout(0.3)(x_parallax, training=True)

    x_parallax = Conv1D(filters=4, kernel_size=3, activation='relu')(x_parallax)
    x_parallax = BatchNormalization()(x_parallax)
    x_parallax = MaxPooling1D(pool_size=2)(x_parallax)
    x_parallax = Dropout(0.3)(x_parallax, training=True)

    x_parallax = Flatten()(x_parallax)
    x_parallax = Dense(128, activation='relu')(x_parallax)
    x_parallax = Dropout(0.3)(x_parallax, training=True)
        
    x_parallax = Dense(64, activation='relu')(x_parallax) 
    x_parallax = Dropout(0.3)(x_parallax, training=True)
    x_parallax = Dense(32, activation='relu')(x_parallax)
    x_parallax = Dropout(0.3)(x_parallax, training=True)
    x_parallax = Dense(units, activation='softplus', name="pseudo-lum")(x_parallax) 
    
    
    #Functions
    outputs_parallax = Lambda(lambda function: tf.math.multiply(function[0], tf.math.pow(10., 
                              tf.math.multiply(-0.2, function[1]))),
                              name='parallax')([x_parallax, inputs_mag])
    

    #Model setup
    model =  Model(inputs = [inputs_spectra,inputs_mag],outputs = [outputs_parallax])

    model.compile(loss='mse', optimizer='adam', metrics=['mse'])
    
    return model 

In [3]:
# Load of the data

with h5py.File('train_set_gaiaedr3_apogeedr16_1.h5','r') as F:  
    parallax = np.array(F['parallax'])
    parallax_error = np.array(F['parallax_err'])
    spectra = np.array(F['spectra'])
    Kmag = np.array(F['corrected_K'])
    
idx = []
for i in range(len(parallax)):
    idx.append(i)
random.seed(10)
random.shuffle(idx)

parallax = parallax[idx]
parallax_error = parallax_error[idx]
spectra = spectra[idx]
Kmag = Kmag[idx]

In [5]:
X = np.expand_dims(spectra, axis = 2)
Y = np.expand_dims(parallax, axis = 1)
K_mag = np.expand_dims(Kmag, axis = 1)
Y_error = np.expand_dims(parallax_error, axis = 1)

In [8]:
X.shape , Y.shape

((13976, 7514, 1), (13976, 1))

In [10]:
n_timesteps, n_features = X.shape[1], X.shape[2]

Global_model = SpectrumParallax(n_timesteps , n_features)

Global_model.summary()

inputs_mag:  KerasTensor(type_spec=TensorSpec(shape=(None, 1), dtype=tf.float32, name='K_mag'), name='K_mag', description="created by layer 'K_mag'")
inputs_error_paralaje:  KerasTensor(type_spec=TensorSpec(shape=(None, 1), dtype=tf.float32, name='error_paralaje'), name='error_paralaje', description="created by layer 'error_paralaje'")
Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
pseudo-lum-input (InputLayer)   [(None, 7514, 1)]    0                                            
__________________________________________________________________________________________________
conv1d_2 (Conv1D)               (None, 7512, 2)      8           pseudo-lum-input[0][0]           
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 7512, 2)    

In [11]:
earlystopper = EarlyStopping(monitor='val_loss', patience=50, verbose=1, min_delta=1e-7)
checkpoint = ModelCheckpoint('Model1_13976spectra.h5', monitor='val_loss', 
                             verbose=1, save_best_only=True)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, verbose=1, patience=5, min_lr=0.000000001)

callbacks=[reduce_lr, checkpoint, earlystopper]

Global_model.fit([X, K_mag], Y,epochs=20, 
                 batch_size=64, verbose=1, shuffle="batch" ,callbacks=callbacks,validation_split=0.2)


Epoch 1/20

Epoch 00001: val_loss improved from inf to 2.37199, saving model to Model1_13976spectra.h5
Epoch 2/20

Epoch 00002: val_loss improved from 2.37199 to 2.36997, saving model to Model1_13976spectra.h5
Epoch 3/20

Epoch 00003: val_loss improved from 2.36997 to 2.06805, saving model to Model1_13976spectra.h5
Epoch 4/20

Epoch 00004: val_loss did not improve from 2.06805
Epoch 5/20

Epoch 00005: val_loss improved from 2.06805 to 2.00025, saving model to Model1_13976spectra.h5
Epoch 6/20

Epoch 00006: val_loss did not improve from 2.00025
Epoch 7/20

Epoch 00007: val_loss did not improve from 2.00025
Epoch 8/20

Epoch 00008: val_loss did not improve from 2.00025
Epoch 9/20

Epoch 00009: val_loss did not improve from 2.00025
Epoch 10/20

Epoch 00010: ReduceLROnPlateau reducing learning rate to 0.0005000000237487257.

Epoch 00010: val_loss did not improve from 2.00025
Epoch 11/20

Epoch 00011: val_loss did not improve from 2.00025
Epoch 12/20

Epoch 00012: val_loss did not improve f

KeyboardInterrupt: 