# Initial Setup

*   Make sure you run the model using a GPU (On Google Colab: Runtime -> Change Runtime Type -> GPU)

In [None]:
#@title  { vertical-output: true, display-mode: "both" }

# install packages
!pip install -q torch-scatter==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.7.0.html
!pip install -q torch-sparse==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.7.0.html
!pip install -q git+https://github.com/rusty1s/pytorch_geometric.git
!pip install -q torch-cluster==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.7.0.html
!pip install -q ase

In [None]:
#@title  { vertical-output: true, display-mode: "both" }
import numpy as np
import torch
import ase
import random

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# set random seeds
seed = 55555
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
# if you are using GPU
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

In [None]:
#@title  { vertical-output: true, display-mode: "both" }

BASE_PATH = '/content/drive/MyDrive/FeMaterials/'
TRAINING_RATIO = 0.8
NUCLEAR_CHARGE = 26 # default nuclear charge
OPTIMIZER = torch.optim.Adam
DTYPE = torch.float64

BATCH_SIZE = 6
CRITERION = torch.nn.MSELoss()
CROSSENTROPY = torch.nn.CrossEntropyLoss()

def print_hyperparameters():
  print("Default nuclear charge:", NUCLEAR_CHARGE)
  print("Training ratio:", TRAINING_RATIO)
  print("Batch size:", BATCH_SIZE)
  print("Optimizer:", OPTIMIZER)
  print("Learning rate:", LEARNING_RATE)
  print("Criterion:", CRITERION)

In [None]:
#@title  { vertical-output: true, display-mode: "both" }
import time
from datetime import datetime
import os
from numpy import savetxt

# mount drive
from google.colab import drive
drive.mount('/content/drive')

path2data = BASE_PATH + 'input/'

# Create results directories
now = datetime.now()
resdir = BASE_PATH + 'dimenet-results/' + now.strftime("%Y%m%d-%H%M%S") + '/'
os.makedirs(resdir)
res_models_dir = resdir + "models/";
res_maes_dir = resdir + "maes/";
res_graphs_dir = resdir + "graphs/";
os.makedirs(res_models_dir);
os.makedirs(res_maes_dir);
os.makedirs(res_graphs_dir);

# Process Data

## Pre-Processing Functions

In [None]:
#@title  { vertical-output: true, display-mode: "both" }
M_1_54 = None

def extend_atoms(atoms: ase.Atoms, target: int):
  global M_1_54
  tmp = None
  if M_1_54 is None or target != 54:
    tmp = ase.build.find_optimal_cell_shape(atoms.get_cell(), target, "sc") 
    if target == 54:
      M_1_54 = tmp
  else:
    tmp = M_1_54
  supercell = ase.build.make_supercell(atoms, tmp)
  supercell.info["energy"] = atoms.info["energy"] * int(target / atoms.get_global_number_of_atoms())
  return supercell

In [None]:
#@title  { vertical-output: true, display-mode: "both" }
from torch_geometric.data import Data

def data_object(atoms: ase.Atoms, num_atoms=54):
  
  n = atoms.get_global_number_of_atoms()

  if n == 1: # one atom structure - need to expand
    atoms = extend_atoms(atoms, num_atoms)

  n = atoms.get_global_number_of_atoms()

  cell = torch.tensor(atoms.cell, dtype=DTYPE).to(device)
  positions = torch.tensor(atoms.get_positions(), dtype=DTYPE).to(device)
  
  charges = [ NUCLEAR_CHARGE ] * len(positions)
  charges = torch.tensor(charges, dtype=torch.long).to(device)

  y = torch.tensor(atoms.info["energy"], dtype=DTYPE).to(device)

  return Data(charges=charges, x=positions, y=y, cell=cell, n=n)

# Load Dataset

We used Dragoni's dataset, which can be downloaded [here](https://archive.materialscloud.org/record/2017.0006/v2) in XYZ format (DB_bccFe_Dragoni.tar.gz)

Place DB\*.xyz files on the *path2data* folder

## Load only DB 1 (6001 structures expanded to 54 atoms each)

In [None]:
#@title  { vertical-output: true, display-mode: "both" }
DBs_to_load = [1]

## Load all 8 DBs (with DB1 expanded to 54 atoms each, others untouched)

In [None]:
#@title  { vertical-output: true, display-mode: "both" }
DBs_to_load = [1, 2, 3, 4, 5, 6, 7, 8]

## Pre-Process Data

