*   Make sure you run this using a GPU (On Google Colab: Runtime -> Change Runtime Type -> GPU)

# Setup

In [None]:
#@title  { vertical-output: true, display-mode: "both" }
!pip install -q ase
!pip install -q torch==1.8.0
!pip install  torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cu102.html
!pip install -q torch-sparse -f https://pytorch-geometric.com/whl/torch-1.8.0+cu102.html
!pip install -q torch-cluster -f https://pytorch-geometric.com/whl/torch-1.8.0+cu102.html
!pip install -q torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.8.0+cu102.html
!pip install -q git+https://github.com/rusty1s/pytorch_geometric.git@1.7.0

[K     |████████████████████████████████| 2.2MB 8.8MB/s 
[K     |████████████████████████████████| 735.5MB 24kB/s 


In [None]:
import numpy as np
import torch
import ase
import random
from torch_geometric.data import Data, DataLoader
from ase.io import read

In [None]:
SEED = 55555

np.random.seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

## Set training hyperparameters

In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # use GPU if available
DTYPE = torch.float64  # data type to use for data and model

BASE_PATH = '/content/drive/MyDrive/gnn_atomistics'  # path to git repo/colab/gnn_atomistics

TRAINING_RATIO = 0.8  # percent of the dataset to use for training
OPTIMIZER = torch.optim.Adam  
BATCH_SIZE = 4
CUTOFF = 3.5

In [None]:
from google.colab import drive
drive.mount('/content/drive')
DATA_PATH = BASE_PATH + "/data/training"
MODELS_PATH = BASE_PATH + "/models"

# Data

## Define functions

In [None]:
M = {}
def extend_atoms(atoms, source, target):
  global M
  if source not in M.keys():
    M[source] = {}
  if target not in M[source].keys():
    M[source][target] = ase.build.find_optimal_cell_shape(atoms.get_cell(), target, "sc") 
  supercell = ase.build.make_supercell(atoms, M[source][target])
  supercell.info["energy"] = atoms.info["energy"] * int(target / source)
  return supercell

In [None]:
def data_object(atoms: ase.Atoms):

  n = atoms.get_global_number_of_atoms()
  if n == 1:
    atoms = extend_atoms(atoms, 1, 54)
    n = atoms.get_global_number_of_atoms()

  cell = torch.tensor(atoms.cell, dtype=DTYPE, device=DEVICE)
  x = torch.tensor(atoms.get_positions(), dtype=DTYPE, device=DEVICE)
  z = torch.tensor(atoms.get_array("numbers", copy=True), dtype=torch.long, device=DEVICE)
  y = torch.tensor(atoms.info["energy"], dtype=DTYPE, device=DEVICE)
  f = torch.tensor(atoms.get_array("force", copy=True), dtype=DTYPE, device=DEVICE)

  return Data(z=z, x=x, cell=cell, y=y, f=f, n=n)

In [None]:
from ase.neighborlist import neighbor_list 
from ase import Atoms
def data_object_with_ghosts(atoms: ase.Atoms, cutoff: float):

  nlocal = atoms.get_global_number_of_atoms()
  if nlocal == 1:
    atoms = extend_atoms(atoms, 1, 54)
    nlocal = atoms.get_global_number_of_atoms()

  i, j, S = neighbor_list("ijS", atoms, cutoff, self_interaction=False)
  mask = ~(np.all(np.equal(S, np.array([0, 0, 0])), axis=1))
  x = atoms.get_positions()
  cell = atoms.get_cell()

  k = []
  shifts = []
  for l, s in zip(j[mask].tolist(), S[mask].tolist()):
    u = [v for v, el in enumerate(k) if el == l]
    found = False
    for v in u:
      if shifts[v][0] == s[0] and shifts[v][1] == s[1] and shifts[v][2] == s[2]:
        found = True
        break
    if not found:
      k.append(l)
      shifts.append(s)

  k = np.array(k)
  shifts = np.array(shifts)

  ghost_x = x[k] + np.matmul(shifts, cell)
  new_x = np.concatenate((x, ghost_x))
  nghost = len(ghost_x)
  n = nlocal + nghost

  idx_local = torch.tensor(list(range(nlocal)), dtype=torch.long, device="cpu")
  cell = torch.tensor(cell, dtype=DTYPE, device="cpu")
  x = torch.tensor(new_x, dtype=DTYPE, device="cpu")
  z = torch.tensor([26]*(nlocal + nghost), dtype=torch.long, device="cpu")
  y = torch.tensor(atoms.info["energy"], dtype=DTYPE, device="cpu")
  f = torch.tensor(atoms.get_array("force", copy=True), dtype=DTYPE, device="cpu")

  return Data(z=z, x=x, cell=cell, y=y, f=f, idx_local=idx_local, nlocal=nlocal, nghost=nghost)

## Load data

In [None]:
data = []
file_names = [f"DB{i}.xyz" for i in range(1,9)]
file_paths = [f"{DATA_PATH}/{f}" for f in file_names]
for f in file_paths:
  db = read(f, index=":")
  data += list(map(lambda a: data_object_with_ghosts(a, CUTOFF), db))
print(len(data))

In [None]:
y_atom = torch.tensor([d.y / d.nlocal for d in data], dtype=DTYPE)
y_atom_mean = y_atom.mean().item()
y_atom_std = y_atom.std().item()
print(y_atom_mean, y_atom_std)

In [None]:
random.shuffle(data)
train_amount = int(len(data) * TRAINING_RATIO)
train_data = data[:train_amount]
test_data = data[train_amount:]
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)

# Model


In [None]:
#@title  { vertical-output: true, form-width: "25%" }
from torch_geometric.nn import DimeNet
from torch_geometric.nn.acts import swish
from math import sqrt, pi as PI

import numpy as np
import torch
from torch.nn import Linear, Embedding
from torch_scatter import scatter
from torch_sparse import SparseTensor
from torch_geometric.nn import radius_graph
from torch_geometric.data import download_url
from torch_geometric.data.makedirs import makedirs

from torch_geometric.nn.models.dimenet import Envelope
from torch_geometric.nn.models.dimenet import BesselBasisLayer
from torch_geometric.nn.models.dimenet import SphericalBasisLayer
from torch_geometric.nn.models.dimenet import ResidualLayer
from torch_geometric.nn.models.dimenet import InteractionBlock
from torch_geometric.nn.models.dimenet import OutputBlock

from torch_geometric.nn.models.dimenet_utils import bessel_basis, real_sph_harm
import ase
from ase.neighborlist import neighbor_list 
from ase import Atoms

from torch_geometric.data import DataLoader

try:
    import sympy as sym
except ImportError:
    sym = None

import os
try:
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
    import tensorflow as tf
except ImportError:
    tf = None

# TODO: move this somewhere else
device = DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DTYPE = torch.float64


def pbc_edges(cutoff, z, x, cell, batch, compute_sc=False):

  NH1 = torch.tensor([], dtype=torch.long, device=DEVICE)
  NH2 = torch.tensor([], dtype=torch.long, device=DEVICE)
  S = torch.tensor([], dtype=torch.long, device=DEVICE)
  D = torch.tensor([], dtype=DTYPE, device=DEVICE)
  SC = torch.tensor([], dtype=DTYPE, device=DEVICE) if compute_sc else None
  x_ = torch.clone(x).detach().cpu().numpy()

  if batch is not None:
    # count number of elements for each batch
    batch_ids = list(set(batch.cpu().tolist()))
    batch_sizes = [ (batch == id).sum().item() for id in batch_ids ]

    for i in range(len(batch_sizes)):
      offset = sum(batch_sizes[:i]) # to obtain correct atom indices
      
      atoms = Atoms(charges = (z[offset:offset + batch_sizes[i]]).cpu(), 
        positions = x_[offset:offset + batch_sizes[i]], 
        cell = (cell[3*i:3*(i+1)]).cpu(),
        pbc=True
      ) 
      
      nh1, nh2, s = neighbor_list("ijS", atoms, cutoff, self_interaction=False) 
      nh1 = torch.tensor(nh1, dtype=torch.long, device=DEVICE)
      nh2 = torch.tensor(nh2, dtype=torch.long, device=DEVICE)
      nh1 = nh1 + offset
      nh2 = nh2 + offset
      s = torch.tensor(s, dtype=DTYPE, device=DEVICE)
      d = x[nh2] - x[nh1] + torch.matmul(s, cell[3*i:3*(i+1)])
      
      if compute_sc:
        cell_flat = torch.flatten(cell[3*i:3*(i+1)])
        sc = torch.tile(cell_flat, (len(d), 1))
        sc[:, 0:3] = (sc[:, 0:3].T * s[:, 0]).T
        sc[:, 3:6] = (sc[:, 3:6].T * s[:, 1]).T
        sc[:, 6:9] = (sc[:, 6:9].T * s[:, 2]).T
        SC = torch.cat((SC, sc), 0)      

      NH1 = torch.cat((NH1, nh1), 0)
      NH2 = torch.cat((NH2, nh2), 0)
      S = torch.cat((S, s), 0)
      D = torch.cat((D, d), 0)

  else: # no batch
    atoms = Atoms(charges = z.cpu(), positions = x.detach().numpy(), cell = cell.cpu(), pbc=True)
    nh1, nh2, s = neighbor_list("ijS", atoms, cutoff, self_interaction=False)
    nh1 = torch.tensor(nh1, dtype=torch.long, device=DEVICE)
    nh2 = torch.tensor(nh2, dtype=torch.long, device=DEVICE)
    s = torch.tensor(s, dtype=DTYPE, device=DEVICE)
    d = x[nh2] - x[nh1] + torch.matmul(s, cell)
    
    if compute_sc:
      cell_flat = torch.flatten(cell)
      sc = torch.tile(cell_flat, (len(d), 1))
      sc[:, 0:3] = (sc[:, 0:3].T * s[:, 0]).T
      sc[:, 3:6] = (sc[:, 3:6].T * s[:, 1]).T
      sc[:, 6:9] = (sc[:, 6:9].T * s[:, 2]).T  
      SC = sc

    NH1, NH2, S, D = nh1, nh2, s, d

  D = D.norm(dim=-1)
  return  NH1, NH2, D, S, SC 

class EmbeddingBlock(torch.nn.Module):
  def __init__(self, num_radial, hidden_channels, act=swish):
    super(EmbeddingBlock, self).__init__()
    self.act = act

    self.emb = Embedding(95, hidden_channels)
    self.lin_rbf = Linear(num_radial, hidden_channels, bias=False)
    self.lin = Linear(3 * hidden_channels, hidden_channels)

    self.reset_parameters()

  def reset_parameters(self):
    self.emb.weight.data.uniform_(-sqrt(3), sqrt(3))
    self.lin_rbf.reset_parameters()
    self.lin.reset_parameters()

  def forward(self, x, rbf, i, j):
    x = self.emb(x)
    #rbf = self.act(self.lin_rbf(rbf)) # FIX: this should not have an activation function
    rbf = self.lin_rbf(rbf)
    return self.act(self.lin(torch.cat([x[i], x[j], rbf], dim=-1)))



class DimeNetx(torch.nn.Module):
  
  def __init__(self, hidden_channels, out_channels, num_blocks, num_bilinear,
                 num_spherical, num_radial, cutoff=5.0, envelope_exponent=5,
                 num_before_skip=1, num_after_skip=2, num_output_layers=3,
                 act=swish, mean=None, std=None):
    super(DimeNetx, self).__init__()

    self.cutoff = cutoff

    #set mean and standard deviation of energies
    self.mean = mean 
    self.std = std

    # padding used for PBCs
    self.padding = torch.nn.ConstantPad2d((0,6,0,0), 0)

    if sym is None:
        raise ImportError('Package `sympy` could not be found.')

    self.num_blocks = num_blocks

    self.rbf = BesselBasisLayer(num_radial, cutoff, envelope_exponent)
    self.sbf = SphericalBasisLayer(num_spherical, num_radial, cutoff,
                                    envelope_exponent)

    self.emb = EmbeddingBlock(num_radial, hidden_channels, act)

    self.output_blocks = torch.nn.ModuleList([
        OutputBlock(num_radial, hidden_channels, out_channels,
                    num_output_layers, act) for _ in range(num_blocks + 1)
    ])

    self.interaction_blocks = torch.nn.ModuleList([
        InteractionBlock(hidden_channels, num_bilinear, num_spherical,
                          num_radial, num_before_skip, num_after_skip, act)
        for _ in range(num_blocks)
    ])

    self.reset_parameters()

  def reset_parameters(self):
    self.rbf.reset_parameters()
    self.emb.reset_parameters()
    for out in self.output_blocks:
      out.reset_parameters()
    for interaction in self.interaction_blocks:
      interaction.reset_parameters()

  def triplets_original(self, edge_index, num_nodes):
    row, col = edge_index  # j->i

    value = torch.arange(row.size(0), device=row.device)
    adj_t = SparseTensor(row=col, col=row, value=value,
                         sparse_sizes=(num_nodes, num_nodes))
    adj_t_row = adj_t[row]
    num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long)

    # Node indices (k->j->i) for triplets.
    idx_i = col.repeat_interleave(num_triplets)
    idx_j = row.repeat_interleave(num_triplets)
    idx_k = adj_t_row.storage.col()
    mask = (idx_i != idx_k)  # Remove i == k triplets.
    idx_i, idx_j, idx_k = idx_i[mask], idx_j[mask], idx_k[mask]

    # Edge indices (k-j, j->i) for triplets.
    idx_kj = adj_t_row.storage.value()[mask]
    idx_ji = adj_t_row.storage.row()[mask]

    return col, row, idx_i, idx_j, idx_k, idx_kj, idx_ji


  def triplets(self, edge_index, num_nodes, shift_cells=None, shift=None):
    row, col = edge_index  # j->i

    value = torch.arange(row.size(0), device=row.device)
    adj_t = SparseTensor(row=col, col=row, value=value,
                         sparse_sizes=(num_nodes, num_nodes))
    adj_t_row = adj_t[row]
    num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long)

    idx_i = col.repeat_interleave(num_triplets)
    idx_j = row.repeat_interleave(num_triplets)
    idx_k = adj_t_row.storage.col()

    if shift_cells is not None: # Update also the shift vectors
      shift_cells_i = shift_cells.repeat_interleave(num_triplets, dim=0)
      shift_i = shift.repeat_interleave(num_triplets, dim=0)
      shift_cells_k = -shift_cells[adj_t_row.storage.value()]
      shift_k = -shift[adj_t_row.storage.value()]

    mask = torch.all((torch.cat((torch.unsqueeze(idx_i, 1), shift_i), dim=1) ==\
                      torch.cat((torch.unsqueeze(idx_k, 1), shift_k), dim=1)), dim=1)

    idx_i, idx_j, idx_k = idx_i[~mask], idx_j[~mask], idx_k[~mask]
    if shift_cells is not None: # Remove also from the shift vector
      shift_cells_i = shift_cells_i[~mask]
      shift_cells_k = shift_cells_k[~mask]
      shift_i = shift_i[~mask]
      shift_k = shift_k[~mask]

    idx_kj = adj_t_row.storage.value()[~mask]
    idx_ji = adj_t_row.storage.row()[~mask]

    return col, row, idx_i, idx_j, idx_k, idx_kj, idx_ji, shift_cells_i, shift_i, shift_cells_k, shift_k


  def forward(self, z, pos, idx_local, cell=None, batch=None):

    edge_index = []
    dist = []
    shift_cells = None
    if cell is not None: # implement PBC
      r1, r2, dist, shift, shift_cells = pbc_edges(self.cutoff, z, pos, cell, batch, compute_sc=True) 
      edge_index = [r1, r2]

        
      i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji, shift_cells_i, shift_i, shift_cells_k, shift_k = self.triplets(
          edge_index, num_nodes=z.size(0), shift_cells=shift_cells, shift=shift)        
    else: # old method without PBC
      edge_index = radius_graph(pos, r=self.cutoff, batch=batch)
      i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji = self.triplets_original(
          edge_index, num_nodes=z.size(0))        
      dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt()

    # Define atoms position 
    pos_i = pos[idx_i]
    pos_j = pos[idx_j]
    pos_k = pos[idx_k]

    if cell is not None: # Fix coordinates for PBCs

      pos_i = pos_i + shift_cells_i[:, 0:3] + shift_cells_i[:, 3:6] + shift_cells_i[:, 6:9]
      pos_k = pos_k + shift_cells_k[:, 0:3] + shift_cells_k[:, 3:6] + shift_cells_k[:, 6:9]
      sc_ij = torch.all(~torch.all(pos_i == pos_j, dim=1)) 
      sc_kj = torch.all(~torch.all(pos_k == pos_j, dim=1))
      #if not (sc_ij and sc_kj):
      #   raise NameError('Found same position for different atoms!')

    # Calculate angles - with some Fixes to indexes compared to the orig. version
    pos_ji, pos_kj = pos_j - pos_i, pos_k - pos_j
    a = (pos_ji * pos_kj).sum(dim=-1)
    b = torch.cross(pos_ji, pos_kj).norm(dim=-1) 
    angle = torch.atan2(b, a)      

    rbf = self.rbf(dist)
    sbf = self.sbf(dist, angle, idx_kj)

    # Embedding block.
    x = self.emb(z, rbf, i, j)
    P = self.output_blocks[0](x, rbf, i, num_nodes=pos.size(0))

    # Interaction blocks.
    for interaction_block, output_block in zip(self.interaction_blocks,
                                               self.output_blocks[1:]):
      x = interaction_block(x, rbf, sbf, idx_kj, idx_ji)
      a = output_block(x, rbf, i, num_nodes=pos.size(0))
      P += a

    # Energy de-standardization
    if self.std is not None and self.mean is not None:
      P = P * self.std + self.mean
    
    P = P[idx_local]
    res = P.sum(dim=0) if batch is None else scatter(P, batch[idx_local], dim=0)
    return res

