In [7]:
!export PYTHONPATH=/opt/MolBART:/worspace/nbs/:.

from megatron.initialize import initialize_megatron
from megatron_molbart.train import setup_model_and_optimizer
from megatron import get_args

initialize_megatron()
args = get_args()
(model, optimizer, lr_scheduler) = setup_model_and_optimizer(args)

ModuleNotFoundError: No module named 'megatron_bart'

In [5]:
!ls /opt/MolBART

README.md	megatron_conda_env.yml	   requirements.txt  train_megatron.sh
bart_vocab.txt	megatron_molbart	   setup.py
evaluate.sh	megatron_requirements.txt  test
fine_tune.sh	molbart			   train.sh


In [2]:
initialize_megatron()

usage: ipykernel_launcher.py [-h] [--num-layers NUM_LAYERS]
                             [--num-unique-layers NUM_UNIQUE_LAYERS]
                             [--param-sharing-style {grouped,spaced}]
                             [--hidden-size HIDDEN_SIZE]
                             [--num-attention-heads NUM_ATTENTION_HEADS]
                             [--max-position-embeddings MAX_POSITION_EMBEDDINGS]
                             [--make-vocab-size-divisible-by MAKE_VOCAB_SIZE_DIVISIBLE_BY]
                             [--layernorm-epsilon LAYERNORM_EPSILON]
                             [--apply-residual-connection-post-layernorm]
                             [--openai-gelu] [--onnx-safe ONNX_SAFE]
                             [--attention-dropout ATTENTION_DROPOUT]
                             [--hidden-dropout HIDDEN_DROPOUT]
                             [--weight-decay WEIGHT_DECAY]
                             [--clip-grad CLIP_GRAD] [--batch-size BATCH_SIZE]
                 

SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [None]:
from megatron_molbart.train import 
from molbart.util import DEFAULT_CHEM_TOKEN_START
from molbart.util import REGEX
from molbart.util import DEFAULT_VOCAB_PATH

tokenizer = MolEncTokeniser.from_vocab_file(DEFAULT_VOCAB_PATH, REGEX,
        DEFAULT_CHEM_TOKEN_START)

VOCAB_SIZE = len(tokenizer)
MAX_SEQ_LEN = 512
pad_token_idx = tokenizer.vocab[tokenizer.pad_token]
sampler = DecodeSampler(tokenizer, MAX_SEQ_LEN)

model = MegatronBART(
    sampler,
    pad_token_idx,
    VOCAB_SIZE,
    args.hidden_size,
    args.num_layers,
    args.num_attention_heads,
    args.hidden_size * 4,
    MAX_SEQ_LEN,
    dropout=0.1,
    )

In [1]:
import torch
import torch.nn
import pickle
from pathlib import Path
import numpy as np
from functools import partial

from molbart.models import BARTModel
from molbart.decode import DecodeSampler
from rdkit import Chem

## Load Tokenizer and Model

The following code will load a pickled tokenizer and model checkpoint.

In [2]:
def load_tokenizer(tokenizer_path):
    """Load pickled tokenizer
       
       Params:
           tokenizer_path: str, path to pickled tokenizer
    
       Returns:
           MolEncTokeniser tokenizer object
    """
    
    tokenizer_path = Path(tokenizer_path)
    
    with open(tokenizer_path, 'rb') as fh:
        tokenizer = pickle.load(fh)
        
    return tokenizer


def load_model(model_checkpoint_path, tokenizer, max_seq_len):
    """Load saved model checkpoint
    
       Params:
           model_checkpoint_path: str, path to saved model checkpoint
           tokenizer: MolEncTokeniser tokenizer object
           max_seq_len: int, maximum sequence length
        
       Returns:
           MolBART trained model
    """
    
    sampler = DecodeSampler(tokenizer, max_seq_len)
    pad_token_idx = tokenizer.vocab[tokenizer.pad_token]

    bart_model = BARTModel.load_from_checkpoint(model_path, 
                                                decode_sampler=sampler, 
                                                pad_token_idx=pad_token_idx)
    bart_model.sampler.device = "cuda"
    return bart_model.cuda()
    

tokenizer_path = '/data/training_data/mol_opt_tokeniser.pickle'
model_path = '/data/training_data/az_molbart_pretrain.ckpt'

max_seq_len = 64
tokenizer = load_tokenizer(tokenizer_path)
bart_model = load_model(model_path, tokenizer, max_seq_len)

