In [1]:
import numpy as np
from astropy.table import Table, MaskedColumn, join
from tqdm.notebook import tqdm
import h5py
from scipy.spatial.distance import cdist
import matplotlib.pyplot as plt

In [2]:
prod = '/home/aalvarez/Work/Data/SC8/PHZ_Prod_18oct2021/'

In [3]:
ref = Table.read(os.path.join(prod, 'ProductionInputEcdmHandler/galaxy_sed_reference_sample_dir.dir/phot.fits'), hdu=1)
len(ref)

497533

In [14]:
target = Table.read(os.path.join(prod, 'GalaxyFilteringTask/galaxy_catalog.fits'))
len(target)

92047

In [15]:
truth = Table.read(os.path.join(prod, 'NnpzStarSed/bruteforce.fits'))

In [16]:
config = dict()
exec(open(os.path.join(prod, 'data/nnpz_star_sed.conf')).read(), config)

In [17]:
# Missing bands!
ref_filters = config['reference_sample_phot_filters']
target_filters = config['target_catalog_filters']

ref_selected = []
target_selected = []
for rname, (tname, terr) in zip(ref_filters, target_filters):
    if not isinstance(target[tname], MaskedColumn) or not target[tname].mask.all():
        ref_selected.append(rname)
        target_selected.append((tname, terr))
print(len(ref_selected), ref_selected)

9 ['lsst/u', 'lsst/g', 'lsst/r', 'lsst/i', 'lsst/z', 'euclid/VIS', 'euclid/Y', 'euclid/J', 'euclid/H']


In [18]:
joint = join(truth, target, 'OBJECT_ID')
len(joint)

100

# Create a similar output

In [19]:
ref_photo = np.zeros((len(ref), len(ref_selected), 2), dtype=np.float32)
for i, rname in enumerate(ref_selected):
    ref_photo[:, i, 0] = ref[rname]

In [20]:
target_photo = np.zeros((len(joint), len(target_selected), 2), dtype=np.float32)
for i, (tname, terr) in enumerate(target_selected):
    target_photo[:,i,0] = joint[tname]
    target_photo[:,i,1] = joint[terr]

In [21]:
distances = []
neighbors = []
for tp, nn, ns, nw in tqdm(zip(target_photo, joint['NEIGHBOR_IDS'], joint['NEIGHBOR_SCALING'], joint['NEIGHBOR_WEIGHTS']), total=len(target_photo)):
    # The test will check the index, not the ID
    neighbors.append(np.argwhere(np.in1d(ref['ID'], nn))[:,0])
    rmask = np.in1d(ref['ID'], nn)
    rp = ref_photo[rmask]
    distances.append(1/nw)
neighbors = np.asarray(neighbors, dtype=int)
distances = np.asarray(distances, dtype=np.float32)

  0%|          | 0/100 [00:00<?, ?it/s]

In [22]:
with h5py.File('/home/aalvarez/Tools/ann-benchmarks/data/sc8-stars.hdf5', 'w') as out:
    out.attrs['distance'] = 'chi2-scaled'
    out.create_dataset('distances', distances.shape, distances.dtype)[:] = distances
    out.create_dataset('neighbors', neighbors.shape, neighbors.dtype)[:] = neighbors
    out.create_dataset('train', ref_photo.shape, ref_photo.dtype)[:] = ref_photo
    out.create_dataset('test', target_photo.shape, target_photo.dtype)[:] = target_photo