In [None]:
from astropy.io import fits
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import healpy as hp
import string
from astropy.wcs import WCS
from scipy import ndimage
import suprht
import shutil
from drizzlib.healpix2wcs import healpix2wcs
from collections import Counter
import os
import statistics

%matplotlib inline

In [None]:
def get_the_map(x,y,xlen,ylen,data):
    
    """
    this function employs the drizzlib python package that maps HEALPix data into FITS format with the 
    least resolution lost
    
    x: the center of the map, longitude in degrees
    y: the center of the map, latitude in degrees
    xlen: the width of the input map
    ylen: the length of the input map
    """
    healpix2wcs(healpix = data, 
            cdelt=[-1/60,1/60], crval=[x,y], ctype=('GLON-TAN' , 'GLAT-TAN'),clobber=True, 
            image_size=[xlen,ylen], output = 'drizz_orig_lon'+str(float("{:.4f}".format(x)))+'_lat'+str(float("{:.4f}".format(y)))+'_'+str(xlen)+'x'+str(ylen)+'.fits')
    return 'drizz_orig_lon'+str(float("{:.4f}".format(x)))+'_lat'+str(float("{:.4f}".format(y)))+'_'+str(xlen)+'x'+str(ylen)+'.fits'


def save_rht(ks, bw,x,y, xlen,ylen, link):
    """
    saves the RHT output in FITS format
    
    x: the center of the map, longitude in degrees
    y: the center of the map, latitude in degrees
    xlen: the width of the input map
    ylen: the length of the input map
    link: the path for output
    """
    output_path = link
    RHT_output = suprht.main(filename = 'drizz_orig_lon'+str(float("{:.4f}".format(x)))+'_lat'+str(float("{:.4f}".format(y)))+'_'+str(xlen)+'x'+str(ylen)+'.fits', kernel_size = ks, bar_width=bw, ndrizz = 101, ntheta = 180, histogram_fraction = 0.75, sigma_gauss_edge_smooth = bw, smooth_rad_bitmap  =bw+0, path = link)
    return 'drizz_orig_lon'+str(float("{:.4f}".format(x)))+'_lat'+str(float("{:.4f}".format(y)))+'_'+str(xlen)+'x'+str(ylen)+'_SUPRHT_K'+str(ks)+'_BAR'+str(bw)+'.fits'

def deleter(filename):
    os.remove(filename)


In [None]:

def labeling(smr, wlen,x,y, xlen,ylen,link):
    
    """"
    labels the masked image, assigns a unique number to each distinct filament
    
    smr: kernel_size is the diameter of the kernel inside which we apply the Hough Transform
    wlen: bar_width defines the width of the line on which we integrate the intensity
    x: the center of the map, longitude in degrees
    y: the center of the map, latitude in degrees
    xlen: the width of the input map
    ylen: the length of the input map
    link: the path for output
    """
    input_map = fits.getdata(str(link)+'drizz_orig_lon'+str(float("{:.4f}".format(x)))+'_lat'+str(float("{:.4f}".format(y)))+'_'+str(xlen)+'x'+str(ylen)+'_SUPRHT_K'+str(smr)+'_BAR'+str(wlen)+'.fits')
    x_axis=len(input_map)
    y_axis=len(input_map[0])

    temp=np.array(input_map, copy=True)
    frame=smr+int(wlen/2)
    
    cut=np.zeros((len(temp)-2*frame,len(temp[0])-2*frame))
    for i in range(len(cut)):
        for j in range (len(cut[0])):
            cut[i][j]=temp[i+frame][j+frame]



    im=np.array(cut, copy=True)
    mask=im>im.mean()

    labeled_image, nb_labels = ndimage.label(mask)

    sizes = ndimage.sum(mask, labeled_image, range(nb_labels + 1))

    mask_size = sizes < 50

    remove_pixel = mask_size[labeled_image]
    remove_pixel.shape

    labeled_image[remove_pixel] = 0
    labels = np.unique(labeled_image)
    label_im_2 = np.searchsorted(labels, labeled_image)
    
    return labeled_image



    

def largest_fil(array):
    
    """
    computes the largest distinct filament in the given map
    """
    
    temp = Counter(array.ravel()[array.ravel()!=0]).most_common(1)[0][0]
    return temp

def whether_expand(pic):
    
    """
    if the expantion is needed returns sides of the map that are to be expanded
    """
    
    arr=np.array(pic, copy=True)
    up=down=right=left=0
    for i in range(len(arr)):
        for j in range(len(arr[0])):
            if(arr[i][j]>0):
                if (i==0):
                    up=1
                if (i==len(arr)-1):
                    down=1
                if (j==0):
                    left=1
                if (j==len(arr[0])-1):
                    right=1
    return [up,down,right,left]
    
    

