Install all needed libraries

In [None]:
!pip install torch_geometric

Import everything that is needed and choose wheter to work on a GPU or a CPU

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch import Tensor
from torch.functional import F
from torch_geometric.data import Batch
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.utils import scatter
from torch_geometric.nn import MessagePassing
import torch.nn.init as init
from typing import Union, Optional

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

`load_data()` is a simple dataloader that splits the data in a reproducible way by setting the seed of the generator.

In [7]:
def load_data(path, batch_size, train_size=0.8, val_size=0.1):
    ### load data
    dataset = QM9(root=path)

    # Calculate split lengths
    total_length = len(dataset)
    train_length = int(train_size * total_length)
    val_length = int(val_size * total_length)
    test_length = total_length - train_length - val_length

    # Perform random split
    train_set, val_set, test_set = torch.utils.data.random_split(dataset,
                                                                    [train_length, val_length, test_length],
                                                                    generator=torch.Generator().manual_seed(42))

    # Create data loaders
    train_loader = DataLoader(train_set, batch_size=batch_size)
    val_loader = DataLoader(val_set, batch_size=batch_size)
    test_loader = DataLoader(test_set, batch_size=batch_size)
    return train_loader, val_loader, test_loader

Choose the parameter for which we train for. Detailed list in readme and in [PyG documentation](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.datasets.QM9.html). Indexes 0 and 5 in the PyG dataset have to be treated specially and we have not implemented this part.

Also, we need to choos hyperparameters for our model.

In [8]:
param = 2 # 0..15, except 0 and 5, see readme
config = {
    "param": param,
    "name":  f"cafa-param-{param}-std",
    "batch_size": 100,
    "train_size": 0.8,
    "test_size":  0.1,

    "num_atoms":      10,
    "num_embeddings": 128,
    "cutoff_dist":    5,
    "hidden_out_dim": 128,

    "epochs":         500,
    "learning_rate":  5e-4,
    "weight_decay":   0.01,
    "smoothing_factor": 0.7,
    "device":           device
}

In [13]:
train_loader, val_loader, test_loader = load_data(f"./data", config["batch_size"], config["train_size"], config["test_size"])

PyG allows us to create special message passing layers by creating a class that defines the messages, aggregation and update functions.
1. We compute the edges of the nearest neighbours that are closer than the cutoff distance 5 \AA.
2. `propagate()` function calls message, aggregate and update based on the first argument that defines the edges of the graph
3. `message()` creates the message from each node to the neighbours it is connected to.
4. `aggregate()` aggregates the messages that are sent to the same atom by summing all the individual messages together. The sum is to be added to the previous scalar and vector represenatations.
5. `update()` is a function of the aggregated messages for each node and produces $\Delta s_i$ and $\Delta\mathbf{v}_i$ to be added to the previous represantations.

