In [2]:
import numpy as np

from data_utils import compute_pairs_euclidean_distances, load_numpy_data, save_numpy_array

from psf_constants import FC_PROCESSED_TRAIN_COMPLEX_FIELDS_PREFIX, \
                          FC_CROPPED_TRAIN_COMPLEX_FIELDS_PREFIX, \
                          FC_PREDICTED_TRAIN_COMPLEX_FIELDS_PREFIX, \
                          FC_PREDICTED_CROPPED_TRAIN_COMPLEX_FIELDS_PREFIX, \
                          FC_PROCESSED_TRAIN_OUTPUT_FLUXES_PREFIX, \
                          PSF_TRAIN_FILE_SUFFIXES, \
                          NUMPY_SUFFIX

In [3]:
def create_path(prefix):
    data_path = f"{prefix}{suffix}{NUMPY_SUFFIX}"
    return data_path

def create_random_pair_indexes(array_n_points, pairs_per_subset=10000):
    selected_pairs = np.random.randint(0, array_n_points, size=(pairs_per_subset, 2))
    selected_pairs = selected_pairs[selected_pairs[:, 0] != selected_pairs[:, 1]]
    
    while selected_pairs.shape[0] < pairs_per_subset:
        more_pairs = np.random.randint(0, 100, size=(pairs_per_subset, 2))
        more_pairs = more_pairs[more_pairs[:, 0] != more_pairs[:, 1]]
        selected_pairs = np.concatenate((selected_pairs, more_pairs))
        
    return selected_pairs[:pairs_per_subset]

# 1. Brute Force

In [6]:
for suffix in PSF_TRAIN_FILE_SUFFIXES[:1]:

    # Define data paths
    fluxes_path = create_path(FC_PROCESSED_TRAIN_OUTPUT_FLUXES_PREFIX)
    og_complex_fields_path = create_path(FC_PROCESSED_TRAIN_COMPLEX_FIELDS_PREFIX)
    og_cropped_complex_fields_path = create_path(FC_CROPPED_TRAIN_COMPLEX_FIELDS_PREFIX)
    predicted_complex_fields_path = create_path(FC_PREDICTED_TRAIN_COMPLEX_FIELDS_PREFIX)
    predicted_cropped_complex_fields_path = create_path(FC_PREDICTED_CROPPED_TRAIN_COMPLEX_FIELDS_PREFIX)

    # Load data
    psf_fluxes = load_numpy_data(fluxes_path)
    og_complex_fields = load_numpy_data(og_complex_fields_path)
    og_cropped_complex_fields = load_numpy_data(og_cropped_complex_fields_path)
    predicted_complex_fields = load_numpy_data(predicted_complex_fields_path)
    predicted_cropped_complex_fields = load_numpy_data(predicted_cropped_complex_fields_path)

    # Select pairs to compute euclidean distances
    n_points = psf_fluxes.shape[0]
    selected_pairs = create_random_pair_indexes(n_points, pairs_per_subset=100)

    # Compute euclidean distances
    psf_fluxes_euclidean_distances = compute_pairs_euclidean_distances(psf_fluxes,
                                                                       selected_pairs)
    
    og_euclidean_distances = compute_pairs_euclidean_distances(og_complex_fields,
                                                               selected_pairs,
                                                               is_complex_field=True)
    
    og_cropped_euclidean_distances = compute_pairs_euclidean_distances(og_cropped_complex_fields,
                                                                       selected_pairs,
                                                                       is_complex_field=True)
    
    predicted_euclidean_distances = compute_pairs_euclidean_distances(predicted_complex_fields, 
                                                                      selected_pairs,
                                                                      is_complex_field=True)
    
    predicted_cropped_euclidean_distances = compute_pairs_euclidean_distances(predicted_cropped_complex_fields,
                                                                              selected_pairs,
                                                                              is_complex_field=True)

    # Merge in columns
    euclidean_distances = np.concatenate((psf_fluxes_euclidean_distances.reshape(-1, 1),
                                          og_euclidean_distances.reshape(-1, 1),
                                          og_cropped_euclidean_distances.reshape(-1, 1),
                                          predicted_euclidean_distances.reshape(-1, 1),
                                          predicted_cropped_euclidean_distances.reshape(-1, 1)), 
                                          axis=1)

(100,)
(100,)
(100,)
(100,)
(100,)


In [7]:
euclidean_distances

array([[  6.05322933,  72.9384079 ,  72.93013763,  57.57764053,
         34.57260132],
       [  5.78940678,  60.13812637,  60.12984085,  61.35387039,
         25.01191902],
       [  5.34978199,  26.47830009,  26.45322037,  26.99607849,
         18.8288002 ],
       [  4.2665205 ,  52.9655838 ,  52.9384346 ,  49.00593948,
         25.73376846],
       [  4.38466406,  54.52734756,  54.52096558,  48.3951683 ,
         25.81070137],
       [  5.34017038,  39.24728775,  39.23397827,  38.43154144,
         21.76660728],
       [  7.99042702,  91.17973328,  91.17436218,  77.78585815,
         29.6790657 ],
       [  4.0512538 ,  37.70767212,  37.6891861 ,  40.44883728,
         48.95627213],
       [  4.25040531,  57.09159088,  57.08198166,  57.70343018,
         29.60367966],
       [  6.76282549,  54.35936737,  54.30332947,  59.50588989,
         45.30532455],
       [  5.98484945,  66.61760712,  66.59954071,  69.38030243,
         22.171875  ],
       [  5.47306013,  44.93622589,  44.921

In [8]:
psf_fluxes_euclidean_distances

array([6.05322933, 5.78940678, 5.34978199, 4.2665205 , 4.38466406,
       5.34017038, 7.99042702, 4.0512538 , 4.25040531, 6.76282549,
       5.98484945, 5.47306013, 5.07047176, 5.86489201, 6.04865551,
       6.23018503, 5.75723791, 5.34744787, 5.90061522, 4.33699894,
       6.13994837, 4.47016048, 5.03127575, 6.90051603, 7.01494026,
       6.05727911, 4.32905531, 5.88534069, 7.48685932, 5.83903122,
       6.69710588, 4.35798025, 5.51715422, 5.95577717, 6.98647785,
       3.82747245, 7.08409119, 6.69926596, 4.31752586, 5.97478867,
       6.16439867, 3.98143148, 5.74632072, 3.9993217 , 8.18000317,
       7.32882023, 4.24619055, 6.93271542, 7.25452089, 7.62883091,
       7.91027212, 5.44464731, 4.76479626, 6.14745569, 4.83521748,
       4.79050779, 5.47213697, 6.41777515, 5.06039095, 6.68316841,
       5.0217557 , 7.32006311, 4.00882196, 5.12279367, 8.08658123,
       6.45821381, 8.48995018, 5.1829319 , 8.35983086, 5.31727076,
       5.2124505 , 6.02598906, 4.37210798, 5.59403324, 5.08607