In [1]:
import torch
import numpy as np
from tqdm import tqdm
from analytical_expressions import local_energy
from torch.autograd.functional import jacobian
from torch.func import jacrev
import matplotlib.pyplot as plt
from torch.func import vmap

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
def psi(X):
    x = X[:3]
    y = X[3:6]
    alpha_1, alpha_2, alpha_3, alpha_4 = X[6:]
    r1 = torch.norm(x)
    r2 = torch.norm(y)
    r12 = torch.norm(x - y)

    term1 = torch.exp(-2 * (r1 + r2))
    term2 = 1 + 0.5 * r12 * torch.exp(-alpha_1 * r12)
    term3 = 1 + alpha_2 * (r1 + r2) * r12 + alpha_3 * (r1 - r2)**2 - alpha_4 * r12

    return term1 * term2 * term3

In [4]:
psi_vec = vmap(psi)

In [76]:
def metropolis(N: int, n_runs: int, alphas: torch.tensor):  
    """
    Vectorized metropolis loop
    Over N steps, for n_runs. 
    Alphas passes in must be of same dim as n_runs
    """  
    assert alphas.shape[0] == n_runs        
    L = 1
    r1 = (torch.rand(n_runs, 3, requires_grad=True) * 2 * L - L)
    r2 = (torch.rand(n_runs, 3, requires_grad=True) * 2 * L - L)
    max_steps = 500
    sampled_Xs = []
    rejection_ratio = 0

    for i in tqdm(range(N)):
        chose = torch.rand(n_runs).reshape(n_runs, 1)
        dummy = torch.rand(n_runs)

        perturbed_r1 = r1 + 0.5 * (torch.rand(n_runs, 3) * 2 * L - L)
        perturbed_r2 = r2 + 0.5 * (torch.rand(n_runs, 3) * 2 * L - L)

        r1_trial = torch.where(chose < 0.5, perturbed_r1, r1)
        r2_trial = torch.where(chose >= 0.5, perturbed_r2, r2)
        psi_val = psi_vec(torch.cat((r1, r2, alphas), axis=1))
        psi_trial_val = psi_vec(torch.cat((r1_trial, r2_trial, alphas), axis=1))      
        psi_ratio = psi_trial_val / psi_val

        density_comp = psi_trial_val >= psi_val
        dummy_comp = dummy < psi_ratio

        condition = density_comp + dummy_comp

        rejection_ratio += torch.where(condition, 1./N, 0.0)

        condition = condition.reshape(condition.shape[0], 1)

        # Careful with overwriting
        r1 = torch.where(condition, r1_trial, r1)
        r2 = torch.where(condition, r2_trial, r2)
                
        if i > max_steps:
            sampled_Xs.append(torch.cat((r1, r2, alphas), axis=1))

    return torch.stack(sampled_Xs)

In [6]:
local_e_vec = vmap(local_energy)
local_e_vec_vec = vmap(local_e_vec)

def get_local_energies(X):
    reshaped_X = X.reshape(
        X.shape[1], X.shape[0], X.shape[2])
    return local_e_vec_vec(reshaped_X)

def get_mean_energies(E):
    return torch.mean(torch.mean(E, dim=1))

In [7]:
def dE_dalpha(input):
    return jacrev(local_energy)(input)

dE_dalpha_vec = vmap(dE_dalpha)
dE_dalpha_vec_vec = vmap(dE_dalpha_vec)

def get_dE_dX(X):
    reshaped_X = X.reshape(
        X.shape[1], X.shape[0], X.shape[2])
    return dE_dalpha_vec_vec(X)

In [21]:
from gradient_expressions import get_psi_alpha

def get_gradients_from_expression(X_, E_):
    psi_alpha = vmap(get_psi_alpha)(X_)

    part_1 = psi_alpha - torch.mean(psi_alpha, axis=0)
    part_2 = E_ - torch.mean(E_)
    return torch.mean(part_1.T * part_2, axis=1)

dE_dalpha = vmap(get_gradients_from_expression)

In [8]:
alpha_1 = torch.tensor(1.013, dtype=torch.float64, requires_grad=True) # 1.013
alpha_2 = torch.tensor(0.2119, dtype=torch.float64, requires_grad=True)
alpha_3 = torch.tensor(0.1406, dtype=torch.float64, requires_grad=True)
alpha_4 = torch.tensor(0.003, dtype=torch.float64, requires_grad=True)

In [11]:
device = torch.device("cuda")
cpu = torch.device("cpu")

## Start of simulation

In [32]:
alpha_1 = torch.tensor(1.013, dtype=torch.float64, requires_grad=True) # 1.013
alpha_2 = torch.tensor(0.2119, dtype=torch.float64, requires_grad=True)
alpha_3 = torch.tensor(0.1406, dtype=torch.float64, requires_grad=True)
alpha_4 = torch.tensor(0.003, dtype=torch.float64, requires_grad=True)

