In [4]:
#More or less general imports
import numpy as np
import matplotlib.pyplot as plt
from astropy.io import fits
import os
import time

#Imports for Creating WebSky Maps 
import h5py    
import healpy as hp
from pixell import enmap,utils, reproject, bunch, curvedsky, enplot
import os, sys
import dill as dl
from scipy.interpolate import *
#Unit conversion
from astropy import units as u
from astropy.constants import k_B, h

#Creating Stamps
from astropy import wcs
from astropy.nddata import Cutout2D
import pickle as pk
from astropy.convolution import Gaussian2DKernel, convolve
from astropy.coordinates import SkyCoord
import yaml
from pixell import enmap,utils, reproject, enplot
from scipy import interpolate
import random
from astropy.nddata import block_reduce, block_replicate
import ipyparallel as ipp
from mpi4py import MPI
from PIL import Image
from astLib import astWCS, astImages
from scipy import ndimage
from pixell.enmap import sky2pix
from astropy.wcs import WCS
from astropy.wcs.utils import skycoord_to_pixel
from astropy.wcs.utils import proj_plane_pixel_scales
from astroquery.skyview import SkyView
import astropy.io.fits as pyfits
from stamping_utils import autotiler, getTileCoordsDict, make_jpg, normalize_map, make_mask_wise, _make_jpg, make_stamp, make_mask
import glob

%load_ext autoreload
%autoreload 2

%matplotlib inline

In [5]:
omegab = 0.049
omegac = 0.261
omegam = omegab + omegac
h      = 0.68
ns     = 0.965
sigma8 = 0.81

c = 3e5

H0 = 100*h
nz = 100000
z1 = 0.0
z2 = 6.0
za = np.linspace(z1,z2,nz)
dz = za[1]-za[0]

H      = lambda z: H0*np.sqrt(omegam*(1+z)**3+1-omegam)
dchidz = lambda z: c/H(z)

chia = np.cumsum(dchidz(za))*dz

zofchi = interp1d(chia,za)

In [6]:
#Access to WebSky data
path = "/mnt/welch/USERS/cwhitaker/maps/websky/"

In [7]:
cluster_catalog=open(path + 'halos.pksc')
N = np.fromfile(cluster_catalog,count=3,dtype=np.int32)[0]
# only take first five entries for testing (there are ~8e8 halos total...)
# comment the following line to read in all halos
#N = 10
catalog=np.fromfile(cluster_catalog,count=N*10,dtype=np.float32)
catalog=np.reshape(catalog,(N,10))
x  = catalog[:,0];  y = catalog[:,1];  z = catalog[:,2] # Mpc (comoving)
vx = catalog[:,3]; vy = catalog[:,4]; vz = catalog[:,5] # km/sec
R  = catalog[:,6] # Mpc
#print(catalog)

In [8]:
# Constants
rho      = 2.775e11*omegam*h**2 # Msun/Mpc^3
# this is M200m (mean density 200 times mean) in Msun
M        = 4*np.pi/3.*rho*R**3
chi      = np.sqrt(x**2+y**2+z**2)    # Mpc
redshift = zofchi(chi)

In [9]:
dec,ra = hp.vec2ang(np.column_stack((x,y,z))) # in (not with utils.degree) radians

In [10]:
print(dec)
print(ra)

[1.7712529  2.3051946  0.27119753 ... 2.714089   2.2073736  2.5505273 ]
[4.674485   0.58030796 2.4672267  ... 2.9862423  5.285679   1.3658549 ]


In [11]:
dec = np.where(dec > np.pi / 2 , dec - np.pi , dec)

In [12]:
ra = np.where(ra > np.pi, ra - 2*np.pi, ra)

In [13]:
#print(dec)  * 180/np.pi
print(ra)

[-1.6087003   0.58030796  2.4672267  ...  2.9862423  -0.9975066
  1.3658549 ]


In [14]:
#ra = np.where(ra > np.pi, ra - 2*np.pi, ra)

In [15]:
#Limit to how bright/massive the clusters can be
cluster_cut = 2.0e14

cluster_flags = np.where((M >= cluster_cut))[0]
#Limit on parameters of clusters
M            = M[cluster_flags]
redshift     = redshift[cluster_flags]
cluster_decs = dec[cluster_flags] 
cluster_ras  = ra[cluster_flags]

In [16]:
print(cluster_decs)

[-1.3703399  -0.8363981   0.27119753 ... -0.6426151  -1.1381083
 -0.5619528 ]


### Learning Machine Learning

