In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import h5py
import os
import matplotlib.pyplot as plt
import time
from ipywidgets import interact, interactive, HBox, VBox, interactive_output, FloatSlider, Layout
from IPython.display import display

In [2]:
# Creating the emulator
def DNN(sizes):
    # Linear layers
    layers = []
    layers.append(nn.Linear(23,sizes[0]))
    layers.append(nn.LeakyReLU(0.07))
    for i in range(len(sizes)-1):
        layers.append(nn.Linear(sizes[i],sizes[i+1]))
        layers.append(nn.LeakyReLU(0.07))
    layers.append(nn.Linear(sizes[-1],3829))

    model = nn.Sequential(*layers)
    return model

In [3]:
means = np.load('mean.npy').reshape(1,23)
stds = np.load('std.npy').reshape(1,23)
grid = np.load('grid.npy')

# Function that maps the parameters onto a spectrum and displays it
def plot(Al,Ba,C,Ca,Co,Cr,Eu,Mg,Mn,N,Na,Ni,O,Si,Sr,Ti,Zn,logg,teff,Fe_H,vsini,vt,vrad):
    params = np.reshape(np.array([Al,Ba,C,Ca,Co,Cr,Eu,Mg,Mn,N,Na,Ni,O,Si,Sr,Ti,Zn,logg,teff,Fe_H,vsini,vt,vrad]),(1,23))
    params = ((params-means)/stds).reshape(1,1,23)
    params = torch.tensor(params,requires_grad=True,device='cuda:0',dtype=torch.float)
    spectrum = NN(params).detach().cpu().numpy()[0,0,:]
    plt.ylim((0,1.2))
    plt.ylabel('Normalized Flux')
    plt.xlabel('Wavelength (Angstrom)')
    plt.plot(grid[0:1980], spectrum[0:1980],color='blue')
    plt.plot(grid[1980:], spectrum[1980:],color='red')

In [4]:
if __name__=='__main__':
    # Initializing the emulator
    model_name = 'emulator_v6.pth'
    NN = DNN([200,268,397,569,739]).to('cuda:0')
    NN.load_state_dict(torch.load(model_name))
    NN.eval()

    plt.rcParams["figure.figsize"] = (15,7)
    #w = interactive(plot,Al=(-1.5,0.5,0.1),Ba=(-2,2,0.1),C=(-1,3,0.1),Ca=(-0.5,1.1,0.1),Co=(-2,1.1,0.1),Cr=(-2,1.1,0.1),Eu=(-2,1.1,0.1),Mg=(-2,1.1,0.1),Mn=(-1.5,0.5,0.1),N=(-2,1.1,0.1),Na=(-2,1.1,0.1),Ni=(-2,1.1,0.1),O=(-2,1.1,0.1),Si=(-2,1.1,0.1),Sr=(-2,1.1,0.1),Ti=(-2,2,0.1),Zn=(-1.5,0.5,0.1),logg=(1,5,0.1),teff=(4600,8800,10),Fe_H=(-4.8,0.6,0.1),vsini=(0,50,1),vt=(0.5,3,0.1),vrad=(-150,150,1))
    w = interactive(plot,Al=FloatSlider(min=-1.5,max=0.5,step=0.1,orientation='vertical',description='[Al/Fe]'),
                         Ba=FloatSlider(min=-2,max=2,step=0.1,orientation='vertical',description='[Ba/Fe]'),
                         C=FloatSlider(min=-1.3,max=0.1,step=0.1,orientation='vertical',description='[C/Fe]'),
                         Ca=FloatSlider(min=-0.5,max=1.1,step=0.1,orientation='vertical',description='[Ca/Fe]'),
                         Co=FloatSlider(min=-2,max=1.1,step=0.1,orientation='vertical',description='[Co/Fe]'),
                         Cr=FloatSlider(min=-2,max=1.1,step=0.1,orientation='vertical',description='[Cr/Fe]'),
                         Eu=FloatSlider(min=-2,max=1.1,step=0.1,orientation='vertical',description='[Eu/Fe]'),
                         Mg=FloatSlider(min=-2,max=1.1,step=0.1,orientation='vertical',description='[Mg/Fe]'),
                         Mn=FloatSlider(min=-1.5,max=0.5,step=0.1,orientation='vertical',description='[Mn/Fe]'),
                         N=FloatSlider(min=-2,max=1.1,step=0.1,orientation='vertical',description='[N/Fe]'),
                         Na=FloatSlider(min=-2,max=1.1,step=0.1,orientation='vertical',description='[Na/Fe]'),
                         Ni=FloatSlider(min=-2,max=1.1,step=0.1,orientation='vertical',description='[Ni/Fe]'),
                         O=FloatSlider(min=-2,max=1.1,step=0.1,orientation='vertical',description='[O/Fe]'),
                         Si=FloatSlider(min=-2,max=1.1,step=0.1,orientation='vertical',description='[Si/Fe]'),
                         Sr=FloatSlider(min=-2,max=1.1,step=0.1,orientation='vertical',description='[Sr/Fe]'),
                         Ti=FloatSlider(min=-2,max=2,step=0.1,orientation='vertical',description='[Ti/Fe]'),
                         Zn=FloatSlider(min=-1.5,max=0.5,step=0.1,orientation='vertical',description='[Zn/Fe]'),
                         logg=FloatSlider(min=-2.5,max=5,step=0.1,value=3,orientation='vertical',description='logg'),
                         teff=FloatSlider(min=4600,max=8800,step=10,value=6000,orientation='vertical',description='Teff'),
                         Fe_H=FloatSlider(min=-4.8,max=0.6,step=0.1,value=-2,orientation='vertical',description='[Fe/H]'),
                         vsini=FloatSlider(min=0,max=50,step=1,value=25,orientation='vertical',description='vsini'),
                         vt=FloatSlider(min=0.5,max=3,step=0.1,value=1.5,orientation='vertical',description='vt'),
                         vrad=FloatSlider(min=-150,max=150,step=1,value=0,orientation='vertical',description='vrad'))
                   
    box_layout = Layout(display='flex', flex_flow='row', justify_content='center', align_items='stretch')
    display(HBox(w.children[:-1]),layout=box_layout)#Show all controls
    display(w.children[-1])#Show the output

HBox(children=(FloatSlider(value=0.0, description='[Al/Fe]', max=0.5, min=-1.5, orientation='vertical'), Float…

Output()