In [35]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import lax
import itertools
from typing import Callable, Optional, Sequence, Tuple

In [36]:
def _sc_lattice_vecs(rs: float, nelec: int) -> np.ndarray:
  """Returns simple cubic lattice vectors with Wigner-Seitz radius rs."""
  area = np.pi * (rs**2) * nelec
  length = area**(1 / 2)
  return length * np.eye(2)


# Understanding the functions for electron-electron interactions

In [267]:
def make_2DCoulomb_potential(
    lattice: jnp.ndarray,
    atoms: jnp.ndarray,
    charges: jnp.ndarray,
    truncation_limit: int = 5,
    interaction_energy_scale: float = 1.0,
) -> Callable[[jnp.ndarray, jnp.ndarray], float]:
  """Creates a function to evaluate infinite Coulomb sum for periodic lattice.

    Args:
        lattice: Shape (2, 2). Matrix whose columns are the primitive lattice vectors.
        atoms: Shape (natoms, 2). Positions of the atoms.
        charges: Shape (natoms). Nuclear charges of the atoms.
        nspins: Tuple of the number of spin-up and spin-down electrons.
        truncation_limit: Integer. Half side length of square of nearest neighbours
        to primitive cell which are summed over in evaluation of Ewald sum.

    Returns:
        Callable with signature f(ae, ee, spins), where (ae, ee) are atom-electron and
        electron-electron displacement vectors respectively, and spins are electron spins,
        which evaluates the Coulomb sum for the periodic lattice via the Ewald method.
  """
  del atoms, charges # unused for 2d system without atoms
  print("making 2DCoulomb potential with energy scale: " + str(interaction_energy_scale))
  rec = 2 * jnp.pi * jnp.linalg.inv(lattice)
  #volume = jnp.abs(jnp.linalg.det(lattice))
  area = jnp.abs(jnp.linalg.det(lattice)) #area for 2D
  # the factor gamma tunes the width of the summands in real / reciprocal space
  # and this value is chosen to optimize the convergence trade-off between the
  # two sums. See CASINO QMC manual.
  gamma_factor = 2.4
  gamma = (gamma_factor / area**(1 / 2))**2  # Adjusted for 2D systems

  ordinals = sorted(range(-truncation_limit, truncation_limit + 1), key=abs)
  ordinals = jnp.array(list(itertools.product(ordinals, repeat=2)))  # Adjusted for 2D
  lat_vectors = jnp.einsum('kj,ij->ik', lattice, ordinals)
  rec_vectors = jnp.einsum('jk,ij->ik', rec, ordinals[1:])
  rec_vec_square = jnp.einsum('ij,ij->i', rec_vectors, rec_vectors)
  rec_vec_magnitude = jnp.sqrt(rec_vec_square) # |rec_vectors|, same as kappa
  lat_vec_norm = jnp.linalg.norm(lat_vectors[1:], axis=-1)

  def real_space_ewald(separation: jnp.ndarray):
      """Real-space Ewald potential between charges in 2D.
      """
      displacements = jnp.linalg.norm(separation - lat_vectors, axis=-1)  # |r - R|

      return jnp.sum(
          jax.scipy.special.erfc(gamma**0.5 * displacements) / displacements)

  def recp_space_ewald(separation: jnp.ndarray):
      """Reciprocal-space Ewald potential between charges in 2D.
      """
      phase = jnp.cos(jnp.dot(rec_vectors, separation))

      factor = jax.scipy.special.erfc(rec_vec_magnitude / (2 * gamma**0.5) )
      return (2 * jnp.pi / area) * jnp.sum( phase * factor / rec_vec_magnitude)

  def ewald_sum(separation: jnp.ndarray):
      """Combined real and reciprocal space Ewald potential in 2D.
      """
      return real_space_ewald(separation) + recp_space_ewald(separation)
      
  # Compute Madelung constant components
  # Real-space part
  # xi_S_0 = 0 * gamma**0.5 / jnp.pi**0.5
  madelung_real = jnp.sum(
      jax.scipy.special.erfc(gamma**0.5 * lat_vec_norm) / lat_vec_norm
  )

  # q = 0 contribution of the real-space part
  phi_S_q0 = (2 * jnp.pi) / area / gamma**0.5 / jnp.pi**0.5

  # Reciprocal-space part
  xi_L_0 = 2 * gamma**0.5 / jnp.pi**0.5
  madelung_recip = - 0*(2 * jnp.pi / area) * (1 / (gamma**0.5 * jnp.pi**0.5)) + \
      (2 * jnp.pi / area) * jnp.sum(
          jax.scipy.special.erfc(rec_vec_magnitude / (2 * gamma**0.5)) / rec_vec_magnitude
      ) - xi_L_0

  # Total Madelung constant
  madelung_const = madelung_real + madelung_recip
  batch_ewald_sum = jax.vmap(ewald_sum, in_axes=(0,))

  def electron_electron_potential(ee: jnp.ndarray):
      """Evaluates periodic electron-electron potential with charges.

      We always include neutralizing background term for homogeneous electron gas.
      """
      nelec = ee.shape[0]
      ee = jnp.reshape(ee, [-1, 2])
      ewald = batch_ewald_sum(ee)
      ewald = jnp.reshape(ewald, [nelec, nelec])
      # Set diagonal elements to zero (self-interaction)
      ewald = ewald.at[jnp.diag_indices(nelec)].set(0.0)

      # Add Madelung constant term: (1/2) * N * q_i^2 * Madelung_const
      # Since q_i^2 = 1, this simplifies to (1/2) * N * Madelung_const
      potential = 0.5 * jnp.sum(ewald)  - 0.5 * nelec**2 * phi_S_q0
      return potential

  def potential(ae: jnp.ndarray, ee: jnp.ndarray):
    """Accumulates atom-electron, atom-atom, and electron-electron potential."""
    # Reduce vectors into first unit cell
    del ae # for HEG calculations, there are no atoms
    phase_ee = jnp.einsum('il,jkl->jki', rec / (2 * jnp.pi), ee)
    phase_prim_ee = (phase_ee + 0.5)  % 1 - 0.5
    prim_ee = jnp.einsum('il,jkl->jki', lattice, phase_prim_ee)
    return interaction_energy_scale * jnp.real(
        electron_electron_potential(prim_ee)
    )

  return potential

