In [1]:
# Copyright: Wentao Shi, 2021
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.functional import softmax

# Copyright: Wentao Shi, 2021
import torch
import re
import yaml
import selfies as sf

from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence


# Copyright: Wentao Shi, 2021
import yaml
import os
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
from rdkit import Chem
import selfies as sf



# suppress rdkit error
from rdkit import rdBase
rdBase.DisableLog('rdApp.error')

# dataloader

In [2]:



def dataloader_gen(dataset_dir, percentage, which_vocab, vocab_path,
                   batch_size, PADDING_IDX, shuffle, drop_last=True):
    """
    Genrate the dataloader for training
    """
    if which_vocab == "selfies":
        vocab = SELFIEVocab(vocab_path)
    elif which_vocab == "regex":
        vocab = RegExVocab(vocab_path)
    elif which_vocab == "char":
        vocab = CharVocab(vocab_path)
    else:
        raise ValueError("Wrong vacab name for configuration which_vocab!")

    dataset = SMILESDataset(dataset_dir, percentage, vocab)

    def pad_collate(batch):
        """
        Put the sequences of different lengths in a minibatch by paddding.
        """
        lengths = [len(x) for x in batch]

        # embedding layer takes long tensors
        batch = [torch.tensor(x, dtype=torch.long) for x in batch]

        x_padded = pad_sequence(
            batch, 
            batch_first=True,
            padding_value=PADDING_IDX
        )

        return x_padded, lengths

    dataloader = DataLoader(
        dataset=dataset, 
        batch_size=batch_size, 
        shuffle=shuffle,
        drop_last=drop_last, 
        collate_fn=pad_collate
    )

    return dataloader, len(dataset)


class SMILESDataset(Dataset):
    def __init__(self, smiles_file, percentage, vocab):
        """
        smiles_file: path to the .smi file containing SMILES.
        percantage: percentage of the dataset to use.
        """
        super(SMILESDataset, self).__init__()
        assert(0 < percentage <= 1)

        self.percentage = percentage
        self.vocab = vocab

        # load eaqual portion of data from each tranche
        self.data = self.read_smiles_file(smiles_file)
        print("total number of SMILES loaded: ", len(self.data))

        # convert the smiles to selfies
        if self.vocab.name == "selfies":
            self.data = [sf.encoder(x)
                         for x in self.data if sf.encoder(x) is not None]
            print("total number of valid SELFIES: ", len(self.data))

    def read_smiles_file(self, path):
        # need to exclude first line which is not SMILES
        with open(path, "r") as f:
            smiles = [line.strip("\n") for line in f.readlines()]

        num_data = len(smiles)

        return smiles[0:int(num_data * self.percentage)]

    def __getitem__(self, index):
        mol = self.data[index]

        # convert the data into integer tokens
        mol = self.vocab.tokenize_smiles(mol)

        return mol

    def __len__(self):
        return len(self.data)


class CharVocab:
    def __init__(self, vocab_path):
        self.name = "char"

        # load the pre-computed vocabulary
        with open(vocab_path, 'r') as f:
            self.vocab = yaml.full_load(f)

        # a dictionary to map integer back to SMILES
        # tokens for sampling
        self.int2tocken = {}
        for token, num in self.vocab.items():
            self.int2tocken[num] = token

        # a hashset of tokens for O(1) lookup
        self.tokens = self.vocab.keys()

    def tokenize_smiles(self, smiles):
        """
        Takes a SMILES string and returns a list of tokens.
        Atoms with 2 characters are treated as one token. The 
        logic references this code piece:
        https://github.com/topazape/LSTM_Chem/blob/master/lstm_chem/utils/smiles_tokenizer2.py
        """
        n = len(smiles)
        tokenized = ['<sos>']
        i = 0

        # process all characters except the last one
        while (i < n - 1):
            # procoss tokens with length 2 first
            c2 = smiles[i:i + 2]
            if c2 in self.tokens:
                tokenized.append(c2)
                i += 2
                continue

            # tokens with length 2
            c1 = smiles[i]
            if c1 in self.tokens:
                tokenized.append(c1)
                i += 1
                continue

            raise ValueError(
                "Unrecognized charater in SMILES: {}, {}".format(c1, c2))

        # process last character if there is any
        if i == n:
            pass
        elif i == n - 1 and smiles[i] in self.tokens:
            tokenized.append(smiles[i])
        else:
            raise ValueError(
                "Unrecognized charater in SMILES: {}".format(smiles[i]))

        tokenized.append('<eos>')

        tokenized = [self.vocab[token] for token in tokenized]
        return tokenized

    def combine_list(self, smiles):
        return "".join(smiles)


