# Initial Setup



*   Make sure you run the model using a GPU (On Google Colab: Runtime -> Change Runtime Type -> GPU)





In [None]:
#@title  { vertical-output: true, display-mode: "both" }

# install packages
!pip install -q torch-scatter==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.7.0.html
!pip install -q torch-sparse==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.7.0.html
!pip install -q git+https://github.com/rusty1s/pytorch_geometric.git
!pip install -q torch-cluster==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.7.0.html
!pip install -q ase

In [None]:
#@title  { vertical-output: true, display-mode: "both" }
import numpy as np
import torch
import ase
import random
from torch_geometric.data import DataLoader
from ase.io import read

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# set random seeds
seed = 55555
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
# if you are using GPU
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

In [None]:
#@title  { vertical-output: true, display-mode: "both" }

BASE_PATH = '/content/drive/MyDrive/FeMaterials/'
TRAINING_RATIO = 0.8
NUCLEAR_CHARGE = 26 # default nuclear charge
OPTIMIZER = torch.optim.Adam
DTYPE = torch.float64

BATCH_SIZE = 6
CRITERION = torch.nn.MSELoss()
CROSSENTROPY = torch.nn.CrossEntropyLoss()

def print_hyperparameters():
  print("Default nuclear charge:", NUCLEAR_CHARGE)
  print("Training ratio:", TRAINING_RATIO)
  print("Batch size:", BATCH_SIZE)
  print("Optimizer:", OPTIMIZER)
  print("Learning rate:", LEARNING_RATE)
  print("Criterion:", CRITERION)

In [None]:
#@title  { vertical-output: true, display-mode: "both" }
import time
from datetime import datetime
import os
from numpy import savetxt

# mount drive
from google.colab import drive
drive.mount('/content/drive')

path2data = BASE_PATH + 'input/'

# Pre-Processing Functions

In [None]:
#@title  { vertical-output: true, display-mode: "both" }
M_1_54 = None

def extend_atoms(atoms: ase.Atoms, target: int):
  global M_1_54
  tmp = None
  if M_1_54 is None or target != 54:
    tmp = ase.build.find_optimal_cell_shape(atoms.get_cell(), target, "sc") 
    if target == 54:
      M_1_54 = tmp
  else:
    tmp = M_1_54
  supercell = ase.build.make_supercell(atoms, tmp)
  supercell.info["energy"] = atoms.info["energy"] * int(target / atoms.get_global_number_of_atoms())
  return supercell

In [None]:
#@title  { vertical-output: true, display-mode: "both" }
from torch_geometric.data import Data

def data_object(atoms: ase.Atoms, num_atoms=54):
  
  n = atoms.get_global_number_of_atoms()

  if n == 1: # one atom structure - need to expand
    atoms = extend_atoms(atoms, num_atoms)

  n = atoms.get_global_number_of_atoms()

  cell = torch.tensor(atoms.cell, dtype=DTYPE).to(device)
  positions = torch.tensor(atoms.get_positions(), dtype=DTYPE).to(device)
  
  charges = [ NUCLEAR_CHARGE ] * len(positions)
  charges = torch.tensor(charges, dtype=torch.long).to(device)

  y = torch.tensor(atoms.info["energy"], dtype=DTYPE).to(device)

  return Data(charges=charges, x=positions, y=y, cell=cell, n=n)

# DimeNet - Edited Version

**Changes done:**

*   Added support to Periodic Boundary Conditions
*   Added support to energy de-standardization
*   Fixed a mistake made on the PyTorch Geometric implementation on the Embedding Block
*   Fixed a mistake made on the original implementation of angles (swapped indices)

In [None]:
#@title  { vertical-output: true, form-width: "25%" }
from torch_geometric.nn import DimeNet
from torch_geometric.nn.acts import swish
from math import sqrt, pi as PI

import numpy as np
import torch
from torch.nn import Linear, Embedding
from torch_scatter import scatter
from torch_sparse import SparseTensor
from torch_geometric.nn import radius_graph
from torch_geometric.data import download_url
from torch_geometric.data.makedirs import makedirs

