In [1]:
import pandas as pd

idx_to_char = set()
smiles = []

with open('sars_protease.tsv') as f:
    data = f.readlines()
    for row in data[1:]:
        smile = row.split('\t')[1]
        for c in smile:
            idx_to_char.add(c)
        smiles.append(smile)
    idx_to_char = list(idx_to_char)
    del(data)
    
print(len(smiles))
print(idx_to_char)

290893
['s', '3', 'N', 'd', '-', 'A', '4', '5', 'e', 'C', 'P', '9', '@', '6', 'a', ')', 'B', 'F', '7', 'H', 'L', 'O', 'K', 'S', '[', '(', '8', 'l', '\\', '1', 'n', 'i', 'I', '/', ']', 'Z', '#', 'r', '+', 'g', '.', '=', 'M', '2', 't']


In [2]:
char_to_idx = dict([(c, i) for i,c in enumerate(idx_to_char)])
char_size = len(idx_to_char)
smiles_indices = [[char_to_idx[c] for c in smile] for smile in smiles]
print(char_to_idx)
print(char_size)
print(smiles_indices[0])

{'s': 0, '3': 1, 'N': 2, 'd': 3, '-': 4, 'A': 5, '4': 6, '5': 7, 'e': 8, 'C': 9, 'P': 10, '9': 11, '@': 12, '6': 13, 'a': 14, ')': 15, 'B': 16, 'F': 17, '7': 18, 'H': 19, 'L': 20, 'O': 21, 'K': 22, 'S': 23, '[': 24, '(': 25, '8': 26, 'l': 27, '\\': 28, '1': 29, 'n': 30, 'i': 31, 'I': 32, '/': 33, ']': 34, 'Z': 35, '#': 36, 'r': 37, '+': 38, 'g': 39, '.': 40, '=': 41, 'M': 42, '2': 43, 't': 44}
45
[9, 9, 21, 9, 9, 9, 2, 9, 9, 25, 41, 21, 15, 2, 9, 29, 41, 9, 9, 41, 9, 25, 9, 41, 9, 29, 15, 21, 9, 25, 17, 15, 25, 17, 15, 17, 40, 9, 27]


In [3]:
import torch, time, math
from torch import nn, optim
import torch.nn.functional as F
import numpy as np
import d2lzh_pytorch as d2l


device = 'cpu'

In [4]:
num_input, num_hiddens, num_outputs = char_size, 256, char_size
print('will use', device)

def get_params():
    def _one(shape):
        ts = torch.tensor(np.random.normal(0, 0.01, size=shape), device=device, dtype=torch.float32)
        return torch.nn.Parameter(ts, requires_grad=True)

    # hidden
    W_xh = _one((num_input, num_hiddens))
    W_hh = _one((num_hiddens, num_hiddens))
    b_h = torch.nn.Parameter(torch.zeros(num_hiddens, device=device, requires_grad=True))

    # output
    W_hq = _one((num_hiddens, num_outputs))
    b_q = torch.nn.Parameter(torch.zeros(num_outputs, device=device, requires_grad=True))
    return nn.ParameterList([W_xh, W_hh, b_h, W_hq, b_q])

will use cpu


In [5]:
def init_rnn_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device),)

In [6]:
def rnn(inputs, state, params):
    W_xh, W_hh, b_h, W_hq, b_q = params
    H, = state
    outputs = []
    for X in inputs:
        H = torch.tanh(torch.matmul(X, W_xh) + torch.matmul(H, W_hh) + b_h)
        Y = torch.matmul(H, W_hq) + b_q
        outputs.append(Y)
    return outputs, (H,)

In [7]:
num_epochs, num_steps, batch_size, lr, clipping_theta = 2, 5, 4, 0.1, 1e-2
pred_len = 24
# 测试用例为所有训练的smiles的前5个字符
prefixes = [smile[:5] for smile in smiles]

print(prefixes[:10])

['CCOCC', 'COCCN', 'COCCN', 'C1CCC', 'COC1=', 'CCOC(', 'CCOC(', 'CC(=O', 'CCCCO', 'CC(C(']


