In [1]:
from fairseq.models.transformer import TransformerModel

In [2]:
root_dir = '/storage/jason/workspace/nlp-rxn'

In [3]:
rxnPred = TransformerModel.from_pretrained(
    f'{root_dir}',
    checkpoint_file=f'{root_dir}/checkpoints/checkpoint_best.pt',
    data_name_or_path=f'{root_dir}/data-bin/uspto.rct-prd/',
    bpe='subword_nmt',
    bpe_codes=f'{root_dir}/preprocess/uspto/code'
)

rxnPred.eval()

print('Load model OK!')

Load model OK!


In [4]:
ans = rxnPred.translate([
    'CS(=O)(=O)Cl.OCCCBr',
])

In [5]:
ans[0] == 'CS(=O)(=O)OCCCBr'

True

# Verify Accuracy

In [6]:
root_test_dir = root_dir + '/preprocess/orig/uspto'
test_rct, test_prd = 'test.rct', 'test.prd'

In [7]:
test_rcts = None
with open(f'{root_test_dir}/{test_rct}', 'r') as f:
    test_rcts = f.readlines()

test_prds = None
with open(f'{root_test_dir}/{test_prd}', 'r') as f:
    test_prds = f.readlines()

In [8]:
ans = rxnPred.translate(test_rcts)

In [9]:
count_correct = 0
count_total = len(test_rcts)
for idx in range(count_total):
    if ans[idx] == test_prds[idx][:-1]:
        count_correct += 1
        
print('Accuracy is: {:.2f} %'.format(100.0 * float(count_correct) / count_total))

Accuracy is: 72.24 %


In [10]:
count_correct

36305

# RDKIT canoSmiles check

In [11]:
from rdkit import Chem

In [12]:
import rdkit.rdBase as rkrb
import rdkit.RDLogger as rkl
logger = rkl.logger()
logger.setLevel(rkl.ERROR)
rkrb.DisableLog('rdApp.error')

In [13]:
count_correct = 0
count_total = len(test_rcts)
count_error = 0
for idx in range(count_total):
    a = ans[idx]
    t = test_prds[idx][:-1]
    try:
        if Chem.CanonSmiles(a) == Chem.CanonSmiles(t):
            count_correct += 1
    except:
        count_error += 1
        
print('Accuracy is: {:.2f} %'.format(100.0 * float(count_correct) / count_total))

Accuracy is: 73.43 %


In [14]:
print('Accuracy is: {:.2f} %'.format(100.0 * float(count_error) / count_total))

Accuracy is: 0.74 %