In [None]:
#@title  { vertical-output: true, display-mode: "both" }
from ase.io import read

data_list = []

dbnames = ["DB{}.xyz".format(n) for n in DBs_to_load]
fns = [ path2data + n for n in dbnames ]
for fn in fns:
  db = read(fn, index=":")
  data_list += list(map(data_object, db))

print(len(data_list))

# Generate mean / std and create train / test set

In [None]:
#@title  { vertical-output: true, display-mode: "both" }

Y_atom = torch.tensor([r.y / r.n for r in data_list], dtype=DTYPE)

y_atom_mean = Y_atom.mean().item()
y_atom_std = Y_atom.std().item()

print(y_atom_mean, y_atom_std)

In [None]:
#@title  { vertical-output: true, display-mode: "both" }
from torch_geometric.data import DataLoader

random.shuffle(data_list)

train_amount = int(len(data_list) * TRAINING_RATIO)
train_data = data_list[:train_amount]
test_data = data_list[train_amount:]

train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)

# DimeNet - Edited version

**Changes done:**

*   Added support to Periodic Boundary Conditions
*   Added support to energy de-standardization
*   Fixed a mistake made on the PyTorch Geometric implementation on the Embedding Block
*   Fixed a mistake made on the original implementation of angles (swapped indices)



In [None]:
#@title  { vertical-output: true, form-width: "25%" }
from torch_geometric.nn import DimeNet
from torch_geometric.nn.acts import swish
from math import sqrt, pi as PI

import numpy as np
import torch
from torch.nn import Linear, Embedding
from torch_scatter import scatter
from torch_sparse import SparseTensor
from torch_geometric.nn import radius_graph
from torch_geometric.data import download_url
from torch_geometric.data.makedirs import makedirs

from torch_geometric.nn.models.dimenet import Envelope
from torch_geometric.nn.models.dimenet import BesselBasisLayer
from torch_geometric.nn.models.dimenet import SphericalBasisLayer
from torch_geometric.nn.models.dimenet import ResidualLayer
from torch_geometric.nn.models.dimenet import InteractionBlock
from torch_geometric.nn.models.dimenet import OutputBlock

from torch_geometric.nn.models.dimenet_utils import bessel_basis, real_sph_harm

try:
    import sympy as sym
except ImportError:
    sym = None

import os
try:
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
    import tensorflow as tf
except ImportError:
    tf = None


qm9_target_dict = {
    0: 'mu',
    1: 'alpha',
    2: 'homo',
    3: 'lumo',
    5: 'r2',
    6: 'zpve',
    7: 'U0',
    8: 'U',
    9: 'H',
    10: 'G',
    11: 'Cv',
}

# Implement PBC
from ase.neighborlist import neighbor_list 
from ase import Atoms


class EmbeddingBlock(torch.nn.Module):
    def __init__(self, num_radial, hidden_channels, act=swish):
        super(EmbeddingBlock, self).__init__()
        self.act = act

        self.emb = Embedding(95, hidden_channels)
        self.lin_rbf = Linear(num_radial, hidden_channels)
        self.lin = Linear(3 * hidden_channels, hidden_channels)

        self.reset_parameters()

    def reset_parameters(self):
        self.emb.weight.data.uniform_(-sqrt(3), sqrt(3))
        self.lin_rbf.reset_parameters()
        self.lin.reset_parameters()

    def forward(self, x, rbf, i, j):
        x = self.emb(x)
        #rbf = self.act(self.lin_rbf(rbf)) # FIX: this should not have an activation function
        rbf = self.lin_rbf(rbf)
        return self.act(self.lin(torch.cat([x[i], x[j], rbf], dim=-1)))

