## Starter notebook for MLP ansatz for Helium

1. Energy - using Hessian,
2. Gradients - using known formula (update manually),
3. Optimization - ADAM.

First, non-symmetric, without Jastrow factor, to slowly add complexity

In [5]:
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from torch import vmap

### Model definition

In [6]:
class MLP(nn.Module):

    def __init__(self, input_dim, n_hidden_layers, hidden_dim, output_size):
        super(MLP, self).__init__()

        layers = []

        # Input layer
        layers.append(nn.Linear(input_dim, hidden_dim))
        layers.append(nn.Tanh())

        # Hidden layers
        for _ in range(n_hidden_layers - 1):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.Tanh())

        # Output layer (no activation here by default)
        layers.append(nn.Linear(hidden_dim, output_size))

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)

In [8]:

input_dim = 6
n_hidden_layers = 2
hidden_dim = 32
output_size = 1

network = MLP(
    input_dim=input_dim,
    n_hidden_layers=n_hidden_layers,
    hidden_dim=hidden_dim,
    output_size=output_size
)

### Metropolis sampling

In [9]:
def metropolis(N: int, n_runs: int, model: nn.Module):  
    """
    Vectorized metropolis loop
    Over N steps, for n_runs. 
    Alphas passes in must be of same dim as n_runs
    """       
    L = 1
    r1 = (torch.rand(n_runs, 3) * 2 * L - L)
    r2 = (torch.rand(n_runs, 3) * 2 * L - L)
    max_steps = 500
    sampled_Xs = []
    rejection_ratio = 0

    for i in tqdm(range(N)):
        chose = torch.rand(n_runs).reshape(n_runs, 1)
        dummy = torch.rand(n_runs)

        perturbed_r1 = r1 + 0.5 * (torch.rand(n_runs, 3) * 2 * L - L)
        perturbed_r2 = r2 + 0.5 * (torch.rand(n_runs, 3) * 2 * L - L)

        r1_trial = torch.where(chose < 0.5, perturbed_r1, r1)
        r2_trial = torch.where(chose >= 0.5, perturbed_r2, r2)
        psi_val = model(torch.cat((r1, r2), axis=1)).squeeze()
        psi_trial_val = model(torch.cat((r1_trial, r2_trial), axis=1)).squeeze()

        
        psi_ratio = (psi_trial_val / psi_val) ** 2

        dummy_comp = psi_ratio > dummy

        condition = dummy_comp

        rejection_ratio += torch.where(condition, 1./N, 0.0)

        condition = condition.reshape(condition.shape[0], 1)

        r1 = torch.where(condition, r1_trial, r1)
        r2 = torch.where(condition, r2_trial, r2)
                
        if i > max_steps:
            sampled_Xs.append(torch.cat((r1, r2), axis=1))

    return torch.stack(sampled_Xs)

### Local energy

In [18]:
# Start with the simplest one - all the positions

def local_energy(positions):
    # positions: [batch_size, 6] with [r1x, r1y, r1z, r2x, r2y, r2z]
    positions.requires_grad_(True)
    psi = network(positions).squeeze()

    # Gradient of log_psi w.r.t positions
    grads = torch.autograd.grad(psi.sum(), positions, create_graph=True)[0]

    # Laplacian: second derivative (sum of second partials)
    laplacian = 0
    for i in range(positions.shape[1]):
        grad_i = grads[:, i]
        grad2 = torch.autograd.grad(grad_i.sum(), positions, create_graph=True)[0][:, i]
        laplacian += grad2

    # Kinetic energy
    kinetic = -0.5 * (laplacian) / psi 

    # Reshape positions
    r1 = positions[:, 0:3]
    r2 = positions[:, 3:6]
    r1_norm = torch.norm(r1, dim=1)
    r2_norm = torch.norm(r2, dim=1)
    r12 = torch.norm(r1 - r2, dim=1)

    potential = -2 / r1_norm - 2 / r2_norm + 1 / r12

    E_local = kinetic + potential
    return E_local