# Training

## Define functions


In [None]:
import gc

In [None]:
def energy_loss(y, p_energies):
  energies_loss = torch.mean(torch.abs(y - p_energies))
  return energies_loss

In [None]:
def energy_forces_loss(y, p_energies, p_forces, energy_coeff):
  energies_loss = torch.mean(torch.abs(y - p_energies))
  forces_loss = torch.mean(torch.abs(data.f - p_forces))
  total_loss = (energy_coeff)*energies_loss + (1-energy_coeff)*forces_loss
  return total_loss, energies_loss, forces_loss

In [None]:
def train(model, loader, optimizer, use_forces=False, energy_coeff=None):
  if not use_forces:
    train_energy(model, loader, optimizer)
  else:
    train_energy_forces(model, loader, optimizer, energy_coeff)

In [None]:
def train_energy(model, loader, optimizer):
  model.train()
  for i, data in enumerate(loader):
    optimizer.zero_grad()
    z, x, batch, y = data.z.to(DEVICE), data.x.to(DEVICE), data.batch.to(DEVICE), data.y.to(DEVICE)
    acc = 0
    batched_idx_local = torch.tensor([], dtype=torch.long)
    for i in range(data.num_graphs):
      ex = data.get_example(i)
      batched_idx_local = torch.cat((batched_idx_local, ex.idx_local + acc)) 
      acc += ex.nlocal + ex.nghost
    batched_idx_local = batched_idx_local.to(DEVICE)
    e = model(z, x, batched_idx_local, cell=None, batch=batch)
    e = e.squeeze(1) 
    e_loss = energy_loss(y, e)

    e_loss.backward()
    optimizer.step()
    del z
    del x
    del batch
    del y
    del batched_idx_local
    gc.collect()
    torch.cuda.empty_cache()
  