In [None]:
def add_frame(array, smr, wlen):
    
    """
    when RHT is run, the utput map is smaller, this function returns the original sized map
    """
    
    frame=smr+int(wlen/2)
    outing=np.zeros((len(array)+2*frame,len(array[0])+2*frame))
    for i in range(len(array)):
        for j in range (len(array[0])):
            outing[i+frame][j+frame]=array[i][j]
    return outing

def center_mass(array, num):
    x, y = ndimage.measurements.center_of_mass(array==num)
    return y,x
    
    
def xy2lonlat(x,y,xlen,ylen,smr,wlen,i,j):
    
    """
    converts array indices to longitude and latitude
    
    x: the center of the map, longitude in degrees
    y: the center of the map, latitude in degrees
    xlen: the width of the input map
    ylen: the length of the input map
    smr: kernel_size is the diameter of the kernel inside which we apply the Hough Transform
    wlen: bar_width defines the width of the line on which we integrate the intensity
    i: array indice, rows
    j: array indice, columns
    """
    
    a = 'drizz_orig_lon'+str(float("{:.4f}".format(x)))+'_lat'+str(float("{:.4f}".format(y)))+'_'+str(xlen)+'x'+str(ylen)+'.fits'
    header = fits.getheader(a)
    w = WCS(header)
    wx, wy = w.wcs_pix2world(i,j,1)
    return float(wx), float(wy)



def lonlat2xy(x,y,xlen,ylen,smr,wlen,i,j):
    
    """
    converts longitude and latitude to array indices
    
    x: the center of the map, longitude in degrees
    y: the center of the map, latitude in degrees
    xlen: the width of the input map
    ylen: the length of the input map
    smr: kernel_size is the diameter of the kernel inside which we apply the Hough Transform
    wlen: bar_width defines the width of the line on which we integrate the intensity
    i: array indice, rows
    j: array indice, columns
    """
    
    a = 'drizz_orig_lon'+str(float("{:.4f}".format(x)))+'_lat'+str(float("{:.4f}".format(y)))+'_'+str(xlen)+'x'+str(ylen)+'.fits'
    header = fits.getheader(a)
    w = WCS(header)
    wx, wy = w.wcs_world2pix(i,j,1)
    return np.round(wx), np.round(wy)

def points_of_area(arr, num):   
    """
    from the masked map obtain indices of the filament
    """
    temp = []
    for i in range(len(arr)):
        for j in range(len(arr[0])):
            if(arr[i][j]==num):
                temp.append(tuple([j, i]))
    return temp
            

In [None]:
def mask_val(arr, num):
    """
    mask a filament, 1 is True
    """
    temp=np.zeros((len(arr), len(arr[0])))
    for i in range(len(arr)):
        for j in range(len(arr[0])):
            if(arr[i][j]==num):
                temp[i][j]=1
    return temp


def save_to_fits(data, name, lon,lat):
    """
    the fits file is saved
    """
    hdu = fits.PrimaryHDU(data)
    hdu.writeto(name, overwrite=True)
    make_header(name, lon, lat)

def make_header(file, lon,lat, cdelt=1/60):
    """
    creates a header for the output fits file
    
    file: the name of the file that needs a header
    lon: center of the file, longitude in degrees
    lat: center of the file, latitude in degrees
    cdlet: set by default, can be overridden
    """
    hdr = fits.open(file)[0].header
    fits.setval(file, 'CTYPE1', value='GLON-TAN')
    fits.setval(file, 'CTYPE2', value='GLAT-TAN')
    fits.setval(file, 'CUNIT1', value='deg')
    fits.setval(file, 'CUNIT2', value='deg')
    fits.setval(file, 'CRVAL1', value=float(lon))
    fits.setval(file, 'CRVAL2', value=float(lat))
    fits.setval(file, 'CRPIX1', value=hdr['NAXIS1']/2)
    fits.setval(file, 'CRPIX2', value=hdr['NAXIS2']/2)
    fits.setval(file, 'CDELT1', value=-cdelt) 
    fits.setval(file, 'CDELT2', value=cdelt)    
    

def fil_search(img, coordinates):
    """
    finds the filament in rounds 2 and above
    """
    marker = np.zeros(len(coordinates))
    for i in range(len(coordinates)):
        x = coordinates[i][0]
        y = coordinates[i][1]
        try:
            marker[i]=img[int(y)][int(x)]
        except IndexError: 
            pass
    sample = np.array(marker)
    return statistics.mode(sample[np.nonzero(sample)])

