In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colors
import pandas as pd
from scipy.stats import median_abs_deviation as med_abs_dev
from tqdm import tqdm

from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.neighbors import KernelDensity
import seaborn as sns

## Load image, mask, metadata, and cutouts

In [None]:
img_r = np.load('images/jades_r.npz')['channel']
img_g = np.load('images/jades_g.npz')['channel']
img_b = np.load('images/jades_b.npz')['channel']

In [None]:
img = np.array([img_r, img_g, img_b])
img = np.transpose(img, axes=(1,2,0))

In [None]:
labels = np.load('mask_labels.npz')['mask']
mask = labels > 0

In [None]:
df = pd.read_csv('jades.csv')

In [None]:
sources_r = np.load('images/jades_sources_r.npz')['sources']
sources_g = np.load('images/jades_sources_g.npz')['sources']
sources_b = np.load('images/jades_sources_b.npz')['sources']

In [None]:
sources_int = np.array([sources_r, sources_g, sources_b])
sources = np.transpose(sources_int, axes=(1,2,3,0)) / 255

In [None]:
n_sources, length, _, n_channels = sources.shape

## Plot image, mask, and sources

In [None]:
fig, axs = plt.subplots(1,2,figsize=[20,10])
axs[0].imshow(img)
axs[1].imshow(mask)

In [None]:
inds = np.random.randint(0, sources.shape[0], (4,4))

In [None]:
fig, axs = plt.subplots(4,4,figsize=[10,10])
for i in range (4):
    for j in range (4):
        ind = inds[i, j]
        y_mean, x_mean, source_size = df.iloc[ind][['y_mean', 'x_mean', 'source_size']].astype(int)
        title = f'y:{y_mean}, x:{x_mean}\n size:{source_size}, label:{ind}'
        axs[i, j].set_title(title)
        axs[i, j].imshow(sources[ind])
        axs[i, j].axis('off')
fig.tight_layout()

## Stats

In [None]:
for func in [np.mean, np.std, np.max, np.min]:
    fig, axs = plt.subplots(1,4,figsize=[20,5])
    axs[0].set_title(func.__name__)
    sources_func = func(sources, axis=0)
    axs[0].imshow(sources_func)
    axs[1].imshow(sources_func[:, :, 0], cmap='Reds')
    axs[2].imshow(sources_func[:, :, 1], cmap='Greens')
    axs[3].imshow(sources_func[:, :, 2], cmap='Blues')
    plt.show()

In [None]:
for func in [np.mean, np.std, np.max, np.min]:
    fig, axs = plt.subplots(1,2,figsize=[20,5])
    axs[0].set_title(func.__name__)
    sources_func = func(sources, axis=(1,2,3))
    sources_func_color = func(sources, axis=(1,2))
    axs[0].hist(sources_func, bins=50)
    axs[1].hist(sources_func_color[:, 0], bins=50, color='red', alpha=0.5)
    axs[1].hist(sources_func_color[:, 1], bins=50, color='green', alpha=0.5)
    axs[1].hist(sources_func_color[:, 2], bins=50, color='blue', alpha=0.5)
    plt.show()

In [None]:
plt.imshow(np.max(sources, axis=(0)), interpolation='bicubic')

## PCA

### Projection

In [None]:
n = 128
pca = PCA(n_components=n, whiten=True)

In [None]:
pca_sources = pca.fit_transform(sources.reshape(sources.shape[0], -1))

In [None]:
pca_sources.shape

In [None]:
pca.explained_variance_ratio_.sum()

In [None]:
plt.plot(pca.explained_variance_ratio_)
plt.yscale('log')

In [None]:
plt.scatter(pca_sources[:, 0], pca_sources[:, 1], alpha=0.1, s=1)

### Inverse transform

