In [1]:
#  This implementation is an adaptation of the schnetpack implementation of atom wise-BPNNs 
#  https://schnetpack.readthedocs.io/en/stable/_modules/schnetpack/nn/blocks.html#GatedNetwork

# It uses an embedding matrix that calculates Nspecies*Nsamples NN atom-wise contributions
# If the embedding is one-hot for each species it multiplies the "off diagonal" NN contributions with zero
# ie NN_hydrogen(X_central_species_is_Oxygen) * 0 + NN_oxygen(X_central_species_is_Oxygen) * 1
# but this embedding can also be learned assigning:
# w_oxygen, w_hydrogen for every species that are not one-hot
# Unfortunately this embedding changes the cost of the BPNN to O(N_species*N_atoms)
# I have added a second implementation where, if the one_hot_encoding is chosen, 
# the other network evaluations are supposed to be skipped (multiplied with zeros anyway...)
# reducing the overall complexity to ~O(N_atoms) assuming that N_species << N_atoms


In [2]:
from typing import List
import numpy as np
import torch
from dataclasses import dataclass
from copy import copy


In [3]:
@dataclass
class EquistoreDummy:
    z: torch.Tensor
    val: torch.Tensor
    idx: torch.Tensor

In [4]:
def init_Embedding_one_hot(embedding):
    pass

def EmbeddingFactory(elements:List[int],one_hot:bool) -> torch.nn.Embedding:
    """Returns an Embedding of dim max_Z,n_unique_elements
    max_Z = 9, n_unique = 2, elements = [1,8]
    Embedding(tensor([8])) -> tensor([0.0,1.0]) (if one hot)
    """
    
    # embedding "technically" starts at zero, Z at one
    max_int = max(elements) + 1
    n_species = len(elements)

    #randomly initialize the Embedding
    #TODO: add a initialize_weights routine
    #TODO: maybe solve it with a decorator?
    embedding = torch.nn.Embedding(max_int,n_species)
    
    # If the embedding is one-hot, the weight matrix is diagonal
    if one_hot:
        weights = torch.zeros(max_int, n_species)
        
        for idx, Z in enumerate(elements):
            weights[Z, idx] = 1.0

        embedding.weight.data = weights

    return embedding

In [18]:
class SimpleMLP(torch.nn.Module):
    """ A simple MLP 
    """

    # TODO: add n_hidden_layers, activation function option
    def __init__(self, dim_input: int, dim_output: int, layer_size: int) -> None:
        super().__init__()

        self.layer_size = layer_size
        self.dim_input = dim_input
        self.dim_output = dim_output

        self.nn = torch.nn.Sequential(
            torch.nn.Linear(self.dim_input, self.layer_size),
            torch.nn.Tanh(),
            torch.nn.Linear(self.layer_size, self.layer_size),
            torch.nn.Tanh(),
            torch.nn.Linear(self.layer_size, self.dim_output),
        )
    def forward(self,x: torch.tensor) -> torch.tensor:
        return self.nn(x)

class MultiMLP(torch.nn.Module):    
    """ A Multi MLP that contains N_species * SimpleMLPs
    """
    def __init__(self, dim_input: int, dim_output: int, layer_size: int, species: int) -> None:
        super().__init__()

        self.dim_output = dim_output
        self.dim_input = dim_input
        self.layer_size = layer_size
        self.species = species 
        self.n_species = len(self.species)

        # initialize as many SimpleMLPs as atomic species
        self.species_nn = torch.nn.ModuleList([ SimpleMLP(dim_input,dim_output,layer_size) for _ in self.species])
    
    def forward(self, x: torch.tensor) -> torch.tensor:
        return torch.cat([nn(x) for nn in self.species_nn],dim=1)


class MultiMLP_skip(MultiMLP):
    """ A Multi MLP that contains N_species * SimpleMLPs
        This Implementation does only batchwise evaluation of neural networks?
        As this implementation skips 
    """
    
    def forward(self, x: torch.tensor, batch_z: torch.tensor) -> torch.tensor:
        #will this work with autograd? -> I think it does
        
        #get the unique zs in batch
        unique_z_in_batch = torch.unique(batch_z)

        #initializes an empty torch tensor of shape (N_samples,N_species)
        model_out = torch.empty((x.shape[0],self.n_species))

        #loops over n_total_species
        for n, (z, nn) in enumerate(zip(self.species, self.species_nn)):
            
            # if a z is in a global batch -> then use the NN_central_species on the X_central species
            # fill the rest with zeros
            if z in unique_z_in_batch:
                model_out[batch_z == z, n] = nn(x[batch_z == z]).flatten()
                model_out[batch_z != z, n] = torch.zeros(x[batch_z != z].shape[0])
            
            #else: if z is not in batch at all, simply fill everything with zeros
            else:
                model_out[:, n] = torch.zeros(x.shape[0])

        return model_out


