# PCA of multiple molecules

## Isoelectronics: ethene and methylene immonium

In [None]:
import matplotlib.pyplot as plt
import xarray as xr

import shnitsel as sh
import shnitsel.xarray

In [None]:
A01 = sh.open_frames('/nc/reports/2025-05-21_datasheets/filtered_C2H4.nc')
I01 = sh.open_frames('/nc/SHNITSEL_databases/dynamic/I01_ch2nh2_dynamic.nc')

In [None]:
I01.sh.expand_midx('frame', 'cmpnd', 'I01')

In [None]:
isoelec = xr.concat(
    [
        A01.sh.expand_midx('frame', 'cmpnd', 'A01'),
        I01.sh.expand_midx('frame', 'cmpnd', 'I01')
    ],
    dim='frame'
)
isoelec

In [None]:
pca = isoelec.atXYZ.sh.subtract_combinations('atom').sh.norm('direction').sh.pca('atomcomb')
pca

In [None]:
plt.scatter(pca[:,0], pca[:,1], s=0.1, c=xr.where(pca.cmpnd=='A01', 0, 1))

In [None]:
for lcmpnd, cmpnd in pca.groupby('cmpnd'):
    for ltraj, traj in cmpnd.groupby('trajid'):
        c = 'b' if lcmpnd == 'I01' else 'r'
        plt.plot(traj[:,0], traj[:,1], c=c, alpha=0.5)

## Now parts of homologous series

In [None]:
A02 = sh.open_frames('/nc/reports/2025-05-21_datasheets/filtered_C3H6.nc')
A03 = sh.open_frames('/nc/reports/2025-05-21_datasheets/filtered_C4H8_g0.nc')

In [None]:
def vis(frames, charge=0):
    mol = sh.dynamic.filter_unphysical.mol_from_atXYZ(
        frames.atXYZ.isel(frame=0), charge=charge, to2D=True
    )
    for atom in mol.GetAtoms():
        atom.SetProp("atomNote", str(atom.GetIdx()))
    return mol

In [None]:
vis(A01, 0)

In [None]:
vis(I01, +1)

In [None]:
vis(A02)

In [None]:
# Renumber: 4 <-> 5
A02new = A02.isel(atom=slice(0,6)).assign_coords(atom=[0, 1, 2, 3, 5, 4]).sortby('atom')
vis(A02new, -3)

In [None]:
vis(A03)

In [None]:
# Renumber: 0->3, 1->1, 2->0, 3->2, 4->4, 5->5,
A03new = A03.isel(atom=[0,1,2,3,7,8]).assign_coords(atom=[3, 1, 0, 2, 4, 5]).sortby('atom')
vis(A03new, -8)

In [None]:
homologs = xr.concat(
    [
        A01.sh.expand_midx('frame', 'cmpnd', 'A01'),
        I01.sh.expand_midx('frame', 'cmpnd', 'I01'),
        A02new.sh.expand_midx('frame', 'cmpnd', 'A02'),
        A03new.sh.expand_midx('frame', 'cmpnd', 'A03')
    ],
    dim='frame',
    coords='minimal',
    compat='override'
)
homologs

In [None]:
pwdist = homologs.atXYZ.sh.subtract_combinations('atom').sh.norm('direction')
pca = pwdist.sh.pca('atomcomb')
plt.scatter(pca[:,0], pca[:,1], s=0.1, c=xr.where(pca.cmpnd=='A01', 0, 1))
for lcmpnd, cmpnd in pca.groupby('cmpnd'):
    for ltraj, traj in cmpnd.groupby('trajid'):
        c = {
            'A01': 'r', 
            'A02': 'g', 
            'A03': 'b', 
            'I01': 'y', 
        }[lcmpnd]
        plt.plot(traj[:,0], traj[:,1], c=c, alpha=0.5)

In [None]:
fig, axs = plt.subplot_mosaic([['I01', 'A01'],['A02', 'A03']], layout='constrained')

# grey background
for ax in axs.values():
    for _, traj in pca.groupby('trajid'):
        ax.plot(traj[:,0], traj[:,1], c='gray', alpha=0.5)

