### Import basic packages and set up data loading

In [None]:
%matplotlib widget
import numpy as np
import matplotlib.pyplot as plt
import pickle

In [None]:
import sys
# Add the MsLightweaver files to the path so we can read the RADYN data - adjust path if necessary
# The atmost file contains the RADYN thermodynamic information for every internal timestep
sys.path.append('/home/osborne/MsLightweaver')
from ReadAtmost import read_atmost
atmost = read_atmost('atmost.dat')
atmost.to_SI()

In [None]:
# Define a couple of functions to load a particular timestep of LI/LE cases
wavelengthLE = np.load('TimestepsNoLybbLosses/Wavelength.npy')
def load_step_LE(idx):
    with open('TimestepsNoLybbLosses/Step_%.6d.pickle' % idx, 'rb') as pkl:
        step = pickle.load(pkl)
    return step
wavelength = np.load('TimestepsAdvNrLosses/Wavelength.npy')
def load_step(idx):
    with open('TimestepsAdvNrLosses/Step_%.6d.pickle' % idx, 'rb') as pkl:
        step = pickle.load(pkl)
    return step

In [None]:
# Load the standard RADYN output
from radynpy.cdf import LazyRadynData
radyn = LazyRadynData('radyn_out.cdf')

In [None]:
import astropy.units as u
import lightweaver as lw

# Get the RADYN intensity for a line down a particular direction; in SI
def line_intensity_with_cont(data, kr, muIdx):
    if not not data.cont[kr]:
        print('line_intensity cannot compute bf intensity')
        return

    wl = data.alamb[kr] / (data.q[0:data.nq[kr], kr] *  data.qnorm * 1e5 / data.cc + 1)
    intens = (data.outint[:, 1:data.nq[kr]+1, muIdx, kr] + data.outint[:, 0, muIdx, kr][:, np.newaxis]) *  data.cc * 1e8 / (wl**2)[np.newaxis, :] 
    print(intens.shape)
    # wl is retruned in angstrom, intens in erg/cm^2/sr/A/s
    wl = wl[::-1] << u.Angstrom
    intens = lw.utils.convert_specific_intensity(wl, intens[:, ::-1] << u.erg / u.cm**2 / u.sr / u.Angstrom / u.s, 'J/m2/sr/Hz/s')
    
    return wl << u.nm, intens

In [None]:
# Load Ca 8542 data from RADYN
lineData = line_intensity_with_cont(radyn, 20, -1)

### Plot 8542 Line Profiles

In [None]:
from matplotlib.ticker import MultipleLocator, AutoMinorLocator
from MsLightweaverAtoms import CaII
with plt.style.context('seaborn-colorblind'):
    plotTimes = [5, 11, 20, 40]
    fig, ax = plt.subplots(1, 4, figsize=(7.5, 3), sharex=True, sharey=True, gridspec_kw={'wspace':0, 'width_ratios':[1,1,1,1]})
    ax = ax.ravel()

    waveRangeH = (396.9, 397.0)
    waveRange = (854.3444, 854.5444)
    ca = CaII() 
    for i, time in enumerate(plotTimes):
        lwIdx = np.searchsorted(atmost.time, time)
        radynIdx = np.searchsorted(radyn.time, time)
        step = load_step(lwIdx)
        stepNoLybb = load_step_LE(lwIdx)
        lambda0 = ca.lines[-1].lambda0

        if i == 0:
            ax[i].plot(lineData[0].value - lambda0, lineData[1][radynIdx], label='RADYN')
            ax[i].plot(wavelength - lambda0, step['Iwave'][:, -1], '--', label='Lightweaver LI')
            ax[i].plot(wavelengthLE - lambda0, stepNoLybb['Iwave'][:, -1], '--', label='Lightweaver LE')
        else:
            ax[i].plot(lineData[0].value - lambda0, lineData[1][radynIdx])
            ax[i].plot(wavelength - lambda0, step['Iwave'][:, -1], '--')
            ax[i].plot(wavelengthLE - lambda0, stepNoLybb['Iwave'][:, -1], '--')

        ax[i].set_title('%.2f s' % time)
        ax[i].xaxis.set_minor_locator(AutoMinorLocator())
        ax[i].tick_params(which='both', direction='in')

    ax[0].set_yscale('log')
    ax[0].set_ylim(1e-8, 1.3e-7)
    ax[0].set_xlim(-0.08, 0.08)

    ax[0].yaxis.offsetText.set_visible(False)
    ax[0].set_ylabel('Specific Intensity [SI]')
    for a in ax:
        a.set_xlabel('$\Delta\lambda$ [nm]')
    fig.subplots_adjust(bottom=0.15)
    fig.legend(*ax[0].get_legend_handles_labels(), loc=(0.71,0.66), frameon=False, handletextpad=0.5, handlelength=1.5)#, bbox_to_anchor=(0.5,0.53))
