In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

import ase
from ase.neighborlist import natural_cutoffs, NeighborList
from ase.io.trajectory import Trajectory
from ase.io import read as ase_read

import re

In [None]:
traj_path = '../test_data/anatase-001-nd-0/out.lammpstrj'

trajectory = ase_read(traj_path, index=':') # index=':' loads the whole trajectory into memory, consider not doing that
# Could be better to just iterate through the file using index='%u'%ii

In [None]:
class HOchainFinder():
    def __init__(self, trajectory, cutoff_mult=0.75):
        self.standard_config = 'H2O' # Sometimes needed to initialise configuration dictionary
        self.ti_regex = re.compile(r'Ti[2-9]')
        self.cutoff_mult = cutoff_mult
        self.trajectory = trajectory
        
        self.current_neighbourlist = None
        self.previous_neighbourlist = None
        
        self._init_neighbourlists()
        
    def _init_neighbourlists(self):
        self.cut_offs = natural_cutoffs(self.trajectory[0], mult=self.cutoff_mult)
        self.current_neighbourlist = NeighborList(self.cut_offs, bothways=True, self_interaction=False)
        self.previous_neighbourlist = NeighborList(self.cut_offs, bothways=True, self_interaction=False)
        
    def _is_single_ti_bond(self, ii_snap, oxygen_index, **kwargs):
        neighbours = self.current_neighbourlist.get_neighbors(oxygen_index)[0]
        symbols = self.trajectory[ii_snap][neighbours].get_chemical_formula(mode='hill')
        return ((re.search(self.ti_regex, symbols) is None) and not ("H2" in symbols))
    
    def _is_ti_bonded_and_changed(self, ii_snap, oxygen_index):
        neighbours = self.current_neighbourlist.get_neighbors(oxygen_index)[0]
        symbols = self.trajectory[ii_snap][neighbours].get_chemical_formula(mode='hill')
        
        if 'Ti' in symbols and not "H2" in symbols: # Bonded to Ti 
            prev_neighbours = self.previous_neighbourlist.get_neighbors(oxygen_index)[0]
            diffs = np.setxor1d(neighbours, prev_neighbours)
            # Something changed in its neighbourhood
            if len(diffs) != 0:
                diff_symbols = self.trajectory[ii_snap][diffs].get_chemical_formula(mode='hill')
                # Check if H was involved in change
                if 'H' in diff_symbols:
                    return True
            
        return False
    
    def find_special_configs(self, is_special=None):
        # Return all special configurations in list of list of np.ndarrays
        # Starting with snapshot 1
        
        if is_special is None:
            is_special = self._is_ti_bonded_and_changed
        
        initial_config = self.trajectory[0]
        initial_config.set_pbc([True, True, True])
        
        special_list = []
        for ii_snapshot in range(1, len(self.trajectory)):
            if ii_snapshot % 100 == 0:
                print(ii_snapshot)
            # TODO: does every snapshot need this:
            # snapshot.set_pbc([True, True, True])
            snapshot = self.trajectory[ii_snapshot]
            self.current_neighbourlist.update(snapshot)
            self.previous_neighbourlist.update(self.trajectory[ii_snapshot-1])
            
            # TODO: Build indices once, then check if all oxygen indices stay oxygen
            oxygen_indices = [atom.index for atom in snapshot if atom.symbol == 'O']
            
            specials_count = 0
            special_sites = []
            for ii_index, oxygen_index in enumerate(oxygen_indices):
                # Use special picker to determine if oxygen is in analysable configuration
                if is_special(ii_snapshot, oxygen_index):
                    neighbours = self.current_neighbourlist.get_neighbors(oxygen_index)[0]
                    specials_count += 1
                    special_sites.append(np.append(neighbours, oxygen_index))
            
            special_list.append(special_sites)
            # print(ii_snap, ": %u"%specials_count)        
            
        return special_list
    
    def all_configs(self):
        """Find all configurations of oxygen for all timesteps.
        Build an array of zeros shaped (n_timesteps, 1) and an empty config dict.
        Initialise config dict with some trial config:
        config_dict['H20'] == 0,
        which means 'H20' has entry 0 in count array.
        
        When a new config is found, append all zeros array to count array and make new entry in config_dict
        """
        # TODO: Build indices once, then check if all oxygen indices stay oxygen
        oxygen_indices = [atom.index for atom in self.trajectory[0] if atom.symbol == 'O']

        config_dict = {self.standard_config: 0}
        count_array = np.zeros((len(self.trajectory), 1), dtype=np.int16)
        config_codes = np.zeros(shape=(len(self.trajectory), len(oxygen_indices)), dtype=np.int8)

        for ii_snapshot in range(0, len(self.trajectory)):
            if ii_snapshot % 100 == 0:
                print(ii_snapshot)
            snapshot = self.trajectory[ii_snapshot]
            self.current_neighbourlist.update(snapshot)

            for ii_index, oxygen_index in enumerate(oxygen_indices):
                # Go through each oxygen index
                neighs = self.current_neighbourlist.get_neighbors(oxygen_index)[0]
                ox_config = snapshot[np.append(neighs, oxygen_index)].get_chemical_formula(mode='hill')

                # TODO: make this a class method
                try: # key is already in config_dict
                    loc = config_dict[ox_config]
                except: # key in not in config_dict, add
                    loc = len(config_dict)
                    config_dict[ox_config] = loc
                    count_array = np.append(count_array, np.zeros((len(self.trajectory), 1), dtype=np.int16), axis=-1)
                    print(count_array.shape)
               
                count_array[ii_snapshot, loc] += 1
                config_codes[ii_snapshot, ii_index] = loc

        return config_dict, count_array, config_codes

    def plot_special_config(self, snapshot, cut_offs=None):
        if cut_offs is None:
            cut_offs = self.cut_offs
            
        neighbour_list = NeighborList(cut_offs, bothways=True, self_interaction=False)
        neighbour_list.update(snapshot)
        oxygen_indices = [atom.index for atom in snapshot if atom.symbol == 'O']
        oxygen_neighbours = [
            neighbour_list.get_neighbors(oxygen_index)[0] for oxygen_index in oxygen_indices
        ]
        
        index_list = []
        neigh_list = []
        for ii_neighs, neighbours in enumerate(oxygen_neighbours):
            symbols = snapshot[neighbours].get_chemical_formula(mode='hill')
            if re.search(self.ti_regex, symbols) is None and not ("H2" in symbols):
                found_index = oxygen_indices[ii_neighs]
                neighs = neighbour_list.get_neighbors(found_index)[0]
                index_list.append(found_index)
                neigh_list.append(neighs)
        
        self.plot_snapshot(snapshot, index_list, neigh_list)
    
    @staticmethod
    def _mark_masses(masses, snapshot, found_index, neighs):
        found_index = np.array([found_index])
        mark_indices = np.append(neighs, found_index, axis=0)
        masses[mark_indices] = 100
        return masses
        
    @staticmethod
    def plot_snapshot(snapshot, found_index, neighs):
        fig = plt.figure(figsize=(12, 8))
        ax = fig.add_subplot(projection='3d')
        
        at_pos = snapshot.get_positions()
        masses = snapshot.get_masses()
        print(type(found_index))
        if isinstance(found_index, list):
            for single_index, single_neighs in zip(found_index, neighs):
                HOchainFinder._mark_masses(masses, snapshot, single_index, single_neighs)
        elif isinstance(found_index, int):
            HOchainFinder._mark_masses(masses, snapshot, found_index, neighs)

        sizes = np.array(natural_cutoffs(snapshot, mult=500))
        
        sc = ax.scatter(
            at_pos[:, 0], at_pos[:, 1], at_pos[:, 2], 
            s=sizes,
            c=masses,
            alpha=1,
            edgecolors="k", # vmin=0, vmax=1
        )
        
        z_span = (ax.get_zlim()[1] - ax.get_zlim()[0])/2.
        ax.set_ylim([np.mean(ax.get_ylim()) - z_span, np.mean(ax.get_ylim()) + z_span])
        ax.set_xlim([np.mean(ax.get_xlim()) - z_span, np.mean(ax.get_xlim()) + z_span])
        # cbar = fig.colorbar(sc)
        plt.show()