for name, cmpnd in pca.groupby('cmpnd'):
    ax = axs[name]
    for _, traj in cmpnd.groupby('trajid'):
        ax.plot(traj[:,0], traj[:,1], c='red', alpha=0.5)
    ax.set_title(name)


### Filter by PCA-adventurousness

In [None]:
fpca = sh.xrhelpers.flatten_levels(pca, 'frame', ['cmpnd', 'trajid'])
fpca

In [None]:
adventurousness = fpca.sh.norm('PC').groupby('trajid').max()
upca = fpca.sh.sel_trajids(adventurousness[adventurousness < 0.2].trajid)
upca

In [None]:
fig, axs = plt.subplot_mosaic([['I01', 'A01'],['A02', 'A03']], layout='constrained')

for name, ax in axs.items():
    ax.set_title(name)
    for (cmpnd, trajid), traj in fpca.groupby('trajid'):
        ax.plot(traj[:,0], traj[:,1], c='gray', alpha=0.5)

for (cmpnd, trajid), traj in upca.groupby('trajid'):
        axs[cmpnd].plot(traj[:,0], traj[:,1], c='r', alpha=0.5)

In [None]:
fig, axs = plt.subplot_mosaic([['I01', 'A01'],['A02', 'A03']], layout='constrained')

for name, ax in axs.items():
    ax.set_title(name)
    for (cmpnd, trajid), traj in upca.groupby('trajid'):
        ax.plot(traj[:,0], traj[:,1], c='gray', alpha=0.5)

for (cmpnd, trajid), traj in upca.groupby('trajid'):
        axs[cmpnd].plot(traj[:,0], traj[:,1], c='r', alpha=0.5)

## Changes over time

In [None]:
homologs.sh.get_inter_state().energy.sh.time_grouped_ci()

In [None]:
ise = homologs.sh.get_inter_state().energy
fig, axs = plt.subplot_mosaic(
    ise.statecomb.values[::-1, None], layout='constrained', sharey=True
)
for lcmpnd, cmpnd in ise.groupby('cmpnd'):
    ci = cmpnd.sh.time_grouped_ci()
    label = lcmpnd
    for lsc, sc in ci.groupby('statecomb'):
        sc = sc.squeeze('statecomb')
        ax = axs[lsc]
        ax.set_title(lsc)
        ax.plot('time', 'mean', data=sc, lw=0.5, label=label)
        label = '' # To avoid duplicate labels in the legend
        ax.fill_between('time', 'upper', 'lower', data=sc, alpha=0.3)

list(axs.values())[-1].set_xlabel("$t$ / fs")
for ax in axs.values():
    ax.set_ylabel(r"$\Delta E$ / eV")
fig.legend()

In [None]:
label_ls = True
for lcmpnd, cmpnd in homologs.groupby('cmpnd'):
    pops = cmpnd.sh.calc_pops()
    c = {
            'A01': 'r', 
            'A02': 'g', 
            'A03': 'b', 
            'I01': 'purple', 
    }[lcmpnd]
    label = lcmpnd
    for lstate, state in pops.groupby('state'):
        ls = {1: '-', 2: '--', 3: ':'}[lstate]
        if label_ls:
            label += f' state {lstate}'
        plt.plot(state['time'], state, c=c, ls=ls, label=label, lw=0.5)
        label=''
    label_ls = False

plt.legend()
plt.xlabel('$t$ / fs')
plt.ylabel('Population')

# Dihedrals

In [None]:
A01.atXYZ.attrs['smiles_map'] = A01.atXYZ.isel(frame=0).sh.smiles_map()
I01.atXYZ.attrs['smiles_map'] = I01.atXYZ.isel(frame=0).sh.smiles_map(charge=+1)
A02new.atXYZ.attrs['smiles_map'] = A02new.atXYZ.isel(frame=0).sh.smiles_map(charge=-3)
A03new.atXYZ.attrs['smiles_map'] = A03new.atXYZ.isel(frame=0).sh.smiles_map(charge=-8)