class DimeNet2(DimeNet):
  
  def __init__(self, hidden_channels, out_channels, num_blocks, num_bilinear,
                 num_spherical, num_radial, cutoff=5.0, envelope_exponent=5,
                 num_before_skip=1, num_after_skip=2, num_output_layers=3,
                 act=swish,
                 mean=None, std=None):
          super(DimeNet, self).__init__()

          self.cutoff = cutoff

          #set mean and standard deviation of energies
          self.mean = mean 
          self.std = std

          # padding used for PBCs
          self.padding = torch.nn.ConstantPad2d((0,6,0,0), 0)

          if sym is None:
              raise ImportError('Package `sympy` could not be found.')

          self.num_blocks = num_blocks

          self.rbf = BesselBasisLayer(num_radial, cutoff, envelope_exponent)
          self.sbf = SphericalBasisLayer(num_spherical, num_radial, cutoff,
                                          envelope_exponent)

          self.emb = EmbeddingBlock(num_radial, hidden_channels, act)

          self.output_blocks = torch.nn.ModuleList([
              OutputBlock(num_radial, hidden_channels, out_channels,
                          num_output_layers, act) for _ in range(num_blocks + 1)
          ])

          self.interaction_blocks = torch.nn.ModuleList([
              InteractionBlock(hidden_channels, num_bilinear, num_spherical,
                                num_radial, num_before_skip, num_after_skip, act)
              for _ in range(num_blocks)
          ])

          self.reset_parameters()

  def pbc_edges(self, z, pos, cell, batch):
          if cell is None:
            return

          tmp_z = z.cpu()
          tmp_pos = pos.cpu()
          tmp_cell = cell.cpu()
          nh1_tmp = np.array([]) # will contain all connection from node i
          nh2_tmp = np.array([]) # .. to node j
          dist_tmp = np.array([]) # distances between (i,j)
          shift_cells_tmp = None # bravais lattice multiplied by shift for the connection

          if batch is not None: #batch input
            tmp_batch = np.array(batch.cpu())
            batch_size = []
            found_b = []
            for b in tmp_batch: # create an array with each element being the dim of the corresponding index batch
              if b not in found_b:
                found_b.append(b)
                batch_size.append((tmp_batch == b).sum())

            for i in range(len(batch_size)):
              prev_sum = sum(batch_size[:i])
              current_z = tmp_z[prev_sum:batch_size[i]+prev_sum]
              # create the atomic structure
              atms = Atoms(charges=current_z, 
                           positions=tmp_pos[prev_sum:batch_size[i]+prev_sum], 
                           cell=tmp_cell[3*i:3*(i+1)], pbc=True) 

              # get the connections for the atomic structure w/ distances and shift
              nh1, nh2, dist, shift = neighbor_list("ijdS", atms, 
                                             self.cutoff, 
                                             self_interaction=False) 

              nh1 = nh1 + prev_sum # adds the number of previous elements to the atom index
              nh2 = nh2 + prev_sum

              nh1_tmp = np.concatenate((nh1_tmp, np.array(nh1)))
              nh2_tmp = np.concatenate((nh2_tmp, np.array(nh2)))
              dist_tmp = np.concatenate((dist_tmp, np.array(dist)))

              # Mult cells array (9 elements each) for each connection element in the batch
              cell_arr = np.asarray(tmp_cell[3*i:3*(i+1)]).reshape(-1)
              repeat = np.tile(cell_arr, (len(dist), 1))

              # multiply cell values by shift
              repeat[:, 0:3] = (repeat[:, 0:3].T * shift[:, 0]).T
              repeat[:, 3:6] = (repeat[:, 3:6].T * shift[:, 1]).T
              repeat[:, 6:9] = (repeat[:, 6:9].T * shift[:, 2]).T

              if shift_cells_tmp  is None:
                shift_cells_tmp  = np.matrix(repeat)
              else:
                shift_cells_tmp  = np.concatenate((shift_cells_tmp, repeat))
          else: # single cell input
              # create the atomic structure
              atms = Atoms(charges=tmp_z, 
                           positions=tmp_pos, 
                           cell=tmp_cell, pbc=True)

              # get the connections for the atomic structure w/ distances and shift
              nh1, nh2, dist, shift = neighbor_list("ijdS", atms, 
                                             self.cutoff, 
                                             self_interaction=False)

              nh1_tmp = np.concatenate((nh1_tmp, np.array(nh1)))
              nh2_tmp = np.concatenate((nh2_tmp, np.array(nh2)))
              dist_tmp = np.concatenate((dist_tmp, np.array(dist)))

              # Mult cells array (9 elements each) for each connection element in the batch
              cell_arr = np.asarray(tmp_cell).reshape(-1)
              repeat = np.tile(cell_arr, (len(dist), 1))

              # multiply cell values by shift
              repeat[:,0:3] = (repeat[:, 0:3].T * shift[:, 0]).T
              repeat[:,3:6] = (repeat[:, 3:6].T * shift[:, 1]).T
              repeat[:,6:9] = (repeat[:, 6:9].T * shift[:, 2]).T

              shift_cells_tmp = np.matrix(repeat)
          return [torch.tensor(nh1_tmp, dtype = torch.long).to(z.device), 
                  torch.tensor(nh2_tmp, dtype = torch.long).to(z.device), 
                  torch.tensor(dist_tmp, dtype = DTYPE).to(z.device),
                  torch.tensor(shift_cells_tmp, dtype = DTYPE).to(z.device)]
            

  def triplets(self, edge_index, num_nodes, shift_cells=None):
        row, col = edge_index  # j->i

        value = torch.arange(row.size(0), device=row.device)
        adj_t = SparseTensor(row=col, col=row, value=value,
                             sparse_sizes=(num_nodes, num_nodes))
        adj_t_row = adj_t[row]

        num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long)

        # Node indices (k->j->i) for triplets.
        idx_i = col.repeat_interleave(num_triplets)
        idx_j = row.repeat_interleave(num_triplets)

        if shift_cells is not None: # Update also the shift vectors
          shift_cells = shift_cells.repeat_interleave(num_triplets, dim=0)

        idx_k = adj_t_row.storage.col()

        mask = (idx_i != idx_k)  # Remove i == k triplets.
        idx_i, idx_j, idx_k = idx_i[mask], idx_j[mask], idx_k[mask]
        if shift_cells is not None: # Remove also from the shift vector
          shift_cells = shift_cells[mask]

        # Edge indices (k-j, j->i) for triplets.
        idx_kj = adj_t_row.storage.value()[mask]
        idx_ji = adj_t_row.storage.row()[mask]

        return col, row, idx_i, idx_j, idx_k, idx_kj, idx_ji, shift_cells

  def forward(self, z, pos, cell=None, batch=None):
        
        edge_index = []
        dist = []
        shift_cells = None
        if cell is not None: # implement PBC
          r1, r2, dist, shift_cells = self.pbc_edges(z, pos, cell, batch)
          edge_index = [r1, r2]
        else: # old method without PBC
          edge_index = radius_graph(pos, r=self.cutoff, batch=batch)

        i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji, shift_cells = self.triplets(
            edge_index, num_nodes=z.size(0), shift_cells=shift_cells)        

        # Calculate distances.
        if cell is None: # calculate distance without PBC
          dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt()
          
        # Define atoms position 
        pos_i = pos[idx_i]
        pos_j = pos[idx_j] # central atom
        pos_k = pos[idx_k]
        
        if cell is not None: # Fix coordinates for PBCs
          pos_i = pos_i + shift_cells[:, 0:3] + shift_cells[:, 3:6] + shift_cells[:, 6:9]
          pos_k = pos_k + shift_cells[:, 0:3] + shift_cells[:, 3:6] + shift_cells[:, 6:9]

        # Calculate angles - with some Fixes to indexes compared to the orig. version
        pos_ji, pos_kj = pos_j - pos_i, pos_k - pos_j

        a = (pos_ji * pos_kj).sum(dim=-1)
        b = torch.cross(pos_ji, pos_kj).norm(dim=-1)
        angle = torch.atan2(b, a)

        rbf = self.rbf(dist)
        sbf = self.sbf(dist, angle, idx_kj)

        # Embedding block.
        x = self.emb(z, rbf, i, j)
        P = self.output_blocks[0](x, rbf, i, num_nodes=pos.size(0))

        # Interaction blocks.
        for interaction_block, output_block in zip(self.interaction_blocks,
                                                   self.output_blocks[1:]):
            x = interaction_block(x, rbf, sbf, idx_kj, idx_ji)
            P += output_block(x, rbf, i, num_nodes=pos.size(0))

        # Energy de-standardization
        if self.std is not None and self.mean is not None:
          P = P * self.std + self.mean

        res = P.sum(dim=0) if batch is None else scatter(P, batch, dim=0)
        return res

