In [1]:
import sys
import os
import math

import time
import datetime as dt

import torch
from torch import nn

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from itertools import combinations

# Enable model differentiation by X

In [2]:
class diff_model_enable:
  def __init__(self, model, X):
    self.model = model
    self.X = X
  
  def __enter__(self):
    self.X.requires_grad = True
  
  def __exit__(self, exc_type, exc_val, exc_tb):
    self.X.requires_grad = False

In [3]:
def diff2(f, X):
  grads = torch.autograd.grad(
    outputs=f, inputs=X, grad_outputs=torch.ones_like(f),
    create_graph=True
  )[0]
  grad_grad = []
  for i in range(X.shape[1]):
    df = grads[:,i]
    df2 = torch.autograd.grad(
      outputs=df, inputs=X, grad_outputs=torch.ones_like(df),
      create_graph=True
    )[0][:,i]
    grad_grad.append(df2)
  return torch.stack(grad_grad, dim=1)

# Metropolis sampler

In [4]:
class MetropolisSampler():
  """Class for Metropolis sampler"""

  def __init__(self, dim1, dim2, epsilon=0.1, device='cpu'):
    self.epsilon = epsilon
    self.device = device
    self.sample = self.initialSample(dim1, dim2)

  def initialSample(self, dim1, dim2) -> torch.Tensor:
    return 3 * (
      torch.rand(
        (
          dim1, dim2
        )
      ) - 0.5
    ).to(self.device)

  def updateSampleBasOnDistrDens(self, __distributionDensity):
    newSample = self.sample + self.epsilon * (
      2 * torch.rand_like(
        self.sample,
        device=self.device
      ) - 1
    )
    critVal = __distributionDensity(newSample) / __distributionDensity(self.sample)
    doesPointMove = (torch.rand(len(self.sample), device=self.device) <= critVal)
    newSample = (
      torch.mul(doesPointMove.int(), newSample.t()).t()
      + torch.mul((1 - doesPointMove.int()), self.sample.t()).t()
    )
    self.sample = newSample

  def updateAndGetSample(self, __distributionDensity):
    for i in range(10):
      self.updateSampleBasOnDistrDens(__distributionDensity)
    return self.sample

# Integrate

In [5]:
def integrate(f, x: torch.Tensor, density) -> torch.Tensor:
  return torch.mean(f / density)

# Trial function

In [6]:
class TrialFunction(nn.Module):
  def __init__(self, dim_coord, num_states, potential, name,
               num_hidden_layers=3, num_hidden_neurons=60,
               init_mean_weights=0.0, init_std_weights=math.sqrt(0.1),
               activ_fnc=nn.Tanh()):
    super(TrialFunction, self).__init__()
    
    self.dim_coord = dim_coord
    self.num_states = num_states
    self.potential = potential

    self.name = name
    
    self.num_hidden_layers = num_hidden_layers
    self.num_hidden_neurons = num_hidden_neurons
    self.activ_fnc = activ_fnc
    # Layers
    self.layers = nn.Sequential()
    self.layers.append(nn.Linear(dim_coord, num_hidden_neurons))
    self.layers.append(activ_fnc)
    for layer in range(num_hidden_layers-1):
      self.layers.append(nn.Linear(num_hidden_neurons, num_hidden_neurons))
      self.layers.append(activ_fnc)
    # Gaussian weights
    self.gaussian_weights= nn.Linear(dim_coord, num_hidden_neurons, bias=False)
    # Out layer
    self.out_layer = nn.ModuleList(
      [nn.Linear(num_hidden_neurons, 1, bias=False) for state in range(num_states)]
    )
    # Initialise weigths
    self.init_weights(init_mean_weights, init_std_weights)

  def init_weights(self, init_mean_weights, init_std_weights):
    for layer in range(len(self.layers)):
      if 'weight' in dir(self.layers[layer]):
        nn.init.normal_(self.layers[layer].weight, init_mean_weights, init_std_weights)
      if 'bias' in dir(self.layers[layer]):
        nn.init.normal_(self.layers[layer].bias, init_mean_weights, init_std_weights)

    for layer in range(len(self.out_layer)):
      if 'weight' in dir(self.out_layer[layer]):
        nn.init.normal_(
          self.out_layer[layer].weight, 
          1.0/math.sqrt(self.num_hidden_neurons), #init_mean_weights, 
          0.01 #init_std_weights
        )

  def forward(self, x):
    z = self.layers(x)
    
    sqr_gauss_weights = self.gaussian_weights.weight*self.gaussian_weights.weight
    gauss_kernel = torch.exp(-torch.matmul(x*x, torch.transpose(sqr_gauss_weights, 0, 1)))

    z_prime = gauss_kernel * z
    return [psi_n(z_prime) for psi_n in self.out_layer]

  def weigth_function(self, x):
    forward_ = self.forward(x)
    result = torch.zeros(len(x)).to(x.device)
    for psi_n in forward_:
      result += psi_n.squeeze(1) ** 2
    result /= len(forward_)
    return result

