In [None]:
import warnings
warnings.filterwarnings('ignore')
from astropy.io import fits
import os
from astropy.table import join

import sys
sys.path.append('../')
from modules import dendro_dendro, dendro_misc, dendro_props, dendro_mask

In [None]:
import warnings
warnings.filterwarnings('ignore')

from astropy.table import Table, join, vstack
import numpy as np
import matplotlib.pyplot as plt
import scipy 
from reproject import reproject_interp

import astropy.units as au
from astropy import stats
from astrodendro import Dendrogram, pp_catalog
from astropy.wcs import WCS
from astropy.table import Column

from astropy.io import fits
import aplpy
from tqdm.auto import tqdm

from astropy.convolution import Gaussian2DKernel
from astropy.convolution import convolve

from astropy.wcs import WCS

In [None]:
# Define names and filenames...

galaxy = 'ngc0628'
root_dir = '/Users/abarnes/Dropbox/work/Smallprojects/galaxies'
cutout_dir = './analysis/cutouts'
dendro_dir = './analysis/dendro'
cutouts_hdus_dir = './analysis/cutouts_hdus/'
rerun_masking = False

regions_file = '%s/sample.reg' %cutout_dir
regions_pickel_file = '%s/sample.pickel' %cutout_dir
sample_table_file = '%s/data_misc/sample_table/phangs_sample_table_v1p6.fits' %root_dir
muscat_table_file = '%s/data_misc/nebulae_catalogue/Nebulae_catalogue_v3.fits' %root_dir

In [None]:
# Load regions, sample table and HDUs... 

hdus_cutouts = dendro_misc.load_pickle('%s/hdus_all.pickel' %cutout_dir)
regions = dendro_misc.load_pickle(regions_pickel_file)

sample_table = dendro_misc.get_galaxyprops(galaxy, sample_table_file)
muscat_table = dendro_misc.get_museprops(galaxy, muscat_table_file)

In [None]:
def get_maskeddata(muscat_hdu, hst_hdu, muscat_id): 
    
    hst_masked_hdu = hst_hdu.copy() #only ID in mask
    hst_masked_ones_hdu = hst_hdu.copy() #only ID in mask
    hst_maskedall_hdu = hst_hdu.copy() #all MUSE regions out of mask - for noise measure
    
    #get mask 
    shape = muscat_hdu.data.shape
    # muscat_id = muscat_hdu.data[int(shape[0]/2),int(shape[0]/2)]
    
    mask1 = muscat_hdu.data!=muscat_id #if catalouge is not ID
    mask2 = np.isnan(muscat_hdu.data) #if catalouge is not a number
    
    muscat_hdu1 = muscat_hdu.copy()
    muscat_hdu2 = muscat_hdu.copy()
    
    muscat_hdu1.data[mask1] = np.nan
    muscat_hdu2.data[mask2] = 1
    muscat_hdu2.data[~mask2] = np.nan

    #regrid
    data1 = reproject_interp(muscat_hdu1, hst_hdu.header, return_footprint=False, order='nearest-neighbor')
    data2 = reproject_interp(muscat_hdu2, hst_hdu.header, return_footprint=False, order='nearest-neighbor')

    #mask 
    hst_masked_hdu.data[np.isnan(data1)] = np.nan
    hst_masked_ones_hdu.data[np.isnan(data1)] = 1
    hst_maskedall_hdu.data[np.isnan(data2)] = np.nan
    
    return(hst_masked_hdu, hst_masked_ones_hdu, hst_maskedall_hdu, muscat_id)

def get_maskedhdus(hdus, regions, muscat_regionIDs, hstha_hdu_name='hstha_hdu'):

    hstha_hdu_smooth_arr = []
    hstha_masked_hdu_arr = []
    hst_masked_ones_hdu_arr = []
    hstha_maskedall_hdu_arr = []
    muscat_id_arr = []
    mask_muse_arr = []

    print(f'[INFO] [get_maskedhdus] Getting HST maps masked by MUSE catalouge...')
    for i in tqdm(range(len(regions['ra'])), desc='Masking regions', position=0):
    # for i in tqdm(range(50), desc='Masking regions', position=0):

        if i != 1272: 
            continue

        # Load
        muscat_hdu = hdus['muscat_hdu'][i]
        hstha_hdu_smooth = hdus[hstha_hdu_name][i].copy()
        muscat_regionID = muscat_regionIDs[i]

        # Smooth the data
        kernel = Gaussian2DKernel(x_stddev=0.5)
        hstha_hdu_smooth.data = convolve(hstha_hdu_smooth.data, kernel)

        # Mask the data
        output = get_maskeddata(muscat_hdu, hstha_hdu_smooth, muscat_regionID)
        hstha_masked_hdu, hst_masked_ones_hdu, hstha_maskedall_hdu, muscat_id = output

        # Make HDU
        mask_muse = fits.PrimaryHDU(~np.isnan(hstha_masked_hdu.data) * 1, hstha_hdu_smooth.header)

        # Extract the results in the correct order
        hstha_hdu_smooth_arr += [hstha_hdu_smooth]
        hstha_masked_hdu_arr += [hstha_masked_hdu]
        hst_masked_ones_hdu_arr += [hst_masked_ones_hdu]
        hstha_maskedall_hdu_arr += [hstha_maskedall_hdu]
        muscat_id_arr += [muscat_regionID]
        mask_muse_arr += [mask_muse]

    # Assign the processed data to the corresponding keys in the hdus dictionary
    hdus['%s_smooth' % hstha_hdu_name] = hstha_hdu_smooth_arr
    hdus['%s_smooth_masked' % hstha_hdu_name] = hstha_masked_hdu_arr
    hdus['%s_smooth_masked_ones' % hstha_hdu_name] = hst_masked_ones_hdu_arr
    hdus['%s_smooth_maskedall' % hstha_hdu_name] = hstha_maskedall_hdu_arr
    hdus['musmask_hdu'] = mask_muse_arr

    # Return the modified hdus dictionary and muscat_id
    return hdus

In [None]:
# Load cutout hdus with smoothed, masked, and non-masked data...

muscat_regionIDs = muscat_table['region_ID']
hdus = get_maskedhdus(hdus_cutouts, regions, muscat_regionIDs)

In [None]:
import matplotlib.pyplot as plt

plt.imshow(hdus['hstha_hdu_smooth_masked'][0].data)
# plt.imshow(hdus['muscat_hdu'][2].data)
# plt.imshow(hdus['hstha_hdu_smooth_maskedall'][2].data)

In [None]:
hdus.keys()