In [173]:
# import packages
%matplotlib inline
import numpy as np
import math
np.math = math  # Redirect numpy.math to the built-in math module

import matplotlib.pyplot as plt
import torch
import torch.fft

# change the path below to where you download the scattering package
import sys
# sys.path.append('/content/scattering_transform/')
# import os
# os.chdir('/content/scattering')

# import scattering package
import scattering
import importlib
scattering = importlib.reload(scattering)
import utils
utils = importlib.reload(utils)

In [174]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.constants import h, k, c
from astropy import units as u

# Load the column density map from file
logN_H = np.load('Archive/Turb_3.npy')[1]  
# logN_H = utils.downsample_by_four(logN_H)
nx, ny = logN_H.shape

# Define constant dust temperature
T_d = 10  # Typical Planck dust temperature in K

# int_val = []
# for nu in nu_list:
#     int_val.append(np.mean(modified_blackbody(logN_H, T_d, nu*1e9)))
# int_val = np.array(int_val)

# Observation frequency (e.g., 353 GHz in Hz)

# Compute the mock observed intensity map in μK_CMB
nu = (217,353)
I_nu_map_μK_nu1 = utils.modified_blackbody(logN_H, T_d, nu[0]*1e9)
I_nu_map_μK_nu2 = utils.modified_blackbody(logN_H, T_d, nu[1]*1e9)

In [175]:
# Ensure inputs have shape (1, H, W)
dust_nu1 = I_nu_map_μK_nu1[None, :, :]
dust_nu2 = I_nu_map_μK_nu2[None, :, :]

# --- Parameters ---
n_realizations = 10
SNR = 1
amplitude = 2.
spectral_index = -1.7

nx, ny = dust_nu1.shape[-2], dust_nu1.shape[-1]

# --- Compute noise variances ---
variance_nu1 = (np.std(dust_nu1) / SNR) ** 2
variance_nu2 = (np.std(dust_nu2) / SNR) ** 2
variance = (variance_nu1, variance_nu2)

# --- Create contamination_arr with shape (n_realizations, 2, 1, H, W) ---
contamination_arr = np.zeros((n_realizations, 2, 1, nx, ny), dtype=np.float32)

for i in range(n_realizations):
    # Shared CMB: shape (1, H, W)
    cmb_map = utils.generate_cmb_map(n_x=nx, n_y=ny, amplitude=amplitude, spectral_index=spectral_index)
    cmb_map = cmb_map.cpu().numpy()[None, :, :]

    # Independent noise: shape (1, H, W)
    noise_nu1 = np.random.normal(0, np.sqrt(variance_nu1), (1, nx, ny))
    noise_nu2 = np.random.normal(0, np.sqrt(variance_nu2), (1, nx, ny))

    # Total contamination: shape (1, H, W)
    contamination_arr[i, 0] = noise_nu1 + cmb_map
    contamination_arr[i, 1] = noise_nu2 + cmb_map

contamination_arr_nu1 = contamination_arr[:, 0]  # shape: (Mn, 1, H, W)
contamination_arr_nu2 = contamination_arr[:, 1]  # shape: (Mn, 1, H, W)


In [163]:
SNR = 1
variance_nu1 = (np.std(dust_nu1)/SNR)**2
variance_nu2 = (np.std(dust_nu2)/SNR)**2
variance = (variance_nu1, variance_nu2)

noise_nu1 = np.random.normal(0, np.sqrt(variance_nu1), dust_nu1.shape)
noise_nu2 = np.random.normal(0, np.sqrt(variance_nu2), dust_nu2.shape)

cmb_map = utils.generate_cmb_map(n_x=nx, n_y=ny, amplitude=amplitude, spectral_index=spectral_index)
cmb_map = cmb_map.cpu().numpy()[None, :, :]

data_nu1 = dust_nu1 + noise_nu1 + cmb_map
data_nu2 = dust_nu2 + noise_nu2 + cmb_map

