# Labelled map axis proof of concept

This notebook explores the option of a `LabeledMapAxis` and implements a minimal working solution. However there are many un-solved question like:
- How to serialise `LabeledMapAxis`?
- Which methods to support?
- Make sure methods like `Map.plot_grid()` etc. work


In [1]:
from gammapy.maps import Map, MapAxis
import numpy as np

class LabeledMapAxis:
    """"""
    node_type = "label"
    def __init__(self, labels, name=""):
        unique_labels = set(labels)
        if not len(unique_labels) == len(labels):
            raise ValueError("Node labels must be unique")
        
        self._labels = np.array(labels)
        self._name = name
    
    @property
    def name(self):
        return self._name
    
    @property
    def nbin(self):
        return len(self._labels)
    
    def coord_to_idx(self, coord, clip=False):
        coord = np.array(coord)[..., np.newaxis]
        is_equal = coord == self._labels
        if not np.all(np.any(is_equal, axis=-1)):
            label = coord[~np.any(is_equal, axis=-1)]
            raise ValueError(f"Not a valid label: {label}")
            
        return np.argmax(is_equal, axis=-1)

    def coord_to_pix(self, coord):
        return self.coord_to_idx(coord)
    
    def pix_to_idx(self, pix, clip=False):
        return pix
    
    @property
    def center(self):
        return self._labels
    
    @property
    def bin_width(self):
        return np.ones(self.nbin)
    
    def __repr__(self):
        str_ = self.__class__.__name__ + "\n"
        str_ += "-" * len(self.__class__.__name__) + "\n\n"
        fmt = "\t{:<10s} : {:<10s}\n"
        str_ += fmt.format("name", self.name)
        str_ += fmt.format("nbins", str(self.nbin))
        str_ += fmt.format("node type", self.node_type)
        str_ += fmt.format(f"labels", "{0}".format(list(self._labels)))
        return str_.expandtabs(tabsize=2)
    
    def upsample(self):
        raise NotImplementedError

One could even imagine setting a `LabeledMapAxis` as and "index" for `Map.__getitem__`:

In [2]:
def __getitem__(self, idx):
    return self.get_image_by_coord({"label": idx})

Map.__getitem__ = __getitem__

In [3]:
axis = LabeledMapAxis(["d1", "d2", "d3"], name="label")
m = Map.create(width=5, axes=[axis])
m.data += np.arange(3)[:, np.newaxis, np.newaxis] + 1
m.data = np.random.poisson(m.data)

In [4]:
print(axis)

LabeledMapAxis
--------------

  name       : label     
  nbins      : 3         
  node type  : label     
  labels     : ['d1', 'd2', 'd3']



In [5]:
m["d1"].data

array([[0, 2, 1, ..., 1, 0, 0],
       [1, 1, 1, ..., 0, 1, 2],
       [1, 0, 1, ..., 0, 1, 0],
       ...,
       [1, 2, 1, ..., 1, 1, 2],
       [2, 1, 1, ..., 1, 1, 2],
       [1, 1, 0, ..., 1, 1, 2]])

In [6]:
m["d2"].data

array([[3, 0, 1, ..., 3, 3, 4],
       [4, 0, 1, ..., 2, 4, 1],
       [2, 1, 2, ..., 2, 0, 2],
       ...,
       [0, 2, 0, ..., 2, 1, 4],
       [4, 3, 3, ..., 1, 4, 2],
       [3, 2, 1, ..., 1, 1, 2]])

In [7]:
m["d3"].data

array([[5, 3, 1, ..., 3, 4, 1],
       [6, 3, 4, ..., 3, 4, 2],
       [5, 3, 1, ..., 3, 0, 2],
       ...,
       [3, 1, 1, ..., 3, 4, 4],
       [2, 4, 6, ..., 2, 7, 3],
       [4, 2, 6, ..., 3, 4, 5]])

In [8]:
print(m.to_region_nd_map())

RegionNDMap

	geom  : RegionGeom 
 	axes  : ['lon', 'lat', 'label']
	shape : (1, 1, 3)
	ndim  : 3
	unit  : 
	dtype : int64

