##### Fitting the Stellar Parameters from Photometry

In [1]:
from astropy.table import Table
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import h5py
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
plt.rcParams['text.usetex'] = True
plt.rcParams['font.size'] = 14
plt.rcParams['legend.fontsize'] = 14
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.rcParams['xtick.major.size'] = 5.0
plt.rcParams['xtick.minor.size'] = 3.0
plt.rcParams['ytick.major.size'] = 5.0
plt.rcParams['ytick.minor.size'] = 3.0
plt.rcParams['xtick.top'] = True
plt.rcParams['ytick.right'] = True
from time import time
from sklearn.preprocessing import StandardScaler
import pandas as pd
from gaiaxpy import generate, PhotometricSystem

metscaler = StandardScaler()
logscaler = StandardScaler()
tefscaler = StandardScaler()
amscaler = StandardScaler()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [2]:
# defining the Dataset class
class data_set(Dataset):
    '''
    Main way to access the .h5 file.
    '''
    def __init__(self,file,train=True,valid=False,test=False):
        fn = h5py.File(file, 'r')
        self.f = fn
        
        # get data
        if train:
            dset = self.f['group_1']['data']
            d = dset[:]
            dat = np.array([
                metscaler.fit_transform(d[[0]].T).flatten(),
                logscaler.fit_transform(d[[1]].T).flatten(),
                tefscaler.fit_transform(d[[2]].T).flatten(),
                amscaler.fit_transform(d[[3]].T).flatten(), # comment out if not
            ])
            self.l = dat.shape[1]
            self.x = torch.Tensor(dat.T)
        elif valid:
            dset = self.f['group_2']['data']
            d = dset[:]
            dat = np.array([
                metscaler.transform(d[[0]].T).flatten(),
                logscaler.transform(d[[1]].T).flatten(),
                tefscaler.transform(d[[2]].T).flatten(),
                amscaler.transform(d[[3]].T).flatten(), # comment out if not
            ])
            self.l = dat.shape[1]
            self.x = torch.Tensor(dat.T)
        elif test:
            dset = self.f['group_3']['data']
            d = dset[:]
            dat = np.array([
                metscaler.transform(d[[0]].T).flatten(),
                logscaler.transform(d[[1]].T).flatten(),
                tefscaler.transform(d[[2]].T).flatten(),
                amscaler.transform(d[[3]].T).flatten(), # comment out if not
            ])
            self.l = dat.shape[1]
            self.x = torch.Tensor(dat.T)
        
        # get label
        if train:
            ydset = self.f['group_1']['label']
            ydat = ydset[:]
            self.y = torch.Tensor(ydat[:].T) # torch.from_numpy(y[index]) does not work since y is doubles and not floats.
        elif valid:
            ydset = self.f['group_2']['label']
            ydat = ydset[:]
            self.y = torch.Tensor(ydat[:].T)
        elif test:
            ydset = self.f['group_3']['label']
            ydat = ydset[:]
            self.y = torch.Tensor(ydat[:].T)
            
        bpnews = np.array(ydat[:55]).T
        rpnews = np.array(ydat[55:]).T
        df = pd.DataFrame(
            {'source_id':range(len(ydat.T)),
             'bp_coefficients':list(bpnews),
             'bp_standard_deviation':[np.std(bp) for bp in bpnews],
             'bp_coefficient_covariances':[np.zeros((55,55)) for _ in bpnews],
             'rp_coefficients':list(rpnews),
             'rp_coefficient_covariances':[np.zeros((55,55)) for _ in bpnews],
             'rp_standard_deviation':[np.std(rp) for rp in rpnews]
            }
        )
        synthetic_photometry = generate(df, photometric_system=PhotometricSystem.Pristine)
        self.g = torch.from_numpy(synthetic_photometry['Pristine_mag_CaHK'].to_numpy(dtype='float32'))
        
    def __len__(self):
        return self.l
  
    def __getitem__(self, index):
        xg = self.x[index]
        yg = self.y[index]
        gg = self.g[index]
        # errg = self.err[index]
        return (xg,yg,gg)

In [3]:
class ResBlock(nn.Module):
    '''
    check this guy out, might not be using sigmoid when I should be.
    '''
    def __init__(self, nodes):
        super(ResBlock, self).__init__()
        self.res_block1 = nn.Sequential(
            nn.Linear(nodes,nodes),
            nn.BatchNorm1d(nodes),
            nn.LeakyReLU(),
            # nn.Sigmoid(),
        )
        self.res_block2 = nn.Sequential(
            nn.Linear(nodes,nodes),
            nn.BatchNorm1d(nodes),
        )
        self.lrelu = nn.LeakyReLU()
        self.siggy = nn.Sigmoid()
