In [None]:
import numpy as np
import numpy.typing as npt
import matplotlib.pyplot as plt

In [None]:
from mapsy.data import Grid

In [None]:
from ase.cell import Cell

## Minimal Cell

In [None]:
minimal_cell = Cell(np.eye(3))

In [None]:
minimal_cell.reciprocal()

In [None]:
minimal_cell.volume

In [None]:
minimal_grid = Grid(cell=np.eye(3), scalars = [2, 2, 2])

In [None]:
print(minimal_grid.volume,minimal_grid.ndata)

In [None]:
minimal_grid.coordinates.reshape(3,2*2*2).T

In [None]:
minimal_grid.corners

In [None]:
minimal_grid.coordinates.reshape(3,2*2*2).T

In [None]:
minimal_grid.coordinates[:,0,0,0]

In [None]:
minimal_grid.coordinates[:,0,0,1]

In [None]:
grid = minimal_grid
#
def plot_boundaries_xy(grid):
    # cell boundaries
    v0 = np.zeros(2)
    v1 = grid.cell[0,:2]
    v2 = grid.cell[1,:2]
    v3 = grid.cell[0,:2] + grid.cell[1,:2]
    plt.plot([v0[0],v1[0]],[v0[1],v1[1]],':',color='tab:blue')
    plt.plot([v0[0],v2[0]],[v0[1],v2[1]],':',color='tab:blue')
    plt.plot([v1[0],v3[0]],[v1[1],v3[1]],':',color='tab:blue')
    plt.plot([v2[0],v3[0]],[v2[1],v3[1]],':',color='tab:blue')

def plot_gridpoints_xy(grid):
    # gridpoints and their periodic images
    v0 = np.zeros(2)
    v1 = grid.cell[0,:2]
    v2 = grid.cell[1,:2]
    v3 = grid.cell[0,:2] + grid.cell[1,:2]
    plt.scatter(grid.coordinates[0,:,:,0],grid.coordinates[1,:,:,0],color='tab:red')
    # some periodic images of the gripoints
    plt.scatter(grid.coordinates[0,:,:,0]+v1[0],grid.coordinates[1,:,:,0]+v1[1],color='tab:red',alpha=0.2)
    plt.scatter(grid.coordinates[0,:,:,0]+v2[0],grid.coordinates[1,:,:,0]+v2[1],color='tab:red',alpha=0.2)
    plt.scatter(grid.coordinates[0,:,:,0]+v3[0],grid.coordinates[1,:,:,0]+v3[1],color='tab:red',alpha=0.2)
    plt.scatter(grid.coordinates[0,:,:,0]-v1[0],grid.coordinates[1,:,:,0]-v1[1],color='tab:red',alpha=0.2)
    plt.scatter(grid.coordinates[0,:,:,0]-v2[0],grid.coordinates[1,:,:,0]-v2[1],color='tab:red',alpha=0.2)
    plt.scatter(grid.coordinates[0,:,:,0]-v3[0],grid.coordinates[1,:,:,0]-v3[1],color='tab:red',alpha=0.2)
    plt.scatter(grid.coordinates[0,:,:,0]+v2[0]-v1[0],grid.coordinates[1,:,:,0]+v2[1]-v1[1],color='tab:red',alpha=0.2)
    plt.scatter(grid.coordinates[0,:,:,0]+v1[0]-v2[0],grid.coordinates[1,:,:,0]+v1[1]-v2[1],color='tab:red',alpha=0.2)

def plot_corners_xy(grid,scale):
    # corners
    corners_origin = np.array([[0.,0.],[0.,0.],[0.,0.]])
    plt.quiver(corners_origin[:,0],corners_origin[:,1],grid.corners[grid.corners[:,2]==0][1:,0],grid.corners[grid.corners[:,2]==0][1:,1],color='tab:orange',scale=scale)

def plot_origin_xy(grid,origin):
    # origin and its periodic images
    v0 = np.zeros(2)
    v1 = grid.cell[0,:2]
    v2 = grid.cell[1,:2]
    v3 = grid.cell[0,:2] + grid.cell[1,:2]
    plt.scatter(origin[0],origin[1],color='tab:blue')
    plt.scatter(origin[0]+v1[0],origin[1]+v1[1],color='tab:blue',alpha=0.2)
    plt.scatter(origin[0]+v2[0],origin[1]+v2[1],color='tab:blue',alpha=0.2)
    plt.scatter(origin[0]+v3[0],origin[1]+v3[1],color='tab:blue',alpha=0.2)
    plt.scatter(origin[0]-v1[0],origin[1]-v1[1],color='tab:blue',alpha=0.2)
    plt.scatter(origin[0]-v2[0],origin[1]-v2[1],color='tab:blue',alpha=0.2)
    plt.scatter(origin[0]-v3[0],origin[1]-v3[1],color='tab:blue',alpha=0.2)
    plt.scatter(origin[0]+v2[0]-v1[0],origin[1]+v2[1]-v1[1],color='tab:blue',alpha=0.2)
    plt.scatter(origin[0]+v1[0]-v2[0],origin[1]+v1[1]-v2[1],color='tab:blue',alpha=0.2)