class MultiSpeciesMLP(torch.nn.Module):
    
    """ Implements a MultiSpecies Behler Parinello neural network
    This implementation scales O(Nspecies*Natoms), but it has a learnable weight matrix, that combines species wise energies
    """

    def __init__(self, species, n_in, n_out, n_hidden, one_hot, embedding_trainable) -> None:
        
        super().__init__()
        
        #just a precaution
        species = copy(species)
        species.sort()

        #print(species)

        self.species = species
        self.nn = MultiMLP(n_in,n_out,n_hidden,species)
        self.embedding = EmbeddingFactory(species, one_hot)

        if not embedding_trainable:
            self.embedding.requires_grad_ = False
            

    def forward(self, descriptor: EquistoreDummy) -> torch.tensor:
        
        x = descriptor.val 
        z = descriptor.z # something like descriptor.

        #The embedding serves a a multiplicative "mask" -> not so nice overall complexity scales as O(N_species*N_samples)
        # whereas an implementation that could "skip" NN evaluations should only scale as O(N_samples)
        return torch.sum(self.nn(x) * self.embedding(z),dim=1)





class MultiSpeciesMLP_skip(torch.nn.Module):
    
    """ Implements a MultiSpecies Behler Parinello neural network
    This implementation should scale O(Natoms) as it skips the neural network evaluations that would be otherwise only multiplied with zeros
    """

    def __init__(self, species, n_in, n_out, n_hidden) -> None:
        
        super().__init__()

        #just a precaution
        species = copy(species)
        species.sort()

        #print(species)
    
        #TODO: Implement this properly in the MultiSpeciesMLP class

        self.n_out = n_out
        self.species = species

        # if we want to skip the NN evaluations the Embedding has to be non trainable
        # therefore -> MultiMLP_skip has no trainable kwargs and one_hot in EmbeddingFactory is always true
        # TODO: Implement this properly in the MultiSpeciesMLP class

        self.nn = MultiMLP_skip(n_in,n_out,n_hidden,species)
        self.embedding = EmbeddingFactory(species, True)
        self.embedding.requires_grad_ = False

    def forward(self, x: torch.tensor, z:torch.tensor) -> torch.tensor:
        # here the embedding multiplication should only introduce a minor overhead 
        return torch.sum(self.nn(x,z) * self.embedding(z),dim=1,keepdim=True)   


    """
    def forward(self, descriptor: EquistoreDummy) -> torch.tensor:
        
        x = descriptor.val  # something like descriptor.
        z = descriptor.z 

        # here the embedding multiplication should only introduce a minor overhead 
        return torch.sum(self.nn(x,z) * self.embedding(z),dim=1)
    """

class MultiSpeciesMLP_skip_w_index_add(MultiSpeciesMLP_skip):
    
    """ For testing purposes I have added the atomic-contributions to structure wise properties addition
    """
    
    def forward(self, descriptor: EquistoreDummy) -> torch.tensor:
                
        x = descriptor.val  
        z = descriptor.z 
        idx = descriptor.idx

        x.requires_grad_(True)
        num_structures = len(torch.unique(idx))
        
        structure_wise_properties = torch.zeros((num_structures,self.n_out))

        # In the summation of the atomic contirbutions the dimensions should be kept for autograds
        atomic_contributions = torch.sum(self.nn(x,z) * self.embedding(z),dim=1,keepdim=True)

        print(atomic_contributions.flatten())
        
        structure_wise_properties.index_add_(0,idx,atomic_contributions)

        nn_grads = torch.autograd.grad(
                structure_wise_properties,
                x,
                grad_outputs=torch.ones_like(structure_wise_properties),
                create_graph=True,
                retain_graph=True,
            )

        return structure_wise_properties, nn_grads
        

