In [366]:
# import packages
%matplotlib inline
import numpy as np
import math
np.math = math  # Redirect numpy.math to the built-in math module

import matplotlib.pyplot as plt
import torch
import torch.fft

# change the path below to where you download the scattering package
import sys
# sys.path.append('/content/scattering_transform/')
# import os
# os.chdir('/content/scattering')

# import scattering package
import scattering
import importlib
scattering = importlib.reload(scattering)
import utils
utils = importlib.reload(utils)

In [367]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.constants import h, k, c
from astropy import units as u

# Load the column density map from file
logN_H = np.load('Archive/Turb_3.npy')[1]  
# logN_H = utils.downsample_by_four(logN_H)
nx, ny = logN_H.shape

# Define constant dust temperature
T_d = 10  # Typical Planck dust temperature in K

# int_val = []
# for nu in nu_list:
#     int_val.append(np.mean(modified_blackbody(logN_H, T_d, nu*1e9)))
# int_val = np.array(int_val)

# Observation frequency (e.g., 353 GHz in Hz)

# Compute the mock observed intensity map in μK_CMB
nu = (217,353)
I_nu_map_μK_nu1 = utils.modified_blackbody(logN_H, T_d, nu[0]*1e9)
I_nu_map_μK_nu2 = utils.modified_blackbody(logN_H, T_d, nu[1]*1e9)

In [368]:
# Ensure inputs have shape (1, H, W)
dust_nu1 = I_nu_map_μK_nu1[None, :, :]
dust_nu2 = I_nu_map_μK_nu2[None, :, :]

# --- Parameters ---
n_realizations = 100
SNR = 1
amplitude = 2.
spectral_index = -1.7

nx, ny = dust_nu1.shape[-2], dust_nu1.shape[-1]

# --- Compute noise variances ---
variance_nu1 = (np.std(dust_nu1) / SNR) ** 2
variance_nu2 = (np.std(dust_nu2) / SNR) ** 2
variance = (variance_nu1, variance_nu2)

# --- Create contamination_arr with shape (n_realizations, 2, 1, H, W) ---
contamination_arr = np.zeros((n_realizations, 2, 1, nx, ny), dtype=np.float32)

for i in range(n_realizations):
    # Shared CMB: shape (1, H, W)
    cmb_map = utils.generate_cmb_map(n_x=nx, n_y=ny, amplitude=amplitude, spectral_index=spectral_index)
    cmb_map = cmb_map.cpu().numpy()[None, :, :]

    # Independent noise: shape (1, H, W)
    noise_nu1 = np.random.normal(0, np.sqrt(variance_nu1), (1, nx, ny))
    noise_nu2 = np.random.normal(0, np.sqrt(variance_nu2), (1, nx, ny))

    # Total contamination: shape (1, H, W)
    contamination_arr[i, 0] = noise_nu1 + cmb_map
    contamination_arr[i, 1] = noise_nu2 + cmb_map

contamination_arr_nu1 = contamination_arr[:, 0]  # shape: (Mn, 1, H, W)
contamination_arr_nu2 = contamination_arr[:, 1]  # shape: (Mn, 1, H, W)

In [369]:
SNR = 1
variance_nu1 = (np.std(dust_nu1)/SNR)**2
variance_nu2 = (np.std(dust_nu2)/SNR)**2
variance = (variance_nu1, variance_nu2)

noise_nu1 = np.random.normal(0, np.sqrt(variance_nu1), dust_nu1.shape)
noise_nu2 = np.random.normal(0, np.sqrt(variance_nu2), dust_nu2.shape)

cmb_map = utils.generate_cmb_map(n_x=nx, n_y=ny, amplitude=amplitude, spectral_index=spectral_index)
cmb_map = cmb_map.cpu().numpy()[None, :, :]

data_nu1 = dust_nu1 + noise_nu1 + cmb_map
data_nu2 = dust_nu2 + noise_nu2 + cmb_map

# define target
image_target_nu1 = data_nu1 
image_target_nu2 = data_nu2 

# definte initial maps for optimisation
image_init_nu1 = image_target_nu1
image_init_nu2 = image_target_nu2