In [None]:
ncA01 = sh.dynamic.filter_unphysical.filter_cleavage(A01, CC=True, CH=True)
ncI01 = sh.dynamic.filter_unphysical.filter_cleavage(I01, CC=False, CH=True)
ncA02 = sh.dynamic.filter_unphysical.filter_cleavage(A02new, CC=True, CH=True)
ncA03 = sh.dynamic.filter_unphysical.filter_cleavage(A03new, CC=True, CH=True)

In [None]:

nchomologs = xr.concat(
    [
        ncA01.sh.expand_midx('frame', 'cmpnd', 'A01'),
        ncI01.sh.expand_midx('frame', 'cmpnd', 'I01'),
        ncA02.sh.expand_midx('frame', 'cmpnd', 'A02'),
        ncA03.sh.expand_midx('frame', 'cmpnd', 'A03')
    ],
    dim='frame',
    coords='minimal',
    compat='override'
)
nchomologs

In [None]:
fdihs = nchomologs.atXYZ.sh.dihedral(2, 0, 1, 3, full=True)

In [None]:
fig, ax = plt.subplots(1,1, subplot_kw={'projection': 'polar'}, dpi=400)
for lcmpnd, cmpnd in fdihs.groupby('cmpnd'):
    c = {
            'A01': 'r', 
            'A02': 'g', 
            'A03': 'b', 
            'I01': 'y', 
    }[lcmpnd]
    for ltraj, traj in cmpnd.groupby('trajid'):
        c = 'b' if traj.sel(time = 7) < 0 else 'r'
        ax.plot(
            traj if c=='r' else -traj,
            traj.time, c='k', lw=0.1
        )
    break

# Identify clusters
Over all homologs/isoelectronics

## 1. On the frame level

### 1.1 $k$-means on frames before PCA

In [None]:
from sklearn.cluster import KMeans

In [None]:
pwdist

In [None]:
kmc = KMeans(n_clusters=9)
pwdist = pwdist.transpose('frame', 'atomcomb')
kmc.fit(pwdist)
res_kmc = kmc.predict(pwdist)
pca['cluster'] = 'frame', res_kmc

plt.scatter(pca[:,0], pca[:,1], c=res_kmc, s=0.1)
for lcmpnd, cmpnd in pca.groupby('cmpnd'):
    for ltraj, traj in cmpnd.groupby('trajid'):
        point = traj.isel(frame=-1)
        plt.text(point.item(0), point.item(1), point.cluster.item())

### 1.2. $k$-means on frames _after_ PCA

In [None]:
kmc = KMeans(n_clusters=9)
res_kmc = kmc.fit(pca).predict(pca)
pca['cluster'] = 'frame', res_kmc
plt.scatter(pca[:, 0], pca[:, 1], c=res_kmc, s=0.1)
for lcmpnd, cmpnd in pca.groupby('cmpnd'):
    for ltraj, traj in cmpnd.groupby('trajid'):
        point = traj.isel(frame=-1)
        plt.text(point.item(0), point.item(1), point.cluster.item())

### 1.3 Other clustering methods on frames
(Takes ages, non-viable)

In [None]:
# from sklearn.cluster import DBSCAN
# res_sc = DBSCAN().fit_predict(pwdist)
# pca['cluster'] = 'frame', res_sc
# plt.scatter(pca[:, 0], pca[:, 1], c=res_sc, s=0.1)
# for lcmpnd, cmpnd in pca.groupby('cmpnd'):
#     for ltraj, traj in cmpnd.groupby('trajid'):
#         point = traj.isel(frame=-1)
#         plt.text(point.item(0), point.item(1), point.cluster.item())

In [None]:
# from sklearn.cluster import AgglomerativeClustering
# res_sc = AgglomerativeClustering().fit_predict(pwdist)
# pca['cluster'] = 'frame', res_sc
# plt.scatter(pca[:, 0], pca[:, 1], c=res_sc, s=0.1)
# for lcmpnd, cmpnd in pca.groupby('cmpnd'):
#     for ltraj, traj in cmpnd.groupby('trajid'):
#         point = traj.isel(frame=-1)
#         plt.text(point.item(0), point.item(1), point.cluster.item())

## 2. On the trajectory level

### 2.1 Cluster trajectories by final frame

### 2.2 Cluster trajectories by final frame _after_ PCA

