In [None]:
import xarray as xr

import shnitsel as sh
import shnitsel.xarray

## Load ensembles

In [None]:
cpnds = {}
for cpnd in ['C2H4', 'C3H6', 'C4H8']:
    print("\t", cpnd)
    cpnds[cpnd] = sh.parse.read_trajs(f'/traj/SHNITSEL_alkenes/traj_{cpnd}/', kind='sharc', parallel=True)

### Checkpoint: loading complete

In [None]:
for cpnd in cpnds:
    sh.xrhelpers.save_frames(cpnds[cpnd], f'/tmp/raw_{cpnd}.nc')

In [None]:
cpnds = {}
for cpnd in ['C2H4', 'C3H6', 'C4H8_g0']:
    cpnds[cpnd] = sh.xrhelpers.open_frames(f'/nc/2025-05-20/raw_{cpnd}.nc').sh.setup_frames()
cpnds['C4H8'] = cpnds['C4H8_g0']
del(cpnds['C4H8_g0'])

## Filter alkenes by energy

In [None]:
feat = cpnds['C2H4'].sh.energy_filtranda()
feat

In [None]:
masks = xr.Dataset(
    {
        'etot_window' : abs(feat['etot_drift']) < 0.2,
        'etot_step'   : abs(feat['etot_step']) < 0.1,
        'epot_step'   : (abs(feat['epot_step']) < 0.7) | feat['is_hop'],
        'ekin_step'   : (abs(feat['ekin_step']) < 0.7) | feat['is_hop'],
        'hop_epot'    : (abs(feat['epot_step']) < 1.0) | ~feat['is_hop'],
    }
)
masks

In [None]:
(masks['etot_window'].sortby(['trajid', 'time']).time == masks['etot_window'].time).all()

In [None]:
import numpy as np
def last_time_where(mask):
    before_first_false = ~((~mask).groupby('trajid').cumsum().astype(bool))
    upto_first_false = mask.coords['time'][before_first_false].groupby('trajid').last()
    fallback = (~before_first_false).groupby('trajid').all()
    if fallback.any():
        fallback = fallback.copy(data=np.full((len(fallback),), -1))  # -1 indicates first ts fails test
        return upto_first_false.combine_first(fallback)
    else:
        return upto_first_false

last_time_where(masks['etot_window'])

In [None]:
masks['etot_window'].sh.last_time_where()

In [None]:
cutoffs = masks.sh.last_time_where()
cutoffs

In [None]:
cutoffs = {}
for cpnd in cpnds:
    feat = cpnds[cpnd].sh.energy_filtranda()
    cutoffs[cpnd] = xr.Dataset(
        {
            'etot_window' : abs(feat['etot_drift']) < 0.2,
            'etot_step'   : abs(feat['etot_step']) < 0.1,
            'epot_step'   : (abs(feat['epot_step']) < 0.7) | feat['is_hop'],
            'ekin_step'   : (abs(feat['ekin_step']) < 0.7) | feat['is_hop'],
            'hop_epot'    : (abs(feat['epot_step']) < 1.0) | ~feat['is_hop'],
        }
    ).sh.last_time_where()

In [None]:
import matplotlib.pyplot as plt
import numpy as np

fig, axs = plt.subplots(2, 3, sharey=False, constrained_layout=True)
for (cn, c), axcol in zip(cutoffs.items(), axs.T):
    nreasons = len(c.attrs['reasons'])
    typefreqs = np.bincount(c['reason'], minlength=nreasons)
    xticks = range(nreasons)
    axcol[0].bar(xticks, typefreqs)
    axcol[0].set_xticks(xticks)
    axcol[0].set_xticklabels(labels=c.attrs['reasons'], rotation=45, ha='right')
    axcol[0].set_title(cn)

    da = c.to_dataarray('cutoff').sortby([c['earliest'], c['original']])
    xticks = np.arange(da.sizes['trajid_'])
    for ctn, ctv in da.groupby('cutoff'):
        ctv = ctv.squeeze()
        axcol[1].barh(xticks, ctv, height=1.0, alpha=0.5)

In [None]:
cpnds['C2H4']

In [None]:
cutoffs['C2H4']['earliest']

In [None]:
efilt = {c: cpnds[c].sh.truncate(cutoffs[c]['earliest']) for c in cpnds}
efilt['C2H4']

