### Jax implementation of Ising-2D simulation

- Random Key Management: JAX uses explicit random number generation keys to maintain reproducibility and functional purity. The key must be split whenever randomness is required.
- Mutation with Immutability: JAX arrays are immutable, so to flip a spin, jax.numpy's .at[index].set(value) is used.
- Functional Approach: The JAX version adheres to a more functional programming style, which means avoiding in-place operations and ensuring that all operations are pure (i.e., no side effects).
- Array Stacking: When collecting results over steps, jnp.stack and jnp.array are used to convert lists of arrays to a single JAX array.

In [None]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import random
import torch

In [None]:
def ising2d_jax(N=20, T=1.0, J=1.0, B=0.0, n_steps=20000, out_freq=10):
    '''
    Metropolis Monte Carlo simulator for the 2D Ising model using JAX.

    Parameters:
    N (int): Size of the lattice, lattice is N x N.
    T (float): Temperature.
    J (float): Interaction strength between spins.
    B (float): External magnetic field.
    n_steps (int): Number of Monte Carlo steps.
    out_freq (int): Output frequency for saving spin configurations, energy, and magnetization.

    Returns:
    tuple: Arrays of spin configurations, energies, and magnetizations.
    '''
    key = random.PRNGKey(0)
    
    # Initialize spins
    key, subkey = random.split(key)
    spins = random.randint(subkey, (N, N), 0, 2) * 2 - 1
    n = N * N
    
    neighbors = jnp.roll(spins, 1, axis=0) + jnp.roll(spins, -1, axis=0) + \
                jnp.roll(spins, 1, axis=1) + jnp.roll(spins, -1, axis=1)
    E_t = -J * (spins * neighbors).sum()
   
    S, E, M = [], [], []
    M_t = spins.sum()

    for step in range(n_steps):
        key, subkey = random.split(key)
        i, j = random.randint(subkey, (2,), 0, N)
        
        z = spins[(i + 1) % N, j] + spins[(i - 1) % N, j] + \
            spins[i, (j + 1) % N] + spins[i, (j - 1) % N]
        
        dE = 2 * spins[i, j] * (J * z + B)
        dM = 2 * spins[i, j]

        key, flip_key = random.split(key)
        if random.uniform(flip_key) < jnp.exp(-dE / T):
            spins = spins.at[i, j].set(-spins[i, j])
            E_t += dE
            M_t += dM

        if step % out_freq == 0:
            S.append(spins)
            E.append(E_t / n)
            M.append(M_t / n)

    return jnp.stack(S), jnp.array(E), jnp.array(M)

### Pytorch implementation

In [None]:
def ising2d_torch(N=20, 
                T=1.0, 
                J=1.0, 
                B=0.0, 
                n_steps=20000, 
                out_freq=10):
    '''
    Metropolis Monte Carlo simulator for the 2D Ising model using PyTorch.

    Parameters:
    spins (torch.Tensor): Initial spin configuration.
    T (float): Temperature.
    J (float): Interaction strength between spins.
    B (float): External magnetic field.
    n_steps (int): Number of Monte Carlo steps.
    out_freq (int): Output frequency for saving spin configurations, energy, and magnetization.
    device (str): Device to run the simulation ('cpu' or 'cuda').

    Returns:
    tuple: Arrays of spin configurations, energies, and magnetizations.
    '''
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Initialize spins
    spins = torch.randint(0, 2, (N, N), device=device) * 2 - 1
    n = N * N
    
    M_t = spins.sum()
    
    z     = spins.roll(1, dims=0) + spins.roll(-1, dims=0) + \
            spins.roll(1, dims=1) + spins.roll(-1, dims=1)
    E_t = -J * (spins * neighbors).sum()
   
    S, E, M = [], [], []

    for step in range(n_steps):

        i, j = torch.randint(0, N, (2,), device=device)

        z = spins[(i + 1) % N, j] + spins[(i - 1) % N, j] + spins[i, (j + 1) % N] + spins[i, (j - 1) % N]

        dE = 2 * spins[i, j] * (J * z + B)
        dM = 2 * spins[i, j]

        if torch.rand(1, device=device) < torch.exp(-dE / T):
            spins[i, j] *= -1
            E_t += dE
            M_t += dM

        if step % out_freq == 0:
            S.append(spins.clone())
            E.append(E_t / n)
            M.append(M_t / n)

    return torch.stack(S), torch.tensor(E), torch.tensor(M)

In [4]:
from numba import njit
import numpy as np

def init_ising2d(params, spins=None):
    '''Compute energy of spin lattice by taking product of spin 
    lattice with itself shifted in four directions using
    via numpy roll''' 
    
    J, B, N = params['J'], params['B'], params['N']

    if spins==None:
      spins=np.random.choice([-1, 1], size=(N, N))
    
    mag = np.sum(spins)

    z = np.roll(spins, 1, axis = 0) + np.roll(spins, -1, axis = 0) + \
        np.roll(spins, 1, axis = 1) + np.roll(spins, -1, axis = 1)

    eng = np.sum( -J*spins*z/2 )-B*mag
    
    return spins, eng, mag
    

@njit
def run_ising2d(spins, eng, mag, N,J, B, T, n_steps, out_freq):
    '''Basic Metropolis Monte Carlo simulator of 2D Ising model
    ---
    spins:    (int, int) 2D numpy array
    T, J, B:  (floats) corresponding to temperature, coupling and field variables
    n_steps:  (int), simulation steeps
    out_freq: (int), How often to compute and save data
    ---
    Returns:
    E/N: per-spin energy over n steps  
    M/N: per-spin magnetization over n steps 
    S:   2D spin configurations over n steps 
    '''

    #### Run MC Simulation
    S, E, M = [], [], []
    N=len(spins)
    for step in range(n_steps):
        
        # Pick random spin
        i, j = np.random.randint(N), np.random.randint(N)

        # Compute energy change resulting from a flip of spin at i,j
        z  = spins[(i+1)%N, j] + spins[(i-1)%N, j] + spins[i, (j+1)%N] + spins[i, (j-1)%N] 
        dE = 2*spins[i,j]*(J*z + B)
        dM = 2*spins[i,j]

        # Metropolis condition
        if np.exp(-dE/T) > np.random.rand():
            spins[i,j] *= -1 
            eng        += dE
            mag        += dM

        # Save Thermo data
        if step % out_freq == 0:
            M.append(eng/N**2)
            E.append(mag/N**2)
            S.append(spins.copy())

    return S, E, M

In [None]:
#Parameters
params = {'N':20,
          'J':1, 
          'B':0, 
          'T': 4,
          'n_steps': 10000, 
          'out_freq': 10}


# Initialize and Simulate
spins, eng, mag = init_ising2d(params)

S, E, M = run_ising2d(spins, eng, mag, **params)