In [86]:
import sys
sys.path.append("..")
import numpy as np
from keras.models import load_model
import h5py
import tensorflow as tf
import time
import keras.backend as K
import matplotlib.pyplot as plt
from load_data import load_data_from_h5
plt.rc('text', usetex=True)
plt.rc('font', family='serif')

datadir= ""
model_path = 'StarNet_2018-10-31_run1/' + 'weights.best.h5'
denormalization_path = '/home/spiffical/data/spiffical/gaia-ESO/mu_std_INTRIGOSS_gaiaeso_UVES-4835-5395_7labels.txt'
test_data_path = '/home/spiffical/data/spiffical/realspec/UVES/UVES_GE_MW_4835-5395_updated.h5'

targets = ['teff', 'logg', 'M_H', 'a_M', 'v_rot', 'v_rad', 'VT']
num_labels = len(targets)
label_names = ['$T_{\mathrm{eff}}$',r'$\log(g)$','$[Fe/H]$',r'[$\alpha/M$]',r'[$v_{rot}$]', 
               r'[$v_{rad}$]',r'$v_{micro}$']

**Load Model**

In [2]:
starnet_model = load_model(model_path)
starnet_model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input (InputLayer)           (None, 39436, 1)          0         
_________________________________________________________________
conv1d_1 (Conv1D)            (None, 39436, 4)          36        
_________________________________________________________________
conv1d_2 (Conv1D)            (None, 39436, 16)         528       
_________________________________________________________________
max_pooling1d_1 (MaxPooling1 (None, 9859, 16)          0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 157744)            0         
_________________________________________________________________
dense_1 (Dense)              (None, 256)               40382720  
_________________________________________________________________
dense_2 (Dense)              (None, 128)               32896     
__________

**Create a denormalization function**

In [7]:
def denormalize(lb_norm):
    with open(denormalization_path,'r') as f1:
        mu = np.array(map(float, f1.readline().split()[0:num_labels]))
        sigma = np.array(map(float, f1.readline().split()[0:num_labels]))
    return ((lb_norm*sigma)+mu)

**Define function to compute the jacobian matrix**

In [8]:
def calc_jacobian(model, spectrum, denormalize=None):
    
    spectrum = spectrum.reshape((1,spectrum.shape[1],1))
        
    if denormalize==None:
        y_list = tf.unstack(model.output)
    else:
        y_list = tf.unstack(denormalize(model.output[0]))

    J = [tf.gradients(y, model.input) for y in y_list]


    jacobian_func = [K.function([model.input, K.learning_phase()], j_) for j_ in J]

    jacobian = np.array([jf([spectrum,False]) for jf in jacobian_func])[:,0,0,:,0]
    '''
    for i in range(len(spectra)):
        jacobian = np.array([jf([spectra,False]) for jf in jacobian_func])[:,:,0,:,0]
        np.save('temp/temp_jacobian_'+str(i)+'.npy',jacobian)
        if i%int(0.1*len(spectra))==0:
            print('Jacobians completed: '+str(i))
    
    for i in range(len(spectra)):
        if i==0:
            jacobian = np.load('temp/temp_jacobian_'+str(i)+'.npy')
        else:
            jacobian = np.concatenate((jacobian,np.load('temp/temp_jacobian_'+str(i)+'.npy')))
        subprocess.check_output(['rm','temp/temp_jacobian_'+str(i)+'.npy'])
    '''
    return jacobian

## Load data

### Synthetic spectra

split analysis into high temp (>5500) and low temp (<5500)

In [87]:
infolder_synth = '/vmstorage/projects/gaiaeso/spectra/intrigoss/UVES_4835-5395/' 
spec_name = 'spectra_starnetnorm'
data_file = infolder_synth + 'INTRIGOSS_gaiaeso_UVES-4835-5395_7labels_testset.h5'

indexes = np.arange(0,9000,1)
data = load_data_from_h5(data_file=data_file,
                         indices=indexes,
                         targetname=targets,
                         mu_std=denormalization_path,
                         specname=spec_name)

X_synth, y_synth, noise_synth, wave_grid = data.X, data.y, data.noise, data.wave_grid

# Mask telluric lines
X_synth = mask_tellurics('telluric_lines.txt', X_synth, wave_grid)

# Zero-point bad values
for x in X_synth:
    x[x>1.03]=0
    x[x<0]=0

X_synth = np.array(X_synth)
y_synth = np.array(y_synth)
indx_highT = y_synth[:,0] > 5500
indx_lowT = y_synth[:,0] <= 5500

### UVES data

In [53]:
num_test = 300

with h5py.File(test_data_path, "r") as f:
    test_spectra = f['spectra_starnet_norm'][:num_test]
    test_labels = np.column_stack([f['teff'][:num_test], f['logg'][:num_test], f['fe_h'][:num_test],
                                   f['v_rad'][:num_test], f['vmicro'][:num_test]])
    wave_grid = f['wave_grid'][:]
    #snr_uves = f['SNR'][:num_test]
    ges_type = f['ges_type'][:num_test]
    objects = f['object'][:num_test]