In [None]:
def truncate(frames, cutoffs):
    expansion = cutoffs.sel(trajid_=frames.coords['trajid']).drop_vars('trajid_')
    mask = frames['time'] <= expansion
    return frames.sel(frame=mask)

truncate(cpnds['C2H4'], cutoffs['C2H4']['earliest'])

In [None]:
efilt['C3H6']
efilt['C2H4']

In [None]:
cpnds['C2H4'].time.groupby('trajid').last()

In [None]:
efilt['C2H4'].time.groupby('trajid').max()


## Eliminate overshort trajectories

In [None]:
efilt = {c: cpnds[c].sh.truncate(cutoffs[c]['earliest']) for c in cpnds}

for cpnd in efilt:
    threshold = cpnds[cpnd].coords['time'].max() * 0.25
    last_times = efilt[cpnd].time.groupby('trajid').last()
    osids = last_times[last_times < threshold].trajid
    print(cpnd, ':', len(osids))
    efilt[cpnd] = efilt[cpnd].sh.sel_trajs(osids, invert=True)

## Summarize: numbers of trajetories and frames

In [None]:
for cpnd in cpnds:
    print(cpnd)
    b = cpnds[cpnd].sizes
    a = efilt[cpnd].sizes
    for x in ['trajid_', 'frame']:
        print(f"{x:<7} {b[x]:6} -> {a[x]: 6}  retaining {100*a[x]/b[x]:05.2f}%")

    print()

### Checkpoint: filtration complete

In [None]:
for cpnd in cpnds:
    sh.xrhelpers.save_frames(efilt[cpnd], f'/tmp/filtered_{cpnd}.nc')

In [None]:
import shnitsel.dynamic as sh
import shnitsel.dynamic.postprocess as P
import matplotlib.pyplot as plt
import numpy as np
filtered = {}
for cpnd in ['C2H4', 'C3H6', 'C4H8_g0']:
    filtered[cpnd] = sh.xrhelpers.open_frames(f'/tmp/filtered_{cpnd}.nc')

### Butene only: remove cleavages for biplot
Because plotting dihedral rather than bodn-length

In [None]:
# borrowed from shnitsel-rough/pipelin.ipynb -- consider canonicalizing
def scope(cpnd):
    # We begin with S2 removed.
    ids = {}
    ids['withCH'] = sh.filter_unphysical.find_overlong(cpnd['atXYZ'], 1, 6, cutoff=1.7)
    ids['withCC'] = sh.filter_unphysical.find_overlong(cpnd['atXYZ'], 6, 6, cutoff=2.8)
    ids['withCHorCC'] = list(set(ids['withCH']).union(ids['withCC']))
    ids['withCHandCC'] = list(set(ids['withCH']).intersection(ids['withCC']))
    print(f"{ids['withCC']=}")
    print(f"{ids['withCHandCC']=}")

    removed_trajs = {k: sh.xrhelpers.sel_trajids(cpnd, v) for k, v in ids.items()}
    removed = {k: [v.sizes['frame'], len(np.unique(v.trajid))] for k, v in removed_trajs.items()}
    remaining_trajs = {k: sh.xrhelpers.sel_trajids(cpnd, v, invert=True) for k, v in ids.items()}
    remaining = {k: [v.sizes['frame'], len(np.unique(v.trajid))] for k, v in remaining_trajs.items()}
    
    return removed, remaining

scope(filtered['C4H8_g0'])

In [None]:
mol = sh.filter_unphysical.mol_from_atXYZ(
    filtered['C4H8_g0'].atXYZ.isel(frame=0), to2D=False
)
filtered['C4H8_g0'].atXYZ.attrs['smiles_map'] = sh.filter_unphysical.mol_to_numbered_smiles(mol)

uncleaving = sh.filter_unphysical.filter_cleavage(filtered['C4H8_g0'], CC=True, CH=True)
uncleaving

In [None]:
folder = '/nc/reports/2025-05-21_datasheets/plots'
sh.plot.biplot_kde(frames=uncleaving, at1=0, at2=1, at3=2, at4=3, geo_filter=[[0,70],[90,180]], levels=8)
plt.gcf().savefig(f'{folder}/kb_C4H8_g0.png')

In [None]:
sh.plot.biplot_kde(frames=uncleaving, at1=0, at2=1, at3=2, at4=3, geo_filter=[[0,70],[90,180]], levels=8)

### Checkpoint

In [None]:
uncleaving.sh.save_frames('/tmp/uncleaving.nc')