# Losses

In [7]:
def cache_psi(cache_dict):
  def wrapped(func):
    id_val = cache_dict
    def inner(*args):
      curr = id(args[0])
      if curr not in id_val:
        id_val[curr] = func(*args)
      return id_val[curr]
    return inner
  return wrapped  

CacheHamilt = dict()
@cache_psi(CacheHamilt)
def Hamiltonian(psi, x, potential):
  return (
    -0.5 * torch.sum(diff2(psi, x), dim=1).unsqueeze(1)
    + (potential(x) * psi.squeeze(1)).unsqueeze(1)
  )

CacheRayleigh = dict()
@cache_psi(CacheRayleigh)
def Rayleigh(psi, x, potential, density):
  return (
    integrate(psi*Hamiltonian(psi, x, potential), x, density)
    / integrate(psi*psi, x, density)
  )

def SqrResidual(output, x, potential, density):
  return sum(
    integrate((Hamiltonian(psi, x, potential) - Rayleigh(psi, x, potential, density)*psi)**2, x, density)
    / integrate(psi*psi, x, density) for psi in output
  )

def SqrResidualForPsi(psi, x, potential, density):
  return (
    integrate((Hamiltonian(psi, x, potential) - Rayleigh(psi, x, potential)*psi)**2, x, density)
    / integrate(psi*psi, x, density)
  ) 

def NormCond(trial_fnc):
  return sum((torch.sum(A*A, dim=1) - 1)**2 for A in trial_fnc.out_layer.parameters())

def OrthogonCond(output, x, density):
  result = torch.tensor([0.0]).to(x.device)
  for psi_n, psi_m in combinations(output, 2):
    result += (
      torch.square(integrate(psi_n*psi_m, x, density)) 
      / integrate(psi_n*psi_n, x, density) 
      / integrate(psi_m*psi_m, x, density)
    )
  return result

# Checkpoint

In [8]:
def save_checkpoint(model, optimizer, step, loss):
  checkpoint = {
    'model_state_dict' : model.state_dict(),
    'optimizer_state_dict' : optimizer.state_dict(),
    'step' : step,
    'loss' : loss
  }
  fn = f"checkpoints/{model.name}.pt"
  torch.save(checkpoint, fn)

# Test (He)

In [9]:
dim_coord = 6
num_states = 16
potential = lambda x: (
  -2/torch.sqrt(torch.sum(x[:,:3]**2, dim=1))
  -2/torch.sqrt(torch.sum(x[:,3:]**2, dim=1))
  +1/torch.sqrt(torch.sum((x[:,:3]-x[:,3:])**2, dim=1))
)
num_hidden_layers = 3
num_hidden_neurons = 60
init_mean_weights = 0.0
init_std_weights = math.sqrt(0.1)
activ_fnc = nn.Tanh()

alpha = 2
beta = 1
gamma = 40
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
S = 2000
sampler = MetropolisSampler(dim1=S, dim2=dim_coord, epsilon=0.1, device=device)

lr = 1e-3
weight_decay = 5e-4

CHECKPOINT_PERIOD = 10

MAX_SQR_RESIDUAL = 1e-3

load_checkpoint = False

name = "test_He"

In [10]:
START_TIME = time.time()
  
hist = {
  'time' : list(),
  'loss' : list(),
  'sqr res' : list(),
  'norm cond' : list(),
  'orthogon cond' : list(),
}
for state in range(num_states):
  hist[f"Rayleigh{state}"] = list()

model = TrialFunction(dim_coord, num_states, potential, name,
               num_hidden_layers, num_hidden_neurons,
               init_mean_weights, init_std_weights,
               activ_fnc).to(device)

optimizer = torch.optim.Adam(
  params=list(model.parameters()),
  lr=lr,
  weight_decay=weight_decay
)

if load_checkpoint:
  checkpoint = torch.load(f"checkpoints/{model.name}.pt")
  model.load_state_dict(checkpoint['model_state_dict'])
  model.train()
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  step = checkpoint['step']
  loss = checkpoint['loss']
  print(f"Load model and optimizer: step {checkpoint['step']}, loss {checkpoint['loss']}")
  # history
  hist_df = pd.read_csv(f"history/{model.name}.csv")
  hist['time'] = hist_df['time'].to_list()
  hist['loss'] = hist_df['loss'].to_list()
  hist['sqr res'] = hist_df['sqr res'].to_list()
  hist['norm cond'] = hist_df['norm cond'].to_list()
  hist['orthogon cond'] = hist_df['orthogon cond'].to_list()
  for state in range(num_states):
    hist[f"Rayleigh{state}"] = hist_df[f"Rayleigh{state}"].to_list()

x = sampler.updateAndGetSample(model.weigth_function)
step = 0

