# Comparison of Astropy vs Torch implementation of coordinate transformation
In this notebook, we will compare the astropy wcs `pixel_to_world` transformation and our implementation. For our implementation, we've broken down what we do in `reprojection.reproject.calculate_skyCoords()` for readibility.

In [1]:
from astropy.io import fits
from astropy.wcs import WCS
import numpy as np
import torch

In [2]:
device = 'cpu'
# Load your target WCS
target_hdu = fits.open('./data/Atik1442426-0035_0032_light.fits')[0]
target_wcs = WCS(target_hdu.header)
target_shape = target_hdu.data.shape

# Convert the data to native byte order before creating tensors
target_data = target_hdu.data.astype(target_hdu.data.dtype.newbyteorder('='))
# Now create the tensor
target_image = torch.tensor(target_data, dtype=torch.float64, device=device)

FileNotFoundError: [Errno 2] No such file or directory: './data/Atik1442426-0035_0032_light.fits'

Now let's run the astropy results

In [4]:
target_wcs_astropy = WCS(target_hdu.header)
# Test pixel coordinate
x_test, y_test = 200, 200
# Get pixel grid for target image
H, W = target_image.shape
y_grid, x_grid = np.mgrid[0:H, 0:W]  # Note: 0-based indexing

# Convert to world coordinates (RA, Dec)
ra_astropy, dec_astropy = target_wcs_astropy.wcs_pix2world(x_test, y_test, 0)


35.24380952832857 59.09821844046949
200.00000000000136 200.00000000000182


And now using our implementation

In [5]:
# Get WCS parameters
CRPIX1 = target_wcs_astropy.wcs.crpix[0]
CRPIX2 = target_wcs_astropy.wcs.crpix[1]
CRVAL1 = target_wcs_astropy.wcs.crval[0]  # Reference RA
CRVAL2 = target_wcs_astropy.wcs.crval[1]  # Reference Dec
PC_matrix = target_wcs_astropy.wcs.get_pc()  # PC Matrix
CDELT = np.array(target_wcs_astropy.wcs.cdelt)  # Scaling factors

# Convert numpy arrays to torch tensors if needed
if not isinstance(x_test, torch.Tensor):
    x = torch.tensor(x_test, device=device, dtype=torch.float64)
    y = torch.tensor(y_test, device=device, dtype=torch.float64)

# Step 1: Compute Pixel Offsets - Precisely as in wcsprm::p2x
u = x - (CRPIX1 - 1)
v = y - (CRPIX2 - 1)

# Step 2: Apply PC Matrix (Rotation) and CDELT (Scaling)
CD_matrix = PC_matrix * CDELT  # Construct CD Matrix
CD_matrix = torch.tensor(CD_matrix, device=device, dtype=torch.float64)
# Handle both scalar and array inputs
if u.dim() == 0:  # scalar
    pixel_offsets = torch.tensor([u.item(), v.item()], device=device, dtype=torch.float64)
    transformed = torch.matmul(CD_matrix, pixel_offsets)
    x_scaled, y_scaled = transformed.unbind()
else:  # arrays
    # Reshape for batch matrix multiplication if needed
    if u.dim() > 1:
        original_shape = u.shape
        u_flat = u.reshape(-1)
        v_flat = v.reshape(-1)
    else:
        u_flat = u
        v_flat = v

    # Stack coordinates for batch processing
    pixel_offsets = torch.stack([u_flat, v_flat], dim=1)  # Shape: [N, 2]

    # Perform batch matrix multiplication
    transformed = torch.matmul(pixel_offsets, CD_matrix.T)  # Shape: [N, 2]
    x_scaled = transformed[:, 0]
    y_scaled = transformed[:, 1]

    # Reshape back to original if needed
    if u.dim() > 1:
        x_scaled = x_scaled.reshape(original_shape)
        y_scaled = y_scaled.reshape(original_shape)

# Step 3: Use the exact tanx2s logic from WCSLib
# Compute the radial distance
r = torch.sqrt(x_scaled ** 2 + y_scaled ** 2)
r0 = torch.tensor(180.0 / torch.pi, device=device)  # R2D from WCSLib

# Apply the tanx2s function exactly as in WCSLib
# Note the sign conventions
phi = torch.zeros_like(r)
non_zero_r = r != 0
if torch.any(non_zero_r):
    phi[non_zero_r] = torch.rad2deg(torch.atan2(-x_scaled[non_zero_r], y_scaled[non_zero_r]))

theta = torch.rad2deg(torch.atan2(r0, r))

# Step 4: Now apply the sph2x (spherical to native) transform from prjx2s
# First convert to radians exactly as WCSLib would
phi_rad = torch.deg2rad(phi)
theta_rad = torch.deg2rad(theta)
ra0_rad = torch.tensor(CRVAL1 * torch.pi / 180.0, device=device)
dec0_rad = torch.tensor(CRVAL2 * torch.pi / 180.0, device=device)

# For TAN projection, the pole is at (0,90) in native coordinates
sin_theta = torch.sin(theta_rad)
cos_theta = torch.cos(theta_rad)
sin_phi = torch.sin(phi_rad)
cos_phi = torch.cos(phi_rad)
sin_dec0 = torch.sin(dec0_rad)
cos_dec0 = torch.cos(dec0_rad)

# This is the exact calculation from wcslib's sphx2s function
sin_dec = sin_theta * sin_dec0 + cos_theta * cos_dec0 * cos_phi
dec_rad = torch.arcsin(sin_dec)

# Calculate RA offset - exact formula from WCSLib
y_term = cos_theta * sin_phi
x_term = sin_theta * cos_dec0 - cos_theta * sin_dec0 * cos_phi
ra_rad = ra0_rad + torch.atan2(-y_term, x_term)

