# PCA of multiple molecules

In [None]:
import matplotlib.pyplot as plt
import xarray as xr

import shnitsel as sh
import shnitsel.xarray

In [None]:
A01 = sh.open_frames('/nc/reports/2025-05-21_datasheets/filtered_C2H4.nc')
I01 = sh.open_frames('/nc/SHNITSEL_databases/dynamic/I01_ch2nh2_dynamic.nc')
A02 = sh.open_frames('/nc/reports/2025-05-21_datasheets/filtered_C3H6.nc')
A03 = sh.open_frames('/nc/reports/2025-05-21_datasheets/filtered_C4H8_g0.nc')

In [None]:
def vis(frames, charge=0):
    mol = sh.dynamic.filter_unphysical.mol_from_atXYZ(
        frames.atXYZ.isel(frame=0), charge=charge, to2D=True
    )
    for atom in mol.GetAtoms():
        atom.SetProp("atomNote", str(atom.GetIdx()))
    return mol

In [None]:
vis(A01, 0)

In [None]:
vis(I01, +1)

In [None]:
vis(A02)

In [None]:
# Renumber: 4 <-> 5
A02new = A02.isel(atom=slice(0,6)).assign_coords(atom=[0, 1, 2, 3, 5, 4]).sortby('atom')
vis(A02new, -3)

In [None]:
vis(A03)

In [None]:
# Renumber: 0->3, 1->1, 2->0, 3->2, 4->4, 5->5,
A03new = A03.isel(atom=[0,1,2,3,7,8]).assign_coords(atom=[3, 1, 0, 2, 4, 5]).sortby('atom')
vis(A03new, -8)

In [None]:
homologs = xr.concat(
    [
        A01.sh.expand_midx('frame', 'cmpnd', 'A01'),
        I01.sh.expand_midx('frame', 'cmpnd', 'I01'),
        A02new.sh.expand_midx('frame', 'cmpnd', 'A02'),
        A03new.sh.expand_midx('frame', 'cmpnd', 'A03')
    ],
    dim='frame',
    coords='minimal',
    compat='override'
)
homologs

In [None]:
pwdist = homologs.atXYZ.sh.subtract_combinations('atom').sh.norm('direction')
pca = pwdist.sh.pca('atomcomb')

In [None]:
fig, axs = plt.subplot_mosaic([['I01', 'A01'],['A02', 'A03']], layout='constrained')

# grey background
for ax in axs.values():
    for _, traj in pca.groupby('trajid'):
        ax.plot(traj[:,0], traj[:,1], c='gray', alpha=0.5)

for name, cmpnd in pca.groupby('cmpnd'):
    ax = axs[name]
    for _, traj in cmpnd.groupby('trajid'):
        ax.plot(traj[:,0], traj[:,1], c='red', alpha=0.5)
    ax.set_title(name)


# Identify clusters
Over all homologs/isoelectronics

## 1. On the frame level

### 1.1 $k$-means on frames before PCA

In [None]:
from sklearn.cluster import KMeans
kmc = KMeans(n_clusters=12)
pwdist = pwdist.transpose('frame', 'atomcomb')
kmc.fit(pwdist)
res_kmc = kmc.predict(pwdist)
pca['cluster'] = 'frame', res_kmc

plt.scatter(pca[:,0], pca[:,1], c=res_kmc, s=0.1)
for lcmpnd, cmpnd in pca.groupby('cmpnd'):
    for ltraj, traj in cmpnd.groupby('trajid'):
        point = traj.isel(frame=-1)
        plt.text(point.item(0), point.item(1), point.cluster.item())

In [None]:
fig, axs = plt.subplot_mosaic([['I01', 'A01'],['A02', 'A03']], layout='constrained')
# fig.set_dpi(200)
w, h = fig.get_size_inches()
fig.set_size_inches(w*2, h*2)

# grey background
for ax in axs.values():
    for _, traj in pca.groupby('trajid'):
        ax.plot(traj[:,0], traj[:,1], c='gray', alpha=0.5, zorder=0.9)

for name, cmpnd in pca.groupby('cmpnd'):
    ax = axs[name]
    ax.scatter(cmpnd[:,0], cmpnd[:,1], c=cmpnd['cluster'], s=0.1)
    ax.set_title(name)
    for ltraj, traj in cmpnd.groupby('trajid'):
        point = traj.isel(frame=-1)
        ax.text(point.item(0), point.item(1), point.cluster.item())

In [None]:
from shnitsel.dynamic.plot import p3mhelpers

sh.dynamic.plot.p3mhelpers.frames3Dgrid(
    homologs.atXYZ.sel(frame=pca.cluster==11)
    .sh.flatten_levels('frame', ['cmpnd', 'trajid'])
    .groupby('trajid').map(lambda x: x.isel(frame=-1))
    .drop_vars('frame')
    .rename({'trajid': 'frame'})
)
None