# Take care of bad values
for spec in test_spectra:
    spec[spec>1.03]=0
    spec[spec<0]=0

#all_ges_types = ['GE_CL', 'GE_MW', 'GE_MW_BL', 'GE_SD_BC', 'GE_SD_BM', 'GE_SD_BW', 'GE_SD_CR',
#                 'GE_SD_GC', 'GE_SD_MC', 'GE_SD_OC', 'GE_SD_PC', 'GE_SD_RV', 'GE_SD_TL']
indx = ges_type == 'GE_MW'
test_spectra = test_spectra[indx]
test_labels = test_labels[indx]

print('Test set contains '  + str(len(test_labels))+' stars')

Test set contains 107 stars


## Compute Jacobians

The jacobian matrix will be of shape (num_spectra, num_labels, num_fluxes)

In [88]:
jacobian = np.zeros((len(y_synth), num_labels, np.array(X_synth).shape[1]))
predictions = np.zeros(np.array(y_synth).shape)
print('Computing jacobian for '+str(len(y_synth))+' spectra')
time_start = time.time()
for i in range(len(y_synth)):
    spectrum = np.array(X_synth)[i:i+1]
    jacobian[i] = calc_jacobian(starnet_model,spectrum,denormalize=denormalize)
    if i%5==0:
        print('\n'+str(i+1)+'/'+str(len(test_labels))+' completed.\n'+str(time.time()-time_start)+' seconds elapsed.')
print('\nAll '+str(i+1)+' completed.\n'+str(time.time()-time_start)+' seconds elapsed.')

KeyboardInterrupt: 

In [None]:
jacobian = np.zeros((len(test_labels), num_labels, test_spectra.shape[1]))
predictions = np.zeros(test_labels.shape)
print('Computing jacobian for '+str(len(test_labels))+' spectra')
time_start = time.time()
for i in range(len(test_labels)):
    spectrum = test_spectra[i:i+1]
    jacobian[i] = calc_jacobian(starnet_model,spectrum,denormalize=denormalize)
    if i%5==0:
        print('\n'+str(i+1)+'/'+str(len(test_labels))+' completed.\n'+str(time.time()-time_start)+' seconds elapsed.')
print('\nAll '+str(i+1)+' completed.\n'+str(time.time()-time_start)+' seconds elapsed.')

Computing jacobian for 107 spectra

1/107 completed.
44.2700848579 seconds elapsed.

6/107 completed.
260.448454857 seconds elapsed.

11/107 completed.
481.561657906 seconds elapsed.

16/107 completed.
704.702936888 seconds elapsed.

21/107 completed.
936.690569878 seconds elapsed.

26/107 completed.
1214.09999895 seconds elapsed.

31/107 completed.
1724.3495779 seconds elapsed.

36/107 completed.
2258.40197802 seconds elapsed.

41/107 completed.
2775.25408101 seconds elapsed.

46/107 completed.
3051.85107994 seconds elapsed.

51/107 completed.
3346.15531898 seconds elapsed.

56/107 completed.
3747.35138798 seconds elapsed.


**Take the average of the jacobian for all of the test spectra and take absolute**

You can try switching the order of these two operations to see if you get more logical results. Taking the absolute of the jacobian makes it easier to analyze visually.

In [83]:
avg_jacobian = np.abs(np.mean(jacobian,axis=0))

### Plot the results

In [84]:
wave_grid = wave_grid.reshape((len(wave_grid),))

**Make a plot for each label**
If you're not familiar with the notebook backend for matplotlib, experiment with the buttons on the bottom left. You change the figure size by dragging the bottom right corner. For some reason the save function doesn't work with our server, but you can click the Power button at the top right of the figure and then right click and save it manually.

In [58]:
plt.close('all')
%matplotlib notebook

fig,ax=plt.subplots(num_labels, 1, sharex=True, sharey=False,
                    figsize=(9.5,1.3*num_labels))

for i in range(num_labels):
    ax[i].plot(wave_grid, avg_jacobian[i], linewidth=.5)
    ax[i].set_ylabel(r'$\partial$ '+label_names[i],fontsize=15)

plt.xlabel(r'Wavelength (\AA)',fontsize=15)

plt.show()

<IPython.core.display.Javascript object>

In [85]:
plt.close('all')
%matplotlib notebook

fig,ax=plt.subplots(num_labels, 1, sharex=True, sharey=False,
                    figsize=(9.5,1.3*num_labels))

for i in range(num_labels):
    ax[i].plot(wave_grid, avg_jacobian[i], linewidth=.5)
    ax[i].set_ylabel(r'$\partial$ '+label_names[i],fontsize=15)

plt.xlabel(r'Wavelength (\AA)',fontsize=15)

plt.show()

<IPython.core.display.Javascript object>

In [59]:
# synthetic spectra!

In [17]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, TextBox, CheckButtons
from matplotlib import rc
plt.rc('text', usetex=True)
plt.rc('font', family='serif')
        