In [None]:
print(len(trajectory))

In [None]:
chain_finder = HOchainFinder(trajectory)
config_dict, counts, config_codes = chain_finder.all_configs()

In [None]:
configs = ['']*len(config_dict)
for key, value in config_dict.items():
    configs[value] = key

total_counts = np.sum(counts, axis=0)
print("In total found: ")
for ii_conf in range(total_counts.shape[0]):
    print(configs[ii_conf]+': ', total_counts[ii_conf])

In [None]:
special_list = chain_finder.find_special_configs()

In [None]:
import pandas as pd

symbols = []
for ii_special, specials in enumerate(special_list):
    ii_snapshot = ii_special + 1
    snapshot = chain_finder.trajectory[ii_snapshot]
    for special_config in specials:
        symbols.append(snapshot[special_config].get_chemical_formula(mode='hill'))
        
count = pd.Series(symbols).value_counts()
print(count)

In [None]:
class hoppingTest(HOchainFinder):
    def __init__(self, chain_finder):
        self.ti_regex = chain_finder.ti_regex
        self.cutoff_mult = chain_finder.cutoff_mult
        self.trajectory = chain_finder.trajectory
        
        self.current_neighbourlist = chain_finder.current_neighbourlist
        self.previous_neighbourlist = chain_finder.previous_neighbourlist
        
        self.counter = 0
        
    def _h30_processor(self, config, n_list):
        '''Fix configuration if neighbouring H is accidentally too close to O
        Alternative: Remove H from config that has two Os as neighbours
        
        config must be sorted according to distance to O
        '''
        atom = self.trajectory[0][config]
        if atom.get_chemical_formula(mode='hill') == 'H3O':
            symbols = atom.get_chemical_symbols()
            for ii_elem in range(len(symbols)):
                # First evaluate symbol H, then if it has more than one (Oxygen) neighbour, remove
                if symbols[ii_elem] == 'H' and len(n_list.get_neighbors(config[ii_elem])[0]) > 1:
                    config = np.delete(config, ii_elem)
                    break
        return config
            
    def find_hopping(self, oxygen_index, cur_step, counter, prev_step=None, verbose=False):
        if prev_step is None:
            prev_step = cur_step-1
            
        self.current_neighbourlist.update(self.trajectory[cur_step])
        self.previous_neighbourlist.update(self.trajectory[prev_step])
        
        special_config = self.current_neighbourlist.get_neighbors(oxygen_index)[0]
        special_config = np.append(special_config, oxygen_index)
        # special_config = self._h30_processor(special_config, self.current_neighbourlist)
        prev_neighbours = self.previous_neighbourlist.get_neighbors(oxygen_index)[0]
        prev_config = np.append(prev_neighbours, oxygen_index)
        # prev_config = self._h30_processor(special_config, self.previous_neighbourlist)
        
        which_are_gone = np.setdiff1d(prev_config, special_config)
        which_are_added = np.setdiff1d(special_config, prev_config)
        
        cur_form = self.trajectory[cur_step][special_config].get_chemical_formula(mode='hill')
        prev_form = self.trajectory[cur_step][prev_config].get_chemical_formula(mode='hill')
        
        # Cases:
        # HOTiX which HO have bonded to: Either H2OTi and H has left or H20 and H has left
        # or whole H20 is double bonded 
        # OTiX which has been left by H: Problem for later
        # 
        # if in loop H20 or H20Ti which got an additional H and hopfully lost one H
        
        # First check if it's a OTiX where an H has left
        if counter == 0:
            if (('OTi' in cur_form) and not ('H' in cur_form)):
                # Simply ignore OTiX
                return -1
            if 'HOTi' in cur_form and (('OTi' in prev_form) and not ('H' in prev_form)):
                # Simply ignore OTiX
                self.current_neighbourlist.update(self.trajectory[cur_step+1])
                neighs = self.current_neighbourlist.get_neighbors(oxygen_index)[0]
                atom = self.trajectory[cur_step][np.append(neighs, oxygen_index)]

                future_form = atom.get_chemical_formula(mode='hill')
                if 'HOTi' in future_form:
                    self.counter += 1
                
                self.current_neighbourlist.update(self.trajectory[cur_step])
                return -2
            
        if ("H2O" in cur_form) and ('HOTi' in prev_form):
            # Special case where path is H20 -> HOTi + H but on other site HOTi + H -> H20
            return -7

        if verbose: # len(which_are_gone):
            print(self.trajectory[0][which_are_gone].get_chemical_formula())
        
        
        # Now it should be HOTiX
        if len(which_are_gone) != 1:
            # Optimally this should only be an H atom
            if len(which_are_gone) != 0:
                return -5 # TODO: If not, find H in gone atoms, surface hopping
            # Nothing left, so H atom must have been added
            which_are_gone = which_are_added
            if len(which_are_gone) == 0:
                # For some reason: no change
                return -6 # fixed
            else:
                return -8 # fixed
            
        prev_bound_oxygen = self.previous_neighbourlist.get_neighbors(which_are_gone[0])[0]
        if len(prev_bound_oxygen) > 1 and (counter == 0):
            # Flyby of H20 at TiO site
            return -3
        
        if verbose:
            print("####################")
            print("Steps")
            print(cur_step)
            print(prev_step)
            print("Configs")
            print(special_config, prev_config)
            print(self.trajectory[0][special_config].get_chemical_symbols(), self.trajectory[0][prev_config].get_chemical_symbols())
            print(which_are_gone)
            print(which_are_added)
            print("Forms")
            print(cur_form, prev_form)
            print("Gone: ", which_are_gone)
            print("Oxygen")
        
        # Find O that H is bond to now
        cur_gone_bound_oxygen = self.current_neighbourlist.get_neighbors(which_are_gone[0])[0]
        cur_bound = self.trajectory[cur_step][cur_gone_bound_oxygen].get_chemical_formula(mode='hill')
        if cur_bound != 'O':
            # H20 close to TiO, sharing bond
            if not (cur_bound == 'O2'):
                # This should never happen
                return -4 # fixed
            return self.find_hopping(oxygen_index, cur_step+1, counter, prev_step, verbose=verbose)
        
        # Find all neighbours of that oxygen
        cur_oxy_bond_inds = np.append(
            self.current_neighbourlist.get_neighbors(cur_gone_bound_oxygen[0])[0],
            cur_gone_bound_oxygen[0]
        )
        
        if verbose:
            print(self.previous_neighbourlist.get_neighbors(which_are_gone[0])[0])
            print(cur_gone_bound_oxygen)
            print(cur_bound)
            
        # cur_oxy_bond_inds = self._h30_processor(cur_oxy_bond_inds, self.current_neighbourlist)
        cur_oxy_bond = self.trajectory[cur_step][cur_oxy_bond_inds]
        cur_oxy_symbols = cur_oxy_bond.get_chemical_formula(mode='hill')
        # print(cur_oxy_symbols)
        if 'HOTi' in cur_oxy_symbols:
            return counter
        elif 'H2O' in cur_oxy_symbols:
            if counter < 100:
                return self.find_hopping(cur_gone_bound_oxygen[0], cur_step, counter=counter+1, prev_step=prev_step, verbose=verbose)
            else:
                print(-9, cur_step, prev_step)
                return -9 # fixed
        else:
            # Probably H30, see how it develops over steps
            self.current_neighbourlist.update(self.trajectory[cur_step+1])
            if not self._is_ti_bonded_and_changed(cur_step+1, oxygen_index):
                # Isn't special site after all
                return -3
            return self.find_hopping(oxygen_index, cur_step+1, counter=counter+1, prev_step=prev_step, verbose=verbose)
        