In [None]:
ind = np.random.randint(0, n_sources)
source = sources[ind]
source_inv = pca.inverse_transform(pca_sources[ind]).reshape(length,length,n_channels)
res = source - source_inv
fig, axs = plt.subplots(2,3,figsize=[15,10])
axs[0,0].set_title('Source')
axs[0,1].set_title('Inverse Transform')
axs[0,2].set_title('Squared Residuals')
axs[1,0].set_title('Residuals (scatter)')
axs[1,1].set_title('Residuals (hist)')
axs[1,2].set_title('Correlation')
axs[1,2].set_xlabel('Source')
axs[1,2].set_ylabel('Inverse Transform')
axs[0,0].imshow(source)
axs[0,1].imshow(source_inv)
axs[0,2].imshow(res ** 2)
axs[1,2].plot([0,1], [0, 1], color='k')
for i, color in enumerate('rgb'):
    axs[1,0].grid()
    axs[1,0].scatter(np.arange(length**2), res[:, :, i].flatten(), s=1, alpha=0.1, color=color)
    axs[1,1].grid()
    axs[1,1].hist(res[:, :, i].flatten(), bins=100, alpha=0.5, color=color)
    axs[1,2].grid()
    r2_color = np.corrcoef(source[:, :, i].flatten(), source_inv[:, :, i].flatten())[0,1]**2
    axs[1,2].scatter(source[:, :, i].flatten(), source_inv[:, :, i].flatten(), 
                     s=1, alpha=0.1, color=color, label=fr'$R^2$ = {r2_color:.2f}')
axs[1,2].legend()
fig.tight_layout()

res_mean = res.mean()
res_std = res.std()
mse = np.mean(res ** 2)
r2 = np.corrcoef(source.flatten(), source_inv.flatten())[0,1]**2
print (f'Residuals: {res_mean:.5f} +/- {res_std:.5f}')
print (f'MSE: {mse:.5f}')
print (f'R2: {r2:.5f}')

### PC Space interpolation

In [None]:
n_box = 16
ind1, ind2 = np.random.randint(0, n_sources, 2)
pc_interp = np.linspace(pca_sources[ind1], pca_sources[ind2], n_box)
pc_interp_inv = pca.inverse_transform(pc_interp).reshape(n_box, length, length, n_channels)

fig, axs = plt.subplots(4,4,figsize=[10,10])
for i in range (4):
    for j in range (4):
        ind = i*4+j
        axs[i, j].imshow(pc_interp_inv[ind])
        axs[i, j].axis('off')
fig.tight_layout()

### Eigenimages

In [None]:
eigenimages = pca.components_.reshape(n, length, length, n_channels)

In [None]:
eigenimages_scale = (eigenimages - eigenimages.min()) / (eigenimages.max() - eigenimages.min())

In [None]:
for i, ei in enumerate(eigenimages_scale[:16]):
    fig, axs = plt.subplots(1,4,figsize=[20,5])
    axs[0].set_title(i)
    axs[0].imshow(ei)
    axs[1].imshow(eigenimages[i, :, :, 0], cmap='Reds')
    axs[2].imshow(eigenimages[i, :, :, 1], cmap='Greens')
    axs[3].imshow(eigenimages[i, :, :, 2], cmap='Blues')
    plt.show()

## KMeans of PCs

In [None]:
n_clusters = 10
kmeans = KMeans(n_clusters=n_clusters)

In [None]:
kmeans_pred = kmeans.fit_predict(pca_sources[:, :16])

In [None]:
plt.scatter(pca_sources[:, 0], pca_sources[:, 1], alpha=0.5, s=1, c=kmeans_pred, cmap='tab10')
plt.colorbar()

In [None]:
for i in range (n_clusters):
    mask_km = kmeans_pred == i
    fig, axs = plt.subplots(1,2,figsize=[10,5])
    axs[0].scatter(pca_sources[:, 0][mask_km], pca_sources[:, 1][mask_km], color= f'C{i}')
    axs[1].imshow(np.mean(sources[mask_km], axis=0))
    axs[0].set_xlim(-2, 7)
    axs[0].set_ylim(-5, 9)
    plt.show()
    inds = np.random.randint(0, mask_km.sum(), 8)
    fig, axs = plt.subplots(1,8,figsize=[40,5])
    for j in range (8):
        axs[j].imshow(sources[mask_km][inds[j]])
        axs[j].axis('off')
    plt.show()

## KDE of PCs

### Globally