def plot_spec_and_lines(x_data, y_data, colors, label_name):
    elems_n = np.load('data/line_elements.npy')
    elems_w = np.load('data/line_wavelengths.npy')
    elems_gf = np.load('data/line_loggfs.npy')
    atom_names = np.load('data/element_names.npy')

    def find_nearest(array,value):
        idx = (np.abs(array-value)).argmin()
        return idx

    def annotate_elems(x_data, y_data, ax):
        for n in elem_checked:
            # Check to see if checked
            if elem_checked[n]==False:
                continue

            elem_indices = np.where(elems_n==n)[0]
            xs = elems_w[elem_indices]
            gfs = elems_gf[elem_indices]

            for x, g in zip(xs,gfs):

                # Check to see if line strength is above min
                if g<min_loggf['val']:
                    continue

                x_indx = find_nearest(x_data, x)
                y_pos = np.nanmedian(y_data[x_indx-5:x_indx+5])
                ax.annotate(n, xy=(x, 1.2*y_pos),
                            xytext=(x, 5.*y_pos),
                            fontsize=15, fontname='serif', 
                            arrowprops=dict(arrowstyle="-", linewidth=1,
                                            connectionstyle="arc3",facecolor='black'))

    def switch(n):
        if elem_checked[n]==False:
            elem_checked[n]=True
        else:
            elem_checked[n]=False

    fig,ax = plt.subplots(figsize=(21,3))

    # Plot data
    for i, y_dat in enumerate(y_data):
        l, = plt.plot(x_data, y_dat, color=colors[i], linewidth=0.8)

    # Setup Check boxes and axes
    plt.subplots_adjust(right=0.45)
    checks = []
    baxes = []
    for i in range(int(len(atom_names)/10+1)):
        #if i < len(atom_names)/10:
        baxes.append(plt.axes([0.47+0.06*i, 0.1, 0.1, 0.85]))
        for a in baxes[i].spines:
            baxes[i].spines[a].set_color('white')
        checks.append(CheckButtons(baxes[i], atom_names[i*10:(i+1)*10], 
                                   np.zeros_like(atom_names,dtype=bool)[i*10:(i+1)*10]))

    # Dictionary of which buttons are clicked
    elem_checked = {}
    for n in atom_names:
        elem_checked[n]=False


    # Setup Slider
    # Dictionary to get current val
    min_loggf = {'val':0.}
    plt.subplots_adjust(bottom=0.2)
    slax = plt.axes([0.55, 0.01, 0.3, 0.05])
    loggf_slider = Slider(slax, r'min log(gf)', -10, 7, valinit=min_loggf['val'], 
                       valfmt='%0.1f')

    # Setup text boxes
    plt.subplots_adjust(top=0.85)
    begin_indx = {'val':0}
    axbeg = plt.axes([0.2, 0.875, 0.03, 0.1])
    text_begin = TextBox(axbeg, r'Start (\AA):   ', initial='{:0.2f}'.format(x_data[begin_indx['val']]))
    end_indx = {'val':len(x_data)-1}
    axend = plt.axes([0.3, 0.875, 0.03, 0.1])
    text_end = TextBox(axend, r'End (\AA):   ', initial='{:0.2f}'.format(x_data[end_indx['val']]))



    # Axis labels
    ax.set_xlabel(r'Wavelength (\AA)',fontsize=15)
    ax.set_ylabel(r'$\partial$ '+label_name, fontsize=15)

    def update():
        ax.clear()
        for i, y_dat in enumerate(y_data):
            ax.plot(x_data[begin_indx['val']:end_indx['val']], 
                    y_dat[begin_indx['val']:end_indx['val']], color=colors[i], linewidth=0.8)
        ax.set_xlabel(r'Wavelength (\AA)',fontsize=15)
        ax.set_ylabel(r'$\partial$ '+label_name, fontsize=15)
        annotate_elems(x_data, y_data[0], ax)

    # Function to update based on check boxes
    def get_ticks(label_):
        included = []
        for n in atom_names:
            if label_==n:
                switch(n)
        update()

    # Function to update based on slider
    def slider_update(val):
        min_loggf['val'] = loggf_slider.val
        update()


    def begin_submit(text):
        begin_wav = eval(text)
        begin_indx['val'] = find_nearest(x_data, begin_wav)
        update()

    def end_submit(text):
        end_wav = eval(text)
        end_indx['val'] = find_nearest(x_data, end_wav)
        update()

    text_begin.on_submit(begin_submit)
    text_end.on_submit(end_submit)


    loggf_slider.on_changed(slider_update)
    
    for c in checks:
        c.on_clicked(get_ticks)

    plt.draw()

    return elem_checked, min_loggf, begin_indx, end_indx, checks, text_begin, text_end

In [18]:
%matplotlib notebook
plt.close('all')
(elem_checked, min_loggf, begin_indx, 
 end_indx, checks, text_begin, text_end) = plot_spec_and_lines(wave_grid, 
                                                                   [avg_jacobian[0]], 
                                                                   colors=['navy'],
                                                              label_name=label_names[0])

<IPython.core.display.Javascript object>