In [1]:
import numpy as np
import update_register as ur
from aiapy.calibrate import register, update_pointing
import sunpy.map

from astropy.io import fits
import astropy.units as u
from sunpy.map.sources.sdo import AIAMap, HMIMap
from sunpy.map import map_edges
try:
    import cv2
    HAS_CV2 = True
except ImportError:
    HAS_CV2 = False

try:
    import cupy
    from cupyx.scipy.ndimage import affine_transform as cupy_affine_transform
    HAS_CUPY = True
except ImportError:
    HAS_CUPY = False

In [2]:
def do_cupy_affine_transform(image, rmatrix, order=1, scale=1.0, image_center=None,recenter=False, missing=0.0):
    """
    Adapted from sunpy.image.transform.affine_transform

    ***MODIFIED: used cupyx.scipy.ndimage.affine_transformation
    ***MODIFIED added cupy stuff
    """

    rmatrix = rmatrix / scale
    array_center = (np.array(image.shape)[::-1] - 1) / 2.0

    # Make sure the image center is an array and is where it's supposed to be
    if image_center is not None:
        image_center = np.asanyarray(image_center)
    else:
        image_center = array_center

    # Determine center of rotation based on use (or not) of the recenter keyword
    if recenter:
        rot_center = array_center
    else:
        rot_center = image_center

    displacement = np.dot(rmatrix, rot_center)
    shift = image_center - displacement

    if np.any(np.isnan(image)):
        warnings.warn("Setting NaNs to 0 for SciPy rotation.", SunpyUserWarning)
    # Transform the image using the scipy affine transform
    image = cupy.array(np.nan_to_num(image))
    rmatrix = cupy.array(rmatrix)
    rotated_image = cupy_affine_transform(
        image.T, rmatrix, offset=shift, order=order,
        mode='constant', cval=missing).T


    return cupy.asnumpy(rotated_image)

In [8]:
%%time
path = './AIA_data/aia_lev1_171a_2017_09_10t01_17_09_35z_image_lev1.fits'
m = sunpy.map.Map(path)
m_up = update_pointing(m)

CPU times: user 258 ms, sys: 3.42 ms, total: 261 ms
Wall time: 2.24 s


In [6]:
smap = m_up
missing = None

In [None]:
%%time
if ((smap.scale[0] / 0.6).round() != 1.0 * u.arcsec / u.pix
        and smap.data.shape != (4096, 4096)):
    scale = (smap.scale[0] / 0.6).round() * 0.6 * u.arcsec
else:
    scale = 0.6 * u.arcsec  # pragma: no cover # can't test this because it needs a full res image
scale_factor = smap.scale[0] / scale

missing = smap.min() if missing is None else missing