In [None]:
def train_energy_forces(model, loader, optimizer, energy_coeff):
  model.train()
  total_e_loss = 0
  total_f_loss = 0
  total_ef_loss = 0

  for data in loader:
    z, x, batch, y = data.z.to(DEVICE), data.x.to(DEVICE), data.batch.to(DEVICE), data.y.to(DEVICE)
    x.requires_grad = True 
    optimizer.zero_grad()
    
    e = model(z, x, cell=None, batch=batch)
    f = -1 * torch.autograd.grad(e, x, grad_outputs=torch.ones_like(e), create_graph=True, retain_graph=True)[0]
    e = e.squeeze(1) 
    
    ef_loss, e_loss, f_loss = energy_forces_loss(data, e, f, energy_coeff)
    with torch.no_grad():
      total_e_loss += e_loss.item()
      total_f_loss += f_loss.item()
      total_ef_loss += ef_loss.item()

    ef_loss.backward()
    optimizer.step()
    del z
    del x
    del batch
    del y
    del batched_idx_local
    gc.collect()
    torch.cuda.empty_cache()
  
  print("Total training loss\t ef: {}, e: {}, f: {}".format(total_ef_loss, total_e_loss, total_f_loss))

In [None]:
def test(model, train_loader, test_loader):
  model.eval()

  with torch.no_grad():
    train_errs = torch.tensor([], dtype=DTYPE)
    test_errs = torch.tensor([], dtype=DTYPE)
  
    for data in train_loader:  # Iterate in batches over the training/test dataset.
      z, x, batch = data.z.to(DEVICE), data.x.to(DEVICE), data.batch.to(DEVICE)
      acc = 0
      batched_idx_local = torch.tensor([], dtype=torch.long)
      for i in range(data.num_graphs):
        ex = data.get_example(i)
        batched_idx_local = torch.cat((batched_idx_local, ex.idx_local + acc)) 
        acc += ex.nlocal + ex.nghost
      batched_idx_local = batched_idx_local.to(DEVICE)
      e = model(z, x, batched_idx_local, cell=None, batch=batch)
      e = e.squeeze(1)
      errs = torch.abs(e.view(-1).cpu() - data.y.cpu())
      train_errs = torch.cat((train_errs, errs))
      torch.cuda.empty_cache()
      del z
      del x
      del batch
      del batched_idx_local
      gc.collect()
    train_mae = torch.mean(train_errs).item()

    for data in test_loader:  # Iterate in batches over the training/test dataset.
      z, x, batch = data.z.to(DEVICE), data.x.to(DEVICE), data.batch.to(DEVICE)
      acc = 0
      batched_idx_local = torch.tensor([], dtype=torch.long)
      for i in range(data.num_graphs):
        ex = data.get_example(i)
        batched_idx_local = torch.cat((batched_idx_local, ex.idx_local + acc)) 
        acc += ex.nlocal + ex.nghost
      batched_idx_local = batched_idx_local.to(DEVICE)
      e = model(z, x, batched_idx_local, cell=None, batch=batch)
      e = e.squeeze(1)
      errs = torch.abs(e.view(-1).cpu() - data.y.cpu())
      test_errs = torch.cat((test_errs, errs))
      torch.cuda.empty_cache()
      del z
      del x
      del batch
      del batched_idx_local
      gc.collect()
    test_mae = torch.mean(test_errs).item()
  
  return train_mae, test_mae

