In [1]:
from gammapy.maps.geom import Geom
from gammapy.maps import Map, MapAxis, MapCoord, WcsGeom
from gammapy.utils.regions import make_region
import numpy as np
from astropy import units as u

In [100]:
from regions import CircleSkyRegion, RectangleSkyRegion, PolygonSkyRegion
from astropy.coordinates import SkyCoord
from astropy import units as u
from astropy.wcs import WCS
from astropy.visualization import quantity_support
import copy
import matplotlib.pyplot as plt
from gammapy.spectrum import ReflectedRegionsFinder

In [128]:
class RegionGeom(Geom):
    """Map geometry representing a region on the sky.
    
    Parameters
    ----------
    region : `~regions.SkyRegion`
        Region object.
    axes : list of `MapAxis`
        Non-spatial data axes.
    wcs : `~astropy.wcs.WCS`
        Optional wcs object to project the region if needed.
    """
    is_image = False
    is_allsky = False
    is_hpx = False
    
    def __init__(self, region, axes=None, wcs=None):
        self._region = region
        self._axes = axes
        
        if wcs is None:
            wcs = WcsGeom.create(
            skydir=region.center, binsz=0.001, width=region.radius, proj="TAN"
        ).wcs
        
        self._wcs = wcs
        self.ndim = len(axes)
        self.coordsys = str(region.center.frame.name)
    
    @property
    def region(self):
        return self._region

    @property
    def axes(self):
        return self._axes

    @property
    def wcs(self):
        return self._wcs

    @property
    def center_coord(self):
        """(`astropy.coordinates.SkyCoord`)"""
        return self.pix_to_coord(self.center_pix)
    
    @property
    def center_pix(self):
        return tuple((np.array(self.data_shape) - 1.0) / 2)[::-1]
    
    @property
    def center_skydir(self):
        """Center skydir"""
        return self.region.center
    
    def contains(self, position):
        idx = self.coord_to_idx(coords)
        return np.all(np.stack([t != INVALID_INDEX.int for t in idx]), axis=0)
    
    @property
    def data_shape(self):
        return tuple([ax.nbin for ax in self.axes]) + (1, 1)
    
    def get_coord(self):
        """Get map coordinates from the geometry.

        Returns
        -------
        coord : `~MapCoord`
            Map coordinate object.
        """
        cdict = {}
        cdict["skycoord"] = self.center_skydir
        
        if self.axes is not None:
            for ax in self.axes:
                cdict[ax.name] = ax.center
        
        return MapCoord.create(cdict)
            
    def pad(self):
        raise NotImplementedError("Padding of `RegionGeom` not implemented")
    
    def crop(self):
        raise NotImplementedError("Cropping of `RegionGeom` not implemented")
    
    def solid_angle(self):
        """"""
        return self.region.solid_angle()
    
    def to_cube(self, axes):
        return self._init_copy(axes=axes)
    
    def to_image(self):
        return self._init_copy(axes=None)
    
    def upsample(self, factor, axis):
        axes = copy.deepcopy(self.axes)
        idx = self.get_axis_index_by_name(axis)
        axes[idx] = axes[idx].upsample(factor)
        return self._init_copy(axes=axes)
    
    def downsample(self, factor, axis):
        axes = copy.deepcopy(self.axes)
        idx = self.get_axis_index_by_name(axis)
        axes[idx] = axes[idx].downsample(factor)
        return self._init_copy(axes=axes)
    
    def pix_to_coord(self, pix):
        lon = np.select(pix[0], -0.5, 0.5, np.nan)
        lat = np.select(pix[1], -0.5, 0.5, np.nan)
        coords = (lon, lat)
        coords += axes_pix_to_coord(self.axes, pix[self._slice_non_spatial_axes])
        return coords
    
    def pix_to_idx(self, pix):
        
        return (1, 1)
        
    
    def coord_to_pix(self, coords):
        coords = MapCoord.create(coords, coordsys=self.coordsys)
        in_region = self.region.contains(coords.skycoord, wcs=self.wcs)
        return in_region
    
    def get_idx(self):
        idxs = (0, 0)
        
        for ax in self.axes:
            idxs += np.arange(ax.nbin)
        
        return idxs
    
    def _make_bands_cols(self):
        pass
    
    @classmethod
    def create(cls, region, **kwargs):
        """Create region.
        
        Parameters
        ----------
        region : str or `~regions.SkyRegion`
            Region
            
        """
        if isinstance(region, str):
            region = make_region(region)
        
        return cls(region, **kwargs)

    def __repr__(self):
        axes = ["lon", "lat"] + [_.name for _ in self.axes]
        lon = self.center_skydir.data.lon.deg
        lat = self.center_skydir.data.lat.deg

        return (
            f"{self.__class__.__name__}\n\n"
            f"\taxes       : {axes}\n"
            f"\tshape      : {self.data_shape[::-1]}\n"
            f"\tndim       : {self.ndim}\n"
            f"\tframe      : {self.center_skydir.frame.name}\n"
            f"\tcenter     : {lon:.1f} deg, {lat:.1f} deg\n"
        )
    

