In [1]:
from rxnmapper import RXNMapper
rxn_mapper = RXNMapper()

In [3]:
rxns = ['CC(C)S.CN(C)C=O.Fc1cccnc1F.O=C([O-])[O-].[K+].[K+]>>CC(C)Sc1ncccc1F', 'C1COCCO1.CC(C)(C)OC(=O)CONC(=O)NCc1cccc2ccccc12.Cl>>O=C(O)CONC(=O)NCc1cccc2ccccc12']
results = rxn_mapper.get_attention_guided_atom_maps(rxns)
results

[{'mapped_rxn': 'CN(C)C=O.F[c:5]1[n:6][cH:7][cH:8][cH:9][c:10]1[F:11].O=C([O-])[O-].[CH3:1][CH:2]([CH3:3])[SH:4].[K+].[K+]>>[CH3:1][CH:2]([CH3:3])[S:4][c:5]1[n:6][cH:7][cH:8][cH:9][c:10]1[F:11]',
  'confidence': 0.9565620983992522},
 {'mapped_rxn': 'C1COCCO1.CC(C)(C)[O:3][C:2](=[O:1])[CH2:4][O:5][NH:6][C:7](=[O:8])[NH:9][CH2:10][c:11]1[cH:12][cH:13][cH:14][c:15]2[cH:16][cH:17][cH:18][cH:19][c:20]12.Cl>>[O:1]=[C:2]([OH:3])[CH2:4][O:5][NH:6][C:7](=[O:8])[NH:9][CH2:10][c:11]1[cH:12][cH:13][cH:14][c:15]2[cH:16][cH:17][cH:18][cH:19][c:20]12',
  'confidence': 0.9704424699539764}]

## Step by Step

In [36]:
from transformers import AlbertForMaskedLM
import pkg_resources
import torch
import os
from rxnmapper.tokenization_smiles import SmilesTokenizer


model_path = pkg_resources.resource_filename(
    "rxnmapper",
    "models/transformers/albert_heads_8_uspto_all_1310k"
)
vocab_path = os.path.join(model_path, "vocab.txt")
model = AlbertForMaskedLM.from_pretrained(
            model_path,
            output_attentions=False,
            output_past=False,
            output_hidden_states=False,
        )
tokenizer = SmilesTokenizer(
            vocab_path, max_len=model.config.max_position_embeddings
        )

In [37]:
device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )
model.to(device)        