In [17]:
def generate_coords():
    # Arccosine helper function
    def arccos_arange(start, finish, steps):
        """
        Converts an arange array of angles (in degrees) to their arccosine values in radians.
        Returns:
        Array of arccosine values in radians.
        """
        degrees_array = np.arange(start, finish, steps)
        radians_array = np.deg2rad(degrees_array)
        # (arccosine values will be between -1 and 1)
        cosine_values = np.cos(radians_array)
        y = np.arccos(cosine_values)
        return y
    
    # Coordinates
    x = np.arange(-np.pi, 181 * utils.degree, 1.5 * utils.degree)
    y = arccos_arange(-69, 69, 1.5)
    
    # Array of coordinates
    dec, ra = np.meshgrid(y, x)
    
    # Flatten the meshgrid arrays and combine them into pairs of positions
    coords = np.array([dec.ravel(), ra.ravel()]).T
    coords = np.unique(coords, axis=0)
    
    return coords

coords = generate_coords()

In [38]:
def make_mask(image, cluster_ras, cluster_decs, box, cur_wcs, size = 2.4, jpg=False):
    #Function which makes masks corresponding to clusters in a image. 
    if jpg: 
        mask = np.zeros(image[...,0].shape)
    else:
        mask = np.zeros(image[0].shape)
        print(mask.shape)
    min_ra, max_ra, min_dec, max_dec = box[0][0], box[0][1], box[1][0], box[1][1] 

    in_image = np.where((min_ra < cluster_ras) & (cluster_ras < max_ra) & (min_dec < cluster_decs) & (cluster_decs < max_dec))[0]
    print(len(in_image))
    if len(in_image) == 0:
        return mask
 
    for i in range(len(in_image)):
        cur_cluster = in_image[i]
        cur_center = SkyCoord(cluster_ras[cur_cluster], cluster_decs[cur_cluster], unit = "rad")
        x,y = skycoord_to_pixel(cur_center, cur_wcs)
        
        x,y = np.round(x), np.round(y)
        print(x,y)
        pix_size = proj_plane_pixel_scales(cur_wcs)[0] * 60
        r = size/2/pix_size
        #print(mask.shape)
        xx, yy = np.meshgrid(np.linspace(0, mask.shape[1]-1, mask.shape[1]), np.linspace(0, mask.shape[0]-1, mask.shape[0]))
        #print(xx-x)
        r_mask = (xx-x)**2 + (yy-y)**2 < r**2
        mask += r_mask*(i+1)
        doubled_mask = mask > i+1 #Un-double counts areas where clusters overlap
        mask -= doubled_mask*(i+1)
    
    return mask

In [39]:
r = 2 * utils.degree
coord=coords[0]
decs, ras = coord
decmin, ramin = decs - r, ras - r
decmax, ramax = decs + r, ras + r
box = np.array([[decmin, ramin],[decmax, ramax]])
freqs = ["090", "150", "220"]
jpg, cur_wcs = make_stamp('/mnt/welch/USERS/cwhitaker/maps/websky/websky_f*_map.fits', box, freqs, normalize= False)

In [40]:
cur_wcs

car:{cdelt:[0.008333,0.008333],crval:[0.004167,0],crpix:[-21358.50,241.00]}

In [48]:
#min_ra, max_ra, min_dec, max_dec = box[0][0], box[0][1], box[1][0], box[1][1] 
in_image = np.where((ramin < ra) & (ra < ramax) & (decmin < dec) & (dec < decmax))[0]
print(in_image)

[    72091    548259    737764 ... 862339747 862344819 862920475]


In [49]:
mask = make_mask(jpg, cluster_ras, cluster_decs, [[ramin, ramax], [decmin, decmax]], cur_wcs, size = 2.4)

(480, 480)
1
-42736.0 41.0


In [50]:
np.any(mask)

False

In [46]:
ramin

-3.1764992386296798

In [83]:
print(box)

[[-0.03490659 -3.17649924]
 [ 0.03490659 -3.10668607]]


In [53]:
for i in range(len(in_image)):
    cur_cluster = in_image[i]
    cur_center = SkyCoord(cluster_ras[cur_cluster], cluster_decs[cur_cluster], unit = "deg")
    x,y = skycoord_to_pixel(cur_center, cur_wcs)
    x,y = np.round(x), np.round(y)
    print(cur_center)
    pix_size = proj_plane_pixel_scales(cur_wcs)[0] * 60
    #print(pix_size)
    size =2.4
    r = size/2/pix_size
    #print(r)
    xx, yy = np.meshgrid(np.linspace(0, mask.shape[1]-1, mask.shape[1]), np.linspace(0, mask.shape[0]-1, mask.shape[0]))   
    r_mask = (xx-x)**2 + (yy-y)**2 < r**2
    print(np.any(r_mask))
    mask += r_mask*(i+1)
        #doubled_mask = mask > i+1 #Un-double counts areas where clusters overlap
        #mask -= doubled_mask*(i+1)
    
    #return mask

