In [10]:
import torch
import torch.nn
import pickle
from pathlib import Path
import numpy as np
from functools import partial
import sys

sys.path.insert(0, "/workspace/nbs")

from molbart.models.pre_train import BARTModel
from molbart.decode import DecodeSampler
from rdkit import Chem

ModuleNotFoundError: No module named 'molbart.models'

## Load Tokenizer and Model

The following code will load a pickled tokenizer and model checkpoint.

In [2]:
def load_tokenizer(tokenizer_path):
    """Load pickled tokenizer
       
       Params:
           tokenizer_path: str, path to pickled tokenizer
    
       Returns:
           MolEncTokeniser tokenizer object
    """
    
    tokenizer_path = Path(tokenizer_path)
    
    with open(tokenizer_path, 'rb') as fh:
        tokenizer = pickle.load(fh)
        
    return tokenizer


def load_model(model_checkpoint_path, tokenizer, max_seq_len):
    """Load saved model checkpoint
    
       Params:
           model_checkpoint_path: str, path to saved model checkpoint
           tokenizer: MolEncTokeniser tokenizer object
           max_seq_len: int, maximum sequence length
        
       Returns:
           MolBART trained model
    """
    
    sampler = DecodeSampler(tokenizer, max_seq_len)
    pad_token_idx = tokenizer.vocab[tokenizer.pad_token]

    bart_model = BARTModel.load_from_checkpoint(model_path, 
                                                decode_sampler=sampler, 
                                                pad_token_idx=pad_token_idx)
    bart_model.sampler.device = "cuda"
    return bart_model.cuda()
    

tokenizer_path = '/data/training_data/mol_opt_tokeniser.pickle'
model_path = '/data/training_data/az_molbart_pretrain.ckpt'

max_seq_len = 64
tokenizer = load_tokenizer(tokenizer_path)
bart_model = load_model(model_path, tokenizer, max_seq_len)

## Interpolation Functions

The following are updated versions of the interpolation functions below. These versions should be used instead of those below. Key changes vs the original functions are the ability to set padding for smiles tokens and batch-wise calculation of the interpolated embeddings. 

In [3]:
def smiles2embedding(smiles, tokenizer, pad_length=None):
    """Calculate embedding and padding mask for smiles with optional extra padding
    
       Params
           smiles: string, input SMILES molecule
           tokenizer: MolEncTokeniser tokenizer object
           pad_length: optional extra
           
       Returns
           embedding array and boolean mask
    """
    
    assert isinstance(smiles, str)
    if pad_length:
        assert pad_length >= len(smiles) + 2
        
    tokens = tokenizer.tokenise([smiles], pad=True)

    # Append to tokens and mask if appropriate
    if pad_length:
        for i in range(len(tokens['original_tokens'])):        
            n_pad = pad_length - len(tokens['original_tokens'][i])
            tokens['original_tokens'][i] += [tokenizer.pad_token] * n_pad
            tokens['pad_masks'][i] += [1] * n_pad

    token_ids = torch.tensor(tokenizer.convert_tokens_to_ids(tokens['original_tokens'])).cuda().T
    pad_mask = torch.tensor(tokens['pad_masks']).bool().cuda().T
    encode_input = {"encoder_input": token_ids, "encoder_pad_mask": pad_mask}

    embedding = bart_model.encode(encode_input)
    torch.cuda.empty_cache()
    
    return embedding, pad_mask


def interpolate_molecules(smiles1, smiles2, num_interp, tokenizer, bart_model, k=1):
    """Interpolate between two molecules in embedding space.
    
       Params
           smiles1: str, input SMILES molecule
           smiles2: str, input SMILES molecule
           num_interp: int, number of molecules to interpolate
           tokenizer: MolEncTokeniser tokenizer object
           bart_model: MolBART trained model
           k: number of molecules for beam search, default 1. Can increase if there are issues with validity
           
       Returns
           list of interpolated smiles molecules
    """
    
    pad_length = max(len(smiles1), len(smiles2)) + 2 # add 2 for start / stop
    embedding1, pad_mask1 = smiles2embedding(smiles1, tokenizer, pad_length=pad_length)
    embedding2, pad_mask2 = smiles2embedding(smiles2, tokenizer, pad_length=pad_length)

    scale = torch.linspace(0.0, 1.0, num_interp+2)[1:-1] # skip first and last because they're the selected molecules
    scale = scale.unsqueeze(0).unsqueeze(-1).cuda()
    interpolated_emb = torch.lerp(embedding1, embedding2, scale).permute(1, 0, 2).cuda()
    combined_mask = (pad_mask1 & pad_mask2).bool().cuda()

    batch_size = 1 # TODO: parallelize this loop as a batch
    smiles_interp_list = []
    
    for memory in interpolated_emb:
        decode_fn = partial(bart_model._decode_fn, mem_pad_mask=combined_mask, memory=memory)
        mol_strs, log_lhs = bart_model.sampler.beam_decode(decode_fn, batch_size=batch_size, k=k)
        mol_strs = sum(mol_strs, []) # flatten list
        
        for smiles in mol_strs:
            mol = Chem.MolFromSmiles(smiles)
            if (mol is not None) and (smiles not in smiles_interp_list):
                smiles_interp_list.append(smiles)
                break
                
    return smiles_interp_list