#         add dropout in the init
        self.do = nn.Dropout() #
    
    def forward(self, x):
        res = x
        x = self.res_block1(x)
        x = self.res_block2(x)
        x = x + res
        output = self.lrelu(x)
        # output = self.siggy(x)
#         add dropout after the relu
        # output = self.do(x) #
        return output
        
class ResNetM(nn.Module):
    def __init__(self):
        super(ResNetM, self).__init__()
        self.input_block = nn.Sequential(
            nn.Linear(4,16),
            nn.LeakyReLU(),
        )
        self.blocklist = nn.ModuleList([
            ResBlock(16),
            ResBlock(16),
            nn.Linear(16,32),
            ResBlock(32),
            ResBlock(32),
            nn.Linear(32,64),
            ResBlock(64),
            ResBlock(64),
            nn.Linear(64,128),
            ResBlock(128),
            ResBlock(128),
        ])
        self.output_block = nn.Sequential(
            nn.Linear(128,110),
        )
        
    def forward(self,x):
        x = self.input_block(x)
        for i, _ in enumerate(self.blocklist):
            x = self.blocklist[i](x)
        logits = self.output_block(x)
        return logits

In [4]:
training_data = data_set('/arc/home/aydanmckay/filtered_apogee_bprp_gmag.h5')
data = data_set('/arc/home/aydanmckay/filtered_apogee_bprp_gmag.h5',train=False,test=True)

                              

In [5]:
model = ResNetM()
model = model.to(device)

In [6]:
model.load_state_dict(torch.load('/arc/home/aydanmckay/torchmodel/rerunfiltered_apogee_resnet_no_do_lossL1_pristine_scale_bl32_lr0.001_SGD_ep10.pth'))

<All keys matched successfully>

In [7]:
unnormalize = lambda x, n: x * (10 ** (8.5 - n / 2.5))

In [8]:
def mag_gen(stellar_features):
    dataloader = DataLoader(
        stellar_features,
        batch_size=32,
        shuffle=False,
        num_workers=0
    )
    model.eval()
    
    preds = []
    covbs = []
    covrs = []
    gs = []
    data = []
    
    with torch.no_grad():
        for X, y, g in dataloader:
            X = X.to(device)
            g = g.to(device)
            pred = model(X)
            for prediction,gmag,dat in zip(pred,g,X):
                covbp = np.zeros((55,55))
                covrp = np.zeros((55,55))
                gs.append(gmag)
                preds.append(prediction)
                data.append(dat)
                covbs.append(covbp)
                covrs.append(covrp)
    preds = np.array([pred.to('cpu').numpy() for pred in preds]).T
    gs = np.array([g.to('cpu').numpy() for g in gs]).T
    data = np.array([dat.to('cpu').numpy() for dat in data]).T
    xpcoefs = np.array([unnormalize(row,mag) for row,mag in zip(preds.T,gs)]).T

    bpnews = np.array(xpcoefs[:55]).T
    rpnews = np.array(xpcoefs[55:]).T
    df = pd.DataFrame(
        {'source_id':range(len(preds.T)),
         'bp_coefficients':list(bpnews),
         'bp_standard_deviation':[np.std(bp) for bp in bpnews],
         'bp_coefficient_covariances':covbs,
         'rp_coefficients':list(rpnews),
         'rp_coefficient_covariances':covrs,
         'rp_standard_deviation':[np.std(rp) for rp in rpnews]
        }
    )
    # print(df)
    synthetic_gaia = generate(df, photometric_system=PhotometricSystem.Gaia_DR3_Vega)
    synthetic_pristine = generate(df, photometric_system=PhotometricSystem.Pristine)
    return synthetic_pristine, synthetic_gaia, gs

