In [1]:
from astropy.table import Table
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import h5py
import matplotlib.pyplot as plt
from gaiaxpy import generate, PhotometricSystem
import pandas as pd
from sklearn.preprocessing import StandardScaler, MinMaxScaler

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# scalers for dataloading
metscaler = MinMaxScaler()
logscaler = MinMaxScaler()
tefscaler = MinMaxScaler()
scalerlist = [MinMaxScaler() for _ in range(110)] # Hardcoded since we know number of xp coefficients is static

# dataloader batchlength, learning rate, epochs for training
batchlen = 32
lr = 1e-2
epochs = 100

class ResBlock(nn.Module):
    def __init__(self, nodes):
        super(ResBlock, self).__init__()
        self.res_block1 = nn.Sequential(
            nn.Linear(nodes,nodes),
            nn.BatchNorm1d(nodes),
            nn.LeakyReLU(),
        )
        self.res_block2 = nn.Sequential(
            nn.Linear(nodes,nodes),
            nn.BatchNorm1d(nodes),
        )
        self.lrelu = nn.LeakyReLU()
    
    def forward(self, x):
        res = x
        x = self.res_block1(x)
        x = self.res_block2(x)
        x = x + res
        output = self.lrelu(x)
        return output
        
class ResNetMcK(nn.Module):
    def __init__(self):
        super(ResNetMcK, self).__init__()
        self.input_block = nn.Sequential(
            nn.Linear(3,16),
            nn.LeakyReLU(),
        )
        self.blocklist = nn.ModuleList([
            ResBlock(16),
            ResBlock(16),
            nn.Linear(16,32),
            ResBlock(32),
            ResBlock(32),
            nn.Linear(32,64),
            ResBlock(64),
            ResBlock(64),
            nn.Linear(64,128),
            ResBlock(128),
            ResBlock(128),
        ])
        self.output_block = nn.Sequential(
            nn.Linear(128,110),
        )
        
    def forward(self,x):
        x = self.input_block(x)
        for i, _ in enumerate(self.blocklist):
            x = self.blocklist[i](x)
        logits = self.output_block(x)
        return logits
    
# defining the Dataset class
class train_set(Dataset):
    def __init__(self,file):
        fn = h5py.File(file, 'r')
        self.f = fn
        
        # get data
        dset = fn['group_1']['data']
        self.x = torch.Tensor(dset[:].T)
        
        # get label
        ydset = self.f['group_1']['label']
        self.y = torch.Tensor(ydset[:].T)
        # torch.from_numpy(y[index]) does not work since y is doubles and not floats.
        
        # get error in label # comment out for non-error label runs
        errdset = self.f['group_1']['e_label']
        self.err = torch.Tensor(errdset[:].T)
        
    def __len__(self):
        return self.f['group_1']['data'].shape[1]
  
    def __getitem__(self, index):
        xg = self.x[index]
        yg = self.y[index]
        errg = self.err[index]
        return (xg,yg,errg)

class valid_set(Dataset):
    def __init__(self,file):
        fn = h5py.File(file, 'r')
        self.f = fn
        
        # get data
        dset = self.f['group_2']['data']
        self.x = torch.Tensor(dset[:].T)
        
        # get label
        ydset = self.f['group_2']['label']
        self.y = torch.Tensor(ydset[:].T)
        # torch.from_numpy(y[index]) does not work since y is doubles and not floats.
        
        # get error in label # comment out for non-error label runs
        errdset = self.f['group_2']['e_label']
        self.err = torch.Tensor(errdset[:].T)
        
    def __len__(self):
        return self.f['group_2']['data'].shape[1]
  
    def __getitem__(self, index):
        xg = self.x[index]
        yg = self.y[index]
        errg = self.err[index]
        return (xg,yg,errg)
    
