In [None]:
import time
from math import erfc, exp, factorial, pi, sqrt

import numpy as np
import torch
from torch import nn

### Static Energy

In [2]:
# ## 1. Simulation constants & helper functions
a_m = 8.031  # nm   moiré lattice constant, supercell length, from paper (the lattice mismatch between WSe2 and WS2)
V0 = 15.0  # meV  from paper
eps_r = 5.0  # dielectric constant in paper
e2_4pieps0 = 14.399645  # meV·nm (|e|²/4πϵ0)
hbar2_over_2m = 108.857  # meV·nm² in paper
phi = np.pi / 4  # phase of the moiré potential from paper
n_sup = 3  # 3×3 super-cell
N_e = 6  # # electrons (= # occupied orbitals)
cut = 2  # |G|≤cut·|g| # cutoff for the plane-wave basis (G vectors) in the Hamiltonian

# --- Ewald parameters (renamed) ---
ew_alpha = 0.35  # nm⁻²  (splitting)
r_cut = 2.5  # real-space cutoff in L units
k_cut = 5  # k-space cutoff in 2π/L units
L = n_sup * a_m  # nm   PBC box length used in Ewald


def a_vectors(a_m):
    """Generates the 3 shortest moiré reciprocal vectors G_1,2,3, 60° apart, six-fold symmetry"""
    a1 = a_m * np.array([1.0, 0.0])
    a2 = a_m * np.array([0.5, np.sqrt(3.0) / 2.0])
    a3 = -(a1 + a2)  # optional third vector (120° w.r.t. a1)
    return np.array([a1, a2, a3])  # same interface style as your b_vectors()


def b_vectors(a_m):
    """from paper: g_j = (4*pi / sqrt(3) / a_m) * [cos(2*pi*j/3), sin(2*pi*j/3)], for j=1,2,3"""
    g_list = []
    prefac = 4 * np.pi / (np.sqrt(3) * a_m)
    for j in range(1, 4):  # j = 1, 2, 3
        angle = 2 * np.pi * j / 3
        g = prefac * np.array([np.cos(angle), np.sin(angle)])
        g_list.append(g)
    # print("g list", g_list)
    return np.array(g_list)  # returns [g1, g2, g3]


# real-space vectors of an n×n super-cell   (n = 3 here since we have 3x3 supercell)


def supercell_vectors(n, a_m):
    a1, a2 = a_vectors(a_m)
    return n * a1, n * a2


# ξ_M : configuration-independent Madelung constant
def madelung_offset(
    alpha=ew_alpha,  # α = 1/(4η²)
    r_lim=r_cut,
    k_lim=k_cut,
    L=L,
):
    """Compute ξ_M in Eq. (A12) Returns a scalar (dimensionless).  Multiply by e²/4πϵ₀ϵ_r later."""
    eta = 0.5 / np.sqrt(alpha)  # because α = 1/(4η²)
    area = L * L
    k0 = 2.0 * np.pi / L
    # ---- Real-space images   Σ_{L≠0} erfc(|L|/2η)/|L|
    rsum = 0.0
    maxn = int(np.ceil(r_lim))
    for nx in range(-maxn, maxn + 1):
        for ny in range(-maxn, maxn + 1):
            if nx == 0 and ny == 0:
                continue
            Rvec = np.array([nx, ny]) * L
            R = np.linalg.norm(Rvec)
            if R > r_lim * L:
                continue
            rsum += erfc(R / (2.0 * eta)) / R
    # ---- Reciprocal-space images   (2π/Area) Σ_{G≠0} e^{−η²G²}/G
    ksum = 0.0
    for mx in range(-k_lim, k_lim + 1):
        for my in range(-k_lim, k_lim + 1):
            if mx == 0 and my == 0:
                continue
            Gvec = np.array([mx, my]) * k0
            G = np.linalg.norm(Gvec)
            ksum += np.exp(-((eta * G) ** 2)) / G
    ksum *= 2.0 * np.pi / area
    xi0_L = 1.0 / (eta * np.sqrt(np.pi))  # ---- ξ^L_0 term   1 / (η √π)
    return rsum + ksum - xi0_L  # ---- ξ_M (Eq. A12)


ξ_M = madelung_offset() * e2_4pieps0 / eps_r  # compute once and store — units:   meV

### Defining the SlaterNet

In [3]:
class FeedForwardLayer(nn.Module):
    def __init__(self, embed_dim: int) -> None:
        super().__init__()

        # W^(l+1) h^l + b^(l+1)
        self.Wl_1p = nn.Linear(embed_dim, embed_dim)
        # (nonlinear) hyperbolic tangent activation function
        self.tanh = nn.Tanh()

    def forward(self, hl: torch.Tensor) -> torch.Tensor:
        # input should be of shape (N, embed_dim): h^l + tanh( W^(l+1) h^l + b^(l+1) )
        return hl + self.tanh(self.Wl_1p(hl))