def save_angles_map(data, name):
    
    """
    returns the angles map using the RHT data
    """
    ipoints, jpoints, hthets, naxis1, naxis2, wlen, smr, thresh = get_RHT_data(data)
    indx = 1000
    ipoint_example = ipoints[indx]
    jpoint_example = jpoints[indx]
    hthets_example = hthets[indx]
    thets_arr = RHT_tools.get_thets(wlen, save=False)
    temp=np.zeros((naxis1,naxis2))

    for i in range (len(hthets)):
        hthets_example = hthets[i]   
        a = np.where(hthets_example == np.max(hthets_example))
        val = a[0][0]
        temp[jpoints[i]][ipoints[i]] = np.degrees(thets_arr)[val]
    save_to_fits(temp, name)
    
    
def delete_all(cent1_new, cent2_new, lenx_new,leny_new,smr,wlen,link):
        deleter(str(link)+'drizz_orig_lon'+str(float("{:.4f}".format(cent1_new)))+'_lat'+str(float("{:.4f}".format(cent2_new)))+'_'+str(lenx_new)+'x'+str(leny_new)+'_SUPRHT_K'+str(smr)+'_BAR'+str(wlen)+'.fits')
        deleter(str(link)+'drizz_orig_lon'+str(float("{:.4f}".format(cent1_new)))+'_lat'+str(float("{:.4f}".format(cent2_new)))+'_'+str(lenx_new)+'x'+str(leny_new)+'_SUPRHT_SIGTHETA_K'+str(smr)+'_BAR'+str(wlen)+'.fits')
        deleter(str(link)+'drizz_orig_lon'+str(float("{:.4f}".format(cent1_new)))+'_lat'+str(float("{:.4f}".format(cent2_new)))+'_'+str(lenx_new)+'x'+str(leny_new)+'_SUPRHT_THETA_K'+str(smr)+'_BAR'+str(wlen)+'.fits')

    