In [9]:
def mag_gen2(stellar_features):
    dataloader = DataLoader(
        stellar_features,
        batch_size=32,
        shuffle=False,
        num_workers=0
    )
    model.eval()
    
    preds = []
    covbs = []
    covrs = []
    gs = []
    data = []
    
    with torch.no_grad():
        for X, y, g in dataloader:
            X = X.to(device)
            g = g.to(device)
            pred = model(X)
            for prediction,gmag,dat in zip(pred,g,X):
                covbp = np.zeros((55,55))
                covrp = np.zeros((55,55))
                gs.append(gmag)
                preds.append(prediction)
                data.append(dat)
                covbs.append(covbp)
                covrs.append(covrp)
    preds = np.array([pred.to('cpu').numpy() for pred in preds]).T
    gs = np.array([g.to('cpu').numpy() for g in gs]).T
    data = np.array([dat.to('cpu').numpy() for dat in data]).T
    # xpcoefs = np.array([unnormalize(row,mag) for row,mag in zip(preds.T,gs)]).T

    bpnews = np.array(preds[:55]).T
    rpnews = np.array(preds[55:]).T
    df = pd.DataFrame(
        {'source_id':range(len(preds.T)),
         'bp_coefficients':list(bpnews),
         'bp_standard_deviation':[np.std(bp) for bp in bpnews],
         'bp_coefficient_covariances':covbs,
         'rp_coefficients':list(rpnews),
         'rp_coefficient_covariances':covrs,
         'rp_standard_deviation':[np.std(rp) for rp in rpnews]
        }
    )
    # print(df)
    synthetic_gaia = generate(df, photometric_system=PhotometricSystem.Gaia_DR3_Vega)
    synthetic_pristine = generate(df, photometric_system=PhotometricSystem.Pristine)
    return synthetic_pristine, synthetic_gaia, gs

In [10]:
def main(X,mag):
    '''
    The forwards implementation of the model contained in one function
    for the least-squares inference fitting that will give the 
    '''
    
    data = torch.Tensor([X]).to(device)
    with torch.no_grad():
        bprp = model(data).to('cpu').numpy()
        
    coefs = unnormalize(bprp,mag).T
    
    bpnews = np.array(coefs[:55]).T
    rpnews = np.array(coefs[55:]).T
    
    df = pd.DataFrame(
        {'source_id':range(len(coefs.T)),
         'bp_coefficients':list(bpnews),
         'bp_standard_deviation':[np.std(bp) for bp in bpnews],
         'bp_coefficient_covariances':list(np.zeros((len(bpnews),55,55))),
         'rp_coefficients':list(rpnews),
         'rp_coefficient_covariances':list(np.zeros((len(bpnews),55,55))),
         'rp_standard_deviation':[np.std(rp) for rp in rpnews]
        }
    )
    
    synthetic_photometry = generate(df, photometric_system=PhotometricSystem.Pristine)
    return synthetic_photometry

In [11]:
syn_mags, syn_gmags, gmags = mag_gen(data)
syn_mags2, syn_gmags2, gmags2 = mag_gen2(data)

                              

In [28]:
syn_gmags.describe()

Unnamed: 0,source_id,GaiaDr3Vega_mag_G,GaiaDr3Vega_mag_BP,GaiaDr3Vega_mag_RP,GaiaDr3Vega_flux_G,GaiaDr3Vega_flux_BP,GaiaDr3Vega_flux_RP,GaiaDr3Vega_flux_error_G,GaiaDr3Vega_flux_error_BP,GaiaDr3Vega_flux_error_RP
count,34241.0,34080.0,34080.0,34080.0,34080.0,34080.0,34080.0,34080.0,34080.0,34080.0
mean,17120.0,12.613278,13.168281,11.932431,6.43736e-16,6.903684e-16,5.841866e-16,0.0,0.0,0.0
std,9884.66962,1.662426,1.751191,1.612035,1.375216e-15,1.625354e-15,1.172444e-15,0.0,0.0,0.0
min,0.0,6.867849,7.105951,6.072227,1.1115279999999999e-20,9.889024e-21,1.048633e-20,0.0,0.0,0.0
25%,8560.0,11.557982,12.042453,10.923028,9.405359e-17,8.47123e-17,9.547432000000001e-17,0.0,0.0,0.0
50%,17120.0,12.52828,13.043068,11.880468,2.470178e-16,2.492142e-16,2.297485e-16,0.0,0.0,0.0
75%,25680.0,13.576662,14.214634,12.833884,6.037363e-16,6.263522e-16,5.549179e-16,0.0,0.0,0.0
max,34240.0,23.395299,24.046616,22.732041,4.538373e-14,5.907715e-14,4.836701e-14,0.0,0.0,0.0


In [27]:
syn_gmags2.describe()