valid_states = 0
while (valid_states < num_states):
  step += 1
  x = sampler.updateAndGetSample(model.weigth_function)

  with diff_model_enable(model, x):
    output = model(x)
    density = model.weigth_function(x)
    
    sqr_residual = SqrResidual(output, x, potential, density)
    
    energies = [Rayleigh(psi, x, potential, density) for psi in output]
    energ_cond = alpha * sum(energies)
    
    norm_cond = beta * NormCond(model)
    
    orthogon_cond = gamma * OrthogonCond(output, x, density)
    
    loss = (
      sqr_residual 
      + energ_cond 
      + norm_cond 
      + orthogon_cond
    )

    valid_states = sum(
      (SqrResidualForPsi(psi, x, potential, density).detach().cpu() < MAX_SQR_RESIDUAL) for psi in output
    )
  
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  
  CacheRayleigh.clear()
  CacheHamilt.clear()

  hist['time'] += [time.time() - START_TIME]
  hist['loss'] += [loss.item()]
  hist['sqr res'] += [sqr_residual.item()]
  hist['norm cond'] += [norm_cond.item()]
  hist['orthogon cond'] += [orthogon_cond.item()]
  for state in range(num_states):
    hist[f"Rayleigh{state}"] += [energies[state].item()]

  if step == 1:
    header = (
        "step, time, loss, sqr res,"
        + " norm cond, orthogon cond,"
    )
    for i in range(num_states):
      header += f", Rayleigh{i}"
    print(header)
  
  if (step % CHECKPOINT_PERIOD) == 0 or step == 1:            
    info = (
      f"{step}, "
      + f"{time.time() - START_TIME:.2f}, "
      + f"{loss.item():.2e}, "
      + f"{sqr_residual.item():.2e}, "
      + f"{norm_cond.item():.2e}, "
      + f"{orthogon_cond.item():.2e},"
    )
    for Rayleigh_ in energies:
      info += f", {Rayleigh_.item():.2e}"
    print(info)
      
    hist_df = pd.DataFrame(data=hist)
    hist_df.to_csv(f"history/{model.name}.csv")

    save_checkpoint(model, optimizer, step, loss)

step, time, loss, sqr res, norm cond, orthogon cond,, Rayleigh0, Rayleigh1, Rayleigh2, Rayleigh3, Rayleigh4, Rayleigh5, Rayleigh6, Rayleigh7, Rayleigh8, Rayleigh9, Rayleigh10, Rayleigh11, Rayleigh12, Rayleigh13, Rayleigh14, Rayleigh15
1, 5.37, 4.80e+03, 8.78e+01, 4.38e-03, 4.73e+03,, -4.90e-01, -4.90e-01, -5.21e-01, -4.66e-01, -5.30e-01, -5.40e-01, -5.08e-01, -4.46e-01, -4.94e-01, -5.31e-01, -5.17e-01, -4.80e-01, -4.86e-01, -5.01e-01, -4.53e-01, -5.20e-01
10, 45.82, 2.83e+03, 1.20e+02, 4.67e-03, 2.71e+03,, -2.29e-01, -1.54e-03, 4.22e-01, 4.59e-01, 1.02e-01, 1.90e-02, -1.72e-01, 3.95e-01, 3.03e-02, 8.98e-02, 1.60e-01, 2.57e-01, -1.16e-01, 1.35e-02, 9.11e-02, -2.86e-01
20, 91.30, 1.63e+03, 1.07e+02, 8.15e-03, 1.56e+03,, -1.45e+00, -1.40e+00, -1.23e+00, -1.16e+00, -7.24e-01, -7.96e-01, -8.86e-01, -1.13e+00, -1.25e+00, -5.99e-01, -4.51e-01, -5.56e-01, -9.11e-01, -1.18e+00, -5.20e-01, -1.00e+00
30, 137.07, 1.23e+03, 7.80e+01, 1.33e-02, 1.19e+03,, -1.77e+00, -1.72e+00, -1.62e+00, -1.53e+00, 

KeyboardInterrupt: 

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

hist = pd.read_csv("history/test_He.csv")
xmax = len(hist)
xmin = 0 #xmax - 1000

fig, ax = plt.subplots(2,1,figsize=(16,7))

hist[['sqr res', 'norm cond',
       'orthogon cond']][xmin:xmax].plot(xlim=(xmin, xmax), ax = ax[0])
# ax[0].plot(np.arange(xmin,xmax), np.sum(hist[[f'Rayleigh{i}' for i in range(16)]][xmin:xmax].to_numpy(), axis=1))
ax[0].set_yscale('log')
ax[0].grid(axis='y')

hist[[f'Rayleigh{i}' for i in range(16)]][xmin:xmax].plot(xlim=(xmin, xmax), legend=False, ax = ax[1], alpha=0.7)
ax[1].axhline(y=0, color='k', linestyle='--')
ax[1].grid(axis='y')
# ax[1].set_yscale('symlog')