## run slaternet

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from math import erfc, sqrt, pi, exp
from torch.optim import Adam
from torch.autograd import grad
import matplotlib.pyplot as plt

In [6]:
# ## 1. Simulation constants & helper functions
a_m   = 8.031          # nm   moiré lattice constant, supercell length, from paper (the lattice mismatch between WSe2 and WS2)
V0    = 15.0         # meV  from paper
eps_r = 5.0         # dielectric constant in paper
e2_4pieps0 = 14.399645  # meV·nm (|e|²/4πϵ0)
hbar2_over_2m = 108.857 # meV·nm² in paper
phi = np.pi/4   # phase of the moiré potential from paper
n_sup = 3            # 3×3 super-cell
N_e   = 6               # # electrons (= # occupied orbitals)
cut   = 2  # |G|≤cut·|g| # cutoff for the plane-wave basis (G vectors) in the Hamiltonian

# --- Ewald parameters (renamed) ---
ew_alpha = 0.35      # nm⁻²  (splitting)
r_cut    = 2.5       # real-space cutoff in L units
k_cut    = 5         # k-space cutoff in 2π/L units
L     = n_sup * a_m  # nm   PBC box length used in Ewald

def a_vectors(a_m):
    """Generates the 3 shortest moiré reciprocal vectors G_1,2,3, 60° apart, six-fold symmetry"""
    a1 = a_m * np.array([1.0,                0.0           ])
    a2 = a_m * np.array([0.5,  np.sqrt(3.0) / 2.0          ])
    a3 = -(a1 + a2)        # optional third vector (120° w.r.t. a1)
    return [a1, a2, a3]    # same interface style as your b_vectors()


def b_vectors(a_m):
    """from paper: g_j = (4*pi / sqrt(3) / a_m) * [cos(2*pi*j/3), sin(2*pi*j/3)], for j=1,2,3"""
    g_list = []
    prefac = 4 * np.pi / (np.sqrt(3) * a_m)
    for j in range(1, 4):  # j = 1, 2, 3
        angle = 2 * np.pi * j / 3
        g = prefac * np.array([np.cos(angle), np.sin(angle)])
        g_list.append(g)
    # print("g list", g_list)
    return g_list  # returns [g1, g2, g3]

# real-space vectors of an n×n super-cell   (n = 3 here since we have 3x3 supercell)
def supercell_vectors(n, a_m):
    a1, a2 = a_vectors(a_m)
    return n*a1, n*a2


def moire_potential(r, a_m = a_m, V0 = V0, phi = phi):
    """ V(r) = -2*V0*sum_{j=1}^{3} cos(g_j · r + phi)where g_j are 3 reciprocal lattice vectors (from paper)."""
    G = np.array(b_vectors(a_m))  # Get the three reciprocal vectors, shape (3,2)
    phase = np.dot(r, G.T) + phi  # r @ G.T + phi
    one_electron_moire = -2 * V0 * np.sum(np.cos(phase), axis=-1)
    return np.sum(one_electron_moire)

# ---------- Ewald helpers (using ew_alpha) ----------
def pairwise_real_space(R, alpha=ew_alpha, r_lim=r_cut, L=L):
    """ Short‐range (real‐space) Ewald sum (equation A6 from the paper):
    E_real = ½ ∑_{i≠j} ∑_L erfc(√α·r_{ij}^L) / r_{ij}^L.,
    where α = 1/(4η²), and r_{ij}^L = |r_i - r_j + L|"""
    N, E = len(R), 0.0 #N:number particles
    maxn = int(np.ceil(r_lim)) # summing over neighbor cells from -max_n to +max_n
    for i in range(N):
        for j in range(i+1,N): # loop over ½ ∑_{i≠j}
            for nx in range(-maxn,maxn+1): #loop over n_x n_y
                for ny in range(-maxn,maxn+1):
                    dr = R[i]-R[j]+np.array([nx,ny])*L # dr = r_i - r_j + n·L
                    r  = np.linalg.norm(dr) # r = |dr|
                    if r<1e-9 or r>r_lim*L: continue # Skip self‐interaction (r≈0) or beyond cutoff r_lim·L
                    E += erfc(sqrt(alpha)*r)/r # α = 1/(4η²) we choose
    return E

def structure_factor(R,k):    # Σ e^{ik·r}
    """Structure factor S(k) = Σ e^{ik·r} (sum over all particles)"""
    phase = R @ k
    return np.sum(np.cos(phase))+1j*np.sum(np.sin(phase))

