In [1]:
from functools import wraps
import inspect
import numpy as np


def view_as_nddata(expected_type):
    def decorator(func):
        @wraps(func)
        def func_wrapper(*args, **kwargs):
            # Introspect the function arguments
#             fargs = inspect.getfullargspec(func)
            
            # Get the expected data argument
            input_data = args[0]
            
            new_input_data = expected_type.__from_nddata__(input_data)
            
            output_data = func(new_input_data, *args[1:], **kwargs)
            
            new_output_data = input_data.__from_nddata__(output_data, input_data)

            return new_output_data
        return func_wrapper
    return decorator

- Decorator: wraps the individual function in the package and is passed the class representing the returned type of the function.
- Before the function is called, the input data object is converted to the type specified in the decorator by calling the `__from_nddata__` method on the type specifed in the decorator.
- After the function has worked over the casted data object, `__from_nddata__(new_data_object, original_data_object=None)` is called on original input data object.


In [10]:
from specutils import Spectrum1D
from ndcube import NDCube


# Might be able to use the support_nddata decorator on the initializer of Spectrum1D

class NewSpectrum1D(Spectrum1D):
    @classmethod
    def __from_nddata__(cls, new_data_object, original_data_object=None):
        return cls(data=new_data_object.data, 
                   unit=new_data_object.unit, 
                   wcs=new_data_object.wcs, 
                   meta=new_data_object.meta,
                   uncertainty=new_data_object.uncertainty)
    
    
class NewNDCube(NDCube):
    @classmethod
    def __from_nddata__(cls, new_data_object, original_data_object=None):
        new_ndcube = cls(new_data_object,
                         wcs=new_data_object.wcs._wcs,
                         extra_coords=original_data_object.extra_coords)
        return new_ndcube

In [11]:
@view_as_nddata(NewSpectrum1D)
def test_function(spec):
    return Spectrum1D(flux=np.ones(100) * u.Jy, spectral_axis=spec.spectral_axis)

In [12]:
import gwcs
import astropy.units as u

spec = Spectrum1D(flux=np.random.sample(100) * u.Jy, spectral_axis=np.arange(100) * u.AA)
wcs = spec.wcs._wcs
print(type(wcs))

cube = NewNDCube(data=np.random.sample(100), wcs=wcs)

<class 'gwcs.wcs.WCS'>


In [14]:
test_function(cube).data

INFO:astropy:overwriting NDData's current wcs with specified wcs.


INFO: overwriting NDData's current wcs with specified wcs. [astropy.nddata.nddata]


array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

In [None]:
%debug