In [None]:
%%capture install
try:
  import imlms
  print('Already installed')
except:
  %pip install git+https://github.com/Mads-PeterVC/imlms

In [None]:
print(install.stdout.splitlines()[-1])

## Pytorch Basics

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from typing import Dict

from matscipy.neighbours import neighbour_list
from ase.data.colors import jmol_colors

from torch_geometric.nn import MessagePassing

#### Warning: This tutorial is probably a less pedagogical and more difficult than the other ones. But there is not really any code you need to write, so you can just run everything and explore at your own pace.

## Graphs in Pytorch

Now that we have gotten a feeling for the basic operations of a `torch` we are 
ready to take a look at more advanced network types. 

The state of the art neural networks for fitting potentials (and other tasks in material science) are graph neural networks.

So it is of interest to understand how to work with graph neural networks. 

Remember that a graph consists of
* $V$: Vertices (or nodes).
* $E$: Edges. 

And optionally
* Vertex features: Such as the atomic number or an embedding thereof. 
* Edge features: Such as the distance between the two connected vertices. 

It can be quite helpful to store these properties in an appropriate data structure, such as the graph structure provided by `torch_geometric`. 

Lets make a simple graph using `torch_geometric.data.Data`

In [None]:
from torch_geometric.data import Data

edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long) # Edges 

x = torch.tensor([[-1], [0], [1]], dtype=torch.float) # Vertex features.

data = Data(x=x, edge_index=edge_index)

This graph looks like this

