In [None]:
import matplotlib.pyplot as plt
import numpy as np

import ase
from ase.neighborlist import natural_cutoffs, NeighborList
from ase.io.trajectory import Trajectory
from ase.io import read as ase_read

import fnmatch
import re

In [None]:
traj_path = '../test_data/demo_trajectory.lammpstrj'

trajectory = ase_read(traj_path, index=':') # index=':' loads the whole trajectory into memory, consider not doing that
# Could be better to just iterate through the file using index='%u'%ii

In [None]:
class HOchainFinder():
    def __init__(self, cutoff_mult=0.75):
        self.ti_regex = re.compile(r'Ti[2-9]')
        self.cutoff_mult = cutoff_mult
    
    def test(self, trajectory):
        initial_config = trajectory[0]
        initial_config.set_pbc([True, True, False])
        self.cut_offs = natural_cutoffs(initial_config, mult=self.cutoff_mult)
        neighbour_list = NeighborList(self.cut_offs, bothways=True, self_interaction=False)
        
        special_list = []
        for ii_snap, snapshot in enumerate(trajectory):
            snapshot.set_pbc([True, True, False])
            neighbour_list.update(snapshot)
            oxygen_indices = [atom.index for atom in snapshot if atom.symbol == 'O']
            oxygen_neighbours = [
                neighbour_list.get_neighbors(oxygen_index)[0] for oxygen_index in oxygen_indices
            ]
            specials_count = 0
            special_sites = []
            for ii_neighs, neighbours in enumerate(oxygen_neighbours):
                symbols = snapshot[neighbours].get_chemical_formula(mode='hill')
                if re.search(self.ti_regex, symbols) is None and not ("H2" in symbols):
                    specials_count += 1
                    special_sites.append(symbols)
                    
                    plot_snapshot = False
                    if plot_snapshot:
                        found_index = oxygen_indices[ii_neighs]
                        neighs = neighbour_list.get_neighbors(found_index)[0]                        
                        self.plot_snapshot(snapshot, found_index, neighs)
                        raise ValueError()
            
            special_list.append(special_sites)
            # print(ii_snap, ": %u"%specials_count)        
            if ii_snap == 900:
                break
        return special_list
    
    def plot_special_config(self, snapshot, cut_offs=None):
        if cut_offs is None:
            cut_offs = self.cut_offs
            
        neighbour_list = NeighborList(cut_offs, bothways=True, self_interaction=False)
        neighbour_list.update(snapshot)
        oxygen_indices = [atom.index for atom in snapshot if atom.symbol == 'O']
        oxygen_neighbours = [
            neighbour_list.get_neighbors(oxygen_index)[0] for oxygen_index in oxygen_indices
        ]
        
        index_list = []
        neigh_list = []
        for ii_neighs, neighbours in enumerate(oxygen_neighbours):
            symbols = snapshot[neighbours].get_chemical_formula(mode='hill')
            if re.search(self.ti_regex, symbols) is None and not ("H2" in symbols):
                found_index = oxygen_indices[ii_neighs]
                neighs = neighbour_list.get_neighbors(found_index)[0]
                index_list.append(found_index)
                neigh_list.append(neighs)
        
        self.plot_snapshot(snapshot, index_list, neigh_list)
    
    @staticmethod
    def _mark_masses(masses, snapshot, found_index, neighs):
        found_index = np.array([found_index])
        mark_indices = np.append(neighs, found_index, axis=0)
        masses[mark_indices] = 100
        return masses
        
    @staticmethod
    def plot_snapshot(snapshot, found_index, neighs):
        fig = plt.figure(figsize=(12, 8))
        ax = fig.add_subplot(projection='3d')
        
        at_pos = snapshot.get_positions()
        masses = snapshot.get_masses()
        
        if isinstance(found_index, list):
            for single_index, single_neighs in zip(found_index, neighs):
                HOchainFinder._mark_masses(masses, snapshot, single_index, single_neighs)
        elif isinstance(found_index, int):
            HOchainFinder._mark_masses(masses, snapshot, found_index, neighs)

        sizes = np.array(natural_cutoffs(snapshot, mult=500))
        
        sc = ax.scatter(
            at_pos[:, 0], at_pos[:, 1], at_pos[:, 2], 
            s=sizes,
            c=masses,
            alpha=1,
            edgecolors="k", # vmin=0, vmax=1
        )
        
        z_span = (ax.get_zlim()[1] - ax.get_zlim()[0])/2.
        ax.set_ylim([np.mean(ax.get_ylim()) - z_span, np.mean(ax.get_ylim()) + z_span])
        ax.set_xlim([np.mean(ax.get_xlim()) - z_span, np.mean(ax.get_xlim()) + z_span])
        # cbar = fig.colorbar(sc)
        plt.show()

In [None]:
chain_finder = HOchainFinder()
special_list = chain_finder.test(trajectory)

In [None]:
n_specials = np.array([len(special_sites) for special_sites in special_list], dtype=np.int32)
# plt.plot(n_specials)
for ii_special, special_sites in enumerate(special_list):
    if "H3" in special_sites:
        print(ii_special)

In [None]:
%matplotlib auto

plot_num = 335
print(n_specials[plot_num])
chain_finder.plot_special_config(trajectory[plot_num], natural_cutoffs(trajectory[plot_num], mult=0.75))