class RegExVocab:
    def __init__(self, vocab_path):
        self.name = "regex"

        # load the pre-computed vocabulary
        with open(vocab_path, 'r') as f:
            self.vocab = yaml.full_load(f)

        # a dictionary to map integer back to SMILES
        # tokens for sampling
        self.int2tocken = {}
        for token, num in self.vocab.items():
            if token == "R":
                self.int2tocken[num] = "Br"
            elif token == "L":
                self.int2tocken[num] = "Cl"
            else:
                self.int2tocken[num] = token

    def tokenize_smiles(self, smiles):
        """Takes a SMILES string and returns a list of tokens.
        This will swap 'Cl' and 'Br' to 'L' and 'R' and treat
        '[xx]' as one token."""
        regex = '(\[[^\[\]]{1,6}\])'
        smiles = self.replace_halogen(smiles)
        char_list = re.split(regex, smiles)

        tokenized = ['<sos>']

        for char in char_list:
            if char.startswith('['):
                tokenized.append(char)
            else:
                chars = [unit for unit in char]
                [tokenized.append(unit) for unit in chars]
        tokenized.append('<eos>')

        # convert tokens to integer tokens
        tokenized = [self.vocab[token] for token in tokenized]

        return tokenized

    def replace_halogen(self, string):
        """Regex to replace Br and Cl with single letters"""
        br = re.compile('Br')
        cl = re.compile('Cl')
        string = br.sub('R', string)
        string = cl.sub('L', string)

        return string

    def combine_list(self, smiles):
        return "".join(smiles)


class SELFIEVocab:
    def __init__(self, vocab_path):
        self.name = "selfies"

        # load the pre-computed vocabulary
        with open(vocab_path, 'r') as f:
            self.vocab = yaml.full_load(f)

        self.int2tocken = {value: key for key, value in self.vocab.items()}

    def tokenize_smiles(self, mol):
        """convert the smiles to selfies, then return 
        integer tokens."""
        ints = [self.vocab['<sos>']]

        #encoded_selfies = sf.encoder(smiles)
        selfies_list = list(sf.split_selfies(mol))
        for token in selfies_list:
            ints.append(self.vocab[token])

        ints.append(self.vocab['<eos>'])

        return ints

    def combine_list(self, selfies):
        return "".join(selfies)

  regex = '(\[[^\[\]]{1,6}\])'


# model

In [3]:


