In [None]:
# Import statements
from spectral_cube import SpectralCube
from astropy import units as u
from astroquery.splatalogue import Splatalogue
import matplotlib.pyplot as plt
from pylab import imshow
import numpy as np
import scipy.ndimage as nd
from lte_modeling_tools import nupper_of_kkms

In [None]:
# Define functions
def get_line(chemical_name, freq_lo, freq_hi, vel):
    tbl = Splatalogue.query_lines(freq_lo-0.1*u.GHz, freq_hi+0.1*u.GHz, 
                                  chemical_name=chemical_name,
                                  energy_max=500, # more lines w/ max energy > 140
                                  energy_type='eu_k',
                                  line_lists=['JPL'],
                                  show_upper_degeneracy=True, 
                                  show_qn_code=True)
    line_freqs = tbl['Meas Freq-GHz(rest frame,redshifted)'].data
    shifted_line_freqs = line_freqs-((vel/299792)*line_freqs) # Shift by velocity
    tbl['Shifted Freq-GHz'] = shifted_line_freqs
    return tbl

def get_subcube(cube, center_freq, slab_width):
    print(center_freq)
    subcube = cube.spectral_slab(center_freq - 0.1*u.GHz, center_freq + 0.1*u.GHz).to(u.K)
    subcube_v = subcube.with_spectral_unit(u.km/u.s, 
                                           rest_value = center_freq,
                                           velocity_convention = 'radio').spectral_slab(-slab_width,
                                                                                        slab_width)
    print(subcube_v)
    return subcube_v

def get_noise_map(cube_noise):
    cube_sclip = cube_noise.sigma_clip_spectrally(3) # Clip values above 3-sigma 
#     mad_std_spectrum_sclip = cube_sclip.mad_std(axis=(1, 2))
#     plt.plot(mad_std_spectrum_sclip.spectral_axis.value, mad_std_spectrum_sclip.value, 
#              drawstyle='steps-mid')
#     plt.xlabel('Velocity (km/s)')
#     plt.ylabel(r' Noise standard deviation $\sigma$ (K)')
#     plt.ylim([0., 0.30]) # Best to extend the range to 0.
#     plt.axhline(0.25, linestyle='--', color='k', linewidth=3, label='A priori noise expectation')
#     plt.legend(frameon=True)
    mad_std_map_sclip = cube_sclip.mad_std(axis=0) # Calculate sigma along the spectral dimension
    return mad_std_map_sclip

def get_signal_mask_scipy(cube_signal, mad_std_map_sclip):
    '''Please pass an already-masked cube to cube_signal'''
    # Make a low and high mask
    low_snr_mask = (cube_signal > 3 * mad_std_map_sclip).include()
    high_snr_mask = (cube_signal > 10 * mad_std_map_sclip).include()
    low_snr_mask = low_snr_mask.compute() # We need to convert from a dask array to a numpy array.
    high_snr_mask = high_snr_mask.compute()
    # Find connected structures
    structure = np.ones((3, 3, 3), dtype=bool)
    low_snr_mask_labels, num_labels = nd.label(low_snr_mask, structure=structure)
    print(f"Initial number of regions found: {num_labels}")
    # From the labels, count the number of pixels within each label.
    num_pixels_in_high_snr_mask = nd.sum(high_snr_mask,
                                         labels=low_snr_mask_labels,
                                         index=range(1, num_labels + 1)) # +1 offset for mask labels
    # Repeat for the high signal mask.
    num_pixels_in_low_snr_mask = nd.sum(low_snr_mask,
                                        labels=low_snr_mask_labels,
                                        index=range(1, num_labels + 1)) # +1 offset for mask labels
    # To preserve the low_snr_mask, we will create a new signal mask where we will remove 
    # regions that do not pass the criteria.
    signal_mask = low_snr_mask
    low_min_pixels = 40
    high_min_pixels = 10
    for num, (high_pix_num, low_pix_num) in enumerate(zip(num_pixels_in_high_snr_mask, 
                                                          num_pixels_in_low_snr_mask)):
        if high_pix_num >= high_min_pixels and low_pix_num >= low_min_pixels:
            # This region passes the criteria. Keep it in the mask.
            continue
        # Remove regions that do not pass the criteria.
        # NOTE: enumerate will start with 0, but the mask labels start at 1
        # We apply a +1 offset to `num` to account for this.
        signal_mask[low_snr_mask_labels == num + 1] = False
    signal_mask_labels, num_labels = nd.label(signal_mask,
                                              structure=structure)
    print(f"Final number of regions found: {num_labels}")
    signal_mask = nd.binary_dilation(signal_mask, structure=structure, iterations=1)
    return signal_mask