### Gradients

In [None]:
from functorch import make_functional, vmap, grad

def get_gradients_walkers(x, E, network):
    """
    x: [n_walkers, batch_size, 6]
    E: [n_walkers, batch_size]
    Returns:
        gradients: [n_walkers, n_parameters]
    """
    n_walkers, batch_size, _ = x.shape

    # Convert model to functional
    fmodel, params = make_functional(network)

    # Define single-sample log Ψ
    def log_psi_single(params, x):
        return fmodel(params, x)

    # Compute ∇θ log Ψ per sample
    grads_per_sample = vmap(grad(log_psi_single), in_dims=(None, 0))(params, x)

    # Flatten parameter structure: list of tensors → single [n_walkers * batch_size, n_params]
    grads_flat = torch.cat([g.reshape(x_flat.shape[0], -1) for g in grads_per_sample], dim=1)

    # Reshape grads and E to [n_walkers, batch_size, n_params]
    grads_flat = grads_flat.view(n_walkers, batch_size, -1)
    E = E.view(n_walkers, batch_size)

    # Mean gradients and energies per walker
    mean_grad = grads_flat.mean(dim=1)      # [n_walkers, n_params]
    mean_E = E.mean(dim=1, keepdim=True)    # [n_walkers, 1]

    # Compute centered quantities
    centered_E = E - mean_E                 # [n_walkers, batch_size]
    centered_grads = grads_flat - mean_grad.unsqueeze(1)  # broadcast

    # Final VMC gradient: [n_walkers, n_params]
    vmc_grad = torch.mean(centered_E.unsqueeze(2) * centered_grads, dim=1)

    return vmc_grad  # [n_walkers, n_parameters]


In [118]:
fmodel, params = make_functional(network)

  warn_deprecated('make_functional', 'torch.func.functional_call')


In [137]:
def inf(x):
    return fmodel(params, x)

In [130]:
t = vmap(grad(inf), in_dims=(None, 0))

  warn_deprecated('grad')
  warn_deprecated('vmap', 'torch.vmap')


In [144]:
def mode(params, x):
    return fmodel(params, x).squeeze()

In [147]:
mode(params, x).shape

torch.Size([4499])

In [151]:
x.requires_grad_(True)

tensor([[ -8.1484,   4.6863,  -1.5537,  -0.3887,   1.7971,   2.5163],
        [ -8.1484,   4.6863,  -1.5537,  -0.1113,   1.4717,   2.9053],
        [ -8.1484,   4.6863,  -1.5537,   0.2571,   1.7718,   3.0985],
        ...,
        [-17.7867, -13.3441, -11.6457, -10.5974,   9.3614,   4.9027],
        [-17.7867, -13.3441, -11.6457, -10.4786,   9.2392,   4.6427],
        [-17.7867, -13.3441, -11.6457, -10.8582,   9.0171,   4.4217]],
       requires_grad=True)

### Calling and testing

In [159]:
from functorch import make_functional

fmodel, params = make_functional(network)

def f(params, x):
    return fmodel(params, x.unsqueeze(0)).squeeze()

grad_f = grad(f)
grads = vmap(grad_f, in_dims=(None, 0))(params, x)  # list of [batch_size, ...]


  warn_deprecated('make_functional', 'torch.func.functional_call')
  warn_deprecated('grad')
  warn_deprecated('vmap', 'torch.vmap')


In [172]:
for p in params:
    print(p[0].grad)

None
None
None
None
None
None


  print(p[0].grad)


In [165]:
flat_grads = torch.cat([g.view(-1) for g in grads])

In [167]:
flat_grads.shape

torch.Size([5907187])

In [162]:
grads[0].shape

torch.Size([4499, 32, 6])

In [163]:
grads[1].shape

torch.Size([4499, 32])

In [164]:
len(grads)

6

In [158]:
grads