In [6]:
a_block = EquistoreDummy(torch.tensor([1,1,8,1,7]),torch.ones(5,4),torch.tensor([0,0,0,1,1]))
#multimlp = MultiSpeciesMLP_skip([1,8,7],3,1,10)

In [7]:
multimlp_index_add = MultiSpeciesMLP_skip_w_index_add([1,8,7],4,1,10)

In [8]:
multimlp_index_add.forward(a_block)

tensor([0.1124, 0.1124, 0.0163, 0.1124, 0.0560],
       grad_fn=<ReshapeAliasBackward0>)


(tensor([[0.2410],
         [0.1684]], grad_fn=<IndexAddBackward0>),
 (tensor([[ 0.0689, -0.0983,  0.0124, -0.0464],
          [ 0.0689, -0.0983,  0.0124, -0.0464],
          [-0.0497, -0.0061,  0.0588,  0.1362],
          [ 0.0689, -0.0983,  0.0124, -0.0464],
          [-0.0140,  0.0829,  0.0698, -0.1322]], grad_fn=<AddBackward0>),))

In [9]:
sum([-0.0752, -0.0752, -0.1002])

-0.2506

In [10]:
sum([-0.0752,  0.1738])

0.09860000000000001

Now combine it with the existing NN/BPNN code

In [19]:
import numpy as np
import torch

from utils.operations import StructureMap


class NNModel(torch.nn.Module):
    def __init__(self, layer_size=100):
        super().__init__()

        self.nn = None
        self.layer_size = layer_size

    # build a combined 
    def initialize_model_weights(self, descriptor, energies, forces=None, seed=None):
        if seed is not None:
            torch.manual_seed(seed)

        X = descriptor.block().values

        # initialize nn with zero weights ??
        def init_zero_weights(m):
            if isinstance(m, torch.nn.Linear):
                m.weight.data.fill_(0)
                m.bias.data.fill_(0)

        # 

        self.nn = torch.nn.Sequential(
            torch.nn.Linear(X.shape[-1], self.layer_size),
            torch.nn.Tanh(),
            torch.nn.Linear(self.layer_size, self.layer_size),
            torch.nn.Tanh(),
            torch.nn.Linear(self.layer_size, 1),
        )

    def forward(self, descriptor, with_forces=False):
        if self.nn is None:
            raise Exception("call initialize_weights first")

        ps_block = descriptor.block()
        ps_tensor = ps_block.values

        if with_forces:
            # TODO(guillaume): can this have unintended side effects???
            ps_tensor.requires_grad_(True)

        structure_map, new_samples, _ = StructureMap(
            ps_block.samples["structure"], ps_tensor.device
        )

        nn_per_atom = self.nn(ps_tensor)
        
        #structure is actually atomic envs
        nn_per_structure = torch.zeros((len(new_samples), 1), device=ps_tensor.device)
        
        #adds atomic contributions per structure
        nn_per_structure.index_add_(0, structure_map, nn_per_atom)

        energies = nn_per_structure
        if with_forces:
            
            # computes dnn/dg for dnn/dg dg/dx
            nn_grads = torch.autograd.grad(
                nn_per_structure,
                ps_tensor,
                grad_outputs=torch.ones_like(nn_per_structure),
                create_graph=True,
                retain_graph=True,
            )


            ps_gradient = descriptor.block().gradient("positions")
            ps_tensor_grad = ps_gradient.data.reshape(-1, 3, ps_tensor.shape[-1])

            gradient_samples_Aj = np.asarray(
                ps_gradient.samples[["structure", "atom"]], dtype=tuple
            )

            #why is a unique gradient necessary
            unique_gradient, unique_gradient_idx = np.unique(
                gradient_samples_Aj, return_index=True
            )
            # new_gradient_samples = gradient_samples_Aj[np.sort(unique_gradient_idx)]

            # the logic is analogous to that for the structures: we have to map
            # positions in the full (A,i,j) vector to the position where they
            # will have to be accumulated
            gradient_replace_rule = dict(
                zip(unique_gradient, range(len(unique_gradient)))
            )


            gradient_map = torch.tensor(
                [gradient_replace_rule[i] for i in gradient_samples_Aj],
                dtype=torch.long,
                device=ps_tensor.device,
            )

            new_gradient_data = torch.zeros(
                (len(unique_gradient), 3, 1),
                device=ps_tensor.device,
            )
            # ... and then contracting the gradients is just one call
            nn_per_atom_forces = -torch.sum(
                ps_tensor_grad * nn_grads[0][gradient_map][:, None, :], -1
            )

            #why the index add here ?
            new_gradient_data.index_add_(
                0, gradient_map, nn_per_atom_forces[:, :, None]
            )
            forces = new_gradient_data.reshape(-1, 3)
        else:
            forces = None
        return energies, forces


