In [None]:
# default_exp model_zoo

# Model Zoo

> Standard Pretrained Models

In [None]:
#hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
# export
from mrl.imports import *
from mrl.core import *
from mrl.chem import *
from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.layers import *
from mrl.dataloaders import *
from mrl.vocab import *
from mrl.g_models import *
from mrl.agent import *

from torch.utils.model_zoo import load_url
from torch.hub import download_url_to_file

  return f(*args, **kwds)


In [None]:
# export 

S3_PREFIX = 'https://dmai-mrl.s3.amazonaws.com/mrl_public'

## LSTM LM

In [None]:
# export

def lstm_lm_small(vocab, drop=True):
    
    d_vocab = len(vocab.itos)
    bos_idx = vocab.stoi['bos']
    d_embedding = 256
    d_hidden = 1024
    n_layers = 3
    bidir = False
    tie_weights = True
    
    if drop:
        input_dropout = 0.3
        lstm_dropout = 0.3
    else:
        input_dropout = 0.
        lstm_dropout = 0.
    
    model = LSTM_LM(d_vocab, 
                    d_embedding,
                    d_hidden, 
                    n_layers,
                    input_dropout,
                    lstm_dropout,
                    bos_idx, 
                    bidir, 
                    tie_weights)
    
    return model

def lstm_lm_large(vocab, drop=True):
    
    d_vocab = len(vocab.itos)
    bos_idx = vocab.stoi['bos']
    d_embedding = 400
    d_hidden = 1552
    n_layers = 5
    bidir = False
    tie_weights = True
    
    
    if drop:
        input_dropout = 0.3
        lstm_dropout = 0.3
    else:
        input_dropout = 0.
        lstm_dropout = 0.
    
    model = LSTM_LM(d_vocab, 
                    d_embedding,
                    d_hidden, 
                    n_layers,
                    input_dropout,
                    lstm_dropout,
                    bos_idx, 
                    bidir, 
                    tie_weights)
    
    return model

In [None]:
vocab = CharacterVocab(SMILES_CHAR_VOCAB)
model = lstm_lm_small(vocab)
assert isinstance(model, nn.Module)
model = lstm_lm_large(vocab)
assert isinstance(model, nn.Module)

In [None]:
# export

class LSTM_LM_Small_ZINC(GenerativeAgent):
    def __init__(self, 
                 base_update=0.97, 
                 base_update_iter=5,
                 base_model=True,
                 opt_kwargs={},
                 clip=1.,
                 name = 'lstmlm_small_zinc'
                ):
        
        vocab = CharacterVocab(SMILES_CHAR_VOCAB)
        model = lstm_lm_small(vocab)
        location = f'{S3_PREFIX}/lstmlm_small_zinc.pt'
        model.load_state_dict(load_url(location, map_location='cpu'))
        loss_function = CrossEntropy()
        
        dataset = Text_Dataset(['C'], vocab)
        
        
        super().__init__(model,
                         vocab, 
                         loss_function,
                         dataset,
                         base_update=base_update,
                         base_update_iter=base_update_iter,
                         base_model=base_model,
                         opt_kwargs=opt_kwargs,
                         clip=clip,
                         name=name
                         )

In [None]:
# slow

agent = LSTM_LM_Small_ZINC()

preds, _ = agent.model.sample_no_grad(100, 100)
smiles = agent.reconstruct(preds)
mols = to_mols(smiles)
mols = [i for i in mols if i is not None]
assert len(mols)>80

In [None]:
# export

class LSTM_LM_Small_Chembl(GenerativeAgent):
    def __init__(self, 
                 base_update=0.97, 
                 base_update_iter=5,
                 base_model=True,
                 opt_kwargs={},
                 clip=1.,
                 name = 'lstmlm_small_chembl'
                ):
        
        vocab = CharacterVocab(SMILES_CHAR_VOCAB)
        model = lstm_lm_small(vocab)
        location = f'{S3_PREFIX}/lstmlm_small_chembl.pt'
        model.load_state_dict(load_url(location, map_location='cpu'))
        loss_function = CrossEntropy()
        
        dataset = Text_Dataset(['C'], vocab)
        
        
        super().__init__(model,
                         vocab, 
                         loss_function,
                         dataset,
                         base_update=base_update,
                         base_update_iter=base_update_iter,
                         base_model=base_model,
                         opt_kwargs=opt_kwargs,
                         clip=clip,
                         name=name
                         )