In [4]:
class SlaterNet(nn.Module):
    def __init__(
        self, G_vectors: np.ndarray, N: int, embed_dim: int = 4, num_layers: int = 3
    ) -> None:
        super().__init__()

        # N is the number of electrons.
        self.N = N

        # get G vectors
        G_vectors = torch.from_numpy(G_vectors).float()
        self.G1_T = G_vectors[0].unsqueeze(-1)
        self.G2_T = G_vectors[1].unsqueeze(-1)

        # input embedding matrix: projects 4 features to embed_dim
        self.W_0 = nn.Linear(4, embed_dim, bias=False)
        self.MLP_layers = nn.ModuleList(  # MLP layers
            [FeedForwardLayer(embed_dim) for _ in range(num_layers)]
        )

        # matrix to hold the projection vectors (complex projectors for orbital)
        # w_2j and w_2j+1 for j = 0, ... N-1
        self.complex_proj = nn.Parameter(
            torch.complex(
                real=torch.randn(embed_dim, N), imag=torch.randn(embed_dim, N)
            )
        )
        self.denominator = sqrt(factorial(N))

    def forward(self, R: torch.Tensor) -> torch.Tensor:  # R should be of shape (N, 2)
        # compute the periodic features
        G1_R = torch.matmul(R, self.G1_T)
        G2_R = torch.matmul(R, self.G2_T)
        features_R = torch.cat(
            (torch.sin(G1_R), torch.sin(G2_R), torch.cos(G1_R), torch.cos(G2_R)), dim=1
        )  # shape should now be (N, 4)

        # embed in higher_dimensional space to get h^0
        h = self.W_0(features_R)

        # pass through MLP layers
        for layer in self.MLP_layers:
            h = layer(h)

        # form complex matrix as in Eq. 2
        WF_matrix = torch.matmul(h.to(torch.complex64), self.complex_proj)

        # compute determinant
        sign, log_abs_det = torch.linalg.slogdet(WF_matrix)
        determinant = sign * torch.exp(log_abs_det)
        result = determinant / self.denominator
        return result

### Testing model

In [5]:
test_model = SlaterNet(G_vectors=b_vectors(a_m), N=10, embed_dim=5, num_layers=3)
test_model.eval()

SlaterNet(
  (W_0): Linear(in_features=4, out_features=5, bias=False)
  (MLP_layers): ModuleList(
    (0-2): 3 x FeedForwardLayer(
      (Wl_1p): Linear(in_features=5, out_features=5, bias=True)
      (tanh): Tanh()
    )
  )
)

In [6]:
R = torch.randn(10, 2)
print(R.dtype)
phi_HF = test_model(R)
print(phi_HF)

torch.float32
tensor(-7.2255e-36+9.2948e-36j, grad_fn=<DivBackward0>)


### Helper functions to compute local energy

In [7]:
# ---------- Ewald helpers (using ew_alpha) ----------
def pairwise_real_space(R, alpha=ew_alpha, r_lim=r_cut, L=L):
    """Short-range (real-space) Ewald sum (equation A6 from the paper):
    E_real = ½ ∑_{i≠j} ∑_L erfc(√α·r_{ij}^L) / r_{ij}^L.,
    where α = 1/(4η²), and r_{ij}^L = |r_i - r_j + L|"""
    N, E = len(R), 0.0  # N:number particles
    maxn = int(np.ceil(r_lim))  # summing over neighbor cells from -max_n to +max_n
    for i in range(N):
        for j in range(i + 1, N):  # loop over ½ ∑_{i≠j}
            for nx in range(-maxn, maxn + 1):  # loop over n_x n_y
                for ny in range(-maxn, maxn + 1):
                    dr = R[i] - R[j] + np.array([nx, ny]) * L  # dr = r_i - r_j + n·L
                    r = np.linalg.norm(dr)  # r = |dr|
                    if r < 1e-9 or r > r_lim * L:
                        continue  # Skip self‐interaction (r≈0) or beyond cutoff r_lim·L
                    E += erfc(sqrt(alpha) * r) / r  # α = 1/(4η²) we choose
    return E


def structure_factor(R, k):  # Σ e^{ik·r}
    """Structure factor S(k) = Σ e^{ik·r} (sum over all particles)"""
    phase = R @ k
    return np.sum(np.cos(phase)) + 1j * np.sum(np.sin(phase))