In [10]:
class MessageLayer(MessagePassing):


    propagate_type = {"neighbours": Tensor, "s": Optional[Tensor], "v": Optional[Tensor], "r": Optional[Tensor], "neighbours": Optional[Tensor]}

    def __init__(self, num_embeddings, cutoff_dist, device):
        super().__init__(flow="source_to_target")

        self.device = device

        self.num_embeddings = num_embeddings
        self.cutoff_dist = cutoff_dist

        # message block
        self.linear_phi1 = nn.Linear(self.num_embeddings, self.num_embeddings)
        self.linear_phi2 = nn.Linear(self.num_embeddings, 3*self.num_embeddings)
        self.linear_W = nn.Linear(20, 3*self.num_embeddings)

        # update block
        self.linear_U = nn.Linear(self.num_embeddings, self.num_embeddings, bias=False)
        self.linear_V = nn.Linear(self.num_embeddings, self.num_embeddings, bias=False)
        self.linear_update1 = nn.Linear(2*self.num_embeddings, self.num_embeddings)
        self.linear_update2 = nn.Linear(self.num_embeddings, 3*self.num_embeddings)
        self.initialize_weights()

    # embeddings :      [natoms, num_embeddings]
    # equivar_repr :    [3, natoms, num_embeddings]
    # pos :             [natoms, 3]
    # batch :           [natoms]
    def forward(self, embeddings: Tensor, equivar_repr: Tensor, pos: Tensor, batch: Tensor) -> Union[Tensor, Tensor]:
        neighbours = self.get_neighbours_as_edge_index(pos, batch, self.cutoff_dist)

        return self.propagate(neighbours, s=embeddings, v=equivar_repr, r=pos, neighbours=neighbours, size=None)

    # _j is a neighbour, _i is the atom
    # s_j : [num_edges, num_embeddings]
    # v_j : [3, num_edges, num_embeddings]
    # r_i : [num_edges, 3]
    # r_j : [num_edges, 3]
    def message(self, s_j, v_j, r_i, r_j):
        phi = self.linear_phi1(s_j) # [num_edges, num_embeddings]
        phi = F.silu(phi)
        phi = self.linear_phi2(phi) # [num_edges, 3*num_embeddings]

        rel_pos = r_i - r_j         # [num_edges, 3]
        distance = torch.norm(rel_pos, dim=1) # [num_edges]
        cutoff = distance.detach().clone()
        cutoff[distance <= self.cutoff_dist] = 0.5*(torch.cos(torch.pi * distance[distance <= self.cutoff_dist] / self.cutoff_dist) + 1)
        cutoff[distance > 0] = 0
        RBF = cutoff[:, None] * torch.sin(torch.arange(1, 21, device=self.device)[None, :] * torch.pi * distance[:, None] / self.cutoff_dist) / distance[:, None] # [num_edges, 20]
        W = self.linear_W(RBF)      # [num_edges, 3*num_embeddings]

        split = torch.mul(phi, W)   # [num_edges, 3*num_embeddings]
        split1 = split[:, :self.num_embeddings]
        split2 = split[:, self.num_embeddings:2*self.num_embeddings]
        split3 = split[:, 2*self.num_embeddings:]

        delta_s_ij = split1 # [num_edges, num_embeddings]

        v_j = v_j.permute([1,2,0]) # [num_edges, num_embeddings, 3]
        delta_v_ij = torch.mul(v_j, split2[:, :, None]) \
                            + torch.mul(torch.mul(split3[:, :, None], rel_pos[:, None, :]), distance[:, None, None]) # [num_edges, num_embeddings, 3]
        delta_v_ij = delta_v_ij.permute([2, 0, 1]) # [3, natoms, num_embeddings]

        return delta_s_ij, delta_v_ij

    # _j is a neighbour, _i is the atom
    # s_i : [num_edges, num_embeddings]
    # v_i : [3, num_edges, num_embeddings]
    def aggregate(self, message, neighbours, s, v):
        delta_s_ij, delta_v_ij = message

        delta_s_i = scatter(delta_s_ij, neighbours[1], dim=0, reduce="sum") # [natoms, num_embeddings]
        delta_v_i = scatter(delta_v_ij, neighbours[1], dim=1, reduce="sum") # [3, natoms, num_embeddings]

        s_i = s + delta_s_i
        v_i = v + delta_v_i

        return s_i, v_i

    def update(self, agg_message, s, v):
        s_i, v_i = agg_message
        U = self.linear_U(v_i) # [3, natoms, num_embeddings]
        V = self.linear_V(v_i) # [3, natoms, num_embeddings]

        stack = torch.cat([s_i, torch.norm(V, dim=0)], dim=1) # [natoms, 2*num_embeddings]

        stack = self.linear_update1(stack) # [natoms, num_embeddings]
        stack = F.silu(stack)
        split = self.linear_update2(stack) # [natoms, 3*num_embeddings]

        split1 = split[:, :self.num_embeddings]  # First part, contains the first 128 elements in the second dimension
        split2 = split[:, self.num_embeddings: 2*self.num_embeddings]  # Second part, contains the next 128 elements
        split3 = split[:, 2*self.num_embeddings:]  # Third part, contains the last 128 elements

        delta_s = split2 + torch.sum(U * V, dim=0) * split3 # [natoms, num_embeddings]

        delta_v = torch.mul(U, split1[None, :, :]) # [3, natoms, num_embeddings]

        s += delta_s
        v += delta_v

        return s, v

    def get_neighbours_as_edge_index(self, pos, batch, cutoff_dist):
        # count atoms in each molecule
        unique = torch.unique(batch, return_counts=True)
        # i-th row is for i-th atom in data.z, ij element is distance between i-th and j-th atom
        distances = torch.cdist(pos, pos, p=2)
        # select atoms
        neighbours = torch.where(distances <= cutoff_dist, 1, 0)
        # mask other atoms
        neighbours = neighbours * torch.block_diag(*[torch.ones((u,u)) for u in unique[1]]).to(self.device)
        # exclude itself
        neighbours = neighbours - torch.eye(neighbours.shape[0], device=self.device)
        # get neighbours in the form of [2, num_edges]
        neighbours = neighbours.nonzero(as_tuple=False).t()
        neighbours = torch.index_select(neighbours, dim=0, index=torch.tensor([1,0], device=self.device)).type(torch.LongTensor).to(self.device)

        return neighbours

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                # Kaiming Initialization for linear layers
                init.kaiming_uniform_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    init.constant_(m.bias, 0)


