In [None]:
import matplotlib.pyplot as plt
import numpy as np

import ase
from ase.neighborlist import natural_cutoffs, NeighborList
from ase.io.trajectory import Trajectory
from ase.io import read as ase_read

import re

In [None]:
traj_path = '../test_data/demo_trajectory.lammpstrj'

trajectory = ase_read(traj_path, index=':') # index=':' loads the whole trajectory into memory, consider not doing that
# Could be better to just iterate through the file using index='%u'%ii

In [None]:
class HOchainFinder():
    def __init__(self, cutoff_mult=0.75):
        self.ti_regex = re.compile(r'Ti[2-9]')
        self.cutoff_mult = cutoff_mult
    
    def test(self, trajectory):
        initial_config = trajectory[0]
        initial_config.set_pbc([True, True, True])
        self.cut_offs = natural_cutoffs(initial_config, mult=self.cutoff_mult)
        neighbour_list = NeighborList(self.cut_offs, bothways=True, self_interaction=False)
        
        special_list = []
        for ii_snap, snapshot in enumerate(trajectory):
            snapshot.set_pbc([True, True, True])
            neighbour_list.update(snapshot)
            
            oxygen_indices = [atom.index for atom in snapshot if atom.symbol == 'O']
            oxygen_neighbours = [
                neighbour_list.get_neighbors(oxygen_index)[0] for oxygen_index in oxygen_indices
            ]
            
            specials_count = 0
            special_sites = []
            for ii_neighs, neighbours in enumerate(oxygen_neighbours):
                symbols = snapshot[neighbours].get_chemical_formula(mode='hill')
                # Determine if water molecule is bonded to Ti and has lost one H
                if re.search(self.ti_regex, symbols) is None and not ("H2" in symbols):
                    specials_count += 1
                    special_sites.append(np.append(neighbours, oxygen_indices[ii_neighs]))
            
            special_list.append(special_sites)
            # print(ii_snap, ": %u"%specials_count)        
            
        return special_list
    
    def plot_special_config(self, snapshot, cut_offs=None):
        if cut_offs is None:
            cut_offs = self.cut_offs
            
        neighbour_list = NeighborList(cut_offs, bothways=True, self_interaction=False)
        neighbour_list.update(snapshot)
        oxygen_indices = [atom.index for atom in snapshot if atom.symbol == 'O']
        oxygen_neighbours = [
            neighbour_list.get_neighbors(oxygen_index)[0] for oxygen_index in oxygen_indices
        ]
        
        index_list = []
        neigh_list = []
        for ii_neighs, neighbours in enumerate(oxygen_neighbours):
            symbols = snapshot[neighbours].get_chemical_formula(mode='hill')
            if re.search(self.ti_regex, symbols) is None and not ("H2" in symbols):
                found_index = oxygen_indices[ii_neighs]
                neighs = neighbour_list.get_neighbors(found_index)[0]
                index_list.append(found_index)
                neigh_list.append(neighs)
        
        self.plot_snapshot(snapshot, index_list, neigh_list)
    
    @staticmethod
    def _mark_masses(masses, snapshot, found_index, neighs):
        found_index = np.array([found_index])
        mark_indices = np.append(neighs, found_index, axis=0)
        masses[mark_indices] = 100
        return masses
        
    @staticmethod
    def plot_snapshot(snapshot, found_index, neighs):
        fig = plt.figure(figsize=(12, 8))
        ax = fig.add_subplot(projection='3d')
        
        at_pos = snapshot.get_positions()
        masses = snapshot.get_masses()
        
        if isinstance(found_index, list):
            for single_index, single_neighs in zip(found_index, neighs):
                HOchainFinder._mark_masses(masses, snapshot, single_index, single_neighs)
        elif isinstance(found_index, int):
            HOchainFinder._mark_masses(masses, snapshot, found_index, neighs)

        sizes = np.array(natural_cutoffs(snapshot, mult=500))
        
        sc = ax.scatter(
            at_pos[:, 0], at_pos[:, 1], at_pos[:, 2], 
            s=sizes,
            c=masses,
            alpha=1,
            edgecolors="k", # vmin=0, vmax=1
        )
        
        z_span = (ax.get_zlim()[1] - ax.get_zlim()[0])/2.
        ax.set_ylim([np.mean(ax.get_ylim()) - z_span, np.mean(ax.get_ylim()) + z_span])
        ax.set_xlim([np.mean(ax.get_xlim()) - z_span, np.mean(ax.get_xlim()) + z_span])
        # cbar = fig.colorbar(sc)
        plt.show()

In [None]:
print(len(trajectory))

In [None]:
chain_finder = HOchainFinder()
special_list = chain_finder.test(trajectory)

