Go over notes and build up from simpler models

1. R AE 
2. R-P AE 
3. R encoder and TS decoder 
4. R-P encoder, TS decoder

In [1]:
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv, GAE
from torch_geometric.utils import train_test_split_edges
from ts_vae.data_processors.grambow_processor import ReactionDataset

## Molecule Autoencoder

In [2]:
def reset(nn):
    def _reset(item):
        if hasattr(item, 'reset_parameters'):
            item.reset_parameters()

    if nn is not None:
        if hasattr(nn, 'children') and len(list(nn.children())) > 0:
            for item in nn.children():
                _reset(item)
        else:
            _reset(nn)

In [3]:
from sklearn.metrics import roc_auc_score, average_precision_score
from torch_geometric.utils import (negative_sampling, remove_self_loops, add_self_loops)
EPS = 1e-15

class GAE(nn.Module):
    def __init__(self, encoder, decoder = None):
        super(GAE, self).__init__()
        self.encoder = encoder
        self.decoder = InnerProductDecoder()
        GAE.reset_parameters(self)

    def reset_parameters(self):
        reset(self.encoder)
        reset(self.decoder) 

    def encode(self, *args, **kwargs):
        return self.encoder(*args, **kwargs)

    def decode(self, *args, **kwargs):
        return self.decoder(*args, **kwargs)

    def recon_loss(self, z, pos_edge_index, neg_edge_index = None):
        """ BCE for input on its reconstruction. """
        pos_loss = -torch.log(self.decoder(z, pos_edge_index, sigmoid = True) + EPS).mean()

        # no self-loops in negative samples
        pos_edge_index, _ = remove_self_loops(pos_edge_index)
        pos_edge_index, _ = add_self_loops(pos_edge_index)
        if neg_edge_index is None:
            neg_edge_index = negative_sampling(pos_edge_index, z.size(0))
        neg_loss = -torch.log(1 - self.decoder(z, neg_edge_index, sigmoid = True) + EPS).mean()

        return pos_loss + neg_loss

    def test(self, z, pos_edge_index, neg_edge_index):
        pos_y = z.new_ones(pos_edge_index.size(1))
        neg_y = z.new_zeros(neg_edge_index.size(1))
        y = torch.cat([pos_y, neg_y], dim=0)

        pos_pred = self.decoder(z, pos_edge_index, sigmoid=True)
        neg_pred = self.decoder(z, neg_edge_index, sigmoid=True)
        pred = torch.cat([pos_pred, neg_pred], dim=0)

        y, pred = y.detach().cpu().numpy(), pred.detach().cpu().numpy()

        return roc_auc_score(y, pred), average_precision_score(y, pred)

class MolEncoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(MolEncoder, self).__init__()
        self.conv = GCNConv(in_channels, out_channels)

    def forward(self, x, edge_index):
        return self.conv(x, edge_index)

class InnerProductDecoder(nn.Module):
    def forward(self, z, edge_index, sigmoid = True):
        value = (z[edge_index[0]] * z[edge_index[1]]).sum(dim = 1)
        return torch.sigmoid(value) if sigmoid else value
    
    def forward_all(self, z, sigmoid = True):
        adj = torch.matmul(z, z.t())
        return torch.sigmoid(adj) if sigmoid else adj


### Training

In [None]:
def train_gae(gae, opt, x, train_pos_edge_index):
    gae.train()
    opt.zero_grad()
    z = gae.encode(x, train_pos_edge_index)
    loss = gae.recon_loss(z, train_pos_edge_index)
    loss.backward()
    opt.step()
    return float(loss)

def test_gae(gae, x, train_pos_edge_index, test_pos_edge_index, test_neg_edge_index):
    gae.eval()
    with torch.no_grad():
        z = gae.encode(x, train_pos_edge_index)
    return gae.test(z, test_pos_edge_index, test_neg_edge_index)

### Reactant Autoencoder