from torch_geometric.nn.models.dimenet import Envelope
from torch_geometric.nn.models.dimenet import BesselBasisLayer
from torch_geometric.nn.models.dimenet import SphericalBasisLayer
from torch_geometric.nn.models.dimenet import ResidualLayer
from torch_geometric.nn.models.dimenet import InteractionBlock
from torch_geometric.nn.models.dimenet import OutputBlock

from torch_geometric.nn.models.dimenet_utils import bessel_basis, real_sph_harm

try:
    import sympy as sym
except ImportError:
    sym = None

import os
try:
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
    import tensorflow as tf
except ImportError:
    tf = None


qm9_target_dict = {
    0: 'mu',
    1: 'alpha',
    2: 'homo',
    3: 'lumo',
    5: 'r2',
    6: 'zpve',
    7: 'U0',
    8: 'U',
    9: 'H',
    10: 'G',
    11: 'Cv',
}

# Implement PBC
from ase.neighborlist import neighbor_list 
from ase import Atoms


class EmbeddingBlock(torch.nn.Module):
    def __init__(self, num_radial, hidden_channels, act=swish):
        super(EmbeddingBlock, self).__init__()
        self.act = act

        self.emb = Embedding(95, hidden_channels)
        self.lin_rbf = Linear(num_radial, hidden_channels)
        self.lin = Linear(3 * hidden_channels, hidden_channels)

        self.reset_parameters()

    def reset_parameters(self):
        self.emb.weight.data.uniform_(-sqrt(3), sqrt(3))
        self.lin_rbf.reset_parameters()
        self.lin.reset_parameters()

    def forward(self, x, rbf, i, j):
        x = self.emb(x)
        #rbf = self.act(self.lin_rbf(rbf)) # FIX: this should not have an activation function
        rbf = self.lin_rbf(rbf)
        return self.act(self.lin(torch.cat([x[i], x[j], rbf], dim=-1)))

