In [1]:
%matplotlib notebook
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
from astropy.table import Table

In [2]:
t = Table.read('/tmp/nnpz_compare/r_shift9/n_real_65ce2ef21faf0cdf5edae8fb253e6b25.fits')
target_obj = t[1760]



### Cumulative plot of normalized weights
Note that this is the "worst" case found, meaning its neighbour weights are very similar.

In [4]:
plt.figure()
plt.plot(np.arange(1000), np.cumsum(
    np.flip(np.sort(target_obj['NeighborWeights']), axis=0) / np.sum(target_obj['NeighborWeights'])
))

<IPython.core.display.Javascript object>

[<matplotlib.lines.Line2D at 0x7f550c91d160>]

### Histogram with different neighbour weights (not normalized)

In [5]:
plt.figure()
_ = plt.hist(target_obj['NeighborWeights'])

<IPython.core.display.Javascript object>

### Scatter plot for different filter combinations

In [7]:
plt.figure()
plt.subplot(2, 2, 1)
plt.scatter(t['g_obs_mag'], t['r_obs_mag'])
plt.xlabel('g'); plt.ylabel('r')
plt.subplot(2, 2, 2)
plt.scatter(t['r_obs_mag'], t['Y_obs_mag'])
plt.xlabel('r'); plt.ylabel('Y')
plt.subplot(2, 2, 3)
plt.scatter(t['g_obs_mag'], t['Y_obs_mag'])
plt.xlabel('g'); plt.ylabel('Y')
plt.tight_layout()

<IPython.core.display.Javascript object>

In [8]:
reference = Table.read('/tmp/nnpz_compare/photometry_8b6c4039f17794ee5b63d26f92624919.fits')



In [9]:
def plot_scatter(target, reference, x, y):
    plt.scatter(reference[x], reference[y])
    plt.scatter(target[x+'_obs'], target[y+'_obs'])
    plt.xlabel(x)
    plt.ylabel(y)

In [10]:
def get_neighbors_and_weights(target, reference_catalog):
    types=list(map(lambda d: d[1], reference.dtype.descr))
    neighbors = Table(None, names=tuple(reference.colnames), dtype=tuple(types))
    neighbors_w = []
    for n, w in zip(target['NeighborIDs'], target['NeighborWeights']):
        neighbors.add_row(reference[reference['ID'] == n][0])
        neighbors_w.append(w)
    return neighbors, neighbors_w

In [11]:
neighbors, neighbors_w = get_neighbors_and_weights(target_obj, reference)

In [12]:
def plot_neighbours_plane(target, neighbors, neighbors_w, x, y, color_map):  
    if type(x) == tuple:
        nx, tx = x
    else:
        nx, tx = x, x
    
    if type(y) == tuple:
        ny, ty = y
    else:
        ny, ty = y, y
    
    scatter = plt.scatter(neighbors[nx], neighbors[ny], c=neighbors_w, cmap=color_map)
    plt.errorbar(
        [target[f'{tx}_obs']], [target[f'{ty}_obs']],
        xerr=[target[f'{tx}_obs_err']],
        yerr=[target[f'{ty}_obs_err']],
        fmt='o', color='r'
    )
    plt.xlabel(x)
    plt.ylabel(y)
    plt.colorbar(scatter)

### Scatter plot of neighbours against target (in red), in context (full reference sample)
Neighbours are plotted using a colormap that goes from pink (lower weight) to cyan (highest weight)

In [13]:
color_map = plt.get_cmap('cool_r')
plt.figure()
plt.subplot(2, 2, 1)
plt.scatter(reference['g'], reference['r'])
plot_neighbours_plane(target_obj, neighbors, neighbors_w, 'g', ('r', 'r_shift9'), color_map)
plt.subplot(2, 2, 2)
plt.scatter(reference['r'], reference['Y'])
plot_neighbours_plane(target_obj, neighbors, neighbors_w, ('r', 'r_shift9'), 'Y', color_map)
plt.subplot(2, 2, 3)
plt.scatter(reference['g'], reference['Y'])
plot_neighbours_plane(target_obj, neighbors, neighbors_w, 'g', 'Y', color_map)
plt.tight_layout()

<IPython.core.display.Javascript object>

### 3D scatter plot of neighbours against target (in red), in context (full reference sample)
Same color schema.

In [14]:
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes

In [16]:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

ref_subset = reference[(
    (reference['Y'] > 0.2) & (reference['Y'] < 0.7) &
    (reference['g'] > 0.20) & (reference['g'] < 0.29) &
    (reference['r'] > 0.30) & (reference['r'] < 0.5)
)]
ax.scatter(ref_subset['g'], ref_subset['r'], ref_subset['Y'], color='b', alpha=0.02, marker='.')

ax.scatter(neighbors['g'], neighbors['r'], neighbors['Y'], c=neighbors_w, cmap=color_map, alpha=0.2, marker='x')
ax.scatter([target_obj['g_obs']], [target_obj['r_shift9_obs']], [target_obj['Y_obs']], color='r', alpha=1., marker='o')

ax.set_xlabel('g')
ax.set_ylabel('r')
ax.set_zlabel('Y')

<IPython.core.display.Javascript object>

Text(0.5,0,'Y')