In [260]:
a1 = jnp.array([jnp.sqrt(3)/2,-0.5])
a2 = jnp.array([0,1])
Tmatrix = jnp.array([[3,0], [0, 3]])  # Identity matrix for unit cell
T = lattice_vecs(a1, a2, Tmatrix)
rec = 2*jnp.pi*jnp.linalg.inv(T)

In [261]:
potential = make_2DCoulomb_potential(T, jnp.array([0.0]), jnp.array([0,0]))

making 2DCoulomb potential with energy scale: 1.0


In [262]:
def construct_input_features(
    pos: jnp.ndarray,
    atoms: jnp.ndarray,
    ndim: int = 3) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
  """Constructs inputs to Fermi Net from raw electron and atomic positions.

  Args:
    pos: electron positions. Shape (nelectrons*ndim,).
    atoms: atom positions. Shape (natoms, ndim).
    ndim: dimension of system. Change only with caution.

  Returns:
    ae, ee, r_ae, r_ee tuple, where:
      ae: atom-electron vector. Shape (nelectron, natom, ndim).
      ee: electron-electron vector. Shape (nelectron, nelectron, ndim).
      r_ae: atom-electron distance. Shape (nelectron, natom, 1).
      r_ee: electron-electron distance. Shape (nelectron, nelectron, 1).
    The diagonal terms in r_ee are masked out such that the gradients of these
    terms are also zero.
  """
  assert atoms.shape[1] == ndim
  ae = jnp.reshape(pos, [-1, 1, ndim]) - atoms[None, ...]
  ee = jnp.reshape(pos, [1, -1, ndim]) - jnp.reshape(pos, [-1, 1, ndim])

  r_ae = jnp.linalg.norm(ae, axis=2, keepdims=True)
  # Avoid computing the norm of zero, as is has undefined grad
  n = ee.shape[0]
  r_ee = (
      jnp.linalg.norm(ee + jnp.eye(n)[..., None], axis=-1) * (1.0 - jnp.eye(n)))
  return ae, ee, r_ae, r_ee[..., None]