def reciprocal_space(R, alpha=ew_alpha, k_lim=k_cut, L=L):
    """ Long‐range (reciprocal‐space) Ewald sum (equation A7 and A10 from the paper):
    E_recip = (π/V) ∑_{k≠0} [ e^{-k²/(4α)} / k² ] |S(k)|²
    with V = L² in 2D, fast convergence."""
    area, k0 = L*L, 2*pi/L
    E=0.0
    # 2) Sum over discrete wavevectors q = (m_x, m_y)·(2π/L), the paper has ∑_{q≠0};
    # here we loop m_x, m_y ∈ [−k_lim,…,+k_lim]
    for mx in range(-k_lim,k_lim+1):
        for my in range(-k_lim,k_lim+1):
            if mx==0 and my==0: continue # skip the q = 0 term
            k = np.array([mx,my])*k0 # q_vec = (m_x, m_y)·(2π/L)
            k2 = k@k # q² = |q_vec|² = q_x² + q_y²
            E += exp(-k2/(4*alpha))*abs(structure_factor(R,k))**2 / k2 # factor e^{–q²/(4α)} # e^{-q²/(4α)}/q² · |S(q)|²
    return (pi/area)*E

def self_energy(N, alpha=ew_alpha):
    """ (equation A12 from the paper, Madelung constant) Self‐interaction correction: E_self = - ∑_i (√α / √π) · q_i²
    Here q_i are unit charges, so E_self = -N·(√α/√π)."""
    return -sqrt(alpha/pi)*N

# ξ_M : configuration-independent Madelung constant
def madelung_offset(alpha=ew_alpha,                 # α = 1/(4η²)
                    r_lim=r_cut, k_lim=k_cut, L=L):
    """ Compute ξ_M in Eq. (A12) Returns a scalar (dimensionless).  Multiply by e²/4πϵ₀ϵ_r later."""
    eta   = 0.5 / np.sqrt(alpha)     # because α = 1/(4η²)
    area  = L * L
    k0    = 2.0 * np.pi / L
    # ---- Real-space images   Σ_{L≠0} erfc(|L|/2η)/|L|
    rsum = 0.0
    maxn = int(np.ceil(r_lim))
    for nx in range(-maxn, maxn + 1):
        for ny in range(-maxn, maxn + 1):
            if nx == 0 and ny == 0:
                continue
            Rvec = np.array([nx, ny]) * L
            R    = np.linalg.norm(Rvec)
            if R > r_lim * L:
                continue
            rsum += erfc(R / (2.0 * eta)) / R
    # ---- Reciprocal-space images   (2π/Area) Σ_{G≠0} e^{−η²G²}/G
    ksum = 0.0
    for mx in range(-k_lim, k_lim + 1):
        for my in range(-k_lim, k_lim + 1):
            if mx == 0 and my == 0:
                continue
            Gvec = np.array([mx, my]) * k0
            G    = np.linalg.norm(Gvec)
            ksum += np.exp(-(eta * G) ** 2) / G
    ksum *= 2.0 * np.pi / area
    xi0_L = 1.0 / (eta * np.sqrt(np.pi))     # ---- ξ^L_0 term   1 / (η √π)
    return rsum + ksum - xi0_L     # ---- ξ_M (Eq. A12)

ξ_M = madelung_offset() * e2_4pieps0 / eps_r # compute once and store — units:   meV


def coulomb_ewald_2D(R):
    """
    Full 2-D Ewald energy for the set of positions R (shape (N,2)).
    Now includes ½ Σ_b ξ_M  so the result matches Eq. (A11) exactly.
    E_total = (e²/(4πϵ₀ ε_r))·( E_real + E_recip + E_self )
    """
    N  = len(R)
    # position-dependent part (your original implementation)
    E_config = (pairwise_real_space(R) +
                reciprocal_space(R)   +
                self_energy(N)) * e2_4pieps0 / eps_r
    # constant Madelung shift
    return E_config + 0.5 * N * ξ_M

# ---------- public function ----------
def energy_static(R):
    """V_ext + V_ee  (independent of Ψ)."""
    return moire_potential(R) + coulomb_ewald_2D(R)

# energy_static(np.random.rand(6,2))

# # alias for backward compatibility
# energy_moire = energy_static

class FeedForwardLayer(nn.Module):
    """ A single feed-forward layer with a tanh activation function.
        The input is added to the output of the layer. """

    def __init__(self, L: int) -> None: # L: layer of width d_L
        super().__init__()
        self.Wl_1p = nn.Linear(L, L)   # W^(l+1) h^l + b^(l+1)
        self.tanh = nn.Tanh()       # (nonlinear) hyperbolic tangent activation function

    def forward(self, hl: torch.Tensor) -> torch.Tensor:
        return hl + self.tanh(self.Wl_1p(hl))  # input should be of shape (N, L): h^l + tanh( W^(l+1) h^l + b^(l+1) )

