In [1]:
#!/usr/bin/env python3

from pathlib import Path

import numpy as np
import torch
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from torch_geometric.data import Batch

from ocpmodels.transfer_learning.common.utils import (
    ATOMS_TO_GRAPH_KWARGS,
    load_xyz_to_pyg_batch,
)
from ocpmodels.transfer_learning.loaders import BaseLoader
from ocpmodels.transfer_learning.models.distribution_regression import (
    EmbeddingKernel,
    GaussianKernel,
    LinearKernel,
    KernelGroupEmbeddingRidgeRegression,
    KernelMeanEmbeddingRidgeRegression,
    LinearGroupEmbeddingKernel,
    LinearMeanEmbeddingKernel,
    StandardizedOutputRegression,
    median_heuristic,
)

# from torch_geometric.data.batch import DataBatch

In [2]:
%cd ../../

/home/isak/life/references/projects/src/python_lang/ocp


In [4]:
from dscribe.descriptors import SOAP

In [5]:
### Create plot
dataset_dir = Path("data/private/2_abinitio-metad_Fe45N2.extxyz")
alpha = 1.0
aggregation = "mean"
start_idx = 1600
end_idx = -10
#lmbda = 1e-9
plot_dir = Path(f"transfer_learning/notebooks/figures/{dataset_dir.stem}/alpha{alpha}_agg{aggregation}_soap")
plot_dir.mkdir(parents=True, exist_ok=True)

### Load data
raw_data, data_batch, num_frames, num_atoms = load_xyz_to_pyg_batch(dataset_dir, ATOMS_TO_GRAPH_KWARGS["schnet"])

soap = SOAP(species=["Fe", "N"], r_cut=6, n_max=12, l_max=6, periodic=True, sparse=False)

def create_soap_features(systems, soap_object):
    features = []
    for atoms in systems:
        features.append(soap_object.create(atoms))
    return torch.tensor(features)
soap_features = create_soap_features(raw_data, soap)
soap_features = soap_features.detach().cpu()

  cell = torch.Tensor(atoms.get_cell()).view(1, 3, 3)


In [6]:
t, n, d = soap_features.shape
m = n
l = t
zx = data_batch.atomic_numbers.reshape(t * n, 1)

In [7]:
def distance_mat(x1, x2, sigma=1.0):
    dist = (x1.unsqueeze(1) - x2.unsqueeze(0)).pow(2).sum(dim=2)
    return dist

def compute_dist_matrix(X, chunk_size=10):
    t_n, d = X.shape
    D = torch.zeros((t_n, t_n))

    for i in range(0, t_n, chunk_size):
        i_end = min(i + chunk_size, t_n)
        for j in range(0, t_n, chunk_size):
            j_end = min(j + chunk_size, t_n)

            X_chunk1 = X[i:i_end]
            X_chunk2 = X[j:j_end]

            D_chunk = distance_mat(X_chunk1, X_chunk2)

            D[i:i_end, j:j_end] = D_chunk

    return D

In [None]:
D = compute_dist_matrix(soap_features.reshape(t * n, -1))

In [None]:
def gaussian_kernel(x1, x2, sigma=1.0):
    dist = (x1.unsqueeze(1) - x2.unsqueeze(0)).pow(2).sum(dim=2)
    return torch.exp(-dist / (2 * (sigma**2)))

def compute_kernel_matrix(X, chunk_size=100):
    t_n, d = X.shape
    K = torch.zeros((t_n, t_n))
    
    for i in range(0, t_n, chunk_size):
        i_end = min(i + chunk_size, t_n)
        for j in range(0, t_n, chunk_size):
            j_end = min(j + chunk_size, t_n)
            
            X_chunk1 = X[i:i_end]
            X_chunk2 = X[j:j_end]
            
            K_chunk = gaussian_kernel(X_chunk1, X_chunk2)
            
            K[i:i_end, j:j_end] = K_chunk

    return K

In [None]:
delta = (zx[:, None] == zy[None, :]).squeeze()

agg_c = 1.0
if aggregation == "mean":
    agg_c /= n * m
    # Get all groups and possibly calculate the number
    groups = torch.unique(torch.cat([zx, zy]))
    group_nx = (zx[:, None] == groups[None, :]).reshape(t, n, -1).sum(axis=1)
    group_ny = (zy[:, None] == groups[None, :]).reshape(l, m, -1).sum(axis=1)
    select_zx = (zx[:, None] == groups[None, :]).reshape(t, n, -1)
    select_zy = (zy[:, None] == groups[None, :]).reshape(l, m, -1)
    group_nx_flatten = (select_zx * group_nx[:, None, :]).reshape(t * n, -1).sum(axis=1)
    group_ny_flatten = (select_zy * group_ny[:, None, :]).reshape(l * m, -1).sum(axis=1)
    # Normalize by the number of kernels in the group \sum_s K_s / S
    mask = (1.0 / group_nx_flatten[:, None]) * (1.0 / group_ny_flatten[None, :])