#     fig.savefig('LineProfiles1e10_8542.png', dpi=300)


### Contribution Functions

We don't save the contribution function information, so it needs to be computed from the populations + atmospheric structure. Lightweaver can easily do this here in the notebook.

In [None]:
# Set up atmosphere with data from initial timestep
nHTot = atmost.d1 / (lw.DefaultAtomicAbundance.massPerH * lw.Amu)
atmos = lw.Atmosphere.make_1d(scale=lw.ScaleType.Geometric, depthScale=np.copy(atmost.z1[0]), 
                              temperature=np.copy(atmost.tg1[0]), vlos=np.copy(atmost.vz1[0]), 
                              vturb=np.copy(atmost.vturb), ne=np.copy(atmost.ne1[0]), nHTot=np.copy(nHTot[0]))
atmos.quadrature(5)

In [None]:
from lightweaver.rh_atoms import H_6_atom, C_atom, O_atom, OI_ord_atom, Si_atom, Al_atom, Fe_atom, FeI_atom, MgII_atom, N_atom, Na_atom, S_atom, CaII_atom, He_9_atom
from MsLightweaverAtoms import H_6, CaII, H_6_nasa, CaII_nasa, H_6_nobb
from lightweaver.LwCompiled import FastBackground
from weno4 import weno4
from matplotlib.colors import LogNorm, SymLogNorm
FchromaAtoms = [H_6(), CaII(), He_9_atom(), C_atom(), O_atom(), Si_atom(), Fe_atom(),
                MgII_atom(), N_atom(), Na_atom(), S_atom()]

# Set up our atoms and Lightweaver context
aSet = lw.RadiativeSet(FchromaAtoms)
aSet.set_active('H', 'Ca')
spect = aSet.compute_wavelength_grid()
eqPops = aSet.compute_eq_pops(atmos)
def fast_background(*args, Nthreads=12):
    return FastBackground(*args, Nthreads=Nthreads)
ctx = lw.Context(atmos, spect, eqPops, initSol=lw.InitialSolution.Lte, conserveCharge=False, Nthreads=12, backgroundProvider=fast_background)
ctx.depthData.fill = True

# Add the extra information expected by our Atoms with Fang non-thermal beam collisional rates
atmos.bHeat = np.ones_like(atmost.bheat1[0]) * 1e-20
atmos.hPops = eqPops['H']

In [None]:
eqPops.atomicPops['Ca'].model

In [None]:
# An omission from Lightweaver (will be added in a later version), computes tau from the complete chi array
def compute_tau(ctx, mu):
    upDown = 1 
    tau = np.zeros_like(ctx.depthData.chi[:, mu, upDown, :])
    chi = ctx.depthData.chi
    atmos = ctx.kwargs['atmos']

    # NOTE(cmo): Compute tau for all wavelengths
    tau[:, 0] = 1e-20
    for k in range(1, tau.shape[1]):
        tau[:, k] = tau[:, k-1] + 0.5 * (chi[:, mu, upDown, k] + chi[:, mu, upDown, k-1]) \
                                      * (atmos.height[k-1] - atmos.height[k]) / atmos.muz[mu]
    return tau

# Load the data from a particular timestep into the Lightweaver context
def load_timestep(path, idx):
    with open(path + 'Step_%.6d.pickle' % idx, 'rb') as pkl:
        step = pickle.load(pkl)

    atmos.temperature[:] = atmost.tg1[idx]
    atmos.vlos[:] = atmost.vz1[idx]
    atmos.ne[:] = atmost.ne1[idx]
    atmos.nHTot[:] = nHTot[idx]
    atmos.bHeat[:] = atmost.bheat1[idx]

    atmos.height[:] = atmost.z1[idx]

    for name, pops in step['eqPops'].items():
        if pops['n'] is not None:
            eqPops.atomicPops[name].pops[:] = pops['n']
        eqPops.atomicPops[name].nStar[:] = pops['nStar']
    atmos.ne[:] = step['ne']
    ctx.update_deps()
    