hopping_test = hoppingTest(chain_finder)

run_special = 3466
hop_list = []
for ii_step, specials in enumerate(special_list):
# for ii_step, specials in zip([run_special], [special_list[run_special]]):
    # Go through all found special configurations
    ii_step += 1
    if ii_step % 100 == 0:
        print(ii_step)
    
    if len(special_list) != 0:
        for special_config in specials:
            hops = hopping_test.find_hopping(special_config[-1], ii_step, counter=0, verbose=False)
            hop_list.append(hops)
            
print(hopping_test.counter)

In [None]:
%matplotlib inline

hops = np.asarray(hop_list)
print(np.min(hop_list), np.max(hop_list))
hop_counts = np.bincount(hop_list-np.min(hop_list))
print(hop_counts)

#-8, -7, -6, -5 ... -1
labels = ["Exchange of H?", "H ends at\nHOTi->H2O Ti", "H ends at\nHOTi->H2O Ti", "No change at site", "Surface Hopping", "Wrong Oxygen Observed", "H20 close to HOTi", "End of Chain", "OTi losing H"]
labels += np.arange(0, np.max(hop_list)+1).tolist()


fig, ax = plt.subplots(1, 1, figsize=(6, 6))
ax.set_title("Analysis of Ti hopping for H2O Splitting")
ax.set_xlabel('Error Codes \ Number of Hops')
ax.set_ylabel('# of occurences')