![graph_example](https://pytorch-geometric.readthedocs.io/en/latest/_images/graph.svg)

The cell below defines two functions, one that creates a random graph and one that plots the graph. 

The random graph created has one attribute, namely an integer number between 0 and 9.

In [None]:
colors = plt.get_cmap('coolwarm')
norm = plt.Normalize(0, 9)

# Create a random graph: 
def random_graph(num_nodes, cutoff=2.5, box_size=10):

    positions = []
    for i in range(num_nodes):

        new_position = torch.rand(1, 2) * box_size

        if len(positions) > 0:
            all_positions = torch.vstack(positions)
            while torch.any(torch.cdist(all_positions, new_position) < 0.75):
                new_position = torch.rand(1, 2) * box_size
        
        positions.append(new_position)

    positions = torch.vstack(positions)
    D = torch.cdist(positions, positions)

    edge_index = []
    for i in range(D.shape[0]):
        for j in range(D.shape[1]):
            if D[i, j] < cutoff:
                edge_index.append([i, j])

    edge_index = torch.tensor(np.array(edge_index).T, dtype=torch.int64).reshape(2, -1)

    x = torch.randint(0, 10, (num_nodes, 1))

    return Data(edge_index=edge_index, pos=positions, x=x)

def plot_graph(ax, graph):
    positions = graph.pos.detach().numpy()
    numbers = graph.x.detach().numpy()

    # Plot the nodes:
    r = 0.35
    theta = np.linspace(0, 2*np.pi, 100)
    x = r * np.cos(theta)
    y = r * np.sin(theta)

    # Plot the edges:
    for idx, edge in enumerate(graph.edge_index.T):
        source = positions[edge[0]]
        target = positions[edge[1]]

        # Plot the target node:
        ax.plot(target[0]+x, target[1]+y, c='black')
        ax.fill_between(target[0]+x, target[1]+y, target[1], color=colors(norm(numbers[edge[1]])))

        ax.plot([source[0], target[0]], [source[1], target[1]], c='black', zorder=0)

    ax.axis('equal')

### Message Passing

Now we will build a simple message-passing layer to illustrate how it works. In general 
a message-passing layer can be defined by

$$
x_i^{k+1} = \phi \left( x^k_i, \bigoplus_{j \in N_i} \psi(x^k_i, x^k_j, e_{ji}) \right)
$$

This is an intimidating equation at first encounter, so lets break it down;

- $x_i^k$ : The initial features of node $i$.
- $x_i^k$ : Features of node $i$ after the update.
- $\phi$, $\psi$ : Differentiable functions, typically represented by neural networks.
- $\bigoplus$ : A *permutation invariant* aggregation operation. Aggregation means an operation such 
as a sum, an average or a maximum and permutation invariant means that the order in which it is applied to the elements is not influential. 
- $N_i$ : The set of neighbours of node $i$.
- $e_{ji}$ : The edge features of the directed node from $j$ to $i$.

As this is our first application of message-passing we will simplify a bit, our first 
simplification will be to choose that 

$$\phi(q, p) = p$$
Which means that the outer function $\phi$ just picks the second input.
Additionally we will choose
$$
\psi(q, p, z) = \psi(p) = p
$$,
that is $psi$ only just ignores $q$ and $e$ and returns $p$. With these choices 
we can restate the equation from above

$$
x_i^{k+1} = \bigoplus_{j \in N_i} x^k_j
$$
Now will choose a $\mathrm{max}$ has our aggregation operation $\bigoplus$

$$
x_i^{k+1} = \max_{j \in N_i} \ x^k_j
$$

Now we define a message passing layer, using `MessagePassing` from `torch_geometric`. 
This layer updates the node attribute to be the max of its neighbours attribute.

In [None]:
class MaxMessage(MessagePassing):

    def __init__(self):
        super(MaxMessage, self).__init__(aggr='max')

    def forward(self, x, edge_index):
        return self.propagate(edge_index, x=x)

    def message(self, x_j):
        return x_j

message_passing = MaxMessage()

sz = 4
graph = random_graph(100, cutoff=1.5, box_size=15)

fig, axes = plt.subplots(1, 4, figsize=(4*sz, 1*sz))

for i, ax in enumerate(axes.flatten()):

    if i != 0:
        new_x = message_passing(graph.x, graph.edge_index)
        graph.x = new_x

    plot_graph(ax, graph)
    ax.set_xticks([])
    ax.set_yticks([])

    ax.set_title(f'{i} Message Passing Steps')



In these plots a dark red color corresponds to a high node feature value and a dark blue color corresponds to a low node feature value.
As more and more propagation steps are taken the graphs become as red as the most red member of the subgraph.

#### Exercises:

1. Change the message or the aggregration method (`aggr` in the `__init__` call) in some way. E.g. try with minimum or mean aggregation.

2. The constructed graph has self-interactions, e.g. each node sees and sends messages to it self - try turning that off by changing how the `edge_index`-tensor is constructed. 

### Atomic Graphs

We would like to work with atomic structures in this format, so we need define which atoms are connected with an edge and decide on the features.

In [None]:
class AtomsGraph(Data):

    @classmethod
    def from_atoms(cls, atoms, cutoff=5.0, dtype=torch.float, energy=None, forces=None):

        # Build the neighbour list:
        i, j, S = neighbour_list('ijS', atoms=atoms, cutoff=cutoff)

        # Edges: Defines which atoms are connected, first row is the sender, 
        # second row is the reciever (as the graph is directed).
        ij = np.array([i, j])
        edge_index = torch.tensor(ij, dtype=torch.long)

        # Nodes: The initial node features are just the atomic numbers.
        node_feat = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.long).reshape(-1)

        # Shift vectors:
        shifts = torch.tensor(S, dtype=dtype) # These are integer shifts in the unit cell.
        cell = torch.tensor(np.array(atoms.get_cell()), dtype=dtype)
        shift_vectors = torch.einsum('ij,jk->ik', shifts, cell) # Convert the shifts to vectors in Å.

        # Positios: Have requires grad because we will want to compute forces.
        positions = torch.tensor(atoms.get_positions(), dtype=dtype, requires_grad=True)

        # Target properties: Might be neater to keep this on the data loader, but I haven't 
        # gotten that to work well yet.
        if energy is not None:
            energy = torch.tensor(energy, dtype=dtype).reshape(1, 1)
        if forces is not None:
            forces = torch.tensor(forces, dtype=dtype)

        return cls(pos=positions,
                   edge_index=edge_index, 
                   x=node_feat,
                   shift_vectors=shift_vectors, 
                   cell=cell, 
                   energy=energy,
                   forces=forces)

    def plot(self, ax, node_radius=0.4, plot_cell=True):

        positions = self.pos.detach().numpy()
        numbers = self.x.detach().numpy()

        # Plot the nodes:
        r = node_radius
        theta = np.linspace(0, 2*np.pi, 100)
        x = r * np.cos(theta)
        y = r * np.sin(theta)

        # Plot the edges:
        for idx, edge in enumerate(self.edge_index.T):

            shift = self.shift_vectors[idx].detach().numpy()
            if shift.any():
                linestyle = '--'
                marker = 'x'
            else:
                linestyle = '-'
                marker=None


            source = positions[edge[0]]
            target = positions[edge[1]] + shift

            # Plot the target node:
            ax.plot(target[0]+x, target[1]+y, c='black')
            ax.plot(target[0], target[1], marker=marker, c='red', zorder=2)
            ax.fill_between(target[0]+x, target[1]+y, target[1], color=jmol_colors[numbers[edge[1]]])

            ax.plot([source[0], target[0]], [source[1], target[1]], c='black', linestyle=linestyle, zorder=0)

        if plot_cell:
            cell = self.cell.detach().numpy()
            ax.plot([0, cell[0, 0]], [0, cell[0, 1]], c='black')
            ax.plot([cell[0, 0], cell[0, 0]+cell[1, 0]], [cell[0, 1], cell[0, 1]+cell[1, 1]], c='black')
            ax.plot([0, cell[1, 0]], [0, cell[1, 1]], c='black')
            ax.plot([cell[1, 0], cell[0, 0]+cell[1, 0]], [cell[1, 1], cell[0, 1]+cell[1, 1]], c='black')

        ax.axis('equal')
        