def plot_minimal_cell_xy(grid,origin=np.zeros(3),plot_corners=False):
    fig, ax = plt.subplots()
    ax.set_aspect('equal', 'box')
    ax.set_xlim(-1.1,1.6)
    ax.set_ylim(-1.1,1.6)
    #
    plot_boundaries_xy(grid)
    #
    plot_gridpoints_xy(grid)
    # corners
    if plot_corners : plot_corners_xy(grid,2.7)
    # random point
    if origin.any() : plot_origin_xy(grid,origin)

plot_minimal_cell_xy(grid,plot_corners=True)
plt.show()


In [None]:
origin = np.array([0.9, 0.2, 0.0])

In [None]:
r = grid.coordinates - origin[:, np.newaxis, np.newaxis, np.newaxis]
print(r.reshape(3,8).T)

In [None]:
plot_minimal_cell_xy(grid,origin)
#
origin_xy = grid.coordinates[:,:,:,0].reshape(3,4).T[:,:2]
r0_xy = -r[:,:,:,0].reshape(3,4).T[:,:2]
plt.quiver(origin_xy[:,0],origin_xy[:,1], r0_xy[:,0], r0_xy[:,1], color='tab:green', scale=2.8)
plt.show()

In [None]:
reciprocal_lattice = grid.reciprocal()
print(reciprocal_lattice)

In [None]:
s = np.einsum('lijk,ml->mijk', r, reciprocal_lattice)
s -= np.floor(s)
r = np.einsum('lm,lijk->mijk', grid.cell, s)

In [None]:
plot_minimal_cell_xy(grid,origin,plot_corners=True)
#
origin_xy = grid.coordinates[:,:,:,0].reshape(3,4).T[:,:2]
r_xy = -r[:,:,:,0].reshape(3,4).T[:,:2]
plt.quiver(origin_xy[:,0],origin_xy[:,1], r_xy[:,0], r_xy[:,1], color='tab:green', scale=2.8)
plt.quiver(origin_xy[:,0],origin_xy[:,1], r0_xy[:,0], r0_xy[:,1], color='tab:green', scale=2.8, alpha=0.3)
plt.show()

In [None]:
rmin = r
r2min = np.einsum('i...,i...', r, r)
t = r
for corner in grid.corners[1:]: 
    r = t + corner[:,np.newaxis,np.newaxis,np.newaxis]
    r2 = np.einsum('i...,i...', r, r)
    mask = r2 < r2min
    rmin = np.where(mask[np.newaxis, :, :, :], r, rmin)
    r2min = np.where(mask, r2, r2min)


In [None]:
plot_minimal_cell_xy(grid,origin)
#
origin_xy = grid.coordinates[:,:,:,0].reshape(3,4).T[:,:2]
r_xy = -rmin[:,:,:,0].reshape(3,4).T[:,:2]
plt.quiver(origin_xy[:,0],origin_xy[:,1], r_xy[:,0], r_xy[:,1], color='tab:green', scale=2.8)
plt.quiver(origin_xy[:,0],origin_xy[:,1], r0_xy[:,0], r0_xy[:,1], color='tab:green', scale=2.8, alpha=0.3)
plt.show()

In [None]:
r = grid.coordinates - origin[:, np.newaxis, np.newaxis, np.newaxis]
s = np.einsum('lijk,ml->mijk', r, reciprocal_lattice)
s -= np.rint(s)
r = np.einsum('lm,lijk->mijk', grid.cell, s)
#
plot_minimal_cell_xy(grid,origin)
#
origin_xy = grid.coordinates[:,:,:,0].reshape(3,4).T[:,:2]
r_xy = -r[:,:,:,0].reshape(3,4).T[:,:2]
plt.quiver(origin_xy[:,0],origin_xy[:,1], r_xy[:,0], r_xy[:,1], color='tab:green', scale=2.8)
plt.show()

In [None]:
dim = 1 # line passing throught the origin
axis = 1 # oriented along the i-th axis, i = 0, 1, 2
origin = np.array([0.9, 0.2, 0.0])

In [None]:
r = grid.coordinates - origin[:, np.newaxis, np.newaxis, np.newaxis]
s = np.einsum('lijk,ml->mijk', r, reciprocal_lattice)
s -= np.floor(s)
r = np.einsum('lm,lijk->mijk', grid.cell, s)

# determines the direction of the line
n = grid.cell[axis,:]
# removes the component directed along n
r = r - np.einsum('jkl,i->ijkl',np.einsum('ijkl,i->jkl',r,n),n)

# pre-corner-check results
rmin = r
r2min = np.einsum('i...,i...', r, r)

