# Setup

In [None]:
#@title  { vertical-output: true, display-mode: "both" }
!pip install -q ase
!pip install -q torch==1.8.0
!pip install  torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cu102.html
!pip install -q torch-sparse -f https://pytorch-geometric.com/whl/torch-1.8.0+cu102.html
!pip install -q torch-cluster -f https://pytorch-geometric.com/whl/torch-1.8.0+cu102.html
!pip install -q torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.8.0+cu102.html
!pip install -q git+https://github.com/rusty1s/pytorch_geometric.git@1.7.0

[K     |████████████████████████████████| 2.2 MB 5.4 MB/s 
[K     |████████████████████████████████| 735.5 MB 15 kB/s 
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torchvision 0.10.0+cu102 requires torch==1.9.0, but you have torch 1.8.0 which is incompatible.
torchtext 0.10.0 requires torch==1.9.0, but you have torch 1.8.0 which is incompatible.[0m
[?25hLooking in links: https://pytorch-geometric.com/whl/torch-1.8.0+cu102.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-1.8.0%2Bcu102/torch_scatter-2.0.8-cp37-cp37m-linux_x86_64.whl (8.1 MB)
[K     |████████████████████████████████| 8.1 MB 5.3 MB/s 
[?25hInstalling collected packages: torch-scatter
Successfully installed torch-scatter-2.0.8
[K     |████████████████████████████████| 3.1 MB 5.4 MB/s 
[K     |████████████████████████████████| 1.5 MB 5.3 MB/s 
[K     |██████

In [None]:
import torch
import numpy as np
import random
import matplotlib.pyplot as plt
import ase
from pprint import pprint
from torch_geometric.data import DataLoader

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # use gpu if available
DTYPE = torch.float64
ff = torch.tensor(54, dtype=DTYPE)

**Comment the first two lines of the following cell if you are not running this 
notebook on Colab**

**Set BASE_PATH to the path to the folder where the files in the GNN_atomistics repository are located**

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
BASE_PATH = '/content/drive/MyDrive/gnn_atomistics'
DATA_PATH = BASE_PATH + "/data/evaluation"
MODELS_PATH = BASE_PATH + "/models"

# Model definitions

In [None]:
from ase import Atoms
from ase.neighborlist import neighbor_list 

def pbc_edges(cutoff, z, x, cell, batch, compute_sc=False):

  NH1 = torch.tensor([], dtype=torch.long, device=DEVICE)
  NH2 = torch.tensor([], dtype=torch.long, device=DEVICE)
  S = torch.tensor([], dtype=torch.long, device=DEVICE)
  D = torch.tensor([], dtype=DTYPE, device=DEVICE)
  SC = torch.tensor([], dtype=DTYPE, device=DEVICE) if compute_sc else None
  x_ = torch.clone(x).detach().cpu().numpy()

  if batch is not None:
    # count number of elements per batch
    batch_ids = list(set(batch.cpu().tolist()))
    batch_sizes = [ (batch == id).sum().item() for id in batch_ids ]

    for i in range(len(batch_sizes)):
      offset = sum(batch_sizes[:i]) # to obtain correct atom indices
      
      atoms = Atoms(charges = (z[offset:offset + batch_sizes[i]]).cpu(), 
        positions = x_[offset:offset + batch_sizes[i]], 
        cell = (cell[3*i:3*(i+1)]).cpu(),
        pbc = True
      ) 
      
      nh1, nh2, s = neighbor_list("ijS", atoms, cutoff, self_interaction=False) 
      nh1 = torch.tensor(nh1, dtype=torch.long, device=DEVICE)
      nh2 = torch.tensor(nh2, dtype=torch.long, device=DEVICE)
      nh1 = nh1 + offset
      nh2 = nh2 + offset
      s = torch.tensor(s, dtype=DTYPE, device=DEVICE)
      d = x[nh2] - x[nh1] + torch.matmul(s, cell[3*i:3*(i+1)])
      
      if compute_sc:
        cell_flat = torch.flatten(cell[3*i:3*(i+1)])
        sc = torch.tile(cell_flat, (len(d), 1))
        sc[:, 0:3] = (sc[:, 0:3].T * s[:, 0]).T
        sc[:, 3:6] = (sc[:, 3:6].T * s[:, 1]).T
        sc[:, 6:9] = (sc[:, 6:9].T * s[:, 2]).T
        SC = torch.cat((SC, sc), 0)      

      NH1 = torch.cat((NH1, nh1), 0)
      NH2 = torch.cat((NH2, nh2), 0)
      S = torch.cat((S, s), 0)
      D = torch.cat((D, d), 0)

  else: # no batch
    atoms = Atoms(charges = z.cpu(), positions = x.cpu(), cell = cell.cpu(), pbc = True)
    nh1, nh2, s = neighbor_list("ijS", atoms, cutoff, self_interaction=False)
    nh1 = torch.tensor(nh1, dtype=torch.long, device=DEVICE)
    nh2 = torch.tensor(nh2, dtype=torch.long, device=DEVICE)
    s = torch.tensor(s, dtype=DTYPE, device=DEVICE)
    d = x[nh2] - x[nh1] + torch.matmul(s, cell)
    
    if compute_sc:
      cell_flat = torch.flatten(cell)
      sc = torch.tile(cell_flat, (len(d), 1))
      sc[:, 0:3] = (sc[:, 0:3].T * s[:, 0]).T
      sc[:, 3:6] = (sc[:, 3:6].T * s[:, 1]).T
      sc[:, 6:9] = (sc[:, 6:9].T * s[:, 2]).T  
      SC = sc

    NH1, NH2, S, D = nh1, nh2, s, d

  D = D.norm(dim=-1)
  return  NH1, NH2, D, S, SC 

## Schnet

In [None]:
import os
import warnings
import os.path as ospa
from math import pi as PI
import torch.nn.functional as F
from torch.nn import Embedding, Sequential, Linear, ModuleList
from torch_scatter import scatter
from torch_geometric.data.makedirs import makedirs
from torch_geometric.data import download_url, extract_zip
from torch_geometric.nn import radius_graph, MessagePassing

class SchNet2(torch.nn.Module):

    def __init__(self, hidden_channels=128, num_filters=128,
                 num_interactions=6, num_gaussians=50, cutoff=10.0,
                 readout='add', dipole=False, mean=None, std=None,
                 atomref=None):
        super(SchNetx, self).__init__()

        assert readout in ['add', 'sum', 'mean']

        self.hidden_channels = hidden_channels
        self.num_filters = num_filters
        self.num_interactions = num_interactions
        self.num_gaussians = num_gaussians
        self.cutoff = cutoff
        self.readout = readout
        self.dipole = dipole
        self.readout = 'add' if self.dipole else self.readout
        self.mean = mean
        self.std = std
        self.scale = None

        atomic_mass = torch.from_numpy(ase.data.atomic_masses)
        self.register_buffer('atomic_mass', atomic_mass)

        self.embedding = Embedding(100, hidden_channels)
        self.distance_expansion = GaussianSmearing(0.0, cutoff, num_gaussians)

        self.interactions = ModuleList()
        for _ in range(num_interactions):
            block = InteractionBlock(hidden_channels, num_gaussians,
                                     num_filters, cutoff)
            self.interactions.append(block)

        self.lin1 = Linear(hidden_channels, hidden_channels // 2)
        self.act = ShiftedSoftplus()
        self.lin2 = Linear(hidden_channels // 2, 1)

        self.register_buffer('initial_atomref', atomref)
        self.atomref = None
        if atomref is not None:
            self.atomref = Embedding(100, 1)
            self.atomref.weight.data.copy_(atomref)
        
        self.reset_parameters()

    def reset_parameters(self):
        self.embedding.reset_parameters()
        for interaction in self.interactions:
            interaction.reset_parameters()
        torch.nn.init.xavier_uniform_(self.lin1.weight)
        self.lin1.bias.data.fill_(0)
        torch.nn.init.xavier_uniform_(self.lin2.weight)
        self.lin2.bias.data.fill_(0)
        if self.atomref is not None:
            self.atomref.weight.data.copy_(self.initial_atomref)

    def forward(self, z, x, cell=None, batch=None):
        assert z.dim() == 1 and z.dtype == torch.long
        batch = torch.zeros_like(z) if batch is None else batch

        h = self.embedding(z)

        if cell != None:
          row, col, edge_weight, shift, _ = pbc_edges(self.cutoff, z, x, cell, batch, compute_sc=False)
          edge_index = torch.stack((row, col))
        else:
          edge_index = radius_graph(x, r=self.cutoff, batch=batch)
          row, col = edge_index
          edge_weight = (x[row] - x[col]).norm(dim=-1)
        
        edge_attr = self.distance_expansion(edge_weight)

        for interaction in self.interactions:
            h = h + interaction(h, edge_index, edge_weight, edge_attr)

        h = self.lin1(h)
        h = self.act(h)
        h = self.lin2(h)

        if self.dipole:
            mass = self.atomic_mass[z].view(-1, 1)
            c = scatter(mass * x, batch, dim=0) / scatter(mass, batch, dim=0)
            h = h * (pos - c[batch])

        if not self.dipole and self.mean is not None and self.std is not None:
            h = h * self.std + self.mean

        if not self.dipole and self.atomref is not None:
            h = h + self.atomref(z)

        out = scatter(h, batch, dim=0, reduce=self.readout)
        
        if self.dipole:
            out = torch.norm(out, dim=-1, keepdim=True)

        if self.scale is not None:
            out = self.scale * out

        return out


    def __repr__(self):
        return (f'{self.__class__.__name__}('
                f'hidden_channels={self.hidden_channels}, '
                f'num_filters={self.num_filters}, '
                f'num_interactions={self.num_interactions}, '
                f'num_gaussians={self.num_gaussians}, '
                f'cutoff={self.cutoff})')



class InteractionBlock(torch.nn.Module):
    def __init__(self, hidden_channels, num_gaussians, num_filters, cutoff):
        super(InteractionBlock, self).__init__()
        self.mlp = Sequential(
            Linear(num_gaussians, num_filters),
            ShiftedSoftplus(),
            Linear(num_filters, num_filters),
        )
        self.conv = CFConv(hidden_channels, hidden_channels, num_filters,
                           self.mlp, cutoff)
        self.act = ShiftedSoftplus()
        self.lin = Linear(hidden_channels, hidden_channels)

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.mlp[0].weight)
        self.mlp[0].bias.data.fill_(0)
        torch.nn.init.xavier_uniform_(self.mlp[2].weight)
        self.mlp[0].bias.data.fill_(0)
        self.conv.reset_parameters()
        torch.nn.init.xavier_uniform_(self.lin.weight)
        self.lin.bias.data.fill_(0)

    def forward(self, x, edge_index, edge_weight, edge_attr):
        x = self.conv(x, edge_index, edge_weight, edge_attr)
        x = self.act(x)
        x = self.lin(x)
        return x


class CFConv(MessagePassing):
    def __init__(self, in_channels, out_channels, num_filters, nn, cutoff):
        super(CFConv, self).__init__(aggr='add')
        self.lin1 = Linear(in_channels, num_filters, bias=False)
        self.lin2 = Linear(num_filters, out_channels)
        self.nn = nn
        self.cutoff = cutoff

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.lin1.weight)
        torch.nn.init.xavier_uniform_(self.lin2.weight)
        self.lin2.bias.data.fill_(0)

    def forward(self, x, edge_index, edge_weight, edge_attr):
        C = 0.5 * (torch.cos(edge_weight * PI / self.cutoff) + 1.0)
        W = self.nn(edge_attr) * C.view(-1, 1)

        x = self.lin1(x)
        x = self.propagate(edge_index, x=x, W=W)
        x = self.lin2(x)
        return x

    def message(self, x_j, W):
        return x_j * W


class GaussianSmearing(torch.nn.Module):
    def __init__(self, start=0.0, stop=5.0, num_gaussians=50):
        super(GaussianSmearing, self).__init__()
        offset = torch.linspace(start, stop, num_gaussians)
        self.coeff = (-0.5 / (offset[1] - offset[0]).item()**2)
        self.register_buffer('offset', offset)

    def forward(self, dist):
        dist = dist.view(-1, 1) - self.offset.view(1, -1)
        res = torch.exp(self.coeff * torch.pow(dist, 2))
        return res

class ShiftedSoftplus(torch.nn.Module):
    def __init__(self):
        super(ShiftedSoftplus, self).__init__()
        self.shift = torch.log(torch.tensor(2.0)).item()

    def forward(self, x):
        return F.softplus(x) - self.shift

## Dimenet

In [None]:
#@title  { vertical-output: true, form-width: "25%" }
from torch_geometric.nn import DimeNet
from torch_geometric.nn.acts import swish
from math import sqrt, pi as PI

import numpy as np
import torch
from torch.nn import Linear, Embedding
from torch_scatter import scatter
from torch_sparse import SparseTensor
from torch_geometric.nn import radius_graph
from torch_geometric.data import download_url
from torch_geometric.data.makedirs import makedirs

from torch_geometric.nn.models.dimenet import Envelope
from torch_geometric.nn.models.dimenet import BesselBasisLayer
from torch_geometric.nn.models.dimenet import SphericalBasisLayer
from torch_geometric.nn.models.dimenet import ResidualLayer
from torch_geometric.nn.models.dimenet import InteractionBlock
from torch_geometric.nn.models.dimenet import OutputBlock

from torch_geometric.nn.models.dimenet_utils import bessel_basis, real_sph_harm
import ase
from ase.neighborlist import neighbor_list 
from ase import Atoms

from torch_geometric.data import DataLoader

try:
    import sympy as sym
except ImportError:
    sym = None

import os
try:
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
    import tensorflow as tf
except ImportError:
    tf = None


class EmbeddingBlock(torch.nn.Module):
  def __init__(self, num_radial, hidden_channels, act=swish):
    super(EmbeddingBlock, self).__init__()
    self.act = act

    self.emb = Embedding(95, hidden_channels)
    self.lin_rbf = Linear(num_radial, hidden_channels, bias=False)
    self.lin = Linear(3 * hidden_channels, hidden_channels)

    self.reset_parameters()

  def reset_parameters(self):
    self.emb.weight.data.uniform_(-sqrt(3), sqrt(3))
    self.lin_rbf.reset_parameters()
    self.lin.reset_parameters()

  def forward(self, x, rbf, i, j):
    x = self.emb(x)
    #rbf = self.act(self.lin_rbf(rbf)) # FIX: this should not have an activation function
    rbf = self.lin_rbf(rbf)
    return self.act(self.lin(torch.cat([x[i], x[j], rbf], dim=-1)))



class DimeNet2(torch.nn.Module):
  
  def __init__(self, hidden_channels, out_channels, num_blocks, num_bilinear,
                 num_spherical, num_radial, cutoff=5.0, envelope_exponent=5,
                 num_before_skip=1, num_after_skip=2, num_output_layers=3,
                 act=swish, mean=None, std=None):
    super(DimeNetx, self).__init__()

    self.cutoff = cutoff

    #set mean and standard deviation of energies
    self.mean = mean 
    self.std = std

    # padding used for PBCs
    self.padding = torch.nn.ConstantPad2d((0,6,0,0), 0)

    if sym is None:
        raise ImportError('Package `sympy` could not be found.')

    self.num_blocks = num_blocks

    self.rbf = BesselBasisLayer(num_radial, cutoff, envelope_exponent)
    self.sbf = SphericalBasisLayer(num_spherical, num_radial, cutoff,
                                    envelope_exponent)

    self.emb = EmbeddingBlock(num_radial, hidden_channels, act)

    self.output_blocks = torch.nn.ModuleList([
        OutputBlock(num_radial, hidden_channels, out_channels,
                    num_output_layers, act) for _ in range(num_blocks + 1)
    ])

    self.interaction_blocks = torch.nn.ModuleList([
        InteractionBlock(hidden_channels, num_bilinear, num_spherical,
                          num_radial, num_before_skip, num_after_skip, act)
        for _ in range(num_blocks)
    ])

    self.reset_parameters()

  def reset_parameters(self):
    self.rbf.reset_parameters()
    self.emb.reset_parameters()
    for out in self.output_blocks:
      out.reset_parameters()
    for interaction in self.interaction_blocks:
      interaction.reset_parameters()

  def triplets_original(self, edge_index, num_nodes):
    row, col = edge_index  # j->i

    value = torch.arange(row.size(0), device=row.device)
    adj_t = SparseTensor(row=col, col=row, value=value,
                         sparse_sizes=(num_nodes, num_nodes))
    adj_t_row = adj_t[row]
    num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long)

    # Node indices (k->j->i) for triplets.
    idx_i = col.repeat_interleave(num_triplets)
    idx_j = row.repeat_interleave(num_triplets)
    idx_k = adj_t_row.storage.col()
    mask = (idx_i != idx_k)  # Remove i == k triplets.
    idx_i, idx_j, idx_k = idx_i[mask], idx_j[mask], idx_k[mask]

    # Edge indices (k-j, j->i) for triplets.
    idx_kj = adj_t_row.storage.value()[mask]
    idx_ji = adj_t_row.storage.row()[mask]

    return col, row, idx_i, idx_j, idx_k, idx_kj, idx_ji


  def triplets(self, edge_index, num_nodes, shift_cells=None, shift=None):
    row, col = edge_index  # j->i

    value = torch.arange(row.size(0), device=row.device)
    adj_t = SparseTensor(row=col, col=row, value=value,
                         sparse_sizes=(num_nodes, num_nodes))
    adj_t_row = adj_t[row]
    num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long)

    idx_i = col.repeat_interleave(num_triplets)
    idx_j = row.repeat_interleave(num_triplets)
    idx_k = adj_t_row.storage.col()

    if shift_cells is not None: # Update also the shift vectors
      shift_cells_i = shift_cells.repeat_interleave(num_triplets, dim=0)
      shift_i = shift.repeat_interleave(num_triplets, dim=0)
      shift_cells_k = -shift_cells[adj_t_row.storage.value()]
      shift_k = -shift[adj_t_row.storage.value()]

    mask = torch.all((torch.cat((torch.unsqueeze(idx_i, 1), shift_i), dim=1) ==\
                      torch.cat((torch.unsqueeze(idx_k, 1), shift_k), dim=1)), dim=1)

    idx_i, idx_j, idx_k = idx_i[~mask], idx_j[~mask], idx_k[~mask]
    if shift_cells is not None: # Remove also from the shift vector
      shift_cells_i = shift_cells_i[~mask]
      shift_cells_k = shift_cells_k[~mask]
      shift_i = shift_i[~mask]
      shift_k = shift_k[~mask]

    idx_kj = adj_t_row.storage.value()[~mask]
    idx_ji = adj_t_row.storage.row()[~mask]

    return col, row, idx_i, idx_j, idx_k, idx_kj, idx_ji, shift_cells_i, shift_i, shift_cells_k, shift_k


  def forward(self, z, pos, cell=None, batch=None):

    edge_index = []
    dist = []
    shift_cells = None      
    
    edge_index = radius_graph(pos, r=self.cutoff, batch=batch)
    i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji = self.triplets_original(
      edge_index, num_nodes=z.size(0))        
    dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt()

    # Define atoms position 
    pos_i = pos[idx_i]
    pos_j = pos[idx_j]
    pos_k = pos[idx_k]

    # Calculate angles - with some Fixes to indexes compared to the orig. version
    pos_ji, pos_kj = pos_j - pos_i, pos_k - pos_j
    a = (pos_ji * pos_kj).sum(dim=-1)
    b = torch.cross(pos_ji, pos_kj)
    b = torch.linalg.norm(b + 1e-16, dim=-1)
    angle = torch.atan2(b, a)      

    rbf = self.rbf(dist)
    sbf = self.sbf(dist, angle, idx_kj)

    # Embedding block.
    x = self.emb(z, rbf, i, j)
    P = self.output_blocks[0](x, rbf, i, num_nodes=pos.size(0))

    # Interaction blocks.
    for interaction_block, output_block in zip(self.interaction_blocks,
                                               self.output_blocks[1:]):
      x = interaction_block(x, rbf, sbf, idx_kj, idx_ji)
      a = output_block(x, rbf, i, num_nodes=pos.size(0))
      P += a

    # Energy de-standardization
    if self.std is not None and self.mean is not None:
      P = P * self.std + self.mean
    
    
    return P, (P.sum(dim=0) if batch is None else scatter(P, batch, dim=0))

# Other function definitions

In [None]:
def volume(coords):
  a = np.array(coords)
  a = a.reshape((3,3))
  return np.linalg.det(a)

In [None]:
from torch_geometric.data import Data

DTYPE = torch.float64
FE_CHARGE = 26
def data_object(atoms: ase.Atoms):
  
  n = atoms.get_global_number_of_atoms()

  vol = volume(atoms.cell)

  if n == 1 and DB1x54 == True:
    atoms = extend_atoms(atoms, 54)

  n = atoms.get_global_number_of_atoms()

  cell = torch.tensor(atoms.cell, dtype=DTYPE).to(device)
  positions = torch.tensor(atoms.get_positions(), dtype=DTYPE).to(device)
  
  charges = [ FE_CHARGE ] * len(positions)
  charges = torch.tensor(charges, dtype=torch.long).to(device)

  y = torch.tensor(atoms.info["energy"], dtype=DTYPE).to(device)

  return Data(charges=charges, x=positions, y=y, cell=cell, n=n, vol=vol)

In [None]:
from ase import Atoms
def data_object_da_vett(cell: list, positions: list, energy = 0.0):
  atoms = Atoms(charges = [FE_CHARGE] * len(positions), positions=positions, cell=cell, info={ "energy": energy }, pbc=True)
  return data_object(atoms)

# Instantiate models and load pretrained

## SchNet

In [None]:
schnet_fn = "schnet_PBC_best (15giu)"
schnet_name = schnet_fn
model_data = torch.load(f"{MODELS_PATH}/{schnet_fn}", map_location=DEVICE)
print("{} \n {} \nmean = {}, std = {}".format(model_data["desc"], model_data["str"], model_data["mean"], model_data["std"]))

In [None]:
schnet = SchNet2(hidden_channels = 128, num_filters=128, num_gaussians=128, cutoff=5.0, 
                num_interactions=3, readout="sum", mean = -3460.825847482401, std = 0.16576383029449515).double().to(DEVICE)
schnet.load_state_dict(model_data["state"])
schnet.eval()

## DimeNet

In [None]:
dimenet_fn = "dimenet_PBC_best (10giu_57)"
dimenet_name = dimenet_fn
model_data = torch.load(f"{MODELS_PATH}/{dimenet_fn}", map_location=DEVICE) 
print("{} \n {} \nmean = {}, std = {}".format(model_data["desc"], model_data["str"], model_data["mean"], model_data["std"]))

In [None]:
dimenet = DimeNetx(hidden_channels=128, out_channels=1, num_blocks=7, num_bilinear=8, num_spherical=7, num_radial=6, cutoff=3.5, mean = -3460.8258474824015, std = 0.16576383029449682).double().to(DEVICE)

In [None]:
state = model_data["state"]
if list(state.keys())[0].find("module") != -1: # change dict keys if model was trained with DataParallel
  new_state = {}
  for k, v in state.items():
    new_k = k[k.find(".")+1:]
    new_state[new_k] = v
else:
  new_state = state
dimenet.load_state_dict(new_state)
dimenet.eval()

# **Equation of state**

## Define functions

In [None]:
def eos_data():
  a0 = 2.843877166
  data_list = []
  a_arr = np.append(np.arange(a0*0.98, a0*1.02, 0.0005*a0), a0)
  a_arr = np.sort(a_arr)
  for a in a_arr:
    data_list.append(data_object_da_vett([[a, 0, 0], [0, a, 0], [a/2, a/2, a/2]], 
                                  [[0.0, 0.0, 0.0]]
    ))
  print(f"generated {len(data_list)} points")
  return data_list

In [None]:
def eos(model):
  out_x = []
  out_y = []
  model.eval()
  data_list = eos_data()
  test_loader = DataLoader(data_list, batch_size=1, shuffle=False)
  for data in test_loader:  # Iterate in batches over the training/test dataset.
      e = model(data.z, data.x, cell=data.cell, batch=data.batch)
      out_y.append(e.squeeze(1).item())
      out_x.append(data.vol.item())
      
  return out_x, out_y

In [None]:
def equation_of_state(model_name, x_model, y_model, x_dft, y_dft):
  plt.rcParams.update({'font.size':25})
  plt.rcParams.update({'axes.titlesize':30})
  plt.rcParams.update({'figure.titlesize':30})
  plt.rcParams.update({'xtick.labelsize':15})
  plt.rcParams.update({'ytick.labelsize':15})
  plt.rcParams.update({'legend.fontsize':25})

  fig, ax = plt.subplots(figsize=(16,10))
  plt.cla()
  ax.set_xlim(10.99, 12)
  ax.set_ylim(-3460934.301071031, -3460914.9599393304)
  ax.set_title(model_name)
  plt.xlabel("V")
  plt.ylabel("E")

  ax.scatter(x_model, [(en*1000) for en in y_model], color="blue", label=model_name)
  ax.plot(x_model, [(en*1000) for en in y_model], color="blue")

  ax.scatter(x_dft, y_dft, color="black", label="DFT")
  ax.plot(x_dft, y_dft, color="black")

  ax.legend()
  plt.show()

## Load data

In [None]:
eos_x_dft = torch.load(f"{DATA_PATH}/equation-of-state/dft_v.pt", map_location=DEVICE)
eos_y_dft = torch.load(f"{DATA_PATH}/equation-of-state/dft_e.pt", map_location=DEVICE)
eos_x_gap = torch.load(f"{DATA_PATH}/equation-of-state/gap_v.pt", map_location=DEVICE)
eos_y_gap = torch.load(f"{DATA_PATH}/equation-of-state/gap_e.pt", map_location=DEVICE)

## Plot

### Schnet

In [None]:
eos_x_schnet, eos_y_schnet = eos(schnet)
equation_of_state(schnet_name, eos_x_schnet, eos_y_schnet, eos_x_dft, eos_y_dft)

### Dimenet

In [None]:
eos_x_dimenet, eos_y_dimenet = eos(dimenet)
equation_of_state(dimenet_name, eos_x_dimenet, eos_y_dimenet, eos_x_dft, eos_y_dft)

# **Bain path**

## Define functions

In [None]:
import math
from math import sqrt
from torch_geometric.data import DataLoader
def bain_vo(model, a_coeff_min=0.95, a_coeff_max=1.05):

  a_0 = 2.83477

  plot_x = []
  plot_y = []
  extra_x = []
  extra_y = []

  cas = []
  for sr_ in np.arange(0.5, 2.05, 0.02):
    sr = sr_.item()
    data_list = []

    r = sr**(2/3)
    a_c = r**(-1)
    a_tmp = np.arange(0.98*a_0,1.02*a_0,0.02)
    a_tmp = a_tmp.tolist()
    a_tmp.append(a_0)
    tmp_arr = []
    for current_a in a_tmp:
      cell = [[math.sqrt(a_c)*current_a, 0, 0], [0, math.sqrt(a_c)*current_a, 0], [math.sqrt(a_c)*current_a/2, math.sqrt(a_c)*current_a/2, (r)*current_a/2]]
      positions = [[0.0, 0.0, 0.0]]
      data_list.append(data_object_da_vett(cell, positions, 0.0))

    model.eval()
    inputs = []
    candidate_ys = []
    test_loader = DataLoader(data_list, batch_size=1, shuffle=False)
    for data in test_loader:  # Iterate in batches over the training/test dataset.
      e = model(data.z, data.x, cell=data.cell, batch=batch)
      candidate_ys.append(e.squeeze(1).item())
      inputs.append(data)

    plot_x.append(sr)
    #plot_x.append(r/sqrt(1/r))
    extra_x += [r/sqrt(1/r)]*len(candidate_ys)
    #extra_x += [r]*len(candidate_ys)

    #find the minimum energy and corresponding a value;
    j = 0
    ok = 0
    for i in range(len(candidate_ys)):
      if candidate_ys[i] < candidate_ys[j]:
        j = i

    cas.append(r)

    plot_y.append(min(candidate_ys))

    extra_y += candidate_ys

  return plot_x, plot_y, extra_x, extra_y, cas

In [None]:
def bain_cv(model):

  a = 2.83477
  #a = 2.8325

  plot_x = []
  plot_y = []
  data_list = []
  data = []

  for sr_ in np.arange(0.5, 2, 0.02):
    sr = sr_.item()
    r = sr**(2/3)
    a_c = r**(-1)

    cell = [[sqrt(a_c)*a,0,0],[0,sqrt(a_c)*a,0],[sqrt(a_c)*a/2,sqrt(a_c)*a/2,(r)*a/2]]
    positions = [[0.0,0.0,0.0]]
    data_list.append(data_object_da_vett(cell, positions, 0.0))

    #plot_x.append(r)
    plot_x.append(r/sqrt(1/r))
  
  test_loader = DataLoader(data_list, batch_size=1, shuffle=False)
  
  for data in test_loader:  # Iterate in batches over the training/test dataset.
      e = model(data.z, data.x, cell=data.cell, batch=data.batch)
      plot_y.append(e.squeeze(1).item())

  return plot_x, plot_y

In [None]:
def bain_path(model_name, calc_mode, model_x, model_y, dft_x, dft_y):
  plt.rcParams.update({'font.size':25})
  plt.rcParams.update({'axes.titlesize':30})
  plt.rcParams.update({'figure.titlesize':30})
  plt.rcParams.update({'xtick.labelsize':15})
  plt.rcParams.update({'ytick.labelsize':15})
  plt.rcParams.update({'legend.fontsize':20})

  fig, ax = plt.subplots(figsize=(16,10))
  plt.cla()
  ax.set_xlim(0.7, 2.05)
  ax.set_ylim(min([y*1000 for y in model_y]) -50, -3.4605e6 + 90)
  ax.set_title(model_name + " " + calc_mode)
  plt.xlabel("c/a")
  plt.ylabel("E")

  plt.axvline(x=0.8, linestyle='--', label="bct")
  plt.axvline(x=1.0, linestyle='--', color='orange', label="bcc")
  plt.axvline(x=math.sqrt(2), linestyle='--', color='brown', label="fcc")

  ax.scatter(dft_x, [en for en in dft_y], color="black", label="DFT")
  ax.plot(dft_x, [en for en in dft_y], color="black")

  ax.scatter(model_x, [(en*1000) for en in model_y], color="blue")
  ax.plot(model_x, [(en*1000) for en in model_y], color="blue", label=model_name + " " + calc_mode)

  ax.legend()

## Load data

In [None]:
a = open(f"{DATA_PATH}/bain-path/DFT.csv", "r").read()
d = a.replace(";", "").replace(",",".").split()
bain_x_dft = [float(x) for i, x in enumerate(d) if i%2 == 0]
bain_y_dft = [float(x) for i, x in enumerate(d) if i%2 == 1]

## Plot

### Schnet

In [None]:
bain_vo_x_schnet, bain_vo_y_schnet, _, _, _ = bain_vo(schnet)
bain_path(schnet_name, "VO", bain_vo_x_schnet, bain_vo_y_schnet, bain_x_dft, bain_y_dft)

In [None]:
bain_cv_x_schnet, bain_cv_y_schnet = bain_cv(schnet)
bain_path(schnet_name, "CV", bain_cv_x_schnet, bain_cv_y_schnet, bain_x_dft, bain_y_dft)

### Dimenet

In [None]:
bain_vo_x_dimenet, bain_vo_y_dimenet, _, _, _ = bain_vo(dimenet)
bain_path(dimenet_name, "VO", bain_vo_x_dimenet, bain_vo_y_dimenet, bain_x_dft, bain_y_dft)

In [None]:
bain_cv_x_dimenet, bain_cv_y_dimenet = bain_cv(dimenet)
bain_path(dimenet_name, "CV", bain_cv_x_dimenet, bain_cv_y_dimenet, bain_x_dft, bain_y_dft)

# **Vacancy formation energy**

## Define functions

In [None]:
def eval_vacancy(model, loader):
  ycap = []
  for data in loader:  # Iterate in batches over the training/test dataset.
      e = model(data.z, data.x, cell=data.cell, batch=data.batch)
      ycap.append(e.squeeze(1).item())  
  return ycap

In [None]:
def vacancy_formation_energy(model_name, N, E_total, E_total_v):
  print(model_name)
  print("{} -> {}".format(N, N-1))
  print("E_total_v =", E_total_v)
  print("E_total = ", E_total)
  print("E_v =", E_total_v - (E_total*(N-1)/N))

## Load data

In [None]:
#bravais lattice vectors
vac_128_lattice_str = """
 11.3360000   0.0000000   0.0000000 
  0.0000000  11.3360000   0.0000000 
  0.0000000   0.0000000  11.3360000 
""" 
#sites
vac_128_sites_str = """
Fe      0.0000000   0.0000000   0.0000000
Fe      1.4170000   1.4170000   1.4170000
Fe      8.5020000   0.0000000   2.8340000
Fe      5.6680000   5.6680000   0.0000000
Fe      8.5020000   2.8340000   0.0000000
Fe      2.8340000   2.8340000   5.6680000
Fe      2.8340000   8.5020000   8.5020000
Fe      0.0000000   0.0000000   8.5020000
Fe      5.6680000   5.6680000   8.5020000
Fe      0.0000000   2.8340000   5.6680000
Fe      5.6680000   0.0000000   5.6680000
Fe      8.5020000   2.8340000   8.5020000
Fe      0.0000000   5.6680000   2.8340000
Fe      2.8340000   0.0000000   5.6680000
Fe      8.5020000   5.6680000   5.6680000
Fe      0.0000000   8.5020000   0.0000000
Fe      2.8340000   2.8340000   2.8340000
Fe      8.5020000   8.5020000   2.8340000
Fe      2.8340000   5.6680000   0.0000000
Fe      5.6680000   8.5020000   5.6680000
Fe      5.6680000   0.0000000   2.8340000
Fe      5.6680000   2.8340000   0.0000000
Fe      5.6680000   2.8340000   2.8340000
Fe      8.5020000   0.0000000   0.0000000
Fe      0.0000000   8.5020000   8.5020000
Fe      2.8340000   5.6680000   8.5020000
Fe      2.8340000   8.5020000   5.6680000
Fe      0.0000000   0.0000000   5.6680000
Fe      5.6680000   2.8340000   8.5020000
Fe      8.5020000   0.0000000   8.5020000
Fe      0.0000000   2.8340000   2.8340000
Fe      5.6680000   5.6680000   5.6680000
Fe      8.5020000   2.8340000   5.6680000
Fe      0.0000000   5.6680000   0.0000000
Fe      2.8340000   0.0000000   2.8340000
Fe      8.5020000   5.6680000   2.8340000
Fe      2.8340000   2.8340000   0.0000000
Fe      5.6680000   8.5020000   2.8340000
Fe      8.5020000   8.5020000   0.0000000
Fe      5.6680000   0.0000000   0.0000000
Fe      0.0000000   5.6680000   8.5020000
Fe      0.0000000   8.5020000   5.6680000
Fe      2.8340000   2.8340000   8.5020000
Fe      8.5020000   8.5020000   8.5020000
Fe      2.8340000   5.6680000   5.6680000
Fe      5.6680000   0.0000000   8.5020000
Fe      2.8340000   8.5020000   2.8340000
Fe      0.0000000   0.0000000   2.8340000
Fe      5.6680000   2.8340000   5.6680000
Fe      8.5020000   0.0000000   5.6680000
Fe      0.0000000   2.8340000   0.0000000
Fe      5.6680000   5.6680000   2.8340000
Fe      8.5020000   2.8340000   2.8340000
Fe      2.8340000   0.0000000   0.0000000
Fe      5.6680000   8.5020000   0.0000000
Fe      8.5020000   5.6680000   0.0000000
Fe      0.0000000   2.8340000   8.5020000
Fe      5.6680000   8.5020000   8.5020000
Fe      0.0000000   5.6680000   5.6680000
Fe      2.8340000   0.0000000   8.5020000
Fe      8.5020000   5.6680000   8.5020000
Fe      0.0000000   8.5020000   2.8340000
Fe      8.5020000   8.5020000   5.6680000
Fe      2.8340000   5.6680000   2.8340000
Fe      2.8340000   8.5020000   0.0000000
Fe      9.9190000   1.4170000   4.2510000
Fe      7.0850000   7.0850000   1.4170000
Fe      9.9190000   4.2510000   1.4170000
Fe      4.2510000   4.2510000   7.0850000
Fe      4.2510000   9.9190000   9.9190000
Fe      1.4170000   1.4170000   9.9190000
Fe      7.0850000   7.0850000   9.9190000
Fe      1.4170000   4.2510000   7.0850000
Fe      7.0850000   1.4170000   7.0850000
Fe      9.9190000   4.2510000   9.9190000
Fe      1.4170000   7.0850000   4.2510000
Fe      4.2510000   1.4170000   7.0850000
Fe      9.9190000   7.0850000   7.0850000
Fe      1.4170000   9.9190000   1.4170000
Fe      4.2510000   4.2510000   4.2510000
Fe      9.9190000   9.9190000   4.2510000
Fe      4.2510000   7.0850000   1.4170000
Fe      7.0850000   9.9190000   7.0850000
Fe      7.0850000   1.4170000   4.2510000
Fe      7.0850000   4.2510000   1.4170000
Fe      7.0850000   4.2510000   4.2510000
Fe      9.9190000   1.4170000   1.4170000
Fe      1.4170000   9.9190000   9.9190000
Fe      4.2510000   7.0850000   9.9190000
Fe      4.2510000   9.9190000   7.0850000
Fe      1.4170000   1.4170000   7.0850000
Fe      7.0850000   4.2510000   9.9190000
Fe      9.9190000   1.4170000   9.9190000
Fe      1.4170000   4.2510000   4.2510000
Fe      7.0850000   7.0850000   7.0850000
Fe      9.9190000   4.2510000   7.0850000
Fe      1.4170000   7.0850000   1.4170000
Fe      4.2510000   1.4170000   4.2510000
Fe      9.9190000   7.0850000   4.2510000
Fe      4.2510000   4.2510000   1.4170000
Fe      7.0850000   9.9190000   4.2510000
Fe      9.9190000   9.9190000   1.4170000
Fe      7.0850000   1.4170000   1.4170000
Fe      1.4170000   7.0850000   9.9190000
Fe      1.4170000   9.9190000   7.0850000
Fe      4.2510000   4.2510000   9.9190000
Fe      9.9190000   9.9190000   9.9190000
Fe      4.2510000   7.0850000   7.0850000
Fe      7.0850000   1.4170000   9.9190000
Fe      4.2510000   9.9190000   4.2510000
Fe      1.4170000   1.4170000   4.2510000
Fe      7.0850000   4.2510000   7.0850000
Fe      9.9190000   1.4170000   7.0850000
Fe      1.4170000   4.2510000   1.4170000
Fe      7.0850000   7.0850000   4.2510000
Fe      9.9190000   4.2510000   4.2510000
Fe      4.2510000   1.4170000   1.4170000
Fe      7.0850000   9.9190000   1.4170000
Fe      9.9190000   7.0850000   1.4170000
Fe      1.4170000   4.2510000   9.9190000
Fe      7.0850000   9.9190000   9.9190000
Fe      1.4170000   7.0850000   7.0850000
Fe      4.2510000   1.4170000   9.9190000
Fe      9.9190000   7.0850000   9.9190000
Fe      1.4170000   9.9190000   4.2510000
Fe      9.9190000   9.9190000   7.0850000
Fe      4.2510000   7.0850000   4.2510000
Fe      4.2510000   9.9190000   1.4170000"""

vac_127_lattice_str = """
 11.3360000   0.0000000   0.0000000 
  0.0000000  11.3360000   0.0000000 
  0.0000000   0.0000000  11.3360000 
"""

vac_127_sites_str = """
Fe      1.4170000   1.4170000   1.4170000
Fe      8.5020000   0.0000000   2.8340000
Fe      5.6680000   5.6680000   0.0000000
Fe      8.5020000   2.8340000   0.0000000
Fe      2.8340000   2.8340000   5.6680000
Fe      2.8340000   8.5020000   8.5020000
Fe      0.0000000   0.0000000   8.5020000
Fe      5.6680000   5.6680000   8.5020000
Fe      0.0000000   2.8340000   5.6680000
Fe      5.6680000   0.0000000   5.6680000
Fe      8.5020000   2.8340000   8.5020000
Fe      0.0000000   5.6680000   2.8340000
Fe      2.8340000   0.0000000   5.6680000
Fe      8.5020000   5.6680000   5.6680000
Fe      0.0000000   8.5020000   0.0000000
Fe      2.8340000   2.8340000   2.8340000
Fe      8.5020000   8.5020000   2.8340000
Fe      2.8340000   5.6680000   0.0000000
Fe      5.6680000   8.5020000   5.6680000
Fe      5.6680000   0.0000000   2.8340000
Fe      5.6680000   2.8340000   0.0000000
Fe      5.6680000   2.8340000   2.8340000
Fe      8.5020000   0.0000000   0.0000000
Fe      0.0000000   8.5020000   8.5020000
Fe      2.8340000   5.6680000   8.5020000
Fe      2.8340000   8.5020000   5.6680000
Fe      0.0000000   0.0000000   5.6680000
Fe      5.6680000   2.8340000   8.5020000
Fe      8.5020000   0.0000000   8.5020000
Fe      0.0000000   2.8340000   2.8340000
Fe      5.6680000   5.6680000   5.6680000
Fe      8.5020000   2.8340000   5.6680000
Fe      0.0000000   5.6680000   0.0000000
Fe      2.8340000   0.0000000   2.8340000
Fe      8.5020000   5.6680000   2.8340000
Fe      2.8340000   2.8340000   0.0000000
Fe      5.6680000   8.5020000   2.8340000
Fe      8.5020000   8.5020000   0.0000000
Fe      5.6680000   0.0000000   0.0000000
Fe      0.0000000   5.6680000   8.5020000
Fe      0.0000000   8.5020000   5.6680000
Fe      2.8340000   2.8340000   8.5020000
Fe      8.5020000   8.5020000   8.5020000
Fe      2.8340000   5.6680000   5.6680000
Fe      5.6680000   0.0000000   8.5020000
Fe      2.8340000   8.5020000   2.8340000
Fe      0.0000000   0.0000000   2.8340000
Fe      5.6680000   2.8340000   5.6680000
Fe      8.5020000   0.0000000   5.6680000
Fe      0.0000000   2.8340000   0.0000000
Fe      5.6680000   5.6680000   2.8340000
Fe      8.5020000   2.8340000   2.8340000
Fe      2.8340000   0.0000000   0.0000000
Fe      5.6680000   8.5020000   0.0000000
Fe      8.5020000   5.6680000   0.0000000
Fe      0.0000000   2.8340000   8.5020000
Fe      5.6680000   8.5020000   8.5020000
Fe      0.0000000   5.6680000   5.6680000
Fe      2.8340000   0.0000000   8.5020000
Fe      8.5020000   5.6680000   8.5020000
Fe      0.0000000   8.5020000   2.8340000
Fe      8.5020000   8.5020000   5.6680000
Fe      2.8340000   5.6680000   2.8340000
Fe      2.8340000   8.5020000   0.0000000
Fe      9.9190000   1.4170000   4.2510000
Fe      7.0850000   7.0850000   1.4170000
Fe      9.9190000   4.2510000   1.4170000
Fe      4.2510000   4.2510000   7.0850000
Fe      4.2510000   9.9190000   9.9190000
Fe      1.4170000   1.4170000   9.9190000
Fe      7.0850000   7.0850000   9.9190000
Fe      1.4170000   4.2510000   7.0850000
Fe      7.0850000   1.4170000   7.0850000
Fe      9.9190000   4.2510000   9.9190000
Fe      1.4170000   7.0850000   4.2510000
Fe      4.2510000   1.4170000   7.0850000
Fe      9.9190000   7.0850000   7.0850000
Fe      1.4170000   9.9190000   1.4170000
Fe      4.2510000   4.2510000   4.2510000
Fe      9.9190000   9.9190000   4.2510000
Fe      4.2510000   7.0850000   1.4170000
Fe      7.0850000   9.9190000   7.0850000
Fe      7.0850000   1.4170000   4.2510000
Fe      7.0850000   4.2510000   1.4170000
Fe      7.0850000   4.2510000   4.2510000
Fe      9.9190000   1.4170000   1.4170000
Fe      1.4170000   9.9190000   9.9190000
Fe      4.2510000   7.0850000   9.9190000
Fe      4.2510000   9.9190000   7.0850000
Fe      1.4170000   1.4170000   7.0850000
Fe      7.0850000   4.2510000   9.9190000
Fe      9.9190000   1.4170000   9.9190000
Fe      1.4170000   4.2510000   4.2510000
Fe      7.0850000   7.0850000   7.0850000
Fe      9.9190000   4.2510000   7.0850000
Fe      1.4170000   7.0850000   1.4170000
Fe      4.2510000   1.4170000   4.2510000
Fe      9.9190000   7.0850000   4.2510000
Fe      4.2510000   4.2510000   1.4170000
Fe      7.0850000   9.9190000   4.2510000
Fe      9.9190000   9.9190000   1.4170000
Fe      7.0850000   1.4170000   1.4170000
Fe      1.4170000   7.0850000   9.9190000
Fe      1.4170000   9.9190000   7.0850000
Fe      4.2510000   4.2510000   9.9190000
Fe      9.9190000   9.9190000   9.9190000
Fe      4.2510000   7.0850000   7.0850000
Fe      7.0850000   1.4170000   9.9190000
Fe      4.2510000   9.9190000   4.2510000
Fe      1.4170000   1.4170000   4.2510000
Fe      7.0850000   4.2510000   7.0850000
Fe      9.9190000   1.4170000   7.0850000
Fe      1.4170000   4.2510000   1.4170000
Fe      7.0850000   7.0850000   4.2510000
Fe      9.9190000   4.2510000   4.2510000
Fe      4.2510000   1.4170000   1.4170000
Fe      7.0850000   9.9190000   1.4170000
Fe      9.9190000   7.0850000   1.4170000
Fe      1.4170000   4.2510000   9.9190000
Fe      7.0850000   9.9190000   9.9190000
Fe      1.4170000   7.0850000   7.0850000
Fe      4.2510000   1.4170000   9.9190000
Fe      9.9190000   7.0850000   9.9190000
Fe      1.4170000   9.9190000   4.2510000
Fe      9.9190000   9.9190000   7.0850000
Fe      4.2510000   7.0850000   4.2510000
Fe      4.2510000   9.9190000   1.4170000"""

vac_53_lattice_str = """
 8.5020000   0.0000000   0.0000000
 0.0000000   8.5020000   0.0000000
 0.0000000   0.0000000   8.5020000
"""

vac_53_sites_str = """
Fe            1.3710331141        1.3710328145        1.3710330239
Fe            2.8636873181        0.0000006056        0.0000002850
Fe            5.6383130622        0.0000006007        0.0000002554
Fe           -0.0000003236        0.0000005923        5.6383131764
Fe           -0.0000003588        5.6780135759        2.8239863030
Fe            2.8239871194        0.0000007435        5.6780131819
Fe            5.6780131857        0.0000007299        5.6780132388
Fe            2.8341486131        5.6678518988        2.8341480191
Fe            5.6678516811        5.6678519082        2.8341480047
Fe           -0.0000003414        2.8636857854        0.0000003193
Fe            2.8239869847        2.8239859741        0.0000003364
Fe            5.6780133287        2.8239860021        0.0000002954
Fe           -0.0000003894        2.8239858259        5.6780134613
Fe            2.8341485983        2.8341475140        5.6678517358
Fe            5.6678517287        2.8341475206        5.6678518095
Fe           -0.0000002984        0.0000005459        2.8636863926
Fe           -0.0000003480        5.6383133214        0.0000003529
Fe            2.8239870666        0.0000006725        2.8239865633
Fe            5.6780131849        0.0000006603        2.8239865538
Fe            2.8239870623        5.6780134537        0.0000003985
Fe            5.6780132879        5.6780134462        0.0000003540
Fe           -0.0000004008        5.6780135795        5.6780133363
Fe            2.8341486930        5.6678518824        5.6678515993
Fe            5.6678516714        5.6678518959        5.6678516707
Fe           -0.0000003469        2.8239859294        2.8239862830
Fe            2.8341485324        2.8341475875        2.8341479943
Fe            5.6678517451        2.8341475991        2.8341479618
Fe            4.2510003036        1.4201541600        1.4201542876
Fe            7.1309666303        1.3710328093        1.3710329475
Fe            1.3710331187        1.3710327648        7.1309672952
Fe            1.4201542656        7.0818467019        4.2509996322
Fe            4.2510003506        1.4201541416        7.0818459679
Fe            7.1309665941        1.3710327507        7.1309673430
Fe            4.2510003140        7.0939924882        4.2509997047
Fe            7.0818453833        7.0818466683        4.2509996776
Fe            1.4201540444        4.2509993812        1.4201540671
Fe            4.2510003236        4.2509994657        1.4080084858
Fe            7.0818455796        4.2509993973        1.4201540235
Fe            1.4201540621        4.2509993032        7.0818463141
Fe            4.2510003715        4.2509993925        7.0939919123
Fe            7.0818455294        4.2509993185        7.0818463189
Fe            1.4201542182        1.4201539663        4.2509997132
Fe            1.3710331908        7.1309676815        1.3710331340
Fe            4.2510002920        1.4080083865        4.2509997882
Fe            7.0818454339        1.4201539908        4.2509997531
Fe            4.2510003261        7.0818465133        1.4201543806
Fe            7.1309665364        7.1309676832        1.3710330651
Fe            1.3710332453        7.1309677188        7.1309671530
Fe            4.2510003731        7.0818465930        7.0818459245
Fe            7.1309664843        7.1309677507        7.1309672224
Fe            1.4080083930        4.2509994420        4.2509996333
Fe            4.2510003258        4.2509994018        4.2509996692
Fe            7.0939911399        4.2509994633        4.2509996798
"""

vac_54_lattice_str = """
 8.5020000   0.0000000   0.0000000
 0.0000000   8.5020000   0.0000000
 0.0000000   0.0000000   8.5020000
"""

vac_54_sites_str = """
Fe      0.0000000   0.0000000   0.0000000
Fe      1.4170000   1.4170000   1.4170000
Fe      2.8340000   0.0000000   0.0000000
Fe      5.6680000   0.0000000   0.0000000
Fe      0.0000000   0.0000000   5.6680000
Fe      0.0000000   5.6680000   2.8340000
Fe      2.8340000   0.0000000   5.6680000
Fe      5.6680000   0.0000000   5.6680000
Fe      2.8340000   5.6680000   2.8340000
Fe      5.6680000   5.6680000   2.8340000
Fe      0.0000000   2.8340000   0.0000000
Fe      2.8340000   2.8340000   0.0000000
Fe      5.6680000   2.8340000   0.0000000
Fe      0.0000000   2.8340000   5.6680000
Fe      2.8340000   2.8340000   5.6680000
Fe      5.6680000   2.8340000   5.6680000
Fe      0.0000000   0.0000000   2.8340000
Fe      0.0000000   5.6680000   0.0000000
Fe      2.8340000   0.0000000   2.8340000
Fe      5.6680000   0.0000000   2.8340000
Fe      2.8340000   5.6680000   0.0000000
Fe      5.6680000   5.6680000   0.0000000
Fe      0.0000000   5.6680000   5.6680000
Fe      2.8340000   5.6680000   5.6680000
Fe      5.6680000   5.6680000   5.6680000
Fe      0.0000000   2.8340000   2.8340000
Fe      2.8340000   2.8340000   2.8340000
Fe      5.6680000   2.8340000   2.8340000
Fe      4.2510000   1.4170000   1.4170000
Fe      7.0850000   1.4170000   1.4170000
Fe      1.4170000   1.4170000   7.0850000
Fe      1.4170000   7.0850000   4.2510000
Fe      4.2510000   1.4170000   7.0850000
Fe      7.0850000   1.4170000   7.0850000
Fe      4.2510000   7.0850000   4.2510000
Fe      7.0850000   7.0850000   4.2510000
Fe      1.4170000   4.2510000   1.4170000
Fe      4.2510000   4.2510000   1.4170000
Fe      7.0850000   4.2510000   1.4170000
Fe      1.4170000   4.2510000   7.0850000
Fe      4.2510000   4.2510000   7.0850000
Fe      7.0850000   4.2510000   7.0850000
Fe      1.4170000   1.4170000   4.2510000
Fe      1.4170000   7.0850000   1.4170000
Fe      4.2510000   1.4170000   4.2510000
Fe      7.0850000   1.4170000   4.2510000
Fe      4.2510000   7.0850000   1.4170000
Fe      7.0850000   7.0850000   1.4170000
Fe      1.4170000   7.0850000   7.0850000
Fe      4.2510000   7.0850000   7.0850000
Fe      7.0850000   7.0850000   7.0850000
Fe      1.4170000   4.2510000   4.2510000
Fe      4.2510000   4.2510000   4.2510000
Fe      7.0850000   4.2510000   4.2510000
"""

In [None]:
def from_lattice_sites(lattice_str, sites_str):
  lattice = list(map(float, lattice_str.strip().split()))
  sites = sites_str.strip().split("\n")
  sites = [x.replace("Fe", "").strip().split() for x in sites]
  sites = [ list(map(float, x)) for x in sites ]
  print(len(lattice), len(sites))
  return lattice, sites

In [None]:
vac_128_lattice, vac_128_sites = from_lattice_sites(vac_128_lattice_str, vac_128_sites_str) 
vac_127_lattice, vac_127_sites = from_lattice_sites(vac_127_lattice_str, vac_127_sites_str) 
reshape = lambda x: [[x[0], x[1], x[2]],[x[3], x[4],x[5]], [x[6], x[7], x[8]] ]
vac_127_lattice = reshape(vac_127_lattice)
vac_128_lattice = reshape(vac_128_lattice)
vac_128 = data_object_da_vett(vac_128_lattice, vac_128_sites, 0.0)
vac_127 = data_object_da_vett(vac_127_lattice, vac_127_sites, 0.0)

vac_54_lattice, vac_54_sites = from_lattice_sites(vac_54_lattice_str, vac_54_sites_str) 
vac_53_lattice, vac_53_sites = from_lattice_sites(vac_53_lattice_str, vac_53_sites_str) 
vac_53_lattice = reshape(vac_53_lattice)
vac_54_lattice = reshape(vac_54_lattice)
vac_54 = data_object_da_vett(vac_54_lattice, vac_54_sites, 0.0)
vac_53 = data_object_da_vett(vac_53_lattice, vac_53_sites, 0.0)

In [None]:
from torch_geometric.data import DataLoader
loader_128 = DataLoader([vac_128, vac_127], batch_size=1, shuffle=False)
loader_54 = DataLoader([vac_54, vac_53], batch_size=1, shuffle=False)

## Evaluation

### Schnet

In [None]:
E_total, E_total_v = eval_vacancy(schnet, loader_54)
vacancy_formation_energy(schnet_name, 54, E_total, E_total_v)

In [None]:
E_total, E_total_v = eval_vacancy(schnet, loader_128)
vacancy_formation_energy(schnet_name, 128, E_total, E_total_v)

### Dimenet

In [None]:
E_total, E_total_v = eval_vacancy(dimenet, loader_54)
vacancy_formation_energy(dimenet_name, 54, E_total, E_total_v)

In [None]:
E_total, E_total_v = eval_vacancy(dimenet, loader_128)
vacancy_formation_energy(dimenet_name, 128, E_total, E_total_v)

# **Surface energy**


## Define functions


In [None]:
import re
def get_cell(configuration_string):
  for line in configuration_string.split("\n"):
    if line.find("xlo xhi") != -1:
      xlo, xhi = [float(v) for v in re.sub("\s+"," ", line).strip().split(" ")[:2]]
    if line.find("ylo yhi") != -1:
      ylo, yhi = [float(v) for v in re.sub("\s+"," ", line).strip().split(" ")[:2]]
    if line.find("zlo zhi") != -1:
      zlo, zhi = [float(v) for v in re.sub("\s+"," ", line).strip().split(" ")[:2]]
  return [[xhi-xlo, 0.0, 0.0], [0.0, yhi-ylo, 0.0], [0.0, 0.0, zhi-zlo]]

def get_positions(configuration_string):
  configuration_string = configuration_string.split("Atoms")[1]
  positions = []
  for line in configuration_string.split("\n")[2:-1]:
    xyz = re.sub("\s+"," ", line).strip().split(" ")[2:]
    positions.append([float(v) for v in xyz])
  return positions

In [None]:
def data_object_from_lmpfile(fn):
  with open(fn) as f:
    s = f.read()
  dobj = data_object_da_vett(get_cell(s), get_positions(s))
  return dobj

In [None]:
def eV_to_J(e):
  return e / 6.242e+18

In [None]:
def A_to_m(l):
  return l / 1e+10

In [None]:
def evaluate(model, data_list):
  loader = DataLoader(data_list, batch_size=1, shuffle=False)
  model.eval()
  ycap = []
  for data in loader:  # Iterate in batches over the training/test dataset.
      e = model(data.z, data.x, cell=data.cell, batch=data.batch)
      ycap.append(e.squeeze(1).item())  
  return ycap

In [None]:
def surface_energy(model, model_name, bulk, surf, id):
  print(model_name)
  
  A = bulk.cell[1][1].item() * bulk.cell[2][2].item()
  print(f"A_{id}: {A}")
  
  E_bulk, E_surf = evaluate(model, [bulk, surf])
  print(f"E_bulk_{id}: {E_bulk}")
  print(f"E_surf_{id}: {E_surf}")

  gamma_surf = (E_surf - E_bulk) / (2 * A)

  E_bulk_J, E_surf_J = eV_to_J(E_bulk), eV_to_J(E_surf)
  A_m = A_to_m(bulk.cell[1][1].item()) * A_to_m(bulk.cell[2][2].item())
  gamma_surf_J_m2 = (E_surf_J - E_bulk_J) / (2 * A_m)

  print(f"gamma_surf_{id}: {gamma_surf} (eV/A^2) = {gamma_surf_J_m2} (J/m^2)")
  return gamma_surf

## Load data

In [None]:
SURF_PATH = f"{DATA_PATH}/surface-energy/"

In [None]:
bulk_100 = data_object_from_lmpfile(SURF_PATH + "lmpmodel_100")
surf_100 = data_object_from_lmpfile(SURF_PATH + "lmpmodel_100_surf")

In [None]:
bulk_110 = data_object_from_lmpfile(SURF_PATH + "lmpmodel_110")
surf_110 = data_object_from_lmpfile(SURF_PATH + "lmpmodel_110_surf")

In [None]:
bulk_111 = data_object_from_lmpfile(SURF_PATH + "lmpmodel_111")
surf_111 = data_object_from_lmpfile(SURF_PATH + "lmpmodel_111_surf")

In [None]:
bulk_112 = data_object_from_lmpfile(SURF_PATH + "lmpmodel_112")
surf_112 = data_object_from_lmpfile(SURF_PATH + "lmpmodel_112_surf")

## Evaluation

### Schnet

In [None]:
gamma_surf_100 = surface_energy(schnet, schnet_name, bulk_100, surf_100, "100")

In [None]:
gamma_surf_110 = surface_energy(schnet, schnet_name, bulk_110, surf_110, "110")

In [None]:
gamma_surf_111 = surface_energy(schnet, schnet_name, bulk_111, surf_111, "111")

In [None]:
gamma_surf_112 = surface_energy(schnet, schnet_name, bulk_112, surf_112, "112")

### Dimenet

In [None]:
gamma_surf_100 = surface_energy(dimenet, dimenet_name, bulk_100, surf_100, "100")

In [None]:
gamma_surf_110 = surface_energy(dimenet, dimenet_name, bulk_110, surf_110, "110")

In [None]:
gamma_surf_111 = surface_energy(dimenet, dimenet_name, bulk_111, surf_111, "111")

In [None]:
gamma_surf_112 = surface_energy(dimenet, dimenet_name, bulk_112, surf_112, "112")