In [None]:
import sys
sys.path.append("../py_src")

from glob import glob
import os

import numpy as np
import matplotlib.pyplot as plt

import sort_neigh

from ase.io import read as ase_read
from ase.neighborlist import natural_cutoffs, NeighborList

In [None]:
target_file = "../test_data/220523_cunano_mcswap/mcmd/1000opt.lammpstrj"
t_file = os.path.abspath(target_file)
cur_dir = os.path.dirname(t_file)
cur_fname = os.path.basename(t_file).split(".")[0]

rcut=9.0
nmax=12
lmax=12
sigma=0.5
gamma_kernel=0.05
mode='pre_group'

cur_sorter = sort_neigh.NeighbourSort(
    rcut=rcut, nmax=nmax, lmax=lmax, 
    sigma=sigma, gamma_kernel=gamma_kernel
)

cur_out_dir = os.path.join(cur_dir, cur_fname+"_out/")
cur_out_file = os.path.join(cur_dir, cur_fname+"_count.txt")

cur_sorted_counts, cur_timesteps, cur_sorted_cats = cur_sorter.load_sort_cat(cur_out_file)

In [None]:
where_cond = np.argwhere(cur_sorted_counts[:, -1]>0)

In [None]:
which_vac = where_cond[0]

ts_path = cur_out_dir+("/ts_%u/cunano_%u.lammpstrj"%(which_vac, which_vac))
site_part = ase_read(ts_path)
site_positions = site_part.get_positions()

%matplotlib auto
plt.ion()

fig = plt.figure(figsize=(12, 8))
ax = fig.add_subplot(projection='3d')

sc = ax.scatter(
    site_positions[:-15, 0], site_positions[:-15, 1], site_positions[:-15, 2],
    c="tab:grey", alpha=0.07,
    s=400, edgecolors="k"
)
sc = ax.scatter(
    site_positions[-15:, 0], site_positions[-15:, 1], site_positions[-15:, 2], 
    c="tab:blue", alpha=1,
    s=400, edgecolors="k"
)

In [None]:
from ase.visualize import view

cut_off = natural_cutoffs(site_part, mult=0.9)
neighbour_list = NeighborList(cut_off, bothways=True, self_interaction=True)
neighbour_list.update(site_part)

for index in range(len(site_part)-15, len(site_part)):
    neighbour_indices, trash = neighbour_list.get_neighbors(index)
    neighbour_particle = site_part[neighbour_indices[:-1]]
    n_neigh, class_id = cur_sorter.classifier.classify(
        neighbour_particle, mode=mode, ensure_position=False
    )
    if n_neigh > 12:
        sc = ax.scatter(
            site_positions[neighbour_indices[1:-1], 0], 
            site_positions[neighbour_indices[1:-1], 1], 
            site_positions[neighbour_indices[1:-1], 2],
            c="tab:red", alpha=0.8,
            s=400, edgecolors="g"
        )
        
        view(neighbour_particle)