In [None]:
import torch
import torch.fft
import numpy as np
import astropy.io.fits as fits
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display

fname = '/Users/danny/Desktop/WL/kappa_map/prediction/map_24156.fits'
# Open image
with fits.open(fname) as f:
    pred = f[0].data * 100
    true = f[1].data * 100
    res = f[2].data * 100
img = np.float32(res)
img = torch.from_numpy(np.array(img))
rows, cols = img.shape

In [None]:
def gaus2d(x=0, y=0, mx=0, my=0, sx=1, sy=1):
    return 1. / (2. * np.pi * sx * sy) * np.exp(-((x - mx)**2. / (2. * sx**2.) + (y - my)**2. / (2. * sy**2.)))

def gaussian_mask(rows, cols, sigma):
    # Create Gaussian masks
    x = np.linspace(0, rows, rows)
    y = np.linspace(0, cols, cols)
    x, y = np.meshgrid(x, y)

    crow, ccol = rows // 2 , cols // 2 
    gaus = gaus2d(x, y, mx=crow, my=ccol, sx=sigma, sy=sigma)
    gaus /= gaus.max()
    return gaus

def draw(ax, data, title, scale=None, cmap=plt.cm.jet):
    if scale:
        ax.imshow(data, cmap=cmap, vmin=scale[0], vmax=scale[1])
    else:
        ax.imshow(data, cmap=cmap)
    ax.set_title(title)

In [None]:

# Fourier transform
f = torch.fft.fftn(img)
# Shift the zero-frequency component to the center of the spectrum.
f_shift = torch.fft.fftshift(f)

# Function to compute and plot the result
def compute_and_plot(sigma1, sigma2, invert_mask):
    fig, axs = plt.subplots(2, 3, figsize=(8,5))

    # Create Gaussian mask
    gaus1 = gaussian_mask(rows, cols, sigma=sigma1)
    gaus2 = gaussian_mask(rows, cols, sigma=sigma2)
    mask = gaus1 - gaus2
    mask /= mask.max()
    if invert_mask:
        mask = 1 - mask
    
    # Apply mask and inverse FFT
    fshift_masked = f_shift*mask
    f_ishift = torch.fft.ifftshift(fshift_masked)
    img_back = torch.fft.ifftn(f_ishift)
    img_back = torch.abs(img_back)
    
    draw(axs[0,0], img, 'Original Image', scale=[-2,5])
    draw(axs[1,0], img_back, 'Inverse FFT Image', scale=[-2,5])
    draw(axs[0,1], torch.log(torch.abs(f_shift)+1), 'FFT Image', cmap='viridis')
    draw(axs[1,1], torch.log(torch.abs(fshift_masked)+1), 'Masked FFT Image', cmap='viridis')
    draw(axs[0,2], mask, 'Gaussian Mask', scale=[0,1], cmap='gray')

    plt.tight_layout()

In [None]:
# Create checkbox
invert_mask_checkbox = widgets.Checkbox(value=False, description='Invert Mask')
# Create sliders
slider_layout = widgets.Layout(width='500px')
sigma1_slider = widgets.FloatSlider(min=0.1, max=100, step=0.1, value=5, description='sigma1', layout=slider_layout)
sigma2_slider = widgets.FloatSlider(min=0.1, max=100, step=0.1, value=1, description='sigma2', layout=slider_layout)

# Create interactive widget
interactive_plot = widgets.interactive(compute_and_plot, sigma1=sigma1_slider, sigma2=sigma2_slider, invert_mask=invert_mask_checkbox)

# Display the interactive plot
display(interactive_plot)