In [None]:
import xarray as xr

import shnitsel as sh
import shnitsel.xarray

## Load ensembles

In [None]:
cpnds = {}
for cpnd in ['C2H4', 'C3H6', 'C4H8']:
    print("\t", cpnd)
    cpnds[cpnd] = sh.parse.read_trajs(f'/traj/SHNITSEL_alkenes/traj_{cpnd}/', kind='sharc', parallel=True)

### Checkpoint: loading complete

In [None]:
# for cpnd in cpnds:
    # cpnds[cpnd].sh.save_frames(f'/tmp/raw_{cpnd}.nc')

In [None]:
cpnds = {}
for cpnd in ['C2H4', 'C3H6', 'C4H8_g0']:
    cpnds[cpnd] = sh.open_frames(f'/nc/2025-05-20/raw_{cpnd}.nc').sh.setup_frames()
cpnds['C4H8'] = cpnds['C4H8_g0']
del(cpnds['C4H8_g0'])

In [None]:
for cpnd in cpnds:
    cpnds[cpnd] = cpnds[cpnd].sh.setup_frames()

## Filter alkenes by energy

In [None]:
feat = cpnds['C2H4'].sh.energy_filtranda()
feat

In [None]:
import numpy as np

def last_time_where(mask):
    mask = mask.unstack('frame', fill_value=False).transpose('trajid', 'time', ...)
    idxs = np.logical_not((~mask.values).cumsum(axis=1)).sum(axis=1)
    times = np.concat([[-1], mask.time.values])
    return mask[:, 0].copy(data=times[idxs]).drop_vars('time').rename('time')

last_time_where(feat['e_kin'] < 1)

In [None]:
masks = xr.Dataset(
    {
        'original'    : True,
        'etot_window' : abs(feat['etot_drift']) < 0.2,
        'etot_step'   : abs(feat['etot_step']) < 0.1,
        'epot_step'   : (abs(feat['epot_step']) < 0.7) | feat['is_hop'],
        'ekin_step'   : (abs(feat['ekin_step']) < 0.7) | feat['is_hop'],
        'hop_epot'    : (abs(feat['epot_step']) < 1.0) | ~feat['is_hop'],
    }
).to_dataarray('cutoff')
masks

In [None]:
cutoffs = last_time_where(masks)
cutoffs

In [None]:
cutoffs['earliest'] = cutoffs.min('cutoff')
cutoffs['reason'] = cutoffs.argmin('cutoff')
cutoffs

In [None]:
cutoffs = {}
for cpnd in cpnds:
    feat = cpnds[cpnd].sh.energy_filtranda()
    masks = xr.Dataset(
        {
            'original'    : True,
            'etot_window' : abs(feat['etot_drift']) < 0.2,
            'etot_step'   : abs(feat['etot_step']) < 0.1,
            'epot_step'   : (abs(feat['epot_step']) < 0.7) | feat['is_hop'],
            'ekin_step'   : (abs(feat['ekin_step']) < 0.7) | feat['is_hop'],
            'hop_epot'    : (abs(feat['epot_step']) < 1.0) | ~feat['is_hop'],
        }
    ).to_dataarray('cutoff')
    cutoffs[cpnd] = last_time_where(masks)


In [None]:
cpnds['C2H4'].atXYZ.isel(frame=0).sh.to_mol()

In [None]:
import matplotlib as mpl
# cmap = mpl.colors.ListedColormap(['#7E5273', '#C4A000', '#2c3e50'])
cmap = mpl.colors.ListedColormap(['#2c3e50'])
cmap

In [None]:
import matplotlib.pyplot as plt
import numpy as np