<SkyCoord (ICRS): (ra, dec) in deg
    (2.43840861, 2.87882996)>
False


In [None]:
for i in range(len(in_image)-10, len(in_image)):
    cur_cluster = in_image[i]
    cur_center = SkyCoord(cluster_ras[cur_cluster], cluster_decs[cur_cluster], unit = "deg")
    x,y = skycoord_to_pixel(cur_center, cur_wcs)
    x,y = np.round(x), np.round(y)
    #print(x,y)
    pix_size = proj_plane_pixel_scales(cur_wcs)[0] * 60 * utils.degree
    #print(pix_size)
    size =2.4
    r = size/2/pix_size
    #print(r)
    xx, yy = np.meshgrid(np.linspace(0, mask.shape[1]-1, mask.shape[1]), np.linspace(0, mask.shape[0]-1, mask.shape[0]))   
    #print(xx,yy)
    r_mask = (xx-x)**2 + (yy-y)**2 < r**2
    print((yy-y)**2)
    mask += r_mask*(i+1)
        #doubled_mask = mask > i+1 #Un-double counts areas where clusters overlap
        #mask -= doubled_mask*(i+1)
    
    #return mask

In [None]:
def tile_mpi():
    comm = MPI.COMM_WORLD
    myrank = comm.Get_rank()
    nproc = comm.Get_size()
    freqs = ["090", "150", "220"]
    r = 2
    for coord in coords:
        decs, ras = coord
        decmin, ramin = decs - r, ras - r
        decmax, ramax = decs + r, ras + r
        #print(ra_max, ra_min)
        in_image = np.where((ramin < cluster_ras) & (cluster_ras < ramax) & (decmin < cluster_decs) & (cluster_decs < decmax))[0]
        if len(in_image) == 0:
            print('no clusters')
            continue
    
        if ramax < ramin:
            ramin -= 360
        
        box = np.array([[decmin, ramin],[decmax, ramax]])
        #Make jpg of box
        jpg, cur_wcs = make_stamp('/mnt/welch/USERS/cwhitaker/maps/websky/websky_f*_map.fits', box, freqs, normalize= False)
        if type(jpg) == int: continue

        mask = make_mask(jpg, catalog, box, cur_wcs, size = 2.4)
    return 0

with ipp.Cluster(controller_ip="*", engines="mpi", n=24) as rc:
    # get a broadcast_view on the cluster which is best
    # suited for MPI style computation
    view = rc.broadcast_view()
    # run the mpi_example function on all engines in parallel
    r = view.apply_sync(tile_mpi())
    # Retrieve and print the result from the engines
    print("\n".join(r))

freqs = ["090", "150", "220"]
r = 2
for i, coord in enumerate(coords):
    if i >= 190:
        break
    decs, ras = coord
    decmin, ramin = decs - r, ras - r
    decmax, ramax = decs + r, ras + r
        #print(ra_max, ra_min)
    in_image = np.where((ramin < cluster_ras) & (cluster_ras < ramax) & (decmin < cluster_decs) & (cluster_decs < decmax))[0]
    if len(in_image) == 0:
        print('no clusters')
        continue
    
    if ramax < ramin:
        ramin -= 360
        
    box = np.array([[decmin, ramin],[decmax, ramax]])
        #Make jpg of box
    jpg, cur_wcs = make_stamp('/mnt/welch/USERS/cwhitaker/maps/websky/websky_f*_map.fits', box, freqs, normalize= False)
    if type(jpg) == int: continue

    mask = make_mask(jpg, catalog, box, cur_wcs, size = 2.4)

In [137]:
# Create a WCS object
wcs = WCS(naxis=2)
wcs.wcs.crpix = [-21358.50,241.00]  # Reference pixel coordinates
wcs.wcs.cdelt = np.array([-0.0066667, 0.0066667])  # Pixel scale (degrees per pixel)
wcs.wcs.crval = [180, 0]  # Reference sky coordinates (RA, Dec)
wcs.wcs.ctype = ["RA---TAN", "DEC--TAN"]  # Coordinate system type

# Create a SkyCoord object with the sky coordinates
sky_coord = SkyCoord(ra=100.0, dec=45.0, unit='deg')

# Convert sky coordinates to pixel coordinates
x, y = skycoord_to_pixel(sky_coord, wcs)

print(f"Pixel coordinates: x={x}, y={y}")

Pixel coordinates: x=27381.33319983929, y=49732.73911660097
