In [1]:
import sys
import os
import math

import time
import datetime as dt

import torch
from torch import nn

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from itertools import combinations

In [None]:
def diff2(f, X):
  DX = 1e-4
  grad_grad = []
  for i in range(X.shape[1]):
    delta_X = torch.zeros_like(X)
    delta_X[:,i] = DX
    df2 = (f(X+delta_X) - 2*f(X) + f(X-delta_X)) / DX**2
    grad_grad.append(df2)
  return torch.stack(grad_grad, dim=1)

In [None]:
# class diff_model_enable:
#   def __init__(self, model, X):
#     self.model = model
#     self.X = X
  
#   def __enter__(self):
#     self.X.requires_grad = True
  
#   def __exit__(self, exc_type, exc_val, exc_tb):
#     self.X.requires_grad = False

In [None]:
# def diff2(f, X):
#   grads = torch.autograd.grad(
#     outputs=f, inputs=X, grad_outputs=torch.ones_like(f),
#     create_graph=True
#   )[0]
#   grad_grad = []
#   for i in range(X.shape[1]):
#     df = grads[:,i]
#     df2 = torch.autograd.grad(
#       outputs=df, inputs=X, grad_outputs=torch.ones_like(df),
#       create_graph=True
#     )[0][:,i]
#     grad_grad.append(df2)
#   return torch.stack(grad_grad, dim=1)

# Metropolis sampler

In [2]:
class MetropolisSampler():
  """Class for Metropolis sampler"""

  def __init__(self, dim1, dim2, epsilon=0.1, device='cpu'):
    self.epsilon = epsilon
    self.device = device
    self.sample = self.initialSample(dim1, dim2)

  def initialSample(self, dim1, dim2) -> torch.Tensor:
    return 3 * (
      torch.rand(
        (
          dim1, dim2
        )
      ) - 0.5
    ).to(self.device)

  def updateSampleBasOnDistrDens(self, __distributionDensity):
    newSample = self.sample + self.epsilon * (
      2 * torch.rand_like(
        self.sample,
        device=self.device
      ) - 1
    )
    critVal = __distributionDensity(newSample) / __distributionDensity(self.sample)
    doesPointMove = (torch.rand(len(self.sample), device=self.device) <= critVal)
    newSample = (
      torch.mul(doesPointMove.int(), newSample.t()).t()
      + torch.mul((1 - doesPointMove.int()), self.sample.t()).t()
    )
    self.sample = newSample

  def updateAndGetSample(self, __distributionDensity):
    for i in range(10):
      self.updateSampleBasOnDistrDens(__distributionDensity)
    return self.sample

# Integrate

In [3]:
def integrate(f, density) -> torch.Tensor:
  return torch.mean(f / density)

# Trial function

