## Starter notebook for MLP ansatz for Helium

1. Energy - using Hessian,
2. Gradients - using known formula (update manually),
3. Optimization - ADAM.

First, non-symmetric, without Jastrow factor, to slowly add complexity

In [20]:
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from torch import vmap
from functorch import make_functional, vmap, grad

### Model definition

In [2]:
class MLP(nn.Module):

    def __init__(self, input_dim, n_hidden_layers, hidden_dim, output_size):
        super(MLP, self).__init__()

        layers = []

        # Input layer
        layers.append(nn.Linear(input_dim, hidden_dim))
        layers.append(nn.Tanh())

        # Hidden layers
        for _ in range(n_hidden_layers - 1):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.Tanh())

        # Output layer (no activation here by default)
        layers.append(nn.Linear(hidden_dim, output_size))

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)

In [80]:

input_dim = 6
n_hidden_layers = 2
hidden_dim = 32
output_size = 1

model = MLP(
    input_dim=input_dim,
    n_hidden_layers=n_hidden_layers,
    hidden_dim=hidden_dim,
    output_size=output_size
)

In [198]:
device = torch.device('cuda')
cpu = torch.device('cpu')

In [82]:
def psi(x):
    if next(model.parameters()).is_cuda:
        x.to(device)

    return model(x).squeeze()

### Metropolis sampling

In [89]:
def metropolis(N: int, n_runs: int):  

    L = 1
    r1 = (torch.rand(n_runs, 3) * 2 * L - L)
    r2 = (torch.rand(n_runs, 3) * 2 * L - L)
    sampled_Xs = []
    rejection_ratio = 0

    for i in tqdm(range(N)):
        chose = torch.rand(n_runs).reshape(n_runs, 1)
        dummy = torch.rand(n_runs)

        perturbed_r1 = r1 + 0.5 * (torch.rand(n_runs, 3) * 2 * L - L)
        perturbed_r2 = r2 + 0.5 * (torch.rand(n_runs, 3) * 2 * L - L)

        r1_trial = torch.where(chose < 0.5, perturbed_r1, r1)
        r2_trial = torch.where(chose >= 0.5, perturbed_r2, r2)
        psi_val = psi(torch.cat((r1, r2), axis=1))
        psi_trial_val = psi(torch.cat((r1_trial, r2_trial), axis=1))

        
        psi_ratio = (psi_trial_val / psi_val) ** 2

        dummy_comp = psi_ratio > dummy

        condition = dummy_comp

        rejection_ratio += torch.where(condition, 1./N, 0.0)

        condition = condition.reshape(condition.shape[0], 1)

        r1 = torch.where(condition, r1_trial, r1)
        r2 = torch.where(condition, r2_trial, r2)
                
    
        sampled_Xs.append(torch.cat((r1, r2), axis=1))

    return torch.stack(sampled_Xs)

### Local energy

In [99]:
def local_energy(positions):
    
    positions.requires_grad_(True)
    psi_val = psi(positions)

    grads = torch.autograd.grad(psi_val.sum(), positions, create_graph=True)[0]

    laplacian = 0
    for i in range(positions.shape[1]):
        grad_i = grads[:, i]
        grad2 = torch.autograd.grad(grad_i.sum(), positions, create_graph=True)[0][:, i]
        laplacian += grad2

    kinetic = -0.5 * (laplacian) / psi_val 

    r1 = positions[:, 0:3]
    r2 = positions[:, 3:6]
    r1_norm = torch.norm(r1, dim=1)
    r2_norm = torch.norm(r2, dim=1)
    r12 = torch.norm(r1 - r2, dim=1)

    potential = -2 / r1_norm - 2 / r2_norm + 1 / r12

    E_local = kinetic + potential
    return E_local


def get_local_energy(sampled_Xs):

    mc_steps = sampled_Xs.shape[0]
    walkers = sampled_Xs.shape[1]
    reshaped_Xs = sampled_Xs.permute(1, 0, 2) # N_walkers, N, input_dim
    flattened_Xs = reshaped_Xs.flatten(end_dim=1) # N_walkers * N, input_dim
    local_E = local_energy(flattened_Xs) # N_walkers * N, 1
    return local_E.reshape(walkers, mc_steps)
 