In [None]:
def find_hopping(
    trajectory, neighbour_list_cur, neighbour_list_prev,
    ii_step, special_config, counter
):
    oxygen_index = special_config[-1]
    prev_neighbours = neighbour_list_prev.get_neighbors(oxygen_index)[0]
    which_are_gone = np.setdiff1d(prev_neighbours, special_config)
    
    # cur_form = trajectory[ii_step][special_config].get_chemical_formula(mode='hill')
    # prev_form = trajectory[ii_step][np.append(prev_neighbours, oxygen_index)].get_chemical_formula(mode='hill')    
    if not len(which_are_gone) > 0:
        if counter == 0:
            cur_form = trajectory[ii_step][special_config].get_chemical_formula(mode='hill')
            prev_form = trajectory[ii_step][np.append(prev_neighbours, oxygen_index)].get_chemical_formula(mode='hill')
            if (not(cur_form == prev_form)) or (not(cur_form == 'HOTi')):
                # This should not happen
                return -4
            # This is the case of water staying in TiO configuration for multiple steps
            return -1
        else:
            # Ti-O-H-O-H chain, should resolve in next timestep
            return -3
    
    #Check if missing atom is actually single H
    diff_atoms = trajectory[ii_step][which_are_gone].get_chemical_symbols()
    
    # Only one atom change, should be H
    if (len(diff_atoms) != 1) or (diff_atoms[0] != 'H'):
        if diff_atoms[0] == 'Ti':
            # Surface hopping case
            return -5
        raise ValueError(f'Diff is {diff_atoms} instead of just H')
        
    cur_gone_bound_oxygen = neighbour_list_cur.get_neighbors(which_are_gone[0])[0]
    cur_bound = trajectory[ii_step][cur_gone_bound_oxygen].get_chemical_symbols()
    if (len(cur_bound) != 1) or (cur_bound[0] != 'O'):
        raise ValueError(f'Current configuration of diff is {cur_bound} instead of just O')

    cur_oxy_bond_inds = np.append(
        neighbour_list_cur.get_neighbors(cur_gone_bound_oxygen[0])[0],
        cur_gone_bound_oxygen[0]
    )
    cur_oxy_bond = trajectory[ii_step][cur_oxy_bond_inds]

    # This is exactly what needs to be tracked until one atom lands on TiO
    cur_oxy_formula = cur_oxy_bond.get_chemical_formula(mode='hill')
    if cur_oxy_formula == 'H2O':
        return find_hopping(
            trajectory, neighbour_list_cur, neighbour_list_prev,
            ii_step, cur_oxy_bond_inds, counter=counter+1
        )
    elif 'HOTi' in cur_oxy_formula:
        return counter
    else:
        # Something like H20Ti, H is not yet split, but should be happening in next timestep
        return -2

initial_config = trajectory[0]
neighbour_list_cur = NeighborList(chain_finder.cut_offs, bothways=True, self_interaction=False)
neighbour_list_prev = NeighborList(chain_finder.cut_offs, bothways=True, self_interaction=False)

hop_list = []
for ii_step, specials in enumerate(special_list[1:]):
# for ii_step, specials in zip([102], [special_list[103]]):
    # Go through all found special configurations
    ii_step += 1
    print(ii_step)
    if len(special_list) != 0:
        cur_snapshot = trajectory[ii_step]
        cur_snapshot.set_pbc([True, True, True])
        
        # Unfortunately this is necessary, neighbourlist has no copy function
        neighbour_list_cur.update(cur_snapshot) 
        neighbour_list_prev.update(trajectory[ii_step-1])
        
        for special_config in specials:
            hops = find_hopping(
                trajectory, neighbour_list_cur, neighbour_list_prev,
                ii_step, special_config, counter=0
            )
            hop_list.append(hops)

In [None]:
hops = np.asarray(hop_list)
hop_counts = np.bincount(hop_list-np.min(hop_list))
fig, ax = plt.subplots(1, 1, figsize=(6,6))
ax.hist(hops, bins=np.arange(-6, np.max(hops+1))+0.5)
plt.show()
plt.bar(np.arange(np.min(hop_list), np.max(hop_list)+1), hop_counts)
plt.show()

In [None]:
n_specials = np.array([len(special_sites) for special_sites in special_list], dtype=np.int32)
# plt.plot(n_specials)
for ii_special, special_sites in enumerate(special_list):
    if "H3" in special_sites:
        print(ii_special)

In [None]:
%matplotlib auto

plot_num = 1
print(n_specials[plot_num])
chain_finder.plot_special_config(trajectory[plot_num], natural_cutoffs(trajectory[plot_num], mult=0.75))