In [None]:
from datetime import datetime

train_maes = []
test_maes = []
def train_and_test(model, train_loader, test_loader, optimizer, scheduler, 
                   use_forces=False, energy_coeff=None, 
                   epochs=100, starting_epoch=1, save_every=5, 
                   name=None, description=None):
  if name == None:
    name = "SchNet"
  if description == None:
    description = datetime.now().strftime('%H:%M:%S')

  for epoch in range(starting_epoch, epochs):
    print("")
    print(f"Epoch {epoch} ({datetime.now().strftime('%H:%M:%S')})")
    print(f"")
    train(model, train_loader, optimizer, use_forces=use_forces, energy_coeff=energy_coeff)
    train_mae, test_mae = test(model, train_loader, test_loader)
    train_maes.append(train_mae)
    test_maes.append(test_mae)
    scheduler.step(test_mae)
    print(f"Energy MAE: train {train_mae} eV, test {test_mae} eV")
    if epoch % save_every == 0:
      save_model(name=name, description=description, epoch=epoch, train_err=train_mae, test_err=test_mae)

In [None]:
def save_model(name, description, epoch=-1, train_err=-1, test_err=-1):
  model_data = {
    "desc": "DimeNet external ghosts",
    "str": model_str,
    "mean": y_atom_mean,
    "std": y_atom_std,
    "state": model.state_dict(),
    "optimizer": optimizer.state_dict(),
    "scheduler": scheduler.state_dict(),
    "train_err": train_err,
    "test_err": test_err
  }  
  if not epoch == -1:
    epoch_str = "_" + str(epoch)
  else:
    epoch_str = ""
  actual_name = f"{name}{epoch_str}"

  dir = f"{MODELS_PATH}/{name}"
  if not os.path.exists(dir):
    os.makedirs(dir)
    print(f"Created directory {dir}")
  torch.save(model_data, f"{dir}/{actual_name}")
  print(f"Saved {actual_name}")

  if os.path.isfile(f"{dir}/{name}_best"):
    best = torch.load(f"{dir}/{name}_best")
    best_train_err = best["train_err"]
    best_test_err = best["test_err"]
  else:
    best_train_err, best_test_err = 10e18, 10e18
  if test_err < best_test_err:
    torch.save(model_data, f"{dir}/{name}_best")
    print(f"Saved best")

