*   Make sure you run the model using a GPU (On Google Colab: Runtime -> Change Runtime Type -> GPU)

# Setup

In [None]:
#@title  { vertical-output: true, display-mode: "both" }
!pip install -q ase
!pip install -q torch==1.8.0
!pip install  torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cu102.html
!pip install -q torch-sparse -f https://pytorch-geometric.com/whl/torch-1.8.0+cu102.html
!pip install -q torch-cluster -f https://pytorch-geometric.com/whl/torch-1.8.0+cu102.html
!pip install -q torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.8.0+cu102.html
!pip install -q git+https://github.com/rusty1s/pytorch_geometric.git@1.7.0

[K     |████████████████████████████████| 2.2MB 14.1MB/s 
[K     |████████████████████████████████| 735.5MB 22kB/s 
[31mERROR: torchvision 0.10.0+cu102 has requirement torch==1.9.0, but you'll have torch 1.8.0 which is incompatible.[0m
[31mERROR: torchtext 0.10.0 has requirement torch==1.9.0, but you'll have torch 1.8.0 which is incompatible.[0m
[?25hLooking in links: https://pytorch-geometric.com/whl/torch-1.8.0+cu102.html
Collecting torch-scatter
[?25l  Downloading https://pytorch-geometric.com/whl/torch-1.8.0%2Bcu102/torch_scatter-2.0.7-cp37-cp37m-linux_x86_64.whl (2.7MB)
[K     |████████████████████████████████| 2.7MB 12.5MB/s 
[?25hInstalling collected packages: torch-scatter
Successfully installed torch-scatter-2.0.7
[K     |████████████████████████████████| 1.6MB 13.7MB/s 
[K     |████████████████████████████████| 1.1MB 7.3MB/s 
[K     |████████████████████████████████| 399kB 7.2MB/s 
[K     |████████████████████████████████| 235kB 14.5MB/s 
[K     |██████████████

In [None]:
import numpy as np
import torch
import ase
import random
from torch_geometric.data import Data, DataLoader
from ase.io import read

In [None]:
SEED = 55555

np.random.seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

## Set training hyperparameters

In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # use GPU if available
DTYPE = torch.float64  # data type to use for data and model

BASE_PATH = '/content/drive/MyDrive/FeMaterials/'  # path to git repo/colab/gnn_atomistics

TRAINING_RATIO = 0.8  # percent of the dataset to use for training
OPTIMIZER = torch.optim.Adam  
BATCH_SIZE = 6

In [None]:
from google.colab import drive
drive.mount('/content/drive')
DATA_PATH = BASE_PATH + 'input/'
MODELS_PATH = BASE_PATH + "/models"

Mounted at /content/drive


# Data

## Define functions


In [None]:
M = {}
def extend_atoms(atoms, source, target):
  global M
  if source not in M.keys():
    M[source] = {}
  if target not in M[source].keys():
    M[source][target] = ase.build.find_optimal_cell_shape(atoms.get_cell(), target, "sc") 
  supercell = ase.build.make_supercell(atoms, M[source][target])
  supercell.info["energy"] = atoms.info["energy"] * int(target / source)
  return supercell

In [None]:
def data_object(atoms: ase.Atoms):

  n = atoms.get_global_number_of_atoms()
  if n == 1:
    atoms = extend_atoms(atoms, 1, 54)
    n = atoms.get_global_number_of_atoms()

  cell = torch.tensor(atoms.cell, dtype=DTYPE, device=DEVICE)
  x = torch.tensor(atoms.get_positions(), dtype=DTYPE, device=DEVICE)
  z = torch.tensor(atoms.get_array("numbers", copy=True), dtype=torch.long, device=DEVICE)
  y = torch.tensor(atoms.info["energy"], dtype=DTYPE, device=DEVICE)
  f = torch.tensor(atoms.get_array("force", copy=True), dtype=DTYPE, device=DEVICE)

  return Data(z=z, x=x, cell=cell, y=y, f=f, n=n)

## Load data

In [None]:
data = []
file_names = [f"DB{i}.xyz" for i in range(1,9)]
# file_names = ["DB1_100.xyz"]  # only 100 samples from DB1; for debugging
file_paths = [f"{DATA_PATH}/{f}" for f in file_names]
for f in file_paths:
  db = read(f, index=":")
  data += list(map(data_object, db))
print(len(data))

12171


In [None]:
y_atom = torch.tensor([d.y / d.n for d in data], dtype=DTYPE)
y_atom_mean = y_atom.mean().item()
y_atom_std = y_atom.std().item()
print(y_atom_mean, y_atom_std)

-3460.825847482401 0.16576383029449515


In [None]:
random.shuffle(data)
train_amount = int(len(data) * TRAINING_RATIO)
train_data = data[:train_amount]
test_data = data[train_amount:]
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)

##Functions for the training samples cloud (to remove)

In [None]:
EVALUATE_SAMPLE_CLOUD = False

In [None]:
# save data
SAVE_DATA = False
if SAVE_DATA and EVALUATE_SAMPLE_CLOUD:
  torch.save(data, BASE_PATH + 'dati.pt')

In [None]:
# load data
if EVALUATE_SAMPLE_CLOUD:
  dati = torch.load(BASE_PATH + 'dati.pt')
  print(len(dati))

12171


In [None]:
if EVALUATE_SAMPLE_CLOUD:
  idx = 12170  # 7000
  print(dati[idx])
  print(dati[idx].cell)
  print(dati[idx].cell[0])
  print(dati[idx].cell[0][0])

Data(cell=[3, 3], f=[12, 3], n=12, x=[12, 3], y=-41530.3393752, z=[12])
tensor([[13.8817,  1.9632,  2.8051],
        [ 0.0000,  2.4540,  0.0000],
        [ 0.0000,  0.0000,  4.0073]], device='cuda:0', dtype=torch.float64)
tensor([13.8817,  1.9632,  2.8051], device='cuda:0', dtype=torch.float64)
tensor(13.8817, device='cuda:0', dtype=torch.float64)


In [None]:
# functions to determine a,b ,c and volume
# warning: for the volume, divide by the number of atoms (12, 54 or 128)
if EVALUATE_SAMPLE_CLOUD:
  a = torch.sqrt(dati[0].cell[0][0]**2 + dati[0].cell[0][1]**2 + dati[0].cell[0][2]**2)
  b = torch.sqrt(dati[0].cell[1][0]**2 + dati[0].cell[1][1]**2 + dati[0].cell[1][2]**2)
  c = torch.sqrt(dati[0].cell[2][0]**2 + dati[0].cell[2][1]**2 + dati[0].cell[2][2]**2)
  vol = torch.linalg.det(dati[0].cell)
  print('a : ', a)
  print('b : ', b)
  print('c : ', c)
  print('c/a : ', c/a)
  print('vol : ', vol)
  print('vol/54 : ', vol/54)

a :  tensor(8.5316, device='cuda:0', dtype=torch.float64)
b :  tensor(8.5316, device='cuda:0', dtype=torch.float64)
c :  tensor(8.5316, device='cuda:0', dtype=torch.float64)
c/a :  tensor(1., device='cuda:0', dtype=torch.float64)
vol :  tensor(621.0067, device='cuda:0', dtype=torch.float64)
vol/54 :  tensor(11.5001, device='cuda:0', dtype=torch.float64)


In [None]:
# different number of atoms in lattice
if EVALUATE_SAMPLE_CLOUD:
  atomic_numbers = [d.n for d in dati]
  num = list(set(atomic_numbers))
  print(num)

[128, 129, 130, 12, 53, 54, 123, 124, 125, 126, 127]


# Model

In [None]:
from ase import Atoms
from ase.neighborlist import neighbor_list 

def pbc_edges(cutoff, z, x, cell, batch, compute_sc=False):

  NH1 = torch.tensor([], dtype=torch.long, device=DEVICE)
  NH2 = torch.tensor([], dtype=torch.long, device=DEVICE)
  S = torch.tensor([], dtype=torch.long, device=DEVICE)
  D = torch.tensor([], dtype=DTYPE, device=DEVICE)
  SC = torch.tensor([], dtype=DTYPE, device=DEVICE) if compute_sc else None
  x_ = torch.clone(x).detach().cpu().numpy()

  if batch is not None:
    # count number of elements per batch
    batch_ids = list(set(batch.cpu().tolist()))
    batch_sizes = [ (batch == id).sum().item() for id in batch_ids ]

    for i in range(len(batch_sizes)):
      offset = sum(batch_sizes[:i]) # to obtain correct atom indices
      
      atoms = Atoms(charges = (z[offset:offset + batch_sizes[i]]).cpu(), 
        positions = x_[offset:offset + batch_sizes[i]], 
        cell = (cell[3*i:3*(i+1)]).cpu(),
        pbc = True
      ) 
      
      nh1, nh2, s = neighbor_list("ijS", atoms, cutoff, self_interaction=False) 
      nh1 = torch.tensor(nh1, dtype=torch.long, device=DEVICE)
      nh2 = torch.tensor(nh2, dtype=torch.long, device=DEVICE)
      nh1 = nh1 + offset
      nh2 = nh2 + offset
      s = torch.tensor(s, dtype=DTYPE, device=DEVICE)
      d = x[nh2] - x[nh1] + torch.matmul(s, cell[3*i:3*(i+1)])
      
      if compute_sc:
        cell_flat = torch.flatten(cell[3*i:3*(i+1)])
        sc = torch.tile(cell_flat, (len(d), 1))
        sc[:, 0:3] = (sc[:, 0:3].T * s[:, 0]).T
        sc[:, 3:6] = (sc[:, 3:6].T * s[:, 1]).T
        sc[:, 6:9] = (sc[:, 6:9].T * s[:, 2]).T
        SC = torch.cat((SC, sc), 0)      

      NH1 = torch.cat((NH1, nh1), 0)
      NH2 = torch.cat((NH2, nh2), 0)
      S = torch.cat((S, s), 0)
      D = torch.cat((D, d), 0)

  else: # no batch
    atoms = Atoms(charges = z.cpu(), positions = x.cpu(), cell = cell.cpu(), pbc = True)
    nh1, nh2, s = neighbor_list("ijS", atoms, cutoff, self_interaction=False)
    nh1 = torch.tensor(nh1, dtype=torch.long, device=DEVICE)
    nh2 = torch.tensor(nh2, dtype=torch.long, device=DEVICE)
    s = torch.tensor(s, dtype=DTYPE, device=DEVICE)
    d = x[nh2] - x[nh1] + torch.matmul(s, cell)
    
    if compute_sc:
      cell_flat = torch.flatten(cell)
      sc = torch.tile(cell_flat, (len(d), 1))
      sc[:, 0:3] = (sc[:, 0:3].T * s[:, 0]).T
      sc[:, 3:6] = (sc[:, 3:6].T * s[:, 1]).T
      sc[:, 6:9] = (sc[:, 6:9].T * s[:, 2]).T  
      SC = sc

    NH1, NH2, S, D = nh1, nh2, s, d

  D = D.norm(dim=-1)
  return  NH1, NH2, D, S, SC 

In [None]:
import os
import warnings
import os.path as ospa
from math import pi as PI
import torch.nn.functional as F
from torch.nn import Embedding, Sequential, Linear, ModuleList
from torch_scatter import scatter
from torch_geometric.data.makedirs import makedirs
from torch_geometric.data import download_url, extract_zip
from torch_geometric.nn import radius_graph, MessagePassing

class SchNetx(torch.nn.Module):

    def __init__(self, hidden_channels=128, num_filters=128,
                 num_interactions=6, num_gaussians=50, cutoff=10.0,
                 readout='add', dipole=False, mean=None, std=None,
                 atomref=None):
        super(SchNetx, self).__init__()

        assert readout in ['add', 'sum', 'mean']

        self.hidden_channels = hidden_channels
        self.num_filters = num_filters
        self.num_interactions = num_interactions
        self.num_gaussians = num_gaussians
        self.cutoff = cutoff
        self.readout = readout
        self.dipole = dipole
        self.readout = 'add' if self.dipole else self.readout
        self.mean = mean
        self.std = std
        self.scale = None

        atomic_mass = torch.from_numpy(ase.data.atomic_masses)
        self.register_buffer('atomic_mass', atomic_mass)

        self.embedding = Embedding(100, hidden_channels)
        self.distance_expansion = GaussianSmearing(0.0, cutoff, num_gaussians)

        self.interactions = ModuleList()
        for _ in range(num_interactions):
            block = InteractionBlock(hidden_channels, num_gaussians,
                                     num_filters, cutoff)
            self.interactions.append(block)

        self.lin1 = Linear(hidden_channels, hidden_channels // 2)
        self.act = ShiftedSoftplus()
        self.lin2 = Linear(hidden_channels // 2, 1)

        self.register_buffer('initial_atomref', atomref)
        self.atomref = None
        if atomref is not None:
            self.atomref = Embedding(100, 1)
            self.atomref.weight.data.copy_(atomref)
        
        self.reset_parameters()

    def reset_parameters(self):
        self.embedding.reset_parameters()
        for interaction in self.interactions:
            interaction.reset_parameters()
        torch.nn.init.xavier_uniform_(self.lin1.weight)
        self.lin1.bias.data.fill_(0)
        torch.nn.init.xavier_uniform_(self.lin2.weight)
        self.lin2.bias.data.fill_(0)
        if self.atomref is not None:
            self.atomref.weight.data.copy_(self.initial_atomref)

    def forward(self, z, x, cell=None, batch=None):
        assert z.dim() == 1 and z.dtype == torch.long
        batch = torch.zeros_like(z) if batch is None else batch

        h = self.embedding(z)

        if cell != None:
          row, col, edge_weight, shift, _ = pbc_edges(self.cutoff, z, x, cell, batch, compute_sc=False)
          edge_index = torch.stack((row, col))
        else:
          edge_index = radius_graph(x, r=self.cutoff, batch=batch)
          row, col = edge_index
          edge_weight = (x[row] - x[col]).norm(dim=-1)
        
        edge_attr = self.distance_expansion(edge_weight)

        for interaction in self.interactions:
            h = h + interaction(h, edge_index, edge_weight, edge_attr)

        h = self.lin1(h)
        h = self.act(h)
        h = self.lin2(h)

        if self.dipole:
            mass = self.atomic_mass[z].view(-1, 1)
            c = scatter(mass * x, batch, dim=0) / scatter(mass, batch, dim=0)
            h = h * (pos - c[batch])

        if not self.dipole and self.mean is not None and self.std is not None:
            h = h * self.std + self.mean

        if not self.dipole and self.atomref is not None:
            h = h + self.atomref(z)

        out = scatter(h, batch, dim=0, reduce=self.readout)
        
        if self.dipole:
            out = torch.norm(out, dim=-1, keepdim=True)

        if self.scale is not None:
            out = self.scale * out

        return out


    def __repr__(self):
        return (f'{self.__class__.__name__}('
                f'hidden_channels={self.hidden_channels}, '
                f'num_filters={self.num_filters}, '
                f'num_interactions={self.num_interactions}, '
                f'num_gaussians={self.num_gaussians}, '
                f'cutoff={self.cutoff})')



class InteractionBlock(torch.nn.Module):
    def __init__(self, hidden_channels, num_gaussians, num_filters, cutoff):
        super(InteractionBlock, self).__init__()
        self.mlp = Sequential(
            Linear(num_gaussians, num_filters),
            ShiftedSoftplus(),
            Linear(num_filters, num_filters),
        )
        self.conv = CFConv(hidden_channels, hidden_channels, num_filters,
                           self.mlp, cutoff)
        self.act = ShiftedSoftplus()
        self.lin = Linear(hidden_channels, hidden_channels)

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.mlp[0].weight)
        self.mlp[0].bias.data.fill_(0)
        torch.nn.init.xavier_uniform_(self.mlp[2].weight)
        self.mlp[0].bias.data.fill_(0)
        self.conv.reset_parameters()
        torch.nn.init.xavier_uniform_(self.lin.weight)
        self.lin.bias.data.fill_(0)

    def forward(self, x, edge_index, edge_weight, edge_attr):
        x = self.conv(x, edge_index, edge_weight, edge_attr)
        x = self.act(x)
        x = self.lin(x)
        return x


class CFConv(MessagePassing):
    def __init__(self, in_channels, out_channels, num_filters, nn, cutoff):
        super(CFConv, self).__init__(aggr='add')
        self.lin1 = Linear(in_channels, num_filters, bias=False)
        self.lin2 = Linear(num_filters, out_channels)
        self.nn = nn
        self.cutoff = cutoff

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.lin1.weight)
        torch.nn.init.xavier_uniform_(self.lin2.weight)
        self.lin2.bias.data.fill_(0)

    def forward(self, x, edge_index, edge_weight, edge_attr):
        C = 0.5 * (torch.cos(edge_weight * PI / self.cutoff) + 1.0)
        W = self.nn(edge_attr) * C.view(-1, 1)

        x = self.lin1(x)
        x = self.propagate(edge_index, x=x, W=W)
        x = self.lin2(x)
        return x

    def message(self, x_j, W):
        return x_j * W


class GaussianSmearing(torch.nn.Module):
    def __init__(self, start=0.0, stop=5.0, num_gaussians=50):
        super(GaussianSmearing, self).__init__()
        offset = torch.linspace(start, stop, num_gaussians)
        self.coeff = (-0.5 / (offset[1] - offset[0]).item()**2)
        self.register_buffer('offset', offset)

    def forward(self, dist):
        dist = dist.view(-1, 1) - self.offset.view(1, -1)
        res = torch.exp(self.coeff * torch.pow(dist, 2))
        return res

class ShiftedSoftplus(torch.nn.Module):
    def __init__(self):
        super(ShiftedSoftplus, self).__init__()
        self.shift = torch.log(torch.tensor(2.0)).item()

    def forward(self, x):
        return F.softplus(x) - self.shift

# Training

## Define functions


In [None]:
def energy_loss(data, p_energies):
  energies_loss = torch.mean(torch.abs(data.y - p_energies))
  return energies_loss

In [None]:
def energy_forces_loss(data, p_energies, p_forces, energy_coeff):
  energies_loss = torch.mean(torch.abs(data.y - p_energies))
  forces_loss = torch.mean(torch.abs(data.f - p_forces))
  total_loss = (energy_coeff)*energies_loss + (1-energy_coeff)*forces_loss
  return total_loss, energies_loss, forces_loss

In [None]:
def train(model, loader, optimizer, use_forces=False, energy_coeff=None):
  if not use_forces:
    train_energy(model, loader, optimizer)
  else:
    train_energy_forces(model, loader, optimizer, energy_coeff)

In [None]:
def train_energy(model, loader, optimizer):
  model.train()
  for data in loader:
    optimizer.zero_grad()

    e = model(data.z, data.x, data.cell, data.batch)
    e = e.squeeze(1) 
    e_loss = energy_loss(data, e)

    e_loss.backward()
    optimizer.step()

In [None]:
def train_energy_forces(model, loader, optimizer, energy_coeff):
  model.train()
  total_e_loss = 0
  total_f_loss = 0
  total_ef_loss = 0

  for data in loader:
    data.x.requires_grad = True 
    optimizer.zero_grad()
    
    e = model(data.z, data.x, data.cell, data.batch)
    f = -1 * torch.autograd.grad(e, data.x, grad_outputs=torch.ones_like(e), create_graph=True, retain_graph=True)[0]
    e = e.squeeze(1) 
    
    ef_loss, e_loss, f_loss = energy_forces_loss(data, e, f, energy_coeff)
    with torch.no_grad():
      total_e_loss += e_loss.item()
      total_f_loss += f_loss.item()
      total_ef_loss += ef_loss.item()

    ef_loss.backward()
    optimizer.step()

  print("Total training loss\t ef: {}, e: {}, f: {}".format(total_ef_loss, total_e_loss, total_f_loss))

In [None]:
def test(model, train_loader, test_loader):
  model.eval()

  with torch.no_grad():
    train_errs = torch.tensor([], dtype=DTYPE)
    test_errs = torch.tensor([], dtype=DTYPE)
  
    for data in train_loader:  # Iterate in batches over the training/test dataset.
      e = model(data.z, data.x, data.cell, data.batch)
      e = e.squeeze(1)
      errs = torch.abs(e.view(-1).cpu() - data.y.cpu())
      train_errs = torch.cat((train_errs, errs))
    train_mae = torch.mean(train_errs).item()

    for data in test_loader:  # Iterate in batches over the training/test dataset.
      e = model(data.z, data.x, data.cell, data.batch)
      e = e.squeeze(1)
      errs = torch.abs(e.view(-1).cpu() - data.y.cpu())
      test_errs = torch.cat((test_errs, errs))
    test_mae = torch.mean(test_errs).item()

  return train_mae, test_mae

In [None]:
# from datetime import datetime

# train_maes = []
# test_maes = []
# def train_and_test(model, train_loader, test_loader, optimizer, scheduler, 
#                    use_forces=False, energy_coeff=None, 
#                    epochs=100, starting_epoch=1, save_every=5, 
#                    name=None, description=None):
#   if name == None:
#     name = "SchNet"
#   if description == None:
#     description = datetime.now().strftime('%H:%M:%S')

#   for epoch in range(starting_epoch, epochs):
#     print("")
#     print(f"Epoch {epoch} ({datetime.now().strftime('%H:%M:%S')})")
#     print(f"")
#     train(model, train_loader, optimizer, use_forces=use_forces, energy_coeff=energy_coeff)
#     train_mae, test_mae = test(model, train_loader, test_loader)
#     train_maes.append(train_mae)
#     test_maes.append(test_mae)
#     scheduler.step(test_mae)
#     print(f"Energy MAE: train {train_mae} eV, test {test_mae} eV")
#     if epoch % save_every == 0:
#       save_model(name=name, description=description, epoch=epoch, train_err=train_mae, test_err=test_mae)

In [None]:
from datetime import datetime

train_maes = []
test_maes = []
def train_and_test(model, train_loader, test_loader, optimizer, scheduler, 
                   use_forces=False, energy_coeff=None, 
                   epochs=100, starting_epoch=1, save_every=5,
                   early_stopping_starting_epoch=100, patience=100,
                   name=None, description=None):
  if name == None:
    name = "SchNet"
  if description == None:
    description = datetime.now().strftime('%H:%M:%S')

  lowest_test_mae = 0.05136417386564761  # was 100.
  epochs_not_improved = 0
  save_model_flag = True
  for epoch in range(starting_epoch, epochs):
    print("")
    print(f"Epoch {epoch} ({datetime.now().strftime('%H:%M:%S')})")
    print(f"")
    train(model, train_loader, optimizer, use_forces=use_forces, energy_coeff=energy_coeff)
    train_mae, test_mae = test(model, train_loader, test_loader)
    train_maes.append(train_mae)
    test_maes.append(test_mae)
    print(f"Energy MAE: train {train_mae} eV, test {test_mae} eV")
    if epoch <= early_stopping_starting_epoch:
      lowest_test_mae = test_mae
      save_model_flag = True
    else:
      print('Early stopping active')
      if test_mae < lowest_test_mae:
        epochs_not_improved = 0
        lowest_test_mae = test_mae
        save_model_flag = True
        print('MAE has decreased, saving the model')
      else:
        epochs_not_improved += 1
        save_model_flag = False
        if epochs_not_improved > patience:
          print('Patience has been finished: end of the training')
          break
        else:
          print('MAE has not decreased for ' + str(epochs_not_improved) + ' epochs, go on without saving the model')
    if save_model_flag == True:
      save_model(name=name, description=description, epoch=epoch, train_err=train_mae, test_err=test_mae)
    scheduler.step(test_mae)

In [None]:
def save_model(name, description, epoch=-1, train_err=-1, test_err=-1):
  model_data = {
    "desc": "new schnet training",
    "str": model_str,
    "mean": y_atom_mean,
    "std": y_atom_std,
    "state": model.state_dict(),
    "optimizer": optimizer.state_dict(),
    "scheduler": scheduler.state_dict(),
    "train_err": train_err,
    "test_err": test_err
  }  
  if not epoch == -1:
    epoch_str = "_" + str(epoch)
  else:
    epoch_str = ""
  actual_name = f"{name}{epoch_str}"

  dir = f"{MODELS_PATH}/{name}"
  if not os.path.exists(dir):
    os.makedirs(dir)
    print(f"Created directory {dir}")
  torch.save(model_data, f"{dir}/{actual_name}")
  print(f"Saved {actual_name}")

  if os.path.isfile(f"{dir}/{name}_best"):
    best = torch.load(f"{dir}/{name}_best")
    best_train_err = best["train_err"]
    best_test_err = best["test_err"]
  else:
    best_train_err, best_test_err = 10e18, 10e18
  # if train_err < best_train_err and test_err < best_test_err:
  if test_err < best_test_err:
    torch.save(model_data, f"{dir}/{name}_best")
    print(f"Saved best")

## Instantiate model

In [None]:
model_name = "schnet_15giu" # name of the folder where checkpoints will be saved
model_description = "SchNet new training should be the same as old one (power something)"

In [None]:
cutoff = 5.0
model = SchNetx(hidden_channels = 128, num_filters=128, num_gaussians=128, cutoff=cutoff, 
                num_interactions=3, readout="sum", mean=y_atom_mean, std=y_atom_std).to(DEVICE)
if DTYPE == torch.float64:
  model = model.double()

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
optimizer.zero_grad()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, threshold=0.05, threshold_mode='rel')
use_forces = False
energy_coeff = 0.
max_epochs = 100

Copy-paste the content of the previous cell in the following variable to allow for saving of parameters + hyperparameters

In [None]:
model_str = """
cutoff = 5.0
model = SchNetx(hidden_channels = 128, num_filters=128, num_gaussians=128, cutoff=cutoff, 
                num_interactions=3, readout="sum", mean=y_atom_mean, std=y_atom_std).to(DEVICE)
if DTYPE == torch.float64:
  model = model.double()

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
optimizer.zero_grad()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, threshold=0.05, threshold_mode='rel')
use_forces = False
energy_coeff = 0.
max_epochs = 100
"""

## [Restore from checkpoint]

In [None]:
RESTORE_FROM_CHECKPOINT = True

if RESTORE_FROM_CHECKPOINT:
  model_fn = "schnet_15giu_62"
  checkpoint = torch.load(MODELS_PATH + "/schnet_15giu/" + model_fn)
  model.load_state_dict(checkpoint["state"])
  optimizer.load_state_dict(checkpoint["optimizer"])
  scheduler.load_state_dict(checkpoint["scheduler"])

## Train

In [None]:
#@title  { vertical-output: true, display-mode: "both" }
def experiment_summary():
  device = "CPU" if DEVICE == torch.device("cpu") else f"GPU ({torch.cuda.get_device_name(device=DEVICE)})"
  print(f"Device: {device}")
  print(f"Data type: {DTYPE}")
  print(f"Files used: {file_names}")
  print(f"Cutoff: {3.5}")
  print(f"Mean: {y_atom_mean}, std: {y_atom_std}")
  print(f"Training ratio: {TRAINING_RATIO}")
  print(f"Batch size: {BATCH_SIZE}")
  print(f"Model: \n    {model_name}: {model_description}\n    {str(model)}")
  loss = "energy" if not use_forces else f"energy+forces ({energy_coeff})"
  print(f"Loss: {loss}")
  print(f"Optimizer: {optimizer}")
  print(f"Scheduler: {scheduler}")
  print(f"Max epochs: {max_epochs}")
experiment_summary()

Device: GPU (Tesla P100-PCIE-16GB)
Data type: torch.float64
Files used: ['DB1.xyz', 'DB2.xyz', 'DB3.xyz', 'DB4.xyz', 'DB5.xyz', 'DB6.xyz', 'DB7.xyz', 'DB8.xyz']
Cutoff: 3.5
Mean: -3460.825847482401, std: 0.16576383029449515
Training ratio: 0.8
Batch size: 6
Model: 
    schnet_15giu: SchNet new training should be the same as old one (power something)
    SchNetx(hidden_channels=128, num_filters=128, num_interactions=3, num_gaussians=128, cutoff=5.0)
Loss: energy
Optimizer: AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 7.8125e-06
    weight_decay: 0.01
)
Scheduler: <torch.optim.lr_scheduler.ReduceLROnPlateau object at 0x7f4c135f2850>
Max epochs: 100


In [None]:
torch.cuda.get_device_name(device=DEVICE)

'Tesla P100-PCIE-16GB'

In [None]:
#@title  { vertical-output: true, display-mode: "both" }
from psutil import *
print(cpu_count())
print(cpu_stats())
!cat /proc/cpuinfo
!df -h
print(virtual_memory())


4
scpustats(ctx_switches=4792969, interrupts=2965446, soft_interrupts=1709210, syscalls=0)
processor	: 0
vendor_id	: GenuineIntel
cpu family	: 6
model		: 63
model name	: Intel(R) Xeon(R) CPU @ 2.30GHz
stepping	: 0
microcode	: 0x1
cpu MHz		: 2299.998
cache size	: 46080 KB
physical id	: 0
siblings	: 4
core id		: 0
cpu cores	: 2
apicid		: 0
initial apicid	: 0
fpu		: yes
fpu_exception	: yes
cpuid level	: 13
wp		: yes
flags		: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm invpcid_single ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid xsaveopt arat md_clear arch_capabilities
bugs		: cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs
bogomips	: 4599.99
clflush size	: 64
cache_alig

In [None]:
train_and_test(model, train_loader, test_loader, optimizer, scheduler, 
               use_forces=use_forces, energy_coeff=energy_coeff,
               epochs=max_epochs, starting_epoch=63, save_every=1,
               early_stopping_starting_epoch=40, patience=10,
               name=model_name, description=model_description)


Epoch 63 (10:37:27)

Energy MAE: train 0.045858823025664394 eV, test 0.053233723264639134 eV
Early stopping active
MAE has not decreased for 1 epochs, go on without saving the model

Epoch 64 (11:06:27)

Energy MAE: train 0.04495390461377666 eV, test 0.05251301282843373 eV
Early stopping active
MAE has not decreased for 2 epochs, go on without saving the model

Epoch 65 (11:35:08)

Energy MAE: train 0.04436545693558506 eV, test 0.0522424130949465 eV
Early stopping active
MAE has not decreased for 3 epochs, go on without saving the model

Epoch 66 (12:03:49)

Energy MAE: train 0.04514432089419587 eV, test 0.05195313950750649 eV
Early stopping active
MAE has not decreased for 4 epochs, go on without saving the model

Epoch 67 (12:32:31)

Energy MAE: train 0.0439818107720717 eV, test 0.05208845941706461 eV
Early stopping active
MAE has not decreased for 5 epochs, go on without saving the model

Epoch 68 (13:02:22)

Energy MAE: train 0.04345585736712474 eV, test 0.05147601548045418 eV
Ear