In [None]:
def main(cent1_new, cent2_new, lenx_new, leny_new, smr, wlen, output, data1,rnd=1, coordinates=0):
    """
    initiates the algorithm
    
    cent1_new: center of the input map, longitude, in degrees
    cent2_new: center of the input map, latitude, in degrees
    lenx_new: width of the input map
    leny_new: length og the input map
    smr: kernel_size is the diameter of the kernel inside which we apply the Hough Transform
    wlen: bar_width defines the width of the line on which we integrate the intensity
    output: the folder where output files would be contained
    data1: the input data in HEALPix
    
    """
    if(rnd==10):
        print('----TOO MANY ITERATIONS, MOVE ON TO THE NEXT CENTER ------')
        return
    
    lenx_new = int(lenx_new)
    leny_new=int(leny_new)
        
    """
    extract the map
    """

    pointer=get_the_map(cent1_new, cent2_new,lenx_new,leny_new,data1)
    original_map = fits.getdata(pointer)
        
    """
    run and save RHT
    """
    pointer2 = save_rht(smr,wlen,cent1_new,cent2_new,lenx_new,leny_new, output)
    rht_map = fits.getdata(output+pointer2)
        
    """
    label each distinct filament with a unique number
    """
    labeled_cut = labeling(smr, wlen, cent1_new, cent2_new,lenx_new,leny_new,output)
    
    if(not labeled_cut.any()):
        
        """
        condition set in case all the strucutres are too small, may consider changing the parameter in the 
        labeling function or re-parametrizing RHT
        """
        
        print('------ ALL STRUCTURES ARE TOO SMALL, MOVE ON TO THE NEXT CENTER ------')
        delete_all(cent1_new, cent2_new, lenx_new,leny_new,smr,wlen, output)
        return
    
    filament_num = 1234567
    
    if(rnd==1):
        
        """
        use references, find the filament passed from the previous round
        """
        filament_num = largest_fil(labeled_cut) 
        old_coords_ij=points_of_area(mask_val(add_frame(labeled_cut, smr, wlen), filament_num),1)
        coord_lonlat = np.zeros((len(old_coords_ij),2))
        for i in range(len(old_coords_ij)):
            coord_lonlat[i][0], coord_lonlat[i][1] = xy2lonlat(cent1_new, cent2_new, lenx_new, leny_new, smr,wlen, old_coords_ij[i][0], old_coords_ij[i][1])

    else:
        
        """
        obtain the references of the detected filament
        """

        coord_lonlat=coordinates
        new_coord_ij = np.zeros((len(coord_lonlat),2))
        for i in range(len(coord_lonlat)):
            new_coord_ij[i][0], new_coord_ij[i][1] = lonlat2xy(cent1_new, cent2_new, lenx_new, leny_new, smr,wlen, coord_lonlat[i][0], coord_lonlat[i][1] )


        filament_num = fil_search(add_frame(labeled_cut,smr,wlen), new_coord_ij)
        
        """
        check whether the filament fits into the map completely
        """
        
    arr_expand = whether_expand(mask_val(labeled_cut, filament_num))
    
    if(max(arr_expand)==1):
        
        """
        obtain the new center by computing for the center of mass of the detected filament
        """
        
        fil_cent_x, fil_cent_y = center_mass(add_frame(labeled_cut, smr,wlen),filament_num)
        
        """
        convert new centers from array indices to longitude and latitude
        """

        lon, lat = xy2lonlat(cent1_new, cent2_new,lenx_new,leny_new,smr,wlen,fil_cent_x, fil_cent_y)
        labeled_uncut = add_frame(labeled_cut, smr, wlen)
        print('\n ----GOING ON ANOTHER ROUND---- \n')
        
        
        deleter('drizz_orig_lon'+str(float("{:.4f}".format(cent1_new)))+'_lat'+str(float("{:.4f}".format(cent2_new)))+'_'+str(lenx_new)+'x'+str(leny_new)+'.fits')
        delete_all(cent1_new, cent2_new, lenx_new,leny_new,smr,wlen, output)
        """
        
        run the main recursively, expand the frame, pass new centers and add references of the filemnt detected 
        in step 1
        """
        
        main(lon, lat, lenx_new*1.08, leny_new*1.08, smr, wlen, output, data1, rnd+1, coordinates=coord_lonlat)
        
    else:
        
        """
        deleting all the extra files and giving the output a correct name
        """
        
        print('\n ----FINISHING UP---- \n')
        naming1 =str(str(output)+'mask_lon'+str(float("{:.4f}".format(cent1_new)))+'_lat'+str(float("{:.4f}".format(cent2_new)))+'_'+str(lenx_new)+'x'+str(leny_new)+'_smr'+str(smr)+'_wlen'+str(wlen)+'.fits')
        save_to_fits(mask_val(add_frame(labeled_cut,smr,wlen), filament_num), naming1, cent1_new, cent2_new)

        os.rename(str(output)+'drizz_orig_lon'+str(float("{:.4f}".format(cent1_new)))+'_lat'+str(float("{:.4f}".format(cent2_new)))+'_'+str(lenx_new)+'x'+str(leny_new)+'_SUPRHT_K'+str(smr)+'_BAR'+str(wlen)+'.fits',str(output)+'suprht_lon'+str(float("{:.4f}".format(cent1_new)))+'_lat'+str(float("{:.4f}".format(cent2_new)))+'_'+str(lenx_new)+'x'+str(leny_new)+'_K'+str(smr)+'_BAR'+str(wlen)+'.fits')
        os.rename(str(output)+'drizz_orig_lon'+str(float("{:.4f}".format(cent1_new)))+'_lat'+str(float("{:.4f}".format(cent2_new)))+'_'+str(lenx_new)+'x'+str(leny_new)+'_SUPRHT_SIGTHETA_K'+str(smr)+'_BAR'+str(wlen)+'.fits',str(output)+'suprth_sigtheta_lon'+str(float("{:.4f}".format(cent1_new)))+'_lat'+str(float("{:.4f}".format(cent2_new)))+'_'+str(lenx_new)+'x'+str(leny_new)+'_K'+str(smr)+'_BAR'+str(wlen)+'.fits')
        os.rename(str(output)+'drizz_orig_lon'+str(float("{:.4f}".format(cent1_new)))+'_lat'+str(float("{:.4f}".format(cent2_new)))+'_'+str(lenx_new)+'x'+str(leny_new)+'_SUPRHT_THETA_K'+str(smr)+'_BAR'+str(wlen)+'.fits', str(output)+'suprht_theta_lon'+str(float("{:.4f}".format(cent1_new)))+'_lat'+str(float("{:.4f}".format(cent2_new)))+'_'+str(lenx_new)+'x'+str(leny_new)+'_K'+str(smr)+'_BAR'+str(wlen)+'.fits')
        

        shutil.move('drizz_orig_lon'+str(float("{:.4f}".format(cent1_new)))+'_lat'+str(float("{:.4f}".format(cent2_new)))+'_'+str(lenx_new)+'x'+str(leny_new)+'.fits',str(output)+'orig_lon'+str(float("{:.4f}".format(cent1_new)))+'_lat'+str(float("{:.4f}".format(cent2_new)))+'_'+str(lenx_new)+'x'+str(leny_new)+'.fits')        
        deleter('kernel_supRHT.fits')

        return cent1_new, cent2_new, lenx_new, leny_new, smr, wlen

In [None]:
center_longitude = -60
center_latitude = -9
width = 250
length = 250
smr = 27
wlen = 7

folder = '/home/sarah/Desktop/output/'
data = 'HFI_SkyMap_353-psb_2048_R3.01_full_CIBsub_Nested_res7_I.fits'

main(center_longitude, center_latitude, width, length, smr, wlen, folder, data)