In [4]:
smiles1 = "CC(=O)OC1=CC=CC=C1C(=O)O"
smiles2 = "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O"
num_interp = 10
k = 1

interpolate_molecules(smiles1, smiles2, num_interp, tokenizer, bart_model, k=k)

['C1(=O)C(C)=CC(=O)C(C(=O)C)=C1',
 'C1(=O)C(C)=C(C)C(=O)C(C)=C1O',
 'C1(=O)C(C)=C(C)C(=O)C(C)=C1',
 'C1(=O)C(C)=C(C)C(=O)C(C)=C1C',
 'C1(C)=CC(=O)C(C)=C(C)C1=O',
 'C1(C)=C(C(C)C)C(=O)C(C)=CC1=O',
 'C1(C)=CC(=O)C(CCCCC)=C(C)C1',
 'C1(C)=C(CCCCCC)C(=O)C=C(C)C1=O']

## Previous Version

These are the previous versions from Rahul M. They have been preserved for testing only. There are also several bugs in the interpolation function as noted below.

In [5]:
def is_valid_molecule(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is not None:
        return True
    return False

def ORIGsmiles2embedding(smiles1, smiles2):
    
    # Tokenize smiles and create masks
    tokens = tokenizer.tokenise([smiles1, smiles2], pad=True)
    token_ids = torch.tensor(tokenizer.convert_tokens_to_ids(tokens['original_tokens'])).cuda().T
    pad_mask = torch.tensor(tokens['pad_masks']).bool().cuda().T
    encode_input = {"encoder_input": token_ids, "encoder_pad_mask": pad_mask}
    
    # Calculate the embedding
    embedding = bart_model.encode(encode_input)
    torch.cuda.empty_cache()
    
    return embedding, pad_mask


def ORIGinterpolate_molecules(molecule1, molecule2):
    mol_emb, mol_mask = ORIGsmiles2embedding(molecule1, molecule2)
    interp_weights = np.linspace(0.1, 0.9, num=10) # BUG -- assumes interpolation starts/ends at 0.1 and 0.9, respectively
    for weight in interp_weights:
        interpolated_emb = torch.lerp(mol_emb[:, 0, :], mol_emb[:, 1, :], torch.full_like(mol_emb[:, 0, :], weight))
        combined_mask = (mol_mask[:, 0] | mol_mask[:, 1]).unsqueeze(0).T # BUG -- will default to shortest mask instead of longest
        interpolated_emb = interpolated_emb.unsqueeze(0).permute(1, 0, 2)
        mem_mask = mol_mask.clone().cuda()
        bart_model.sampler.device = "cuda"
        decode_fn = partial(bart_model._decode_fn, memory=interpolated_emb.cuda(), mem_pad_mask=combined_mask.bool())
        mol_strs, log_lhs = bart_model.sampler.beam_decode(decode_fn, 1, 1)
        for mol in mol_strs:
            print("Generated molecule: " + str(mol[0]) + ", valid: " + str(is_valid_molecule(mol[0])))

In [6]:
ORIGinterpolate_molecules(smiles1, smiles2)

Generated molecule: C1(=O)C=CC(=O)C(CCC=C)=C1O, valid: True
Generated molecule: C1(=O)C=CC(=O)C(CCCC)=C1, valid: True
Generated molecule: C1(=O)C=CC(=O)C(CCCCC)=C1, valid: True
Generated molecule: C1(=O)C=CC(=O)C(C)=C1CCC, valid: True
Generated molecule: C1(C)=C(C)C(=O)C(C)=CC1, valid: True
Generated molecule: C1(CCCCCC)=CC(=O)C=CC1, valid: True
Generated molecule: C(CCCCC1=CC(=O)CC1)C=C, valid: True
Generated molecule: C1(C)=CCC(C)=CC1=CCCC, valid: True
Generated molecule: C1CC(C(=C)C)=CCC1CC=C, valid: True
Generated molecule: C1CCC(CCC=C(C)C)=CC1=C, valid: True
