In [45]:
import math
import multiprocessing as mg
import multiprocessing.pool
# import pys2let as ps
import random
import string
import itertools
import os

import jax
jax.config.update("jax_enable_x64", True)
import s2fft
import healpy as hp
import numpy as np
import s2wav
import s2wav
import matplotlib.pyplot as plt
%matplotlib inline 
import skyclean
from skyclean import CMB_data

In [46]:
def mw_alm_2_hp_alm(MW_alm2, lmax):
    # Initialize the 1D hp_alm array with the appropriate size
    hp_alm1 = np.zeros(hp.Alm.getsize(lmax), dtype=np.complex128)
        
    for l in range(lmax + 1):
        for m in range(-l, l + 1):
            index = hp.Alm.getidx(lmax, l, abs(m))
            if m < 0:
                hp_alm1[index] = (-1)**m * np.conj(MW_alm2[l, lmax + m])
            else:
                hp_alm1[index] = MW_alm2[l, lmax + m]

    return hp_alm1

# Doubleworker

## Loaded mw wavelet coefficient map

## Covert to mw alm space

## add zero to the mw alms  (Is it correct? or should I add zeros to the hp alm's and then convert to mw alm's)

## Convert to from mw alm to mw map      (hp)


def Single_Map_doubleworker(MW_Pix_Map):
    '''
    Input: MW_Pix_Map: list of mw maps at different scales 
    Each pixel map is a wavelet pixel map of shape (1, Lmax, 2*Lmax-1) (MW sampling, McEwen & Wiaux)
    It is the output of s2wav.analysis
    (Scale: 0, size (1, 4, 7))

    Process:
    1. Covert MW Pixel Map to MW alm space using s2fft.forward

    2. Add zero to the mw alms  (Is it correct? or should I add zeros to the hp alm's and then convert to mw alm's)
    by adding zeros to the MW alm's we are increasing the resolution of the map
    Double the rows of the mw alms, since, the number of rows represents the L (level of detail)
    
    3. Convert mw alm to mw map 
    
    '''
    print("original Pixel Map size", MW_Pix_Map.shape)
    MW_alm = s2fft.forward(MW_Pix_Map, L = MW_Pix_Map.shape[1])
    print("original alm size", MW_alm.shape)
    
   
    # print("Scale:",i,"original alm size", MW_alm[i].shape)
    padded_alm = np.zeros((MW_alm.shape[0]*2,MW_alm.shape[1]*2))
    # stored_wavelet_coeffs_alm_doubled.append(skyclean.double_resolution(stored_wavelet_coeffs_alm[i]))
    padded_alm[:MW_alm.shape[0], :MW_alm.shape[1]] = MW_alm
    print("padded alm size", padded_alm.shape)
    MW_alm_doubled = padded_alm
    
    MW_Pix_Map_doubled = s2fft.inverse(MW_alm_doubled, L = MW_alm_doubled.shape[0])
    print("Scale:","doubled map size", MW_Pix_Map_doubled.shape)

    return MW_Pix_Map_doubled

## Loaded mw wavelet coefficient map
stored_wavelet_coeffs_pix = [np.load(f"../convolution/wavelet_coefficient/wav_30_{i}.npy", allow_pickle=True) for i in range(12)]
stored_scaling_coeffs_pix = np.load("../convolution/scaling_coefficient/scal_30.npy")


# print(stored_wavelet_coeffs_pix[0].shape)
stored_wavelet_coeffs_pix = stored_wavelet_coeffs_pix[:3]

wavelet_MW_Pix_Map_doubled = Single_Map_doubleworker(stored_wavelet_coeffs_pix[0])

# display(wavelet_MW_Pix_Map_doubled[0])
# Why oen dimension is reduced?
# Is it the spin?
print(wavelet_MW_Pix_Map_doubled.shape)


original Pixel Map size (1, 4, 7)
original alm size (4, 7)
padded alm size (8, 14)
Scale: doubled map size (8, 15)
(8, 15)