# Support functions for plotting the contribution function
def scale_cfn(cfn, scaleLimits=None):
    cfnLog = cfn
    if scaleLimits is None:
        minVal = np.min(cfnLog[np.isfinite(cfnLog)])
        maxVal = np.max(cfnLog[np.isfinite(cfnLog)])
    else:
        minVal = scaleLimits[0]
        maxVal = scaleLimits[1]
    cfnLog = np.clip(cfnLog, minVal, maxVal)
    return cfnLog

# Scale a line profile to overlie the contfn
def scale_profile(wavelengthGrid, profile, wavelengthRange, scaleRange):
    minIdx = np.searchsorted(wavelengthGrid, wavelengthRange[0])
    maxIdx = np.searchsorted(wavelengthGrid, wavelengthRange[1]) + 1
    
    profile = np.copy(profile)
    profile -= profile[minIdx:maxIdx].min()
    profile /= profile[minIdx:maxIdx].max()
    profile *= (scaleRange[1] - scaleRange[0])
    profile += scaleRange[0]
    return profile

# Compute the tau=1 line via interpolation
def tau1_line(tau, z):
    tau1 = np.zeros(tau.shape[0])
    
    for la in range(tau.shape[0]):
        tau1[la] = weno4(1.0, tau[la], z)
        
    return tau1

In [None]:
# This cell computes and plots the contribution function
timeToPlot = 11.0 
lwIdx = np.searchsorted(atmost.time, timeToPlot)
# Use ctx for LI case
load_timestep('TimestepsAdvNrLosses/', lwIdx)
for i in range(10):
    dJ = ctx.formal_sol_gamma_matrices()
    if dJ < 1e-3:
        break
        
fullNe = np.copy(ctx.kwargs['atmos'].ne)
fullTau = compute_tau(ctx, mu=-1)
fullCfn = lw.compute_contribution_fn(ctx, mu=-1)
caPopsFull = np.copy(ctx.eqPops['Ca'])

# LE case
load_timestep('TimestepsNoLybbLosses/', lwIdx)
for i in range(10):
    dJ = ctx.formal_sol_gamma_matrices()
    if dJ < 1e-3:
        break
        
noLybbNe = np.copy(ctx.kwargs['atmos'].ne)
noLybbTau = compute_tau(ctx, mu=-1)
noLybbCfn = lw.compute_contribution_fn(ctx, mu=-1)
caPopsNoLybb = np.copy(ctx.eqPops['Ca'])

stepNoHbb = load_step_LE(lwIdx)
step = load_step(lwIdx)

# Now plot
fig, ax = plt.subplots(1, 3, sharey=True, figsize=(7.5, 3), gridspec_kw={'wspace':0})
heightEdges = lw.compute_height_edges(ctx) / 1e6
wlEdges = lw.compute_wavelength_edges(ctx)

minVal = max(np.min(noLybbCfn), np.min(fullCfn))
maxVal = max(np.max(noLybbCfn), np.max(fullCfn))
    
fullCfnLog = scale_cfn(fullCfn, scaleLimits=None)
noLybbCfnLog = scale_cfn(noLybbCfn, scaleLimits=None)
line = ca.lines[-1]
lambda0 = line.lambda0
waveRange = (-0.1, 0.1)
waveRangeFull = (854.3444, 854.5444)
lowerIdx = np.searchsorted(wlEdges-lambda0, waveRange[0]) - 1
higherIdx = np.searchsorted(wlEdges-lambda0, waveRange[1]) + 1
ax[0].pcolormesh(wlEdges[lowerIdx:higherIdx]-lambda0, heightEdges, fullCfn.T[:, lowerIdx:higherIdx], cmap='Blues', norm=SymLogNorm(linthresh=5e-14, vmin=1e-20, vmax=maxVal))#, vmin=minVal, vmax=maxVal)
ax[1].pcolormesh(wlEdges[lowerIdx:higherIdx]-lambda0, heightEdges, noLybbCfn.T[:, lowerIdx:higherIdx], cmap='Blues', norm=SymLogNorm(linthresh=5e-14, vmin=1e-20, vmax=maxVal))#, vmin=minVal, vmax=maxVal)