In [None]:
# slow

agent = LSTM_LM_Small_Chembl()

preds, _ = agent.model.sample_no_grad(100, 100)
smiles = agent.reconstruct(preds)
mols = to_mols(smiles)
mols = [i for i in mols if i is not None]
assert len(mols)>80

In [None]:
# export

class LSTM_LM_Small_ZINC_NC(GenerativeAgent):
    def __init__(self, 
                 base_update=0.97, 
                 base_update_iter=5,
                 base_model=True,
                 opt_kwargs={},
                 clip=1.,
                 name = 'lstmlm_small_zinc_nc'
                ):
        
        vocab = CharacterVocab(SMILES_CHAR_VOCAB, prefunc=remove_stereo, postfunc=remove_stereo)
        model = lstm_lm_small(vocab)
        location = f'{S3_PREFIX}/lstmlm_small_zinc_nc.pt'
        model.load_state_dict(load_url(location, map_location='cpu'))
        loss_function = CrossEntropy()
        
        dataset = Text_Dataset(['C'], vocab)
        
        
        super().__init__(model,
                         vocab, 
                         loss_function,
                         dataset,
                         base_update=base_update,
                         base_update_iter=base_update_iter,
                         base_model=base_model,
                         opt_kwargs=opt_kwargs,
                         clip=clip,
                         name=name
                         )

In [None]:
# slow

agent = LSTM_LM_Small_ZINC_NC()

preds, _ = agent.model.sample_no_grad(100, 100)
smiles = agent.reconstruct(preds)
mols = to_mols(smiles)
mols = [i for i in mols if i is not None]
assert len(mols)>80

In [None]:
# export

class LSTM_LM_Small_Chembl_NC(GenerativeAgent):
    def __init__(self, 
                 base_update=0.97, 
                 base_update_iter=5,
                 base_model=True,
                 opt_kwargs={},
                 clip=1.,
                 name = 'lstmlm_small_chembl_nc'
                ):
        
        vocab = CharacterVocab(SMILES_CHAR_VOCAB, prefunc=remove_stereo, postfunc=remove_stereo)
        model = lstm_lm_small(vocab)
        location = f'{S3_PREFIX}/lstmlm_small_chembl_nc.pt'
        model.load_state_dict(load_url(location, map_location='cpu'))
        loss_function = CrossEntropy()
        
        dataset = Text_Dataset(['C'], vocab)
        
        
        super().__init__(model,
                         vocab, 
                         loss_function,
                         dataset,
                         base_update=base_update,
                         base_update_iter=base_update_iter,
                         base_model=base_model,
                         opt_kwargs=opt_kwargs,
                         clip=clip,
                         name=name
                         )

In [None]:
# slow

agent = LSTM_LM_Small_Chembl_NC()

preds, _ = agent.model.sample_no_grad(100, 100)
smiles = agent.reconstruct(preds)
mols = to_mols(smiles)
mols = [i for i in mols if i is not None]
assert len(mols)>80

In [None]:
# export
        
class LSTM_LM_Small_ZINC_Selfies(GenerativeAgent):
    def __init__(self, 
                 base_update=0.97, 
                 base_update_iter=5,
                 base_model=True,
                 opt_kwargs={},
                 clip=1.,
                 name = 'lstmlm_small_zinc_selfies'
                ):
        
        vocab = FuncVocab(SELFIES_VOCAB, split_selfie, 
                  prefunc=smile_to_selfie, postfunc=selfie_to_smile)
        model = lstm_lm_small(vocab)
        location = f'{S3_PREFIX}/lstmlm_small_zinc_selfies.pt'
        model.load_state_dict(load_url(location, map_location='cpu'))
        loss_function = CrossEntropy()
        
        dataset = Text_Dataset(['C'], vocab)
        
        
        super().__init__(model,
                         vocab, 
                         loss_function,
                         dataset,
                         base_update=base_update,
                         base_update_iter=base_update_iter,
                         base_model=base_model,
                         opt_kwargs=opt_kwargs,
                         clip=clip,
                         name=name
                         )

