In [None]:
import numpy as np
import healpy as hp
from tqdm.auto import tqdm
from threadpoolctl import threadpool_limits
from joblib import Parallel, delayed
import pickle
import os

In [None]:
nside = 512
gen_nside = 4 * nside
lmax = 2 * gen_nside

nmocks = 10
mock_base_dir = '/spiff/pierfied/Simulations/des_y3_mocks'

num_bins = 4

In [None]:
prior_params = pickle.load(open('lognorm_params.pkl', 'rb'))

cl = prior_params['cl']
y_cl = prior_params['y_cl']
shift = prior_params['shift']
mu = prior_params['mu']

In [None]:
size = np.deg2rad(2.5)
resol = hp.nside2resol(2 * nside)
size = int(np.ceil(size / resol))
spacing = (np.arange(size) - size / 2 + 0.5) * resol
x, y = np.meshgrid(spacing, spacing, indexing='xy')

rho = np.sqrt(x**2 + y**2)
c = np.arctan(rho)

theta0 = 0
phi0 = 0

theta0 = np.pi / 2 - theta0
theta = np.pi / 2 - np.arcsin(np.cos(c) * np.sin(theta0) + y * np.sin(c) * np.cos(theta0) / (rho + 1e-8))
phi = phi0 + np.arctan2(x * np.sin(c), rho * np.cos(theta0) * np.cos(c) - y * np.sin(theta0) * np.sin(c))

mask = np.zeros(hp.nside2npix(nside), dtype=bool)
mask[hp.ang2pix(nside, theta, phi)] = True

np.save('disc_mask.npy', mask)

In [None]:
size = np.deg2rad(3.5)
resol = hp.nside2resol(2 * nside)
size = int(np.ceil(size / resol))
spacing = (np.arange(size) - size / 2 + 0.5) * resol
x, y = np.meshgrid(spacing, spacing, indexing='xy')

rho = np.sqrt(x**2 + y**2)
c = np.arctan(rho)

theta0 = 0
phi0 = 0

theta0 = np.pi / 2 - theta0
theta = np.pi / 2 - np.arcsin(np.cos(c) * np.sin(theta0) + y * np.sin(c) * np.cos(theta0) / (rho + 1e-8))
phi = phi0 + np.arctan2(x * np.sin(c), rho * np.cos(theta0) * np.cos(c) - y * np.sin(theta0) * np.sin(c))

mask = np.zeros(hp.nside2npix(nside), dtype=bool)
mask[hp.ang2pix(nside, theta, phi)] = True

np.save('gen_mask.npy', mask)

In [None]:
def rotation_matrix_from_vectors(vec1, vec2):
    """ Find the rotation matrix that aligns vec1 to vec2
    :param vec1: A 3d "source" vector
    :param vec2: A 3d "destination" vector
    :return mat: A transform matrix (3x3) which when applied to vec1, aligns it with vec2.
    """
    a, b = (vec1 / np.linalg.norm(vec1)).reshape(3), (vec2 / np.linalg.norm(vec2)).reshape(3)
    v = np.cross(a, b)
    c = np.dot(a, b)
    s = np.linalg.norm(v)
    kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
    rotation_matrix = np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s ** 2))
    return rotation_matrix

ref_vec = np.array([0, 0, 1])
rot_vecs = np.array(hp.pix2vec(16, np.arange(hp.nside2npix(16)))).T

rot_mats = np.array([rotation_matrix_from_vectors(ref_vec, vec) for vec in rot_vecs])

In [None]:
nside_fac = 4

mask_pix = np.arange(hp.nside2npix(2 * nside))[hp.ud_grade(mask, 2 * nside)]
mask_vec = np.array(hp.pix2vec(2 * nside, mask_pix))

nest_mask = hp.reorder(mask, r2n=True)

hr_pix_inds = np.arange(hp.nside2npix(nside * nside_fac))

def precompute_rot(ind):
    with threadpool_limits(limits=1):
        rot_mask_vec = rot_mats[ind] @ mask_vec
        rot_mask_pix = hp.vec2pix(nside, rot_mask_vec[0], rot_mask_vec[1], rot_mask_vec[2])
        rot_mask = np.zeros(hp.nside2npix(nside), dtype=bool)
        rot_mask[rot_mask_pix] = True

        rot_mask_hr = hp.ud_grade(rot_mask, nside_fac * nside)
        pix = hr_pix_inds[rot_mask_hr]
        vec = np.array(hp.pix2vec(nside * nside_fac, pix))
        rot_vec = rot_mats[ind].T @ vec
        rot_pix = hp.vec2pix(nside, rot_vec[0], rot_vec[1], rot_vec[2], nest=True)

        return pix, rot_pix

rots = Parallel(n_jobs=-1)(delayed(precompute_rot)(ind) for ind in tqdm(range(len(rot_mats))))

In [None]:
os.system('rm -rf data')
os.system('mkdir -p data/real')
os.system('mkdir -p data/fake')

for sim_num in tqdm(range(nmocks)):
    alm = hp.synalm(y_cl, lmax=lmax, new=False)
    khr_train = np.exp(mu[:,None] + hp.alm2map(alm, gen_nside, pol=False)) - shift[:,None]
    khr_sim = hp.ud_grade(np.load(f'{mock_base_dir}/mocks/mock_{sim_num}.npy')[:,0,:], nside * nside_fac)

    def generate_patches(ind):
        rot_mask_hr, rot_pix = rots[ind]

        m_train = np.zeros([num_bins, hp.nside2npix(nside)])
        m_sim = np.zeros([num_bins, hp.nside2npix(nside)])
        c = np.zeros(hp.nside2npix(nside))
        at_inds = (np.concatenate([np.ones_like(rot_pix) * i for i in range(num_bins)]), np.tile(rot_pix, num_bins))
        np.add.at(m_train, at_inds, khr_train[:,rot_mask_hr].ravel())
        np.add.at(m_sim, at_inds, khr_sim[:,rot_mask_hr].ravel())
        np.add.at(c, rot_pix, 1)
        m_train = m_train[:,nest_mask] / c[None,nest_mask]
        m_sim = m_sim[:,nest_mask] / c[None,nest_mask]

        np.save(f'data/fake/{sim_num}_{ind}.npy', m_train.astype(np.float32))
        np.save(f'data/real/{sim_num}_{ind}.npy', m_sim.astype(np.float32))

    for ind in tqdm(range(len(rot_mats))):
        generate_patches(ind)