In [None]:
import shnitsel as sh
import shnitsel.xarray
import matplotlib.pyplot as plt

In [None]:
frames = sh.open_frames('/nc/reports/2025-05-21_datasheets/filtered_C4H8_g0.nc')
frames

In [None]:
# Select trajectories without cleavages
mol = frames.atXYZ.isel(frame=0).sh.to_mol(to2D=False)
frames.atXYZ.attrs['smiles_map'] = sh.core.postprocess.mol_to_numbered_smiles(mol)
frames = sh.core.filter_unphysical.filter_cleavage(frames, CC=True, CH=True)

In [None]:
dih = frames.atXYZ.sh.dihedral(0, 1, 2, 3, full=False, deg=True)
dih

In [None]:
from sklearn.cluster import KMeans

### $k$-means clustering using final timestep

In [None]:
Y = dih.groupby('trajid').map(lambda traj: traj[{'frame': -1}])
kmc = KMeans(n_clusters=2)
kmc.fit(Y.data.reshape(-1, 1))
plt.scatter(Y, [1]*len(Y), c=kmc.predict(Y.data.reshape(-1, 1)))
plt.xlabel('Final dihedral / °')

### $k$-means clustering using final 2 timesteps

In [None]:
Y = dih.groupby('trajid').map(lambda traj: traj[{'frame': slice(-2, None)}].unstack('frame').assign_coords(time=[-2,-1]))
kmc = KMeans(n_clusters=2)
kmc.fit(Y)
plt.scatter(Y[:, 0], Y[:, 1], c=kmc.predict(Y))
plt.xlabel('Penultimate dihedral / °')
plt.ylabel('Final dihedral / °')

### -10 and -1

In [None]:
Y = dih.groupby('trajid').map(lambda traj: traj[{'frame': [-10, -1]}].unstack('frame').assign_coords(time=[-10,-1]))
kmc = KMeans(n_clusters=2)
kmc.fit(Y)
plt.scatter(Y[:, 0], Y[:, 1], c=kmc.predict(Y))
plt.xlabel('Dihedral 5 fs before end / °')
plt.ylabel('Final dihedral / °')

### $k$-means over the final 20 frames + PCA

In [None]:
def tmp(traj):
    traj = traj[{'frame': slice(-10, None)}]
    traj = traj.unstack('frame')
    traj = traj.assign_coords(time=range(-20, 0, 2))
    return traj

Ynew = dih.groupby('trajid').map(tmp)
Ynew

In [None]:
kmc = KMeans(n_clusters=2)
kmc.fit(Ynew)
pca_new = Ynew.sh.pca('time')
plt.scatter(pca_new[:, 0], pca_new[:, 1], c=kmc.predict(Ynew))
plt.xlabel('PC1')
plt.ylabel('PC2')