In [263]:
ae,ee,r_ae, r_ee = construct_input_features(jnp.array([[0.0, 0.0], [1.0, 1.0]]).flatten(), jnp.array([[0.0, 0.0]]),2)

In [264]:
jnp.array([[0.0, 0.0], [1.0+T[0,0], 1.0+T[1,0]]]).flatten().shape

(4,)

In [265]:
T[:,0] + np.array([0,0])

Array([ 2.598076, -1.5     ], dtype=float32)

In [266]:
potential(ae,ee)

Array(-1.0038664, dtype=float32)

# Periodic Potential

In [5]:
import jax.numpy as jnp
from typing import Callable

def make_cosine_potential_with_input_features(
    potential_lattice: jnp.ndarray,
    coefficients: jnp.ndarray,
    phases: jnp.ndarray,
) -> Callable[[jnp.ndarray], float]:
    """
    Creates a function to evaluate a periodic potential as a sum of three cosines with phases,
    adapted to take `ae` in the form outputted by `construct_input_features`.

    Args:
        potential_lattice: Shape (2, 2). Matrix whose columns are the primitive lattice vectors.
        coefficients: Shape (3,). Coefficients for the three cosine terms.
        phases: Shape (3,). Phases for the three cosine terms.

    Returns:
        Callable with signature f(ae), where ae is an array of atom-electron
        displacement vectors of (shape (nelectron, natom, ndim)), which evaluates the periodic potential.
    """
    # Compute reciprocal lattice vectors
    rec = 2 * jnp.pi * jnp.linalg.inv(potential_lattice)

    # Define the cosine potential function
    def potential(ae: jnp.ndarray) -> float:
        """
        Evaluates the periodic potential using atom-electron displacement vectors.

        Args:
            ae: Shape (nelec * natom, 2). Flattened array of atom-electron displacement vectors.

        Returns:
            The value of the potential summed over all displacement vectors.
        """
        # Reshape `ae` to (nelec, 2) 
        ae = jnp.reshape(ae, (-1, 2))

        # Compute the cosine terms with phases for each displacement vector
        cos_term1 = coefficients[0] * jnp.cos(jnp.dot(ae, rec[0,:]) + phases[0])
        cos_term2 = coefficients[1] * jnp.cos(jnp.dot(ae, rec[1,:]) + phases[1])
        cos_term3 = coefficients[2] * jnp.cos(jnp.dot(ae, rec[0,:] - rec[1,:]) + phases[2])

        # Sum the cosine terms over all displacement vectors
        return jnp.sum(cos_term1 + cos_term2 + cos_term3)

    return potential

In [6]:
# Define a lattice
lattice = jnp.array([[1.0, 0.0], [0.0, 1.0]])

# Define coefficients and phases for the cosine terms
coefficients = jnp.array([1.0, 1.0, 1.0])
phases = jnp.array([0.0, 0.0, 0.0])

# Create the potential function
cosine_potential = make_cosine_potential_with_input_features(lattice, coefficients, phases)

# Example `ae` array (flattened atom-electron displacement vectors)
#ae = jnp.array([[0.5, 0.5], [0.3, 0.7], [0.1, 0.9], [0.4, 0.6]]).flatten()

# Evaluate the potential
#value = cosine_potential(ae)
#print("Potential value:", value)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [7]:
print(ae.shape)

NameError: name 'ae' is not defined

