In [18]:
import torch
import numpy as np
from scipy.fft import fft2, ifft2
from astropy.io import fits

def shear_rec(shear1, shear2):
    N_grid = shear1.shape[0]
    theta = np.linspace(-N_grid+1, N_grid-1, 2*N_grid-1)
    theta_x, theta_y = np.meshgrid(theta, theta)
    D_starkernel = -1. / (theta_x + 1j*theta_y) ** 2
    D_starkernel[N_grid-1, N_grid-1] = 0
    y = np.real(ifft2(fft2(D_starkernel, (3*N_grid-2, 3*N_grid-2)) * fft2(shear1 + 1j*shear2, (3*N_grid-2, 3*N_grid-2)))) / np.pi
    y = y[N_grid-1:2*N_grid-1, N_grid-1:2*N_grid-1]
    return y

def shear_rec_torch(shear1, shear2):
    N_grid = shear1.shape[0]
    theta = torch.linspace(-N_grid+1, N_grid-1, 2*N_grid-1, device=shear1.device)
    theta_x, theta_y = torch.meshgrid(theta, theta)
    D_starkernel = -1. / (theta_x + 1j*theta_y) ** 2
    D_starkernel[N_grid-1, N_grid-1] = 0
    y = torch.fft.ifftn(torch.fft.fftn(D_starkernel, s=(3*N_grid-2, 3*N_grid-2)) * torch.fft.fftn(shear1 + 1j*shear2, s=(3*N_grid-2, 3*N_grid-2)))
    y = y.real / torch.tensor([np.pi], device=shear1.device)
    y = y[N_grid-1:2*N_grid-1, N_grid-1:2*N_grid-1]
    return y


# Replace 'shear1.fits' and 'shear2.fits' with the paths to your .fits files
shear1_data = - fits.getdata('/Users/danny/Desktop/cos0_Set1_rotate1_area1_37_gamma1.fits')
shear2_data = fits.getdata('/Users/danny/Desktop/cos0_Set1_rotate1_area1_37_gamma2.fits')

shear1 = torch.tensor(np.float32(shear1_data), device='cpu')  # replace with your data
shear2 = torch.tensor(np.float32(shear2_data), device='cpu')  # replace with your data


# Call the shear_rec function
kappa_python = np.array(shear_rec_torch(-shear1, shear2))

# Write the kappa map to a .fits file
fits.writeto('/Users/danny/Desktop/kappa_python_torch.fits', kappa_python, overwrite=True)


In [19]:
hdu = fits.open('/Users/danny/Desktop/cos0_Set1_rotate1_area1_37_ks.fits')
kappa_zhaoan = hdu[0].data

hdu = fits.open('/Users/danny/Desktop/kappa_matlab.fits')
kappa_matlab = hdu[0].data

hdu = fits.open('/Users/danny/Desktop/kappa_python.fits')
kappa_python = hdu[0].data

hdu = fits.open('/Users/danny/Desktop/kappa_python_torch.fits')
kappa_torch = hdu[0].data

In [20]:
kappa_python

array([[-0.01084101, -0.00526796, -0.00952041, ..., -0.00547499,
        -0.00689655, -0.00453543],
       [-0.0111584 , -0.01148524, -0.01264642, ..., -0.00361932,
        -0.00786871, -0.00921078],
       [-0.00565541, -0.01345704, -0.0163328 , ..., -0.00772127,
        -0.01112676, -0.01391968],
       ...,
       [ 0.0079258 ,  0.00598816, -0.00238138, ...,  0.00239898,
         0.00210277,  0.00282967],
       [ 0.00817869,  0.01125333, -0.00097717, ...,  0.00021383,
         0.00399128,  0.00349119],
       [ 0.00440049,  0.00631212, -0.00102575, ...,  0.00149292,
         0.00220676, -0.00309678]], dtype='>f8')

In [11]:
kappa_matlab

array([[-0.01084101, -0.00526797, -0.00952041, ..., -0.00547499,
        -0.00689655, -0.00453543],
       [-0.0111584 , -0.01148524, -0.01264642, ..., -0.00361933,
        -0.00786872, -0.00921078],
       [-0.00565541, -0.01345704, -0.0163328 , ..., -0.00772127,
        -0.01112677, -0.01391968],
       ...,
       [ 0.0079258 ,  0.00598816, -0.00238138, ...,  0.00239898,
         0.00210277,  0.00282967],
       [ 0.00817869,  0.01125333, -0.00097716, ...,  0.00021382,
         0.00399127,  0.00349119],
       [ 0.00440049,  0.00631212, -0.00102574, ...,  0.00149292,
         0.00220676, -0.00309678]], dtype='>f8')

In [12]:
kappa_zhaoan

array([[-0.01084101, -0.00526797, -0.00952041, ..., -0.00547499,
        -0.00689655, -0.00453543],
       [-0.0111584 , -0.01148524, -0.01264642, ..., -0.00361933,
        -0.00786872, -0.00921078],
       [-0.00565541, -0.01345704, -0.0163328 , ..., -0.00772127,
        -0.01112677, -0.01391968],
       ...,
       [ 0.0079258 ,  0.00598816, -0.00238138, ...,  0.00239898,
         0.00210277,  0.00282967],
       [ 0.00817869,  0.01125333, -0.00097716, ...,  0.00021382,
         0.00399127,  0.00349119],
       [ 0.00440049,  0.00631212, -0.00102574, ...,  0.00149292,
         0.00220676, -0.00309678]], dtype='>f8')

In [21]:
kappa_torch

array([[-0.01084101, -0.00526797, -0.0095204 , ..., -0.00547499,
        -0.00689654, -0.00453542],
       [-0.0111584 , -0.01148524, -0.01264642, ..., -0.00361932,
        -0.00786871, -0.00921078],
       [-0.00565541, -0.01345704, -0.0163328 , ..., -0.00772127,
        -0.01112676, -0.01391968],
       ...,
       [ 0.0079258 ,  0.00598816, -0.00238138, ...,  0.00239898,
         0.00210277,  0.00282967],
       [ 0.00817868,  0.01125333, -0.00097717, ...,  0.00021382,
         0.00399127,  0.00349118],
       [ 0.00440049,  0.00631211, -0.00102575, ...,  0.00149292,
         0.00220676, -0.00309678]], dtype='>f4')