In [6]:
# atom features, bond type, graph connectivity, (x,y,z) coordinates  
#   - when we encode the graph, we're doing it through atom features, bond types, and connectivitity (i.e. which atoms are connected to each other and how?)
#   - the coordinate-based representation is particularly useful 
#   - for reaction centre, find adjacency matrix differences then map to 3D matrix

# convert MLP to GNN by swapping torch.nn.linear with PyG's GNN operators e.g. GCN layer

In [22]:
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv, GAE
from torch_geometric.utils import train_test_split_edges

import sys
sys.path.insert(0, "Users/rmhavij/3d-reactions/") # azure again
from ts_vae.data_processors.grambow_processor import ReactionDataset

In [23]:
# normal: base_path = r'data/'
# azure
base_path = r'Users/rmhavij/3d-reactions/data/'
r_dataset = ReactionDataset(base_path, geo_file = 'train_r') 

data = r_dataset.data
data.train_mask = data.val_mask = data.test_mask = data.y = None
data = train_test_split_edges(data = data, val_ratio = 0, test_ratio = 0.2)

In [24]:
# should rename this like molecule encoder

class LinearEncoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(LinearEncoder, self).__init__()

        # use single GC to get embeddings for nodes here
        self.conv = GCNConv(in_channels, out_channels)
    
    def forward(self, x, edge_index):
        # no relu for linearity
        return self.conv(x, edge_index)
    


In [25]:
num_node_fs = r_dataset.data.num_node_features # = 11
out_channels = 2

# build model and optimiser
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = GAE(LinearEncoder(num_node_fs, out_channels))
model = model.to(device)
x = data.x.to(device)
train_pos_edge_index = data.train_pos_edge_index.to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.01)


In [None]:
def train():
    model.train()
    opt.zero_grad()
    z = model.encode(x, train_pos_edge_index)
    loss = model.recon_loss(z, train_pos_edge_index)
    loss.backward()
    opt.step()
    return float(loss)

def test(pos_edge_index, neg_edge_index):
    model.eval()
    with torch.no_grad():
        z = model.encode(x, train_pos_edge_index)
    return model.test(z, pos_edge_index, neg_edge_index)

epochs = 10
for epoch in range(1, epochs + 1):
    loss = train()
    auc, ap = test(data.test_pos_edge_index, data.test_neg_edge_index)
    print('Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}'.format(epoch, auc, ap))

In [31]:
# build models and optimiser
base_path = r'Users/rmhavij/3d-reactions/data/'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# reactant data
r_dataset = ReactionDataset(base_path, geo_file = 'train_r') 
reactant_data = r_dataset.data
reactant_data.train_mask = reactant_data.val_mask = reactant_data.test_mask = reactant_data.y = None
reactant_data = train_test_split_edges(data = reactant_data, val_ratio = 0, test_ratio = 0.2)

# reactant encoder dimensions
num_node_fs = r_dataset.data.num_node_features # = 11
out_channels = 2

# reactant encoder
reactant_encoder = GAE(LinearEncoder(num_node_fs, out_channels))
reactant_encoder = reactant_encoder.to(device)
x = reactant_data.x.to(device)
reactant_pos_edges = reactant_data.train_pos_edge_index.to(device)
reactant_opt = torch.optim.Adam(reactant_encoder.parameters(), lr=0.01)

# product data
p_dataset = ReactionDataset(base_path, geo_file = 'train_p') 
product_data = p_dataset.data
product_data.train_mask = product_data.val_mask = product_data.test_mask = product_data.y = None
product_data = train_test_split_edges(data = product_data, val_ratio = 0, test_ratio = 0.2)

# product data dimensions
num_node_fs = p_dataset.data.num_node_features # = 11 (this is same as reactants here but useful anyway)
out_channels = 2

# product encoder
product_encoder = GAE(LinearEncoder(num_node_fs, out_channels))
product_encoder = product_encoder.to(device)
x = product_data.x.to(device)
product_pos_edges = product_data.train_pos_edge_index.to(device)
product_opt = torch.optim.Adam(product_encoder.parameters(), lr=0.01)

In [None]:
class TSDecoder(torch.nn.Module):
    # Decoder for TS
    # takes reactant and product z's, combines them, then decodes to TS

    def forward(self, reactant_z, product_z, reactant_edge_index, product_edge_index, sigmoid=True):
        # ref: inner product decoder
        
        """ Decode combined reactant and product latent embeddings into edge probabilities
            for the given node-pairs of (reactant and product) edge_index.
        """
        
        value = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=1)
        return torch.sigmoid(value) if sigmoid else value

    def forward_all(self, reactant_z, product_z, sigmoid=True):
        """ Decode latent embeddings into probabilistic adjacenecy matrix. """
        adj = torch.matmul(z, z.t())
        return torch.sigmoid(adj) if sigmoid else adj