In [4]:
class TrialFunction(nn.Module):
  def __init__(self, dim_coord, num_states, potential, name,
               num_hidden_layers=3, num_hidden_neurons=60,
               init_mean_weights=0.0, init_std_weights=math.sqrt(0.1),
               activ_fnc=nn.Sigmoid()):
    super(TrialFunction, self).__init__()
    
    self.dim_coord = dim_coord
    self.num_states = num_states
    self.potential = potential

    self.name = name
    
    self.num_hidden_layers = num_hidden_layers
    self.num_hidden_neurons = num_hidden_neurons
    self.activ_fnc = activ_fnc
    # Layers
    self.layers = nn.Sequential()
    self.layers.append(nn.Linear(dim_coord, num_hidden_neurons))
    self.layers.append(activ_fnc)
    for layer in range(num_hidden_layers-1):
      self.layers.append(nn.Linear(num_hidden_neurons, num_hidden_neurons))
      self.layers.append(activ_fnc)
    # Gaussian weights
    self.gaussian_weights= nn.Linear(dim_coord, num_hidden_neurons, bias=False)
    # Out layer
    self.out_layer = nn.ModuleList(
      [nn.Linear(num_hidden_neurons, 1, bias=False) for state in range(num_states)]
    )
    # Initialise weigths
    self.init_weights(init_mean_weights, init_std_weights)

  def init_weights(self, init_mean_weights, init_std_weights):
    for layer in range(len(self.layers)):
      if 'weight' in dir(self.layers[layer]):
        nn.init.normal_(self.layers[layer].weight, init_mean_weights, init_std_weights)
      if 'bias' in dir(self.layers[layer]):
        nn.init.normal_(self.layers[layer].bias, init_mean_weights, init_std_weights)

    for layer in range(len(self.out_layer)):
      if 'weight' in dir(self.out_layer[layer]):
        nn.init.normal_(
          self.out_layer[layer].weight, 
          init_mean_weights,  #1.0/math.sqrt(self.num_hidden_neurons), 
          init_std_weights
        )

  def forward(self, x):
    z = self.layers(x)
    
    sqr_gauss_weights = self.gaussian_weights.weight*self.gaussian_weights.weight
    gauss_kernel = torch.exp(-torch.matmul(x*x, torch.transpose(sqr_gauss_weights, 0, 1)))

    z_prime = gauss_kernel * z
    return torch.stack([psi_n(z_prime)for psi_n in self.out_layer], dim=1).squeeze(2)

  def weigth_function(self, x):
    forward_ = self.forward(x)
    return torch.mean(forward_**2 / torch.max(forward_**2, dim=0)[0], dim=1).detach()
  
  def laplacian(self, x):
    DX = 1e-4
    grad_grad = []
    for i in range(self.dim_coord):
      delta_x = torch.zeros_like(x)
      delta_x[:,i] = DX
      df2 = (self.forward(x+delta_x) - 2 * self.forward(x) + self.forward(x-delta_x)) / DX**2
      grad_grad += [df2]
    return sum(grad_grad)

  def hamiltonian(self, x):
    return (
      -0.5 * self.laplacian(x)
      + self.potential(x).unsqueeze(1) * self.forward(x)
    )

  def rayleigh(self, x):
    weigth_function_ = self.weigth_function(x)
    forward_ = self.forward(x)
    hamiltonian_ = self.hamiltonian(x)
    return [
      integrate(forward_[:,s]*hamiltonian_[:,s], weigth_function_)
      / integrate(forward_[:,s]*forward_[:,s], weigth_function_) for s in range(self.num_states)]

  def sqr_res_per_state(self, x):
    rayleigh_ = self.rayleigh(x)
    weigth_function_ = self.weigth_function(x)
    forward_ = self.forward(x)
    hamiltonian_ = self.hamiltonian(x)
    return [
      integrate((hamiltonian_[:,s] - rayleigh_[s] * forward_[:,s])**2, weigth_function_)
      / integrate(forward_[:,s]*forward_[:,s], weigth_function_) for s in range(self.num_states)]

  def sqr_res(self, x):
    return sum(self.sqr_res_per_state(x))

  # def norm_cond(self, x):
  #   result = []
  #   weigth_function_ = self.weigth_function(x)
  #   forward_ = self.forward(x)
  #   for s in range(self.num_states):
  #     result += [(integrate(forward_[:,s]**2, weigth_function_) - 1)**2]
  #   return sum(result)
  
  def norm_cond(self):
    return sum((torch.sum(A*A, dim=1) - 1)**2 for A in self.out_layer.parameters())

  def orthogon_cond(self, x):
    result = []
    weigth_function_ = self.weigth_function(x)
    forward_ = self.forward(x)
    for s1, s2 in combinations(list(range(self.num_states)), 2):
      result += [
        torch.square(integrate(forward_[:,s1]*forward_[:,s2], weigth_function_)) 
        / integrate(forward_[:,s1]*forward_[:,s1], weigth_function_) 
        / integrate(forward_[:,s2]*forward_[:,s2], weigth_function_)
      ]
    return sum(result)

# Checkpoint