ax[0].set_xlim(*waveRange)
ax[1].set_xlim(*waveRange)
ax[0].set_ylim(None, 2.4)

ax[0].plot(wavelength-lambda0, scale_profile(wavelength, step['Iwave'][:, -1], waveRangeFull, ax[0].get_ylim()), alpha=0.5)
ax[0].plot(wavelength-lambda0, tau1_line(fullTau, atmost.z1[lwIdx] / 1e6), 'r', alpha=0.5)
ax[1].plot(wavelengthLE-lambda0, scale_profile(wavelengthLE, stepNoHbb['Iwave'][:, -1], waveRangeFull, ax[0].get_ylim()), '--', alpha=0.5)
ax[1].plot(wavelength-lambda0, tau1_line(noLybbTau, atmost.z1[lwIdx] / 1e6), 'r', alpha=0.5)

ax[2].semilogx(caPopsFull[line.j], atmost.z1[lwIdx] / 1e6, c='C0', label='upper')
ax[2].semilogx(caPopsFull[line.i], atmost.z1[lwIdx] / 1e6, c='C1', label='lower')
ax[2].semilogx(caPopsNoLybb[line.j], atmost.z1[lwIdx] / 1e6, '--', c='C0')
ax[2].semilogx(caPopsNoLybb[line.i], atmost.z1[lwIdx] / 1e6, '--', c='C1')
ax[2].semilogx(caPopsFull[-1], atmost.z1[lwIdx] / 1e6, c='C3', label='Ca ɪɪɪ')
ax[2].semilogx(caPopsNoLybb[-1], atmost.z1[lwIdx] / 1e6, '--', c='C3')
ax[2].semilogx(caPopsNoLybb[-1][0], atmost.z1[lwIdx][0] / 1e6, 'C2', label='Temperature')

ax[2].tick_params(axis='x', which='both')
ax[0].set_xticks([-0.05, 0, 0.05])
ax[0].xaxis.set_minor_locator(AutoMinorLocator())
ax[1].set_xticks([-0.05, 0, 0.05])
ax[1].xaxis.set_minor_locator(AutoMinorLocator())
ax3 = ax[2].twiny()
fig.subplots_adjust(top=0.82, bottom=0.18, hspace=0.5)
ax3.semilogx(atmost.tg1[lwIdx], atmost.z1[lwIdx] / 1e6, c='C2')
ax3.set_xlabel('T [K]', c='C2')
ax[2].set_xlabel('Number Density [m$^{-3}$]')
ax[0].set_xlabel('$\lambda$ [nm]')
ax[1].set_xlabel('$\lambda$ [nm]')
ax[0].set_ylabel('Height [Mm]')
ax[0].set_title('Full treatment', size=11)
ax[1].set_title('Lyman lines excluded', size=11)
for a in ax:
    a.tick_params(which='both', direction='in')
ax3.tick_params(which='both', direction='in')
leg = fig.legend(*ax[2].get_legend_handles_labels(), loc=(0.67,0.23), frameon=False, handletextpad=0, handlelength=0)#, bbox_to_anchor=(0.5,0.53))

for handle, label in zip(leg.legendHandles, leg.texts):
    label.set_color(handle.get_color())
fig.suptitle('Time: %.2f s' % atmost.time[lwIdx])
# fig.savefig('ContFn1e10_8542_%d.png' % atmost.time[lwIdx], dpi=300)

### Plot the loss comparisons

In [None]:
# Get the losses from RADYN for the appropriate transitions
radynCaIdxs = []
radynHCaIdxs = []
radynTotIdxs = []

for kr in range(radyn.ielrad.shape[0]):
    if radyn.ielrad[kr] == 2 and radyn.cont[kr] == 0:
        radynCaIdxs.append(kr)
    if radyn.ielrad[kr] == 1 or radyn.cont[kr] == 0:
        radynTotIdxs.append(kr)
    if str(radyn.atomid[0, radyn.ielrad[kr]-1]).startswith('he') and radyn.cont[kr] == 1 and radyn.alamb[kr] < 900.0:
        radynTotIdxs.append(kr)
    if (radyn.ielrad[kr] == 1 or radyn.ielrad[kr] == 2) and radyn.cont[kr] == 0:
        radynHCaIdxs.append(kr)
    if (radyn.ielrad[kr] == 1) and radyn.cont[kr] == 1:
        radynHCaIdxs.append(kr)
    if str(radyn.atomid[0, radyn.ielrad[kr]-1]).startswith('he') and radyn.cont[kr] == 1 and radyn.alamb[kr] < 900.0:
        radynHCaIdxs.append(kr)
    
        