In [48]:
def smoothworker(MW_Map1, MW_Map2, smoothing_lmax, scale_fwhm):
    '''
    Input: MW_Map1, MW_Map2: same size MW pixel wavelet maps at different frequencies
    output: R_map: smoothed coveriance map beteen MW_Map1 and MW_Map2
    '''
    # Get the real part of the map
    map1 = np.real(MW_Map1)
    map2 = np.real(MW_Map2)
    # Covariance matrix
    R = np.multiply(map1,map2) + 0.j #Add back in zero imaginary part
    # print("R", R.shape)

    R_MW_alm = s2fft.forward(R, L = smoothing_lmax)
    # print("R_MW_alm", R_MW_alm.shape)
    gauss_smooth = hp.gauss_beam(scale_fwhm,lmax=smoothing_lmax-1)
    MW_alm_beam_convolved = np.zeros(R_MW_alm.shape, dtype=np.complex128)

    # Convolve the MW alms with the beam
    for i in range(R_MW_alm.shape[1]):
        MW_alm_beam_convolved[:, i] = R_MW_alm[:, i] * gauss_smooth
    
    R_covariance_map = s2fft.inverse(MW_alm_beam_convolved, L = smoothing_lmax)

    return R_covariance_map

npix = hp.nside2npix(1<<(int(0.5*wavelet_MW_Pix_Map_doubled.shape[0])-1).bit_length())
# (int(0.5*scale_lmax)-1).bit_length() calculates the number of bits necessary to represent the integer int(0.5*scale_lmax)-1 in binary.
# 1 << (int(0.5*scale_lmax)-1).bit_length() performs a bitwise left shift, essentially calculating 2^(number of bits).
scale_fwhm = 4.0 * math.sqrt(1200 / npix)
R_covariance_map = smoothworker(wavelet_MW_Pix_Map_doubled, wavelet_MW_Pix_Map_doubled,wavelet_MW_Pix_Map_doubled.shape[0], scale_fwhm)


In [49]:
R_covariance_map.shape

(8, 15)

In [None]:
def ILC(MW_Pix_Map):

    
    Current_Wavelet_Map = MW_Pix_Map
    
    # Define the size of the smoothing beam 
    # 1200 pixels, size of the sphere
    nsamp = 1200.0
    lmax_at_scale_j = Current_Wavelet_Map.shape[0]
    npix = hp.nside2npix(1<<(int(0.5*lmax_at_scale_j)-1).bit_length())
    # (int(0.5*scale_lmax)-1).bit_length() calculates the number of bits necessary to represent the integer int(0.5*scale_lmax)-1 in binary.
    # 1 << (int(0.5*scale_lmax)-1).bit_length() performs a bitwise left shift, essentially calculating 2^(number of bits).
    scale_fwhm = 4.0 * math.sqrt(nsamp / npix)
    # for high resolution maps, it is still the same number pixels sampled by the actual range is smaller.
    # the beam will become narrow?
    print(lmax_at_scale_j,npix, scale_fwhm)





In [None]:
a = np.array([[1,2],[3,4]])
b = np.array([[5,6],[7,8]])
# display(a)
# display(b)
# display(np.multiply(a,b))
# display(np.multiply(wavelet_MW_Pix_Map_doubled[0],wavelet_MW_Pix_Map_doubled[1]))

In [None]:
def ILC(MW_Pix_Map):

    
    Current_Wavelet_Map = MW_Pix_Map
    
    # Define the size of the smoothing beam 
    # 1200 pixels, size of the sphere
    nsamp = 1200.0
    lmax_at_scale_j = Current_Wavelet_Map.shape[0]
    npix = hp.nside2npix(1<<(int(0.5*lmax_at_scale_j)-1).bit_length())
    # (int(0.5*scale_lmax)-1).bit_length() calculates the number of bits necessary to represent the integer int(0.5*scale_lmax)-1 in binary.
    # 1 << (int(0.5*scale_lmax)-1).bit_length() performs a bitwise left shift, essentially calculating 2^(number of bits).
    scale_fwhm = 4.0 * math.sqrt(nsamp / npix)
    # for high resolution maps, it is still the same number pixels sampled by the actual range is smaller.
    # the beam will become very narrow.
    print(lmax_at_scale_j,npix, scale_fwhm)





    

# ILC(wavelet_MW_Pix_Map_doubled)

# print(wavelet_MW_Pix_Map_doubled.shape)
    

