# reference
https://docs.e3nn.org/en/latest/guide/periodic_boundary_conditions.html#example-crystal-structures

In [1]:
import torch
import e3nn
import ase
import ase.neighborlist
import torch_geometric
import torch_geometric.data

default_dtype = torch.float64
torch.set_default_dtype(default_dtype)

  from .autonotebook import tqdm as notebook_tqdm


In [21]:
# A lattice is a 3 x 3 matrix
# The first index is the lattice vector (a, b, c)
# The second index is a Cartesian index over (x, y, z)

# Polonium with Simple Cubic Lattice
po_lattice = torch.eye(3) * 3.340  # Cubic lattice with edges of length 3.34 AA
po_coords = torch.tensor([[0., 0., 0.,]])
po_types = ['Po']

# Silicon with Diamond Structure
si_lattice = torch.tensor([
    [0.      , 2.734364, 2.734364],
    [2.734364, 0.      , 2.734364],
    [2.734364, 2.734364, 0.      ]
])
si_coords = torch.tensor([
    [1.367182, 1.367182, 1.367182],
    [0.      , 0.      , 0.      ]
])
si_types = ['Si', 'Si']

po = ase.Atoms(symbols=po_types, positions=po_coords, cell=po_lattice, pbc=True)
si = ase.Atoms(symbols=si_types, positions=si_coords, cell=si_lattice, pbc=True)

In [3]:
po.get_positions(), po.cell.array, po.symbols

(array([[0., 0., 0.]]),
 array([[3.34, 0.  , 0.  ],
        [0.  , 3.34, 0.  ],
        [0.  , 0.  , 3.34]]),
 Symbols('Po'))

In [8]:
si.get_positions(), si.cell.array, si.symbols

(array([[1.367182, 1.367182, 1.367182],
        [0.      , 0.      , 0.      ]]),
 array([[0.      , 2.734364, 2.734364],
        [2.734364, 0.      , 2.734364],
        [2.734364, 2.734364, 0.      ]]),
 Symbols('Si2'))

In [6]:
radial_cutoff = 3.5  # Only include edges for neighboring atoms within a radius of 3.5 Angstroms.
type_encoding = {'Po': 0, 'Si': 1}
type_onehot = torch.eye(len(type_encoding))

dataset = []

dummy_energies = torch.randn(2, 1, 1)  # dummy energies for example

for crystal, energy in zip([po, si], dummy_energies):
    # edge_src and edge_dst are the indices of the central and neighboring atom, respectively
    # edge_shift indicates whether the neighbors are in different images / copies of the unit cell
    edge_src, edge_dst, edge_shift = ase.neighborlist.neighbor_list("ijS", a=crystal, cutoff=radial_cutoff, self_interaction=True)
    print(edge_src)
    print(edge_dst)
    print(edge_shift)

    data = torch_geometric.data.Data(
        pos=torch.tensor(crystal.get_positions()),
        lattice=torch.tensor(crystal.cell.array).unsqueeze(0),  # We add a dimension for batching
        x=type_onehot[[type_encoding[atom] for atom in crystal.symbols]],  # Using "dummy" inputs of scalars because they are all C
        edge_index=torch.stack([torch.LongTensor(edge_src), torch.LongTensor(edge_dst)], dim=0),
        edge_shift=torch.tensor(edge_shift, dtype=default_dtype),
        energy=energy  # dummy energy (assumed to be normalized "per atom")
    )

    dataset.append(data)

print(dataset)
len(dataset)

[0 0 0 0 0 0 0]
[0 0 0 0 0 0 0]
[[ 0  0  1]
 [ 0  1  0]
 [ 1  0  0]
 [ 0  0  0]
 [ 0 -1  0]
 [-1  0  0]
 [ 0  0 -1]]
[0 0 0 0 0 1 1 1 1 1]
[1 1 0 1 1 0 1 0 0 0]
[[ 0  1  0]
 [ 1  0  0]
 [ 0  0  0]
 [ 0  0  1]
 [ 0  0  0]
 [ 0 -1  0]
 [ 0  0  0]
 [ 0  0  0]
 [-1  0  0]
 [ 0  0 -1]]
[Data(x=[1, 2], edge_index=[2, 7], pos=[1, 3], lattice=[1, 3, 3], edge_shift=[7, 3], energy=[1, 1]), Data(x=[2, 2], edge_index=[2, 10], pos=[2, 3], lattice=[1, 3, 3], edge_shift=[10, 3], energy=[1, 1])]


2

In [13]:
from ase.lattice import BCC
alphairon = BCC(a = 1)
alphairon.conventional(), alphairon.description() 

(CUB(a=1),
 'BCC(a=1)\n  Variant name: BCC\n  Special point names: GHPN\n  Default path: GHNGPH,PN\n\n  Special point coordinates:\n    G   0.0000  0.0000  0.0000\n    H   0.5000 -0.5000  0.5000\n    P   0.2500  0.2500  0.2500\n    N   0.0000  0.0000  0.5000\n')

In [22]:
from ase.visualize import view
view(si)

<subprocess.Popen at 0x23703f968d0>

In [15]:
alphairon.tocell()

Cell([[-0.5, 0.5, 0.5], [0.5, -0.5, 0.5], [0.5, 0.5, -0.5]])

In [24]:
from ase.lattice import CUB
a = CUB(a=1)
view(a)

TypeError: 'CUB' object is not iterable

In [31]:
fe_lattice = torch.eye(3) * 4
fe_coords = (torch.tensor([
    [0.      , 0.      , 0.      ],
    [0, 0, 1],
    [0, 1, 0],
    [0, 1, 1],
    [1, 0, 0],
    [1, 0, 1],
    [1, 1, 0],
    [1, 1, 1],
    [0.5, 0.5, 0.5]
]) - 0.5) * 4
print(fe_coords)
fe_types = ['Fe'] * 9

fe = ase.Atoms(symbols=fe_types, positions=fe_coords, cell=fe_lattice, pbc=False)
view(fe)

tensor([[-2., -2., -2.],
        [-2., -2.,  2.],
        [-2.,  2., -2.],
        [-2.,  2.,  2.],
        [ 2., -2., -2.],
        [ 2., -2.,  2.],
        [ 2.,  2., -2.],
        [ 2.,  2.,  2.],
        [ 0.,  0.,  0.]])


<subprocess.Popen at 0x23703f96860>