In [38]:
def lattice_vecs(a1:jnp.ndarray, a2:jnp.ndarray,Tmatrix:jnp.ndarray) -> np.ndarray:
  "Return the basis T1,T2 of the super-cell built from the unit cell lattice vectors a1 and a2"
  T1 = Tmatrix[0,0]*a1 + Tmatrix[0,1]*a2
  T2 = Tmatrix[1,0]*a1 + Tmatrix[1,1]*a2
  return jnp.column_stack([T1, T2])
def reciprocal_vecs(a1: jnp.ndarray, a2: jnp.ndarray, Tmatrix: jnp.ndarray) -> jnp.ndarray:
  """
  Return the reciprocal basis vectors g1, g2 such that T_i · g_j = 2π δ_ij,
  where T1, T2 are the supercell vectors built from the unit cell vectors a1 and a2.
  """
  # Construct supercell lattice vectors T1, T2
  T1 = Tmatrix[0,0] * a1 + Tmatrix[0,1] * a2
  T2 = Tmatrix[1,0] * a1 + Tmatrix[1,1] * a2

  # Compute the oriented area A = T1 x T2 (scalar in 2D)
  A = T1[0] * T2[1] - T1[1] * T2[0]

  # Compute reciprocal vectors satisfying T_i · g_j = 2π δ_ij
  g1 = 2 * jnp.pi / A * jnp.array([ T2[1], -T2[0] ])
  g2 = 2 * jnp.pi / A * jnp.array([ -T1[1], T1[0] ])

  return jnp.column_stack([g1, g2])

In [47]:
a1 = jnp.array([jnp.sqrt(3)/2,-0.5])
a2 = jnp.array([0,1])
Tmatrix = jnp.array([[-2,2], [4, 2]])  # Identity matrix for unit cell
T = lattice_vecs(a1, a2, Tmatrix)
rec = 2*jnp.pi*jnp.linalg.inv(T)

In [48]:
area = jnp.abs(jnp.linalg.det(T))

In [49]:
area

Array(10.392304, dtype=float32)

In [50]:
v1 = -2*a1 + 2*a2
v2 = 4*a1 + 2*a2
area = np.abs(v1[1]*v2[0] - v1[0]*v2[1])

In [51]:
area

10.392304

In [17]:
def hbar2_over_2m_eff(meff: float) -> float:
    """
    Returns hbar^2 / (2 * m_eff) in units of meV·nm².
    `meff` should be given in units of the electron mass m_e.
    """
    # Physical constants
    hbar = 1.054571817e-34  # J·s
    m_e = 9.10938356e-31    # kg
    eV = 1.602176634e-19    # J
    meV = 1e-3 * eV
    nm2 = 1e-18             # m²

    # Convert hbar^2 / (2 m_eff) to meV·nm²
    value = hbar**2 / (2 * meff * m_e)  # in J·m²
    value_meV_nm2 = value / meV / nm2  # convert to meV·nm²

    return value_meV_nm2


In [7]:
kin = hbar2_over_2m_eff(0.35)

In [11]:
kin*2

217.71326689531793

In [37]:
def coulomb_prefactor(epsilon_r: float) -> float:
    """
    Compute e^2 / (4 * pi * epsilon_0 * epsilon_r * r) at r = 1 nm,
    return result in meV units.
    """

    # Constants
    e = 1.602176634e-19        # C
    epsilon_0 = 8.854187817e-12  # F/m
    r_nm = 1e-9                # m
    J_to_meV = 1 / 1.602176634e-22  # 1 J = this many meV

    # Energy in J at r = 1 nm
    energy_J = e**2 / (4 * jnp.pi * epsilon_0 * epsilon_r * r_nm)

    # Convert to meV
    energy_meV = energy_J * J_to_meV

    return energy_meV


In [30]:
prefac = coulomb_prefactor(20)

In [32]:
prefac/hbar2_over_2m_eff(1.0)/2

0.9448630483400831