## Instantiate model

In [None]:
model_name = "dimenet_35_extghost3" # name of the folder where checkpoints will be saved
model_description = "DimeNet with external ghosts, batched idx local masking"

In [None]:
model = DimeNetx(hidden_channels=128, out_channels=1, num_blocks=7, num_bilinear=8, num_spherical=7, num_radial=6, cutoff=3.5, std=y_atom_std, mean=y_atom_mean).to(device)

if DTYPE == torch.float64:
  model = model.double()

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
optimizer.zero_grad()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, threshold=0.01, threshold_mode='abs')
use_forces = False
energy_coeff = 0.
max_epochs = 100

Copy-paste the content of the previous cell in the following variable to allow for saving of parameters + hyperparameters

In [None]:
model_str = """
model = DimeNetx(hidden_channels=128, out_channels=1, num_blocks=7, num_bilinear=8, num_spherical=7, num_radial=6, cutoff=3.5, std=y_atom_std, mean=y_atom_mean).to(device)

if DTYPE == torch.float64:
  model = model.double()

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
optimizer.zero_grad()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, threshold=0.01, threshold_mode='abs')
use_forces = False
energy_coeff = 0.
max_epochs = 100
"""

# [Restore from checkpoint]

In [None]:
model_fn = "dimenet_35_extghost3_45"
checkpoint = torch.load(MODELS_PATH + "/dimenet_35_extghost3/" + model_fn)
model.load_state_dict(checkpoint["state"])
optimizer.load_state_dict(checkpoint["optimizer"])
scheduler.load_state_dict(checkpoint["scheduler"])

