In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colors
import pandas as pd
from scipy.stats import median_abs_deviation as med_abs_dev
from scipy.ndimage import label
from imageio.v3 import imread
from skimage.transform import resize
from skimage.filters import gaussian
from tqdm import tqdm

In [None]:
def label_cmap (labels):
    """
    Source:
    https://github.com/nmiles2718/hst_cosmic_rays
    """
    ncolors = np.max(labels) + 1
    prng = np.random.RandomState(1234)
    h = prng.uniform(low=0.0, high=1.0, size=ncolors)
    s = prng.uniform(low=0.2, high=0.7, size=ncolors)
    v = prng.uniform(low=0.5, high=1.0, size=ncolors)
    hsv = np.dstack((h, s, v))

    rgb = np.squeeze(colors.hsv_to_rgb(hsv))
    rgb[0] = (0,0,0)
    cmap = colors.ListedColormap(rgb)

    return cmap

## Save image by channel
The powerpoint has samples of Webb public release images. Download an image from the [Webb site](https://webbtelescope.org/images) and save each channel (if they need to be pushed to a GitHub repo). This example uses the JWST Advanced Deep Extragalactic Survey (JADES) taken by NIRCam (released June 2023).

In [None]:
#img = imread('STScI-01H1Q2KSWH9JQW20MA3WVQ48RS.tif')
#np.savez_compressed('images/jades_r.npz', channel=img[:, :, 0])
#np.savez_compressed('images/jades_g.npz', channel=img[:, :, 1])
#np.savez_compressed('images/jades_b.npz', channel=img[:, :, 2])

## Load and plot image/histogram

In [None]:
img_r = np.load('images/jades_r.npz')['channel']
img_g = np.load('images/jades_g.npz')['channel']
img_b = np.load('images/jades_b.npz')['channel']

In [None]:
image = np.array([img_r, img_g, img_b])
image = np.transpose(image, axes=(1,2,0))

In [None]:
ymin, ymax, xmin, xmax = 4000, 5000, 2000, 3000
img = image[ymin:ymax, xmin:xmax]
fig, axs = plt.subplots(1,2,figsize=[20,10], dpi=100)
axs[0].imshow(image)
axs[0].plot([xmin, xmax], [ymin, ymin], color='C3')
axs[0].plot([xmin, xmax], [ymax, ymax], color='C3')
axs[0].plot([xmin, xmin], [ymin, ymax], color='C3')
axs[0].plot([xmax, xmax], [ymin, ymax], color='C3')
axs[1].imshow(img)

In [None]:
fig,axs = plt.subplots(1,2,figsize=[10,5])
axs[0].hist(image[:, :, 0].flatten(), bins=100, alpha=0.5, color='r')
axs[0].hist(image[:, :, 1].flatten(), bins=100, alpha=0.5, color='g')
axs[0].hist(image[:, :, 2].flatten(), bins=100, alpha=0.5, color='b')
axs[1].hist(img[:, :, 0].flatten(), bins=100, alpha=0.5, color='r')
axs[1].hist(img[:, :, 1].flatten(), bins=100, alpha=0.5, color='g')
axs[1].hist(img[:, :, 2].flatten(), bins=100, alpha=0.5, color='b')
axs[0].set_yscale('log')
axs[1].set_yscale('log')

## Detect sources by thresholding a Gaussian smoothed image

- Raw pixel thresholding (more noise gets detected)
  - mean + 1 * std ~ 95 percentile
  - median + 3 * mad ~ 90 percentile
- Sobel edge detection (more holes and less spherical)
- Blob detection using Laplacian of Gaussian (only detects positions, not accurate enough for size)
- Adaptive thresholding (opencv implementation is difficult to optimize; regular thresholding should suffice since background is constant and black)
  - [Example](https://docs.opencv.org/4.x/d7/d4d/tutorial_py_thresholding.html)
  - [Docstring](https://docs.opencv.org/4.x/d7/d1b/group__imgproc__misc.html#ga72b913f352e4a1b1b397736707afcde3)
  - [Source code](https://github.com/opencv/opencv/blob/4.x/modules/imgproc/src/thresh.cpp#L1673)
- Gaussian smoothing with 3 sigma and 95% thresholds are hyperparameters
  - High sigma --> more smoothing --> fainter sources are lost
  - High threshold --> more pixels masked --> fainter sources are lost
  - 3 and 95% were sufficient enough for our purposes
    - ~10k cutouts
    - log(Source size) distribution ~ normal around 2 with a skew
  - May be worth while to find what sources are lost at < 3 sigma and < 95% threshold
  - Can also use binary morphology (dilation/erosion) to remove small "noise"

In [None]:
img = image

In [None]:
img_scale = img / 255

In [None]:
img_scale_gauss = gaussian(img_scale, sigma=3)

In [None]:
thresh_gauss = np.percentile(img_scale_gauss, 95, axis=(0,1))

In [None]:
masks_gauss = img_scale_gauss > thresh_gauss

In [None]:
mask_gauss = masks_gauss.sum(2) == 3

In [None]:
fig, axs = plt.subplots(1,2,figsize=[10,5])
axs[0].imshow(img_scale_gauss)
axs[1].imshow(mask_gauss)

In [None]:
labels, num_feat = label(mask_gauss, structure=np.ones((3,3)))
sizes = np.bincount(labels.flatten())
plt.hist(np.log10(sizes[1:]), bins=50)
plt.yscale('log')
num_feat

In [None]:
mask = np.broadcast_to(mask_gauss.reshape(mask_gauss.shape[0],mask_gauss.shape[1],1), img.shape)

In [None]:
fig, axs = plt.subplots(1,1,dpi=200)
axs.imshow(img*~mask)

## Make cutouts

In [None]:
def make_source_cutout(source_label, labels, img, size=128, pad=10, plot=False):
    
    # Get positions/size of labeled source in image
    y, x = np.where(labels==source_label)
    source_size = y.shape[0]
    y_min = y.min()
    y_max = y.max()
    x_min = x.min()
    x_max = x.max()
    
    # Get start and end indices of source
    if y_min < pad:
        y_start = 0
    else:
        y_start = y_min - pad
    y_end = y_max + pad
    
    if x_min < pad:
        x_start = 0
    else:
        x_start = x_min - pad
    x_end = x_max + pad
    loc = [y.mean(), x.mean(), y_start, y_end, x_start, x_end, source_size]
    
    # Retrieve and resize source cutout
    source_img = img[y_start:y_end, x_start:x_end]
    source_img_resize = resize(source_img, (size, size, 3), order=3)
    
    # Plot
    if plot:
        source_labels = labels[y_start:y_end, x_start:x_end]
        fig, axs = plt.subplots(1,3,figsize=[30,10])
        title = f'y:{int(loc[0])}, x:{int(loc[1])}, size:{source_size}, label:{source_label}'
        axs[0].set_title(title)
        axs[0].imshow(source_labels)
        axs[1].imshow(source_img)
        axs[2].imshow(source_img_resize)
        plt.show()
    
    return loc, source_img_resize

In [None]:
locs = []
sources = []
for i in tqdm(range(1, num_feat+1)):
    loc, source = make_source_cutout(i, labels, img, plot=False)
    locs.append(loc)
    sources.append(source)

## Save mask, metadata, and cutouts

In [None]:
np.savez_compressed('mask_labels.npz', mask=labels)

In [None]:
df = pd.DataFrame(np.array(locs), columns=['y_mean', 'x_mean', 'y_min', 'y_max', 'x_min', 'x_max', 'source_size'])

In [None]:
df.to_csv('jades.csv', index=False)

In [None]:
sources = np.array(sources)

In [None]:
sources_int = (sources*255).astype(np.uint8)

In [None]:
#np.savez_compressed('images/jades_sources_r.npz', sources=sources_int[:, :, :, 0])
#np.savez_compressed('images/jades_sources_g.npz', sources=sources_int[:, :, :, 1])
#np.savez_compressed('images/jades_sources_b.npz', sources=sources_int[:, :, :, 2])