In [9]:
def my_train_and_predict_rnn(rnn, get_params, init_rnn_state, num_hiddens,
                          vocab_size, device, corpus_indices, idx_to_char,
                          char_to_idx, is_random_iter, num_epochs, num_steps,
                          lr, clipping_theta, batch_size, pred_len, prefixes):
    if is_random_iter:
        data_iter_fn = d2l.data_iter_random
    else:
        data_iter_fn = d2l.data_iter_consecutive
    params = get_params()
    loss = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        mm = 0
        for smile in corpus_indices:
            if not is_random_iter:
                state = init_rnn_state(batch_size, num_hiddens, device)
            l_sum, n, start = 0.0, 0, time.time()
            data_iter = data_iter_fn(smile, batch_size, num_steps, device)
            for X, Y in data_iter:
                if is_random_iter:
                    state = init_rnn_state(batch_size, num_hiddens, device)
                else: 
                    for s in state:
                        s.detach_()
                
                inputs = d2l.to_onehot(X, vocab_size)
                (outputs, state) = rnn(inputs, state, params)
                outputs = torch.cat(outputs, dim=0)
                y = torch.transpose(Y, 0, 1).contiguous().view(-1)
                l = loss(outputs, y.long())
                
                if params[0].grad is not None:
                    for param in params:
                        param.grad.data.zero_()
                l.backward()
                d2l.grad_clipping(params, clipping_theta, device)
                d2l.sgd(params, lr, 1)
                l_sum += l.item() * y.shape[0]
                n += y.shape[0]

            mm += 1
            if mm % 20000 == 0:
                print(mm, " smiles were trained")
        print("Training has done.")

        return [d2l.predict_rnn(prefix, pred_len, rnn, params, init_rnn_state,
                num_hiddens, vocab_size, device, idx_to_char, char_to_idx) for prefix in prefixes]


In [10]:
ret = my_train_and_predict_rnn(rnn, get_params, init_rnn_state, num_hiddens, char_size, device, smiles_indices,
                            idx_to_char, char_to_idx, False, num_epochs, num_steps, lr, clipping_theta, batch_size,
                            pred_len, prefixes)
print("Done")

20000  smiles were trained
40000  smiles were trained
60000  smiles were trained
80000  smiles were trained
100000  smiles were trained
120000  smiles were trained
140000  smiles were trained
160000  smiles were trained
180000  smiles were trained
200000  smiles were trained
220000  smiles were trained
240000  smiles were trained
260000  smiles were trained
280000  smiles were trained
Training has done.
Done


In [11]:
for i in range(10):
    print(ret[i])

CCOCC1=CC=CC=C(C=C1)C(=O)CCC(
COCCN1CCC1=CC=CC=C(C=C1)C(=O)
COCCN1CCC1=CC=CC=C(C=C1)C(=O)
C1CCC(=O)NC(=O)C2=CC=CC=C(C=C
COC1=CC=CC=C(C=C1)C(=O)CCC(=O
CCOC(=O)CCC(=O)NC(=O)C2=CC=CC
CCOC(=O)CCC(=O)NC(=O)C2=CC=CC
CC(=O)NC(=O)C2=CC=CC=C(C=C1)C
CCCCOC1=CC=CC=C(C=C1)C(=O)CCC
CC(C(=O)NC(=O)C2=CC=CC=C(C=C1


In [12]:
from rdkit import Chem


cnt = 0
for smile in ret:
    if Chem.MolFromSmiles(smile) != None:
        cnt += 1

print(cnt / len(ret) * 100, end='%\n')

32.14205910764439%


In [14]:
""" 
    由于生成序列很难保证括号的合法性，因此加一步检验括号合法性的筛选，
    在所有括号合法的smiles里检验正确率
"""
def if_bracket_legal(smile):
    cnt = 0
    for c in smile:
        if c == '(':
            cnt += 1
        elif c == ')':
            cnt -= 1
        if cnt < 0:
            return False
    return cnt == 0

cnt, all_num = 0, 0
for smile in ret:
    if if_bracket_legal(smile):
        all_num += 1
        if Chem.MolFromSmiles(smile) != None:
            cnt += 1

print(cnt / all_num * 100, end='%\n')

54.03535741737125%
