In [35]:
import numpy as np
import update_register as ur
from aiapy.calibrate import register, update_pointing
import sunpy.map

from astropy.io import fits
import astropy.units as u
from sunpy.map.sources.sdo import AIAMap, HMIMap
from sunpy.map import map_edges
try:
    import cv2
    HAS_CV2 = True
    print('openCV: ', cv2.__version__)
except ImportError:
    HAS_CV2 = False

try:
    import cupy
    from cupyx.scipy.ndimage import affine_transform as cupy_affine_transform
    HAS_CUPY = True
    print('CuPy ', cupy.__version__)
except ImportError:
    HAS_CUPY = False

openCV:  4.4.0
CuPy  7.5.0


In [2]:
def do_cupy_affine_transform(image, rmatrix, order=1, scale=1.0, image_center=None,recenter=False, missing=0.0):
    """
    Adapted from sunpy.image.transform.affine_transform

    ***MODIFIED: used cupyx.scipy.ndimage.affine_transformation
    ***MODIFIED added cupy stuff
    """

    rmatrix = rmatrix / scale
    array_center = (np.array(image.shape)[::-1] - 1) / 2.0

    # Make sure the image center is an array and is where it's supposed to be
    if image_center is not None:
        image_center = np.asanyarray(image_center)
    else:
        image_center = array_center

    # Determine center of rotation based on use (or not) of the recenter keyword
    if recenter:
        rot_center = array_center
    else:
        rot_center = image_center

    displacement = np.dot(rmatrix, rot_center)
    shift = image_center - displacement

    if np.any(np.isnan(image)):
        warnings.warn("Setting NaNs to 0 for SciPy rotation.", SunpyUserWarning)
    # Transform the image using the scipy affine transform
    image = cupy.array(np.nan_to_num(image))
    rmatrix = cupy.array(rmatrix)
    rotated_image = cupy_affine_transform(
        image.T, rmatrix, offset=shift, order=order,
        mode='constant', cval=missing).T


    return cupy.asnumpy(rotated_image)

In [10]:
%%time
path = './AIA_data/aia_lev1_171a_2017_09_10t01_17_09_35z_image_lev1.fits'
m = sunpy.map.Map(path)
m_up = update_pointing(m)
smap = m_up
missing = None

CPU times: user 342 ms, sys: 3.54 ms, total: 345 ms
Wall time: 2.25 s


### Alternate data / header loader - no sunpy - including metadata preparation

In [14]:
%%time
hdul = fits.open(path)
hdul[1].verify('silentfix')
header = hdul[1].header
data = hdul[1].data.astype(np.float64)
data /= header['EXPTIME']
# Target scale is 0.6 arcsec/px
target_scale = 0.6
scale_factor = header['CDELT1'] / target_scale
# Center of rotation at reference pixel converted to a coordinate origin at 0
reference_pixel = [header['CRPIX1'] - 1, header['CRPIX2'] - 1]
# Rotation angle with openCV uses coordinate origin at top-left corner. For solar images in numpy we need to invert the angle.
angle = -header['CROTA2']

CPU times: user 144 ms, sys: 16.4 ms, total: 160 ms
Wall time: 158 ms


## Sunpy metadata preparation

In [18]:
%%time
if ((smap.scale[0] / 0.6).round() != 1.0 * u.arcsec / u.pix
        and smap.data.shape != (4096, 4096)):
    scale = (smap.scale[0] / 0.6).round() * 0.6 * u.arcsec
else:
    scale = 0.6 * u.arcsec  # pragma: no cover # can't test this because it needs a full res image
scale_factor = smap.scale[0] / scale

missing = smap.min() if missing is None else missing

##### Entering cupy_rotate()

# tempmap = cupy_rotate(smap, recenter=True,
#                       scale=scale_factor.value,
#                       order=order,
#                       missing=missing)
angle = None
rmatrix = None
recenter = True
order = 1
scale = scale_factor.value

if angle is not None and rmatrix is not None:
    raise ValueError("You cannot specify both an angle and a rotation matrix.")
elif angle is None and rmatrix is None:
    rmatrix = smap.rotation_matrix
        

# Copy meta data
new_meta = smap.meta.copy()

# Calculate the shape in pixels to contain all of the image data
extent = np.max(np.abs(np.vstack((smap.data.shape @ rmatrix,
                                  smap.data.shape @ rmatrix.T))), axis=0)

# Calculate the needed padding or unpadding
diff = np.asarray(np.ceil((extent - smap.data.shape) / 2), dtype=int).ravel()
# Pad the image array
pad_x = int(np.max((diff[1], 0)))
pad_y = int(np.max((diff[0], 0)))

new_data = np.pad(smap.data,
                  ((pad_y, pad_y), (pad_x, pad_x)),
                  mode='constant',
                  constant_values=(missing, missing))
new_meta['crpix1'] += pad_x
new_meta['crpix2'] += pad_y

# All of the following pixel calculations use a pixel origin of 0

pixel_array_center = (np.flipud(new_data.shape) - 1) / 2.0