In [20]:
class SpeciesWiseBPNN(torch.nn.Module):
    def __init__(self, layer_size=100):
        super().__init__()

        self.nn = None
        self.layer_size = layer_size
        # for now we only want to predict energies
        self.n_out = 1

    # build a combined 
    def initialize_model_weights(self, descriptor, energies, forces=None, seed=None):
        if seed is not None:
            torch.manual_seed(seed)

        X = descriptor.block().values
        z = torch.tensor(descriptor.block().samples["species_center"])

        species_unique = torch.unique(z).tolist()


        # initialize nn with zero weights ??
        def init_zero_weights(m):
            if isinstance(m, torch.nn.Linear):
                m.weight.data.fill_(0)
                m.bias.data.fill_(0)

        # 
        n_feat_descriptor = X.shape[-1]


        #MultiSpeciesMLP_skip: feat --> species wise NN, skipping evals --> atomic contributions out
        self.nn = MultiSpeciesMLP_skip(species_unique,n_feat_descriptor,self.n_out,self.layer_size)

    def forward(self, descriptor, with_forces=False):
        if self.nn is None:
            raise Exception("call initialize_weights first")

        ps_block = descriptor.block()
        ps_tensor = ps_block.values #is this a torch tensor? check

        # obtaining central species of batch
        ps_z = torch.tensor(ps_block.samples["species_center"])

        if with_forces:
            # TODO(guillaume): can this have unintended side effects???
            ps_tensor.requires_grad_(True)

        structure_map, new_samples, _ = StructureMap(
            ps_block.samples["structure"], ps_tensor.device
        )

        nn_per_atom = self.nn(ps_tensor, ps_z)
        
        #structure is actually atomic envs
        nn_per_structure = torch.zeros((len(new_samples), 1), device=ps_tensor.device)
        
        #adds atomic contributions per structure
        nn_per_structure.index_add_(0, structure_map, nn_per_atom)

        energies = nn_per_structure
        if with_forces:
            
            # computes dnn/dg for dnn/dg dg/dx
            nn_grads = torch.autograd.grad(
                nn_per_structure,
                ps_tensor,
                grad_outputs=torch.ones_like(nn_per_structure),
                create_graph=True,
                retain_graph=True,
            )


            ps_gradient = descriptor.block().gradient("positions")
            ps_tensor_grad = ps_gradient.data.reshape(-1, 3, ps_tensor.shape[-1])

            gradient_samples_Aj = np.asarray(
                ps_gradient.samples[["structure", "atom"]], dtype=tuple
            )

            #why is a unique gradient necessary
            unique_gradient, unique_gradient_idx = np.unique(
                gradient_samples_Aj, return_index=True
            )
            # new_gradient_samples = gradient_samples_Aj[np.sort(unique_gradient_idx)]

            # the logic is analogous to that for the structures: we have to map
            # positions in the full (A,i,j) vector to the position where they
            # will have to be accumulated
            gradient_replace_rule = dict(
                zip(unique_gradient, range(len(unique_gradient)))
            )


            gradient_map = torch.tensor(
                [gradient_replace_rule[i] for i in gradient_samples_Aj],
                dtype=torch.long,
                device=ps_tensor.device,
            )

            new_gradient_data = torch.zeros(
                (len(unique_gradient), 3, 1),
                device=ps_tensor.device,
            )
            # ... and then contracting the gradients is just one call
            nn_per_atom_forces = -torch.sum(
                ps_tensor_grad * nn_grads[0][gradient_map][:, None, :], -1
            )

            #why the index add here ?
            new_gradient_data.index_add_(
                0, gradient_map, nn_per_atom_forces[:, :, None]
            )
            forces = new_gradient_data.reshape(-1, 3)
        else:
            forces = None
        return energies, forces

Testing the z-wise implementation

In [13]:
import argparse
import json
import os
import sys
import time
from datetime import datetime

