Skip to content

Commit

Permalink
bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
asaboor-gh committed Aug 3, 2023
1 parent 6badd26 commit e966290
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 39 deletions.
71 changes: 36 additions & 35 deletions ipyvasp/_lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,16 +1005,22 @@ def splot_bz(
if is3d:
XYZ, UVW = (np.ones_like(s_basis) * z0).T, s_basis.T
quiver3d(
*XYZ, *UVW, C="k", L=0.7, ax=ax, arrowstyle="-|>", mutation_scale=7
*XYZ,
*UVW,
C=color,
L=0.7,
ax=ax,
arrowstyle="-|>",
mutation_scale=7,
)
else:
s_zero = [0 for _ in s_basis] # either 3 or 2.
ax.quiver(
s_zero,
s_zero,
*s_basis[:, idxs[plane]].T,
lw=0.9,
color="navy",
lw=0.7,
color=color,
angles="xy",
scale_units="xy",
scale=1,
Expand Down Expand Up @@ -1415,14 +1421,17 @@ def iplot_bz(


# Cell
def _fix_sites(poscar_data, tol=1e-2, eqv_sites=False, translate=None):
def _fix_sites(
poscar_data, tol=1e-2, eqv_sites=False, translate=None, origin=(0, 0, 0)
):
"""Add equivalent sites to make a full data shape of lattice. Returns same data after fixing.
It should not be exposed mostly be used in visualizations"""
if not isinstance(origin, (tuple, list, np.ndarray)) or len(origin) != 3:
raise ValueError("origin must be a list, tuple or numpy array of length 3.")

pos = (
poscar_data.positions.copy()
) # We can also do poscar_data.copy().positions that copies all contents.
if hasattr(poscar_data.metadata, "origin"):
pos = pos + poscar_data.metadata.origin # Move towards origin of basis

labels = np.array(poscar_data.labels) # We need to store equivalent labels as well
out_dict = poscar_data.to_dict() # For output
Expand Down Expand Up @@ -1490,8 +1499,7 @@ def _fix_sites(poscar_data, tol=1e-2, eqv_sites=False, translate=None):
start += len(new_dict[k]["pos"])

out_dict["positions"] = np.vstack([new_dict[k]["pos"] for k in new_dict.keys()])
if hasattr(poscar_data.metadata, "origin"):
out_dict["positions"] -= poscar_data.metadata.origin # origin given by user
out_dict["positions"] -= origin # origin given by user to subtract

out_dict["metadata"]["eqv_labels"] = np.hstack(
[new_dict[k]["lab"] for k in new_dict.keys()]
Expand Down Expand Up @@ -1581,6 +1589,7 @@ def iplot_lattice(
bond_tol=1e-3,
eqv_sites=True,
translate=None,
origin=(0, 0, 0),
fig=None,
ortho3d=True,
mask_sites=None,
Expand Down Expand Up @@ -1616,7 +1625,7 @@ def iplot_lattice(
kwargs are passed to `iplot_bz`.
"""
poscar_data = _fix_sites(
poscar_data, tol=tol, eqv_sites=eqv_sites, translate=translate
poscar_data, tol=tol, eqv_sites=eqv_sites, translate=translate, origin=origin
)
bond_length = _get_bond_length(poscar_data, given=bond_length)

Expand Down Expand Up @@ -1780,6 +1789,7 @@ def splot_lattice(
bond_tol=1e-3,
eqv_sites=True,
translate=None,
origin=(0, 0, 0),
ax=None,
mask_sites=None,
showlegend=True,
Expand Down Expand Up @@ -1836,7 +1846,7 @@ def splot_lattice(
ix, iy = arr[ind], arr[ind + 1]

poscar_data = _fix_sites(
poscar_data, tol=tol, eqv_sites=eqv_sites, translate=translate
poscar_data, tol=tol, eqv_sites=eqv_sites, translate=translate, origin=origin
)
bond_length = _get_bond_length(poscar_data, given=bond_length)

Expand Down Expand Up @@ -2188,17 +2198,8 @@ def set_origin(poscar_data, origin):
>>> ax = poscar.splot_cell() # plot original cell
>>> poscar_shifted = poscar.scale((3,3,3)).set_origin((1/3,1/3,1/3))
>>> poscar_shifted.splot_lattice(ax=ax, plot_cell=False) # displays sites around original cell
.. warning::
The shifted origin will be used in all subsequent operations such as joining, rotating, writing etc.
You must know what you are doing after setting origin. It is recommended to use this function only for visualization purposes.
"""
if not isinstance(origin, (tuple, list, np.ndarray)) or len(origin) != 3:
raise ValueError("origin must be a list, tuple or numpy array of length 3.")
new_poscar = poscar_data.to_dict()
new_poscar["positions"] = poscar_data.positions - np.array(origin)
new_poscar["metadata"]["origin"] = origin # need this info in fixing sites
return serializer.PoscarData(new_poscar)
return _fix_sites(poscar_data, eqv_sites=False, origin=origin)


def set_zdir(poscar_data, hkl, phi=0):
Expand Down Expand Up @@ -2581,32 +2582,33 @@ def deform_poscar(poscar_data, deformation):
.. note::
This function can change underlying crystal structure if cell shape changes, to just change cell shape, use `transform` function instead.
"""
poscar_dict = poscar_data.to_dict() # make a copy

if callable(deformation):
try:
dmatrix = deformation(*poscar_data.basis)
poscar_dict["basis"] = deformation(*poscar_data.basis)
except:
raise ValueError(
"`deformation` function must be a function(a,b,c) -> 3x3 matrix to multiply with basis."
)
else:
dmatrix = deformation

if not isinstance(dmatrix, np.ndarray):
dmatrix = np.array(dmatrix)
if not isinstance(dmatrix, np.ndarray):
dmatrix = np.array(dmatrix)

if dmatrix.shape != (3, 3):
raise ValueError(
"`deformation` must be a 3x3 matrix or a function(a,b,c) -> 3x3 matrix to multiply with basis."
)
if dmatrix.shape != (3, 3):
raise ValueError(
"`deformation` must be a 3x3 matrix or a function(a,b,c) -> 3x3 matrix to multiply with basis."
)

poscar_data = poscar_data.to_dict() #
poscar_data["basis"] = (
poscar_data["basis"] * dmatrix
) # Update basis by elemetwise multiplication
poscar_data["metadata"][
# Update basis by elemetwise multiplication
poscar_dict["basis"] = poscar_data.basis * dmatrix

poscar_dict["metadata"][
"comment"
] = f'{poscar_data["metadata"]["comment"]} + Deformed POSCAR' # Update comment
return serializer.PoscarData(poscar_data) # Return new POSCAR
] = f'{poscar_data["metadata"]["comment"]} + Deformed POSCAR'
return serializer.PoscarData(poscar_dict) # Return new POSCAR


def view_poscar(poscar_data, **kwargs):
Expand All @@ -2617,4 +2619,3 @@ def view(elev=30, azim=30, roll=0):
ax.view_init(elev=elev, azim=azim, roll=roll)

return interactive(view, elev=(0, 180), azim=(0, 360), roll=(0, 360))

16 changes: 12 additions & 4 deletions ipyvasp/lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def ngl_viewer(
height="400px",
plot_vectors=True,
dashboard=False,
origin=(0, 0, 0),
):
"""Display structure in Jupyter notebook using nglview.
Expand Down Expand Up @@ -132,7 +133,9 @@ def ngl_viewer(
raise ImportError("Please install nglview to use this function.")

if plot_cell: # Only show equivalent sites if plotting cell
poscar = POSCAR(data=plat._fix_sites(poscar.data, eqv_sites=True))
poscar = POSCAR(
data=plat._fix_sites(poscar.data, eqv_sites=True, origin=origin)
)

_types = list(poscar.data.types.keys())
_sizes = [0.5 for _ in _types]
Expand Down Expand Up @@ -290,14 +293,19 @@ def view_kpath(self):

return KpathWidget(path=str(self.path.parent), glob=self.path.name)

@_sub_doc(plat.iplot_lattice)
@_sig_kwargs(plat.iplot_lattice, ("poscar_data",))
def view_widegt(self, **kwargs):
self.__class__._update_kws = kwargs # attach to class, not self
return iplot2widget(self.iplot_lattice(**kwargs))

def update_widget(self, handle):
"Update widget (if shown in notebook) after some operation on POSCAR with `handle` returned by `view_widget`"
iplot2widget(self.iplot_lattice(**self._update_kws), fig_widget=handle)
@_sig_kwargs(plat.iplot_lattice, ("poscar_data",))
def update_widget(self, handle, **kwargs):
"""Update widget (if shown in notebook) after some operation on POSCAR with `handle` returned by `view_widget`
kwargs are passed to `self.iplot_lattice` method.
"""
kwargs = {**self.__class__._update_kws, **kwargs}
iplot2widget(self.iplot_lattice(**kwargs), fig_widget=handle)

@classmethod
def from_file(cls, path):
Expand Down

0 comments on commit e966290

Please sign in to comment.