In [36]:
#  This implementation is an adaptation of the schnetpack implementation of atom wise-BPNNs 
#  https://schnetpack.readthedocs.io/en/stable/_modules/schnetpack/nn/blocks.html#GatedNetwork

# It uses an embedding matrix that calculates Nspecies*Nsamples NN atom-wise contributions
# If the embedding is one-hot for each species it multiplies the "off diagonal" NN contributions with zero
# ie NN_hydrogen(X_central_species_is_Oxygen) * 0
# but this embedding can also be learned
# I have added a second implementation where, if the one_hot_encoding is chosen, the other network evaluations are supposed to be skipped

# 

In [37]:
from typing import List
import numpy as np
import torch
from dataclasses import dataclass
from copy import copy


In [38]:
@dataclass
class EquistoreDummy:
    z: torch.Tensor
    val: torch.Tensor
    idx: torch.Tensor

In [39]:
def init_Embedding_one_hot(embedding):
    pass

def EmbeddingFactory(elements:List[int],one_hot:bool) -> torch.nn.Embedding:
    """Returns an Embedding of dim max_Z,n_unique_elements
    max_Z = 9, n_unique = 2, elements = [1,8]
    Embedding(tensor([8])) -> tensor([0.0,1.0]) (if one hot)
    """
    
    max_int = max(elements) + 1
    n_species = len(elements)

    embedding = torch.nn.Embedding(max_int,n_species)
    
    if one_hot:
        weights = torch.zeros(max_int, n_species)
        
        for idx, Z in enumerate(elements):
            weights[Z, idx] = 1.0

        embedding.weight.data = weights

    return embedding

In [159]:
class SimpleMLP(torch.nn.Module):
    """ A simple MLP 
    """

    # TODO: add n_hidden_layers, activation function option
    def __init__(self, dim_input: int, dim_output: int, layer_size: int) -> None:
        super().__init__()

        self.layer_size = layer_size
        self.dim_input = dim_input
        self.dim_output = dim_output

        self.nn = torch.nn.Sequential(
            torch.nn.Linear(self.dim_input, self.layer_size),
            torch.nn.Tanh(),
            torch.nn.Linear(self.layer_size, self.layer_size),
            torch.nn.Tanh(),
            torch.nn.Linear(self.layer_size, self.dim_output),
        )
    def forward(self,x: torch.tensor) -> torch.tensor:
        return self.nn(x)

class MultiMLP(torch.nn.Module):    
    """ A Multi MLP that contains N_species * SimpleMLPs
    """
    def __init__(self, dim_input: int, dim_output: int, layer_size: int, species: int) -> None:
        super().__init__()

        self.dim_output = dim_output
        self.dim_input = dim_input
        self.layer_size = layer_size
        self.species = species 
        self.n_species = len(self.species)
        self.species_nn = torch.nn.ModuleList([ SimpleMLP(dim_input,dim_output,layer_size) for _ in self.species])
    
    def forward(self, x: torch.tensor) -> torch.tensor:
        return torch.cat([nn(x) for nn in self.species_nn],dim=1)


class MultiMLP_skip(MultiMLP):
    """ A Multi MLP that contains N_species * SimpleMLPs
        This Implementation does only batchwise evaluation of neural networks?
        As this implementation skips 
    """
    
    def forward(self, x: torch.tensor, batch_z: torch.tensor) -> torch.tensor:
        #will this work with autograd?
        
        unique_z_in_batch = torch.unique(batch_z)
        
        model_out = torch.empty((x.shape[0],self.n_species))

        for n, (z, nn) in enumerate(zip(self.species, self.species_nn)):
            if z in unique_z_in_batch:
                model_out[batch_z == z, n] = nn(x[batch_z == z]).flatten()
                model_out[batch_z != z, n] = torch.zeros(x[batch_z != z].shape[0])
            else:
                model_out[:, n] = torch.zeros(x.shape[0])

        return model_out