In [None]:
from ase.build import molecule
from ase.build import mx2

atoms = molecule('C6H6')
atoms.set_cell([[10, 0, 0], [0, 10, 0], [0, 0, 10]])
atoms.center()
molecule_graph = AtomsGraph.from_atoms(atoms, cutoff=2.0) # You can play with the cutoff to get cool graphs.

mos2 = mx2(formula='MoS2', kind='2H', a=3.18, thickness=3.19, size=(1, 1, 1), vacuum=10.0)
mos2 = mos2.repeat([3, 3, 1])
mos2_graph = AtomsGraph.from_atoms(mos2, cutoff=5.0)


fig, axes = plt.subplots(1, 2, figsize=(10, 5))
molecule_graph.plot(axes[0])
mos2_graph.plot(axes[1])



Here the nodes marked with a $\times$ are periodic images.

Note, we are only keeping track of node features in the cell, but they can recieve messages 
from periodic images. 

## SchNet in Pytorch

With the `AtomsGraph`-class defined we can move on to building a message-passing neural 
network. 

### SchNet

Now we're ready to implement SchNet, so lets recap what happens in SchNet:

![SchNet](https://www.researchgate.net/profile/Kristof-Schuett/publication/317954658/figure/fig4/AS:530501098524672@1503492726745/Illustration-of-SchNet-with-an-architectural-overview-left-the-interaction-block.png)

The schnet architecture is shown in the first colun of the figure. 
A few things to note: 
* The first layer is an `embedding` layer, that takes an atomic number and represents it as a vector with 64 elements. 
* The embedding is followed by the interaction (message-passing) layers.
* `Atomwise` just means fully-connected
* Finally the local energy is calculated and summed to yield the total energy.

The interaction block consists of two branches (middle column), the left branch is just a skip connection and 
all the interesting things happen in the right-hand branch.
1. The node features are updated with a fully connected layer. 
2. Messages from other nodes are computed using the `cfconv` layer. 
3. This is followed by more linear layers and a `softplus` activation function.

The `cfconv` layer (right most column) is where message passing is performed. 
1. The distance between nodes `i` and `j` is expanded in a basis of Gaussians in the `rbf` layer. 
2. A filter is formed by passing this vector through a series of linear layers
3. The filter is multplied with the node features elementwise to produce the message.

We will start by defining the embedding layer

In [None]:
class NodeEmbedding(torch.nn.Module):

    def __init__(self, max_z=100, hidden_dim=32):
        super(NodeEmbedding, self).__init__()
        self.n_atom_types = max_z
        self.linear = torch.nn.Linear(self.n_atom_types, hidden_dim, bias=False)

    def forward(self, numbers):
        x = torch.nn.functional.one_hot(numbers-1, num_classes=self.n_atom_types).float()
        return self.linear(x)

This could be simplified by using `torch.nn.Embedding`, but that is less clear about what is going on. 
The embedding first transforms the atomic number to a onehot-vector which when passed through the linear layer picks out the 
specific vector that the element is assigned to. 

Because of the linear layer this becomes a learnable embedding. 

Next lets look at the distance expansion part.

In [None]:
class DistanceEmbedding(torch.nn.Module):

    def __init__(self, distance_dim=32, cutoff=5.0, gamma=10):
        super(DistanceEmbedding, self).__init__()
        self.r0 = torch.nn.Parameter(torch.linspace(0, 1.5*cutoff, distance_dim), requires_grad=False)
        self.gamma = gamma

    def forward(self, distances):
        x = torch.exp(-self.gamma*(distances - self.r0)**2)
        return x

This is relatively simple, the distances are expanded in a basis 

$e_k(d_{ij}) = \exp (-\gamma (d_ij - \mu_k)^2)$

Where we have chosen $\mu_k$ evenly spaced between 0 and 1.5 the cutoff distance.

Torch doesn't have the activation function we want by default, so we can implement that:

In [None]:
class ShiftedSoftPlus(torch.nn.Module):

    def __init__(self):
        super(ShiftedSoftPlus, self).__init__()

    def forward(self, x):
        return torch.log(0.5 * torch.exp(x) + 0.5)


And now we're ready to define the SchNet message passing layer.

Here we use the `MessagePassing`-class from `torch_geometric`, as it is easier to 
efficiently implement this operation in this way. 

In [None]:
class SchNetMessage(MessagePassing):

    def __init__(self, embedding_dim=32, distance_dim=32, activation_fn=ShiftedSoftPlus):
        super().__init__(aggr='add')

        # This is the `CFConv` filter-producing function (CFconv block).
        self.filter = torch.nn.Sequential(
            torch.nn.Linear(distance_dim, embedding_dim),
            activation_fn(),
            torch.nn.Linear(embedding_dim, embedding_dim),
            activation_fn()
        )

        # First atom-wise transformation (Interaction block)
        self.atom_wise_0 = torch.nn.Linear(embedding_dim, embedding_dim)

        # This is the last atom-wise transformation (Interaction block)
        self.atom_wise_1 = torch.nn.Sequential(
            torch.nn.Linear(embedding_dim, embedding_dim),
            activation_fn(),
            torch.nn.Linear(embedding_dim, embedding_dim)
        )

    @staticmethod
    def cosine_cutoff(d, cutoff=5.0):
        return 0.5 * (torch.cos(np.pi * d / cutoff) + 1.0)

    def forward(self, x, edge_index, edge_attr, distances):

        # Calculate the filter: We take care of the cutoff here - ensuring that the filter is zero beyond the cutoff.
        cont_filter = self.filter(edge_attr) * self.cosine_cutoff(distances)

        # Pass node features through an mlp.
        x_out = self.atom_wise_0(x)

        # Propagate messages: 'x_out' now contains the new node features.
        x_out = self.propagate(edge_index, x=x_out, cont_filter=cont_filter)

        # Skip connection & final mlp.
        x_out = x + self.atom_wise_1(x_out)

        return x_out

    def message(self, x_j, cont_filter):
        # Message function: This is the `CFConv` operation. 
        # torch_geometric does fancy stuff behind the scenes to make this efficient, 
        # and we don't need to worry about indexing.
        return x_j * cont_filter

Now we have all the ingredients we need to create the model, so all thats left to do is to 
combined them.

In [None]:
from torch_geometric.utils import scatter
from torch_geometric.nn import SumAggregation

class SchNetModel(torch.nn.Module):

    def __init__(self, embedding_dim=32, distance_dim=300, n_blocks=1, cutoff=5.0,
                 activation_fn=ShiftedSoftPlus):
        super(SchNetModel, self).__init__()
        self.node_embedding = NodeEmbedding(hidden_dim=embedding_dim)
        self.distance_embedding = DistanceEmbedding(distance_dim=distance_dim, cutoff=cutoff)

        self.sch_blocks = torch.nn.ModuleList([
            SchNetMessage(embedding_dim=embedding_dim, 
                          distance_dim=distance_dim, 
                          activation_fn=activation_fn)
                          for _ in range(n_blocks)
            ])
        
        self.energy_head = torch.nn.Sequential(
            torch.nn.Linear(embedding_dim, embedding_dim),
            torch.nn.Softplus(),
            torch.nn.Linear(embedding_dim, 1)
        )

        self.aggr = SumAggregation()

    def forward(self, data, compute_forces=False):
        x = self.node_embedding(data.x) # Embed the atomic numbers.

        # Edge vectors: The vectors between the atoms.
        distance_vectors = data.pos[data.edge_index[1]] - data.pos[data.edge_index[0]] + data.shift_vectors

        # We only use the radial information, so we compute the distances between the atoms.
        distances = torch.norm(distance_vectors, dim=1).reshape(-1, 1)

        # Expand the distances in a basis. 
        edge_attr = self.distance_embedding(distances) 

        for block in self.sch_blocks: # Apply the SchNet blocks.
            x = block(x, data.edge_index, edge_attr, distances) # Iteratively update the node embeddings.

        E_atomic = self.energy_head(x) # Compute the atomic energies.
        E_total = self.aggr(E_atomic, data.batch)

        if compute_forces:
            # We sum the energy of all structures and let torch autograd compute the gradients, which takes into 
            # account which atoms contributed to which energy and thus the forces.
            forces = -torch.autograd.grad(E_total.sum(), data.pos, create_graph=True, retain_graph=True)[0] # Compute the forces.
            return E_total, forces

        return E_total    

### Morse Dataset

In [None]:
from ase.calculators.morse import MorsePotential
from ase import Atoms

def get_atoms(r):
    atoms = Atoms('H2', positions=[[0, 0, 0], [0, 0, r]])
    atoms.set_cell(np.eye(3) * 10)
    atoms.center()
    atoms.calc = MorsePotential()
    F = atoms.get_forces()
    E = atoms.get_potential_energy()
    return atoms, E, F

def get_morse_dataset(n_data, r_max=2.0, r_min=0.85):

    r_values = np.linspace(r_max, r_min, n_data)
    atoms = [None for _ in range(n_data)]
    E = np.zeros(n_data)
    F = np.zeros((n_data, 2, 3))

    for i, r in enumerate(r_values):
        atoms[i], E[i], F[i] = get_atoms(r)

    return atoms, E, F, r_values

atoms, E, F, r_values = get_morse_dataset(100)

fig, ax = plt.subplots()
ax.plot(r_values, E, '-o', label='Energy')

In [None]:
from torch_geometric.loader import DataLoader
from tqdm import trange

batch_size = 16
epochs = 100

# Define the dataset with a loader:
atoms, E, F, r_values = get_morse_dataset(32)
graphs = [AtomsGraph.from_atoms(atoms, cutoff=5.0, energy=e, forces=f) for atoms, e, f in zip(atoms, E, F)]
loader = DataLoader(graphs, batch_size=batch_size, shuffle=True)

# Instantiate a model:
model = SchNetModel(n_blocks=1, embedding_dim=32, distance_dim=256)

# Parameters:
print('Parameters')
total = 0
for name, param in model.named_parameters():
    print(f'\t{name}: {param.shape}, {param.numel()}')
    total += param.numel()
print(f'Total parameters: {total}')

# Make an optimizer: 
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Loss function:
loss = torch.nn.MSELoss()

for epoch in trange(epochs):
    for graph_batch in loader:

        energy_batch = graph_batch.energy
        force_batch = graph_batch.forces

        E_pred, F_pred = model(graph_batch, compute_forces=True)
        E_loss = loss(E_pred, energy_batch)
        F_loss = loss(F_pred, force_batch)

        total_loss = E_loss + F_loss
        model.zero_grad()
        total_loss.backward()
        optimizer.step()


In [None]:
test_data, test_E, test_F, test_r = get_morse_dataset(100)
test_graphs = [AtomsGraph.from_atoms(atoms, cutoff=5.0, energy=e, forces=f) for atoms, e, f in zip(test_data, test_E, test_F)]

test_loader = DataLoader(test_graphs, batch_size=1, shuffle=False)

test_pred = []
for graph_batch in test_loader:
    test_pred.append(model(graph_batch).item())

test_pred = np.array(test_pred).flatten()

fig, ax = plt.subplots()

ax.plot(r_values, E, 'o', label='Energy', color='C0')
ax.plot(test_r, test_E, color='C0')
ax.plot(test_r, test_pred, '-', label='Predicted Energy', color='C1')
ax.legend()

### Exercises: 

1. Try varying the importance of the force and the energy in the loss function. Choose values such that you can convince yourself that both parts work. 
2. Try changing the size of the embedding. Whats the limit at which the network becomes unable to fit the potential?
3. Make sense of the shape and size of the parameters of the model. Which input parameters influence them? 

4. In the original SchNet paper there's a nice figure showing the filters learned on a specific dataset. Fix the below code to produce that kind of figure.

The steps are as follows: 
1. Make a tensor of linearly spaced distances
2. Compute the RBF expansion of this tensor. 
3. Compute the filter for each expansion.

In [None]:
distance = torch.linspace(0, 5*1.25, 100).reshape(-1, 1)

distance_embedding_block = model.distance_embedding
filter_block = model.sch_blocks[0].filter

distance_embedding = distance_embedding_block(distance)
W = filter_block(distance_embedding).detach().numpy()

azm = np.linspace(0, 2 * np.pi, 200)
r, th = np.meshgrid(distance.detach().numpy(), azm)

nrows = 4 # Assumes a embedding dimension of 32.
ncols = 8
sz = 1.5
fig, axes = plt.subplots(nrows, ncols, subplot_kw=dict(projection='polar'), figsize=(ncols*sz, nrows*sz))

for i, ax in enumerate(axes.flatten()):
    z = np.tile(W[:, i], (r.shape[0], 1))
    ax.pcolormesh(th, r, z, cmap='coolwarm')
    ax.axis('off')