# Training / Test Functions

In [None]:
#@title  { vertical-output: true, display-mode: "both" }

def experiment_summary():
  print("DBs used:", dbnames)
  print("dtype:", DTYPE)
  print("mean:", y_atom_mean, ", std:", y_atom_std)
  print("training ratio:", TRAINING_RATIO)
  print("batch size:", BATCH_SIZE)
  print("model:", model)
  print("optimizer:", optimizer)
  print("scheduler:", scheduler)

In [None]:
#@title  { vertical-output: true, display-mode: "both" }

def train(model, loader, optimizer):
  model.train()
  rl = 0

  for data in loader:
    if WITH_PBC:
      out = model(data.charges, data.x, data.cell, data.batch)
    else:
      out = model(data.charges, data.x, batch=data.batch)
      
    out = out.squeeze(1) 
    loss = CRITERION(out, data.y)

    with torch.no_grad():
      rl += loss.item()

    optimizer.zero_grad()  # Clear gradients.
    loss.backward()  # Derive gradients.
    optimizer.step()  # Update parameters based on gradients.

  return rl  

In [None]:
#@title  { vertical-output: true, display-mode: "both" }

def test(model, loader):
  model.eval()

  maes = []
  for data in loader:  # Iterate in batches over the training/test dataset.
    with torch.no_grad():
      if WITH_PBC:
        out = model(data.charges, data.x, data.cell, data.batch)
      else:
        out = model(data.charges, data.x, batch=data.batch)

      out = out.squeeze(1)
      mae = (out.view(-1) - data.y).abs()
      maes.append(mae)
  
  mae = torch.cat(maes, dim=0) # flatten
  return mae.mean()

