In [None]:
#%matplotlib notebook
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
import importlib
import kk.tools
import kk.uitools
from kk.dggs import DGGS, io as dggs_io
#importlib.reload(kk.tools)


In [None]:
dd = dggs_io.h5_load('cbr-google-map.h5')

In [None]:
dg = DGGS()

def is_rgba(ds):
    return set('blue green red alpha'.split()).issubset(set(v for v in ds.data_vars))

def is_rgb(ds):
    return set('blue green red'.split()).issubset(set(v for v in ds.data_vars))

def as_rgb(ds):
    return np.dstack([ds[v] for v in 'red green blue'.split()])

def as_rgba(ds):
    return np.dstack([ds[v] for v in 'red green blue alpha'.split()])

def cell_bounds(addr):
    tr, (maxW, maxH) = dg.pixel_coord_transform(addr, 1,1, native=True, no_offset=True)
    x_min, y_min = tr(0,0)
    x_max, y_max = tr(1,1)
    return (x_min, x_max, y_min, y_max)

def cell_center(addr):
    tr, (maxW, maxH) = dg.pixel_coord_transform(addr, 1,1, native=True)
    return tr(0,0)

def merge_extents(e1, e2):
    if e1 is None:
        return e2
    
    return [min(e1[0], e2[0]), max(e1[1], e2[1]),
            min(e1[2], e2[2]), max(e1[3], e2[3])]

def hide_axis(ax):
    ax.xaxis.set_visible(False)
    ax.yaxis.set_visible(False)
    ax.axis('off')
    
def plot_all(data, band_idx=0, south_square=0, north_square=0, ax=None):
    dg = DGGS()
    dh = dg.mk_display_helper(south_square=south_square, north_square=north_square)
    
    if ax is None:
        fig = plt.figure()
        ax = fig.add_axes([0,0,1,1])
    else:
        fig = None
        
    extents = None

    for addr, ds in data.items():
        if is_rgba(ds):
            im = as_rgba(ds)
        elif is_rgb(ds):
            im = as_rgb(ds)
        else:
            im = list(ds.data_vars.values())[band_idx].values
            
        if np.isnan(im).all():
            continue
        im, ee = dh(addr, im)
        extents = merge_extents(extents, ee)
        ax.imshow(im, extent=ee)

    ax.set_xlim(*extents[:2])
    ax.set_ylim(*extents[2:])
    return fig

In [None]:
fig = plt.figure(figsize=(12,6))
ax_main = fig.add_axes([0,0,1,1])
ax_main.set_anchor('NW')

for ax in fig.axes:
    hide_axis(ax)

plot_all(dd, south_square=3, ax=ax_main)

In [None]:
ds = list(dd.values())[0]
ds