### Gradients

In [182]:

def parameter_gradients(x, E, network, n_walkers, mc_steps):
    
    fmodel, params = make_functional(network)

    def psi_func(params, x):
        return fmodel(params, x.unsqueeze(0)).squeeze()

    grad_log_psi = grad(psi_func)

    grads = vmap(grad_log_psi, in_dims=(None, 0))(params, x)
    flat_grads = torch.cat([g.reshape(x.shape[0], -1) for g in grads], dim=1)

    n_parameters = flat_grads.shape[-1]

    mean_grad = flat_grads.mean(dim=1)
    mean_E = E.mean(dim=1, keepdim=True)

    centered_E = E - mean_E
    centered_grads = flat_grads - mean_grad.unsqueeze(1)

    return torch.mean(((centered_grads.T * centered_E.flatten())).reshape(n_walkers, mc_steps, n_parameters), axis=1)

def get_parameter_gradients(sampled_Xs, local_E, network):

    mc_steps = sampled_Xs.shape[0]
    n_walkers = sampled_Xs.shape[1]
    reshaped_Xs = sampled_Xs.permute(1, 0, 2) # N_walkers, N, input_dim
    flattened_Xs = reshaped_Xs.flatten(end_dim=1) # N_walkers * N, input_dim

    return parameter_gradients(flattened_Xs, local_E, network, n_walkers, mc_steps)


### Assigning gradients to model

In [199]:
def assign_gradients_to_model(parameter_gradients, model):
    """Assign a flattened gradient vector to model parameters."""
    pointer = 0
    for p in model.parameters():
        numel = p.numel()
        p.grad = parameter_gradients[pointer:pointer + numel].view_as(p).clone()
        pointer += numel

    return model

### Training loop

In [212]:

for i in range(100):
    N = 5000
    n_walkers = 5

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    sampled_Xs = metropolis(N, n_walkers)[500:]

    local_E = get_local_energy(sampled_Xs)
    print(f"Mean energy is {torch.mean(local_E)}")

    grads = get_parameter_gradients(sampled_Xs, local_E, model.to(cpu))
    model = assign_gradients_to_model(grads[0], model)

    optimizer.step()

100%|██████████| 5000/5000 [00:02<00:00, 1965.27it/s]
  warn_deprecated('make_functional', 'torch.func.functional_call')
  warn_deprecated('grad')
  warn_deprecated('vmap', 'torch.vmap')


Mean energy is -0.28386828303337097


100%|██████████| 5000/5000 [00:02<00:00, 1870.83it/s]


Mean energy is -0.23376259207725525


100%|██████████| 5000/5000 [00:02<00:00, 1906.08it/s]


Mean energy is -0.23676703870296478


100%|██████████| 5000/5000 [00:02<00:00, 1887.01it/s]


Mean energy is -0.28802844882011414


100%|██████████| 5000/5000 [00:02<00:00, 1944.78it/s]


Mean energy is -0.27324503660202026


100%|██████████| 5000/5000 [00:02<00:00, 1773.90it/s]


Mean energy is -0.28660693764686584


100%|██████████| 5000/5000 [00:02<00:00, 1955.43it/s]


Mean energy is -0.21810366213321686


100%|██████████| 5000/5000 [00:02<00:00, 1944.40it/s]


Mean energy is -0.22769327461719513


100%|██████████| 5000/5000 [00:02<00:00, 1936.31it/s]


Mean energy is -0.2886466681957245


100%|██████████| 5000/5000 [00:02<00:00, 1947.11it/s]


Mean energy is -0.27718785405158997


100%|██████████| 5000/5000 [00:02<00:00, 1797.92it/s]


Mean energy is -0.23017017543315887


KeyboardInterrupt: 

Incoporate skipping the first n steps in local energy and gradient calulations