elif self.aggregation == "sum":
    mask = 1.0
else:
    raise ValueError(f"Unknown aggregation {aggregation}")
group_c = mask * delta

k = (k0 * ((1 - alpha) * agg_c + alpha * group_c)).reshape(t, n, l, m).sum(axis=(1, 3))

In [None]:

from sklearn.cluster import SpectralClustering

# set number of clusters
n_clusters = 2

# perform spectral clustering
spec_cluster = SpectralClustering(n_clusters=n_clusters, affinity='precomputed')
labels = spec_cluster.fit_predict(K)

# plot clustering
ts = np.loadtxt("./transfer_learning/notebooks/2_abinitio-metad_Fe45N2_N2distance.txt")[start_idx:end_idx]
t = np.arange(ts.shape[0])
#K = np.random.rand(100, 100)
#labels = np.random.randint(0, 2, size=100)

In [None]:
ts = np.loadtxt("./transfer_learning/notebooks/2_abinitio-metad_Fe45N2_N2distance.txt")[start_idx:end_idx]
t = np.arange(ts.shape[0])

In [None]:
from matplotlib import pyplot as plt
import seaborn as sns
plt.style.use('seaborn-v0_8-paper')

#Styling
import matplotlib as mpl
mpl.rcParams['lines.linewidth'] = 2
mpl.rcParams['lines.linestyle'] = '-'

font = {'family' : 'Times New Roman',
        'weight' : 'normal',
        'size'   : 12}
mpl.rc('mathtext',**{'default':'regular'})
mpl.rcParams['xtick.labelsize'] = 12
mpl.rcParams['ytick.labelsize'] = 12
mpl.rcParams['axes.labelsize'] = 14
mpl.rc('font', **font)
plt.rcParams.update({
    "text.usetex": True
})
# End

heatmap_height = 1
ls_ratio = 0.4


# fig, ax = plt.subplots(nrows=4, ncols=1, sharex=True, gridspec_kw={'hspace':0.05})
# imgs = []
# for i, a in enumerate(ax):
#     im = a.imshow(np.random.randn(100,100), cmap='jet', origin='lower')
#     imgs.append(im)
#     divider = make_axes_locatable(a)
#     cax = divider.append_axes('right', size='5%', pad='5%')
#     if i <= 3:
#         cax.set_axis_off()
# cbar = fig.colorbar(imgs[-1], cax=cax)
# plt.show()

fig, ax = plt.subplots(2, 1,
                       #figsize=(heatmap_height * (1 + ls_ratio), heatmap_height), 
                       gridspec_kw={'height_ratios': [1, 3]}, sharex=True)

# def stylize_axes(ax):
#     ax.spines['top'].set_visible(False)
#     ax.spines['right'].set_visible(False)

#     ax.xaxis.set_tick_params(top='off', direction='out', width=1)
#     ax.yaxis.set_tick_params(right='off', direction='out', width=1)

#stylize_axes(ax[0])
ax[0].plot(ts)
cm = sns.color_palette("Set2")
eq1 = np.array(labels == 1)
ax[0].fill_between(t, 0, 1, where=eq1,
                color=cm[0], alpha=0.3, transform=ax[0].get_xaxis_transform())
ax[0].fill_between(t, 0, 1, where=np.roll(~eq1, -1) | np.roll(~eq1, 1),
                color=cm[1], alpha=0.3, transform=ax[0].get_xaxis_transform())
ax[0].set_ylabel("$d(\mathrm{N},\mathrm{N})$ [\AA]")
ax[0].yaxis.set_ticks([0, 4.0])
ax[0].set_box_aspect(1/3)

im = ax[1].imshow(K)
ax[1].set_aspect("equal")
ax[1].xaxis.set_major_locator(plt.MaxNLocator(4))
ax[1].yaxis.set_major_locator(plt.MaxNLocator(4))
ax[1].set_xlabel("t")
ax[1].set_ylabel("t")
ax[1].set_box_aspect(1)

from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import numpy as np

axins = inset_axes(ax[1],
                    width="5%",
                    height="100%",
                    loc='right',
                    borderpad=-3)
cbar = fig.colorbar(im, cax=axins, orientation="vertical")
cbar.set_label("Similarity")
fig.savefig(plot_dir / "spectral_clustering", dpi=150, bbox_inches='tight')#, transparent=True)