class SlaterNet(nn.Module):
    # def __init__(self, a: float, N: int, L: int = 4, num_layers: int = 3) -> None: # a: lattice constant
    def __init__(self, a, N, L = 64, num_layers = 3) -> None:  # L: layer of width d_L; a: lattice constant

        super().__init__()
        self.N = N
        self.L = L
        self.a = a
        self.num_layers = num_layers

        G_vectors = torch.from_numpy(np.array(b_vectors(a))).float()
        self.register_buffer('G1_T', G_vectors[0].unsqueeze(-1))
        self.register_buffer('G2_T', G_vectors[1].unsqueeze(-1)) # register_buffer: G vectors as part of the model, but not as a trainable parameter.

        # input embedding matrix: projects 4 features to L-dim
        self.W_0 = nn.Linear(4, L, bias=False)
        self.MLP_layers = nn.ModuleList( # create a list of MLP layers
            [FeedForwardLayer(L) for _ in range(num_layers)]
        )

        # matrix to hold the projection vectors (complex projectors for orbital) they're trainable parameters
        # w_2j and w_2j+1 (one for real one for complex) for j = 0, ... N-1 (6 electrons)
        self.complex_proj = nn.Parameter(
            torch.complex(real=torch.randn(L, N), imag=torch.randn(L, N))
        )
        # self.denominator = math.sqrt(math.factorial(N))

    def forward(self, R: torch.Tensor) -> torch.Tensor:  # R should be of shape (N, 2)

        G1_R = torch.matmul(R, self.G1_T)         # compute the periodic features
        G2_R = torch.matmul(R, self.G2_T)
        features_R = torch.cat(
            (torch.sin(G1_R), torch.sin(G2_R), torch.cos(G1_R), torch.cos(G2_R)), dim=1
        ) # shape (N, 4)

        # embed in higher_dimensional space to get h^0
        h = self.W_0(features_R)

        # pass through MLP layers
        for layer in self.MLP_layers:
            h = layer(h)

        # slater matxix
        WF_matrix = torch.matmul(h.to(torch.complex64), self.complex_proj)
        determinant = torch.linalg.det(WF_matrix)
        # result = determinant/self.denominator
        return determinant

# Checks if you have a GPU (CUDA). If yes, run on GPU for speed; otherwise, run on CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# to handle complex psi (new)
def complex_grad(outputs, inputs, grad_outputs=None): # output is ψ, input is R
    # Takes the each part of psi and computes its gradient w.r.t. inputs using PyTorch’s autograd.grad
    grad_real = torch.autograd.grad(outputs.real, inputs, grad_outputs=grad_outputs, create_graph=True, allow_unused=True)[0]
    grad_imag = torch.autograd.grad(outputs.imag, inputs, grad_outputs=grad_outputs, create_graph=True, allow_unused=True)[0]
    grad = None
    if grad_real is not None and grad_imag is not None:
        grad = grad_real + 1j * grad_imag
    elif grad_real is not None:
        grad = grad_real
    elif grad_imag is not None:
        grad = 1j * grad_imag
    else:
        grad = torch.zeros_like(inputs)
    return grad # Return the (possibly complex-valued) gradient


def compute_laplacian_complex(psi, R):
    """
    Compute ∇²Ψ(R) for a complex Ψ using two successive calls
    to complex_grad (which handles complex outputs).
    """
    # first derivatives ∂Ψ/∂R_{i,d}, shape (N_e, 2), complex
    grads = complex_grad(psi, R)
    lap = 0
    N, D = R.shape

    # loop electrons i and dims d
    for i in range(N):
        for d in range(D):
            # grab the scalar ∂Ψ/∂R_{i,d}
            first_deriv = grads[i, d]

            # now take its gradient w.r.t. R
            # this is ∂²Ψ / (∂R_{i,d} ∂R_{j,k}) for all j,k
            grads2 = complex_grad(first_deriv, R)

            # we only want the diagonal piece ∂²Ψ/∂R_{i,d}²
            lap += grads2[i, d]

    return lap


def local_energy(net, R):
    R = R.clone().detach().to(device).requires_grad_(True)
    psi = net(R)
    lap_psi = compute_laplacian_complex(psi, R)
    kin_complex = -hbar2_over_2m * (lap_psi / psi)
    kin = kin_complex.real
    V_np = energy_static(R.detach().cpu().numpy())
    V = torch.tensor(V_np, device=device, dtype=kin.dtype)
    return kin + V