tensor([[-0.0283, -0.0063,  0.0159, -0.0124, -0.0445,  0.0456],
        [-0.0109, -0.0166,  0.0169,  0.0162, -0.0520,  0.0775],
        [-0.0458, -0.0235,  0.0074,  0.0154, -0.0717,  0.1039],
        [-0.0499, -0.0336, -0.0085,  0.0265, -0.0578,  0.0791],
        [ 0.0020, -0.0320,  0.0385,  0.0206, -0.0658,  0.0646],
        [-0.0322, -0.0060,  0.0163,  0.0084, -0.0385,  0.0681],
        [-0.0312, -0.0502,  0.0234, -0.0011, -0.0780,  0.0813],
        [-0.0636, -0.0264, -0.0113,  0.0265, -0.0612,  0.0998],
        [-0.0243, -0.0439,  0.0143,  0.0185, -0.0659,  0.0828],
        [-0.0272, -0.0013, -0.0066,  0.0192, -0.0511,  0.0587]],
       grad_fn=<SqueezeBackward2>)

In [154]:
grads

tensor([[ 4.0234, -1.2570, -2.6814,  0.7713,  1.9691],
        [-4.0000, -2.2121, -0.6792, -0.9052, -1.9934],
        [ 0.8249,  1.9552, -2.2715, -0.8959, -0.3449],
        [-2.1716, -0.6094,  1.7183,  0.0491, -2.6119],
        [ 2.0293,  3.6102,  1.3915,  3.0094, -0.2820],
        [-2.3896, -3.3258,  0.4831,  0.4983, -0.9949],
        [ 0.5123, -2.2230, -2.4487,  3.0820,  1.6691],
        [ 0.9562, -1.6421,  2.3684, -1.9525, -2.6460],
        [ 4.3113,  0.1937, -4.5661, -0.8182,  3.2101],
        [ 0.0471,  4.0138, -0.4707,  0.2721,  0.0478]])

In [20]:
N = 5000
n_walkers = 50

In [None]:
sampled_Xs = metropolis(N, n_walkers, network)

100%|██████████| 5000/5000 [00:02<00:00, 1981.40it/s]


In [21]:
local_E = torch.stack([local_energy(sampled_Xs[:, i]) for i in range(n_walkers)])

In [111]:
reshaped_Xs = sampled_Xs.permute(1, 0, 2)

In [113]:
reshaped_Xs.shape

torch.Size([50, 4499, 6])

In [117]:
local_E.shape

torch.Size([50, 4499])

In [116]:
grads = get_gradients_walkers(reshaped_Xs, local_E, network)

  warn_deprecated('make_functional', 'torch.func.functional_call')


RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

In [59]:
x = sampled_Xs[:, 0]

In [60]:
x.shape

torch.Size([4499, 6])

In [72]:
model = network
E = local_E
x = sampled_Xs[:, 0]

In [79]:
x.shape

torch.Size([4499, 6])

In [80]:
N

5000

In [92]:
local_E[0].shape

torch.Size([4499])

In [96]:
batch_size = 4499
param_shapes = [p.shape for p in network.parameters()]
n_params = sum(p.numel() for p in network.parameters())

grads = torch.zeros(batch_size, n_params)

for i in range(batch_size):
    xi = x[i].detach().clone().requires_grad_(True)

    log_psi_i = network(xi).squeeze()
    grad_i = torch.autograd.grad(
        log_psi_i, 
        network.parameters(),
        retain_graph=True,
        create_graph=True
    )

# Flatten and concat all parameter gradients for this sample
grad_i_flat = torch.cat([g.reshape(-1) for g in grad_i])
grads[i] = grad_i_flat

mean_grad_psi = torch.mean(grads)
mean_E = torch.mean(E)

part_a = grads - mean_grad_psi
part_b = E - mean_E

t = torch.mean(part_a.T * part_b[0], axis=1)


In [None]:
t.shap

torch.Size([1313])

In [83]:
grads.shape

torch.Size([4499, 1313])

In [84]:
part_b.shape

torch.Size([50, 4499])

In [None]:
grads.T

torch.Size([1313, 4499])