def reciprocal_space(R, alpha=ew_alpha, k_lim=k_cut, L=L):
    """Long-range (reciprocal‐space) Ewald sum (equation A7 and A10 from the paper):
    E_recip = (π/V) ∑_{k≠0} [ e^{-k²/(4α)} / k² ] |S(k)|²
    with V = L² in 2D, fast convergence."""
    area, k0 = L * L, 2 * pi / L
    E = 0.0
    # 2) Sum over discrete wavevectors q = (m_x, m_y)·(2π/L), the paper has ∑_{q≠0};
    # here we loop m_x, m_y ∈ [−k_lim,…,+k_lim]
    for mx in range(-k_lim, k_lim + 1):
        for my in range(-k_lim, k_lim + 1):
            if mx == 0 and my == 0:
                continue  # skip the q = 0 term
            k = np.array([mx, my]) * k0  # q_vec = (m_x, m_y)·(2π/L)
            k2 = k @ k  # q² = |q_vec|² = q_x² + q_y²
            E += (
                exp(-k2 / (4 * alpha)) * abs(structure_factor(R, k)) ** 2 / k2
            )  # factor e^{–q²/(4α)} # e^{-q²/(4α)}/q² · |S(q)|²
    return (pi / area) * E


def self_energy(N, alpha=ew_alpha):
    """(equation A12 from the paper, Madelung constant) Self‐interaction correction: E_self = - ∑_i (√α / √π) · q_i²
    Here q_i are unit charges, so E_self = -N·(√α/√π)."""
    return -sqrt(alpha / pi) * N

In [8]:
def moire_potential(r, a_m=a_m, V0=V0, phi=phi):
    """V(r) = -2*V0*sum_{j=1}^{3} cos(g_j · r + phi)where g_j are 3 reciprocal lattice vectors (from paper)."""
    G = torch.from_numpy(
        np.array(b_vectors(a_m), dtype=np.float32)
    )  # Get the three reciprocal vectors, shape (3, 2)
    phase = torch.matmul(r, G.T) + phi  # r @ G.T + phi
    one_electron_moire = -2 * V0 * torch.sum(torch.cos(phase), dim=-1)
    return torch.sum(one_electron_moire)

In [9]:
def coulomb_ewald_2D(R):
    """
    Full 2-D Ewald energy for the set of positions R of shape (N, 2).
    Now includes ½ Σ_b ξ_M  so the result matches Eq. (A11) exactly.
    E_total = (e²/(4πϵ₀ ε_r))·( E_real + E_recip + E_self )
    """
    N = len(R)
    # position-dependent part (your original implementation)
    E_config = (
        (pairwise_real_space(R) + reciprocal_space(R) + self_energy(N))
        * e2_4pieps0
        / eps_r
    )
    # constant Madelung shift
    return E_config + 0.5 * N * ξ_M


def energy_static(R):
    """V_ext + V_ee  (independent of Ψ)."""
    return moire_potential(R)  # + coulomb_ewald_2D(R)

In [10]:
def compute_laplacian_hessian(net, R):
    # first, determine the shape of R
    N, D = R.shape  # N: number of electrons, D: dimensions (2D here)
    R_flattened = R.reshape(N * D)  # Flatten R to (N * D)

    # Define helper runctions that take an input of shape (N * D) and output a real scalar:
    def fn_real(X):
        return net(X.reshape(N, D)).real

    def fn_imag(X):
        return net(X.reshape(N, D)).imag

    # Compute the Hessian matrix w.r.t. both functions
    real_Hessian = torch.func.hessian(fn_real)(R_flattened)
    imag_Hessian = torch.func.hessian(fn_imag)(R_flattened)

    # Compute the Laplacian as the trace of the Hessian
    real_laplacian = torch.trace(real_Hessian)
    imag_laplacian = torch.trace(imag_Hessian)
    return real_laplacian + (1j * imag_laplacian)

In [11]:
# compute the local energy
def local_energy(net, R):
    """Local energy using the Hessian-based Laplacian."""

    psi = net(R)  # complex scalar

    # The only change: use compute_laplacian_hessian instead of nested loops.
    lap_psi = compute_laplacian_hessian(net, R)  # complex

    # Compute kinetic energy term: T = -ħ²/2m ∇²ψ / ψ
    kin = -(hbar2_over_2m * (lap_psi / psi)).real

    # Compute potential energy term: V = V_ext + V_ee
    V = energy_static(R)
    total_energy = kin + V
    return total_energy, psi  # return both energy and wavefunction value


vmapped_local_energy = torch.func.vmap(local_energy, in_dims=(None, 0))

In [12]:
x = torch.randn(5, 10, 2)
e, psi = vmapped_local_energy(test_model, x)
print(e.shape, psi.shape)

torch.Size([5]) torch.Size([5])


### VMC Loop