def mcmc_sampler(net, R_init, n_steps=200, step_size=0.1):
    """ Metropolis–Hastings sampler to draw samples ~ |Ψ|².
    Inputs:
      - net: SlaterNet model
      - R_init: torch.Tensor (N_e,2), starting positions
      - n_steps: number of MCMC steps
      - step_size: Gaussian proposal standard deviation
    Returns:
      - samples: list of torch.Tensor configurations (one per step)"""

    R = R_init.clone().to(device)
    psi_sq = torch.abs(net(R))**2     # Compute initial |Ψ|²
    samples = []
    for _ in range(n_steps):
        for i in range(R.shape[0]):     # propose move for each electron
            R_prop = R.clone()
            R_prop[i] += step_size * torch.randn_like(R[i])             # Gaussian random displacement
            R_prop[i] = R_prop[i] % L       # Enforce PBC in a box of length L
            psi_sq_prop = torch.abs(net(R_prop))**2       # evalualte |Ψ|² at the proposed R
            if (psi_sq_prop / (psi_sq + 1e-12)) > torch.rand(1, device=device):
                R[i] = R_prop[i]
                psi_sq = psi_sq_prop
        samples.append(R.clone())
    return samples


def train_vmc(net,
              n_iter=200,
              n_walkers=8,
              n_steps=200,
              step_size=0.1):
    """ VMC training with a decaying learning rate: η(t) = η_0 * (1 + t/t0)^(-1) （page 13）"""

    η_0 = 10.0    # initial learning rate (Table II)
    # η_0 = 10e-4 
    t0  = 1e5     # decay “time constant”
    rho = 5.0     # clipping threshold for local energy

    optimizer = Adam(net.parameters(), lr=η_0)
    # Adam will perform parameter updates θ ← θ – η·∇_θ L

    # Each walker is one configuration R of N_e electrons in 2D, drawn uniformly in [0, L)^2
    walkers = [L * torch.rand(net.N, 2, device=device) for _ in range(n_walkers)]

    # Main VMC loop ──────────────────────────────────────────────────────────────────
    for it in range(1, n_iter + 1):
        # — Update the learning rate —
        lr_t = η_0 * (1 + it / t0) ** -1
        for pg in optimizer.param_groups:
            pg['lr'] = lr_t

        # Prepare lists to collect data over all walkers & steps
        batch_E_full = []   # unclipped E_loc, for logging true ⟨E⟩
        batch_E_clip = []   # clipped E_loc, for building the loss
        batch_logpsi = []   # ln|Ψ(R)|, the “score‐function” term, will be used in loss function
        new_walkers  = []   # to hold final config of each walker

        for w in walkers:
            samples = mcmc_sampler(net, w, n_steps, step_size) # mcmc_sampler we previously wrote
            new_walkers.append(samples[-1])
            for R in samples:
                # 1) compute full local energy
                Eloc_full = local_energy(net, R)

                # 2) clip for stability in the loss
                Eloc_clip = torch.clamp(Eloc_full, -rho, +rho)

                batch_E_full.append(Eloc_full)
                batch_E_clip.append(Eloc_clip)
                batch_logpsi.append(torch.log(torch.abs(net(R)) + 1e-12))

        # stack into tensors
        E_full      = torch.stack(batch_E_full)
        E_clip      = torch.stack(batch_E_clip)
        logpsi      = torch.stack(batch_logpsi)

        # true mean energy (for reporting)
        E_mean_full = E_full.mean()

        # clipped mean energy (for loss)
        E_mean_clip = E_clip.mean().detach()

        # VMC loss uses clipped energies
        loss = torch.mean((E_clip - E_mean_clip) * logpsi)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        walkers = new_walkers
        print(f'Iter {it:3d}/{n_iter:3d} | <E> = {E_mean_full.item():.6f} meV | lr = {lr_t:.3e}')

    return net

In [None]:
net = SlaterNet(a_m, N_e, L=64, num_layers=3).to(device) #-----number layers: 3; Perceptron dim: 64
trained_net = train_vmc(net,
                        n_iter=20, # paper says 150000 ------- Training iterations
                        n_walkers=16,
                        n_steps=256,
                        step_size=0.05,
                        )

# ------n_walkers * n_steps = 4096 (MCMC batch size) namely gather 4096 (𝑅,𝐸_loc, lnΨ) where
# “n_walkers”：how many independent MCMC chains we run in parallel
# “n_steps”: how many successive moves each walker makes