# Tutorial of New Architecture



## Imports and Settings

In [55]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from shutil import rmtree
import os

import schnetpack as spk

In [56]:
# paths
model_dir = "modeldir"
db_path = "data/schnetpack/qm9.db"

# architecture settings
n_atom_basis = 128
n_filters = 128
n_interactions = 6
n_gaussians = 25
cutoff = 10.

# training settings
batch_size = 100
n_train = 500
n_val = 200
n_workers = 4
lr = 1e-3
prop = spk.datasets.QM9.U0


## Prepare Data


In [57]:
# load qm9 dataset and download if necessary
data = spk.datasets.QM9(db_path)

# split in train and val
train, val, test = spk.data.train_test_split(data, n_train, n_val)
train_loader = spk.data.AtomsLoader(train, batch_size=100, num_workers=n_workers)
val_loader = spk.data.AtomsLoader(val)


## Build Architectures


### Classic SchNet

This section will build the classic SchNet model, as it has been used until now. From
 a user`s point of view, nothing has changed.


In [58]:
# representation
schnet = spk.SchNet(
    n_atom_basis=n_atom_basis,
    n_interactions=n_interactions,
    n_gaussians=n_gaussians,
)

# output modules
atomwise_output = spk.atomistic.Atomwise(n_in=n_atom_basis, property=prop)

# final model
classic_schnet = spk.atomistic.AtomisticModel(
    representation=schnet, output_modules=[atomwise_output]
)


### PhysNet

This section builds a PhysNet model, analogue to the classic SchNet style models. The
 usage is generally the same, but PhysNet requires some different arguments.


In [59]:
# representation
physnet = spk.PhysNet(
    n_atom_basis=n_atom_basis,
    n_interactions=n_interactions,
    n_gaussians=n_gaussians,
)

# outpur modules
corrections = [spk.atomistic.ElectrostaticEnergy(cuton=0., cutoff=10.)]
corr_atomwise_output = spk.atomistic.AtomwiseCorrected(n_in=n_atom_basis, property=prop)

# final model
classic_physnet = spk.atomistic.AtomisticModel(
    representation=physnet, output_modules=[corr_atomwise_output]
)


### AtomisticRepresentation: SchNet

Since SchNet and PhysNet generally use a similar architecture, a new parent class 
`AtomisticRepresentation` is introduced. `SchNet` and `PhysNet` are inherited from 
the base representation class. Although the subclasses make the construction of 
representations very easy, it might be useful to directly use the 
`AtomisticRepresentation` class. Therefore the single building blocks of a 
representation need to be constructed first. This requires an embedding, a 
distance-expansion, the interaction blocks, a post-interaction-layer (if needed) and 
the aggregation mode, which transforms the intermediate interaction outputs to the 
representation features. With these blocks, one can either define a SchNet 
representation, a PhysNet representation or a mixture of both.
 

In [60]:
# building blocks
s_embedding = nn.Embedding(87, n_atom_basis, padding_idx=0)
s_distance_expansion = spk.nn.GaussianSmearing(0.0, cutoff, n_gaussians)
s_interactions = nn.ModuleList(
    [
        spk.representation.SchNetInteraction(
            n_atom_basis=n_atom_basis,
            n_spatial_basis=n_gaussians,
            n_filters=n_filters,
            cutoff=cutoff,
        )
        for _ in range(n_interactions)
    ]
)
s_interaction_aggregation = spk.representation.InteractionAggregation(mode="last")

# representation
ar_schnet = spk.AtomisticRepresentation(
    embedding=s_embedding,
    distance_expansion=s_distance_expansion,
    interactions=s_interactions,
    interaction_aggregation=s_interaction_aggregation,
    return_intermediate=False,
    return_distances=False,
    sum_before_interaction_append=True,
)

# output module
ar_atomwise_output = spk.atomistic.Atomwise(n_in=n_atom_basis, property=prop)

# final model
ar_schnet_model = spk.atomistic.AtomisticModel(
    representation=ar_schnet, output_modules=[ar_atomwise_output]
)


### AtomisticRepresentation: PhysNet


In [61]:
# building blocks
p_embedding = spk.nn.Embedding(n_features=n_atom_basis)
p_distance_expansion = spk.nn.GaussianSmearing(0.0, cutoff, n_gaussians)
p_interactions = nn.ModuleList(
    [
        spk.representation.PhysNetInteraction(
            n_features=n_atom_basis,
            n_gaussians=n_gaussians,
            activation=spk.nn.Swish,
        )
        for _ in range(n_interactions)
    ]
)
p_post_interactions = nn.ModuleList(
    [
        nn.Sequential(
            spk.nn.ResidualStack(
                n_features=n_atom_basis,
                n_blocks=1,
                activation=spk.nn.Swish,
            ),
            spk.nn.Dense(
                in_features=n_atom_basis,
                out_features=n_atom_basis,
                pre_activation=spk.nn.Swish,
            ),
        )
        for _ in range(n_interactions)
    ]
)
p_interaction_aggregation = spk.representation.InteractionAggregation(mode="sum")

# build represnetation
ar_physnet = spk.AtomisticRepresentation(
    embedding=p_embedding,
    distance_expansion=p_distance_expansion,
    interactions=p_interactions,
    interaction_aggregation=p_interaction_aggregation,
    return_intermediate=False,
    return_distances=True,
    sum_before_interaction_append=False,
)

# build output module
ar_corr_atomwise_output = spk.atomistic.AtomwiseCorrected(
    n_in=n_atom_basis, property=prop
)

# final model
ar_physnet_model = spk.atomistic.AtomisticModel(
    representation=ar_physnet,
    output_modules=[ar_corr_atomwise_output]
)


### AtomisticRepresentation: Mixture Model

The mixture model illustrates how the `AtomisticRepresentation` is able to combine 
all kinds of building blocks with each other. As embedding, a layer with a 
transformation of the atomtype to a tensor with the feature dimension is required. 
The distance-expansion can be any layer which transform the interatomic distances to 
basis functions. The new representation is able to use `SchNetInteraction` and 
`PhysNetInteraction` blocks. This example shows, that even a mixture of both could 
theoretically be used. The representation loops through all interaction layers and 
collects the intermediate results of the single layers. Since some models (e.g. 
PhysNet) use an additional output layer that is applied to the interaction tensor, 
without contributing to the residual sum (x + v), the `post_interactions` layer 
is applied to the interaction-output v after the residual sum and before appending to
 the intermediate interactions list:
    
    x = x + v
    intermediate_interactions.append(post_interaction(v))
        
If post-interaction layers are not needed, `post_interactions=None` ignores them. The 
evalustion of the interaction blocks yields a list with `n_interactions` 
interaction-tensors. Depending on the model these interactions are aggregated to a 
atomistic feature representation with the use of the `interaction_aggregation`. While
 SchNet only uses the last interaction tensor as a feature representation, PhysNet 
 uses a sum-aggregation over the intermediate interactions. Independent of the 
 selected aggregation-type, the `AtomisticRepresentation` returns a feature 
 representation with the dimensions `[n_batch, n_atoms, n_atom_basis]`. This allows 
 the use of all output-modules, that are derived from the `Atomwise` class. Although 
 there is no direct need for the output-modules to recieve the intermediate 
 interactions anymore, since they are allready aggregated in the selected manner, it 
 is still possible to return the intermediate interactions. Furthermore the distance 
 calculations can be returned, because they may be required for some output-modules (e
 .g. the new `AtomwiseCorrected` output module).


In [62]:
# building blocks
m_embedding = spk.nn.Embedding(n_features=n_atom_basis)
m_distance_expansion = spk.nn.GaussianSmearing(0.0, cutoff, n_gaussians)
m_interactions = nn.ModuleList(
    [
        spk.representation.PhysNetInteraction(            
            n_features=n_atom_basis,
            n_gaussians=n_gaussians,
            activation=spk.nn.Swish,
            #cutoff=cutoff,
        ),
        spk.representation.SchNetInteraction(
            n_atom_basis=n_atom_basis,
            n_spatial_basis=n_gaussians,
            n_filters=n_filters,
            cutoff=cutoff,
        )
    ]
)
m_post_interactions = nn.ModuleList(
    [
        spk.nn.Dense(in_features=n_atom_basis, out_features=n_atom_basis) for _ in 
        range(len(m_interactions))
    ]
)
m_interaction_aggregation = spk.representation.InteractionAggregation(mode="sum")

# build representation
m_atomistic_representation = spk.AtomisticRepresentation(
    embedding=m_embedding,
    distance_expansion=m_distance_expansion,
    interactions=m_interactions,
    post_interactions=m_post_interactions,
    interaction_aggregation=m_interaction_aggregation,
    return_intermediate=False,
    return_distances=True,
    sum_before_interaction_append=False,
)

# output modules
m_atomwise = spk.atomistic.Atomwise(n_in=n_atom_basis, property=prop)
m_corrections = [spk.atomistic.ElectrostaticEnergy(cuton=0., cutoff=10.)]
m_corr_atomwise = spk.atomistic.AtomwiseCorrected(
    n_in=n_atom_basis, corrections=corrections, property=prop
)

# final model
m_atomistic_model = spk.AtomisticModel(
    representation=m_atomistic_representation, 
    output_modules=[m_atomwise, m_corr_atomwise])



## Training


In [63]:
def train(model):
    # create trainer
    print("setting up trainer...")
    modeldir = "modeldir"
    if os.path.exists(modeldir):
        rmtree(modeldir)
    opt = Adam(model.parameters(), lr=lr)
    loss = lambda b, p: F.mse_loss(p[prop], b[prop])
    trainer = spk.train.Trainer(modeldir, model, loss, opt, train_loader, val_loader)

    # start training
    print("training...")
    trainer.train(torch.device("cpu"), n_epochs=1)


### Train Classic SchNet


In [64]:
train(classic_schnet)


setting up trainer...
training...
tensor(1.2431e+08, grad_fn=<MseLossBackward>)
tensor(1.2511e+08, grad_fn=<MseLossBackward>)
tensor(1.2737e+08, grad_fn=<MseLossBackward>)
tensor(1.2721e+08, grad_fn=<MseLossBackward>)
tensor(1.2621e+08, grad_fn=<MseLossBackward>)


### Train PhysNet


In [65]:
train(classic_physnet)


setting up trainer...
training...
tensor(1.2450e+08, grad_fn=<MseLossBackward>)
tensor(1.2533e+08, grad_fn=<MseLossBackward>)
tensor(1.2758e+08, grad_fn=<MseLossBackward>)
tensor(1.2744e+08, grad_fn=<MseLossBackward>)
tensor(1.2645e+08, grad_fn=<MseLossBackward>)


### Train AtomisticRepresentation SchNet


In [66]:
train(ar_schnet_model)


setting up trainer...
training...
tensor(1.2449e+08, grad_fn=<MseLossBackward>)
tensor(1.2521e+08, grad_fn=<MseLossBackward>)
tensor(1.2736e+08, grad_fn=<MseLossBackward>)
tensor(1.2711e+08, grad_fn=<MseLossBackward>)
tensor(1.2600e+08, grad_fn=<MseLossBackward>)


### Train AtomisticRepresentation PhysNet


In [67]:
train(ar_physnet_model)


setting up trainer...
training...
tensor(1.2450e+08, grad_fn=<MseLossBackward>)
tensor(1.2520e+08, grad_fn=<MseLossBackward>)
tensor(1.2733e+08, grad_fn=<MseLossBackward>)
tensor(1.2708e+08, grad_fn=<MseLossBackward>)
tensor(1.2597e+08, grad_fn=<MseLossBackward>)


### Train AtomisticRepresentation Mixed Model


In [68]:
train(m_atomistic_model)



setting up trainer...
training...
tensor(1.2450e+08, grad_fn=<MseLossBackward>)
tensor(1.2538e+08, grad_fn=<MseLossBackward>)
tensor(1.2770e+08, grad_fn=<MseLossBackward>)
tensor(1.2761e+08, grad_fn=<MseLossBackward>)
tensor(1.2669e+08, grad_fn=<MseLossBackward>)