# Attemp at adding a vector potential

In [51]:
import jax.numpy as jnp
from typing import Callable

def make_vectorpotential(
    Bfield_lattice: jnp.ndarray,
    flux: jnp.ndarray,
    phase: jnp.ndarray,
) -> Callable[[jnp.ndarray], jnp.ndarray]:
    """
    Creates a function to evaluate the vector potential

    Args:
        Bfield_lattice: Shape (2, 2). Matrix whose columns are the primitive lattice vectors.
        coefficients: Shape (3,). Coefficients for the three cosine terms.
        phases: Shape (3,). Phases for the three cosine terms.

    Returns:
        Callable with signature f(ae), where ae is a flattened array of atom-electron
        displacement vectors (shape (nelec * natom, 2)), which evaluates the vector potential and returns a vector.
    """
    # Compute reciprocal lattice vectors
    rec = 2 * jnp.pi * jnp.linalg.inv(Bfield_lattice)
    Glist = jnp.array([rec[0,:], rec[1,:], -rec[0,:] + rec[1,:], -rec[0,:], -rec[1,:], rec[0,:] - rec[1,:]])
    # Precompute norms and coefficients for speed
    Gcoeff_x = 1.0j * Glist[:, 1] * flux
    Gcoeff_y = -1.0j * Glist[:, 0] * flux
    print(Glist), print(Gcoeff_x), print(Gcoeff_y)
    # Define the vector potential component function
    def vector_potential_comp(r: jnp.ndarray) -> jnp.ndarray:
        # Compute dot products for all G vectors at once
        dot_products = jnp.dot(Glist, r)
        exp_terms = jnp.exp(1.0j * dot_products + phase) 

        # Compute outx and outy using vectorized operations
        outx = jnp.sum(exp_terms * Gcoeff_x)
        outy = jnp.sum(exp_terms * Gcoeff_y)
        return jnp.array([jnp.real(outx), jnp.real(outy)])

    # Vectorize the component function using jax.vmap
    vectorized_potential_comp = jax.vmap(vector_potential_comp)

    # Define the vector potential function
    def potential(ae: jnp.ndarray) -> jnp.ndarray:
        """
        Evaluates the vector potential using atom-electron displacement vectors.

        Args:
            ae: Shape (nelec , natom, 2). Array of atom-electron displacement vectors.

        Returns:
            The vector potential as a vector of shape (nelec, 2)
        """
        # Reshape `ae` to (nelec * natom, 2) if necessary
        ae = jnp.reshape(ae, (-1, 2))
        return vectorized_potential_comp(ae),jnp.reshape(vectorized_potential_comp(ae), (-1,))

    return potential

In [54]:
glist = jnp.array([rec[0,:],rec[1,:],rec[0,:] - rec[1,:],-rec[0,:],-rec[1,:],-rec[0,:] + rec[1,:]])

In [7]:
def vector_potential_comp(r:jnp.ndarray) -> jnp.ndarray:
    outx = 0.0 + 0.0j
    outy = 0.0 + 0.0j
    for i in range(6):
        outx += (jax.numpy.exp(jnp.dot(r, glist[i])))*(1.0j*glist[i][1]*(Bfield/(jnp.linalg.norm(glist[i])**2)))
        outy += (jax.numpy.exp(jnp.dot(r, glist[i])))*(-1.0j*glist[i][2]*(Bfield/(jnp.linalg.norm(glist[i])**2)))
    return jnp.array([outx,outy])
vectorized_function = jax.vmap(vector_potential_comp)


In [52]:
outfunc = make_vectorpotential(T,jnp.array(-0.05),0.0)

[[ 7.2551975  0.       ]
 [ 3.6275988  6.2831855]
 [-3.6275988  6.2831855]
 [-7.2551975 -0.       ]
 [-3.6275988 -6.2831855]
 [ 3.6275988 -6.2831855]]