class DimeNet2(DimeNet):
  
  def __init__(self, hidden_channels, out_channels, num_blocks, num_bilinear,
                 num_spherical, num_radial, cutoff=5.0, envelope_exponent=5,
                 num_before_skip=1, num_after_skip=2, num_output_layers=3,
                 act=swish,
                 mean=None, std=None):
          super(DimeNet, self).__init__()

          self.cutoff = cutoff

          #set mean and standard deviation of energies
          self.mean = mean 
          self.std = std

          # padding used for PBCs
          self.padding = torch.nn.ConstantPad2d((0,6,0,0), 0)

          if sym is None:
              raise ImportError('Package `sympy` could not be found.')

          self.num_blocks = num_blocks

          self.rbf = BesselBasisLayer(num_radial, cutoff, envelope_exponent)
          self.sbf = SphericalBasisLayer(num_spherical, num_radial, cutoff,
                                          envelope_exponent)

          self.emb = EmbeddingBlock(num_radial, hidden_channels, act)

          self.output_blocks = torch.nn.ModuleList([
              OutputBlock(num_radial, hidden_channels, out_channels,
                          num_output_layers, act) for _ in range(num_blocks + 1)
          ])

          self.interaction_blocks = torch.nn.ModuleList([
              InteractionBlock(hidden_channels, num_bilinear, num_spherical,
                                num_radial, num_before_skip, num_after_skip, act)
              for _ in range(num_blocks)
          ])

          self.reset_parameters()

  def pbc_edges(self, z, pos, cell, batch):
          if cell is None:
            return

          tmp_z = z.cpu()
          tmp_pos = pos.cpu()
          tmp_cell = cell.cpu()
          nh1_tmp = np.array([]) # will contain all connection from node i
          nh2_tmp = np.array([]) # .. to node j
          dist_tmp = np.array([]) # distances between (i,j)
          shift_cells_tmp = None # bravais lattice multiplied by shift for the connection

          if batch is not None: #batch input
            tmp_batch = np.array(batch.cpu())
            batch_size = []
            found_b = []
            for b in tmp_batch: # create an array with each element being the dim of the corresponding index batch
              if b not in found_b:
                found_b.append(b)
                batch_size.append((tmp_batch == b).sum())

            for i in range(len(batch_size)):
              prev_sum = sum(batch_size[:i])
              current_z = tmp_z[prev_sum:batch_size[i]+prev_sum]
              # create the atomic structure
              atms = Atoms(charges=current_z, 
                           positions=tmp_pos[prev_sum:batch_size[i]+prev_sum], 
                           cell=tmp_cell[3*i:3*(i+1)], pbc=True) 

              # get the connections for the atomic structure w/ distances and shift
              nh1, nh2, dist, shift = neighbor_list("ijdS", atms, 
                                             self.cutoff, 
                                             self_interaction=False) 

              nh1 = nh1 + prev_sum # adds the number of previous elements to the atom index
              nh2 = nh2 + prev_sum

              nh1_tmp = np.concatenate((nh1_tmp, np.array(nh1)))
              nh2_tmp = np.concatenate((nh2_tmp, np.array(nh2)))
              dist_tmp = np.concatenate((dist_tmp, np.array(dist)))

              # Mult cells array (9 elements each) for each connection element in the batch
              cell_arr = np.asarray(tmp_cell[3*i:3*(i+1)]).reshape(-1)
              repeat = np.tile(cell_arr, (len(dist), 1))

              # multiply cell values by shift
              repeat[:, 0:3] = (repeat[:, 0:3].T * shift[:, 0]).T
              repeat[:, 3:6] = (repeat[:, 3:6].T * shift[:, 1]).T
              repeat[:, 6:9] = (repeat[:, 6:9].T * shift[:, 2]).T

              if shift_cells_tmp  is None:
                shift_cells_tmp  = np.matrix(repeat)
              else:
                shift_cells_tmp  = np.concatenate((shift_cells_tmp, repeat))
          else: # single cell input
              # create the atomic structure
              atms = Atoms(charges=tmp_z, 
                           positions=tmp_pos, 
                           cell=tmp_cell, pbc=True)

              # get the connections for the atomic structure w/ distances and shift
              nh1, nh2, dist, shift = neighbor_list("ijdS", atms, 
                                             self.cutoff, 
                                             self_interaction=False)

              nh1_tmp = np.concatenate((nh1_tmp, np.array(nh1)))
              nh2_tmp = np.concatenate((nh2_tmp, np.array(nh2)))
              dist_tmp = np.concatenate((dist_tmp, np.array(dist)))

              # Mult cells array (9 elements each) for each connection element in the batch
              cell_arr = np.asarray(tmp_cell).reshape(-1)
              repeat = np.tile(cell_arr, (len(dist), 1))

              # multiply cell values by shift
              repeat[:,0:3] = (repeat[:, 0:3].T * shift[:, 0]).T
              repeat[:,3:6] = (repeat[:, 3:6].T * shift[:, 1]).T
              repeat[:,6:9] = (repeat[:, 6:9].T * shift[:, 2]).T

              shift_cells_tmp = np.matrix(repeat)
          return [torch.tensor(nh1_tmp, dtype = torch.long).to(z.device), 
                  torch.tensor(nh2_tmp, dtype = torch.long).to(z.device), 
                  torch.tensor(dist_tmp, dtype = DTYPE).to(z.device),
                  torch.tensor(shift_cells_tmp, dtype = DTYPE).to(z.device)]
            

  def triplets(self, edge_index, num_nodes, shift_cells=None):
        row, col = edge_index  # j->i

        value = torch.arange(row.size(0), device=row.device)
        adj_t = SparseTensor(row=col, col=row, value=value,
                             sparse_sizes=(num_nodes, num_nodes))
        adj_t_row = adj_t[row]

        num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long)

        # Node indices (k->j->i) for triplets.
        idx_i = col.repeat_interleave(num_triplets)
        idx_j = row.repeat_interleave(num_triplets)

        if shift_cells is not None: # Update also the shift vectors
          shift_cells = shift_cells.repeat_interleave(num_triplets, dim=0)

        idx_k = adj_t_row.storage.col()

        mask = (idx_i != idx_k)  # Remove i == k triplets.
        idx_i, idx_j, idx_k = idx_i[mask], idx_j[mask], idx_k[mask]
        if shift_cells is not None: # Remove also from the shift vector
          shift_cells = shift_cells[mask]

        # Edge indices (k-j, j->i) for triplets.
        idx_kj = adj_t_row.storage.value()[mask]
        idx_ji = adj_t_row.storage.row()[mask]

        return col, row, idx_i, idx_j, idx_k, idx_kj, idx_ji, shift_cells

  def forward(self, z, pos, cell=None, batch=None):
        
        edge_index = []
        dist = []
        shift_cells = None
        if cell is not None: # implement PBC
          r1, r2, dist, shift_cells = self.pbc_edges(z, pos, cell, batch)
          edge_index = [r1, r2]
        else: # old method without PBC
          edge_index = radius_graph(pos, r=self.cutoff, batch=batch)

        i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji, shift_cells = self.triplets(
            edge_index, num_nodes=z.size(0), shift_cells=shift_cells)        

        # Calculate distances.
        if cell is None: # calculate distance without PBC
          dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt()
          
        # Define atoms position 
        pos_i = pos[idx_i]
        pos_j = pos[idx_j] # central atom
        pos_k = pos[idx_k]
        
        if cell is not None: # Fix coordinates for PBCs
          pos_i = pos_i + shift_cells[:, 0:3] + shift_cells[:, 3:6] + shift_cells[:, 6:9]
          pos_k = pos_k + shift_cells[:, 0:3] + shift_cells[:, 3:6] + shift_cells[:, 6:9]

        # Calculate angles - with some Fixes to indexes compared to the orig. version
        pos_ji, pos_kj = pos_j - pos_i, pos_k - pos_j

        a = (pos_ji * pos_kj).sum(dim=-1)
        b = torch.cross(pos_ji, pos_kj).norm(dim=-1)
        angle = torch.atan2(b, a)

        rbf = self.rbf(dist)
        sbf = self.sbf(dist, angle, idx_kj)

        # Embedding block.
        x = self.emb(z, rbf, i, j)
        P = self.output_blocks[0](x, rbf, i, num_nodes=pos.size(0))

        # Interaction blocks.
        for interaction_block, output_block in zip(self.interaction_blocks,
                                                   self.output_blocks[1:]):
            x = interaction_block(x, rbf, sbf, idx_kj, idx_ji)
            P += output_block(x, rbf, i, num_nodes=pos.size(0))

        # Energy de-standardization
        if self.std is not None and self.mean is not None:
          P = P * self.std + self.mean

        res = P.sum(dim=0) if batch is None else scatter(P, batch, dim=0)
        return res