Unnamed: 0,source_id,GaiaDr3Vega_mag_G,GaiaDr3Vega_mag_BP,GaiaDr3Vega_mag_RP,GaiaDr3Vega_flux_G,GaiaDr3Vega_flux_BP,GaiaDr3Vega_flux_RP,GaiaDr3Vega_flux_error_G,GaiaDr3Vega_flux_error_BP,GaiaDr3Vega_flux_error_RP
count,34241.0,34241.0,34241.0,34241.0,34241.0,34241.0,34241.0,34241.0,34241.0,34241.0
mean,17120.0,18.726848,19.284035,18.044663,1.281975e-18,1.018775e-18,1.464198e-18,0.0,0.0,0.0
std,9884.66962,0.960806,0.735913,1.128017,1.567529e-18,8.300112999999999e-19,2.220516e-18,0.0,0.0,0.0
min,0.0,15.57986,17.441945,14.304587,1.68531e-19,2.254071e-19,1.192573e-19,0.0,0.0,0.0
25%,8560.0,18.184236,18.833955,17.418134,3.9561529999999997e-19,4.523714999999999e-19,3.361077e-19,0.0,0.0,0.0
50%,17120.0,18.77206,19.285188,18.101683,7.856262999999999e-19,7.938246999999999e-19,7.460476999999999e-19,0.0,0.0,0.0
75%,25680.0,19.516917,19.895762,18.967404,1.350036e-18,1.202868e-18,1.400188e-18,0.0,0.0,0.0
max,34240.0,20.443401,20.652081,20.092387,1.4862640000000002e-17,4.33533e-18,2.4637990000000003e-17,0.0,0.0,0.0


In [14]:
gmags

array([13.152964, 14.849155, 14.015436, ..., 16.61399 , 20.178276,
       14.233086], dtype=float32)

In [15]:
gmags2

array([13.152964, 14.849155, 14.015436, ..., 16.61399 , 20.178276,
       14.233086], dtype=float32)

In [29]:
syn_mags.describe()

Unnamed: 0,source_id,Pristine_mag_CaHK,Pristine_flux_CaHK,Pristine_flux_error_CaHK
count,34241.0,34080.0,34080.0,34080.0
mean,17120.0,15.225139,3.437075e-16,0.0
std,9884.66962,2.219202,1.206698e-15,0.0
min,0.0,7.884302,6.715963e-22,0.0
25%,8560.0,13.690183,1.573244e-17,0.0
50%,17120.0,15.05108,6.947811e-17,0.0
75%,25680.0,16.66371,2.433363e-16,0.0
max,34240.0,27.587929,5.111628e-14,0.0


In [17]:
syn_mags2

Unnamed: 0,source_id,Pristine_mag_CaHK,Pristine_flux_CaHK,Pristine_flux_error_CaHK
0,0,21.329360,2.140602e-19,0.0
1,1,21.218833,2.369992e-19,0.0
2,2,21.219831,2.367814e-19,0.0
3,3,21.423072,1.963592e-19,0.0
4,4,21.177599,2.461731e-19,0.0
...,...,...,...,...
34236,34236,21.279549,2.241095e-19,0.0
34237,34237,21.231513,2.342474e-19,0.0
34238,34238,21.349806,2.100668e-19,0.0
34239,34239,21.231863,2.341718e-19,0.0


In [18]:
for i in data:
    print(i[0])
    break

tensor([-0.6468, -0.5675,  0.0102,  0.7541])


In [19]:
main([-0.6468, -0.5675,  0.0102,  0.7541],13.1530)

                              

Unnamed: 0,source_id,Pristine_mag_CaHK,Pristine_flux_CaHK,Pristine_flux_error_CaHK
0,0,13.232359,3.70968e-16,0.0


In [20]:
# from scipy.optimize import curve_fit, root

In [21]:
# popt, pcov = curve_fit(main, data, syn_mags)

In [22]:
# ans = root(main, data)

In [23]:
from scipy.optimize import least_squares

# Define the target output value for which you want to find the inverse input vector
y = 15

# Define a function that takes an input vector x and returns the difference between the output of the unknown function and the target output value y
def residual_function(x,arg):
    return (main(x,arg).values - y).flatten()

# Define an initial guess for the input vector
x_guess = np.zeros(4)

# Use the least_squares function to find the input vector that gives the target output value
result = least_squares(residual_function, x_guess, args=(y,),verbose=True)

# The input vector that gives the target output value is stored in the result.x attribute
x_inverse = result.x

`gtol` termination condition is satisfied.
Function evaluations 1, initial cost 3.3750e+02, final cost 3.3750e+02, first-order optimality 0.00e+00.


  data = torch.Tensor([X]).to(device)


In [24]:
x_inverse

array([0., 0., 0., 0.])

In [25]:
result

 active_mask: array([0., 0., 0., 0.])
        cost: 337.5000620230586
         fun: array([-1.50000000e+01,  1.11375993e-02, -1.50000000e+01, -1.50000000e+01])
        grad: array([0., 0., 0., 0.])
         jac: array([[0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.]])
     message: '`gtol` termination condition is satisfied.'
        nfev: 1
        njev: 1
  optimality: 0.0
      status: 1
     success: True
           x: array([0., 0., 0., 0.])

In [26]:
model([0,0,0,0])

TypeError: linear(): argument 'input' (position 1) must be Tensor, not list