In [None]:
model_dir = "/home/arvid/models/chemformers/2021-10-06/14:44:00/"
from Transformer import *
from Transformer import Seq2SeqTransformer
from Tools import *
from modules import *
from constants import *
from Reaction_dataset import *

args = pickle.load(open(model_dir + "params.pickle", "rb"))

DEVICE = "cuda:2"
torch.torch.cuda.set_device(DEVICE)
print("Using device " + str(torch.torch.cuda.current_device()) + "/" + str(torch.cuda.device_count())
      +", name: " + str(torch.cuda.get_device_name(0)))

chemFormer = Seq2SeqTransformer(num_encoder_layers=args["N_ENCODERS"],
                                num_decoder_layers=args["N_DECODERS"],
                                emb_size=args["EMBEDDING_SIZE"],
                                nhead=args["N_HEADS"],
                                src_vocab_size=SRC_VOCAB_SIZE,
                                tgt_vocab_size=TGT_VOCAB_SIZE,
                                dim_feedforward=args["DIM_FF"],
                                dropout=args["DROPOUT"],
                                DEVICE=DEVICE).to(DEVICE)

chemFormer.load_state_dict(torch.load(model_dir + "/weights"))

file = open("/home/arvid/data/USTPO_paper_5x/USTPO_5x_parsed.pickle",'rb')
data = pickle.load(file)
if True: 
        data = {"train": data["train"][0:64], "eval": data["eval"][0:64]}
datasets = {}
dataloaders = {}
for split in ['eval', "train"]:
    datasets[split] = ReactionDataset(data=data,
                                        split=split,
                                        args=args)

    dataloaders[split] = DataLoader(datasets[split],
                                    batch_size=args["BATCH_SIZE"],
                                    shuffle=(split != 'test'),
                                    num_workers=8,
                                    pin_memory=False,
                                    drop_last=True)
    
chemFormer.eval()




In [None]:
# Code always computes embedding, could be more efficient.
split = "train"

# Init set and put all ixs from eval to be inferred
Q = set()
for i in range(len(datasets[split])): Q.add(i)
#for i in range(1000): Q.add(i)
inferred = {}
print(len(datasets[split]))

# Init src and tgt tensors, free_ixs_tracker & element tracker
capacity = 32
src = torch.zeros([MAX_SEQ_LEN, capacity]).to(DEVICE)
tgt = torch.zeros([MAX_SEQ_LEN, capacity]).to(DEVICE)
element_tracker = {}
free_ixs_tracker = set()
for e in range(capacity): free_ixs_tracker.add(e)
j = 0
# Infer while PQ nonempty
while Q != set() or len(free_ixs_tracker) != capacity:
    j += 1 
    print(j)
    # Track elements and fill a batch
    if Q != set():
        while free_ixs_tracker != set():
            # Pop data item for a free ix in tensor, and record where it is put & ix where to get next word
            free_ix = free_ixs_tracker.pop()
            data_ix = Q.pop()
            element_tracker[free_ix] = [data_ix, 0]
            data = torch.tensor(datasets[split].__getitem__(data_ix)["ps"]).to(DEVICE)
            src[:,free_ix] = data
            tgt[:,free_ix] = PAD_IDX
            tgt[0,free_ix] = BOS_IDX
        
    # Calculate all masks
    src_mask, tgt_mask, src_key_padding_mask, tgt_key_padding_mask = create_mask(src=src,
                                                                         tgt=tgt,
                                                                         DEVICE=DEVICE)
    # Compute embedding
    memory = chemFormer.encode(src=src,
                    src_mask=src_mask,
                    src_key_padding_mask=src_key_padding_mask)
    
    # Decode until a sequence finishes
    while True:
        
        # Calculate all masks again
        src_mask, tgt_mask, src_key_padding_mask, tgt_key_padding_mask = create_mask(src=src,
                                                                                     tgt=tgt,
                                                                                     DEVICE=DEVICE)
        
        logits = chemFormer.decode(tgt=tgt,
                                   memory=memory,
                                   tgt_mask=tgt_mask,
                                   memory_key_padding_mask=src_key_padding_mask,
                                   tgt_key_padding_mask=tgt_key_padding_mask)
        probs = chemFormer.generator(logits)

        # Iterate over tracked seqs
        should_break = False
        for k,v in element_tracker.items():
            if element_tracker[k][0] in inferred: continue
            # Find next word for each seq, record that seq expands
            if element_tracker[k][1] == MAX_SEQ_LEN - 1:
                inferred[element_tracker[k][0]] = None
                free_ixs_tracker.add(k)
                should_break = True
            
            else:
                next_word = torch.argmax(probs[element_tracker[k][1], k,:])

                element_tracker[k][1] += 1

                tgt[element_tracker[k][1], k] = next_word

                # If seq ends, need to save result & remove from capacity
                if next_word == EOS_IDX:
                    seq = tgt[:,k]
                    seq = tokens_to_smiles(seq[torch.where(seq != PAD_IDX)][1:-1].cpu().numpy())
                    gt = datasets[split].__getitem(element_tracker[k][0])["rs"]
                    gt = tokens_to_smiles(gt[np.where(gt != PAD_IDX)][1:-1])
                    
                    inferred[element_tracker[k][0]] = {"pred": seq, "tgt": gt}
                    free_ixs_tracker.add(k)
                    should_break = True
                
        if should_break == True: break
    