fig, axs = plt.subplots(2, 3, sharey=False, constrained_layout=True, height_ratios=[1, 3])
for (cn, c), axcol in zip(cutoffs.items(), axs.T):
    nreasons = c.sizes['cutoff']
    typefreqs = np.bincount(c.argmin('cutoff'), minlength=nreasons)
    xticks = range(nreasons)
    axcol[0].bar(xticks, typefreqs, color='#2c3e50')
    axcol[0].set_xticks(xticks)
    axcol[0].set_xticklabels(labels=c['cutoff'].data, rotation=45, ha='right')
    axcol[0].set_title(cn)

    sortorder = [c.isel(cutoff=i) for i in typefreqs.argsort()]
    c = c.sortby([c.min('cutoff'), *sortorder])
    # c = c.rolling(trajid=10).mean()
    xticks = np.arange(c.sizes['trajid'])
    i = 0
    for i, (ctn, ctv) in enumerate(c.groupby('cutoff')):
        ctv = ctv.squeeze()
        # axcol[1].barh(xticks, ctv, height=1.0, alpha=0.5, color=cmap(i))
        rolling = ctv.rolling(trajid=10).median()
        axcol[1].fill_betweenx(xticks, rolling, alpha=0.2, color=cmap(i), lw=1, step='mid')
        # axcol[1].fill_betweenx(xticks, ctv, alpha=0.1, color=cmap(i), lw=1, step='mid')
        # if i < 2:
        #     axcol[1].plot(ctv, xticks, lw=1, zorder=10, color=cmap(i))
    axcol[1].plot(c.min('cutoff'), xticks, c='white', lw=1, zorder=10)
    
    axs[0,0].set_ylabel("# trajs truncated\nfor given reason")
    axs[1,0].set_ylabel("trajs (sorted by earliest cutoff)")
    for ax in axs[1, :]:
        ax.set_xlabel("$t$ / fs")

In [None]:
import matplotlib as mpl
def outlabel(ax, letter):
    fixedtrans = mpl.transforms.ScaledTranslation(
        -20 / 72, +7 / 72, ax.figure.dpi_scale_trans
    )
    transform = ax.transAxes + fixedtrans
    return ax.text(
        0.0,
        1.0,
        letter,
        transform=transform,
        va='bottom',
        fontweight='bold',
        bbox=dict(facecolor='0.9', edgecolor='none', pad=3.0),
    )

def inlabel(ax, letter):
    return ax.annotate(
        letter,
        xy=(1, 1),
        xycoords='axes fraction',
        xytext=(-1, -0.5),
        textcoords='offset fontsize',
        va='top',
        fontweight='bold',
        bbox=dict(facecolor='0.9', edgecolor='none', pad=3.0),
    )

In [None]:
import matplotlib.pyplot as plt
import numpy as np

fig, axs = plt.subplots(2, 2, sharey=False, constrained_layout=True, height_ratios=[1, 3])
fig.set_size_inches(4.4, 4.8)
axcols = [axs.T[0], None, axs.T[1]]
titles = [r'C$_2$H$_4$ ($\mathbf{A01}$)', None, r'C$_4$H$_8$ ($\mathbf{A03}$)']
for (cn, c), axcol, title in zip(cutoffs.items(), axcols, titles):
    if cn == 'C3H6':
        continue
    nreasons = c.sizes['cutoff']
    typefreqs = np.bincount(c.argmin('cutoff'), minlength=nreasons)
    xticks = range(nreasons)
    axcol[0].bar(xticks, typefreqs, color=cmap(0))
    axcol[0].set_xticks(xticks)
    axcol[0].set_xticklabels(labels=c['cutoff'].data, rotation=45, ha='right')
    axcol[0].set_title(title)

    # c = c.sortby([c.min('cutoff'), c.sel(cutoff='original')])
    sortorder = [c.isel(cutoff=i) for i in typefreqs.argsort()]
    c = c.sortby([c.min('cutoff'), *sortorder])
    xticks = np.arange(c.sizes['trajid'])
    for ctn, ctv in c.groupby('cutoff'):
        ctv = ctv.squeeze()
        # axcol[1].barh(xticks, ctv, height=1.0, alpha=0.5)
        axcol[1].fill_betweenx(xticks, ctv, alpha=0.2, color=cmap(i), lw=1, step='mid')
    axcol[1].plot(c.min('cutoff').where(c.min('cutoff') < cpnds[cn].time.max()), xticks, c='white', lw=1, zorder=10)
        # if ctn == 'original'
        # axcol[1].barh(xticks, ctv, height=1.0, alpha=0.5)

    
    axs[0,0].set_ylabel("# trajs truncated\nfor given reason")
    axs[1,0].set_ylabel("trajs (sorted by earliest cutoff)")
    for ax in axs[1, :]:
        ax.set_xlabel("$t$ / fs")

