In [None]:
%matplotlib inline
from lixtools.hdf import h5xs_scan
from py4xs.data2d import Data2d,unflip_array
from py4xs.plot import show_data,show_data_qphi
import pylab as plt
import numpy as np
from dask.distributed import Client
from IPython.display import display,HTML
import warnings

warnings.filterwarnings('ignore')

fn = "00template00.h5"
dt = h5xs_scan(fn)
dt.load_data()

try:
    client = Client("00scheduler_addr00")    
    print("created Dask client ...")
except:
    print("Could not create Dask client ...")
    client = None
    
# show averaged data, pick up to 3 raw data files
samples = list(dt.h5xs.keys())
N = 1
if len(samples)>4:
    N = 2
if len(samples)>7:
    N = 3
for i in np.linspace(0,len(samples)-1,N, dtype=int):
    sn = samples[i]
    if 'avg_data' in dt.proc_data[sn].keys():
        d2s = {}
        for ext,img in dt.proc_data[sn]['avg_data'].items():
            det = dt.det(ext)
            ep = det.exp_para
            d2 = Data2d(unflip_array(img, ep.flip), exp=ep)
            d2s[ext] = d2
            d2.md['frame #'] = "average"
    elif client:
        d2s = dt.h5xs[sn].get_d2(frn="average", detectors=dt.detectors, client=client)
        dt.proc_data[sn]['avg_data'] = {k:d.data.d for k,d in d2s.items()} 
        dt.save_data(save_sns=sn, save_data_keys="avg_data")
    else:
        print(f"skip plotting {sn}, averaged data not available")
        continue
        
    print(f"plotting {sn} ...")
    dt.proc_data[sn]['avg_data'] = {k:d.data.d for k,d in d2s.items()} 
    fig = plt.figure(figsize=(10,5))
    show_data(d2s, fig=plt.gcf(), aspect=1, showMask=False, cmap='jet')
    fig = plt.figure(figsize=(10,5))
    show_data_qphi(d2s, dt.detectors, ax=plt.gca(), Nq=dt.qgrid, Nphi=31,  clim="auto", 
                   aspect='auto', logScale=True, cmap='jet', apply_symmetry=True, sc_factor="x")
    plt.show()

    
if len(samples)>1:
    sn = "overall"
else:
    sn = samples[0]
# show the maps, two or three per row
asp = {"maps": "auto", "tomo": 1}
for k in ["maps", "tomo"]:
    print("\n\n", k)
    sks = list(dt.proc_data[sn][k].keys())
    if "transmission" in sks:
        sks.remove("transmission")   # go with absoprtion instead
    nsk = len(sks)
    if nsk>4:
        nc = 3
    else:
        nc = 2
    nr = int(np.ceil(nsk/nc))
    fig = plt.figure(figsize=(nc*5, nr*5))
    for i in range(nsk):
        ax = fig.add_subplot(nr, nc, i+1)
        dt.proc_data[sn][k][sks[i]].plot(ax=ax, aspect=asp[k], cmap='jet')
        ax.set_title(sks[i])
    plt.subplots_adjust(wspace=0.375, hspace=0.375)
    plt.show()