Data source: https://huggingface.co/datasets/sagawa/pubchem-10m-canonicalized

In [1]:
from ast import Import
from tkinter import YES
from xml.parsers.expat import model
from translation_smi4decoder_with_mask import TrfmSeq2seq
import torch
from torch.nn import functional as F
from tqdm import tqdm
from build_vocab import WordVocab
from torch.utils import data
import pandas as pd 
import math
import numpy as np
import collections
import copy
from typing import Callable, List
from torch.autograd import Variable
from collections import namedtuple
from utils_attention_visiualization import visualize_attention

In [2]:
# from train.TranslationModel import TrfmSeq2seq_pubchem
Hypothesis = namedtuple('Hypothesis', ['value', 'score'])
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

src_vocab = WordVocab.load_vocab('data/pubchem/smi_vocab1.pkl')
tgt_vocab = WordVocab.load_vocab('data/pubchem/pharm_vocab1.pkl')
inchi_vocab = WordVocab.load_vocab('data/pubchem/inchi_vocab1.pkl')
pubchemfp_vocab = WordVocab.load_vocab('data/pubchem/pubfp_vocab1.pkl')

In [3]:
def tokenize_nmt(text):
    source, target = [], []
    for i in tqdm(range(len(text)),desc='tokenize_nmt进度:'):
        line = text[i].strip('\n').strip(' ')
        parts = line.split('\t')
        if len(parts) == 2:
            source.append(parts[0].split(' '))
            target.append(parts[1].split(' '))
    return source, target



In [4]:
def split_sm(sm):
    '''
    function: Split SMILES into words. Care for Cl, Br, Si, Se, Na etc.
    input: A SMILES
    output: A string with space between words
    '''
    arr = []
    i = 0
    while i < len(sm)-1:
        if not sm[i] in ['%', 'C', 'B', 'S', 'N', 'R', 'X', 'L', 'A', 'M', \
                        'T', 'Z', 's', 't', 'H', '+', '-', 'K', 'F']:
            arr.append(sm[i])
            i += 1
        elif sm[i]=='%':
            arr.append(sm[i:i+3])
            i += 3
        elif sm[i]=='C' and sm[i+1]=='l':
            arr.append(sm[i:i+2])
            i += 2
        elif sm[i]=='C' and sm[i+1]=='a':
            arr.append(sm[i:i+2])
            i += 2
        elif sm[i]=='C' and sm[i+1]=='u':
            arr.append(sm[i:i+2])
            i += 2
        elif sm[i]=='B' and sm[i+1]=='r':
            arr.append(sm[i:i+2])
            i += 2
        elif sm[i]=='B' and sm[i+1]=='e':
            arr.append(sm[i:i+2])
            i += 2
        elif sm[i]=='B' and sm[i+1]=='a':
            arr.append(sm[i:i+2])
            i += 2
        elif sm[i]=='B' and sm[i+1]=='i':
            arr.append(sm[i:i+2])
            i += 2
        elif sm[i]=='S' and sm[i+1]=='i':
            arr.append(sm[i:i+2])
            i += 2
        elif sm[i]=='S' and sm[i+1]=='e':
            arr.append(sm[i:i+2])
            i += 2
        elif sm[i]=='S' and sm[i+1]=='r':
            arr.append(sm[i:i+2])
            i += 2
        elif sm[i]=='N' and sm[i+1]=='a':
            arr.append(sm[i:i+2])
            i += 2
        elif sm[i]=='N' and sm[i+1]=='i':
            arr.append(sm[i:i+2])
            i += 2
        elif sm[i]=='R' and sm[i+1]=='b':
            arr.append(sm[i:i+2])
            i += 2
        elif sm[i]=='R' and sm[i+1]=='a':
            arr.append(sm[i:i+2])
            i += 2
        elif sm[i]=='X' and sm[i+1]=='e':
            arr.append(sm[i:i+2])
            i += 2
        elif sm[i]=='L' and sm[i+1]=='i':
            arr.append(sm[i:i+2])
            i += 2
        elif sm[i]=='A' and sm[i+1]=='l':
            arr.append(sm[i:i+2])
            i += 2
        elif sm[i]=='A' and sm[i+1]=='s':
            arr.append(sm[i:i+2])
            i += 2
        elif sm[i]=='A' and sm[i+1]=='g':
            arr.append(sm[i:i+2])
            i += 2
        elif sm[i]=='A' and sm[i+1]=='u':
            arr.append(sm[i:i+2])
            i += 2
        elif sm[i]=='M' and sm[i+1]=='g':
            arr.append(sm[i:i+2])
            i += 2
        elif sm[i]=='M' and sm[i+1]=='n':
            arr.append(sm[i:i+2])
            i += 2
        elif sm[i]=='T' and sm[i+1]=='e':
            arr.append(sm[i:i+2])
            i += 2
        elif sm[i]=='Z' and sm[i+1]=='n':
            arr.append(sm[i:i+2])
            i += 2
        elif sm[i]=='s' and sm[i+1]=='i':
            arr.append(sm[i:i+2])
            i += 2
        elif sm[i]=='s' and sm[i+1]=='e':
            arr.append(sm[i:i+2])
            i += 2
        elif sm[i]=='t' and sm[i+1]=='e':
            arr.append(sm[i:i+2])
            i += 2
        elif sm[i]=='H' and sm[i+1]=='e':
            arr.append(sm[i:i+2])
            i += 2
        elif sm[i]=='+' and sm[i+1]=='2':
            arr.append(sm[i:i+2])
            i += 2
        elif sm[i]=='+' and sm[i+1]=='3':
            arr.append(sm[i:i+2])
            i += 2
        elif sm[i]=='+' and sm[i+1]=='4':
            arr.append(sm[i:i+2])
            i += 2
        elif sm[i]=='-' and sm[i+1]=='2':
            arr.append(sm[i:i+2])
            i += 2
        elif sm[i]=='-' and sm[i+1]=='3':
            arr.append(sm[i:i+2])
            i += 2
        elif sm[i]=='-' and sm[i+1]=='4':
            arr.append(sm[i:i+2])
            i += 2
        elif sm[i]=='K' and sm[i+1]=='r':
            arr.append(sm[i:i+2])
            i += 2
        elif sm[i]=='F' and sm[i+1]=='e':
            arr.append(sm[i:i+2])
            i += 2
        else:
            arr.append(sm[i])
            i += 1
    if i == len(sm)-1:
        arr.append(sm[i])
    return ' '.join(arr)

