In [1]:
import math
import numpy as np
import torch
from torch import nn

In [None]:
# def b_vectors(a_m): # a_m = lattice constant, b_vectors = reciprocal lattice vectors
#     b1 = (2*np.pi/a_m) * np.array([ 1.0, -1/np.sqrt(3.0) ])
#     b2 = (2*np.pi/a_m) * np.array([ 0.0, 2/np.sqrt(3.0) ])
#     b3 = -(b1+b2)
#     return np.stack([b1, b2, b3])

def b_vectors(a_m):
    """from paper: g_j = (4*pi / sqrt(3) / a_m) * [cos(2*pi*j/3), sin(2*pi*j/3)], for j=1,2,3"""
    g_list = []
    prefac = 4 * np.pi / (np.sqrt(3) * a_m)
    for j in range(1, 4):  # j = 1, 2, 3
        angle = 2 * np.pi * j / 3
        g = prefac * np.array([np.cos(angle), np.sin(angle)])
        g_list.append(g)
    return g_list  # returns [g1, g2, g3]

In [None]:
class FeedForwardLayer(nn.Module):
    def __init__(self, L: int) -> None:
        super().__init__()

        # W^(l+1) h^l + b^(l+1)
        self.Wl_1p = nn.Linear(L, L)
        # (nonlinear) hyperbolic tangent activation function
        self.tanh = nn.Tanh()

    def forward(self, hl: torch.Tensor) -> torch.Tensor:
        # input should be of shape (N, L): h^l + tanh( W^(l+1) h^l + b^(l+1) )
        return hl + self.tanh(self.Wl_1p(hl))

In [None]:
class SlaterNet(nn.Module):
    def __init__(self, a: float, N: int, L: int = 4, num_layers: int = 3) -> None:
        super().__init__()

        # get G vectors
        G_vectors = torch.from_numpy(b_vectors(a)).float()
        self.G1_T = G_vectors[0].unsqueeze(-1)
        self.G2_T = G_vectors[1].unsqueeze(-1)

        # input embedding matrix: projects 4 features to L-dim
        self.W_0 = nn.Linear(4, L, bias=False)
        self.MLP_layers = nn.ModuleList( # MLP layers
            [FeedForwardLayer(L) for _ in range(num_layers)]
        )

        # matrix to hold the projection vectors (complex projectors for orbital) 
        # w_2j and w_2j+1 for j = 0, ... N-1
        self.complex_proj = nn.Parameter(
            torch.complex(real=torch.randn(L, N), imag=torch.randn(L, N))
        )
        self.denominator = math.sqrt(math.factorial(N))

    def forward(self, R: torch.Tensor) -> torch.Tensor:  # R should be of shape (N, 2)
        # compute the periodic features
        G1_R = torch.matmul(R, self.G1_T)
        G2_R = torch.matmul(R, self.G2_T)
        features_R = torch.cat(
            (torch.sin(G1_R), torch.sin(G2_R), torch.cos(G1_R), torch.cos(G2_R)), dim=1
        ) # shape should now be (N, 4)

        # embed in higher_dimensional space to get h^0
        h = self.W_0(features_R)

        # pass through MLP layers
        for layer in self.MLP_layers:
            h = layer(h)

        # form complex matrix as in Eq. 2
        WF_matrix = torch.matmul(h.to(torch.complex64), self.complex_proj)

        # compute determinant
        determinant = torch.linalg.det(WF_matrix)
        result = determinant / self.denominator
        return result

### Testing model

In [7]:
a_m = 8.031
test_model = SlaterNet(a=a_m, N=10, L=5, num_layers=3)
test_model.eval()

SlaterNet(
  (W_0): Linear(in_features=4, out_features=5, bias=False)
  (MLP_layers): ModuleList(
    (0): FeedForwardLayer(
      (Wl_1p): Linear(in_features=5, out_features=5, bias=True)
      (tanh): Tanh()
    )
    (1): FeedForwardLayer(
      (Wl_1p): Linear(in_features=5, out_features=5, bias=True)
      (tanh): Tanh()
    )
    (2): FeedForwardLayer(
      (Wl_1p): Linear(in_features=5, out_features=5, bias=True)
      (tanh): Tanh()
    )
  )
)

In [9]:
R = torch.randn(10, 2)
print(R.dtype)
phi_HF = test_model(R)
print(phi_HF)

torch.float32
tensor(-1.3063e-36+2.4554e-36j, grad_fn=<DivBackward0>)