import ase.io
import numpy as np
import torch
import rascaline
from utils.combine import UnitCombineSpecies, CombineSpecies
from utils.dataset import AtomisticDataset, create_dataloader
from utils.model import AlchemicalModel, SoapBpnn, CombinedPowerSpectrum
from utils.soap import PowerSpectrum
from ase.io import read


torch.set_default_dtype(torch.float64)

def extract_energy_forces(frames):
    energies = (
        torch.tensor([frame.info["TotEnergy"] for frame in frames])
        .reshape(-1, 1)
        .to(dtype=torch.get_default_dtype())
    )

    forces = [
        torch.tensor(frame.arrays["force"]).to(dtype=torch.get_default_dtype())
        for frame in frames
    ]

    return energies, forces

frames = read("./data/water_converted.xyz",index=":2")
energies, forces = extract_energy_forces(frames)

hypers_rs = {
        "cutoff": 6.0,
        "max_angular": 0,
        "max_radial": 12,
        "atomic_gaussian_width": 0.25,
        "cutoff_function": {"ShiftedCosine": {"width": 0.5}},
        "radial_basis": {"SplinedGto": {"accuracy": 1e-6}},
        "center_atom_weight": 1.0,
        "radial_scaling":  {"Willatt2018": { "scale": 2.0, "rate": 0.8, "exponent": 2}}
    }
hypers_ps = {
    "cutoff": 4.0,
    "max_angular": 4,
    "max_radial": 8,
    "atomic_gaussian_width": 0.3,
    "cutoff_function": {"ShiftedCosine": {"width": 0.5}},
    "radial_basis": {"SplinedGto": {"accuracy": 1e-6}},
    "center_atom_weight": 1.0,
    "radial_scaling":  {"Willatt2018": { "scale": 2.0, "rate": 0.8, "exponent": 2}}
}

atomic_data = AtomisticDataset(frames,all_species=[1,8],hypers={"radial_spectrum": hypers_rs, "spherical_expansion": hypers_ps},energies=energies,forces=forces,do_gradients=True)
train_dataloader_no_batch = create_dataloader(
    atomic_data,
    batch_size=len(atomic_data),
    shuffle=False,
    device="cpu",
)
composition, radial_spectrum, spherical_expansions, energies, forces = next(
    iter(train_dataloader_no_batch)
)

combiner = UnitCombineSpecies([1,8],2)
soap_map = CombinedPowerSpectrum(combiner)

SOAPS = soap_map.forward(spherical_expansions)

Timing for key2prop 0.008439064025878906
Timing for key2prop 0.008617877960205078
Timing for key2prop 0.01029205322265625
Timing for key2prop 0.007021188735961914
Timing for key2prop 0.008301019668579102
Timing for key2prop 0.010734081268310547
Timing for key2prop 0.015706777572631836
Timing for key2prop 0.01627206802368164
Species mixing init with eigenvalues  [0.99898595 1.00101405]


In [21]:
model = SpeciesWiseBPNN(layer_size=10)

In [22]:
model.initialize_model_weights(SOAPS,energies)

In [23]:
model(SOAPS,with_forces=True)

(tensor([[31.6962],
         [31.7011]], grad_fn=<IndexAddBackward0>),
 tensor([[ 6.8181e-03, -1.6958e-04, -3.8749e-03],
         [-1.7751e-05,  2.9390e-03, -2.8651e-03],
         [ 8.0560e-03,  7.4388e-04, -1.0027e-03],
         ...,
         [ 6.6841e-03,  3.8505e-04, -4.9294e-03],
         [-1.7343e-03,  1.0702e-03, -3.0131e-03],
         [-3.4347e-03,  6.9243e-05,  3.1525e-03]],
        grad_fn=<ReshapeAliasBackward0>))

In [25]:
model_NN = NNModel(layer_size=10)
model_NN.initialize_model_weights(SOAPS,energies)

In [26]:
model_NN(SOAPS,with_forces=True)

(tensor([[-38.3957],
         [-38.4049]], grad_fn=<IndexAddBackward0>),
 tensor([[ 0.0050,  0.0004, -0.0024],
         [ 0.0019, -0.0004,  0.0014],
         [-0.0053, -0.0004, -0.0007],
         ...,
         [ 0.0056, -0.0024,  0.0003],
         [ 0.0021,  0.0029,  0.0082],
         [ 0.0019, -0.0037, -0.0083]], grad_fn=<ReshapeAliasBackward0>))