- bw=0.1 is not smooth enough to sufficiently sample space
- Also suffers from curse of dimensionality (sampling a point in a 128d space will be close to a training sample)
- Solution:
  - Sample the first few principal components as a KDE
  - Independently sample the remaining principal components as a standard normal (based on PC KDEs)

In [None]:
bw = 0.1
fig, axs = plt.subplots(1,2,figsize=[10,5])
sns.kdeplot(x=pca_sources[:, 0], ax=axs[0], bw_method=bw, color='k', label='PC0')
for i in range(1,n):
    sns.kdeplot(x=pca_sources[:, i], ax=axs[0], bw_method=bw, color=f'C{i}')
    axs[0].set_xlim(-2,2)
axs[0].legend()
sns.kdeplot(x=pca_sources[:, 0], y=pca_sources[:, 1], ax=axs[1], levels=10, bw_method=bw)

In [None]:
kde = KernelDensity(bandwidth=bw)
kde_sources = kde.fit(pca_sources[:, 0:1])

In [None]:
n_box = 16
kde_samples0 = kde_sources.sample(n_box)
dim1 = pca_sources.shape[1] - kde_samples0.shape[1]
kde_samples1 = np.random.standard_normal((n_box, dim1))
kde_samples = np.concatenate((kde_samples0, kde_samples1), axis=1)
kde_samples_inv = pca.inverse_transform(kde_samples).reshape(n_box,length,length,n_channels)

In [None]:
for ind in range (n_box):
    fig, axs = plt.subplots(1,5,figsize=[25,5])
    axs[0].set_title('KDE Sample Inversed')
    axs[0].imshow(kde_samples_inv[ind])
    axs[0].axis('off')
    dist = np.sum((pca_sources[:, :16] - kde_samples[ind, :16]) ** 2, axis=1)
    inds_min = np.argsort(dist)[:4]
    for j in range (1, 5):
        ind_min = inds_min[j-1]
        axs[j].set_title(f'Nearest Neighbor\n distance = {dist[ind_min]:.3f}')
        axs[j].imshow(sources[ind_min])
        axs[j].axis('off')
    fig.tight_layout()
    plt.show()

### By Cluster

- Sampling method weaker when distributions aren't normal or smooth

In [None]:
col = []
for i in range (pca_sources.shape[1]):
    col.append(f'pca{i}')
df_pca_kmeans = pd.DataFrame(pca_sources, columns=col)
df_pca_kmeans['cluster'] = kmeans_pred

In [None]:
fig, axs = plt.subplots(1,2,figsize=[10,5])
for i in range(n_clusters):
    mask_km = kmeans_pred == i
    sns.kdeplot(data=df_pca_kmeans[mask_km], x='pca0', ax=axs[0], color=f'C{i}')
    sns.kdeplot(data=df_pca_kmeans[mask_km], x='pca0', y='pca1', ax=axs[1], color=f'C{i}', levels=5)

In [None]:
mask_km = kmeans_pred == 5

In [None]:
kde = KernelDensity(bandwidth=bw)
kde_sources = kde.fit(pca_sources[:, :8][mask_km])

In [None]:
n_box = 16
kde_samples0 = kde_sources.sample(n_box)
dim1 = pca_sources.shape[1] - kde_samples0.shape[1]
kde_samples1 = np.random.standard_normal((n_box, dim1))
kde_samples = np.concatenate((kde_samples0, kde_samples1), axis=1)
kde_samples_inv = pca.inverse_transform(kde_samples).reshape(n_box,length,length,n_channels)

In [None]:
for ind in range (n_box):
    fig, axs = plt.subplots(1,5,figsize=[25,5])
    axs[0].set_title('KDE Sample Inversed')
    axs[0].imshow(kde_samples_inv[ind])
    axs[0].axis('off')
    dist = np.sum((pca_sources[mask_km][:, :16] - kde_samples[ind, :16]) ** 2, axis=1)
    inds_min = np.argsort(dist)[:4]
    for j in range (1, 5):
        ind_min = inds_min[j-1]
        axs[j].set_title(f'Nearest Neighbor\n distance = {dist[ind_min]:.3f}')
        axs[j].imshow(sources[mask_km][ind_min])
        axs[j].axis('off')
    fig.tight_layout()
    plt.show()