AlbertForMaskedLM(
  (albert): AlbertModel(
    (embeddings): AlbertEmbeddings(
      (word_embeddings): Embedding(591, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): AlbertTransformer(
      (embedding_hidden_mapping_in): Linear(in_features=128, out_features=256, bias=True)
      (albert_layer_groups): ModuleList(
        (0): AlbertLayerGroup(
          (albert_layers): ModuleList(
            (0): AlbertLayer(
              (full_layer_layer_norm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
              (attention): AlbertAttention(
                (query): Linear(in_features=256, out_features=256, bias=True)
                (key): Linear(in_features=256, out_features=256, bias=True)
                (value): Linear(in_features=256, out_features=256, bias=True)
  

In [7]:
from rxnmapper.smiles_utils import process_reaction
rxns_canon = [process_reaction(rxn) for rxn in rxns] # canonicalize
encoded_ids = tokenizer.batch_encode_plus(
    rxns_canon,
    padding=True,
    return_tensors="pt",
)
encoded_ids, tokenizer.batch_decode(encoded_ids.input_ids)

({'input_ids': tensor([[12, 16, 16, 17, 16, 18, 34, 24, 16, 23, 17, 16, 18, 16, 22, 19, 24, 27,
          15, 20, 15, 15, 15, 25, 15, 20, 27, 24, 19, 22, 16, 17, 36, 18, 36, 24,
          51, 24, 51, 29, 16, 16, 17, 16, 18, 34, 15, 20, 25, 15, 15, 15, 15, 20,
          27, 13,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
           0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [12, 16, 20, 16, 19, 16, 16, 19, 20, 24, 16, 16, 17, 16, 18, 17, 16, 18,
          19, 16, 17, 22, 19, 18, 16, 19, 23, 16, 17, 22, 19, 18, 23, 16, 15, 20,
          15, 15, 15, 15, 21, 15, 15, 15, 15, 15, 20, 21, 24, 28, 29, 19, 22, 16,
          17, 19, 18, 16, 19, 23, 16, 17, 22, 19, 18, 23, 16, 15, 20, 15, 15, 15,
          15, 21, 15, 15, 15, 15, 15, 20, 21, 13]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [10]:
parsed_input = { 
            k: v.to(device) for k, v in encoded_ids.items()
        }
with torch.no_grad():
    output = model(**parsed_input)
output, output.logits.shape

(MaskedLMOutput(loss=None, logits=tensor([[[ -7.5936, -15.5652, -16.5842,  ..., -15.5250, -15.0628, -16.4616],
          [ -8.6925, -14.5749,  -9.4798,  ..., -14.4005, -11.6564, -13.1309],
          [ -9.1922, -14.8548,  -9.7666,  ..., -16.3378, -13.1260, -15.1557],
          ...,
          [ -8.8833, -10.0049, -12.6157,  ..., -13.3744, -13.0116,  -8.9807],
          [-11.6483, -12.4890, -13.6977,  ..., -14.6326, -15.4447, -13.7454],
          [ -8.8414, -12.1751, -11.9319,  ..., -13.9214, -13.8918, -13.2628]],
 
         [[ -8.0817, -15.9391, -16.5254,  ..., -15.6952, -15.3414, -16.5301],
          [-11.4589, -15.8037, -10.1328,  ..., -14.1449, -11.3160, -15.0684],
          [-14.1538, -16.3130, -16.4539,  ..., -18.7450, -18.1558, -17.0857],
          ...,
          [ -9.7115, -13.8946, -15.2012,  ..., -16.1713, -16.7350, -14.5115],
          [ -4.6475, -13.9797, -13.7251,  ..., -15.7515, -13.5745, -13.5006],
          [-10.1286, -13.5812, -12.6343,  ..., -14.7465, -12.8332, -14.5790]

In [13]:
tokenizer.mask_token_id, rxns[0]

(14, 'CC(C)S.CN(C)C=O.Fc1cccnc1F.O=C([O-])[O-].[K+].[K+]>>CC(C)Sc1ncccc1F')

In [30]:
one_rxn = rxns[0]
one_rxn_canon = process_reaction(one_rxn)
one_encoded_ids = tokenizer.encode(one_rxn_canon, return_tensors="pt")
lx = one_encoded_ids.shape[1]
squared_input = one_encoded_ids.repeat(lx,1)
for i in range(0,lx):
    squared_input[i,i] = tokenizer.mask_token_id
squared_input, tokenizer.batch_decode(squared_input)

(tensor([[14, 16, 16,  ..., 20, 27, 13],
         [12, 14, 16,  ..., 20, 27, 13],
         [12, 16, 14,  ..., 20, 27, 13],
         ...,
         [12, 16, 16,  ..., 14, 27, 13],
         [12, 16, 16,  ..., 20, 14, 13],
         [12, 16, 16,  ..., 20, 27, 14]]),
 ['[MASK] C C ( C ) S. C N ( C ) C = O. F c 1 c c c n c 1 F. O = C ( [O-] ) [O-]. [K+]. [K+] >> C C ( C ) S c 1 n c c c c 1 F [SEP]',
  '[CLS] [MASK] C ( C ) S. C N ( C ) C = O. F c 1 c c c n c 1 F. O = C ( [O-] ) [O-]. [K+]. [K+] >> C C ( C ) S c 1 n c c c c 1 F [SEP]',
  '[CLS] C [MASK] ( C ) S. C N ( C ) C = O. F c 1 c c c n c 1 F. O = C ( [O-] ) [O-]. [K+]. [K+] >> C C ( C ) S c 1 n c c c c 1 F [SEP]',
  '[CLS] C C [MASK] C ) S. C N ( C ) C = O. F c 1 c c c n c 1 F. O = C ( [O-] ) [O-]. [K+]. [K+] >> C C ( C ) S c 1 n c c c c 1 F [SEP]',
  '[CLS] C C ( [MASK] ) S. C N ( C ) C = O. F c 1 c c c n c 1 F. O = C ( [O-] ) [O-]. [K+]. [K+] >> C C ( C ) S c 1 n c c c c 1 F [SEP]',
  '[CLS] C C ( C [MASK] S. C N ( C ) C = O. F c 1 c 

In [31]:
squared_input = squared_input.to(device)

In [55]:
with torch.no_grad():
    res = model(squared_input)
logits = res.logits.detach().to("cpu")
res= None
torch.cuda.empty_cache()

In [73]:
logits_for_masks = torch.stack([logits[i][i] for i,_ in enumerate(logits)])
logits_for_masks.shape

torch.Size([56, 591])

In [83]:
prob_for_masks = torch.nn.functional.softmax(logits_for_masks, dim=-1)
prob_for_masks.shape, torch.sum(prob_for_masks)

(torch.Size([56, 591]), tensor(56.0000))

In [84]:
top_k_calc = torch.topk(prob_for_masks,k=10)
top_k_calc.indices, top_k_calc

(tensor([[ 21,  16,  15,  25,  20,  22,  17, 148,  75, 179],
         [ 16, 188, 209,  47,  37,  64, 166,  27, 146, 127],
         [ 16, 155,  15,  23,  33,  35, 146,  22, 166,  18],
         [ 17,  18,  16,  19,  38,  24,  30,  22, 100,  23],
         [ 16, 188, 249,  23,  19,  27, 299, 155, 209, 146],
         [ 18,  16,  19,  24,  23,  27,  34,  22,  37,  38],
         [ 34, 112,  48,  19,  23, 249, 173,  28,  87, 126],
         [ 24,  16,  18,  30,  38,  48,  22,  29,  15, 101],
         [ 16, 157, 146,  64,  92, 188, 196,  43,  24,  68],
         [ 23, 242,  16, 244, 121, 139,  61,  26, 155, 184],
         [ 17, 100,  24,  30,  29, 126,  18,  16,  19,  22],
         [ 16,  23, 146,  36,  61, 199, 154,  20, 210,  35],
         [ 18,  16,  24,  22,  17,  30,  20,  39,  29,  19],
         [ 16, 146, 164,  41,  35,  15,  27,  36,  23, 210],
         [ 22,  16,  24,  30,  99,  11,  38, 247,  17,  36],
         [ 19, 116,  23,  16,  61,  34, 121, 198,  24,  35],
         [ 24,  18,  30,

In [85]:

tokenizer.batch_decode(top_k_calc.indices.numpy().tolist(), clean_up_tokenization_spaces=False)

['2 C c n 1 = ( [Hg] [Cu] [Bi]',
 'C [CH-] [Ge] [Si] Br [2H] [Ce] F [C+] [SiH2]',
 'C [CH2+] c N [C@H] [C@@H] [C+] = [Ce] )',
 '( ) C O # . ~ = [Mn] N',
 'C [CH-] [SeH] N O F [Cr+2] [CH2+] [Ge] [C+]',
 ') C O . N F S = Br #',
 'S [S-] I O N [SeH] [SH-] Cl [Zn] [Se]',
 '. C ) ~ # I = >> c [AlH]',
 'C [CH2-] [C+] [2H] [BH3-] [CH-] [AlH2-] 5 . [Cs+]',
 'N [N@@] C [N@] [NH+] [SH] [N-] 3 [CH2+] [As]',
 '( [Mn] . ~ >> [Se] ) C O =',
 'C N [C+] [O-] [N-] [Rh+2] [Mo] 1 [Pb+4] [C@@H]',
 ') C . = ( ~ 1 / >> O',
 'C [C+] [PH2] [N+] [C@@H] c F [O-] N [Pb+4]',
 '= C . ~ [Cu+2] [UNK] # [Al+] ( [O-]',
 'O [NH2+] N C [N-] S [NH+] [Ti+3] . [C@@H]',
 '. ) ~ N C [O-] # [N@@] [Se] Cl',
 'F Cl C N O [O-] Br I [F-] [SiH3]',
 'c n [CH2+] p [C+] [Fe-3] C [n+] [Pt-2] [Ge]',
 '1 [K] [SH] [NH4+] c [Rh] [Na] [Ag+] [Zn+] [NH+]',
 'c n s C [Ce+3] [nH] [CH2+] p 3 [Pb+4]',
 'c n [te] [CH2+] [Pb+4] [NH4+] 1 [Ce+3] [o+] [Fe-3]',
 'c n [o+] [CH2+] C [NH4+] [te] [CH-] [Rh+2] [C+]',
 'n c [o+] [te] [N@@] [n+] [nH] = N [Sc

In [104]:
t_probs, t_ranks = torch.sort(prob_for_masks, descending=True)
ranks_correct = [(t_ranks[i] == tid).nonzero().tolist()[0][0] for i, tid in enumerate(one_encoded_ids[0])]
ranks_correct


[379,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 15]

In [105]:
all_stats = {
    "topk_probs": top_k_calc.values.numpy().tolist(),
    "topk_tokens": [[tokenizer.convert_ids_to_tokens(seq)for seq in sample] for sample in top_k_calc.indices.numpy().tolist()],
    "ranks": ranks_correct
}

all_stats

{'topk_probs': [[0.11130529642105103,
   0.09437178075313568,
   0.09109871834516525,
   0.06392545253038406,
   0.047387197613716125,
   0.03893246129155159,
   0.03733941912651062,
   0.03398270159959793,
   0.033031564205884933,
   0.03241202235221863],
  [0.9999030828475952,
   4.965277184965089e-05,
   1.1889126653841231e-05,
   3.971201294916682e-06,
   3.1189997571345884e-06,
   2.3762270302540855e-06,
   2.3427471660397714e-06,
   1.769982645782875e-06,
   1.3530710702980286e-06,
   1.2868537169197225e-06],
  [0.9999681711196899,
   1.1352793080732226e-05,
   8.843916475598235e-06,
   1.8527322254158207e-06,
   1.4515687780658482e-06,
   9.642348004490486e-07,
   8.140411864587804e-07,
   8.041430987759668e-07,
   7.745834409433883e-07,
   6.737744797646883e-07],
  [0.9999972581863403,
   1.054706672221073e-06,
   6.787763027205074e-07,
   2.5513239165775303e-07,
   1.328741774386799e-07,
   8.70189396096066e-08,
   8.390105676880921e-08,
   7.94309187313047e-08,
   6.889070647