radynCaLosses = np.sum(np.abs(radyn.cool[:,:,radynCaIdxs]), axis=-1)
radynAbsCool = np.sum(np.abs(radyn.cool[:,:,radynTotIdxs]), axis=-1)

In [None]:
import matplotlib.ticker as ticker
from weno4 import weno4
fig = plt.figure(figsize=(7.5, 3), constrained_layout=True)
gs = fig.add_gridspec(1, 4, width_ratios=[1, 0.1, 1, 0.1], height_ratios=[1])
gs.update(left=0.1, right=0.95, wspace=0.2, hspace=0.4)
caLossesBig = []
caLossesLEBig = []
radynCaLossesBig = []
radynAbsCoolBig = []
tempInterp = []
fixedZGrid = np.linspace(radyn.z1[0].min(), 2.5e8, 10000)
# Get the losses for each of the "typical" 500 RADYN timesteps (0.1s cadence)
for tIdx in range(500):
    lwIdx = np.searchsorted(atmost.time, radyn.time[tIdx])
    stepNoHbb = load_step_LE(lwIdx)
    step = load_step(lwIdx)
    tempInterp.append(weno4(fixedZGrid, radyn.z1[tIdx], radyn.tg1[tIdx]))
    caLosses = np.sum(np.abs(np.stack(step['losses'][-5:])), axis=0)
    caLossesLE = np.sum(np.abs(np.stack(stepNoHbb['losses'][-5:])), axis=0)
    caLossesBig.append(weno4(fixedZGrid, radyn.z1[tIdx], caLosses))
    caLossesLEBig.append(weno4(fixedZGrid, radyn.z1[tIdx], caLossesLE))
    radynCaLossesBig.append(weno4(fixedZGrid, radyn.z1[tIdx], radynCaLosses[tIdx]))
    radynAbsCoolBig.append(weno4(fixedZGrid, radyn.z1[tIdx], radynAbsCool[tIdx]))
caLossesBig = np.stack(caLossesBig)
caLossesLEBig = np.stack(caLossesLEBig)
radynCaLossesBig = np.stack(radynCaLossesBig)
radynAbsCoolBig = np.stack(radynAbsCoolBig)
tempInterp = np.stack(tempInterp)

zEdges = np.concatenate(((fixedZGrid[0] + 0.5 * (fixedZGrid[0] - fixedZGrid[1]),),
                0.5 * (fixedZGrid[1:] + fixedZGrid[:-1]),
                (fixedZGrid[-1] + 0.5 * (fixedZGrid[-1] - fixedZGrid[-2]),)
               ))
timeEdges = np.concatenate(((0,), 0.5 * (radyn.time[:-1] + radyn.time[1:])))

ax0 = fig.add_subplot(gs[0])
cb0 = fig.add_subplot(gs[1])
ax1 = fig.add_subplot(gs[2])
cb1 = fig.add_subplot(gs[3])
lpanel = ((caLossesLEBig - caLossesBig) / (caLossesLEBig)).T
rpanel = (radynCaLossesBig / radynAbsCoolBig).T
maxVal = max(lpanel.max(), abs(lpanel.min()))
mesh0 = ax0.pcolormesh(timeEdges, zEdges / 1e8, 
                       lpanel
                       , cmap='RdBu_r', norm=SymLogNorm(linthresh=1e-2, vmax=maxVal, vmin=-maxVal))
mesh1 = ax1.pcolormesh(timeEdges, zEdges / 1e8, 
                       rpanel
                       , cmap='Spectral_r', vmax=0.4)
fig.colorbar(mesh0, cax=cb0)
fig.colorbar(mesh1, cax=cb1)
ax0.set_title('Relative change of Ca line losses\nwith LE and LI treatments')
ax1.set_title('Proportion of total radiative\nlosses due to Ca lines')
ax0.set_xlabel('t [s]')
ax1.set_xlabel('t [s]')
ax0.set_ylabel('Height [Mm]')

# fig.savefig('Losses1e10_Im.png', dpi=300)
# with open('F10LossPanels.pickle', 'wb') as pkl:
#     pickle.dump({'timeEdges': timeEdges, 'zEdges': zEdges / 1e8, 'data': lpanel * rpanel}, pkl)

