In [None]:
import numpy as np
import cmasher as cmr
import matplotlib.colors as mcolors
from astropy.io import fits
import astropy.units as u
import os
from scipy import signal
from tqdm import tqdm
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib
import matplotlib.pyplot as plt
import math
Pi = math.pi
from scipy import ndimage

In [None]:
def find_peak_cnt(phase_dgm, flx_len, promin1=100, promin2=10): 
    #phase_dgm is the 2D array of the phase diagram, colour-coded by whatever you want.
    #flx_len is the width of the strip, which will be later send to signal.find_peaks to find the localised peak
    #You need to play around the flx_len as it differs at various dataset.
    #promin1, promin2, how significant you require your peak to be. If finding promin1 fails, the code will try a second time for promin2.
    
    flag = 1 #footprint of the ant will be flagged from 1 or any constant you want
    marker = np.full(phase_dgm.shape, np.nan) #define an empty footprint map
    hgt = phase_dgm.shape[0] #read out the height of the phase diagram
    strip = phase_dgm[int(hgt/2)-flx_len:int(hgt/2)+flx_len, :] #the strip around the middle radius
    start_point = (int(hgt/2)-flx_len + np.where(strip == np.nanmax(strip))[0][0], np.where(strip == np.nanmax(strip))[1][0]) #the position where I drop the ant, by finding the brightest pixel in the strip
    marker[start_point] = flag 
    for i in range(start_point[1]+1, phase_dgm.shape[1]): #ask the ant to walk from starting anchor toward INCREASING azimuth, ending at 360deg
        new_strip = phase_dgm[start_point[0]-flx_len:start_point[0]+flx_len, i] #get a new strip, taking regime of +/- flx_len in the R direction, and 1 grid in the azimuth direction
        try:
            strip_arch = signal.find_peaks(new_strip, prominence=promin1)[0] #find a localised peak in the new strip, the signal needs to be more prominant than promin1
            1/len(strip_arch[0])
        except:
            strip_arch = signal.find_peaks(new_strip, prominence=promin2)[0] #if promin1 fail, lower the threshold to promin2 (you can skip second search this if you want)
        if len(strip_arch) >= 1: #if there is a prominant peak found in the new strip
            start_point = (strip_arch[0]+start_point[0]-flx_len, i) #move the ant to the new peak
            flag += 1
            marker[start_point[0]][i] = flag #highlight the footprint as 1, 2, 3, ...
            
        if start_point[0]-flx_len < 0: #do nothing if there is no "good" peak found in the new strip
            break
    
    #As (R1, 360deg) is the same as (R1, 0deg), we want the ant to jump from the right end to the left end of the phase diagram
    for i in range(0, phase_dgm.shape[1]): #ask the ant to keep walking from 0 deg toward INCREASING azimuth, ending at 360deg
        new_strip = phase_dgm[start_point[0]-flx_len:start_point[0]+flx_len, i]
        try:
            strip_arch = signal.find_peaks(new_strip, prominence=promin1)[0]
            1/len(strip_arch[0])
        except:
            strip_arch = signal.find_peaks(new_strip, prominence=promin2)[0]
        if len(strip_arch) >= 1:
            start_point = (strip_arch[0]+start_point[0]-flx_len, i)
            marker[start_point[0]][i] = flag
            flag += 1
        if start_point[0]-flx_len < 0:
            break
            
    strip = phase_dgm[int(hgt/2)-flx_len:int(hgt/2)+flx_len, :]
    tmp = 0
    start_point = (int(hgt/2)-flx_len + np.where(strip == np.nanmax(strip))[0][0], np.where(strip == np.nanmax(strip))[1][0])
    tmp = start_point[0] #the above four lines are going back to the starting anchor
    flag = 0
    for i in range(start_point[1]-1, -1, -1): #ask the ant to walk from the starting anchor, towards decreasing azimuth, ending at 0deg
        new_strip = phase_dgm[start_point[0]-flx_len:start_point[0]+flx_len, i]#same as the loop above, but in the decreasing θ direction
        try:
            strip_arch = signal.find_peaks(new_strip, prominence=promin1)[0]
            1/len(strip_arch[0])
        except:
            strip_arch = signal.find_peaks(new_strip, prominence=promin2)[0]
        if len(strip_arch) >= 1:
            if strip_arch[0] < 2:
                flag += 1
            start_point = (strip_arch[0]+start_point[0]-flx_len, i)
            marker[start_point[0]][i] = markout
            markout += 1
        if start_point[0]-flx_len < 0:
            break
        
    #As (R2, 0deg) is the same pixel as (R2, 360 deg), the ant jumps from the left end to the right end of the phase diagram
    for i in range(phase_dgm.shape[1]-1, 0, -1): 
        new_strip = phase_dgm[start_point[0]-flx_len:start_point[0]+flx_len, i]
        try:
            strip_arch = signal.find_peaks(new_strip, prominence=promin1)[0]
            1/len(strip_arch[0])
        except:
            strip_arch = signal.find_peaks(new_strip, prominence=promin2)[0]
        if len(strip_arch) >= 1:
            if strip_arch[0] < 2:
                flag += 1
            start_point = (strip_arch[0]+start_point[0]-flx_len, i)
            marker[start_point[0]][i] = flag
            flag += 1
        if start_point[0]-flx_len < 0:
            break
    return(marker)