In [None]:
from rdkit import RDLogger 
RDLogger.DisableLog('rdApp.*') 

TP = 0
FP = 0
for k,v in inferred.items():
    if v is None:
        FP += 1
        continue
    out_smiles = v["pred"]
    tgt_smiles = v["tgt"]
    out_mols = [Chem.MolFromSmiles(e) for e in out_smiles]
    tgt_mols = [Chem.MolFromSmiles(e) for e in tgt_smiles]
    if None in out_mols:
        FP += 1
    else:
        out_smiles_canonical = [Chem.MolToSmiles(e) for e in out_mols]
        tgt_smiles_canonical = [Chem.MolToSmiles(e) for e in tgt_mols]

        if set(out_smiles_canonical) == set(tgt_smiles_canonical): TP += 1
        else: FP += 1

print(TP/(TP+FP))

In [None]:
inferred

In [None]:
from Tools import tokens_to_smiles
# greedy_decode will be used in a later stage to test our model
def greedy_decode(model, src, src_mask, max_len, start_symbol, DEVICE):
    src = src.to(DEVICE).unsqueeze(-1)
    #print(src.shape)
    src_mask = src_mask.to(DEVICE)
    #print(src_mask.shape)
    memory = model.encode(src, src_mask)
    #print(memory.shape)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
    for i in range(max_len-1):
        memory = memory.to(DEVICE)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0), DEVICE)
                    .type(torch.bool))
        #print(tgt_mask.shape)
        
        out = model.decode(ys, memory, tgt_mask)
        #print(out.shape)
        probs = model.generator(out)
        #print(probs.shape)
        next_word = torch.argmax(probs, axis=2)[-1,:].squeeze()

        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        #print(ys.shape)
        if next_word == EOS_IDX:
            break
    return ys.squeeze(-1).cpu().numpy()

i=np.random.randint(1000)

data = datasets["eval"].__getitem__(i)
src = torch.tensor(data["ps"]).to(dtype=torch.int64)
tgt = data["rs"]
tgt = tgt[np.where(tgt != PAD_IDX)][1:-1]
src = src[torch.where(src != PAD_IDX)]
num_tokens = src.shape[0]
src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
out_tokens = greedy_decode(chemFormer,
                           src,
                           src_mask,
                           max_len=MAX_SEQ_LEN,
                           start_symbol=BOS_IDX,
                           DEVICE=DEVICE).flatten()[1:-1]
out_smiles = tokens_to_smiles(out_tokens)
tgt_smiles = tokens_to_smiles(tgt)
print(out_smiles)
print(tgt_smiles)
out_mols = [Chem.MolFromSmiles(e) for e in out_smiles]
tgt_mols = [Chem.MolFromSmiles(e) for e in tgt_smiles]
if None in out_mols:
    print("FALSE")
else:
    out_smiles_canonical = [Chem.MolToSmiles(e) for e in out_mols]
    tgt_smiles_canonical = [Chem.MolToSmiles(e) for e in tgt_mols]
    print(out_smiles_canonical)
    print(tgt_smiles_canonical)

    if set(out_smiles_canonical) == set(tgt_smiles_canonical): print("TRUE")
    else: print("FALSE")
print()