Now, we can define the model with the conventional PyTorch methods. The output is one number for each molecule.

In [11]:
class PaiNN(nn.Module):
    def __init__(self, num_atoms, num_embeddings, cutoff_dist, hidden_out_dim, device, message_layers=3):
        super().__init__()
        self.num_embeddings = num_embeddings
        self.cutoff_dist = cutoff_dist
        self.device = device
        self.message_layers = message_layers

        self.embeddings = nn.Embedding(num_atoms, num_embeddings, padding_idx=0)
        # self.message = Message(num_embeddings, cutoff_dist, device)
        # self.update = Update(num_embeddings, cutoff_dist, device)

        self.messagelayer = MessageLayer(num_embeddings, cutoff_dist, device)

        # # multiple message passing layers
        # self.messageLayers = []
        # for _ in range(message_layers):
        #     self.messageLayers.append(MessageLayer(num_embeddings, cutoff_dist, device))

        self.linear_out1 = nn.Linear(num_embeddings, hidden_out_dim)
        self.linear_out2 = nn.Linear(hidden_out_dim, 1)

    def forward(self, data: Batch) -> Tensor:
        # 1. Initialize inputs (s and v)
        embeddings = self.embeddings(data.z) # [batch_size, num_embeddings]
        equivariant_repr = torch.zeros((3, len(data.z), self.num_embeddings), device=self.device)

        # 2. Send messages and make updates
        for _ in range(self.message_layers):
            embeddings, equivariant_repr = self.messagelayer(embeddings, equivariant_repr, data.pos, data.batch)
            # embeddings, equivariant_repr = self.message(embeddings, equivariant_repr, data.pos, data.batch)
            # embeddings, equivariant_repr = self.update(embeddings, equivariant_repr, data.pos, data.batch)

        # For passing through multiple message passing layers use:
        # for messagelayer in self.messageLayers:
        #     embeddings, equivariant_repr = messagelayer(embeddings, equivariant_repr, data.pos, data.batch)

        # 3. Final linear layer
        out = self.linear_out1(embeddings) # [batch_size, num_embeddings] -> [batch_size, hidden_out_dim]
        out = F.silu(out)
        out = self.linear_out2(out) # [batch_size, 1]

        out = scatter(out, data.batch, dim=0, reduce="sum")

        return out

In [14]:
model = PaiNN(num_atoms=config["num_atoms"], num_embeddings=config["num_embeddings"],
              cutoff_dist=config["cutoff_dist"], hidden_out_dim=config["hidden_out_dim"], device=device)
model.to(device)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"])
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)