class new_data_set(Dataset):
    def __init__(self,file,train=True,valid=False):
        fn = h5py.File(file, 'r')
        self.f = fn
        
        # get data
        dset = fn['group_1']['data']
        d = dset[:]
        dat = np.array([
            metscaler.fit_transform(d[[0]].T).flatten(),
            logscaler.fit_transform(d[[1]].T).flatten(),
            tefscaler.fit_transform(d[[2]].T).flatten(),
        ])
        if train:
            self.l = dat.shape[1]
            self.x = torch.Tensor(dat.T)
        elif valid:
            dset = fn['group_2']['data']
            d = dset[:]
            dat = np.array([
                metscaler.transform(d[[0]].T).flatten(),
                logscaler.transform(d[[1]].T).flatten(),
                tefscaler.transform(d[[2]].T).flatten(),
            ])
            self.l = dat.shape[1]
            self.x = torch.Tensor(dat.T)
        
        # get label
        ydset = self.f['group_1']['label']
        yd = ydset[:]
        ydat = np.array([
            scaler.fit_transform(yd[[it]].T).flatten() for it,scaler in enumerate(scalerlist)
        ])
        if train:
            self.y = torch.Tensor(ydat[:].T) # torch.from_numpy(y[index]) does not work since y is doubles and not floats.
        elif valid:
            ydset = self.f['group_2']['label']
            yd = ydset[:]
            ydat = np.array([
                scaler.transform(yd[[it]].T).flatten() for it,scaler in enumerate(scalerlist)
            ])
            self.y = torch.Tensor(ydat[:].T)
        
        # get error in label # comment out for non-error label runs
        if train:
            errdset = self.f['group_1']['e_label']
            self.err = torch.Tensor(errdset[:].T)
        elif valid:
            errdset = self.f['group_2']['e_label']
            self.err = torch.Tensor(errdset[:].T)
        
    def __len__(self):
        return self.l
  
    def __getitem__(self, index):
        xg = self.x[index]
        yg = self.y[index]
        errg = self.err[index]
        return (xg,yg,errg)

cuda


In [3]:
# training_data = train_set("/arc/home/aydanmckay/mydataelabelssmallscalecuts.h5")
# data = valid_set("/arc/home/aydanmckay/smallcutdataMinMaxscaled.h5")
# training_data = new_data_set('/arc/home/aydanmckay/smallcutdata.h5',train=True,valid=False)
valid_data = new_data_set('/arc/home/aydanmckay/smallcutdata.h5',train=False,valid=True)
loaded_data = DataLoader(
    valid_data,
    batch_size=32,
    num_workers=0
    # shuffle=True
)

In [4]:
model = ResNetMcK()
model.load_state_dict(torch.load("/arc/home/aydanmckay/torchresmodel/modelL2smallminmaxscalecutsbl32lr-2wd-5SGDep100new.pth"))
# model.eval()

<All keys matched successfully>

In [5]:
model = model.to(device)

In [6]:
preds = []
covbs = []
covrs = []
with torch.no_grad():
    for X, y, z in loaded_data:
        X = X.to(device)
        y = y.to(device)
        z = z.to(device)
        pred = model(X)
        for prediction,err in zip(pred,z):
            covbp = np.zeros((55,55))
            covrp = np.zeros((55,55))
            preds.append(prediction)
            for it in range(len(err[:55])):
                covbp[it][it] += err[it].item()**2
                covrp[it][it] += err[it+55].item()**2
            covbs.append(covbp)
            covrs.append(covrp)

In [7]:
np.array(covrs).shape

(5000, 55, 55)

In [8]:
# phot_system_list = [
#     PhotometricSystem.Gaia_2,
#     PhotometricSystem.Gaia_DR3_Vega,
#     PhotometricSystem.PanSTARRS1,
#     PhotometricSystem.PanSTARRS1_Std,
#     PhotometricSystem.Pristine,
#     PhotometricSystem.SDSS,
#     PhotometricSystem.SDSS_Std
# ]

In [9]:
# df = pd.DataFrame(
#     {'source_id':range(len(preds)),
#      'bp_coefficients':[pred.to('cpu').numpy()[:55] for pred in preds],
#      'bp_standard_deviation':[np.std(pred.to('cpu').numpy()[:55]) for pred in preds],
#      'bp_coefficient_covariances':covbs,
#      'rp_coefficients':[pred.to('cpu').numpy()[55:] for pred in preds],
#      'rp_coefficient_covariances':covrs,
#      'rp_standard_deviation':[np.std(pred.to('cpu').numpy()[55:]) for pred in preds]
#     }
# )

In [10]:
# df

