In [14]:
"""
Make and plot spectrum from given ALMA fits cube
"""
from photutils.aperture import CircularAperture, EllipticalAperture, CircularAnnulus, aperture_photometry
from astropy.wcs import WCS
from astropy.coordinates import SkyCoord
from astropy.io import fits
import numpy as np
import matplotlib.pyplot as plt
from spectral_cube import SpectralCube


ra = '13:11:29.935' #coordinates of your galaxy (in SkyCoord accepted format)
dec = '-1:19:18.750'
smaj = 2.5 #arcseconds #major and minor axis dimensions and orientation angle of elliptical aperture
smin = 1.5 #arcseconds
ang = -np.pi/6 #radians

def aper_pos(header):
    """
    Use header to get the WCS.
    
    Convert given aperture RA and Dec position into pixel x,y position.
    
    Return x, y in pixels, and pixel size in degrees.
    (assuming pixel size information is present in the header).
    """
    w = WCS(header,naxis=2)
    center = SkyCoord(ra, dec, unit=(u.hourangle, u.deg),frame='fk5')
    x,y = w.world_to_pixel(center)
    pix_len = header['CDELT2']
    return(x,y,pix_len)

def make_spectrum(file, nu_obs=233, vel_min=-180, vel_max=200, aper=None, save_as=None):
    """
    Read in given fits cube.
    
    Apply given aperture and velocity range (vel_min, vel_max).
    
    Extract spectrum within aperture and velocity range.
    
    Construct velocity and frequency axes (observed frequence, nu_obs).
    
    Plot spectrum and save to file (with path+name = save_as parameter).
    """
    if aper is None:
        header = fits.open(file)[0].header
        galx,galy,pix_len = aper_pos(header)
        aper = EllipticalAperture((galx,galy), a = smaj/(pix_len*3600), b = smin/(pix_len*3600), theta = ang) #arcsec to pixel conversion
    
    cube = SpectralCube.read(file)
    *_,npix = cube.shape
    freq_axis = cube.spectral_axis.value
    
    spectrum_mask = aper.to_mask(method='center').to_image((npix,npix))!=0
    maskedcube = cube.with_mask(spectrum_mask)
    spectrum = maskedcube.mean(axis=(1,2)).value #Jy/beam
    
    fig, ax = plt.subplots(figsize=(8,5))
    ax.plot(freq_axis, spectrum*1e3, drawstyle='steps-mid', color='grey', linewidth=2, zorder=10)

    def nu2vel(nu):
        nu_0 = nu_obs*1e9
        return c*1e-5*(nu_0 - nu)/nu_0

    def vel2nu(vel):
        nu_0 = nu_obs*1e9
        return nu_0*(1 - (vel/(c*1e-5)))
    
    nu_min = vel2nu(vel_min)
    nu_max = vel2nu(vel_max)
    ax.fill_between(freq_axis, 0, spectrum*1e3, where=((freq_axis > nu_max) & (freq_axis < nu_min)), step='mid', color='yellow',alpha=0.8, zorder=5)

    ax.set_xlabel('Frequency (Hz)', fontsize=18)
    secax = ax.secondary_xaxis('top', functions=(nu2vel, vel2nu))
    secax.set_xlabel('Velocity (km/s)', fontsize=18, labelpad=0.5)

    ax.set_ylabel('Flux Density (mJy/beam)', fontsize=18)#, labelpad=-0.5)
    
    plt.setp(ax.get_xticklabels(), rotation=30, horizontalalignment='right')
    plt.grid(zorder=1)

    if save_as is not None:
        plt.savefig(f'{save_as}.pdf', overwrite=True, bbox_inches='tight', pad_inches=0)