# Network Initialization (Load Trained Model)

In [None]:
# Uncomment only one triple [modelPath, HAS_PBC, model]
# HAS_PBC: set to True if the PBC was trained using PBCs

'''
modelName = "DimeNet DB1 NO PBC"
modelPath = BASE_PATH + 'trained/db1-nopbc-pretrained.model'
HAS_PBC = False
model = DimeNet2(hidden_channels=128, out_channels=1, num_blocks=7, num_bilinear=6, num_spherical=5, num_radial=5, cutoff=3.5, std=0.22314909777243813, mean=-3460.810642670331).to(device)
'''

'''
modelName = "DimeNet DB1-8 NO PBC"
modelPath = BASE_PATH + 'trained/all-nopbc-pretrained.model'
HAS_PBC = False
model = DimeNet2(hidden_channels=128, out_channels=1, num_blocks=7, num_bilinear=6, num_spherical=5, num_radial=5, cutoff=3.5, std=0.16576383029449515, mean=-3460.825847482401).to(device)
'''

'''
modelName = "DimeNet DB1 PBC"
modelPath = BASE_PATH + 'trained/db1-pbc-pretrained.model'
HAS_PBC = True
model = DimeNet2(hidden_channels=128, out_channels=1, num_blocks=7, num_bilinear=6, num_spherical=5, num_radial=5, cutoff=3.5, std=0.22314909777243813, mean=-3460.810642670331).to(device)
'''

