In [None]:
import math

import numpy as np
import torch
from torch import nn

In [22]:
# ---------- reciprocal lattice ----------
def reciprocal_vectors(a):
    """Generates the 3 shortest moiré reciprocal vectors G_1,2,3, 60° apart, six-fold symmetry"""
    g = 4 * np.pi / (np.sqrt(3) * a)  # (length |G|=4π/√3a)
    G1 = np.array([g, 0.0])  # first vector along x-axis

    def rot(v, ang):
        return np.array(
            [
                np.cos(ang) * v[0] - np.sin(ang) * v[1],  # rotate G1 by +60°, get G2
                np.sin(ang) * v[0] + np.cos(ang) * v[1],  # rotate G1 by –60°, get G3
            ]
        )

    return np.stack([G1, rot(G1, np.pi / 3), rot(G1, -np.pi / 3)])

In [None]:
class FeedForwardLayer(nn.Module):
    def __init__(self, L: int) -> None:
        super().__init__()

        # W^(l+1) h^l + b^(l+1)
        self.Wl_1p = nn.Linear(L, L)

        # (nonlinear) hyperbolic tangent activation function
        self.tanh = nn.Tanh()

    def forward(self, hl: torch.Tensor) -> torch.Tensor:
        # input should be of shape (N, L)
        # h^l + tanh( W^(l+1) h^l + b^(l+1) )
        return hl + self.tanh(self.Wl_1p(hl))

In [44]:
class SlaterNet(nn.Module):
    def __init__(self, a: float, N: int, L: int = 4, num_layers: int = 3) -> None:
        super().__init__()

        # first, get G vectors
        G_vectors = torch.from_numpy(reciprocal_vectors(a)).float()
        self.G1_T = G_vectors[0].unsqueeze(-1)
        self.G2_T = G_vectors[1].unsqueeze(-1)

        # input embedding matrix
        self.W_0 = nn.Linear(4, L, bias=False)

        # multilayer perceptron neural network layers
        self.MLP_layers = nn.ModuleList(
            [FeedForwardLayer(L) for _ in range(num_layers)]
        )

        # matrix to hold the projection vectors
        # w_2j and w_2j+1 for j = 0, ... N-1
        self.complex_proj = nn.Parameter(
            torch.complex(real=torch.randn(L, N), imag=torch.randn(L, N))
        )

        # denominator of Eq. 2
        # = sqrt(N!)
        self.denominator = math.sqrt(math.factorial(N))

    def forward(self, R: torch.Tensor) -> torch.Tensor:
        # R should be of shape (N, 2)
        # first we need to compute the periodic features
        G1_R = torch.matmul(R, self.G1_T)
        G2_R = torch.matmul(R, self.G2_T)
        features_R = torch.cat(
            (torch.sin(G1_R), torch.sin(G2_R), torch.cos(G1_R), torch.cos(G2_R)), dim=1
        )
        # shape should now be (N, 4)

        # embed in higher_dimensional space to get h^0
        h = self.W_0(features_R)

        # pass through MLP layers
        for layer in self.MLP_layers:
            h = layer(h)

        # form complex matrix as in Eq. 2
        WF_matrix = torch.matmul(h.to(torch.complex64), self.complex_proj)

        # compute determinant
        determinant = torch.linalg.det(WF_matrix)

        # get result
        result = determinant / self.denominator
        return result

### Testing model

In [45]:
a_m = 8.5
test_model = SlaterNet(a=a_m, N=10, L=5, num_layers=3)
test_model.eval()

SlaterNet(
  (W_0): Linear(in_features=4, out_features=5, bias=False)
  (MLP_layers): ModuleList(
    (0-2): 3 x FeedForwardLayer(
      (Wl_1p): Linear(in_features=5, out_features=5, bias=True)
      (tanh): Tanh()
    )
  )
)

In [46]:
R = torch.randn(10, 2)
print(R.dtype)
phi_HF = test_model(R)
print(phi_HF)

torch.float32
tensor(1.4927e-36+1.4403e-36j, grad_fn=<DivBackward0>)