class RegionNDMap(Map):
    """"""
    def __init__(self, geom, data=None, dtype="float32", meta=None, unit=""):
        if data is None:
            data = np.zeros(geom.data_shape, dtype=dtype)
        
        self.geom = geom
        self.data = data
        self.meta = meta
        self.unit = u.Unit(unit)
    
    @property
    def quantity(self):
        return self.data * self.unit
    
    def plot(self, ax=None):
        """Plot map.
        """
        ax = ax or plt.gca()
        
        if len(self.geom.axes) > 1:
            raise TypeError("Use `.plot_interactive()` if more the one extra axis is present.")
        
        axis = self.geom.axes[0] 
        with quantity_support():
            ax.plot(axis.center, self.quantity.squeeze())
            
        if axis.interp == "log":
            ax.set_xscale("log")
    
    @classmethod
    def create(cls, region, **kwargs):
        """
        """
        if isinstance(region, str):
            region = None
        
        return cls(region, **kwargs)
        
    def downsample(self, factor, axis=None):
        pass
    
    def fill_by_idx(self):
        pass
    
    def get_by_idx(self):
        pass
    
    def interp_by_coord(self):
        pass
    
    def interp_by_idx(self):
        pass
    
    def interp_by_pix(self):
        pass
        
    def set_by_idx(self):
        pass
    
    def upsample(self, factor, axis=None):
        pass
    
    @staticmethod
    def read(cls, filename):
        pass
    
    def write(self, filename):
        pass
    
    def to_hdulist(self):
        pass
    
    @classmethod
    def from_hdulist(cls):
        pass
    
    def crop(self):
        raise NotImplementedError
    
    def pad(self):
        raise NotImplementedError
    
    def sum_over_axes(self):
        raise NotImplementedError
    
    def get_image_by_coord(self):
        raise NotImplementedError

    def get_image_by_idx(self):
        raise NotImplementedError

    def get_image_by_pix(self):
        raise NotImplementedError

    def get_by_coord(self, coords):
        coords = coords.copy()
        coords.setdefault("skycoord", self.geom.center_skydir)
        print(coords)
        return super().get_by_coord(coords)
        

In [129]:
circle = CircleSkyRegion(center=SkyCoord("0 deg", "0 deg", frame="galactic"), radius=0.2 * u.deg)
axis = MapAxis.from_bounds(0.1, 30, 30, unit="TeV", name="energy", interp="log")
geom = RegionGeom(circle, axes=[axis])

# support creation via string
#geom = RegionGeom.create("galactic:()", axes=[axis])

In [130]:
spectrum = RegionNDMap(geom)

In [131]:
print(spectrum.geom)

RegionGeom

	axes       : ['lon', 'lat', 'energy']
	shape      : (1, 1, 30)
	ndim       : 1
	frame      : galactic
	center     : 0.0 deg, 0.0 deg



In [132]:
spectrum.get_by_coord({"energy": [1] * u.TeV})

{'energy': <Quantity [1.] TeV>, 'skycoord': <SkyCoord (Galactic): (l, b) in deg
    (0., 0.)>}


TypeError: pix_to_idx() takes 1 positional argument but 2 were given