In [1]:
from fairseq.models.transformer import TransformerModel

# Top-K

In [2]:
def translate_top_k(model, inputs, k=5):
    in_bins = []
    for inp in inputs:
        toks = model.tokenize(inp)
        bpe = model.apply_bpe(toks)
        in_bin = model.binarize(bpe)
        in_bins.append(in_bin)

    out_bins = model.generate(in_bins, beam=k)
    results = []
    for ob in out_bins:
        outs = []
        for o in ob:
            bpe = model.string(o['tokens'])
            toks = model.remove_bpe(bpe)
            out = model.detokenize(toks)
            outs.append(out)
        results.append(list(set(outs))[:k])
    return results

# FORWARD

In [3]:
forward_root_dir = '../forward'

In [4]:
forwardPred = TransformerModel.from_pretrained(
    f'{forward_root_dir}',
    checkpoint_file=f'{forward_root_dir}/checkpoints/checkpoint_best.pt',
    data_name_or_path=f'{forward_root_dir}/data-bin/uspto.rct-prd/',
    bpe='subword_nmt',
    bpe_codes=f'{forward_root_dir}/preprocess/uspto/code'
)

forwardPred.eval()

print('Load FORWARD model OK!')

Load FORWARD model OK!


In [5]:
ans = forwardPred.translate([
    'CCN(CC)CC.CCOCC.CS(=O)(=O)Cl.OCCCBr',
])

In [6]:
targets = [
    'CS(=O)(=O)OCCCBr',
]

In [7]:
ans[0] == targets[0]

True

In [8]:
inputs = [
    'CCN(CC)CC.CCOCC.CS(=O)(=O)Cl.OCCCBr',
]

translate_top_k(forwardPred, inputs, k=5)

[['CS(=O)(=O)OCCCBr',
  'CS(=O)(=O)C(O)CCBr',
  'CS(=O)(=O)NCCCBr',
  'CS(=O)(=O)CCCBr',
  'N#CCCCBr']]

# RETRO

In [9]:
retro_root_dir = '../retro'

In [10]:
retroPred = TransformerModel.from_pretrained(
    f'{retro_root_dir}',
    checkpoint_file=f'{retro_root_dir}/checkpoints/checkpoint_best.pt',
    data_name_or_path=f'{retro_root_dir}/data-bin/uspto.prd-rct/',
    bpe='subword_nmt',
    bpe_codes=f'{retro_root_dir}/preprocess/uspto/code'
)

retroPred.eval()

print('Load RETRO model OK!')

Load RETRO model OK!


In [11]:
ans = retroPred.translate([
    'CS(=O)(=O)OCCCBr',
])

In [12]:
targets = [
    'CS(=O)(=O)Cl.OCCCBr',
]


In [13]:
ans[0] == targets[0]

True

In [14]:
inputs = [
    'CS(=O)(=O)OCCCBr',
]

translate_top_k(retroPred, inputs, k=5)

[['CC(O)Br.CS(=O)(=O)Cl.O=C([O-])[O-]',
  'CC(O)Br.CS(=O)(=O)Cl',
  'CS(=O)(=O)Cl.OCCCBr',
  'CCC(O)Br.CS(=O)(=O)Cl.O=C([O-])O',
  'BrCCCBr.CS(=O)(=O)O']]