In [None]:
kmc = KMeans(n_clusters=7)
pca_final_frames = (
    pca
    .groupby('cmpnd').map(
        lambda cmpnd: cmpnd.groupby('trajid').map(
            lambda traj: traj.isel(frame=-1)
        )
    )
    .stack(traj=['cmpnd', 'trajid'])
    .dropna('traj', how='all')
    .transpose('traj', 'PC')
)
pca_final_frames

In [None]:
# Better approach?

# (
#     pca
#     .reset_index('frame')
#     .assign_coords(
#         trajid=pca
#             .indexes['frame']
#             .map(lambda x: '_'.join([x[0], str(x[1])]))
#     )
#     .drop_vars('cmpnd')
# )

In [None]:
res_kmc = kmc.fit(pca_final_frames).predict(pca_final_frames)
res_kmc = pca_final_frames.isel(PC=0).copy(data=res_kmc)
res_kmc

In [None]:
pca['cluster'] = res_kmc
plt.scatter(pca[:, 0], pca[:, 1], c=res_kmc, s=0.1)
for lcmpnd, cmpnd in pca.groupby('cmpnd'):
    for ltraj, traj in cmpnd.groupby('trajid'):
        point = traj.isel(frame=-1)
        plt.text(point.item(0), point.item(1), point.cluster.item())

In [None]:
cmap = plt.get_cmap('viridis').resampled(12).colors
fig, axs = plt.subplot_mosaic([['I01', 'A01'],['A02', 'A03']], layout='constrained')

# grey background
for ax in axs.values():
    for _, traj in pca.groupby('trajid'):
        ax.plot(traj[:,0], traj[:,1], c='gray', alpha=0.5)

for name, cmpnd in pca.groupby('cmpnd'):
    seen_clusters = []
    ax = axs[name]
    for trajid, traj in cmpnd.groupby('trajid'):
        cluster = res_kmc.sel(cmpnd=name, trajid=trajid).item()
        c = cmap[cluster]
        ax.plot(traj[:,0], traj[:,1], c=c, alpha=0.5, label='' if cluster in seen_clusters else cluster)
        seen_clusters.append(cluster)
        point = traj.isel(frame=-1)
        ax.text(point.item(0), point.item(1), point.cluster.item()) # WRONG!
    ax.set_title(name)
    ax.legend()

In [None]:
for trajid, traj in pca.groupby('trajid'):
    print(trajid, traj.cmpnd)
    break
    plt.scatter(pca[:, 0], pca[:, 1], c=res_kmc, s=0.1)

### (Previous work)

In [None]:
upwdist = pwdist.unstack('frame').stack(traj=['cmpnd', 'trajid']).stack(special=['time', 'atomcomb']).dropna('traj', how='all')
upwdist

In [None]:
dupwdist = upwdist.dropna('special', how='any')#.dropna('traj', how='any'))
dupwdist

In [None]:
dupwdist.isel(special=0)

In [None]:
kmc = KMeans(n_clusters=10)
kmc.fit(dupwdist)
res_kmc = kmc.predict(dupwdist)
tmp = dupwdist.isel(special=0).drop_vars(['special', 'time', 'atomcomb'])
res_kmc = tmp.copy(data=res_kmc)
res_kmc

In [None]:
cmap = plt.get_cmap('tab10').colors#.resampled(12).colors

In [None]:
fig, axs = plt.subplot_mosaic([['I01', 'A01'],['A02', 'A03']], layout='constrained')

# grey background
for ax in axs.values():
    for _, traj in pca.groupby('trajid'):
        ax.plot(traj[:,0], traj[:,1], c='gray', alpha=0.5)

for name, cmpnd in pca.groupby('cmpnd'):
    seen_clusters = []
    ax = axs[name]
    for trajid, traj in cmpnd.groupby('trajid'):
        cluster = res_kmc.sel(cmpnd=name, trajid=trajid).item()
        c = cmap[cluster]
        ax.plot(traj[:,0], traj[:,1], c=c, alpha=0.5, label='' if cluster in seen_clusters else cluster)
        seen_clusters.append(cluster)
    ax.set_title(name)
    ax.legend()