## Interpolation Functions

The following are updated versions of the interpolation functions below. These versions should be used instead of those below. Key changes vs the original functions are the ability to set padding for smiles tokens and batch-wise calculation of the interpolated embeddings. 

In [3]:
def smiles2embedding(smiles, tokenizer, pad_length=None):
    """Calculate embedding and padding mask for smiles with optional extra padding
    
       Params
           smiles: string, input SMILES molecule
           tokenizer: MolEncTokeniser tokenizer object
           pad_length: optional extra
           
       Returns
           embedding array and boolean mask
    """
    
    assert isinstance(smiles, str)
    if pad_length:
        assert pad_length >= len(smiles) + 2
        
    tokens = tokenizer.tokenise([smiles], pad=True)

    # Append to tokens and mask if appropriate
    if pad_length:
        for i in range(len(tokens['original_tokens'])):        
            n_pad = pad_length - len(tokens['original_tokens'][i])
            tokens['original_tokens'][i] += [tokenizer.pad_token] * n_pad
            tokens['pad_masks'][i] += [1] * n_pad

    token_ids = torch.tensor(tokenizer.convert_tokens_to_ids(tokens['original_tokens'])).cuda().T
    pad_mask = torch.tensor(tokens['pad_masks']).bool().cuda().T
    encode_input = {"encoder_input": token_ids, "encoder_pad_mask": pad_mask}

    embedding = bart_model.encode(encode_input)
    torch.cuda.empty_cache()
    
    return embedding, pad_mask


def interpolate_molecules(smiles1, smiles2, num_interp, tokenizer, bart_model, k=1):
    """Interpolate between two molecules in embedding space.
    
       Params
           smiles1: str, input SMILES molecule
           smiles2: str, input SMILES molecule
           num_interp: int, number of molecules to interpolate
           tokenizer: MolEncTokeniser tokenizer object
           bart_model: MolBART trained model
           k: number of molecules for beam search, default 1. Can increase if there are issues with validity
           
       Returns
           list of interpolated smiles molecules
    """
    
    pad_length = max(len(smiles1), len(smiles2)) + 2 # add 2 for start / stop
    embedding1, pad_mask1 = smiles2embedding(smiles1, tokenizer, pad_length=pad_length)
    embedding2, pad_mask2 = smiles2embedding(smiles2, tokenizer, pad_length=pad_length)

    scale = torch.linspace(0.0, 1.0, num_interp+2)[1:-1] # skip first and last because they're the selected molecules
    scale = scale.unsqueeze(0).unsqueeze(-1).cuda()
    interpolated_emb = torch.lerp(embedding1, embedding2, scale).permute(1, 0, 2).cuda()
    combined_mask = (pad_mask1 & pad_mask2).bool().cuda()

    batch_size = 1 # TODO: parallelize this loop as a batch
    smiles_interp_list = []
    
    for memory in interpolated_emb:
        decode_fn = partial(bart_model._decode_fn, mem_pad_mask=combined_mask, memory=memory)
        mol_strs, log_lhs = bart_model.sampler.beam_decode(decode_fn, batch_size=batch_size, k=k)
        mol_strs = sum(mol_strs, []) # flatten list
        
        for smiles in mol_strs:
            mol = Chem.MolFromSmiles(smiles)
            if (mol is not None) and (smiles not in smiles_interp_list):
                smiles_interp_list.append(smiles)
                break
                
    return smiles_interp_list

In [4]:
smiles1 = "CC(=O)OC1=CC=CC=C1C(=O)O"
smiles2 = "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O"
num_interp = 10
k = 1

interpolate_molecules(smiles1, smiles2, num_interp, tokenizer, bart_model, k=k)

['C1(=O)C(C)=CC(=O)C(C(=O)C)=C1',
 'C1(=O)C(C)=C(C)C(=O)C(C)=C1O',
 'C1(=O)C(C)=C(C)C(=O)C(C)=C1',
 'C1(=O)C(C)=C(C)C(=O)C(C)=C1C',
 'C1(C)=CC(=O)C(C)=C(C)C1=O',
 'C1(C)=C(C(C)C)C(=O)C(C)=CC1=O',
 'C1(C)=CC(=O)C(CCCCC)=C(C)C1',
 'C1(C)=C(CCCCCC)C(=O)C=C(C)C1=O']