In [1]:
import pandas as pd

import torch
from torch_geometric.data import Batch
from torch_geometric.loader import DataLoader
from torch.functional import F

from rdkit import Chem

from tqdm import tqdm

from models.gcn_molclr import GCN
from molclr import MolCLR
from dataset.dataset_contrastive import USPTO50_contrastive

  from .autonotebook import tqdm as notebook_tqdm


Please install apex for mixed precision training from: https://github.com/NVIDIA/apex


In [2]:
uspto_triplets_dataset = pd.read_pickle('dataset/uspto_50_retrieval.pickle')
USPTO_triplets_dataclass = USPTO50_contrastive(uspto_triplets_dataset, return_index=True, split='all')
uspto_triplets_dataset

Unnamed: 0,reactants_mol,products_mol,reaction_type,set,exclude_indices
0,<rdkit.Chem.rdchem.Mol object at 0x7f7ea004eca0>,<rdkit.Chem.rdchem.Mol object at 0x7f7e9bf5de40>,<RX_1>,train,"[0, 1]"
1,<rdkit.Chem.rdchem.Mol object at 0x7f7ea004e700>,<rdkit.Chem.rdchem.Mol object at 0x7f7e9bf5de40>,<RX_1>,train,"[0, 1]"
2,<rdkit.Chem.rdchem.Mol object at 0x7f7ea004e660>,<rdkit.Chem.rdchem.Mol object at 0x7f7e9bfede90>,<RX_6>,train,[2]
3,<rdkit.Chem.rdchem.Mol object at 0x7f7ea004e5c0>,<rdkit.Chem.rdchem.Mol object at 0x7f7e9bfbdee0>,<RX_9>,train,"[3, 4]"
4,<rdkit.Chem.rdchem.Mol object at 0x7f7ea004e520>,<rdkit.Chem.rdchem.Mol object at 0x7f7e9bfbdee0>,<RX_9>,train,"[3, 4]"
...,...,...,...,...,...
85533,<rdkit.Chem.rdchem.Mol object at 0x7f7e9bfbdcb0>,<rdkit.Chem.rdchem.Mol object at 0x7f7e99b86f20>,<RX_7>,test,[85533]
85534,<rdkit.Chem.rdchem.Mol object at 0x7f7e9bfbdd00>,<rdkit.Chem.rdchem.Mol object at 0x7f7e99b66f70>,<RX_10>,test,"[85534, 85535]"
85535,<rdkit.Chem.rdchem.Mol object at 0x7f7e9bfbdd50>,<rdkit.Chem.rdchem.Mol object at 0x7f7e99b66f70>,<RX_10>,test,"[85534, 85535]"
85536,<rdkit.Chem.rdchem.Mol object at 0x7f7e9bfedda0>,<rdkit.Chem.rdchem.Mol object at 0x7f7e99afefc0>,<RX_1>,test,"[85536, 85537]"


### Initialising GCN model and loading finetuned weights

In [3]:
gcn_model = GCN(feat_dim=512)
# gcn_model.load_state_dict(torch.load('ckpt/TripletMarginCosineDistanceCheckpoints/checkpoints/model.pth'))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gcn_model.eval()
gcn_model.to(device)