# check against corner shifts
t = r
for corner in grid.corners[1:]:
    r = t + corner[:,np.newaxis,np.newaxis,np.newaxis]
    r = r - np.einsum('jkl,i->ijkl',np.einsum('ijkl,i->jkl',r,n),n)
    r2 = np.einsum('i...,i...', r, r)
    mask = r2 < r2min
    rmin = np.where(mask[np.newaxis, :, :, :], r, rmin)
    r2min = np.where(mask, r2, r2min)

In [None]:
plot_minimal_cell_xy(grid,origin)
#
origin_xy = grid.coordinates[:,:,:,0].reshape(3,4).T[:,:2]
r_xy = -rmin[:,:,:,0].reshape(3,4).T[:,:2]
plt.quiver(origin_xy[:,0],origin_xy[:,1], r_xy[:,0], r_xy[:,1], color='tab:green', scale=2.8)
plt.axvline(x=origin[0], color='tab:blue', linestyle=(0, (5,1)))
plt.axvline(x=origin[0]-1., color='tab:blue', linestyle=(0, (5,1)),alpha=0.3)
plt.show()

## Hexagonal Cell

In [None]:
at = np.eye(3) * 1
at[1, 0] = 0.5
at[1, 1] *= np.sqrt(3) * 0.5
nr = np.array([2, 2, 2])
hexagonal_cell = Grid(cell=at, scalars=nr)

In [None]:
print(hexagonal_cell.cell)

In [None]:
print(hexagonal_cell.volume)

In [None]:
print(hexagonal_cell.reciprocal())

In [None]:
print("The first axis vector is {}".format(hexagonal_cell.cell[0,:]))
print("The second axis vector is {}".format(hexagonal_cell.cell[1,:]))
print("The third axis vector is {}".format(hexagonal_cell.cell[2,:]))

In [None]:
grid = hexagonal_cell
#
def plot_hexagonal_cell_xy(grid,origin=np.zeros(3),plot_corners=False):
    fig, ax = plt.subplots()
    ax.set_aspect('equal', 'box')
    ax.set_xlim(-1.7,2.6)
    ax.set_ylim(-1.1,1.6)
    plot_boundaries_xy(grid)
    #
    plot_boundaries_xy(grid)
    #
    plot_gridpoints_xy(grid)
    # corners
    if plot_corners : plot_corners_xy(grid,4.3)
    # random point
    if origin.any() : plot_origin_xy(grid,origin)
    # gridpoints
    ax.scatter(grid.coordinates[0,:,:,0],grid.coordinates[1,:,:,0],color='tab:red')

plot_hexagonal_cell_xy(grid,plot_corners=True)
plt.show()

In [None]:
hexagonal_cell.coordinates.reshape(3,2*2*2).T

In [None]:
mesh: npt.NDArray[np.float64] = np.mgrid[0 : 2, 0 : 2, 0 : 2]

In [None]:
mesh.reshape(3,8).T

In [None]:
basis = hexagonal_cell.cell / np.array([2,2,2])

In [None]:
basis

In [None]:
np.einsum("ij,jklm->iklm", basis.T, mesh).reshape(3,8).T

In [None]:
origin = np.array([1.1,0.1,0.])

In [None]:
r = grid.coordinates - origin[:, np.newaxis, np.newaxis, np.newaxis]
reciprocal_lattice = grid.reciprocal() 
s = np.einsum('lijk,ml->mijk', r, reciprocal_lattice)
s -= np.rint(s)
r = np.einsum('lm,lijk->mijk', grid.cell, s)
#
plot_hexagonal_cell_xy(grid,origin)
#
origin_xy = grid.coordinates[:,:,:,0].reshape(3,4).T[:,:2]
r_xy = -r[:,:,:,0].reshape(3,4).T[:,:2]
plt.quiver(origin_xy[:,0],origin_xy[:,1], r_xy[:,0], r_xy[:,1], color='tab:green', scale=4.5)
plt.show()

In [None]:
dr,dr2 = grid.get_min_distance(origin)
#
plot_hexagonal_cell_xy(grid,origin)
#
origin_xy = grid.coordinates[:,:,:,0].reshape(3,4).T[:,:2]
r_xy = -dr[:,:,:,0].reshape(3,4).T[:,:2]
plt.quiver(origin_xy[:,0],origin_xy[:,1], r_xy[:,0], r_xy[:,1], color='tab:green', scale=4.5)
plt.show()

In [None]:
dr.reshape(3,8).T

In [None]:
origin = np.array([0.35,0.2,0.])
dr,dr2 = grid.get_min_distance(origin,dim=1,axis=1)
#
plot_hexagonal_cell_xy(grid,origin)
#
origin_xy = grid.coordinates[:,:,:,0].reshape(3,4).T[:,:2]
r_xy = -dr[:,:,:,0].reshape(3,4).T[:,:2]
plt.quiver(origin_xy[:,0],origin_xy[:,1], r_xy[:,0], r_xy[:,1], color='tab:green', scale=4.5)
x = np.linspace(-1,2,100)
y = origin[1] + (x - origin[0]) * np.sqrt(3)
plt.plot(x,y,linestyle=(0, (5,1)))
plt.show()