class MultiSpeciesMLP(torch.nn.Module):
    
    """ Implements a MultiSpecies Behler Parinello neural network
    This implementation scales O(Nspecies*Natoms), but it has a learnable weight matrix, that combines species wise energies
    """

    def __init__(self, species, n_in, n_out, n_hidden, one_hot, embedding_trainable) -> None:
        
        super().__init__()
        
        #just a precaution
        species = copy(species)
        species.sort()

        #print(species)

        self.species = species
        self.nn = MultiMLP(n_in,n_out,n_hidden,species)
        self.embedding = EmbeddingFactory(species, one_hot)

        if not embedding_trainable:
            self.embedding.requires_grad_ = False
            

    def forward(self, descriptor: EquistoreDummy) -> torch.tensor:
        
        x = descriptor.val 
        z = descriptor.z # something like descriptor.

        #The embedding serves a a multiplicative "mask" -> not so nice overall complexity scales as O(N_species*N_samples)
        # whereas an implementation that could "skip" NN evaluations should only scale as O(N_samples)
        return torch.sum(self.nn(x) * self.embedding(z),dim=1)





class MultiSpeciesMLP_skip(torch.nn.Module):
    
    """ Implements a MultiSpecies Behler Parinello neural network
    This implementation should scale O(Natoms) as it skips the neural network evaluations that would be otherwise only multiplied with zeros
    """

    def __init__(self, species, n_in, n_out, n_hidden) -> None:
        
        super().__init__()

        #just a precaution
        species = copy(species)
        species.sort()

        #print(species)
        
        #TODO: Implement this properly in the MultiSpeciesMLP class
        self.n_out = n_out

        self.species = species
        self.nn = MultiMLP_skip(n_in,n_out,n_hidden,species)
        self.embedding = EmbeddingFactory(species, True)
        self.embedding.requires_grad_ = False
            

    def forward(self, descriptor: EquistoreDummy) -> torch.tensor:
        
        x = descriptor.val  # something like descriptor.
        z = descriptor.z 

        # here the embedding multiplication should only introduce a minor overhead
        return torch.sum(self.nn(x,z) * self.embedding(z),dim=1)

class MultiSpeciesMLP_skip_w_index_add(MultiSpeciesMLP_skip):
    
    """ For testing purposes I have added the atomic-contributions to structure wise properties addition
    """
    
    def forward(self, descriptor: EquistoreDummy) -> torch.tensor:
                
        x = descriptor.val  
        z = descriptor.z 
        idx = descriptor.idx

        x.requires_grad_(True)
        num_structures = len(torch.unique(idx))
        
        structure_wise_properties = torch.zeros((num_structures,self.n_out))
        atomic_contributions = torch.sum(self.nn(x,z) * self.embedding(z),dim=1,keepdim=True)

        print(atomic_contributions.flatten())
        
        structure_wise_properties.index_add_(0,idx,atomic_contributions)

        nn_grads = torch.autograd.grad(
                structure_wise_properties,
                x,
                grad_outputs=torch.ones_like(structure_wise_properties),
                create_graph=True,
                retain_graph=True,
            )

        return structure_wise_properties, nn_grads
        

In [163]:
a_block = EquistoreDummy(torch.tensor([1,1,8,1,7]),torch.ones(5,4),torch.tensor([0,0,0,1,1]))
#multimlp = MultiSpeciesMLP_skip([1,8,7],3,1,10)

In [166]:
multimlp_index_add = MultiSpeciesMLP_skip_w_index_add([1,8,7],4,1,10)

In [167]:
multimlp_index_add.forward(a_block)

tensor([-0.0752, -0.0752, -0.1002, -0.0752,  0.1738],
       grad_fn=<ReshapeAliasBackward0>)


(tensor([[-0.2505],
         [ 0.0986]], grad_fn=<IndexAddBackward0>),
 (tensor([[-0.0495,  0.0438,  0.0283,  0.0377],
          [-0.0495,  0.0438,  0.0283,  0.0377],
          [ 0.0351,  0.0911, -0.0211, -0.0484],
          [-0.0495,  0.0438,  0.0283,  0.0377],
          [ 0.0325,  0.0733, -0.1261,  0.0340]], grad_fn=<AddBackward0>),))

In [168]:
sum([-0.0752, -0.0752, -0.1002])

-0.2506

In [169]:
sum([-0.0752,  0.1738])

0.09860000000000001