In [4]:
# model data
base_path = r'data/'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
val_ratio = 0
test_ratio = 0.2

# reactant data
r_dataset = ReactionDataset(base_path, geo_file = 'train_r') 
r_data = r_dataset.data
r_data.train_mask = r_data.val_mask = r_data.test_mask = r_data.y = None
r_data = train_test_split_edges(data = r_data, val_ratio = val_ratio, test_ratio = test_ratio)
r_x = r_data.x.to(device)
r_train_pos_edge_index = r_data.train_pos_edge_index.to(device)

# reactant encoder
r_num_node_fs = r_data.num_node_features
r_latent_dim = 2
r_ae = GAE(MolEncoder(r_num_node_fs, r_latent_dim))
r_opt = torch.optim.Adam(r_ae.parameters(), lr = 0.01)

In [5]:
r_ae.reset_parameters()

epochs = 100
for epoch in range(1, epochs + 1):
    loss = train_gae(r_ae, r_opt, r_x, r_data.train_pos_edge_index)
    auc, ap = test_gae(r_ae, r_x, r_data.train_pos_edge_index, r_data.test_pos_edge_index, r_data.test_neg_edge_index)
    if epoch % 10 == 0:
        print('Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}'.format(epoch, auc, ap))

Epoch: 010, AUC: 0.7424, AP: 0.6088
Epoch: 020, AUC: 0.7838, AP: 0.6450
Epoch: 030, AUC: 0.7891, AP: 0.6672
Epoch: 040, AUC: 0.7893, AP: 0.6733
Epoch: 050, AUC: 0.7941, AP: 0.6827
Epoch: 060, AUC: 0.8118, AP: 0.6922
Epoch: 070, AUC: 0.8162, AP: 0.6925
Epoch: 080, AUC: 0.8214, AP: 0.6945
Epoch: 090, AUC: 0.8159, AP: 0.6860
Epoch: 100, AUC: 0.8178, AP: 0.6866


### Product Autoencoder

In [6]:
# model data
base_path = r'data/'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
val_ratio = 0
test_ratio = 0.2

# reactant data
p_dataset = ReactionDataset(base_path, geo_file = 'train_p') 
p_data = p_dataset.data
p_data.train_mask = p_data.val_mask = p_data.test_mask = p_data.y = None
p_data = train_test_split_edges(data = p_data, val_ratio = val_ratio, test_ratio = test_ratio)
p_x = p_data.x.to(device)
p_train_pos_edge_index = p_data.train_pos_edge_index.to(device)

# reactant encoder
p_num_node_fs = p_data.num_node_features
p_latent_dim = 2
p_ae = GAE(MolEncoder(p_num_node_fs, p_latent_dim))
p_opt = torch.optim.Adam(p_ae.parameters(), lr = 0.01)

In [29]:
p_ae.reset_parameters()

epochs = 100
for epoch in range(1, epochs + 1):
    loss = train_gae(p_ae, p_opt, p_x, p_data.train_pos_edge_index)
    auc, ap = test_gae(p_ae, p_x, p_data.train_pos_edge_index, p_data.test_pos_edge_index, p_data.test_neg_edge_index)
    if epoch % 10 == 0:
        print('Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}'.format(epoch, auc, ap))

Epoch: 010, AUC: 0.5636, AP: 0.5001
Epoch: 020, AUC: 0.7952, AP: 0.6576
Epoch: 030, AUC: 0.4720, AP: 0.4673
Epoch: 040, AUC: 0.7494, AP: 0.7111
Epoch: 050, AUC: 0.8159, AP: 0.7243
Epoch: 060, AUC: 0.8428, AP: 0.7289
Epoch: 070, AUC: 0.8612, AP: 0.7397
Epoch: 080, AUC: 0.8670, AP: 0.7454
Epoch: 090, AUC: 0.8691, AP: 0.7465
Epoch: 100, AUC: 0.8713, AP: 0.7495