In [100]:
n_steps = 200
mc_steps = 5000
alphas = torch.tensor([alpha_1, alpha_2, alpha_3, alpha_4]).unsqueeze(0).repeat(n_steps, 1)
sampled_Xs = metropolis(mc_steps, n_steps, alphas=alphas)

100%|██████████| 5000/5000 [00:06<00:00, 823.01it/s]


In [101]:
E = get_local_energies(sampled_Xs.to(device))
mean_E = get_mean_energies(E.to(cpu))
print(f"Mean energy is {torch.mean(torch.mean(E, axis=1))}")

Mean energy is -2.9016276879729537


In [102]:
torch.mean(E - torch.mean(E, axis=1).unsqueeze(1).repeat(1, E.shape[1]), axis=1)

tensor([-2.7559e-16,  2.7165e-16, -4.2642e-16, -2.9218e-16,  2.8902e-16,
        -9.0022e-17,  6.0015e-17, -1.6583e-16, -1.9268e-16, -9.4760e-18,
        -1.3977e-16, -3.5535e-18, -2.9297e-16,  2.4578e-16,  7.8967e-17,
         1.5793e-17, -5.2118e-17, -1.9386e-16,  9.4760e-18,  2.8744e-16,
        -5.6382e-16, -4.7656e-16, -3.0007e-17, -3.4745e-17, -2.1005e-16,
        -3.8733e-16, -1.1529e-16,  1.9584e-16,  1.2635e-16, -2.5269e-16,
        -3.2060e-16, -3.5693e-16,  2.9455e-16, -2.2032e-16, -3.5535e-17,
        -1.2003e-16, -1.7057e-16,  9.1601e-17, -3.3166e-16, -4.9038e-16,
        -2.8862e-16, -3.7272e-16, -1.5793e-16, -3.3877e-16, -3.4745e-17,
         3.6562e-16, -2.9376e-16, -1.0266e-16,  2.2979e-16, -1.1806e-16,
        -2.4361e-16, -2.0531e-17,  4.1142e-16, -2.3374e-16, -2.6888e-16,
        -9.4760e-17,  1.5162e-16, -2.5032e-16, -4.2326e-16, -2.2900e-17,
        -1.7531e-16,  4.5801e-17, -1.3030e-16, -2.4203e-16, -5.4566e-16,
        -1.6820e-16,  2.7875e-16, -1.9110e-16,  1.5

In [79]:
E.shape

torch.Size([30, 49499])

In [87]:
torch.mean(E[5])

tensor(-2.8646, dtype=torch.float64, grad_fn=<MeanBackward0>)

In [68]:
E.shape

torch.Size([10, 148999])

In [103]:
torch.mean(torch.mean(E, axis=1) - torch.mean(torch.mean(E, axis=1), axis=0))

tensor(1.4433e-16, device='cuda:0', dtype=torch.float64,
       grad_fn=<MeanBackward0>)

In [65]:
E[0] - torch.mean(E[0])

tensor([ 0.4783,  1.3553,  0.7382,  ..., -1.7697, -0.0597,  1.5159],
       dtype=torch.float64, grad_fn=<SubBackward0>)

In [61]:
variance = torch.sqrt((E - (torch.mean(E, axis=1)).unsqueeze(1).repeat(1, 148999)) ** 2)

In [62]:
torch.mean(torch.mean(variance, axis=1), axis=0)

tensor(1.2890, dtype=torch.float64, grad_fn=<MeanBackward1>)

In [43]:
torch.mean(E, axis=1)

tensor([-2.9562, -2.9060, -2.9096, -2.9558, -2.8629, -2.8733, -2.9745, -2.9979,
        -2.9435, -2.9074], dtype=torch.float64, grad_fn=<MeanBackward1>)

In [16]:
## Getting gradients
X_ = sampled_Xs[:,0]

In [95]:
gradients = get_dE_dX(sampled_Xs)

In [80]:
gradients.shape

torch.Size([8999, 5, 10])

In [87]:
reshaped_X = sampled_Xs.reshape(
        sampled_Xs.shape[1], sampled_Xs.shape[0], sampled_Xs.shape[2])

Energy value should be −2.901188

The actual value is −2.9037243770

## Optimization loop

In [10]:
import math

In [104]:
E_true = -2.9037243770

In [None]:
alpha_1 = torch.tensor(0.013, dtype=torch.float64, requires_grad=True) # 1.013
alpha_2 = torch.tensor(0.6419, dtype=torch.float64, requires_grad=True) # 0.2119
alpha_3 = torch.tensor(0.1406, dtype=torch.float64, requires_grad=True) # 0.1406
alpha_4 = torch.tensor(0.103, dtype=torch.float64, requires_grad=True) # 0.003

In [106]:
device = torch.device("cuda")
cpu = torch.device("cpu")

In [112]:
# Naive approach - define loss as true energy - found energy

epochs = 2000
alphas = [alpha_1, alpha_2, alpha_3, alpha_4]
losses = []
n_walkers = 200
met_steps = 5000
optimizer = torch.optim.Adam(alphas, lr=0.001)

for i in range(epochs):

    alphas_metropolis = torch.tensor(alphas).unsqueeze(0).repeat(n_walkers, 1)
    sampled_Xs = metropolis(met_steps, n_walkers, alphas=alphas_metropolis)

    with torch.no_grad():
        E = get_local_energies(sampled_Xs.to(device))
        mean_E = get_mean_energies(E)
        loss = torch.abs(E_true - mean_E)

        print(f"Mean energy is {mean_E}")
        print(f"Loss is {loss}")
        losses.append(loss.item())

        reshaped_X = sampled_Xs.reshape(
            sampled_Xs.shape[1], sampled_Xs.shape[0], sampled_Xs.shape[2])
        gradients = dE_dalpha(reshaped_X.to(device), E).to(cpu)

        # Now, assume you already have gradients computed externally:
        # Example: for step t, these are your gradients (replace with actual values)
        gradients = torch.mean(gradients, axis=0)
        external_grads = gradients.detach()

        # Step 3: Assign gradients manually
        for p, g in zip(alphas, external_grads):
            p.grad = g  # assign your externally computed gradient

    # Step 4: Optimizer step
    optimizer.step()
    optimizer.zero_grad()

    torch.cuda.empty_cache()
    del sampled_Xs
    del reshaped_X
    del E




100%|██████████| 5000/5000 [00:06<00:00, 813.79it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 

In [26]:
print(gradients)

tensor([-0.0752,  3.5492, -0.0362, -0.6221], dtype=torch.float64,
       grad_fn=<MeanBackward1>)


In [27]:
print(alphas)

[tensor(0.5331, dtype=torch.float64, requires_grad=True), tensor(-0.0537, dtype=torch.float64, requires_grad=True), tensor(0.7332, dtype=torch.float64, requires_grad=True), tensor(-0.0505, dtype=torch.float64, requires_grad=True)]


In [137]:
alpha_1 = torch.tensor(2.013, dtype=torch.float64, requires_grad=True) # 1.013
alpha_2 = torch.tensor(0.6419, dtype=torch.float64, requires_grad=True) # 0.2119
alpha_3 = torch.tensor(2.1406, dtype=torch.float64, requires_grad=True) # 0.1406
alpha_4 = torch.tensor(3.003, dtype=torch.float64, requires_grad=True) # 0.003

In [113]:
alpha_1

tensor(2.0130, dtype=torch.float64, requires_grad=True)

In [106]:
external_grads

tensor([-0.0080,  0.0429,  0.0155, -0.0305], dtype=torch.float64)

## Gradient values

In [34]:
def dE_dalpha(input):
    return jacrev(local_energy)(input)

t = dE_dalpha_vec(torch.stack(inputs_arr[0]))

In [72]:
dE_dalpha_mean = torch.mean(t, axis=0)

In [36]:
psi_vmap = vmap(psi)

In [38]:
psi_values = psi_vmap(torch.stack(inputs_arr[0]))

In [61]:
mean_energy = sum(energies[0])/(len(energies[0]))

In [63]:
El_Etheta = energies[0] - mean_energy

In [65]:
mean_psi = torch.mean(psi_values)

In [73]:
dE_dalpha_mean.shape

torch.Size([10])

In [86]:
t[0].shape

torch.Size([10])

In [89]:
psi_values.shape

torch.Size([9500])

In [92]:
psi_dalph = torch.stack([psi_values[i] * t[i] for i in range(len(t))])

In [97]:
psi_dalph.shape

torch.Size([9500, 10])

In [98]:
dE_dalpha_mean.shape

torch.Size([10])

In [95]:
mean_psi

tensor(0.0207, dtype=torch.float64, grad_fn=<MeanBackward0>)

In [77]:
t.shape

torch.Size([9500, 10])

In [111]:
a = psi_values.unsqueeze(1).repeat(1, 10) * t

In [110]:
b = (mean_psi * dE_dalpha_mean).unsqueeze(0).repeat(9500, 1)

In [119]:
c = (energies[0] - mean_energy).unsqueeze(1).repeat(1, 10)

In [115]:
mean_energy.shape

torch.Size([])

In [120]:
gradients = (a - b) * (c)

In [124]:
torch.mean(gradients, axis=0)

tensor([ 0.0652,  0.4838,  0.4329,  0.0146,  0.1017,  0.0883,  0.0167, -0.0222,
        -0.0307,  0.0271], dtype=torch.float64, grad_fn=<MeanBackward1>)

In [52]:
energies[0][1]

tensor(-1.6312, dtype=torch.float64, grad_fn=<SelectBackward0>)

In [53]:
E_fixed = [energies[0][i] / psi_values[i] for i in range(len(inputs_arr[0]))]

In [55]:
torch.mean(torch.stack(E_fixed))

tensor(-2546.7772, dtype=torch.float64, grad_fn=<MeanBackward0>)