In [None]:
#generating Φ map, 0deg in the +x-axis and increasing in counter-clockwise direction
pix = n_pixels_in_your_image
xc = pix/2.
yc = pix/2.
phi0_im = np.full((pix,pix),np.nan)
for i in range(0,pix):
    for j in range(0,pix):
        if (i > yc) and (j > xc):
            phi0_im[i][j] = math.atan((i-yc)/(j-xc))/Pi*180. + 0 #in unit of degree
        if (i > yc) and (j < xc):
            phi0_im[i][j] = math.atan((i-yc)/(j-xc))/Pi*180. + 180. #in unit of degree
        if (i < yc) and (j < xc):
            phi0_im[i][j] = math.atan((i-yc)/(j-xc))/Pi*180. + 180. #in unit of degree
        if (i < yc) and (j > xc):
            phi0_im[i][j] = math.atan((i-yc)/(j-xc))/Pi*180. + 360.#in unit of degree
        if (i == yc) and (j > xc) :
            phi0_im[i][j] = 0.
        if (i == yc) and (j < xc) :
            phi0_im[i][j] = 180.
        if (i < yc) and (j == xc) :
            phi0_im[i][j] = 270.
        if (i > yc) and (j == xc) :
            phi0_im[i][j] = 90.

In [None]:
#generating radial distance map (in unit of pixel), need deprojection if it is observational data. 
#Deprojection sees Sec2.3 in Grasha et al. (2017).
dedist = np.full((pix,pix),np.nan)
xc = pix/2. 
yc = pix/2.
for i in range(0,pix):
    for j in range(0,pix):
        dedist[i][j] = np.sqrt(np.square(j-xc) + np.square(i-yc))

In [None]:
#im_star is the 2D array of the mass map.
im_star = fits.open('massmap.fits')[0].data
#generating the phase diagram, x-axis is azimuth Φ while y-axis is the radial distance (taking natural log)
im_phase = np.full([60, 180], np.nan) #you can choose how fine your phase diagram is. I choose 60 grid in the R-direction and 180 grid in the azimuth-direction.
for i in tqdm(range(im_phase.shape[0])): #tqdm is to show the terminal progress bar, you can delete it if you want 
    for j in range(im_phase.shape[1]):
        grid = np.where( (np.log(dedist) < np.log(3*(i+1))) & 
                        (np.log(dedist) > np.log(3*i)) & 
                        (phi0_im < (j+1)*2) & (phi0_im >= j*2), im_star, 0) #pixels in the radius of 30 - 33 pixel & φ of 180-182deg will be compiled into one grid
        #the phase diagram is colour-coded by im_star, 2D image of the star mass map, should be in shape of (pix, pix)
        im_phase[i][j] = np.nanmean(grid[grid>0]) 
        