# Info

In [None]:
#@title  { vertical-output: true, display-mode: "both" }
def experiment_summary():
  device = "CPU" if DEVICE == torch.device("cpu") else f"GPU ({torch.cuda.get_device_name(device=DEVICE)})"
  print(f"Device: {device}")
  print(f"Data type: {DTYPE}")
  print(f"Files used: {file_names}")
  print(f"Cutoff: {3.5}")
  print(f"Mean: {y_atom_mean}, std: {y_atom_std}")
  print(f"Training ratio: {TRAINING_RATIO}")
  print(f"Batch size: {BATCH_SIZE}")
  print(f"Model: \n    {model_name}: {model_description}\n    {str(model)}")
  loss = "energy" if not use_forces else f"energy+forces ({energy_coeff})"
  print(f"Loss: {loss}")
  print(f"Optimizer: {optimizer}")
  print(f"Scheduler: {scheduler}")
  print(f"Max epochs: {max_epochs}")
experiment_summary()

In [None]:
torch.cuda.get_device_name(device=DEVICE)

In [None]:
#@title  { vertical-output: true, display-mode: "both" }
from psutil import *
print(cpu_count())
print(cpu_stats())
!cat /proc/cpuinfo
!df -h
print(virtual_memory())


In [None]:
train_and_test(model, train_loader, test_loader, optimizer, scheduler, 
               use_forces=use_forces, energy_coeff=energy_coeff,
               epochs=max_epochs, starting_epoch=46, save_every=1,
               name=model_name, description=model_description)

## [Clean memory]

In [None]:
import gc
if model:
  model.cpu()
  del model
  gc.collect()
  torch.cuda.empty_cache()