In [None]:
# TODO: does TS need its own GAE: encoder for R and P, then decodes to TS

class TSGAE(nn.Module):
    # TS GAE

    def __init__(self, reactant_encoder, product_encoder, ts_decoder, in_channels, out_channels):

        super(TSGAE, self).__init__()

        self.reactant_encoder = reactant_encoder
        self.product_encoder = product_encoder
        self.ts_decoder = ts_decoder

        TSGAE.reset_parameters(self)

    def reset_parameters(self):
        reset(self.reactant_encoder)
        reset(self.product_encoder)
        reset(self.ts_decoder)

    def combine_reactant_and_product(self, *args, **kwargs):
        """ Run reactant and product encoders to compute node-wise latent variables. """
        
        # run encoders for reactant and product, then combine

        # TODO: should this be the combine function instead? probably, yeah

        # return self.encoder(*args, **kwargs)
        return 
        

    def decode(self, *args, **kwargs):
        """Runs the TS decoder to decode to  and computes edge probabilities."""
        return self.ts_decoder(*args, **kwargs)

    def ts_creation_loss(self, reactant_z, reactant_pos_edges, reactant_neg_edges=None, product_z, product_pos_edges, product_neg_edges=None):
        """ Compute BCE for positive edges and negative sampled (optional) edges. """

        

        pos_loss = -torch.log(self.decoder(z, pos_edge_index, sigmoid=True) + EPS).mean()

        # Do not include self-loops in negative samples
        pos_edge_index, _ = remove_self_loops(pos_edge_index)
        pos_edge_index, _ = add_self_loops(pos_edge_index)
        if neg_edge_index is None:
            neg_edge_index = negative_sampling(pos_edge_index, z.size(0))
        neg_loss = -torch.log(1 - self.decoder(z, neg_edge_index, sigmoid=True) + EPS).mean()

        return pos_loss + neg_loss


    def test(self, z, pos_edge_index, neg_edge_index):
        """ Compute area under ROC curve (AUC) and average precision (AP) scores. """

        pos_y = z.new_ones(pos_edge_index.size(1))
        neg_y = z.new_zeros(neg_edge_index.size(1))
        y = torch.cat([pos_y, neg_y], dim=0)

        pos_pred = self.decoder(z, pos_edge_index, sigmoid=True)
        neg_pred = self.decoder(z, neg_edge_index, sigmoid=True)
        pred = torch.cat([pos_pred, neg_pred], dim=0)

        y, pred = y.detach().cpu().numpy(), pred.detach().cpu().numpy()

        return roc_auc_score(y, pred), average_precision_score(y, pred)


def reset(nn):
    def _reset(item):
        if hasattr(item, 'reset_parameters'):
            item.reset_parameters()

    if nn is not None:
        if hasattr(nn, 'children') and len(list(nn.children())) > 0:
            for item in nn.children():
                _reset(item)
        else:
            _reset(nn)

In [None]:
def train_individual_geometry(geometry_encoder, opt, pos_edge_indices):
    # use this on reactant or product
    geometry_encoder.train()
    opt.zero_grad()
    z = geometry_encoder.encode(x, pos_edge_indices)
    loss = geometry_encoder.recon_loss(z, pos_edge_indices)
    loss.backward()
    opt.step()
    return z, float(loss)

def train_reaction():
    # train reactant and product together and decode to TS
    # TODO: pass in encoders, etc.?

    reactant_z, reactant_loss = train_individual_geometry(reactant_encoder, reactant_opt, reactant_pos_edges)
    product_z, product_loss = train_individual_geometry(product_encoder, product_opt, product_pos_edges)
    # TODO: how to use the losses?

    # combine latent representation of reactant and product
    ts_initialisation = combine(reactant_z, product_z)
    
    # need some index in order to reconstruct the TS from this
    

    return


In [None]:
# lucky's work
# PairFeatures: a manual MP I think. it has to be otherwise what he's doing isn't a GNN at all.

# set edges
#   iterate:
#       compute features (i.e. MP) -> MLP(features) -> update edges
#       compute features (i.e. MP) -> MLP(MLP(edges)) -> update vertices