# Convert to degrees and normalize
ra = torch.rad2deg(ra_rad) % 360.0
dec = torch.rad2deg(dec_rad)


Let's compare our results with astropy now.

In [6]:
print(f"Final celestial coordinates Torch:   RA={ra}, Dec={dec}")
print(f"Final celestial coordinates Astropy: RA={ra_astropy}, Dec={dec_astropy}")

Final celestial coordinates Torch:   RA=35.24380952587278, Dec=59.09821841758037
Final celestial coordinates Astropy: RA=35.24380952832857, Dec=59.09821844046949


They are extremely close! This is well below the arcsecond precision we need.

Now let's compare the world_to_pixel results. We are really doing a round trip test, so the coordinates, after the inverse operations, should be at the starting value.

In [9]:
# Convert numpy arrays to torch tensors if needed
if not isinstance(ra, torch.Tensor):
    ra = torch.tensor(ra, device=device)
    dec = torch.tensor(dec, device=device)

# Helper functions for trigonometric calculations
def atan2d(y, x):
    """PyTorch implementation of WCSLib's atan2d function"""
    return torch.rad2deg(torch.atan2(y, x))

def sincosd(angle_deg):
    """PyTorch implementation of WCSLib's sincosd function"""
    angle_rad = torch.deg2rad(angle_deg)
    return torch.sin(angle_rad), torch.cos(angle_rad)

# Step 1: Convert from world to native spherical coordinates
# Convert to radians
ra_rad = torch.deg2rad(ra)
dec_rad = torch.deg2rad(dec)
ra0_rad = torch.tensor(CRVAL1 * torch.pi / 180.0, device=device)
dec0_rad = torch.tensor(CRVAL2 * torch.pi / 180.0, device=device)

# Calculate the difference in RA
delta_ra = ra_rad - ra0_rad

# Calculate sine and cosine values
sin_dec = torch.sin(dec_rad)
cos_dec = torch.cos(dec_rad)
sin_dec0 = torch.sin(dec0_rad)
cos_dec0 = torch.cos(dec0_rad)
sin_delta_ra = torch.sin(delta_ra)
cos_delta_ra = torch.cos(delta_ra)

# Calculate the native spherical coordinates using the correct sign conventions
# Calculate the numerator for phi (native longitude)
y_phi = -cos_dec * sin_delta_ra  # Note the negative sign

# Calculate the denominator for phi
x_phi = sin_dec * cos_dec0 - cos_dec * sin_dec0 * cos_delta_ra

# Calculate native longitude (phi)
phi = atan2d(y_phi, x_phi)

# Calculate native latitude (theta)
theta = torch.rad2deg(torch.arcsin(sin_dec * sin_dec0 + cos_dec * cos_dec0 * cos_delta_ra))

# Step 2: Apply the TAN projection (tans2x function from WCSLib)
# Calculate sine and cosine of phi and theta
sin_phi, cos_phi = sincosd(phi)
sin_theta, cos_theta = sincosd(theta)

# Check for singularity (when sin_theta is zero)
eps = 1e-10
if torch.any(torch.abs(sin_theta) < eps):
    raise ValueError("Singularity in tans2x: theta close to 0 degrees")

# r0 is the radius scaling factor (typically 180.0/π)
r0 = torch.tensor(180.0 / torch.pi, device=device)

# Calculate the scaling factor r with correct sign
r = r0 * cos_theta / sin_theta

# Calculate intermediate world coordinates (x_scaled, y_scaled)
# With the corrected signs based on your findings
x_scaled = -r * sin_phi  # Note the negative sign
y_scaled = r * cos_phi

# Step 3: Apply the inverse of the CD matrix to get pixel offsets
# First, construct the CD matrix
CD_matrix = PC_matrix * CDELT
CD_matrix = torch.tensor(CD_matrix, device=device)
# Calculate the inverse of the CD matrix
CD_inv = torch.linalg.inv(CD_matrix)

# Handle batch processing for arrays
if ra.dim() == 0:  # scalar inputs
    standard_coords = torch.tensor([x_scaled.item(), y_scaled.item()], device=device, dtype=torch.float64)
    pixel_offsets = torch.matmul(CD_inv, standard_coords)
    u = pixel_offsets[0]
    v = pixel_offsets[1]
else:  # array inputs
    # Reshape for batch processing if needed
    if ra.dim() > 1:
        original_shape = ra.shape
        x_scaled_flat = x_scaled.reshape(-1)
        y_scaled_flat = y_scaled.reshape(-1)
    else:
        x_scaled_flat = x_scaled
        y_scaled_flat = y_scaled

    # Stack for batch matrix multiplication
    standard_coords = torch.stack([x_scaled_flat, y_scaled_flat], dim=1)  # Shape: [N, 2]

    # Use batch matrix multiplication
    pixel_offsets = torch.matmul(standard_coords, CD_inv.T)  # Shape: [N, 2]
    u = pixel_offsets[:, 0]
    v = pixel_offsets[:, 1]

    # Reshape back to original dimensions if needed
    if ra.dim() > 1:
        u = u.reshape(original_shape)
        v = v.reshape(original_shape)

# Step 4: Add the reference pixel to get final pixel coordinates
# Remember to add (CRPIX-1) to account for 1-based indexing in FITS/WCS
x_pixel = u + (CRPIX1 - 1)
y_pixel = v + (CRPIX2 - 1)


print(f"Final: x={x_pixel}, y={y_pixel}")
print(f"Difference in x: {x_pixel - x_test}")
print(f"Difference in y: {y_pixel - y_test}")

Final: x=200.00000000004002, y=200.00000000003547
Difference in x: 4.001776687800884e-11
Difference in y: 3.54702933691442e-11


So this is well below what we need :)