def find_outliers(masked_cube, v_thresh):
    masked_moment1 = masked_cube.moment1()
    masked_moment1_outliers = (masked_moment1 > v_thresh*u.km/u.s)|(masked_moment1 < -v_thresh*u.km/u.s)
    imshow(masked_moment1_outliers, origin='lower') 
    # Clumps of outliers might mean they're real, just outside of vel range

# def remove_outliers(masked_cube):
#     '''Remove outliers based on mom0 map after-the-fact (deprecated)'''
#     mom0 = masked_cube.moment0()
#     mom0_mask = mom0 > 1.*u.K*u.km/u.s # Mask pixels with mom0 less than threshold
#     print(f"Found {mom0_mask.sum()} good pixels")
#     masked_cube_no_outliers = masked_cube.with_mask(mom0_mask)
#     return masked_cube_no_outliers

# Noise cube and map

In [None]:
# Get overall cube
results = '/blue/adamginsburg/abulatek/brick/symlinks/imaging_results/'
freq_spw = '146_spw51'
fn = results+'source_ab_'+freq_spw+'_clean_2sigma_n50000_masked_3sigma_pbmask0p18.image'
cube = SpectralCube.read(fn, format='casa_image')
print(cube.shape)

In [None]:
# Get noise subcube
subcube_noise = get_subcube(cube, 146.8*u.GHz, 15.*u.km/u.s)
spectrum = subcube_noise[:, 256, 256]
plt.plot(spectrum.spectral_axis, spectrum.value, drawstyle='steps-mid')
plt.xlabel('Velocity (km/s)')
plt.ylabel('Intensity (K)')

In [None]:
# Get noise map
mad_std_map_sclip = get_noise_map(subcube_noise)
ax = plt.subplot(projection=mad_std_map_sclip.wcs)
im = ax.imshow(mad_std_map_sclip.value, origin='lower', cmap='gray')
cbar = plt.colorbar(im)
cbar.set_label('Intensity (K)')

ax.set_ylabel('Declination')
ax.set_xlabel('Right Ascension')

# Signal cube

In [None]:
# Get lines
tbl = get_line('CH3CN', 147.01*u.GHz, 147.17*u.GHz, 0)
tbl = tbl[tbl['Quantum Number Code'] == 202]
tbl = tbl[::-1]
tbl.show_in_notebook()

In [None]:
# Get signal subcubes
# subcubes = []
# for n in range(8):
center_freq = tbl['Freq-GHz(rest frame,redshifted)'][3]*u.GHz
subcube = get_subcube(cube, center_freq, 10.*u.km/u.s)
# subcubes.append(subcube)

In [None]:
# Look at peak intensity map of signal subcube
# peak_intensity_signal = subcube[0].max(axis = 0) # Take the maximum along the spectral dimension
# imshow(peak_intensity_signal.value, cmap='gray', origin='lower')

### Use Desmond's method of plain masking on the subcube

In [None]:
plain_mask = subcube >= 3 * mad_std_map_sclip
plain_masked_slab = subcube.with_mask(plain_mask)

In [None]:
# Get signal mask
# masked_cubes = []
# for n in range(8):
signal_mask = get_signal_mask_scipy(plain_masked_slab, mad_std_map_sclip) # subcubes[n]
masked_cube = plain_masked_slab.with_mask(signal_mask) # subcubes[n] # This used to be subcube (wrong)
# masked_cubes.append(masked_cube)

In [None]:
# Look at peak intensity map of masked signal subcube
# peak_intensity_sigmask = masked_cubes[0].max(axis=0)
# imshow(peak_intensity_sigmask.value, cmap='gray', origin='lower')