In [374]:
M, N, J, L = dust_nu1.shape[-2], dust_nu1.shape[-1], 7, 4
st_calc = scattering.Scattering2d(M, N, J, L)
st_calc.add_ref(ref=data_nu1)
s_cov = st_calc.scattering_cov(data_nu1, use_ref=True, 
                    normalization='P00', pseudo_coef=1
                )
threshold_func = scattering.threshold_func_test(s_cov)
# threshold_func = None

In [378]:
std_nu1 = scattering.compute_std(image_target_nu1, contamination_arr = contamination_arr_nu1, s_cov_func=threshold_func) 
std_nu2 = scattering.compute_std(image_target_nu2, contamination_arr = contamination_arr_nu2, s_cov_func=threshold_func) 
std_double = scattering.compute_std_double(image_target_nu1, image_target_nu2, contamination_arr = contamination_arr, image_ref1=image_target_nu1, image_ref2 = image_target_nu2,  wavelets='BS') 
std = (std_nu1, std_nu2)

In [379]:
n_epochs = 3 #number of epochs
image_ref = (image_target_nu1, image_target_nu2)
# decontaminate
for i in range(n_epochs):
    print(f'Starting epoch {i+1}')
    running_map = scattering.denoise_double(image_target_nu1, image_target_nu2, contamination_arr = contamination_arr, std = std, std_double=std_double, image_init1 = image_init_nu1, image_init2 = image_init_nu2, seed=0, print_each_step=True, steps = 25, n_batch = 25, s_cov_func=threshold_func)

    std_nu1 = scattering.compute_std(running_map[0], contamination_arr = contamination_arr_nu1, s_cov_func=threshold_func) 
    std_nu2 = scattering.compute_std(running_map[1], contamination_arr = contamination_arr_nu2, s_cov_func=threshold_func) 
    # std_double = scattering.compute_std_double(running_map[0], running_map[1], contamination_arr = contamination_arr, image_ref1=image_target_nu1, image_ref2 = image_target_nu2) 
    # std = (std_nu1, std_nu2)

image_syn_nu1 = running_map[0]
image_syn_nu2 = running_map[1]

Starting epoch 1
# of estimators:  33966
Current Loss: 6.13e+01
Current Loss: 6.17e+01
Current Loss: 6.16e+01
Current Loss: 6.10e+01
Current Loss: 6.10e+01
Current Loss: 1.95e+01
Current Loss: 1.69e+01
Current Loss: 1.29e+01
Current Loss: 9.71e+00
Current Loss: 7.96e+00
Current Loss: 6.81e+00
Current Loss: 5.66e+00
Current Loss: 5.06e+00
Current Loss: 4.70e+00
Current Loss: 4.54e+00
Current Loss: 4.15e+00
Current Loss: 3.91e+00
Current Loss: 3.78e+00
Current Loss: 3.81e+00
Current Loss: 3.61e+00
Current Loss: 3.64e+00
Current Loss: 3.61e+00
Current Loss: 3.66e+00
Current Loss: 3.62e+00
Current Loss: 3.67e+00
Time used:  315.7155568599701 s
Starting epoch 2
# of estimators:  33966
Current Loss: 6.13e+01
Current Loss: 6.17e+01
Current Loss: 6.16e+01
Current Loss: 6.10e+01
Current Loss: 6.10e+01
Current Loss: 1.95e+01
Current Loss: 1.69e+01
Current Loss: 1.29e+01
Current Loss: 9.71e+00
Current Loss: 7.96e+00
Current Loss: 6.81e+00
Current Loss: 5.66e+00
Current Loss: 5.06e+00
Current Loss

In [380]:
# Convert tuples to NumPy arrays
dust = np.stack([dust_nu1[0], dust_nu2[0]])  # Shape: (2, ...)
data = np.stack([data_nu1[0], data_nu2[0]])  # Shape: (2, ...)
image_denoised = np.stack([image_syn_nu1[0], image_syn_nu2[0]])  # Ensure it's an array

cmb = True
noise_optimisation = False
# Create an array of objects to preserve different shapes
results = np.array([dust, data, image_denoised])
np.save(f"nu={nu}_cmb={cmb}_noise_opt={noise_optimisation}_double_threshold2", results)