Unnamed: 0,source_id,bp_coefficients,bp_standard_deviation,bp_coefficient_covariances,rp_coefficients,rp_coefficient_covariances,rp_standard_deviation
0,0,"[0.10593756, 0.6655291, 0.58050627, 0.5698537,...",0.141355,"[[50.824411354017, 0.0, 0.0, 0.0, 0.0, 0.0, 0....","[0.114132665, 0.26792872, 0.41232288, 0.296072...","[[10.894251866433535, 0.0, 0.0, 0.0, 0.0, 0.0,...",0.132816
1,1,"[0.2897923, 0.48539084, 0.68183684, 0.5885987,...",0.142764,"[[57.57570784600466, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.14480898, 0.18635112, 0.37563154, 0.2285097...","[[8.049284034164032, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.139449
2,2,"[0.08459433, 0.65503055, 0.6278038, 0.5765127,...",0.146393,"[[0.6034642279553424, 0.0, 0.0, 0.0, 0.0, 0.0,...","[0.084980555, 0.26499164, 0.40993285, 0.289873...","[[0.26975468997284224, 0.0, 0.0, 0.0, 0.0, 0.0...",0.134833
3,3,"[0.082439564, 0.65433115, 0.60703474, 0.539486...",0.145130,"[[3.5874154757670027, 0.0, 0.0, 0.0, 0.0, 0.0,...","[0.07736059, 0.25649434, 0.4211324, 0.2452985,...","[[1.365683548072468, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.137491
4,4,"[0.14318383, 0.6594771, 0.55993843, 0.58846396...",0.136831,"[[75.07698531556889, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.12613405, 0.26504216, 0.40678883, 0.2996904...","[[22.05784515259529, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.132948
...,...,...,...,...,...,...,...
4995,4995,"[0.220137, 0.568948, 0.62653226, 0.60142946, 0...",0.136549,"[[39.98701003183669, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.14585447, 0.21746223, 0.38856703, 0.2586031...","[[8.878970670616582, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.144862
4996,4996,"[0.08742771, 0.61870503, 0.5750307, 0.5726926,...",0.137805,"[[4.884835755780159, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.10068012, 0.27979735, 0.36410362, 0.2777702...","[[3.958926802659576, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.139909
4997,4997,"[0.17764795, 0.65069497, 0.63860613, 0.6540721...",0.149513,"[[78.68567875133886, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.13621025, 0.2522908, 0.42042232, 0.24695438...","[[20.18509114127687, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0.146926
4998,4998,"[0.18813697, 0.7152061, 0.57346886, 0.5895612,...",0.141624,"[[0.38581238603241275, 0.0, 0.0, 0.0, 0.0, 0.0...","[0.15228955, 0.24471405, 0.44860658, 0.3170434...","[[1.2997614199214382, 0.0, 0.0, 0.0, 0.0, 0.0,...",0.138642


In [11]:
f = "select TOP 1 source_id from gaiadr3.gaia_source where source_id = '4486317061229025280'"

In [12]:
synthetic_photometry = generate(f, photometric_system=PhotometricSystem.Pristine)
# synthetic_photometry = generate(df, photometric_system=PhotometricSystem.Pristine)
# synthetic_photometry

INFO: Query finished. [astroquery.utils.tap.core]
                              

In [13]:
synthetic_photometry

Unnamed: 0,source_id,Pristine_mag_CaHK,Pristine_flux_CaHK,Pristine_flux_error_CaHK
0,4486317061229025280,18.390889,3.2057110000000002e-18,2.908615e-19


In [14]:
# PhotometricSystem.get_available_systems()

In [23]:
preds = np.array([pred.to('cpu').numpy() for pred in preds]).T

In [24]:
preds.shape

(110, 5000)

In [25]:
# rpnewpred = []
xpnewpred = []
for pred,scaler in zip(preds,scalerlist):
    xpnewpred.append(scaler.inverse_transform(np.array([pred]).T).flatten())

In [26]:
bpnews = np.array(xpnewpred[:55]).T
rpnews = np.array(xpnewpred[55:]).T
rpnews.shape

(5000, 55)

In [27]:
bpnews.shape

(5000, 55)

In [28]:
len(preds.T)

5000

In [29]:
dfnew = pd.DataFrame(
    {'source_id':range(len(preds.T)),
     'bp_coefficients':list(bpnews),
     'bp_standard_deviation':[np.std(bp) for bp in bpnews],
     'bp_coefficient_covariances':covbs,
     'rp_coefficients':list(rpnews),
     'rp_coefficient_covariances':covrs,
     'rp_standard_deviation':[np.std(rp) for rp in rpnews]
    }
)

