In [8]:
#read config
from easydict import EasyDict
import yaml

def load_config(path):
    with open(path, 'r') as f:
        return EasyDict(yaml.safe_load(f))

config_file = './configs/vector_attention.yml'
config = load_config(config_file)

In [5]:
from utils.chem import read_sdf
from dataset.featurizer import featurize_mol
from dataset.molgeom import torchify_dict
from torch_geometric.data import Data, Batch
mols = read_sdf('./data/mol.sdf')

mol_feat_dicts = [featurize_mol(mol) for mol in mols]
mols_feats = [torchify_dict(mol_feat_dict) for mol_feat_dict in mol_feat_dicts]
mol_datas = [
    Data(x=mols_feat['atom_feature'],pos=mols_feat['pos'],
    edge_index=mols_feat['bond_index'], edge_attr=mols_feat['bond_feature'], element=mols_feat['element'])
    for mols_feat in mol_feat_dicts]

batch = Batch.from_data_list(mol_datas)

In [96]:
from models.gvat import *
import torch
class VectorAttention(Module):
    # SO(3) version 
    # Global structure and local structure
    # The vector \times vector attention remains equivariant: Frobnenius norm, i.e., perform summation on the last dimension twice
    # Equivaraint example could be referred to the AttentionEdges in ./models/tri_attention.py
    def __init__(self, node_dim=45, edge_dim=5, node_hiddens=[256, 64], edge_hidden=64, key_channels=128, num_heads=4, num_blocks=4, k=32, cutoff=10.0):
        super().__init__()
        self.node_hiddens = node_hiddens
        self.edge_hidden = edge_hidden
        self.key_channels = key_channels  # not use
        self.num_heads = num_heads  # not use
        self.num_blocks = num_blocks
        self.k = k
        self.cutoff = cutoff
        self.node_dim = node_dim
        self.edge_dim = edge_dim
        self.atom_sca_mapper = Linear(node_dim, node_hiddens[0])
        self.atom_vec_mapper = VNLinear(1, node_hiddens[1])
        self.interactions = ModuleList()
        for _ in range(num_blocks):
            block = AttentionInteractionBlockVN(
                node_hiddens=node_hiddens,
                edge_hidden=edge_hidden,
                num_edge_types=edge_dim,
                key_channels=key_channels,
                num_heads=num_heads,
                cutoff = cutoff
            )
            self.interactions.append(block)

    def forward(self, node_attr, pos, edge_index, edge_feature):
    
        if len(pos.shape) != 3:
            vector_feature = pos.unsqueeze(1) # torch.Size([batch, 1, 3])
        else:
            vector_feature = pos
            
        edge_vector = pos[edge_index[0]] - pos[edge_index[1]]
        atom_sca_hidden0 = self.atom_sca_mapper(node_attr)
        atom_vec_hidden0 = self.atom_vec_mapper(vector_feature)
        # There are two possible attention could be designed
        # First, Edge vector attention with node features
        # second, edge scalar attention with 
        # self-attention on each atom

        h = [atom_sca_hidden0, atom_vec_hidden0]
        for interaction in self.interactions:
            delta_h = interaction(h, edge_index, edge_feature, edge_vector)
            h[0] = h[0] + delta_h[0]
            h[1] = h[1] + delta_h[1]
        # global could be incorporated here
        return h

In [97]:
gvat = VectorAttention()

In [106]:
sca_feature, vec_feature = gvat(batch.x, batch.pos, batch.edge_index, batch.edge_attr)
# perform graph attention

In [122]:
from torch_scatter import scatter_mean
# for example, perform the average of node features of each graph

prediction_layer = Linear(sca_feature.shape[-1],1)
prediction = prediction_layer(sca_feature).squeeze(dim=-1) # squeeze to the 1-dim vector


scatter_mean(prediction, batch.batch)
# then the loss could be obtained

tensor([0.1255, 0.1244, 0.1127, 0.1149, 0.1222, 0.1163, 0.1121, 0.1120, 0.1105,
        0.1119], grad_fn=<DivBackward0>)

In [99]:
from utils.mics import get_optimizer, get_scheduler

In [105]:
# the use of scheduler and optimizer could be found at ResGen: https://github.com/HaotianZhangAI4Science/ResGen
optimizer = get_optimizer(config.train.optimizer,gvat)
scheduler = get_scheduler(config.train.scheduler,optimizer)