In [None]:
# slow

agent = LSTM_LM_Small_ZINC_Selfies()

preds, _ = agent.model.sample_no_grad(100, 100)
smiles = agent.reconstruct(preds)
mols = to_mols(smiles)
mols = [i for i in mols if i is not None]
assert len(mols)>80

In [None]:
# export
        
class LSTM_LM_Small_Chembl_Selfies(GenerativeAgent):
    def __init__(self, 
                 base_update=0.97, 
                 base_update_iter=5,
                 base_model=True,
                 opt_kwargs={},
                 clip=1.,
                 name = 'lstmlm_small_chembl_selfies'
                ):
        
        vocab = FuncVocab(SELFIES_VOCAB, split_selfie, 
                  prefunc=smile_to_selfie, postfunc=selfie_to_smile)
        model = lstm_lm_small(vocab)
        location = f'{S3_PREFIX}/lstmlm_small_chembl_selfies.pt'
        model.load_state_dict(load_url(location, map_location='cpu'))
        loss_function = CrossEntropy()
        
        dataset = Text_Dataset(['C'], vocab)
        
        
        super().__init__(model,
                         vocab, 
                         loss_function,
                         dataset,
                         base_update=base_update,
                         base_update_iter=base_update_iter,
                         base_model=base_model,
                         opt_kwargs=opt_kwargs,
                         clip=clip,
                         name=name
                         )

In [None]:
# slow

agent = LSTM_LM_Small_Chembl_Selfies()

preds, _ = agent.model.sample_no_grad(100, 100)
smiles = agent.reconstruct(preds)
mols = to_mols(smiles)
mols = [i for i in mols if i is not None]
assert len(mols)>80

## Conditional LSTM LM

In [None]:
# export

def cond_lstm_small(vocab, encoder, drop=True):
    
    d_vocab = len(vocab.itos)
    bos_idx = vocab.stoi['bos']
    
    d_latent = 512
    d_embedding = 256
    d_hidden = 1024
    n_layers = 3
    bidir = False
    tie_weights = True
    condition_hidden = True
    condition_output = False
    norm_latent = True
    
    if drop:
        input_dropout = 0.3
        lstm_dropout = 0.3
    else:
        input_dropout = 0.
        lstm_dropout = 0.
    
    model = Conditional_LSTM_LM(encoder, 
                                d_vocab, 
                                d_embedding, 
                                d_hidden, 
                                d_latent, 
                                n_layers,
                                input_dropout, 
                                lstm_dropout, 
                                norm_latent, 
                                condition_hidden, 
                                condition_output, 
                                bos_idx)
    
    return model

def cond_lstm_large(vocab, encoder, drop=True):

    d_vocab = len(vocab.itos)
    bos_idx = vocab.stoi['bos']
    
    d_latent = 512
    d_embedding = 400
    d_hidden = 1552
    n_layers = 5
    bidir = False
    tie_weights = True
    condition_hidden = True
    condition_output = False
    norm_latent = True
    
    
    if drop:
        input_dropout = 0.3
        lstm_dropout = 0.3
    else:
        input_dropout = 0.
        lstm_dropout = 0.
    
    model = Conditional_LSTM_LM(encoder, 
                                d_vocab, 
                                d_embedding, 
                                d_hidden, 
                                d_latent, 
                                n_layers,
                                input_dropout, 
                                lstm_dropout, 
                                norm_latent, 
                                condition_hidden, 
                                condition_output, 
                                bos_idx)
    
    return model

def mlp_cond_lstm_small(vocab, drop=True):
    if drop:
        enc_drops = [0.1, 0.1]
    else:
        enc_drops = [0., 0.]
        
    encoder = MLP_Encoder(2048, [1024, 512], 512, enc_drops)
    return cond_lstm_small(vocab, encoder, drop=drop)