[-0.+0.j         -0.-0.31415927j -0.-0.31415927j  0.-0.j
  0.+0.31415927j  0.+0.31415927j]
[ 0.+0.3627599j   0.+0.18137994j -0.-0.18137994j -0.-0.3627599j
 -0.-0.18137994j  0.+0.18137994j]


Array(-0.23, dtype=float32, weak_type=True)

In [63]:
ae,ee,r_ae, r_ee = construct_input_features(jnp.array([[1, 2],[5,6]]).flatten(), jnp.array([[0.0, 0.0]]),2)

In [64]:
ae

Array([[[1., 2.]],

       [[5., 6.]]], dtype=float32)

In [65]:
outfunc(ae)

(Array([[-3.8743019e-07, -2.6040596e-01],
        [ 8.9406967e-08,  1.1914012e+00]], dtype=float32),
 Array([-3.8743019e-07, -2.6040596e-01,  8.9406967e-08,  1.1914012e+00],      dtype=float32))

In [62]:
-2.6040596e-01

-0.26040596

In [165]:
def flatten_AvectorPotential(Avec: jnp.ndarray) -> jnp.ndarray:
    """
    Flattens the vector potential output from the vector potential function.

    Args:
        Avec: Shape (nelec, 2). The vector potential output.

    Returns:
        Flattened vector potential as a 1D array of shape (ndim * nelec,).
    """
    return jnp.reshape(Avec, (-1,))  # Flatten to 1D array

In [20]:
import jax.numpy as jnp

# Compute the sum of squared norms of each 2D vector in Avec
def sum_squared_norm(Avec):
    return jnp.sum(jnp.sum(jnp.square(Avec), axis=-1))

# Example usage
Avec = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])  # Shape (nelec, 2)
result = sum_squared_norm(outfunc(ae))  # Output: 1^2 + 2^2 + 3^2 + 4^2 + 5^2 + 6^2 = 91
print(result)

33.182137


In [176]:
testout2 = flatten_AvectorPotential(testout)

In [178]:
testout

Array([[0.  +0.j      , 0.  +0.j      ],
       [0. +10.346515j, 0. -10.346515j],
       [0. +69.381805j, 0. -69.381805j],
       [0.+486.46396j , 0.-486.46396j ]], dtype=complex64)

In [177]:
testout2

Array([0.  +0.j      , 0.  +0.j      , 0. +10.346515j, 0. -10.346515j,
       0. +69.381805j, 0. -69.381805j, 0.+486.46396j , 0.-486.46396j ],      dtype=complex64)

In [None]:
gradient_Avec_energy = Callable[
    [networks.ParamTree, networks.FermiNetData], jnp.ndarray
]
def local_gradient_vectorpotential_energy(
    f: networks.FermiNetLike,
    Avec:  Callable[[jnp.ndarray], jnp.ndarray]
) -> gradient_Avec_energy:
  r"""Creates a function for the local dot product of vector potential and gradient of wavefunction, A(r).grad(\psi)

  Args:
    f: Callable which evaluates the wavefunction as a
      (sign or phase, log magnitude) tuple.
    Avec: Callable which evaluates the vector potential A(r) at positions r and output a vector
        of shape (ndim*N,2) with ndim = 2

  Returns:
    Callable which evaluates A(r).grad(\psi)
  """
    phase_f = utils.select_output(f, 0)
    logabs_f = utils.select_output(f, 1)
    def Avec_dot_grad_over_f(params, data):
      n = data.positions.shape[0]
      eye = jnp.eye(n)
      grad_logabs_f = jax.grad(logabs_f, argnums=1)
      grad_phase_f = jax.grad(phase_f,argnums = 1)
      return jnp.dot(Avec(data.pos), grad_logabs_f(params, data.pos, data.spins, data.atoms, data.charges) + 1.j * grad_phase_f(params, data.pos, data.spins, data.atoms, data.charges) ) 
    
    return Avec_dot_grad_over_f