class RNN(torch.nn.Module):
    def __init__(self, rnn_config):
        super(RNN, self).__init__()

        self.embedding_layer = nn.Embedding(
            num_embeddings=rnn_config['num_embeddings'],
            embedding_dim=rnn_config['embedding_dim'],
            padding_idx=rnn_config['num_embeddings'] - 1
        )

        if rnn_config['rnn_type'] == 'LSTM':
            self.rnn = nn.LSTM(
                input_size=rnn_config['input_size'],
                hidden_size=rnn_config['hidden_size'],
                num_layers=rnn_config['num_layers'],
                batch_first=True,
                dropout=rnn_config['dropout']
            )
        elif rnn_config['rnn_type'] == 'GRU':
            self.rnn = nn.GRU(
                input_size=rnn_config['input_size'],
                hidden_size=rnn_config['hidden_size'],
                num_layers=rnn_config['num_layers'],
                batch_first=True,
                dropout=rnn_config['dropout']
            )
        else:
            raise ValueError(
                "rnn_type should be either 'LSTM' or 'GRU'."
            )

        # output does not include <sos> and <pad>, so
        # decrease the num_embeddings by 2
        self.linear = nn.Linear(
            rnn_config['hidden_size'], rnn_config['num_embeddings'] - 2
        )

    def forward(self, data, lengths):
        embeddings = self.embedding_layer(data)

        # pack the padded input
        # the lengths are decreased by 1 because we don't
        # use <eos> for input and we don't need <sos> for
        # output during traning.
        embeddings = pack_padded_sequence(
            input=embeddings, 
            lengths=lengths, 
            batch_first=True, 
            enforce_sorted=False
        )

        # recurrent network, discard (h_n, c_n) in output.
        # Tearcher-forcing is used here, so we directly feed
        # the whole sequence to model.
        embeddings, _ = self.rnn(embeddings)

        # linear layer to generate input of softmax
        embeddings = self.linear(embeddings.data)

        # return the packed representation for backpropagation,
        # the targets will also be packed.
        return embeddings

    def sample(self, batch_size, vocab, device, max_length=140):
        """Use this function if device is GPU"""
        # get integer of "start of sequence"
        start_int = vocab.vocab['<sos>']

        # create a tensor of shape [batch_size, seq_step=1]
        sos = torch.ones(
            [batch_size, 1], 
            dtype=torch.long, 
            device=device
        )
        sos = sos * start_int

        # sample first output
        output = []
        x = self.embedding_layer(sos)
        x, hidden = self.rnn(x)
        x = self.linear(x)
        x = softmax(x, dim=-1)
        x = torch.multinomial(x.squeeze(), 1)
        output.append(x)

        # a tensor to indicate if the <eos> token is found
        # for all data in the mini-batch
        finish = torch.zeros(batch_size, dtype=torch.bool).to(device)

        # sample until every sequence in the mini-batch
        # has <eos> token
        for _ in range(max_length):
            # forward rnn
            x = self.embedding_layer(x)
            x, hidden = self.rnn(x, hidden)
            x = self.linear(x)
            x = softmax(x, dim=-1)
            
            # sample
            x = torch.multinomial(x.squeeze(), 1)
            output.append(x)

            # terminate if <eos> is found for every data
            eos_sampled = (x == vocab.vocab['<eos>']).data
            finish = torch.logical_or(finish, eos_sampled.squeeze())
            if torch.all(finish):
                return torch.cat(output, -1)

        return torch.cat(output, -1)

    def sample_cpu(self, vocab):
        """Use this function if device is CPU"""
        output = []

        # get integer of "start of sequence"
        start_int = vocab.vocab['<sos>']

        # create a tensor of shape [batch_size=1, seq_step=1]
        sos = torch.tensor(
            start_int, 
            dtype=torch.long
        ).unsqueeze(dim=0
        ).unsqueeze(dim=0)

        # sample first output
        x = self.embedding_layer(sos)
        x, hidden = self.rnn(x)
        x = self.linear(x)
        x = softmax(x, dim=-1)
        x = torch.multinomial(x.squeeze(), 1)
        output.append(x.item())

        # use first output to iteratively sample until <eos> occurs
        while output[-1] != vocab.vocab['<eos>']:
            x = x.unsqueeze(dim=0)
            x = self.embedding_layer(x)
            x, hidden = self.rnn(x, hidden)
            x = self.linear(x)
            x = softmax(x, dim=-1)
            x = torch.multinomial(x.squeeze(), 1)
            output.append(x.item())

        # convert integers to tokens
        output = [vocab.int2tocken[x] for x in output]

        # popout <eos>
        output.pop()

        # convert to a single string
        output = vocab.combine_list(output)

        return output

# train

In [4]:



def make_vocab(config):
    # load vocab
    which_vocab = config["which_vocab"]
    vocab_path = config["vocab_path"]

    if which_vocab == "selfies":
        return SELFIEVocab(vocab_path)
    elif which_vocab == "regex":
        return RegExVocab(vocab_path)
    elif which_vocab == "char":
        return CharVocab(vocab_path)
    else:
        raise ValueError(
            "Wrong vacab name for configuration which_vocab!"
        )


def sample(model, vocab, batch_size):
    """Sample a batch of SMILES from current model."""
    model.eval()
    # sample
    sampled_ints = model.sample(
        batch_size=batch_size,
        vocab=vocab,
        device=device
    )

    # convert integers back to SMILES
    molecules = []
    sampled_ints = sampled_ints.tolist()
    for ints in sampled_ints:
        molecule = []
        for x in ints:
            if vocab.int2tocken[x] == '<eos>':
                break
            else:
                molecule.append(vocab.int2tocken[x])
        molecules.append("".join(molecule))

    # convert SELFIES back to SMILES
    if vocab.name == 'selfies':
        molecules = [sf.decoder(x) for x in molecules]

    return molecules


def compute_valid_rate(molecules):
    """compute the percentage of valid SMILES given
    a list SMILES strings"""
    num_valid, num_invalid = 0, 0
    for mol in molecules:
        mol = Chem.MolFromSmiles(mol)
        if mol is None:
            num_invalid += 1
        else:
            num_valid += 1

    return num_valid, num_invalid


In [9]:
# detect cpu or gpu
device = torch.device(
    'cuda:1' if torch.cuda.is_available() else 'cpu'
)
print('device: ', device)

config_dir = "./train.yaml"
with open(config_dir, 'r') as f:
    config = yaml.full_load(f)