The final plot is simply the product of the left and right panels from the previous one. As can be seen above, an F10 and F9 copy was saved and then reloaded in the following cell

In [None]:
with open('F10LossPanels.pickle', 'rb') as pkl:
    F10LossData = pickle.load(pkl)
    
with open('F9LossPanels.pickle', 'rb') as pkl:
    F9LossData = pickle.load(pkl)
    
    
fig = plt.figure(figsize=(7.5, 3), constrained_layout=True)
gs = fig.add_gridspec(1, 4, width_ratios=[1, 0.1, 1, 0.1], height_ratios=[1])
gs.update(left=0.1, right=0.95, wspace=0.2, hspace=0.4)
ax0 = fig.add_subplot(gs[0])
cb0 = fig.add_subplot(gs[1])
ax1 = fig.add_subplot(gs[2])
cb1 = fig.add_subplot(gs[3])
maxVal = max(F9LossData['data'].max(), abs(F9LossData['data'].min()))
mesh0 = ax0.pcolormesh(F9LossData['timeEdges'], F9LossData['zEdges'], 
                       F9LossData['data'], cmap='RdBu_r', norm=SymLogNorm(1e-2, vmin=-maxVal, vmax=maxVal)
                       )
# plt.colorbar()
maxVal = max(F10LossData['data'].max(), abs(F10LossData['data'].min()))
mesh1 = ax1.pcolormesh(F10LossData['timeEdges'], F10LossData['zEdges'], F10LossData['data'],
                       cmap='RdBu_r', norm=SymLogNorm(1e-2, vmin=-maxVal, vmax=maxVal))
#                        , vmax=0.2)
fig.colorbar(mesh0, cax=cb0)
fig.colorbar(mesh1, cax=cb1)
ax0.set_title('Variation in losses due to calcium\ntreatment in F9 simulation')
ax1.set_title('Variation in losses due to calcium\ntreatment in F10 simulation')
ax0.set_xlabel('t [s]')
ax1.set_xlabel('t [s]')
ax0.set_ylabel('Height [Mm]')

To just plot the case in this notebook we can simply do:

In [None]:
fig = plt.figure(figsize=(4.5, 3), constrained_layout=True)
gs = fig.add_gridspec(1, 2, width_ratios=[1, 0.1], height_ratios=[1])
gs.update(left=0.1, right=0.95, wspace=0.2, hspace=0.4)
ax0 = fig.add_subplot(gs[0])
cb0 = fig.add_subplot(gs[1])
data = lpanel * rpanel
maxVal = max(abs(data.max()), abs(data.min()))
mesh0 = ax0.pcolormesh(timeEdges, zEdges / 1e8, 
                       lpanel*rpanel, cmap='RdBu_r', norm=SymLogNorm(1e-2, vmin=-maxVal, vmax=maxVal)
                       )
fig.colorbar(mesh0, cax=cb0)
ax0.set_title('Variation in losses due to calcium\ntreatment in F10 simulation')
ax0.set_xlabel('t [s]')
ax1.set_xlabel('t [s]')
ax0.set_ylabel('Height [Mm]')

Function to create zarr from pickles

In [None]:
import zarr
from tqdm import tqdm as tqdm
def convert_to_zarr(outFile, step_loader, wavelength):
    Ntime = atmost.time.shape[0]-1
    Nspace = atmost.z1.shape[1]
    step = step_loader(0)
    out = zarr.convenience.open(outFile, mode='w')