In [15]:
def training_loop(model, optimizer, criterion, scheduler, train_loader, val_loader, config, save_path, from_epoch=0, smoothed_val_loss=None):
    epochs = config["epochs"]
    device = torch.device(config["device"])
    param = config["param"]
    smoothing_factor = config["smoothing_factor"]

    mean = 0
    length = 0
    for batch in train_loader:
        mean += batch.y[:, param].sum()
        length += len(batch.y[:, param])
    for batch in val_loader:
        mean += batch.y[:, param].sum()
        length += len(batch.y[:, param])
    mean = mean / length
    print(f"Mean of data {mean}")

    std = 0
    for batch in train_loader:
        std += torch.sum((batch.y[:, param] - mean)**2)
    for batch in val_loader:
        std += torch.sum((batch.y[:, param] - mean)**2)
    std = torch.sqrt(std/(length-1))
    print(f"Standard dev. {std}")

    for epoch in range(from_epoch, epochs):
        model.train()
        total_train_loss, total_train_mae = 0.0, 0.0


        for batch in train_loader:
            batch.to(device)
            optimizer.zero_grad()
            #Forward pass
            std_output = model(batch)
            output = std_output.squeeze()*std + mean
            # Assuming 'output' and 'batch.y' are aligned for loss calculation
            loss = criterion(1000*output, 1000*(batch.y[:, param] - mean)/std)
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()
            total_train_mae += F.l1_loss(1000*output*std+mean, 1000*batch.y[:, param]).item()



        avg_train_loss = total_train_loss / len(train_loader)
        avg_train_mae = total_train_mae / len(train_loader)

        print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {avg_train_loss:.4f}, Train L1 Loss: {avg_train_mae:.4f}' )


        # Validation phase
        model.eval()
        total_val_loss, total_val_mae = 0.0, 0.0


        with torch.no_grad():
            for batch in val_loader:
                batch.to(device)
                std_output = model(batch)
                output = std_output.squeeze()*std + mean
                loss = criterion(1000*output, 1000*(batch.y[:, param]-mean)/std)
                total_val_loss += loss.item()
                total_val_mae += F.l1_loss(1000*output, 1000*(batch.y[:, param] - mean)/std).item()



        avg_val_loss = total_val_loss / len(val_loader)
        avg_val_mae = total_val_mae / len(val_loader)
        # Apply exponential smoothing to validation loss
        if smoothed_val_loss is None:
            smoothed_val_loss = avg_val_loss
        else:
            smoothed_val_loss = (smoothing_factor * smoothed_val_loss) + ((1 - smoothing_factor) * avg_val_loss)

        print(f'Epoch [{epoch+1}/{epochs}], Validation Loss: {avg_val_loss:.4f}, Validation L1: {avg_val_mae:.4f}, Smoothed Validation Loss: {smoothed_val_loss:.4f}')

        # Adjust learning rate based on smoothed validation loss
        scheduler.step(smoothed_val_loss)

        # wandb.log({"train_loss": avg_train_loss, "train l1 loss": avg_train_mae, "val loss": avg_val_loss, "val l1 loss": avg_val_mae, "smoothed val loss":smoothed_val_loss })
        if (epoch + 1) % 10 == 0:
            # Save the model
            save_dict = {
                "epoch": epoch,
                "config": config,
                "model": model.state_dict(),
                "smoothed_val_loss": smoothed_val_loss,
                "scheduler": scheduler.state_dict(),
                "optimizer": optimizer.state_dict()
            }
            torch.save(save_dict, f"{save_path}/epoch_{epoch+1}.pth")

    save_dict = {
        "epoch": epoch,
        "config": config,
        "model": model.state_dict(),
        "smoothed_val_loss": smoothed_val_loss,
        "scheduler": scheduler.state_dict(),
        "optimizer": optimizer.state_dict()
    }
    torch.save(save_dict, f"{save_path}/final.pth")
    # wandb.finish()

In [None]:
training_loop(model, optimizer, criterion, scheduler, train_loader, val_loader, config, ".",
                        from_epoch=0, smoothed_val_loss=None)