In [5]:
def generate_square_subsequent_mask(sz, device):
        mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

def build_array_nmt(lines, vocab, num_steps):
    A,masks = [],[]
    for i in range(len(lines)):
        line = lines[i]
        content = [vocab.stoi.get(token, vocab.unk_index) for token in line]
        if len(content) > num_steps -2 : 
            content = content[:254]
        X = content + [vocab.eos_index]
        masks.append([False for _ in range(len(X))] + [True for _ in range(num_steps - len(X))])
        padding = [vocab.pad_index]*(num_steps - len(X))
        X.extend(padding)
        # array = torch.tensor([truncate_pad(l, num_steps, vocab.pad_index) for l in lines])
        A.append(copy.deepcopy(X))
    return torch.tensor(A), torch.tensor(masks)

def load_array(data_arrays, batch_size, is_train=True): #@save
    dataset = data.TensorDataset(*data_arrays)
    return data.DataLoader(dataset, batch_size, shuffle=is_train)

def truncate_pad(line, num_steps, padding_token):
    if len(line) > num_steps:
        return line[:num_steps] # 截断
    return line + [padding_token] * (num_steps - len(line)) # 填充

def get_dataset(text, src_vocab, tgt_vocab,seq_len):
    source,target = tokenize_nmt(text)
    target, tgt_valid_len = build_array_nmt(target, tgt_vocab, num_steps=seq_len) 
    source, src_valid_len = build_array_nmt(source, src_vocab, num_steps=seq_len) 
    dataset = (source, src_valid_len, target, tgt_valid_len)
    return dataset

def greedy_decode(mode,model, src, max_len, start_symbol ):

    src = src.cuda()
    src_mask = (src == src_vocab.pad_index).transpose(0,1).cuda()
    memory = model.Encoder(src,src_mask).cuda()  
    ys = torch.ones(1, 1).fill_(start_symbol). \
        type(torch.long).cuda()
    
    for i in range(max_len - 1):
       
        memory = memory.cuda()

        if mode == 1:
            out = model.Decoder1(ys,memory,src_mask)
            out = out.transpose(0, 1) 
            prob = model.out1(out[:, -1])
        if mode == 2:
            out = model.Decoder2(ys,memory,src_mask)
            out = out.transpose(0, 1)
            prob = model.out2(out[:, -1]) 
        if mode == 3:
            out = model.Decoder3(ys,memory,src_mask)
            out = out.transpose(0, 1)
            prob = model.out3(out[:, -1]) 
        if mode == 4:
            out = model.Decoder4(ys,memory,src_mask)
            out = out.transpose(0, 1)
            prob = model.out4(out[:, -1]) 

        _, next_word = torch.max(prob, dim=1)  # 选Choose the one with the maximum probability
        next_word = next_word.item()
        ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        # The predictive output decoded at the current moment is stacked with all previous results as input to predict the next word
        if next_word == src_vocab.eos_index:  # If the output for the current moment is eos, end the loop.
            break
    return ys


