# Example demonstrating optimization of the CS 2D model with TF ansatz

In [2]:
import sys
from pathlib import Path
repo_root = Path().resolve().parent 
sys.path.append(str(repo_root))

In [3]:
import jax
from jax import numpy as jnp
from src.transformer_nqfs import TransformerNQFS
from src.optimization import *
from utils.plot_utils import *
jax.config.update("jax_enable_x64", True)

# NQFS initialization

In [7]:
num_layers = 2          # Number of transformer layers in the model
embed_dim = 10          # Dimension of the embedding vector
head_dim  = 25          # Dimension of each attention head
dim_feedforward = 100   # Dimension of the feedforward network model
dim_out = 75            
L = 10.0                 # System size
periodic = False       # Boundary conditions
phys_dim = 2            # Physical dimension 
m = 1/2                 # Mass of the particles           
seed = 43               # Random seed
n_max = 10              # maximum number of particles
jastrow_type = "CS2D"    # Type of Jastrow factor used in the model
embed_type="Gaussian"       # Type of embedding used in the model
k = L/5 # embedding parameter (if used)
g = 5.0              # interaction strength in jastrow factor (if used)

q_n_mean_init = 5.0
q_n_inv_softplus_width_init = 3.0
q_n_inv_softplus_slope_init = 3.0


nqfs = TransformerNQFS(
    num_layers, embed_dim, head_dim, dim_feedforward, dim_out, 
    L, periodic, phys_dim, m, g, jastrow_type, embed_type, k, 
    q_n_mean_init, q_n_inv_softplus_width_init, q_n_inv_softplus_slope_init
)

# Initialize the model
rng = jax.random.PRNGKey(seed)
rng, inp_rng, init_rng = jax.random.split(rng, 3)
inp = jax.random.uniform(inp_rng, (n_max, phys_dim))
mask_valid = ~jnp.isnan(inp)
mask_valid = jnp.any(mask_valid, axis=1)
inp = jnp.nan_to_num(inp)
params = nqfs.init(init_rng, inp, mask_valid)
nqfs.apply(params, inp, mask_valid)

# Total number of parameters in the model
from jax.tree_util import tree_leaves
print("Number of parameters: ", sum(l.size for l in tree_leaves(params)))

Number of parameters:  129854


# Definition of the external (V) and interaction (W) potentials

In [8]:
from itertools import product
omega = 1
mu = 25.0
                      
def V(x, mask_valid):    
    val = jnp.sum(omega**2 * (x - L/2)**2, axis=-1)
    val = jnp.sum(val-mu, where = mask_valid)
    return val.squeeze()


def W(x, mask_valid):  
    n = x.shape[0]
    # 2-BODY TERM
    row_idx, col_idx = jnp.triu_indices(n, k=1) 
    diffs = x[row_idx] - x[col_idx]
    interparticle_seps = jnp.linalg.norm(diffs, axis=-1)
    upper_mask = mask_valid[row_idx] & mask_valid[col_idx]
    W2 = jnp.sum(g/(interparticle_seps)**2, where=upper_mask)

    # # 3-BODY TERM
    triplets = jnp.array([(i, j, k) for i, j, k in product(range(n), repeat=3) if i < j and k != i and k != j])
    i, j, k = triplets[:, 0], triplets[:, 1], triplets[:, 2]
    xi, xj, xk = x[i], x[j], x[k]

    vec_kj = xk - xj
    vec_ki = xk - xi
    rkj = jnp.linalg.norm(vec_kj, axis=-1)**2
    rki = jnp.linalg.norm(vec_ki, axis=-1)**2
    
    f3 = g * jnp.sum(vec_kj * vec_ki, axis=-1) / (rkj * rki)
    
    mask3 = mask_valid[i] & mask_valid[j] & mask_valid[k]
    W3 = jnp.sum(f3, where=mask3) 
    return (W2+W3).squeeze()

# Optimization

In [9]:
from jax.scipy.linalg import solve
solver = lambda A, b: solve(A, b, assume_a="pos")

n_samples = 50
params = nqfs.init(init_rng, inp, mask_valid)

params_dict = {
    "w": 0.05*L,
    "L": L,
    "n_max": n_max,
    "n_0": 2,
    "phys_dim": phys_dim,
    "seed": seed,
    "params": params,
    "n_samples": n_samples,
    "n_chains": 100,
    "warmup": 50,
    "sweep_size": 30,
    "m": m,
    "mu": mu,
    "V": V,
    "W": W,
    "lr": 1e-2,
    "lr_q": 1e-2,
    "n_iters": 100,
    "chunk_size": n_samples,
    "chunk_grads_size": n_samples,
    "optimizer_type": "sgd",
    "optimization": "min_sr",
    "diag_shift": 1e-5,
    "diag_shift_step": None,
    "diag_shift_red": 11,
    "solver": solver,
    "pm": 0.25,
}

Es, E_stds, n_means, n_stds, params, KEs, VEs, WEs = minimize_energy(nqfs.apply, **params_dict)


Iteration: 1/100


KeyboardInterrupt: 

# Save of optimization data and final ansatz parameters

In [None]:
save_Energy_number(Es, E_stds, n_means, n_stds,"../results/TF_CS_2D_E_n_mu_"+str(mu))
save_optimized_params(params, "../results/params_TF_CS_2D_mu_"+str(mu))