# define target
image_target_nu1 = data_nu1 
image_target_nu2 = data_nu2 

# definte initial maps for optimisation
image_init_nu1 = image_target_nu1
image_init_nu2 = image_target_nu2

estimator_name = 's_cov'
std_nu1 = scattering.compute_std(image_target_nu1, contamination_arr = contamination_arr_nu1,  wavelets='BS') 
std_nu2 = scattering.compute_std(image_target_nu2, contamination_arr = contamination_arr_nu2,  wavelets='BS') 
std_double = scattering.compute_std_double(image_target_nu1, image_target_nu2, contamination_arr = contamination_arr, image_ref1=image_target_nu1, image_ref2 = image_target_nu2,  wavelets='BS') 
std = (std_nu1, std_nu2)

In [164]:
n_epochs = 3 #number of epochs
# decontaminate
for i in range(n_epochs):
    print(f'Starting epoch {i+1}')
    running_map = scattering.denoise_double(image_target_nu1, image_target_nu2, contamination_arr = contamination_arr, std = std, std_double=std_double, image_init1 = image_init_nu1, image_init2 = image_init_nu2, image_ref1 = image_target_nu1, image_ref2 = image_target_nu2, seed=0, wavelets='BS', print_each_step=True, steps = 25)

    std_nu1 = scattering.compute_std(running_map[0], contamination_arr = contamination_arr_nu1) 
    std_nu2 = scattering.compute_std(running_map[1], contamination_arr = contamination_arr_nu2) 
    std_double = scattering.compute_std_double(running_map[0], running_map[1], contamination_arr = contamination_arr, image_ref1=image_target_nu1, image_ref2 = image_target_nu2) 
    std = (std_nu1, std_nu2)

image_syn_nu1 = running_map[0]
image_syn_nu2 = running_map[1]

Starting epoch 1
Current Loss: 1.49e+01
Current Loss: 1.49e+01
Current Loss: 1.49e+01
Current Loss: 1.14e+01
Current Loss: 9.28e+00
Current Loss: 7.06e+00
Current Loss: 5.49e+00
Current Loss: 4.21e+00
Current Loss: 3.19e+00
Current Loss: 2.55e+00
Current Loss: 2.02e+00
Current Loss: 1.65e+00
Current Loss: 1.41e+00
Current Loss: 1.18e+00
Current Loss: 1.06e+00
Current Loss: 9.57e-01
Current Loss: 8.79e-01
Current Loss: 8.28e-01
Current Loss: 7.84e-01
Current Loss: 7.45e-01
Current Loss: 7.21e-01
Current Loss: 6.99e-01
Current Loss: 6.78e-01
Current Loss: 6.64e-01
Current Loss: 6.48e-01
Time used:  291.9786710739136 s
Starting epoch 2
Current Loss: 2.27e+01
Current Loss: 2.27e+01
Current Loss: 2.25e+01
Current Loss: 1.82e+01
Current Loss: 1.38e+01
Current Loss: 1.11e+01
Current Loss: 8.91e+00
Current Loss: 6.62e+00
Current Loss: 5.48e+00
Current Loss: 4.48e+00
Current Loss: 3.38e+00
Current Loss: 2.79e+00
Current Loss: 2.29e+00
Current Loss: 1.77e+00
Current Loss: 1.51e+00
Current Loss: 

In [165]:
# Convert tuples to NumPy arrays
dust = np.stack([dust_nu1[0], dust_nu2[0]])  # Shape: (2, ...)
data = np.stack([data_nu1[0], data_nu2[0]])  # Shape: (2, ...)
image_denoised = np.stack([image_syn_nu1[0], image_syn_nu2[0]])  # Ensure it's an array

cmb = True
noise_optimisation = False
# Create an array of objects to preserve different shapes
results = np.array([dust, data, image_denoised])
np.save(f"nu={nu}_cmb={cmb}_noise_opt={noise_optimisation}_double_test", results)