Skip to content

Commit

Permalink
bond_length as dict beside number
Browse files Browse the repository at this point in the history
  • Loading branch information
asaboor-gh committed Aug 5, 2023
1 parent e966290 commit 796ef4c
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 47 deletions.
121 changes: 79 additions & 42 deletions ipyvasp/_lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
import plotly.graph_objects as go

import matplotlib.pyplot as plt # For viewpoint
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from matplotlib.collections import LineCollection
from mpl_toolkits.mplot3d.art3d import Poly3DCollection, Line3DCollection
import matplotlib.colors as mplc

from ipywidgets import interactive
Expand Down Expand Up @@ -1519,34 +1520,57 @@ def translate_poscar(poscar_data, offset):
return _fix_sites(poscar_data, translate=offset, eqv_sites=False)


def get_pairs(basis, positions, r, tol=1e-3):
"""Returns a tuple of Lattice (coords,pairs), so coords[pairs] given nearest site bonds.
def get_pairs(coords, r, tol=1e-3):
"""Returns a tuple of Points(coords,pairs, dist), so coords[pairs] given nearest site bonds.
Parameters
----------
basis : array_like
3x3 array of lattice basis vectors in rows.
positions : array_like
Array(N,3) of fractional positions of lattice sites. If coordinates positions, provide unity basis.
coords : array_like
Array(N,3) of cartesian positions of lattice sites.
r : float
Cartesian distance between the pairs in units of Angstrom e.g. 1.2 -> 1.2E-10.
tol : float
Tolerance value. Default is 10^-3.
"""
coords = to_R3(basis, positions)
if np.ndim(coords) != 2 and np.shape(coords)[1] != 3:
raise ValueError("coords must be a 2D array of shape (N,3).")

tree = KDTree(coords)
inds = np.array([[*p] for p in tree.query_pairs(r, eps=tol)])
return serializer.dict2tuple("Lattice", {"coords": coords, "pairs": inds})
if len(inds) > 0:
dist = np.linalg.norm(coords[inds[:, 0],] - coords[inds[:, 1],], axis=1)
else:
dist = np.array([])
return serializer.dict2tuple(
"Points", {"coords": coords, "pairs": inds, "dist": dist}
)


def _get_bond_length(poscar_data, given=None):
"`given` bond length should be in range [0,1] which is scaled to V^(1/3)."
if given is not None:
return given * poscar_data.volume ** (1 / 3)
def _get_bond_length(poscar_data, bond_length=None):
"Given `bond_length` should be in unit of Angstrom, and can be a number of dict like {1.2:['Fe','O'],...}"
if bond_length is not None:
if isinstance(bond_length, (int, float, np.integer)):
return bond_length
elif isinstance(bond_length, dict):
for k, v in bond_length.items():
if not isinstance(k, (int, float, np.integer)):
raise TypeError(
f"Value of key `{k}` should be a number in unit of Angstrom."
)
if not isinstance(v, (list, tuple, np.ndarray)) or len(v) != 2:
raise TypeError(
f"Value of key `{k}` should be a list of two elements like ['Fe', 'O']."
)

return max(
list(bond_length.keys())
) # return the maximum distance, will filter later
else:
raise TypeError("`bon_length` should be a number or a dict.")
else:
keys = list(poscar_data.types.keys())
if len(keys) == 1:
keys = [*keys, *keys] # strill need it to be a list of two elements
keys = [*keys, *keys] # still need it to be a list of two elements

dists = [poscar_data.get_distance(k1, k2) for k1, k2 in combinations(keys, 2)]
return (
Expand Down Expand Up @@ -1579,14 +1603,29 @@ def _masked_data(poscar_data, mask_sites):
return pick # could be duplicate indices


def _filter_pairs(poscar_data, pairs, dist, bond_length):
"""Filter pairs based on bond_length dict like {1.2:['Fe','O'],...}. Returns same pairs otherwise."""
if isinstance(bond_length, dict):
new_pairs = []
for pair, d in zip(pairs, dist):
t1, t2 = [poscar_data.labels[idx].split()[0] for idx in pair]
for k, v in bond_length.items():
if tuple(v) in [(t1, t2), (t2, t1)] and d <= k:
new_pairs.append(pair)

return np.unique(new_pairs, axis=0) # remove duplicates

# Return all pairs otherwise
return pairs # None -> auto calculate bond_length, number -> use that number


# Cell
def iplot_lattice(
poscar_data,
sizes=10,
colors=None,
bond_length=None,
tol=1e-2,
bond_tol=1e-3,
eqv_sites=True,
translate=None,
origin=(0, 0, 0),
Expand All @@ -1607,7 +1646,7 @@ def iplot_lattice(
colors : tuple
Sequence of colors for each type. Automatically generated if not provided.
bond_length : float
Length of bond in fractional unit [0,1]. It is scaled to V^1/3 and auto calculated if not provides.
Length of bond in Angstrom. Auto calculated if not provides. Can be a dict like {1.2:['Fe','O'],...} to specify bond length between specific types.
mask_sites : callable
Provide a mask function `f(index, x,y,z) -> bool` to show only selected sites.
For example, to show only sites with z > 0.5, use `mask_sites = lambda i, x,y,z: x > 0.5`.
Expand All @@ -1627,21 +1666,20 @@ def iplot_lattice(
poscar_data = _fix_sites(
poscar_data, tol=tol, eqv_sites=eqv_sites, translate=translate, origin=origin
)
bond_length = _get_bond_length(poscar_data, given=bond_length)
blen = _get_bond_length(poscar_data, bond_length)

sites = None
pos = poscar_data.positions
coords = poscar_data.coords
if (
mask_sites is not None
): # not None is important, as it can be False given by user
sites = _masked_data(poscar_data, mask_sites)
pos = poscar_data.positions[sites]
coords = poscar_data.coords[sites]
if not sites:
raise ValueError("No sites found with given mask_sites function.")

coords, pairs = get_pairs(
poscar_data.basis, pos, r=bond_length, tol=bond_tol
) # bond tolernce should be smaller than cell tolernce.
coords, pairs, dist = get_pairs(coords, r=blen)
pairs = _filter_pairs(poscar_data, pairs, dist, bond_length)

if not fig:
fig = go.Figure()
Expand Down Expand Up @@ -1786,7 +1824,6 @@ def splot_lattice(
colors=None,
bond_length=None,
tol=1e-2,
bond_tol=1e-3,
eqv_sites=True,
translate=None,
origin=(0, 0, 0),
Expand All @@ -1810,7 +1847,7 @@ def splot_lattice(
colors : tuple
Sequence of colors for each ion type. If None, automatically generated.
bond_length : float
Length of bond in fractional unit [0,1]. It is scaled to V^1/3 and auto calculated if not provides.
Length of bond in Angstrom. Auto calculated if not provides. Can be a dict like {1.2:['Fe','O'],...} to specify bond length between specific types.
alpha : float
Opacity of points and bonds.
mask_sites : callable
Expand All @@ -1822,8 +1859,7 @@ def splot_lattice(
Keyword arguments to pass to `plt.scatter` for plotting sites.
Default is just hint, you can pass any keyword argument that `plt.scatter` accepts.
bond_kws : dict
Keyword arguments to pass to `plt.plot` for plotting bonds.
Default is just hint, you can pass any keyword argument that `plt.plot` accepts.
Keyword arguments to pass to `LineCollection`/`Line3DCollection` for plotting bonds.
fmt_label : callable
If given, each site label is passed to it like fmt_label('Ga 1').
It must return a string or a list/tuple of length 2 with first item as label and second item as dictionary of keywords to pass to `plt.text`.
Expand All @@ -1848,19 +1884,18 @@ def splot_lattice(
poscar_data = _fix_sites(
poscar_data, tol=tol, eqv_sites=eqv_sites, translate=translate, origin=origin
)
bond_length = _get_bond_length(poscar_data, given=bond_length)
blen = _get_bond_length(poscar_data, bond_length)

sites = None
pos = poscar_data.positions # take all sites
coords = poscar_data.coords # take all sites
if mask_sites is not None: # not None is important, user can give anything
sites = _masked_data(poscar_data, mask_sites)
pos = poscar_data.positions[sites]
coords = poscar_data.coords[sites]
if not sites:
raise ValueError("No sites found with given mask_sites function.")

coords, pairs = get_pairs(
poscar_data.basis, positions=pos, r=bond_length, tol=bond_tol
) # bond tolernce should be smaller than cell tolernce.
coords, pairs, dist = get_pairs(coords, r=blen)
pairs = _filter_pairs(poscar_data, pairs, dist, bond_length)

labels = [poscar_data.labels[i] for i in sites] if sites else poscar_data.labels
if fmt_label is not None:
Expand Down Expand Up @@ -1920,16 +1955,16 @@ def splot_lattice(

bond_kws = {
"alpha": 0.7,
"solid_capstyle": "butt",
"capstyle": "butt",
**bond_kws,
} # bond_kws overrides alpha and solid_capstyle only
if not plane:
_ = [ax.plot(*c.T, c=_c, **bond_kws) for c, _c in zip(coords_n, colors_n)]
elif plane in "xyzxzyx":
_ = [
ax.plot(c[:, ix], c[:, iy], c=_c, **bond_kws)
for c, _c in zip(coords_n, colors_n)
]
} # bond_kws overrides alpha and capstyle only
# 3D LineCollection by default, very fast as compared to plot one by one.
lc = Line3DCollection(coords_n, colors=colors_n, **bond_kws)
if plane in "xyzxzyx":
lc = LineCollection(coords_n[:, :, [ix, iy]], colors=colors_n, **bond_kws)

ax.add_collection(lc)
ax.autoscale_view()

if not plane:
site_kws = {
Expand Down Expand Up @@ -2586,7 +2621,9 @@ def deform_poscar(poscar_data, deformation):

if callable(deformation):
try:
poscar_dict["basis"] = deformation(*poscar_data.basis)
poscar_dict["basis"] = np.array(
deformation(*poscar_data.basis)
) # mostly tuple
except:
raise ValueError(
"`deformation` function must be a function(a,b,c) -> 3x3 matrix to multiply with basis."
Expand Down
2 changes: 1 addition & 1 deletion ipyvasp/core/plot_toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,7 +1019,7 @@ def iplot2widget(fig, fig_widget=None, template=None):

fig_widget.layout = fig.layout

with fig_widget.batch_animate(0): # Disable animation to speed up
with fig_widget.batch_update():
for data in fig.data:
fig_widget.add_trace(data)

Expand Down
10 changes: 6 additions & 4 deletions ipyvasp/lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,12 @@ def ngl_viewer(
except ImportError:
raise ImportError("Please install nglview to use this function.")

if plot_cell: # Only show equivalent sites if plotting cell
poscar = POSCAR(
data=plat._fix_sites(poscar.data, eqv_sites=True, origin=origin)
# Only show equivalent sites if plotting cell, only shift origin otherwise
poscar = POSCAR( # don't change instance itself, make new one
data=plat._fix_sites(
poscar.data, eqv_sites=True if plot_cell else False, origin=origin
)
)

_types = list(poscar.data.types.keys())
_sizes = [0.5 for _ in _types]
Expand Down Expand Up @@ -295,7 +297,7 @@ def view_kpath(self):

@_sub_doc(plat.iplot_lattice)
@_sig_kwargs(plat.iplot_lattice, ("poscar_data",))
def view_widegt(self, **kwargs):
def view_widget(self, **kwargs):
self.__class__._update_kws = kwargs # attach to class, not self
return iplot2widget(self.iplot_lattice(**kwargs))

Expand Down

0 comments on commit 796ef4c

Please sign in to comment.