In [3]:
from astropy.nddata.utils import Cutout2D
from astropy.stats import SigmaClip, sigma_clipped_stats
from astropy.table import Table, QTable
from astropy import units as u
from astropy.wcs import WCS
from astropy.io import fits
from pathlib import Path
import os
from os import path
from glob import glob
import sys
import re
import csv
import math
import random
import copy
import subprocess
import threading
import time
import timeit
import importlib
from scipy import ndimage

import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
import cv2
import imageio

from scipy.ndimage import gaussian_filter, zoom, uniform_filter, convolve
from scipy.stats import kde, lognorm, scoreatpercentile

from skimage import measure, color, data, exposure
from skimage.morphology import skeletonize, disk, ball, binary_dilation
from skimage.filters import meijering, sato, frangi, hessian, rank
from skimage.util import img_as_ubyte
from skimage.util.dtype import dtype_range
from skimage.restoration import inpaint

from reproject import reproject_exact, reproject_adaptive, reproject_interp

from photutils.background import Background2D, MedianBackground, LocalBackground, MMMBackground
from photutils.psf import CircularGaussianPRF, make_psf_model_image, PSFPhotometry, SourceGrouper
from photutils.segmentation import detect_sources, deblend_sources

import AnalysisFuncs as AF
import sys


def cropNanBorder(image, expected_shape):

    """
    Crop the nan border away such that the image fits the expected shape

    Parameters: 
    -image (float): The data to crop. 2D array. 
    -expected_shape (int): the expected shape to crop the data into. 2 element tuple. 

    Returns: 
    - returns the cropped image
    """

    # Create a mask of non-NaN values
    image = np.array(image, dtype=float)  # converts numbers; non-convertible become nan
    mask = ~np.isnan(image)

    # Find the rows and columns with at least one non-NaN value
    non_nan_rows = np.where(mask.any(axis=1))[0]
    non_nan_cols = np.where(mask.any(axis=0))[0]

    # Use the min and max of these indices to slice the image
    cropped_image = image[non_nan_rows.min():non_nan_rows.max() + 1, non_nan_cols.min():non_nan_cols.max() + 1]
    
    # Get the current shape of the cropped image
    current_shape = cropped_image.shape
    
    # Check if the cropped image needs to be resized
    if current_shape != expected_shape:
        # Pad or crop to reach the expected shape
        padded_image = np.full(expected_shape, np.nan)  # Initialize with NaNs
        
        # If cropped_image is larger than expected_shape, trim it
        trim_rows = min(current_shape[0], expected_shape[0])
        trim_cols = min(current_shape[1], expected_shape[1])
        
        # Center the cropped image within the expected shape
        start_row = (expected_shape[0] - trim_rows) // 2
        start_col = (expected_shape[1] - trim_cols) // 2
        
        # Place the trimmed or centered cropped_image in the padded_image
        padded_image[start_row:start_row + trim_rows, start_col:start_col + trim_cols] = \
            cropped_image[:trim_rows, :trim_cols]
        
        return padded_image
    
    return cropped_image

def reprojectWrapper(inData, inHeader, OutputHeader, OutputData):
    """
    Reproject data from one fits file to another. Take OrigData and OrigHeader and reproject into the frame of OutputHeader and OutputData.

    Parameters:
    - OrigData: Data to reproject
    - OrigHeader: Header of data to reproject
    - OutputHeader: Header of data to reproject into
    - OutputData: Data to reproject into, used only for shape
    """

    reprojected_data, _ = reproject_interp((inData, inHeader), OutputHeader, shape_out=(np.shape(OutputData)))
    # reprojected_data = cropNanBorder(reprojected_data, np.shape(OutputData))

    return reprojected_data

In [None]:


BlockFactor = 4

fits_path = r"C:\Users\jhoffm72\Documents\FilPHANGS\Data\ngc0628_F770W\Composites\ngc0628_F770W_JWST_Emission_starsub_CDDss0256pc.fits_Composites.fits"
block_path = r"C:\Users\jhoffm72\Documents\FilPHANGS\Data\ngc0628_F770W\BkgSubDivRMS\ngc0628_F770W_JWST_Emission_starsub_CDDss0064pc.fits_BkgSubDivRMS.fits"


with fits.open(fits_path, ignore_missing=True) as hdul:
    OrigData = np.array(hdul[0].data)  # Assuming the image data is in the primary HDU 
    data = OrigData
    OrigHeader = hdul[0].header
    wcs_1 = WCS(OrigHeader)
data[np.isnan(data)] = 0

with fits.open(block_path, ignore_missing=True) as hdul:
    BlockData = np.array(hdul[0].data)  # Assuming the image data is in the primary HDU 
    BlockHeader = hdul[0].header
    wcs_1 = WCS(BlockHeader)

coords_data = data



#Step 6: Apply blocking to speed up PSF fitting
if BlockFactor != 0:
    # BlockData = np.zeros((int(data.shape[0] / self.BlockFactor), int(data.shape[1] / self.BlockFactor))) 


    data = reprojectWrapper(data, OrigHeader, BlockHeader, BlockData) #use original data to avoid blurring from reproject_interp


    # coords_data = self.reprojectWrapper(coords_data, self.OrigHeader, self.BlockHeader, self.BlockData) #use original data to avoid blurring from reproject_interp
    coords_data, _ = reproject_exact((coords_data, OrigHeader), BlockHeader, shape_out= BlockData.shape)
    # coords_data = cropNanBorder(coords_data, BlockData.shape)


    #To Do: Save this stage of coords data that is reprojected down and skeletonized for debugging
    hdu = fits.PrimaryHDU(coords_data, header=BlockHeader)
    try:
        hdu.writeto('composite_blocked.fits', overwrite=True)
    except Exception as e:
        print('write failed:', e)