#### Post-facto outliers (no longer used)

In [None]:
# Visualize outliers
# find_outliers(masked_cube[0], 15.)

# Remove outliers
# masked_cubes_no_outliers = []
# for n in range(8):
# masked_cube_no_outliers = remove_outliers(masked_cube) # masked_cubes[n]
# masked_cubes_no_outliers.append(masked_cube_no_outliers)

# Moment maps

In [None]:
# Moment 0 maps
# masked_moment0s = []
# for n in range(8):
masked_moment0 = masked_cube.moment0() # masked_cubes[n]

ax = plt.subplot(projection=masked_moment0.wcs)
im = ax.imshow(masked_moment0.value, origin='lower', cmap='inferno')
cbar = plt.colorbar(im)
cbar.set_label('Integrated Intensity (K km/s)')
ax.set_ylabel('Declination')
ax.set_xlabel('Right Ascension')
# plt.savefig(f'k{n}_mom0.pdf')
# masked_moment0s.append(masked_moment0)

In [None]:
# Moment 1 maps
# masked_moment1s = []
# for n in range(8):
masked_moment1 = masked_cube.moment1() # masked_cubes[n]

ax = plt.subplot(projection=masked_moment1.wcs)
im = ax.imshow(masked_moment1.value, origin='lower', cmap='coolwarm')
cbar = plt.colorbar(im)
cbar.set_label('Centroid (km/s)')

ax.set_ylabel('Declination')
ax.set_xlabel('Right Ascension')
# plt.savefig(f'k{n}_mom1.pdf')
# masked_moment1s.append(masked_moment1)

#### Compare high-velocity pixel to central pixel

In [None]:
max_vel_coord = np.unravel_index(np.nanargmax(masked_moment1), masked_moment1.shape)
spectrum = masked_cube[:, max_vel_coord[0], max_vel_coord[1]]
plt.plot(spectrum.spectral_axis, spectrum.value, drawstyle='steps-mid')
plt.xlabel('Velocity (km/s)')
plt.ylabel('Intensity (K)')

In [None]:
spectrum = masked_cube[:, 259, 256]
plt.plot(spectrum.spectral_axis, spectrum.value, drawstyle='steps-mid')
plt.xlabel('Velocity (km/s)')
plt.ylabel('Intensity (K)')

#### Investigate noise map

In [None]:
# # Mask out pixels with integrated intensity less than 1 K*km/s
# mad_std_map_sclip_mask = masked_moment0 > 1.*u.K*u.km/u.s
# subcube_sclip = subcube_noise.sigma_clip_spectrally(3)
# subcube_masked = subcube_noise.with_mask(mad_std_map_sclip_mask)
# mad_std_map_sclip_masked = subcube_masked.mad_std(axis=0)

# ax = plt.subplot(projection=mad_std_map_sclip_masked.wcs)
# im = ax.imshow(mad_std_map_sclip_masked.value, origin='lower', cmap='gray')
# cbar = plt.colorbar(im)
# cbar.set_label('Intensity (K)')

# ax.set_ylabel('Declination')
# ax.set_xlabel('Right Ascension')

# # Should not have any noise values < 1 K*km/s / width
# noise_map = mad_std_map_sclip_masked
# print(f"Range of masked noise map: {np.nanmin(noise_map)} to {np.nanmax(noise_map)}")
# print(f"Lower threshold: {(1.*u.K*u.km/u.s)/channel_width}")

# From Desmond: Run linewidth_fwhm on slabs. "There shouldn't be any linewidth values that are 
# greater than the width of the slab, and if there were that'd imply that there were negative/bad
# flux or integrated intensity values still floating around in your slab or mom0 somehow."

#### Check molecules that don't have neighboring transitions for other velocity components

In [51]:
tbl = get_line('CH3CN', 147.01*u.GHz, 147.17*u.GHz, 0)
tbl.show_in_notebook()

TypeError: get_line() missing 1 required positional argument: 'vel'

# Rotational diagram

In [None]:
type(masked_moment1)

In [None]:
N_upper = nupper_of_kkms(masked_moment0, shifted_line_freqs[i], 10**(einstein_A_coefficients[i]))