In [5]:
def save_checkpoint(model, optimizer, step, loss):
  checkpoint = {
    'model_state_dict' : model.state_dict(),
    'optimizer_state_dict' : optimizer.state_dict(),
    'step' : step,
    'loss' : loss
  }
  fn = f"checkpoints/{model.name}.pt"
  torch.save(checkpoint, fn)

# Test (He)

In [6]:
dim_coord = 6
num_states = 5
potential = lambda x: (
  -2/torch.sqrt(torch.sum(x[:,:3]**2, dim=1))
  -2/torch.sqrt(torch.sum(x[:,3:]**2, dim=1))
  +1/torch.sqrt(torch.sum((x[:,:3]-x[:,3:])**2, dim=1))
)
num_hidden_layers = 3
num_hidden_neurons = 60
init_mean_weights = 0.0
init_std_weights = math.sqrt(0.1)
activ_fnc = nn.Tanh()

alpha = 2
beta = 1
gamma = 40
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
S = 8000
sampler = MetropolisSampler(dim1=S, dim2=dim_coord, epsilon=0.1, device=device)

lr = 1e-2
weight_decay = 5e-4

CHECKPOINT_PERIOD = 10

MAX_SQR_RESIDUAL = 1e-3

load_checkpoint = False

name = "test_He"

In [None]:
START_TIME = time.time()
  
hist = {
  'time' : list(),
  'loss' : list(),
  'sqr res' : list(),
  'norm cond' : list(),
  'orthogon cond' : list(),
}
for state in range(num_states):
  hist[f"Rayleigh{state}"] = list()

model = TrialFunction(dim_coord, num_states, potential, name,
               num_hidden_layers, num_hidden_neurons,
               init_mean_weights, init_std_weights,
               activ_fnc).to(device)

optimizer = torch.optim.Adam(
  params=list(model.parameters()),
  lr=lr,
  weight_decay=weight_decay
)

if load_checkpoint:
  checkpoint = torch.load(f"checkpoints/{model.name}.pt")
  model.load_state_dict(checkpoint['model_state_dict'])
  model.train()
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  step = checkpoint['step']
  loss = checkpoint['loss']
  print(f"Load model and optimizer: step {checkpoint['step']}, loss {checkpoint['loss']}")
  # history
  hist_df = pd.read_csv(f"history/{model.name}.csv")
  hist['time'] = hist_df['time'].to_list()
  hist['loss'] = hist_df['loss'].to_list()
  hist['sqr res'] = hist_df['sqr res'].to_list()
  hist['norm cond'] = hist_df['norm cond'].to_list()
  hist['orthogon cond'] = hist_df['orthogon cond'].to_list()
  for state in range(num_states):
    hist[f"Rayleigh{state}"] = hist_df[f"Rayleigh{state}"].to_list()

x = sampler.updateAndGetSample(model.weigth_function)
step = 0

valid_states = 0
while (valid_states < num_states):
  step += 1
  x = sampler.updateAndGetSample(model.weigth_function)

  sqr_residual = model.sqr_res(x)

  energies = model.rayleigh(x)

  energ_cond = alpha * sum(energies)
    
  norm_cond = beta * model.norm_cond()
   
  orthogon_cond = gamma * model.orthogon_cond(x)

  loss = (
    sqr_residual 
    + energ_cond 
    + norm_cond 
    + orthogon_cond
  )

  sqr_res_per_state_ = torch.stack(model.sqr_res_per_state(x),dim=0).detach().cpu()
  valid_states = sum((sqr_res_per_state_ < MAX_SQR_RESIDUAL))
  
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

  hist['time'] += [time.time() - START_TIME]
  hist['loss'] += [loss.item()]
  hist['sqr res'] += [sqr_residual.item()]
  hist['norm cond'] += [norm_cond.item()]
  hist['orthogon cond'] += [orthogon_cond.item()]
  for state in range(num_states):
    hist[f"Rayleigh{state}"] += [energies[state].item()]

  if step == 1:
    header = (
        "step, time, loss, sqr res,"
        + " norm cond, orthogon cond,"
    )
    for i in range(num_states):
      header += f", Rayleigh{i}"
    header += ","
    for i in range(num_states):
      header += f", sqr_res{i}"
    print(header)
  
  if (step % CHECKPOINT_PERIOD) == 0 or step == 1:            
    info = (
      f"{step}, "
      + f"{time.time() - START_TIME:.2f}, "
      + f"{loss.item():.2e}, "
      + f"{sqr_residual.item():.2e}, "
      + f"{norm_cond.item():.2e}, "
      + f"{orthogon_cond.item():.2e},"
    )
    for Rayleigh_ in energies:
      info += f", {Rayleigh_.item():.2e}"
    info += ","
    for i in range(num_states):
      info += f", {sqr_res_per_state_[i]:.2e}"
    print(info)
      
    hist_df = pd.DataFrame(data=hist)
    hist_df.to_csv(f"history/{model.name}.csv")

    save_checkpoint(model, optimizer, step, loss)