#     store = zarr.MemoryStore()
#     out = zarr.group()
    out['time'] = atmost.time[:-1]
    out['wavelength'] = wavelength
    eqPops = out.require_group('eqPops')
    for ele in step['eqPops'].keys():
        g = eqPops.require_group(ele)
        Nlevel = step['eqPops'][ele]['nStar'].shape[0]
        if step['eqPops'][ele]['n'] is not None:
            rates = g.require_group('radiativeRates')
            for ratePair in step['eqPops'][ele]['radiativeRates'].keys():
                rates[repr(ratePair)] = np.zeros((Ntime, Nspace))
            g['n'] = np.zeroS((Ntime, Nlevel, Nspace))
        g['nStar'] = np.zeros((Ntime, Nlevel, Nspace))
        
    out['ne'] = np.zeros((Ntime, Nspace))
    out['Iwave'] = np.zeros((Ntime, *step['Iwave'].shape))
    out.attrs['lines'] = [(l.atom.element.name, l.lambda0) for l in step['lines']]
    out['losses'] = np.zeros((Ntime, len(step['lines']), Nspace))
    
    for t in tqdm(range(Ntime)):
        step = step_loader(t)
        for ele in step['eqPops'].keys():
            g = eqPops[ele]
            g['nStar'][t, ...] = step['eqPops'][ele]['nStar']
            if step['eqPops'][ele]['n'] is not None:
                g['n'][t, ...] = step['eqPops'][ele]['n']
                for ratePair in step['eqPops'][ele]['radiativeRates'].keys():
                    g['radiativeRates'][repr(ratePair)][t, :] = step['eqPops'][ele]['radiativeRates'][ratePair]
        out['ne'][t, :] = step['ne']
        out['Iwave'][t, :] = step['Iwave']
        for idx, loss in enumerate(step['losses']):
            out['losses'][t, idx, :] = loss
            
#     zarr.save(outFile, out)

In [None]:
import zarr
from tqdm import tqdm as tqdm
def convert_to_zarr(outFile, step_loader, wavelength):
    Ntime = atmost.time.shape[0]-1
    Nspace = atmost.z1.shape[1]
    step = step_loader(0)
    
    memStore = {}
    memStore['eqPops'] = {}
    eqPops = memStore['eqPops']
    for ele in step['eqPops'].keys():
        eqPops[ele] = {}
        g = eqPops[ele]
        Nlevel = step['eqPops'][ele]['nStar'].shape[0]
        if step['eqPops'][ele]['n'] is not None:
            g['radiativeRates'] = {}
            rates = g['radiativeRates']
            for ratePair in step['eqPops'][ele]['radiativeRates'].keys():
                rates[repr(ratePair)] = np.zeros((Ntime, Nspace))
            g['n'] = np.zeros((Ntime, Nlevel, Nspace))
        g['nStar'] = np.zeros((Ntime, Nlevel, Nspace))
    memStore['ne'] = np.zeros((Ntime, Nspace))
    memStore['Iwave'] = np.zeros((Ntime, *step['Iwave'].shape))
    memStore['losses'] = np.zeros((Ntime, len(step['lines']), Nspace))
        
    for t in tqdm(range(Ntime)):
        step = step_loader(t)
        for ele in step['eqPops'].keys():
            g = eqPops[ele]
            g['nStar'][t, ...] = step['eqPops'][ele]['nStar']
            if step['eqPops'][ele]['n'] is not None:
                g['n'][t, ...] = step['eqPops'][ele]['n']
                for ratePair in step['eqPops'][ele]['radiativeRates'].keys():
                    g['radiativeRates'][repr(ratePair)][t, :] = step['eqPops'][ele]['radiativeRates'][ratePair]
        memStore['ne'][t, :] = step['ne']
        memStore['Iwave'][t, :] = step['Iwave']
        for idx, loss in enumerate(step['losses']):
            memStore['losses'][t, idx, :] = loss
            
    out = zarr.convenience.open(outFile, mode='w')
#     store = zarr.MemoryStore()
#     out = zarr.group()
    out['time'] = atmost.time[:-1]
    out['wavelength'] = wavelength
    eqPops = out.require_group('eqPops')
    for ele in step['eqPops'].keys():
        g = eqPops.require_group(ele)
        Nlevel = step['eqPops'][ele]['nStar'].shape[0]
        if step['eqPops'][ele]['n'] is not None:
            rates = g.require_group('radiativeRates')
            for ratePair in step['eqPops'][ele]['radiativeRates'].keys():
                rates[repr(ratePair)] = memStore['eqPops'][ele]['radiativeRates'][repr(ratePair)]
            g['n'] = memStore['eqPops'][ele]['n']
        g['nStar'] = memStore['eqPops'][ele]['nStar']
        
    out['ne'] = memStore['ne']
    out['Iwave'] = memStore['Iwave']
    out.attrs['lines'] = [(l.atom.element.name, l.lambda0) for l in step['lines']]
    out['losses'] = memStore['losses']
            

In [None]:
convert_to_zarr('/local1/scratch/cmo/Flat1e10NoIncRadLI.zarr', load_step, wavelength)