def mlp_cond_lstm_large(vocab, drop=True):
    if drop:
        enc_drops = [0.2, 0.2, 0.2, 0.2]
    else:
        enc_drops = [0., 0., 0., 0.]
    
    encoder = MLP_Encoder(2048, [1024, 512, 512, 512], 512, [0.2, 0.2, 0.2, 0.2])
    return cond_lstm_small(vocab, encoder)

In [None]:
vocab = CharacterVocab(SMILES_CHAR_VOCAB)
model = mlp_cond_lstm_small(vocab)
assert isinstance(model, nn.Module)
model = mlp_cond_lstm_large(vocab)
assert isinstance(model, nn.Module)

# VAE

In [None]:
# export

def mlp_vae(vocab, drop=True):
    
    d_vocab = len(vocab.itos)
    bos_idx = vocab.stoi['bos']
    
    d_embedding = 256
    encoder_d_in = 2048
    encoder_dims = [1024, 512]
    d_hidden = 1024
    n_layers = 3
    d_latent =512
    condition_hidden=True
    condition_output=True
    
    if drop:
        encoder_drops = [0.2, 0.2]
        input_dropout=0.3
        lstm_dropout=0.3
    else:
        encoder_drops = [0., 0.]
        input_dropout=0.
        lstm_dropout=0.

    model = MLP_VAE(
                d_vocab,
                d_embedding,
                encoder_d_in,
                encoder_dims,
                encoder_drops,
                d_hidden,
                n_layers,
                d_latent,
                input_dropout=input_dropout,
                lstm_dropout=lstm_dropout,
                condition_hidden=condition_hidden,
                condition_output=condition_output,
                bos_idx=bos_idx,
            )
    
    return model
    
def conv_vae(vocab, drop=True):
    
    d_vocab = len(vocab.itos)
    bos_idx = vocab.stoi['bos']
    
    d_embedding = 256
    conv_filters = [256, 512, 512]
    kernel_sizes = [7, 7, 7]
    strides = [2, 2, 2]
    d_hidden = 1024
    n_layers = 3
    d_latent = 512
    condition_hidden=True
    condition_output=True
    
    if drop:
        conv_drops = [0.2, 0.2, 0.2]
        input_dropout=0.3
        lstm_dropout=0.3
    else:
        conv_drops = [0., 0., 0.]
        input_dropout=0.
        lstm_dropout=0.
    
    model = Conv_VAE(
                    d_vocab,
                    d_embedding,
                    conv_filters,
                    kernel_sizes,
                    strides,
                    conv_drops,
                    d_hidden,
                    n_layers,
                    d_latent,
                    input_dropout=input_dropout,
                    lstm_dropout=lstm_dropout,
                    condition_hidden=condition_hidden,
                    condition_output=condition_output,
                    bos_idx=bos_idx)
    
    return model
    
def lstm_vae(vocab, drop=True):
    
    d_vocab = len(vocab.itos)
    bos_idx = vocab.stoi['bos']
    
    d_embedding = 256
    d_hidden = 1024
    n_layers = 3
    d_latent = 512
    condition_hidden=True
    condition_output=True
    
    if drop:
        input_dropout=0.3
        lstm_dropout=0.3
    else:
        input_dropout=0.
        lstm_dropout=0.

    
    model = LSTM_VAE(
                    d_vocab,
                    d_embedding,
                    d_hidden,
                    n_layers,
                    d_latent,
                    input_dropout=input_dropout,
                    lstm_dropout=lstm_dropout,
                    condition_hidden=condition_hidden,
                    condition_output=condition_output,
                    bos_idx=bos_idx,
                )

    return model

In [None]:
vocab = CharacterVocab(SMILES_CHAR_VOCAB)
model = mlp_vae(vocab)
assert isinstance(model, nn.Module)
model = conv_vae(vocab)
assert isinstance(model, nn.Module)
model = lstm_vae(vocab)
assert isinstance(model, nn.Module)