## Testing Image Segmentation with Napari

In [1]:
# math, array manipulation, etc.
import numpy as np

# timing
from timeit import default_timer

import astropy.io.fits as fits
from astropy.table import Table                    # Table data structure
import astropy.units as u

# necessary utilities from scipy, astropy and photutils
from scipy.optimize import differential_evolution
from scipy.ndimage import maximum_filter, gaussian_filter
from astropy.modeling import functional_models
from astropy.convolution import convolve
from photutils import background

# plots
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

#%matplotlib inline

# MCMC sampling package
import emcee

import os

import sys

from photutils.background import Background2D
from astropy.stats import SigmaClip
from photutils.background import StdBackgroundRMS
from photutils.segmentation import deblend_sources, SegmentationImage, detect_sources

from astropy.convolution import Gaussian2DKernel
from astropy.stats import gaussian_fwhm_to_sigma

import napari

In [2]:
# the following commands make plots look better
def plot_prettier(dpi=200, fontsize=10): 
    plt.rcParams['figure.dpi']= dpi
    plt.rc("savefig", dpi=dpi)
    plt.rc('font', size=fontsize)
    plt.rc('xtick', direction='in')
    plt.rc('ytick', direction='in')
    plt.rc('xtick.major', pad=5) 
    plt.rc('xtick.minor', pad=5)
    plt.rc('ytick.major', pad=5) 
    plt.rc('ytick.minor', pad=5)
    plt.rc('lines', dotted_pattern = [2., 2.])
    # if you don't have LaTeX installed on your laptop and this statement 
    # generates error, comment it out
    plt.rc('text', usetex=True)
    
plot_prettier()

In [3]:
def create_mask(img_r, img_g, num=3):
    '''
    Takes a raw input fits image and constructs a mask to filter out light 
    from other sources apart from the lens/galaxy. Returns the masked image,
    where pixel values from other sources are set to 0.
    
    '''
    
    # construct a background image w/ background noise
    bkg = Background2D(img_g, box_size=5)
    bkg_img = bkg.background
    
    # calculate RMS of each pixel, used to calculate threshold for source identification
    sigma_clip = SigmaClip(sigma=3.0, maxiters=10)
    bkgrms = StdBackgroundRMS(sigma_clip)
    bkgrms_img = bkgrms(img_g) 

    # map of thresholds over which sources are detected
    threshold = bkg_img + (0.5 * bkgrms_img)  
    
    # source detection
    sigma = 3.0 * gaussian_fwhm_to_sigma  # FWHM = 3.
    kernel = Gaussian2DKernel(sigma, x_size=3, y_size=3).normalize()
    segm = detect_sources(img_g, threshold, 5, kernel)
    
    # deblending sources, looking for saddles between peaks in flux
    segm_deblend = deblend_sources(img_g, segm, npixels=5
                                   , nlevels=32, contrast=0.001)
                
    label = segm_deblend.data[(segm_deblend.data.shape[0]//2, segm_deblend.data.shape[1]//2)]
    other_inds = np.delete(np.arange(1, segm_deblend.nlabels+1), label-1)
    
    deblend_copy = segm_deblend.data.copy()
    source = (deblend_copy == label)
    
    # get pixels from all other sources
    deblend_copy2 = segm_deblend.copy()
    deblend_copy2.keep_labels(other_inds)

    segm_dilated_arr = maximum_filter(deblend_copy2.data, num)
    segm_dilated_arr[deblend_copy2.data != 0] = deblend_copy2.data[deblend_copy2.data != 0]
        
    # label central source, which is the lens/galaxy
    segm_dilated_arr[source] = 10000
    
    segm_dilated = SegmentationImage(segm_dilated_arr)
    
    other_inds = np.delete(segm_dilated.labels, -1)
        
    # get pixels from all other sources
    segm_dilated.keep_labels(other_inds)
    
    mask = (segm_dilated.data > 0)

    # make values of those pixels 0 in both img_gal and uncertainties
    gal_copy_g = img_r.copy()
    gal_copy_g[mask] = 0
    gal_copy_r = img_g.copy()
    gal_copy_r[mask] = 0
    
    return mask, gal_copy_r, gal_copy_g

In [4]:
def read_fits_table(filename):
    '''
    reads and returns data in a table from a FITS file
    '''
    hdu = fits.open(filename)
    data = Table(hdu[1].data)   # second index of the hdu corresponds to the data in my astropy table files
    hdu.close()
    
    return data

def read_fits_image(filename):
    '''
    reads and returns an image from a FITS file
    '''
    hdu = fits.open(filename)
    data = hdu[0].data
    header = hdu[0].header 
    hdu.close()
    
    return data

def get_files(path):
    files = np.array(os.listdir(path))
    finds = [i for i, f in enumerate(files) if '.fits' in f]
    files = files[finds]
    return files

In [5]:
# read test image

jr_path = '/Users/aidan/Desktop/sl_project/img_cutouts/sl_jacobs/rband_dr2/'
jg_path = '/Users/aidan/Desktop/sl_project/img_cutouts/sl_jacobs/gband_dr2/'

img_r = read_fits_image(jr_path + 'DESJ001424.2784+004145.4560_r.fits')
img_g = read_fits_image(jg_path + 'DESJ001424.2784+004145.4560_g.fits')

mask, gal_img_r, gal_img_g = create_mask(img_r, img_g)


  from .autonotebook import tqdm as notebook_tqdm
100%|███████████████████████████████████████████| 28/28 [00:00<00:00, 81.84it/s]


In [6]:
# napari segmentation test

### pseudocode
# create two 'labels' where one is the raw image and the other is the initial mask
# use the paintbrush tool to add onto the mask label
from skimage import data

In [7]:

# create the viewer and add the coins image
viewer = napari.view_image(img_r, name='r-band')
# add the labels
labels_layer = viewer.add_labels(mask, name='segmentation')


In [19]:
viewer = napari.view_image(img_r)

In [7]:
from napari.settings import get_settings #; print(get_settings().config_path)

settings = get_settings()

settings.experimental.octree = False


  warn(


In [24]:
! napari reset

Traceback (most recent call last):
  File "/Users/aidan/opt/anaconda3/envs/gal-gal-sel/bin/napari", line 10, in <module>
    sys.exit(main())
  File "/Users/aidan/opt/anaconda3/envs/gal-gal-sel/lib/python3.10/site-packages/napari/__main__.py", line 446, in main
    _run()
  File "/Users/aidan/opt/anaconda3/envs/gal-gal-sel/lib/python3.10/site-packages/napari/__main__.py", line 311, in _run
    viewer._window._qt_viewer._qt_open(
  File "/Users/aidan/opt/anaconda3/envs/gal-gal-sel/lib/python3.10/site-packages/napari/_qt/qt_viewer.py", line 754, in _qt_open
    self.viewer.open(
  File "/Users/aidan/opt/anaconda3/envs/gal-gal-sel/lib/python3.10/site-packages/napari/components/viewer_model.py", line 941, in open
    layers = self._open_or_raise_error(
  File "/Users/aidan/opt/anaconda3/envs/gal-gal-sel/lib/python3.10/site-packages/napari/components/viewer_model.py", line 1012, in _open_or_raise_error
    raise NoAvailableReaderError(
napari.errors.reader_errors.NoAvailableReaderError: No 