In [30]:
dfnew

Unnamed: 0,source_id,bp_coefficients,bp_standard_deviation,bp_coefficient_covariances,rp_coefficients,rp_coefficient_covariances,rp_standard_deviation
0,0,"[3498.6313, 285.0433, -173.1007, -10.959786, -...",469.522919,"[[50.824411354017, 0.0, 0.0, 0.0, 0.0, 0.0, 0....","[6805.109, 196.58281, 74.35448, 97.03508, 1.70...","[[10.894251866433535, 0.0, 0.0, 0.0, 0.0, 0.0,...",908.957031
1,1,"[9557.244, -1817.2513, 51.59623, 12.684404, -3...",1304.275269,"[[57.57570784600466, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[8599.435, -1396.0289, -105.27336, -54.01748, ...","[[8.049284034164032, 0.0, 0.0, 0.0, 0.0, 0.0, ...",1167.811768
2,2,"[2795.3022, 162.52078, -68.22011, -2.5604293, ...",374.109802,"[[0.6034642279553424, 0.0, 0.0, 0.0, 0.0, 0.0,...","[5099.9365, 139.24323, 62.653748, 83.1751, -1....","[[0.26975468997284224, 0.0, 0.0, 0.0, 0.0, 0.0...",681.170898
3,3,"[2724.2957, 154.35846, -114.27476, -49.26366, ...",365.191010,"[[3.5874154757670027, 0.0, 0.0, 0.0, 0.0, 0.0,...","[4654.228, -26.64663, 117.48272, -16.48218, 16...","[[1.365683548072468, 0.0, 0.0, 0.0, 0.0, 0.0, ...",621.746460
4,4,"[4726.0176, 214.41423, -218.70914, 12.514416, ...",632.934021,"[[75.07698531556889, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[7507.096, 140.22942, 47.261707, 105.1237, -4....","[[22.05784515259529, 0.0, 0.0, 0.0, 0.0, 0.0, ...",1002.712524
...,...,...,...,...,...,...,...
4995,4995,"[7261.8755, -842.1018, -71.03971, 28.868574, -...",979.102783,"[[39.98701003183669, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[8660.588, -788.6574, -41.945774, 13.263403, 5...","[[8.878970670616582, 0.0, 0.0, 0.0, 0.0, 0.0, ...",1163.718018
4996,4996,"[2888.6714, -261.41437, -185.2426, -7.378976, ...",389.632568,"[[4.884835755780159, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[6018.2393, 428.28998, -161.71002, 56.116043, ...","[[3.958926802659576, 0.0, 0.0, 0.0, 0.0, 0.0, ...",805.668457
4997,4997,"[5861.7227, 111.922585, -44.26633, 95.269875, ...",782.802673,"[[78.68567875133886, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[8096.4756, -108.711266, 114.00647, -12.780081...","[[20.18509114127687, 0.0, 0.0, 0.0, 0.0, 0.0, ...",1081.466675
4998,4998,"[6207.3696, 864.7966, -188.70589, 13.898458, -...",835.813538,"[[0.38581238603241275, 0.0, 0.0, 0.0, 0.0, 0.0...","[9036.99, -256.62927, 251.98668, 143.92046, 9....","[[1.2997614199214382, 0.0, 0.0, 0.0, 0.0, 0.0,...",1207.743164


In [31]:
synthetic_photometry = generate(dfnew, photometric_system=PhotometricSystem.Pristine)
synthetic_photometry

                              

Unnamed: 0,source_id,Pristine_mag_CaHK,Pristine_flux_CaHK,Pristine_flux_error_CaHK
0,0,17.444342,7.665549e-18,1.610519e-15
1,1,14.019518,1.796689e-16,5.359299e-15
2,2,16.762064,1.436993e-17,1.523146e-16
3,3,17.005479,1.148387e-17,3.079294e-16
4,4,16.444478,1.925258e-17,2.406260e-15
...,...,...,...,...
4995,4995,14.633889,1.020291e-16,3.103260e-15
4996,4996,16.167465,2.484814e-17,3.547774e-16
4997,4997,15.547913,4.396577e-17,3.185746e-15
4998,4998,16.703126,1.517153e-17,3.215084e-16