for i in range(0, im_phase.shape[0]):
    im_phase[i] = im_phase[i]-np.nanmean(im_phase[i]) #subtrating the radial gradient

In [None]:
marker1 = defuc.find_peak_cnt(ndimage.gaussian_filter(np.where(im_phase>0, im_phase, 0), sigma=sigma), 
                        flxlen, np.percentile(im_phase[im_phase>0], 50)/2., np.percentile(im_phase[im_phase>0], 25)/2.)#try to guide the ant to walk out the brightest spiral arm

im_phase_2 = np.copy(im_phase)
for i in range(0, im_phase.shape[0]):
    for j in range(0, im_phase.shape[1]):
        if np.nansum(marker1[i-2: i+2, j-2: j+2])>0:
            im_phase_2[i][j] = 0 #mask out the brightes spiral arm, given the width of 2 grid in R direction and φ direction
marker2 = defuc.find_peak_cnt(ndimage.gaussian_filter(np.where(im_phase_2>0, im_phase_2, 0), sigma=sigma), 
                        flxlen, np.percentile(im_phase[im_phase>0], 50)/2., np.percentile(im_phase[im_phase>0], 25)/2.) #guide the ant to walk out the second brightest spiral arm

#delete the pixels duplicated or too close to each other
for i in range(3, marker2.shape[0]-3):
    for j in range(3, marker2.shape[1]-3):
        if np.nansum(marker1[i-4:i+4, j-4:j+4]) > 0:
            marker2[i][j] = np.nan

In [None]:
#visualise the spiral arm marker in phase diagram
fig, ax = plt.subplots(figsize=[17,8])
plot = ax.imshow(im_phase, norm=mcolors.PowerNorm(gamma=0.6, vmax=3e6, vmin=-1e6), cmap=cmr.bubblegum)
ax.imshow(marker1, cmap='winter', vmax=1, vmin=0)
ax.imshow(np.where(marker2<1, np.nan, marker2), cmap='winter', vmax=1, vmin=0)
ax.axhline(30)
#plt.axhline(25)
plt.gca().invert_yaxis()
plt.xticks(np.arange(0, 180, 20), np.arange(0, 360, 40))
plt.yticks(np.linspace(0, 50, 5), np.linspace(0, tmp, 5))

plt.ylabel('ln (Deproj. Radial Dist.)', font1)
plt.xlabel('Azimuth (degree)', font1)

divider = make_axes_locatable(ax)
cax = divider.new_vertical(size='8%', pad=0.)
fig.add_axes(cax)
cb = fig.colorbar(plot, cax=cax, orientation='horizontal')
#cb.set_label('radial subtracted M$_{young}$', labelpad=-80, fontsize=20)
cb.ax.tick_params(labeltop=True, labelbottom=False)

plt.title('Halo '+str(halo) + '- radial subtracted M$_{young}$', font2)
#plt.savefig(outputdir+halo+'_phase_marker.png', bbox_inches='tight')
plt.show()

In [None]:
#convert phase diagram back to 2D image
spmask = np.full(im_star.shape, np.nan)
mask = np.where((marker1>0) | (marker2>0) )
for i in range(0, len(mask[0])):
    r_arch = mask[0][i]/50*pix/2.
    r_mask = np.where((dedist <= r_arch +1) & (dedist >= r_arch -1), 1, 0)
    phi_arch = mask[1][i]*2
    phi_mask = np.where((phi0_im < phi_arch +2) & (phi0_im > phi_arch -2), 1, 0)
    spmask = np.where((phi_mask == 1) & (r_mask == 1), 1, spmask)
    
plt.figure(dpi=70)
plt.imshow(im_star, vmax=6e6, vmin=1e4, cmap=cmr.bubblegum)
plt.imshow(spmask, vmax=1, cmap='winter_r')