ax.bar(np.arange(np.min(hop_list), np.max(hop_list)+1), hop_counts)

ax.set_xticks(
    np.arange(np.min(hop_list), np.max(hop_list)+1),
    labels[-(np.max(hop_list) + 1 - np.min(hop_list)):],
    rotation=-45, ha="left"
    )
plt.show()

In [None]:
for key, value in config_dict.items():
    print(key, value)

In [None]:
draw_codes = config_codes.copy()[1:, :]

water_bound = np.array([0, 3, 4], dtype=np.int8)
bulk_bound = np.array([1, 2], dtype=np.int8)

wb_tf = np.any(draw_codes[..., np.newaxis] == water_bound[np.newaxis, np.newaxis, :], axis=-1)

# Filter out oxygen that never changes state
has_changed = np.any(draw_codes != draw_codes[0, :], axis=0)
always_waterbound = np.all(np.any(draw_codes[..., np.newaxis] == water_bound[np.newaxis, np.newaxis, :], axis=-1), axis=0)
always_bulk = np.all(np.any(draw_codes[..., np.newaxis] == bulk_bound[np.newaxis, np.newaxis, :], axis=-1), axis=0)
always_relevant = np.logical_not(np.logical_or(always_waterbound, always_bulk))

particle_selection = np.logical_and(has_changed,always_relevant)

draw_codes = draw_codes[:, particle_selection]
draw_codes = draw_codes[:, :]

norm = mpl.colors.Normalize(vmin=np.min([v for v in config_dict.values()]), vmax=np.max([v for v in config_dict.values()]))
fig, ax = plt.subplots(1, 1, figsize=(25, 6))
im = ax.imshow(draw_codes.T, aspect='auto', norm=norm, cmap="tab10", interpolation='nearest')
ax.set_xlim([0, draw_codes.shape[0]-1])
cb = fig.colorbar(im)
cb.ax.set_yticks(np.arange(len(config_dict)), labels=[key for key in config_dict.keys()])
plt.show()

In [None]:
%matplotlib auto

cur_step = 774
oxygen_index = 416
cur_gone_bound_oxygen = [65, 130, 64, 60, 70, 71]
neighs = [60, 63]

print(hopping_test.trajectory[cur_step][oxygen_index])
HOchainFinder.plot_snapshot(hopping_test.trajectory[cur_step], cur_gone_bound_oxygen[0], neighs)

In [None]:
%matplotlib auto

plot_num = 1
print(n_specials[plot_num])
chain_finder.plot_special_config(trajectory[plot_num], natural_cutoffs(trajectory[plot_num], mult=0.75))