In [6]:
def translate(model, src, mode):
    model.eval()
    src_token = split_sm(src).split()
    tokens = [src_vocab.stoi.get(token, src_vocab.unk_index) for token in src_token]
    if len(tokens) > 254 : 
            tokens = tokens[:254]
    tokens = tokens + [src_vocab.eos_index]
    src_tokens = truncate_pad(tokens,256, src_vocab.pad_index)
    num_tokens = len(tokens)
    src = torch.t(torch.tensor(src_tokens).unsqueeze(0)).cuda()
    tgt_tokens = greedy_decode(mode,model, src, max_len=256,start_symbol=tgt_vocab.sos_index).flatten() 

    if mode == 1:
        return " ".join([tgt_vocab.itos[tok] for tok in tgt_tokens]).replace("<sos>", "").replace("<eos>", "")
    if mode == 2:
        return " ".join([src_vocab.itos[tok] for tok in tgt_tokens]).replace("<sos>", "").replace("<eos>", "")
    if mode == 3:
        return " ".join([inchi_vocab.itos[tok] for tok in tgt_tokens]).replace("<sos>", "").replace("<eos>", "")
    if mode == 4:
        return " ".join([pubchemfp_vocab.itos[tok] for tok in tgt_tokens]).replace("<sos>", "").replace("<eos>", "")


def translate_smi_to_pharm(mode,srcs, hidden_size,n_heads,n_layers):

    translation_model = TrfmSeq2seq(len(src_vocab),hidden_size, len(tgt_vocab), len(src_vocab),
                                         len(inchi_vocab), len(pubchemfp_vocab), 
                                         n_heads,
                                         n_layers,
                                         )

    translation_model.eval()
    translation_model = translation_model.cuda()

    loaded_paras = torch.load('trfm_new_98_10000.pkl') 
    translation_model.load_state_dict(loaded_paras)
    results = []
    for src in tqdm(srcs):
        r = translate(translation_model, src, mode)
        # r = predict_seq2seq(translation_model, src)
        results.append(r)
    return results

In [7]:
def bleu(pred_tokens, label_tokens, k):
    """calculate BLEU"""
    len_pred, len_label = len(pred_tokens), len(label_tokens)
    score = math.exp(min(0, 1 - len_label / len_pred))
    for n in range(1, k + 1):
        num_matches, label_subs = 0, collections.defaultdict(int)
        for i in range(len_label - n + 1):
            label_subs[''.join(label_tokens[i: i + n])] += 1
        for i in range(len_pred - n + 1):
            if label_subs[''.join(pred_tokens[i: i + n])] > 0:
                num_matches += 1
                label_subs[''.join(pred_tokens[i: i + n])] -= 1
        # try:
        score *= math.pow(num_matches / (len_pred - n + 1), math.pow(0.5, n))
        # except:
        #     score = 0
    return score

def predict(srcs, mode):
    results = translate_smi_to_pharm(mode,srcs, hidden_size=256,n_heads=4,n_layers=2)
    
    return results

In [8]:
def get_golden_st(task):
    data = corpus[task]
    results = []
    for i in range(len(data)):
        results.append(data[i].split(' '))
    return results

In [9]:
fileph = 'data/transltation_test/smi_pharm_inchi_corpus1.txt'
corpus = pd.read_csv(fileph,delimiter="\t",header=None)
corpus.columns = ['smiles','pharm','inchi','pubchemfp']
srcs = get_golden_st('smiles')
pharm = get_golden_st('pharm')
inchi = get_golden_st('inchi')
pubchemfp = get_golden_st('pubchemfp')

In [10]:
for mode in [1,2,3,4]:    
    results = predict(srcs,mode)
    output = []
    for i in range(len(results)):
        output.append(results[i].split(' ')[1:-1])
    BLUE = []

    for i in range(len(output)):
        # print(f'bleu {bleu(output[i], pharm[i], k=3):.3f}')
        if mode == 1:
            try:
                BLUE.append(bleu(output[i], pharm[i], k=1))
            except:
                continue
        if mode == 2: 
            try:
                BLUE.append(bleu(output[i], srcs[i], k=3))    
            except:
                continue
        if mode == 3:
            try:
                BLUE.append(bleu(output[i], inchi[i], k=3))
            except:
                continue
        if mode == 4:
            try:
                BLUE.append(bleu(output[i], pubchemfp[i], k=1))
            except:
                continue
    print(np.mean(BLUE), mode)

100%|██████████| 71009/71009 [25:19:02<00:00,  1.28s/it]   


0.7437027200383395 1


100%|██████████| 71009/71009 [13:36:16<00:00,  1.45it/s]  


0.9998477311422149 2


100%|██████████| 71009/71009 [25:33:29<00:00,  1.30s/it]   


0.7381593471383687 3


100%|██████████| 71009/71009 [46:48:41<00:00,  2.37s/it]   


0.968787643673169 4
