In [None]:
import matplotlib.pyplot as plp
import scipy.cluster.hierarchy as sch
from  h5py import File
import pandas as pd
from pandas import DataFrame
import h5py
from HDF5er import saveXYZfromTrajGroup,MDA2HDF5,saveXYZfromTrajGroup
import numpy
from MDAnalysis import Universe as mdaUniverse
from SOAPify import (saponifyGroup,
                    createReferencesFromTrajectory,
                    mergeReferences,
                    SOAPdistanceNormalized,
                    saveReferences,
                    getReferencesFromDataset,
                    classify
                    )

loadReferences=True
soapReferences=True



def patchBoxFromTopology(hdf5TrajFile:str,topologyFile:str):
    u=mdaUniverse(topologyFile,atom_style="id type x y z")
    with h5py.File(hdf5TrajFile,"a") as workFile:
        for key in workFile['Trajectories']:
            tgroup=workFile[f'Trajectories/{key}']
            tgroup['Box'][:]=[u.dimensions]*tgroup['Box'].shape[0]

In [None]:
def getCompactedAnnotationsForTmat(tmat) -> list:
    """
    Returns a list of compacted annotations for a given tmat.
    """
    annot = list(numpy.empty(tmat.shape, dtype=str))
    # annot=numpy.chararray(tmat.shape, itemsize=5)
    for row in range(tmat.shape[0]):
        annot[row] = list(annot[row])
        for col in range(tmat.shape[1]):
            if tmat[row, col] < 0.01:
                annot[row][col] = f"<0.01"
            elif tmat[row, col] > 0.99:
                annot[row][col] = f">0.99"
            else:
                annot[row][col] = f"{tmat[row,col]:.2f}"
    return annot

In [None]:
def convertClusteringV(classification:numpy.ndarray, conversion:numpy.ndarray):
    myf=lambda x: conversion[x]
    vfunc = numpy.vectorize(myf)
    t=vfunc(classification.flatten())
    

    return t.reshape(classification.shape)

Here we evaluate the distances between the soap environments

In [None]:
references = {}
with File("../../create_reference/references.hdf5", "r") as refFile:
    g = refFile["testReferences"]
    for k in g.keys():
        references[k] = getReferencesFromDataset(g[k])

wholeData = mergeReferences(
    references["111"], references["110"], references["211"], references["210"])

ref1=wholeData
ndataset = len(ref1)
r1 = numpy.zeros((int(ndataset * (ndataset - 1) / 2)))
cpos = 0
for i in range(ndataset):
    for j in range(i + 1, ndataset):
        r1[cpos] = SOAPdistanceNormalized(
            ref1.spectra[i], ref1.spectra[j]
        )
        cpos += 1
        
links=sch.linkage(r1,method="complete")
clusterCut={}
for cut in [0.01]: 
    c=sch.fcluster(links,t=cut, criterion="distance")
    clusterCut[str(cut).replace(".","_")]=c 

Here, we obtain the hierarchy dendrogram cut at 0.01 d_soap; This allowes us to  merge very similar SOAP environments into
common macro-clusters: e.g., bulk (b*), sub-surface (ss*), surface (s).

In [None]:
cut=0.01
import matplotlib.pyplot as plt
links=sch.linkage(r1,method="complete")
fig,ax=plp.subplots(1,1, dpi=150)
plt.rcParams["figure.figsize"] = [5.50, 5.50]
plt.rcParams["figure.autolayout"] = True
dendro=sch.dendrogram(links,color_threshold=cut,labels=wholeData.names, orientation="left", 
                      ax=ax,above_threshold_color='black')
clusters=sch.fcluster(links,t=cut, criterion="distance")
Name_dict=dict()
names=["" for i in range (0,max(clusters))]
for i,n in enumerate(wholeData.names):
    print(i,n,clusters[i])
    Name_dict[n]=clusters[i]-1
    names[clusters[i]-1]+=' ' + n
names    

here we classify trajectory using the atomic environments defined in the complete dendrogram

In [None]:
with h5py.File('../211.hdf5', "r") as workFile:
        g=workFile[f"SOAP"]
        for key in workFile[f"SOAP"].keys():
            cls = {}
            t= classify(g[key], wholeData, SOAPdistanceNormalized, True)
            cls[f"whole"] = t.references

Here we convert the classification derived from the complete dictionary, to that obtained with the cut at 0.01 d_soap.
we then save the trajectory with the new classification

In [None]:
with h5py.File('../211.hdf5', "r") as workFile:
        for key in workFile[f"SOAP"].keys():

            for k in clusterCut:
                cls[k] = convertClusteringV(cls[f"whole"],clusterCut[k]) 
            saveXYZfromTrajGroup(
                f"211_T_700_001.xyz",
                workFile[f"Trajectories/{key}"],
                **cls,
            )

we selected only the atoms belonging to the most top layers to calculate the transition matrix

In [None]:
with h5py.File('../211.hdf5', "r") as workFile:
    mask=workFile["Trajectories/211_T_700/Trajectory"][0,:,2]>5.0
    x=cls[k][:][:,mask]
    print(x.shape)

Here we assign each cluster to the ideal surfaces (0=111,1=110,2=211,3=210), in order to obtain the transition matrix between native and non-native environments 

In [None]:
new_clusters_cut_211 = {
     '1' : 0,
     '2' : 2,
     '3' : 3,
     '4' : 1,
     '5' : 3,
     '6' : 2,
    '7' :  1,
    '8' :  2,
    '9' :  2,
    '10':  0,
    '11':  3,
    '12':  2,
    '13':  2,
    '14':  1,
    '15':  2,
    '16':  3,}     
new_array = numpy.empty(numpy.shape(x),dtype=numpy.int64)
for row_idx, row in enumerate(x):
    for col_idx , col in enumerate(row):
        new_array[row_idx,col_idx] = new_clusters_cut_211[str(col)]

In [None]:
from pandas import DataFrame
import matplotlib.pyplot as plt
import seaborn as sns
from SOAPify import (SOAPclassification,
transitionMatrixFromSOAPClassificationNormalized as TransitionMatrixMaker,
transitionMatrixFromSOAPClassification as TransitionMatrixMakerNotNorm,
)
#names=[]
names_4=['111','110','211','210']
plt.rcParams.update({'font.size': 35})
classification = SOAPclassification(
None, 
new_array[250:],
names_4[:]
#names[:],
)
matrix_cl=TransitionMatrixMaker(classification, 1)
#annot=getCompactedAnnotationsForTmat(matrix_cl)
mask=matrix_cl==0
fig, ax = plt.subplots(figsize=(10,10)) 
matrix_name = DataFrame(matrix_cl,index=names_4,columns=names_4)
ax = sns.heatmap(matrix_name*100,linewidths=0.1,ax=ax, fmt=".0f", annot=True,square=True,cmap="rocket_r", 
                 vmax=100, vmin=0, 
                 mask=mask,cbar=False )