In [None]:
import numpy as np
from matplotlib import pyplot as plt
import copy

%matplotlib qt5

In [None]:
class Figure(object):
    def __init__(self, data, coords=None):
        self.data = data
        self.coords = coords
        self.transformers = []
        
        if self.coords is None:
            shp = self.data[1].shape
            self.coords = [(f"Axis {i}", np.arange(s)) for i, s in enumerate(shp)]
        
    def transform(self):
        d, c = self.data, self.coords
        for cls, opts in self.transformers:
            t = cls()
            d, c = t.transform(d, c, **opts)
        return d, c
        
    
    def plot(self):
        raise NotImplementedError
        
        
class Transformer(object):
    
    def transform(self, data, coords, **kw):
        raise NotImplementedError
        

class Average(Transformer):
    
    def transform(self, data, coords, axis=0):
        if type(axis) == str:
            for i, (n, _) in enumerate(coords):
                if n == axis:
                    axis = i
                    break
        if type(axis) == str:
            raise ValueError("Unknown axis given.")
            
        c = [c for i, c in enumerate(coords) if i != axis]
        d = [data[0], data[1].mean(axis=axis)]
        
        return d, c
        
        

class MplPlot(object):
    _axes = []
        

class MplLinePlot(MplPlot):
    
    default_opts = dict(
        marker = 'o',
        lw = 1,
        mew = 0,
    )
    
    def __init__(self, fig, ax=None):
        self.fig = fig
        if ax is None:
            self.ax = fig.add_subplot(1, 1, 1)
        else:
            self.ax = ax
        self._axes.append(ax)
            
    
    def plot(self, x, y, **plotkw):
        opts = self.default_opts.copy()
        opts.update(plotkw)
        self.ax.plot(x, y, **opts)


class MplFigure(Figure):
    
    def __init__(self, *arg, **kw):
        super().__init__(*arg, **kw)
        self.figure = plt.figure()
    
    
    def plot(self):
        d, c = self.transform()
        naxes = len(c)
        
        if naxes == 1:
            p = MplLinePlot(self.figure)
            p.plot(c[0][1], d[1])
            p.ax.set_xlabel(c[0][0])
            p.ax.set_ylabel(d[0])

In [None]:
x = np.linspace(0, 10, 51)
y = np.linspace(0, 10, 51)

xx, yy = np.meshgrid(x, y, indexing='ij')
z = np.cos(xx) * np.cos(yy)

coords = [('x', x), ('y', y)]
data = ('cos(x) * cos(y)', z)

f = MplFigure(data, coords)
f.transformers.append((Average, dict(axis='y')))
f.plot()

In [None]:
z.m