In [217]:
def div_A(
    Avec: Callable[[jnp.ndarray], jnp.ndarray]
) -> Callable[[jnp.ndarray], float]:
    r"""Computes the sum of the divergence of the vector potential A(r).

    Args:
        Avec: Callable which evaluates the vector potential A(r) at positions r and outputs a vector
            of shape (ndim*N, 2) with ndim = 2.

    Returns:
        Callable which evaluates sum of div(A)at positions r_i.
    """
    def div_A_func(r: jnp.ndarray) -> jnp.ndarray:
        Avec_val = Avec(r)
        return jax.jacobian(lambda x: Avec_val)(r).sum(axis=-1)

    return div_A_func

In [199]:
outfunc

<function __main__.make_vectorpotential.<locals>.potential(ae: jax.Array) -> jax.Array>

In [222]:
div_func = div_A(outfunc)

In [223]:
div_func(ae)

Array([[[[0.],
         [0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.],
         [0.]]],


       [[[0.],
         [0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.],
         [0.]]],


       [[[0.],
         [0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.],
         [0.]]],


       [[[0.],
         [0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.],
         [0.]]]], dtype=float32)

# Madelung constant

In [27]:

def get_Madelung_constant_2DCoulomb_potential(
    lattice: jnp.ndarray,
    truncation_limit: int = 5,
) -> float:
  """ Returns Madelung constants. 
  These capture the contribution of Coulomb interaction of particles with their images to the total Coulomb energy, as
  E_Madelung = 0.5 * N_elec * madelung_const 
  """
  rec = 2 * jnp.pi * jnp.linalg.inv(lattice)
  print(rec)
  area = jnp.abs(jnp.linalg.det(lattice)) #area for 2D
  gamma_factor = 2.8
  gamma = (gamma_factor / area**(1 / 2))**2  # Adjusted for 2D systems

  ordinals = sorted(range(-truncation_limit, truncation_limit + 1), key=abs)
  ordinals = jnp.array(list(itertools.product(ordinals, repeat=2)))  # Adjusted for 2D
  lat_vectors = jnp.einsum('kj,ij->ik', lattice, ordinals)
  lat_vec_norm = jnp.linalg.norm(lat_vectors[1:], axis=-1)
  rec_vectors = jnp.einsum('jk,ij->ik', rec, ordinals[1:])
  rec_vec_square = jnp.einsum('ij,ij->i', rec_vectors, rec_vectors)
  rec_vec_magnitude = jnp.sqrt(rec_vec_square) # |rec_vectors|, same as kappa

  # Real space part
  madelung_real = jnp.sum(
      jax.scipy.special.erfc(gamma**0.5 * lat_vec_norm) / lat_vec_norm
  )

  # Reciprocal-space part
  xi_L_0 = 2 * gamma**0.5 / jnp.pi**0.5
  madelung_recip = - 0*(2 * jnp.pi / area) * (1 / (gamma**0.5 * jnp.pi**0.5)) + \
      (2 * jnp.pi / area) * jnp.sum(
          jax.scipy.special.erfc(rec_vec_magnitude / (2 * gamma**0.5)) / rec_vec_magnitude
      ) - xi_L_0
  
  # q = 0 contribution of short-range part
  phi_S_q0 = (2 * jnp.pi) / area / gamma**0.5 / jnp.pi**0.5

  # Total Madelung constant
  madelung_const = madelung_real + madelung_recip

  return madelung_const, phi_S_q0

In [28]:
madconstant, phi_S_q0 = get_Madelung_constant_2DCoulomb_potential(T)

[[2.4183993 0.       ]
 [1.2091997 2.0943952]]


In [33]:
madconstant*0.5*3*0.9448630483400831

Array(-1.3478357, dtype=float32)

In [34]:
2.8851 - -1.3478357

4.2329357000000005

In [20]:
- 0.5 * 3**2 * phi_S_q0* prefac

Array(-587.69855, dtype=float32)

In [11]:
4.7496 - -0.5480

5.2976