CPU times: user 4.61 ms, sys: 4 ms, total: 8.6 ms
Wall time: 7.59 ms


### Bottleneck in Sunpy metadata preparation ~ 100ms

In [21]:
%%time
# Create a temporary map so we can use it for the data to pixel calculation.
temp_map = smap._new_instance(new_data, new_meta, smap.plot_settings)
# Convert the axis of rotation from data coordinates to pixel coordinates
pixel_rotation_center = u.Quantity(temp_map.world_to_pixel(smap.reference_coordinate,origin=0)).value
del temp_map

CPU times: user 89 ms, sys: 4 ms, total: 93 ms
Wall time: 90.8 ms


## CuPy Affine Transform (GPU processing) - ~100ms

In [29]:
%%time
if recenter:
    pixel_center = pixel_rotation_center
else:
    pixel_center = pixel_array_center
    
new_data2 = do_cupy_affine_transform(new_data.T,
                            np.asarray(rmatrix),
                            order=order, scale=scale,
                            image_center=np.flipud(pixel_center),
                            recenter=recenter, missing=missing).T

CPU times: user 124 ms, sys: 20.5 ms, total: 144 ms
Wall time: 143 ms


In [34]:
%%timeit 
new_data2 = do_cupy_affine_transform(new_data.T,
                            np.asarray(rmatrix),
                            order=order, scale=scale,
                            image_center=np.flipud(pixel_center),
                            recenter=recenter, missing=missing).T

96.2 ms ± 493 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [30]:
%%time
if recenter:
    new_reference_pixel = pixel_array_center
else:
    # Calculate new pixel coordinates for the rotation center
    new_reference_pixel = pixel_center + np.dot(rmatrix,pixel_rotation_center - pixel_center)
    new_reference_pixel = np.array(new_reference_pixel).ravel()
    
# The FITS-WCS transform is by definition defined around the
# reference coordinate in the header. -- THIS TAKES 30ms!!! 
lon, lat = smap._get_lon_lat(smap.reference_coordinate.frame)
rotation_center = u.Quantity([lon, lat])

# Define the new reference_pixel
new_meta['crval1'] = rotation_center[0].value
new_meta['crval2'] = rotation_center[1].value
new_meta['crpix1'] = new_reference_pixel[0] + 1  # FITS pixel origin is 1
new_meta['crpix2'] = new_reference_pixel[1] + 1  # FITS pixel origin is 1

# Unpad the array if necessary
unpad_x = -np.min((diff[1], 0))
if unpad_x > 0:
    new_data = new_data[:, unpad_x:-unpad_x]
    new_meta['crpix1'] -= unpad_x
unpad_y = -np.min((diff[0], 0))
if unpad_y > 0:
    new_data = new_data[unpad_y:-unpad_y, :]
    new_meta['crpix2'] -= unpad_y

    
# Calculate the new rotation matrix to store in the header by
# "subtracting" the rotation matrix used in the rotate from the old one
# That being calculate the dot product of the old header data with the
# inverse of the rotation matrix.
pc_C = np.dot(smap.rotation_matrix, np.linalg.inv(rmatrix))
new_meta['PC1_1'] = pc_C[0, 0]
new_meta['PC1_2'] = pc_C[0, 1]
new_meta['PC2_1'] = pc_C[1, 0]
new_meta['PC2_2'] = pc_C[1, 1]

# Update pixel size if image has been scaled.
if scale != 1.0:
    new_meta['cdelt1'] = (smap.scale[0] / scale).value
    new_meta['cdelt2'] = (smap.scale[1] / scale).value

# Remove old CROTA kwargs because we have saved a new PCi_j matrix.
new_meta.pop('CROTA1', None)
new_meta.pop('CROTA2', None)
# Remove CDi_j header
new_meta.pop('CD1_1', None)
new_meta.pop('CD1_2', None)
new_meta.pop('CD2_1', None)
new_meta.pop('CD2_2', None)

# Create new map with the modification
tempmap = smap._new_instance(new_data, new_meta, smap.plot_settings)

# end of cupy_rotate()

CPU times: user 44.6 ms, sys: 517 µs, total: 45.1 ms
Wall time: 35.7 ms


###  -> back into cupy_register()

In [31]:
%%time
center = np.floor(tempmap.meta['crpix1'])
range_side = (center + np.array([-1, 1]) * smap.data.shape[0] / 2) * u.pix
newmap = tempmap.submap(
    u.Quantity([range_side[0], range_side[0]]),
    top_right=u.Quantity([range_side[1], range_side[1]]) - 1*u.pix)

newmap.meta['r_sun'] = newmap.meta['rsun_obs'] / newmap.meta['cdelt1']
newmap.meta['lvl_num'] = 1.5
newmap.meta['bitpix'] = -64

CPU times: user 9.85 ms, sys: 3.96 ms, total: 13.8 ms
Wall time: 12.6 ms


## - Total Sunpy overhead: ~ 150ms; 2s with Sunpy data loader
## - CuPy interpolation (linear only): 100 ms