step, time, loss, sqr res, norm cond, orthogon cond,, Rayleigh0, Rayleigh1, Rayleigh2, Rayleigh3, Rayleigh4,, sqr_res0, sqr_res1, sqr_res2, sqr_res3, sqr_res4
1, 1.45, 1.22e+05, 1.22e+05, 1.35e+02, 2.55e+01,, -1.55e+00, -2.79e-01, -2.05e+00, 2.14e+00, -1.63e+00,, 2.65e+04, 2.12e+04, 1.05e+04, 3.69e+04, 2.67e+04
10, 3.51, 1.14e+04, 1.12e+04, 1.55e+02, 7.39e+01,, -1.86e+00, -2.35e+00, -8.48e-01, -2.03e+00, -1.77e+00,, 1.92e+03, 2.34e+03, 2.52e+03, 2.21e+03, 2.19e+03
20, 5.60, 8.83e+03, 8.56e+03, 1.92e+02, 9.56e+01,, -1.59e+00, -1.62e+00, -9.97e-01, -1.49e+00, -2.41e+00,, 1.56e+03, 1.76e+03, 1.76e+03, 1.66e+03, 1.82e+03
30, 7.61, 7.92e+03, 7.58e+03, 2.40e+02, 1.11e+02,, -1.37e+00, -1.66e+00, -1.12e+00, -1.32e+00, -1.51e+00,, 1.45e+03, 1.55e+03, 1.54e+03, 1.43e+03, 1.62e+03
40, 9.56, 7.27e+03, 6.87e+03, 2.96e+02, 1.23e+02,, -1.37e+00, -1.64e+00, -1.67e+00, -1.54e+00, -1.04e+00,, 1.35e+03, 1.33e+03, 1.37e+03, 1.34e+03, 1.49e+03
50, 11.41, 6.74e+03, 6.26e+03, 3.60e+02, 1.33e+02,, -2.57e+00, 

In [None]:
# import numpy as np
# import pandas as pd
# import matplotlib.pyplot as plt

# hist = pd.read_csv("history/test_He.csv")
# xmax = len(hist)
# xmin = 0 #xmax - 1000

# fig, ax = plt.subplots(2,1,figsize=(16,7))

# hist[['sqr res', 'norm cond',
#        'orthogon cond']][xmin:xmax].plot(xlim=(xmin, xmax), ax = ax[0])
# # ax[0].plot(np.arange(xmin,xmax), np.sum(hist[[f'Rayleigh{i}' for i in range(16)]][xmin:xmax].to_numpy(), axis=1))
# ax[0].set_yscale('log')
# ax[0].grid(axis='y')

# hist[[f'Rayleigh{i}' for i in range(16)]][xmin:xmax].plot(xlim=(xmin, xmax), legend=False, ax = ax[1], alpha=0.7)
# ax[1].axhline(y=0, color='k', linestyle='--')
# ax[1].grid(axis='y')
# # ax[1].set_yscale('symlog')