else:
    print(f'na')


print('processing coords data')
#Step 3: Process the composite image
kernel_size = 3
kernel = np.ones((kernel_size, kernel_size), np.uint8)
dilated_image = cv2.morphologyEx(coords_data, cv2.MORPH_CLOSE, kernel)
dilated_image = skeletonize(dilated_image.astype(bool))

hdu = fits.PrimaryHDU(coords_data.astype(np.uint8), header=BlockHeader)
try:
    hdu.writeto('RdyForJunctionRemoval.fits', overwrite=True)
except Exception as e:
    print('write failed:', e)

#Step 4: Remove junctions
fil_centers = copy.deepcopy(coords_data)
junctions = AF.getSkeletonIntersection(np.array(fil_centers*255))
IntersectsRemoved = AF.removeJunctions(junctions, fil_centers, dot_size = 1) #check intersects removed
IntersectsRemoved[IntersectsRemoved > 0] = 1
IntersectsRemoved[IntersectsRemoved < 0] = 0
fil_centers = IntersectsRemoved

rep_centers = fil_centers.copy()


#reproject this processed image back and detect sources
rep_centers, _ = reproject_exact((rep_centers , BlockHeader), OrigHeader, shape_out=OrigData.shape) #nearest neighbor
rep_centers  = cropNanBorder(rep_centers, OrigData.shape)
rep_centers[rep_centers > 0] = 1

hdu = fits.PrimaryHDU(rep_centers, header=OrigHeader)
try:
    hdu.writeto('processedImageRdyForSegmentation256.fits', overwrite=True)
except Exception as e:
    print('write failed:', e)

#Step 7: Create a filament dictionary 
print('creating dictionary')
label_val = 3
Scale = 256
Scale = float(Scale)

min_area =  int(8*16**2/(5.24**2)) #aspect ratio of 8, all images used on blocked image of 16pc character. 

segment_map = detect_sources(rep_centers, threshold=.5, npixels=min_area)
segm_deblend = deblend_sources(rep_centers, segment_map, npixels=min_area, nlevels=32, contrast=0.001,progress_bar=False)
imgNew = np.zeros_like(OrigData, dtype = float)

hdu = fits.PrimaryHDU(segm_deblend, header=OrigHeader)
hdu.writeto(f'SegmImgForDicCreation.fits', overwrite=True)
print("Saved reprojected filament map to imgNew_reprojected.fits")

for label in segm_deblend.labels:
    mask = segm_deblend.data == label # Find pixels that belong to this segment
    imgNew[mask] = label_val
    label_val+=10

# Create a labeled mask for all filaments
segment_info_reprojected = {}
imgNew = np.rint(imgNew).astype(int)

print(f'max label is: {label_val}, number of filaments: {label_val//10}')

# Extract coordinates for each label
for lab in range(3, label_val + 1, 10):
    white_mask = (imgNew >= lab -1) & (imgNew <= lab + 1)
    if not np.any(white_mask):
        continue 
    coords = np.argwhere(white_mask)
    img_skel = skeletonize(white_mask.astype(bool))
    length = len(np.argwhere(img_skel)) #determine length from skeletonized filament
    coords_list = [(int(x), int(y)) for y, x in coords]
    segment_info_reprojected[lab] = (coords_list, length)

print('Dictionary reprojected.')
print(np.max(segment_info_reprojected.keys()))

# Debugging
img_bool = np.nan_to_num(imgNew).astype(bool)

# Skeletonize
skel_img = skeletonize(img_bool)
hdu = fits.PrimaryHDU(skel_img.astype(np.uint8), header=OrigHeader)
hdu.writeto(f'imgNewBlockedRdyForPSF.fits', overwrite=True)
print("Saved reprojected filament map to imgNew_reprojected.fits")

Set DATE-AVG to '2022-07-17T12:01:53.586' from MJD-AVG.
Set DATE-END to '2022-07-17T12:54:46.016' from MJD-END'. [astropy.wcs.wcs]
Set OBSGEO-B to   -37.754891 from OBSGEO-[XYZ].
Set OBSGEO-H to 1738895745.206 from OBSGEO-[XYZ]'. [astropy.wcs.wcs]


processing coords data


Set DATE-AVG to '2022-07-17T12:01:53.586' from MJD-AVG.
Set DATE-END to '2022-07-17T12:54:46.016' from MJD-END'. [astropy.wcs.wcs]
Set OBSGEO-B to   -37.754891 from OBSGEO-[XYZ].
Set OBSGEO-H to 1738895745.206 from OBSGEO-[XYZ]'. [astropy.wcs.wcs]


creating dictionary
Saved reprojected filament map to imgNew_reprojected.fits
max label is: 533, number of filaments: 53
Dictionary reprojected.
dict_keys([3, 13, 23, 33, 43, 53, 63, 73, 83, 93, 103, 113, 123, 133, 143, 153, 163, 173, 183, 193, 203, 213, 223, 233, 243, 253, 263, 273, 283, 293, 303, 313, 323, 333, 343, 353, 363, 373, 383, 393, 403, 413, 423, 433, 443, 453, 463, 473, 483, 493, 503, 513, 523])
Saved reprojected filament map to imgNew_reprojected.fits
