In [1]:
import torch
from torch import nn
import numpy as np

In [2]:
# ---------- reciprocal lattice ----------
def reciprocal_vectors(a):
    """Generates the 3 shortest moiré reciprocal vectors G_1,2,3, 60° apart, six-fold symmetry"""
    g = 4 * np.pi/(np.sqrt(3) * a) #(length |G|=4π/√3a)
    G1 = np.array([g, 0.0]) # first vector along x-axis
    def rot(v, ang):
        return np.array([
            np.cos(ang) * v[0] - np.sin(ang) * v[1],  # rotate G1 by +60°, get G2
            np.sin(ang) * v[0] + np.cos(ang) * v[1]   # rotate G1 by –60°, get G3
        ])
    return np.stack([G1, rot(G1, np.pi/3), rot(G1,-np.pi/3)])

In [12]:
a_m = 8.5
G_vectors = torch.from_numpy(reciprocal_vectors(a_m)).float()
print(G_vectors)
G1 = G_vectors[0].unsqueeze(-1)
G2 = G_vectors[1].unsqueeze(-1)
print(G1)
print(G2)

tensor([[ 0.8536,  0.0000],
        [ 0.4268,  0.7392],
        [ 0.4268, -0.7392]])
tensor([[0.8536],
        [0.0000]])
tensor([[0.4268],
        [0.7392]])


In [14]:
class FFN(nn.Module):
    pass

In [13]:
class SlaterNet(nn.Module):
    def __init__(self, a: float, N: int) -> None:
        super().__init__()

        # first, get G vectors
        G_vectors = torch.from_numpy(reciprocal_vectors(a_m)).float()
        self.G1_T = G_vectors[0].unsqueeze(-1)
        self.G2_T = G_vectors[1].unsqueeze(-1)
    
    def forward(self, R: torch.Tensor) -> torch.Tensor:
        # R should be of shape (N, 2)
        # first we need to compute the periodic features
        G1_R = torch.matmul(R, self.G1_T)
        G2_R = torch.matmul(R, self.G2_T)
        features_R = torch.cat((torch.sin(G1_R), torch.sin(G2_R), torch.cos(G1_R), torch.cos(G2_R)), dim=1)
        # shape should now be (N, 4)
        pass