GCN(
  (x_embedding1): Embedding(119, 300)
  (x_embedding2): Embedding(3, 300)
  (gnns): ModuleList(
    (0-4): 5 x GCNConv()
  )
  (batch_norms): ModuleList(
    (0-4): 5 x BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (feat_lin): Linear(in_features=300, out_features=512, bias=True)
  (out_lin): Sequential(
    (0): Linear(in_features=512, out_features=512, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=512, out_features=256, bias=True)
  )
)

In [4]:
uspto_graph_retrieval_dataloader = DataLoader(USPTO_triplets_dataclass, batch_size=32, shuffle=False, num_workers=16, pin_memory=True)

### Converting the reactants and product graphs to embeddings using the GCN model and storing in columns reactants_embedding and products_embedding

In [5]:
# make reactants_embedding and products_embedding column in the dataframe and save it
reactants_embedding = []
products_embedding = []

with torch.no_grad():
    for bn, (anchor, positive, negative, index) in enumerate(tqdm(uspto_graph_retrieval_dataloader)):
        anchor = anchor.pin_memory().to(device, non_blocking=True)
        positive = positive.pin_memory().to(device, non_blocking=True)

        anchor_embedding = gcn_model(anchor)
        positive_embedding = gcn_model(positive)

        reactants_embedding.extend(positive_embedding.cpu().detach().numpy())
        products_embedding.extend(anchor_embedding.cpu().detach().numpy())

uspto_triplets_dataset['reactants_embedding'] = reactants_embedding
uspto_triplets_dataset['products_embedding'] = products_embedding

  0%|          | 0/2674 [00:00<?, ?it/s]

  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = torch.cat(values, dim=cat_dim or 0, out=out)
  value = 

In [7]:
# save to a pickle file
uspto_triplets_dataset.to_pickle('uspto50_random_model.pickle')

In [6]:
uspto_triplets_dataset

Unnamed: 0,reactants_mol,products_mol,reaction_type,set,exclude_indices,reactants_embedding,products_embedding
0,<rdkit.Chem.rdchem.Mol object at 0x7f7ea004eca0>,<rdkit.Chem.rdchem.Mol object at 0x7f7e9bf5de40>,<RX_1>,train,"[0, 1]","[0.6397088, -1.6009799, -2.5030267, -0.1623832...","[0.7623692, -1.5720319, -2.6945984, -0.4495493..."
1,<rdkit.Chem.rdchem.Mol object at 0x7f7ea004e700>,<rdkit.Chem.rdchem.Mol object at 0x7f7e9bf5de40>,<RX_1>,train,"[0, 1]","[0.77561533, -1.5849286, -2.7274234, -0.475469...","[0.762369, -1.5720319, -2.6945984, -0.449549, ..."
2,<rdkit.Chem.rdchem.Mol object at 0x7f7ea004e660>,<rdkit.Chem.rdchem.Mol object at 0x7f7e9bfede90>,<RX_6>,train,[2],"[0.80680054, -1.7902905, -2.9642751, -0.377863...","[0.8163652, -1.8162354, -3.0039322, -0.3777662..."
3,<rdkit.Chem.rdchem.Mol object at 0x7f7ea004e5c0>,<rdkit.Chem.rdchem.Mol object at 0x7f7e9bfbdee0>,<RX_9>,train,"[3, 4]","[0.6237804, -1.2752599, -2.1893618, -0.3953007...","[0.74899876, -1.5312753, -2.6279905, -0.453936..."
4,<rdkit.Chem.rdchem.Mol object at 0x7f7ea004e520>,<rdkit.Chem.rdchem.Mol object at 0x7f7e9bfbdee0>,<RX_9>,train,"[3, 4]","[0.78942806, -1.6140804, -2.7687526, -0.471496...","[0.74899924, -1.5312757, -2.6279907, -0.453935..."
...,...,...,...,...,...,...,...
85533,<rdkit.Chem.rdchem.Mol object at 0x7f7e9bfbdcb0>,<rdkit.Chem.rdchem.Mol object at 0x7f7e99b86f20>,<RX_7>,test,[85533],"[0.74037355, -1.5510249, -2.6315265, -0.395771...","[0.73333186, -1.5179448, -2.5880501, -0.417379..."
85534,<rdkit.Chem.rdchem.Mol object at 0x7f7e9bfbdd00>,<rdkit.Chem.rdchem.Mol object at 0x7f7e99b66f70>,<RX_10>,test,"[85534, 85535]","[0.7819213, -1.6302356, -2.7852437, -0.4385047...","[0.760665, -1.5546442, -2.6778135, -0.47068048..."
85535,<rdkit.Chem.rdchem.Mol object at 0x7f7e9bfbdd50>,<rdkit.Chem.rdchem.Mol object at 0x7f7e99b66f70>,<RX_10>,test,"[85534, 85535]","[0.34974918, -0.71762013, -1.2140527, -0.24230...","[0.7606651, -1.554644, -2.6778138, -0.47068062..."
85536,<rdkit.Chem.rdchem.Mol object at 0x7f7e9bfedda0>,<rdkit.Chem.rdchem.Mol object at 0x7f7e99afefc0>,<RX_1>,test,"[85536, 85537]","[0.77234185, -1.6063658, -2.7360032, -0.427054...","[0.7837186, -1.6214052, -2.7719252, -0.4481021..."