modelName = "DimeNet DB1-8 PBC"
modelPath = BASE_PATH + 'trained/all-pbc-pretrained.model'
HAS_PBC = True
model = DimeNet2(hidden_channels=128, out_channels=1, num_blocks=7, num_bilinear=8, num_spherical=7, num_radial=6, cutoff=3.5, std=0.16576383029449515, mean=-3460.825847482401).to(device)



model.load_state_dict(torch.load(modelPath))
model = model.to(device)
model = model.double() # necessary if the model was trained using the float64 DTYPE
model.eval()

# Results

## Equation of State Curve

We want the network to be able to reproduce the Equation of State (EOS) Curve, that is the Volume-Energy curve reported in FIG.2 of [Dragoni's Paper](https://arxiv.org/abs/1706.10229)

*dataset_nocut.xyz* is provided in this repository.
It's a set of perfect 1-atom structures BCCs, that is with the Bravais Lattice made like this:

*   a1 = <a, 0, 0>
*   a2 = <0, a, 0>
*   a3 = <a/2, a/2, a/2>



In [None]:
# load BCC dataset that we want to predict

volumes = []
predictions = []

predict_reticoli = []
tmp_db = read(path2data + 'dataset_nocut.xyz', index=":")

original_volumes = []
for db in tmp_db:
  original_volumes.append(db.cell.volume)

predict_reticoli = list(map(data_object, tmp_db))

print(predict_reticoli)
print(original_volumes)

In [None]:
# load baseline results (GAP and DFT) - provided with the repository

dft_v = torch.load(path2data + "dft_fig2_volumes.pt")
dft_e = torch.load(path2data + "dft_fig2_energies.pt")
gap_v = torch.load(path2data + "gap_fig2_volumes.pt")
gap_e = torch.load(path2data + "gap_fig2_energies.pt")

In [None]:
import matplotlib.pyplot as plt

# Bigger fonts for thesis images
'''
plt.rcParams.update({'font.size':25})
plt.rcParams.update({'axes.titlesize':30})
plt.rcParams.update({'figure.titlesize':30})
plt.rcParams.update({'xtick.labelsize':15})
plt.rcParams.update({'ytick.labelsize':15})
plt.rcParams.update({'legend.fontsize':25})
'''

predictions = []

# get the predictions for BCCs on the trained model
if len(volumes) == 0:
  test_loader = DataLoader(predict_reticoli, batch_size=1, shuffle=False)
  for data in predict_reticoli:
    xyz = np.array(data.cell.cpu())

  for data in test_loader:
    with torch.no_grad():
      if HAS_PBC:
        out = model(data.charges, data.x, data.cell, data.batch)
      else:
        out = model(data.charges, data.x, batch=data.batch)
      
      out = out.squeeze(1)
      predicted_energy = float(out.view(-1))
      predictions.append(predicted_energy / int(data.n)) # divide the output by the number of atoms of the structure


energies = predictions.copy()



print("Plotted VO / Energies:")
print(original_volumes)
print(energies)


fig, ax = plt.subplots(figsize=(16,10))
plt.cla()
ax.set_xlim(10.99, 12)
ax.set_ylim(-3460934.301071031, -3460914.9599393304)

ax.set_title(modelName)
plt.xlabel("V")
plt.ylabel("E")


b = ax.scatter(original_volumes, [(en*1000) for en in energies], color="blue")
ax.plot(original_volumes, [(en*1000) for en in energies], color="blue")
b.set_label('DimeNet')

c = ax.scatter(dft_v, dft_e, color="black")
ax.plot(dft_v, dft_e, color="black")

c.set_label('DFT')

ax.legend()
plt.show()

## Bain Path Curve

We want the network to be able to reproduce the Bain Path Curve, that is the c/a ratio / Energy curve reported in FIG.3 of [Dragoni's Paper](https://arxiv.org/abs/1706.10229)

In [None]:
# Load DFT baseline

dft_fig3_ca = []
dft_fig3_en = []
f = open(path2data+"dft_fig3.csv", "r")
lines = f.read().split("\n")
for line in lines:
  if line.strip() == "":
    continue
  dft_fig3_ca.append(float(line.split(";")[0].replace(",", ".").strip()))
  dft_fig3_en.append(float(line.split(";")[1].replace(",", ".").strip()))

out = ""
print(dft_fig3_ca)
print(dft_fig3_en)

In [None]:
# Constant Volume Method: generate the 54 atoms structures by varying the c/a ratio

from ase.atoms import Cell
import math

predictions_vol = []
predictions = []

predict_reticoli_single = []
tmp_dataset = read(path2data + 'DB1.xyz', index=":")

# Create cells by variying the c/a ratio from 0.5 to 2.0 by a step of 0.02 (coeff)
a = 2.83477
reticolo = tmp_dataset[0]

coeff = np.arange(0.5,2.05,0.02)

predict_reticoli_vol = []

dime_x_plot_m1 = []
for sr_ in coeff:
  sr = sr_.item()
  r = sr**(2/3)
  a_c = r**(-1)

  cell1 = "{} {} {}".format(math.sqrt(a_c)*a, 0, 0)
  cell2 = "{} {} {}".format(0, math.sqrt(a_c)*a, 0)
  cell3 = "{} {} {}".format(math.sqrt(a_c)*a/2, math.sqrt(a_c)*a/2, (r)*a/2)

  tmp_ret = reticolo.copy()
  tmp_ret.cell = Cell([ [ float(x.strip()) for x in cell1.split(' ')], [ float(x.strip()) for x in cell2.split(' ')], [ float(x.strip()) for x in cell3.split(' ')] ])
  predict_reticoli_vol.append(data_object(tmp_ret))

  dime_x_plot_m1.append(sr)

In [None]:
# Volume Optimization Method: generate like before but also vary based on 
#                             a "r" coefficent on last term of Bravais lattices


predict_reticoli_single = []
tmp_dataset = read(path2data + 'DB1.xyz', index=":")


a = 2.83477
reticolo = tmp_dataset[0]

predict_reticoli = []
# extra variation on the last Bravais coefficent
r = np.arange(0.7,2.0,0.02)

coeff = np.arange(0.5,2.05,0.02)
dime_x_plot_m2 = []
for sr_ in coeff:
  sr = sr_.item()
  r = sr**(2/3)
  a_c = r**(-1)
  a_tmp = np.arange(0.95*a,1.05*a,0.02)
  tmp_arr = []
  for current_a in a_tmp:
    cell1 = "{} {} {}".format(math.sqrt(a_c)*current_a, 0, 0)
    cell2 = "{} {} {}".format(0, math.sqrt(a_c)*current_a, 0)
    cell3 = "{} {} {}".format(math.sqrt(a_c)*current_a/2, math.sqrt(a_c)*current_a/2, (r)*current_a/2)
    tmp_ret = reticolo.copy()
    tmp_ret.cell = Cell([ [ float(x.strip()) for x in cell1.split(' ')], [ float(x.strip()) for x in cell2.split(' ')], [ float(x.strip()) for x in cell3.split(' ')] ])
    tmp_arr.append(data_object(tmp_ret))

  predict_reticoli.append(tmp_arr)
  dime_x_plot_m2.append(sr)


print(predict_reticoli)

### Draw Figure

Notice that the network performs well on the [0.8 - 1.2] range - where the tain data is focused on. On other ranges the network is extrapolating.

In [None]:
import matplotlib.pyplot as plt


# Predictions for the Method 2, take only the best results (lower energy)
all_points = []
all_points_x = []
elements = 0
if len(predictions) == 0:
  for topredict in predict_reticoli:
    test_loader = DataLoader(topredict, batch_size=1, shuffle=False)
    tmp_predictions = []
    tmp_a = []
    for data in test_loader:
      with torch.no_grad():
        if HAS_PBC:
          out = model(data.charges, data.x, data.cell, data.batch)
        else:
          out = model(data.charges, data.x, batch=data.batch)
        out = out.squeeze(1)
        predicted_energy = float(out.view(-1))
        tmp_predictions.append(predicted_energy / int(data.n)) # divide by the atoms of the structure
        tmp_a.append(data.x[0])

    best_energy = 9999
    for i in range(len(tmp_predictions)):
      all_points.append(tmp_predictions[i])
      all_points_x.append(dime_x_plot_m2[elements])
      if tmp_predictions[i] < best_energy:
        best_energy = tmp_predictions[i]

    predictions.append(best_energy)
    elements = elements + 1

# Predictions for the Method 1
if len(predictions_vol) == 0:
  test_loader = DataLoader(predict_reticoli_vol, batch_size=1, shuffle=False)
  for data in test_loader:
    with torch.no_grad():
      if HAS_PBC:
        out = model(data.charges, data.x, data.cell, data.batch)
      else:
        out = model(data.charges, data.x, batch=data.batch)
      out = out.squeeze(1)
      predicted_energy = float(out.view(-1))
      predictions_vol.append(predicted_energy / int(data.n))

    
#standardize energy
energies = predictions.copy()
energies_vol = predictions_vol.copy()

fig, ax = plt.subplots(figsize=(16,10))
plt.cla()

ax.set_title(modelName)
plt.xlabel("c/a")
plt.ylabel("E")

#ax.set_xlim(10.95, 12.05)
test = [en for en in dft_fig3_en]

ax.set_ylim(min(test)-50, max(test)+150)
ax.set_xlim(0.7, 2.1)

# Plot DFT Baseline 
c = ax.scatter(dft_fig3_ca, [en for en in dft_fig3_en], color="black")
ax.plot(dft_fig3_ca, [en for en in dft_fig3_en], color="black")
c.set_label("DFT")

print("Dimenet Constant Volume:")
print(dime_x_plot_m1)
print([(en*1000) for en in energies])
print("----")
print("Dimenet Volume Optimization:")
print(dime_x_plot_m2)
print([(en*1000) for en in energies_vol])

d = ax.scatter(dime_x_plot_m2, [(en*1000) for en in energies], color="blue")
ax.plot(dime_x_plot_m2, [(en*1000) for en in energies], color="blue")
d.set_label("DimeNet Vol. Optimization")

e = ax.scatter(dime_x_plot_m1, [(en*1000) for en in energies_vol], color="red")
ax.plot(dime_x_plot_m1, [(en*1000) for en in energies_vol], color="red")
e.set_label("DimeNet Costant Vol.")


# Draw vertical lines
plt.axvline(x=0.8, linestyle='--', label="bct")
plt.axvline(x=1.0, linestyle='--', color='orange', label="bcc")
plt.axvline(x=math.sqrt(2), linestyle='--', color='brown', label="fcc")

plt.legend()
plt.show()

## Periodic Boundary Conditions Sanity Check

Here we are testing the correct implementation of PBCs on the network by creating a 4-atoms structure, 16-atoms structure, 54-atoms structure and a 128-atoms structure based on the same Bravais lattice. 

The two predicted energies, divided by the number of atoms in the structure, should be roughly the same if PBCs are working correctly. Notice that this sanity check should be performed only when testing a PBC-trained model.

In [None]:
db1 = read(path2data + 'DB1.xyz', index=":")
first_4 = data_object(db1[0], num_atoms=4)
first_128 = data_object(db1[0], num_atoms=128)

print(first_4)
print(first_128)

In [None]:
e_4 = 0
e_128 = 0

test_loader = DataLoader([first_4], batch_size=1, shuffle=False)
for data in test_loader:
  with torch.no_grad():
    if HAS_PBC:
      out = model(data.charges, data.x, cell=data.cell, batch=data.batch)
    else:
      out = model(data.charges, data.x, batch=data.batch)
    out = out.squeeze(1)
    predicted_energy = float(out.view(-1))
    e_4 = predicted_energy

test_loader = DataLoader([first_128], batch_size=1, shuffle=False)
for data in test_loader:
  with torch.no_grad():
    if HAS_PBC:
      out = model(data.charges, data.x, cell=data.cell, batch=data.batch)
    else:
      out = model(data.charges, data.x, batch=data.batch)
    out = out.squeeze(1)
    predicted_energy = float(out.view(-1))
    e_128 = predicted_energy

print("e_4: ", e_4)
print("e_4 / 4: ", (e_4 / 4))
print("e_128: ", e_128)
print("e_128 / 128: ", (e_128 / 128))

print("error: ", abs((e_128 / 128) - (e_4 / 4)))

## Vacancy Formation Energy

Here we want to predict the missing energy caused by a single-vacancy inside the structure (the removal of one atom).

The datasets used for this test are given with the repository.


In [None]:
# Load the models and predict energies
bulk_128_dataset = read(path2data + 'bulk_128atoms.xyz', index=":")
vac_127_dataset = read(path2data + 'vac_127atoms.xyz', index=":")
bulk_54_dataset = read(path2data + 'bulk_54atoms.xyz', index=":")
vac_53_dataset = read(path2data + 'vac_53atoms.xyz', index=":")

bulk_lattice_128 = []
vac_lattice_127 = []
for lattice in bulk_128_dataset:
    bulk_lattice_128.append(data_object(lattice))

for lattice in vac_127_dataset:
  vac_lattice_127.append(data_object(lattice))


print(bulk_lattice_128)
print(vac_lattice_127)


bulk_lattice_54 = []
vac_lattice_53 = []
for lattice in bulk_54_dataset:
    bulk_lattice_54.append(data_object(lattice))

for lattice in vac_53_dataset:
  vac_lattice_53.append(data_object(lattice))

print(bulk_lattice_54)
print(vac_lattice_53)

In [None]:
#Predict energies for the 128 / 127 configurations
target = 2.48920984
e_total = 0
e_total_v = 0

test_loader = DataLoader(bulk_lattice_128, batch_size=1, shuffle=False)
for data in test_loader:
  with torch.no_grad():
    if HAS_PBC:
      out = model(data.charges, data.x, data.cell, data.batch)
    else:
      out = model(data.charges, data.x, batch=data.batch)
    out = out.squeeze(1)
    predicted_energy = float(out.view(-1))
    e_total = predicted_energy

test_loader = DataLoader(vac_lattice_127, batch_size=1, shuffle=False)
for data in test_loader:
  with torch.no_grad():
    if HAS_PBC:
      out = model(data.charges, data.x, data.cell, data.batch)
    else:
      out = model(data.charges, data.x, batch=data.batch)
    out = out.squeeze(1)
    predicted_energy = float(out.view(-1))
    e_total_v = predicted_energy


print("E_Total (128 / 127)", e_total)
print("E_Total_v (128 / 127)", e_total_v)
print("")

e_v = e_total_v - e_total * (128 - 1) / 128
print("E_v (128 / 127)", e_v)
print("error %", abs(((target - e_v) / target) * 100))

In [None]:
#Predict energies for the 54 / 53 configurations
target = 2.27723546
e_total = 0
e_total_v = 0

test_loader = DataLoader(bulk_lattice_54, batch_size=1, shuffle=False)
for data in test_loader:
  with torch.no_grad():
    if HAS_PBC:
      out = model(data.charges, data.x, data.cell, data.batch)
    else:
      out = model(data.charges, data.x, batch=data.batch)
    out = out.squeeze(1)
    predicted_energy = float(out.view(-1))
    e_total = predicted_energy

test_loader = DataLoader(vac_lattice_53, batch_size=1, shuffle=False)
for data in test_loader:
  with torch.no_grad():
    if HAS_PBC:
      out = model(data.charges, data.x, data.cell, data.batch)
    else:
      out = model(data.charges, data.x, batch=data.batch)
    out = out.squeeze(1)
    predicted_energy = float(out.view(-1))
    e_total_v = predicted_energy


print("E_Total (54 / 53)", e_total)
print("E_Total_v (54 / 53)", e_total_v)
print("")

e_v = e_total_v - e_total * (54 - 1) / 54
print("E_v (54 / 53)", e_v)
print("error %", abs(((target - e_v) / target) * 100))