In [13]:
# Check if CUDA is available and set the device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [None]:
def mcmc_sampler(net, R_init, n_steps=200, step_size=0.1):
    """Metropolis-Hastings sampler to draw samples ~ |Ψ|².
    Inputs:
      - net: SlaterNet model
      - R_init: torch.Tensor (N_e,2), starting positions
      - n_steps: number of MCMC steps
      - step_size: Gaussian proposal standard deviation
    Returns:
      - samples: list of torch.Tensor configurations (one per step)"""

    R = R_init.clone().to(device)
    psi_sq = torch.abs(net(R)) ** 2  # Compute initial |Ψ|²
    samples = []
    for _ in range(n_steps):
        for i in range(R.shape[0]):  # propose move for each electron
            R_prop = R.clone()
            R_prop[i] += step_size * torch.randn_like(
                R[i]
            )  # Gaussian random displacement
            R_prop[i] = R_prop[i] % L  # Enforce PBC in a box of length L
            psi_sq_prop = (
                torch.abs(net(R_prop)) ** 2
            )  # evalualte |Ψ|² at the proposed R
            if (psi_sq_prop / (psi_sq + 1e-12)) > torch.rand(1, device=device):
                R[i] = R_prop[i]
                psi_sq = psi_sq_prop
        samples.append(R.clone())
    return samples


# training loop with Variational Monte Carlo
def train_vmc(net, n_iter=200, n_walkers=8, n_steps=200, step_size=0.1):
    """VMC training with a decaying learning rate: η(t) = η_0 * (1 + t/t0)^(-1) (page 13)"""

    η_0 = 1e-4  # initial learning rate (Table II)
    t0 = 1e5  # decay “time constant”
    rho = 5.0  # clipping threshold for local energy

    optimizer = torch.optim.Adam(net.parameters(), lr=η_0)
    # Adam will perform parameter updates θ ← θ – η·∇_θ L

    # Each walker is one configuration R of N_e electrons in 2D, drawn uniformly in [0, L)^2
    walkers = [L * torch.rand(net.N, 2, device=device) for _ in range(n_walkers)]

    # Main VMC loop ──────────────────────────────────────────────────────────────────
    for it in range(1, n_iter + 1):
        start_time = time.time()

        # — Update the learning rate —
        lr_t = η_0 * (1 + it / t0) ** -1
        for pg in optimizer.param_groups:
            pg["lr"] = lr_t

        # Prepare lists to collect data over all walkers & steps
        new_walkers = []  # to hold final config of each walker
        samples = []  # to hold all samples for each walker

        # Sample from each walker
        for w in walkers:
            samples.extend(mcmc_sampler(net, w, n_steps, step_size))
            new_walkers.append(samples[-1])

        # Convert list of samples to a tensor of shape (n_walkers * n_steps, N_e, 2)
        samples = torch.stack(samples)

        # compute the local energy for all the samples
        E_full, psi = vmapped_local_energy(net, samples)

        # reshape and process the energies
        E_full = E_full.squeeze(-1)
        E_clip = torch.clamp(E_full, -rho, +rho)
        logpsi = torch.log(torch.abs(psi.squeeze(-1)) + 1e-12)

        # true mean energy (for reporting)
        E_mean_full = E_full.mean()

        # clipped mean energy (for loss)
        E_mean_clip = E_clip.mean().detach()

        # VMC loss uses clipped energies
        loss = torch.mean((E_clip - E_mean_clip) * logpsi)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        walkers = new_walkers
        print(
            f"Iter {it:3d}/{n_iter:3d} | <E> = {E_mean_full.item():.6f} meV | lr = {lr_t:.3e}"
        )

        end_time = time.time()
        print(f"Time taken: {end_time - start_time}\n")

    return net

In [18]:
# train to show time
net = SlaterNet(
    G_vectors=b_vectors(a_m), N=N_e, embed_dim=64, num_layers=3
)  # -----number layers: 3; Perceptron dim: 64
trained_net = train_vmc(
    net,
    n_iter=5,  # paper says 150000 ------- Training iterations
    n_walkers=16,
    n_steps=256,
    step_size=0.05,
)

# ------n_walkers * n_steps = 4096 (MCMC batch size) namely gather 4096 (𝑅,𝐸_loc, lnΨ) where
# “n_walkers”：how many independent MCMC chains we run in parallel

Iter   1/  5 | <E> = 4471858208440320.000000 meV | lr = 1.000e-04
Time taken: 14.542210102081299
Iter   2/  5 | <E> = -14814428528640.000000 meV | lr = 1.000e-04
Time taken: 17.229722023010254
Iter   3/  5 | <E> = -596572031156224.000000 meV | lr = 1.000e-04
Time taken: 16.84843897819519
Iter   4/  5 | <E> = 8836599932518400.000000 meV | lr = 1.000e-04
Time taken: 14.976014137268066
Iter   5/  5 | <E> = -11950712946688.000000 meV | lr = 1.000e-04
Time taken: 16.461782932281494