# directory for results
out_dir = config['out_dir']
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
trained_model_dir = out_dir + 'trained_model.pt'

# save the configuration file for future reference
with open(out_dir + 'config.yaml', 'w') as f:
    yaml.dump(config, f)

device:  cuda:1


In [5]:





# training data
dataset_dir = config['dataset_dir']
which_vocab = config['which_vocab']
vocab_path = config['vocab_path']
percentage = config['percentage']

# create dataloader
batch_size = config['batch_size']
shuffle = config['shuffle']
PADDING_IDX = config['rnn_config']['num_embeddings'] - 1
num_workers = os.cpu_count()
print('number of workers to load data: ', num_workers)
print('which vocabulary to use: ', which_vocab)
dataloader, train_size = dataloader_gen(
    dataset_dir, percentage, which_vocab,
    vocab_path, batch_size, PADDING_IDX,
    shuffle, drop_last=False
)

device:  cuda
number of workers to load data:  32
which vocabulary to use:  selfies
total number of SMILES loaded:  538247
total number of valid SELFIES:  538247


In [11]:
model

RNN(
  (embedding_layer): Embedding(48, 256, padding_idx=47)
  (rnn): GRU(256, 512, num_layers=3, batch_first=True)
  (linear): Linear(in_features=512, out_features=46, bias=True)
)

In [10]:




print("########################################")
# model and training configuration
rnn_config = config['rnn_config']
model = RNN(rnn_config).to(device)
learning_rate = config['learning_rate']
weight_decay = config['weight_decay']

# Making reduction="sum" makes huge difference
# in valid rate of sampled molecules.
loss_function = nn.CrossEntropyLoss(reduction='sum')

# create optimizer
if config['which_optimizer'] == "adam":
    optimizer = torch.optim.Adam(
        model.parameters(), lr=learning_rate,
        weight_decay=weight_decay, amsgrad=True
    )
elif config['which_opti0mizer'] == "sgd":
    optimizer = torch.optim.SGD(
        model.parameters(), lr=learning_rate,
        weight_decay=weight_decay, momentum=0.9
    )
else:
    raise ValueError(
        "Wrong optimizer! Select between 'adam' and 'sgd'."
    )

# learning rate scheduler
scheduler = ReduceLROnPlateau(
    optimizer, mode='min',
    factor=0.5, patience=5,
    cooldown=10, min_lr=0.0001,
    verbose=True
)

# vocabulary object used by the sample() function
vocab = make_vocab(config)

# train and validation, the results are saved.
train_losses = []
best_valid_rate = 0
num_epoch = config['num_epoch']

print('begin training...')
for epoch in range(1, 1 + num_epoch):
    model.train()
    train_loss = 0
    for data, lengths in tqdm(dataloader):
        # the lengths are decreased by 1 because we don't
        # use <eos> for input and we don't need <sos> for
        # output during traning.
        lengths = [length - 1 for length in lengths]

        optimizer.zero_grad()
        data = data.to(device)
        preds = model(data, lengths)

        # The <sos> token is removed before packing, because
        # we don't need <sos> of output during training.
        # the image_captioning project uses the same method
        # which directly feeds the packed sequences to
        # the loss function:
        # https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/03-advanced/image_captioning/train.py
        targets = pack_padded_sequence(
            data[:, 1:],
            lengths,
            batch_first=True,
            enforce_sorted=False
        ).data

        loss = loss_function(preds, targets)
        loss.backward()
        optimizer.step()

        # accumulate loss over mini-batches
        train_loss += loss.item()  # * data.size()[0]

    train_losses.append(train_loss / train_size)

    print('epoch {}, train loss: {}.'.format(epoch, train_losses[-1]))

    scheduler.step(train_losses[-1])

    # sample 1024 SMILES each epoch
    sampled_molecules = sample(model, vocab, batch_size=1024)

    # print the valid rate each epoch
    num_valid, num_invalid = compute_valid_rate(sampled_molecules)
    valid_rate = num_valid / (num_valid + num_invalid)

    print('valid rate: {}'.format(valid_rate))

    # update the saved model upon best validation loss
    if valid_rate >= best_valid_rate:
        best_valid_rate = valid_rate
        print('model saved at epoch {}'.format(epoch))
        torch.save(model.state_dict(), trained_model_dir)

# save train and validation losses
with open(out_dir + 'loss.yaml', 'w') as f:
    yaml.dump(train_losses, f)




########################################


RuntimeError: CUDA error: CUDA-capable device(s) is/are busy or unavailable
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