inlabel(axs[0,0], 'a')
inlabel(axs[0,1], 'b')
inlabel(axs[1,0], 'c')
inlabel(axs[1,1], 'd')

fig.savefig('/nc/reports/2025-06-30_4-step_figure/energy_filtration_stats.pdf')

In [None]:
cutoffs = {}
for cpnd in cpnds:
    feat = cpnds[cpnd].sh.energy_filtranda()
    cutoffs[cpnd] = xr.Dataset(
        {
            'etot_window' : abs(feat['etot_drift']) < 0.2,
            'etot_step'   : abs(feat['etot_step']) < 0.1,
            'epot_step'   : (abs(feat['epot_step']) < 0.7) | feat['is_hop'],
            'ekin_step'   : (abs(feat['ekin_step']) < 0.7) | feat['is_hop'],
            'hop_epot'    : (abs(feat['epot_step']) < 1.0) | ~feat['is_hop'],
        }
    ).sh.get_cutoffs()

In [None]:
cutoffs['C2H4'].to_pandas()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

fig, axs = plt.subplots(2, 3, sharey=False, constrained_layout=True, height_ratios=[1, 3])
for (cn, c), axcol in zip(cutoffs.items(), axs.T):
    nreasons = len(c.attrs['reasons'])
    typefreqs = np.bincount(c['reason'], minlength=nreasons)
    xticks = range(nreasons)
    axcol[0].bar(xticks, typefreqs)
    axcol[0].set_xticks(xticks)
    axcol[0].set_xticklabels(labels=c.attrs['reasons'], rotation=45, ha='right')
    axcol[0].set_title(cn)

    da = c.to_dataarray('cutoff').sortby([c['earliest'], c['original']])
    xticks = np.arange(da.sizes['trajid_'])
    for ctn, ctv in da.groupby('cutoff'):
        ctv = ctv.squeeze()
        axcol[1].barh(xticks, ctv, height=1.0, alpha=0.5)
    
    axs[0,0].set_ylabel("# trajs truncated\nfor given reason")
    axs[1,0].set_ylabel("$t$ / fs")
    for ax in axs[1, :]:
        ax.set_xlabel("# trajs with this cutoff")

In [None]:
efilt = {c: cpnds[c].sh.truncate(cutoffs[c]['earliest']) for c in cpnds}
efilt['C2H4']

## Eliminate overshort trajectories

In [None]:
efilt = {c: cpnds[c].sh.truncate(cutoffs[c]['earliest']) for c in cpnds}

for cpnd in efilt:
    threshold = cpnds[cpnd].coords['time'].max() * 0.25
    last_times = efilt[cpnd].time.groupby('trajid').last()
    osids = last_times[last_times < threshold].trajid
    print(cpnd, ':', len(osids))
    efilt[cpnd] = efilt[cpnd].sh.sel_trajs(osids, invert=True)

## Summarize: numbers of trajetories and frames

In [None]:
for cpnd in cpnds:
    print(cpnd)
    b = cpnds[cpnd].sizes
    a = efilt[cpnd].sizes
    for x in ['trajid_', 'frame']:
        print(f"{x:<7} {b[x]:6} -> {a[x]: 6}  retaining {100*a[x]/b[x]:05.2f}%")

    print()

### Checkpoint: filtration complete

In [None]:
for cpnd in cpnds:
    efilt[cpnd].sh.save_frames(f'/tmp/filtered_{cpnd}.nc')