In [None]:
#@title  { vertical-output: true, display-mode: "both" }

import matplotlib.pyplot as plt

def save_mae_plot(tr_mae, te_mae, ep):
  ep = range(1, len(tr_mae)+1)
  plt.plot(ep, tr_mae, label="train mae")
  plt.plot(ep, te_mae, label="test mae")
  plt.ylabel('MAE')
  plt.xlabel('EPOCHS')
  plt.savefig(res_graphs_dir + 'mae_{}_.png'.format(ep))

In [None]:
#@title  { vertical-output: true, display-mode: "both" }

from numpy import savetxt

def train_and_test(model, train_loader, test_loader, optimizer, scheduler, epochs=100):
  experiment_summary()
  train_maes = []
  test_maes = []
  for epoch in range(0, epochs):
    loss = train(model, train_loader, optimizer)
    train_acc = test(model, train_loader)
    test_acc = test(model, test_loader)

    if scheduler is not None:
      scheduler.step(metrics=test_acc)

    train_maes.append(train_acc)
    test_maes.append(test_acc)

    savetxt(res_maes_dir + "train.txt", train_maes, delimiter=";")
    savetxt(res_maes_dir + "test.txt", test_maes, delimiter=";")
    if epoch % 5 == 0: # save mae graph
      save_mae_plot(train_maes, test_maes, epoch)
      torch.save(model.state_dict(), res_models_dir + "dimenet_{}.model".format(epoch))

    print(f'Epoch: {epoch:03d}, Train MAE: {train_acc:.4f}, Test MAE: {test_acc:.4f}, Train Loss: {loss:.4f}')
    
  return train_maes, test_maes

# Network Training

In [None]:
#@title  { vertical-output: true }
from datetime import datetime

# hyperparameters 
EPOCHS = 100
WITH_PBC = True
hidden_channels = 128
out_channels = 1
num_blocks = 7
num_bilinear = 8
num_spherical = 7
num_radial = 6
cutoff = 3.5

model = DimeNet2(hidden_channels=hidden_channels, out_channels=out_channels, num_blocks=num_blocks, num_bilinear=num_bilinear, num_spherical=num_spherical, num_radial=num_radial, cutoff=cutoff, std=y_atom_std, mean=y_atom_mean).to(device)

# save the model data as string to be easily reused when loading the model
model_str="model = DimeNet2(hidden_channels={}, out_channels={}, num_blocks={}, num_bilinear={}, num_spherical={}, num_radial={}, cutoff={}, std={}, mean={}).to(device)".format(
    hidden_channels, out_channels, num_blocks, num_bilinear, num_spherical, num_radial, cutoff, y_atom_std, y_atom_mean
);
model_file = open(resdir + "model_str.txt", "w")
model_file.write(model_str)
model_file.close()

model = model.double() # necessary when using float64 as DTYPE

# Optimizer and scheduler
optimizer = OPTIMIZER(model.parameters(), lr=0.0001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',
    factor=0.1, patience=10, threshold=0.01, threshold_mode='abs')

In [None]:
train_maes, test_maes = train_and_test(model, train_loader, test_loader, optimizer, scheduler, epochs = EPOCHS)

torch.save(model.state_dict(), res_models_dir + "final.model") # save final model
save_mae_plot(train_maes, test_maes, 9999) # save final mae plot

## Function to clean the network from memory if you need to instanciate it again

In [None]:
import gc

if model:
  model.cpu()
  del model
  gc.collect()
  torch.cuda.empty_cache()
  print("Deleted dimenet")