In [65]:
from astropy.table import Table
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import h5py
import matplotlib.pyplot as plt
from gaiaxpy import generate, PhotometricSystem
import pandas as pd

In [2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(3, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 110)
        )

    def forward(self, x):
        logits = self.linear_relu_stack(x)
        return logits
    
# defining the Dataset class
class train_set(Dataset):
    def __init__(self,file):
        fn = h5py.File(file, 'r')
        self.f = fn
        
    def __len__(self):
        return self.f['group_1']['data'].shape[1]
  
    def __getitem__(self, index):
        # get data
        dset = self.f['group_1']['data']
        x = dset[:].T
        x = torch.Tensor(x[index])
        

        # get label
        ydset = self.f['group_1']['label']
        y = ydset[:].T
        y = torch.Tensor(y[index])
        # torch.from_numpy(y[index]) does not work since y is doubles and not floats.
        
        # get error in label # comment out for non-error label runs
        errdset = self.f['group_1']['e_label']
        err = errdset[:].T
        err = torch.Tensor(err[index])
        return (x,y,err)

class test_set(Dataset):
    def __init__(self,file):
        fn = h5py.File(file, 'r')
        self.f = fn
        
    def __len__(self):
        return self.f['group_2']['data'].shape[1]
  
    def __getitem__(self, index):
        # get data
        dset = self.f['group_2']['data']
        x = dset[:].T
        x = torch.from_numpy(x[index])

        # get label
        ydset = self.f['group_2']['label']
        y = ydset[:].T
        y = torch.from_numpy(y[index])
        
        # get error in label # comment out for non-error label runs
        errdset = self.f['group_2']['e_label']
        err = errdset[:].T
        err = torch.from_numpy(err[index])
        return (x.float(),y.float(),err.float())

In [3]:
# training_data = train_set("/arc/home/aydanmckay/mydataelabelssmallscalecuts.h5")
data = test_set("/arc/home/aydanmckay/mydataelabelssmallscalecuts.h5")
loaded_data = DataLoader(
    data,
    batch_size=32,
    # shuffle=True
)

In [4]:
model = Net()
model.load_state_dict(torch.load("/arc/home/aydanmckay/torchmodel/torchmodelWL1smallscalecutsep5.pth"))
model.eval()

Net(
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=3, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=256, bias=True)
    (5): ReLU()
    (6): Linear(in_features=256, out_features=512, bias=True)
    (7): ReLU()
    (8): Linear(in_features=512, out_features=110, bias=True)
  )
)

In [103]:
preds = []
covbp = np.zeros((55,55))
covrp = np.zeros((55,55))
covms = []
with torch.no_grad():
    for X, y, z in loaded_data:
        pred = model(X)
        preds.append(pred)
        for m in z:
            for i in range(len(m[:55])):
                covbp[i][i] += m[i].item()**2
                covrp[i][i] += m[55+i].item()**2
            break
        break

In [104]:
covbp.shape

(55, 55)

In [105]:
PhotometricSystem.get_available_systems()

'DECam, Els_Custom_W09_S2, Euclid_VIS, Gaia_2, Gaia_DR3_Vega, Halpha_Custom_AB, H_Custom, Hipparcos_Tycho, HST_ACSWFC, HST_HUGS_Std, HST_WFC3UVIS, HST_WFPC2, IPHAS, JKC, JKC_Std, JPAS, JPLUS, JWST_NIRCAM, PanSTARRS1, PanSTARRS1_Std, Pristine, SDSS, SDSS_Std, Sky_Mapper, Stromgren, Stromgren_Std, WFIRST'

In [106]:
phot_system_list = [
    PhotometricSystem.Gaia_2,
    PhotometricSystem.Gaia_DR3_Vega,
    PhotometricSystem.PanSTARRS1,
    PhotometricSystem.PanSTARRS1_Std,
    PhotometricSystem.Pristine,
    PhotometricSystem.SDSS,
    PhotometricSystem.SDSS_Std
]

In [107]:
len(preds[0][0].numpy()[55:])

55

In [108]:
df = pd.DataFrame({'source_id':1,'bp_coefficients':[preds[0][0].numpy()[:55]],'bp_standard_deviation':np.std(preds[0][0].numpy()[:55]),'bp_coefficient_covariances':[covbp],'rp_coefficients':[preds[0][0].numpy()[55:]],'rp_coefficient_covariances':[covrp],'rp_standard_deviation':np.std(preds[0][0].numpy()[55:])})

In [109]:
df

Unnamed: 0,source_id,bp_coefficients,bp_standard_deviation,bp_coefficient_covariances,rp_coefficients,rp_coefficient_covariances,rp_standard_deviation
0,1,"[-0.66917336, 0.48170173, 0.3844219, -0.878041...",0.395025,"[[0.8794816779873713, 0.0, 0.0, 0.0, 0.0, 0.0,...","[-0.5746207, 0.38471243, 0.55732507, -0.361483...","[[0.8618410848556302, 0.0, 0.0, 0.0, 0.0, 0.0,...",0.332891


In [110]:
f = '/arc/home/aydanmckay/XpContinuousMeanSpectrum_407725-409897.csv'

In [111]:
# synthetic_photometry = generate(f, photometric_system=phot_system_list)
synthetic_photometry = generate(df, photometric_system=PhotometricSystem.Pristine)
synthetic_photometry

                              

Unnamed: 0,source_id,Pristine_mag_CaHK,Pristine_flux_CaHK,Pristine_flux_error_CaHK
0,1,21.31737,2.164373e-19,5.092640999999999e-19


In [63]:
from gaiaxpy.core.satellite import BANDS

In [64]:
for band in BANDS:
    print(band)

bp
rp
