In [1]:
import gc
from tqdm import tqdm
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
from itertools import chain

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch.optim import AdamW

from mlm_pytorch.mlm_pytorch.mlm_pytorch import MLM
from x_transformers.x_transformers import TransformerWrapper, Encoder, Decoder
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper

import codecs
from SmilesPE.pretokenizer import atomwise_tokenizer
from SmilesPE.tokenizer import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class USPTOIFT(Dataset):
    
    def __init__(self, data_dir: str='data/uspto_IFT', split: str='val', to_gen: int=-1, extra: str='') -> None:

        # target is the reactant
        with open(f'{data_dir}/{split}/reactants.txt', 'r') as f:
            self.reactants = f.read().splitlines()
        self.reactants = [r.split(' ') for r in self.reactants]

        # source or input is the product
        with open(f'{data_dir}/{split}/products.txt', 'r') as f:
            self.products = f.read().splitlines()
        self.products = [p.split(' ') for p in self.products]

        # verify that the dataset is consistent
        assert len(self.reactants) == len(self.products), 'Mismatched length of reactants and products'
        self.to_gen = to_gen if to_gen > 0 else len(self.reactants)

        # vocab and tokenizer
        with open(f'{data_dir}/vocab{extra}.txt', 'r') as f:
            self.token_decoder = f.read().splitlines()
        self.token_encoder = {t: i for i, t in enumerate(self.token_decoder)}

        # sanity check the tokenizer
        print(f'Performing sanity check on vocab and tokenizer...')
        reactant_set = set(chain.from_iterable(self.reactants))
        product_set = set(chain.from_iterable(self.products))
        all_chars = reactant_set.union(product_set)
        assert all_chars <= set(self.token_encoder.keys()), "Tokenizer is not consistent with the dataset"

        # additional information
        self.vocab_size = len(self.token_decoder)
        self.pad_token_id = self.token_encoder['<pad>']
        self.mask_token_id = self.token_encoder['<mask>']
        self.mask_ignore_token_ids = [v for k, v in self.token_encoder.items() if '<' in k and '>' in k]

    def __len__(self):
        return self.to_gen
    
    def __getitem__(self, idx):
        r, p = self.reactants[idx], self.products[idx]
        num_reactants, num_products = r.count('.')+1, p.count('.')+1

        r = [f'<{num_reactants}>'] + ['<sos>'] + r + ['<eos>']
        p = [f'<{num_products}>']  + ['<sos>'] + p + ['<eos>']
        
        r = [self.token_encoder[t] for t in r]
        p = [self.token_encoder[t] for t in p]

        src_mask = [True] * len(p)

        r, p, src_mask = torch.tensor(r), torch.tensor(p), torch.tensor(src_mask).bool()

        return r, p, src_mask
    
    def collate_fn(self, data):

        # unpack the input data
        r, p, src_mask = zip(*data)
        
        # pad the encoder stuff
        p = pad_sequence(p, batch_first=True, padding_value=self.pad_token_id)
        src_mask = pad_sequence(src_mask, batch_first=True, padding_value=False).bool()
        
        # pad the decoder stuff
        r = pad_sequence(r, batch_first=True, padding_value=self.pad_token_id)
        
        return r, p, src_mask

In [6]:
val_dataset = USPTOIFT(split='train', to_gen=-1)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=val_dataset.collate_fn)

Performing sanity check on vocab and tokenizer...


In [7]:
r, p, src_mask = next(iter(val_loader))

In [8]:
r.shape, p.shape, src_mask.shape

(torch.Size([32, 80]), torch.Size([32, 60]), torch.Size([32, 60]))

In [2]:
from Levy.levenshteinaugment.levenshtein import Levenshtein_augment

#Supress warnings from RDKit
from rdkit import rdBase
rdBase.DisableLog('rdApp.error')
rdBase.DisableLog('rdApp.warning')

In [3]:
df = pd.read_pickle('data/uspto50/processed.pickle')
df.head()

Unnamed: 0,reactants_mol,products_mol,reaction_type,set,num_reacts,num_prods,ratio
0,"[CS(=O)(=O)OC[C@H]1CCC(=O)O1, Fc1ccc(Nc2ncnc3c...",[O=C1CC[C@H](CN2CCN(CCOc3cc4ncnc(Nc5ccc(F)c(Cl...,<RX_1>,train,2,1,1.121951
1,[COC(=O)c1cc(CCCc2cc3c(=O)[nH]c(N)nc3[nH]2)cs1],[Nc1nc2[nH]c(CCCc3csc(C(=O)O)c3)cc2c(=O)[nH]1],<RX_6>,train,1,1,1.045455
2,"[CC1(C)OB(B2OC(C)(C)C(C)(C)O2)OC1(C)C, FC(F)(F...",[CC1(C)OB(c2cccc(Nc3nccc(C(F)(F)F)n3)c2)OC1(C)C],<RX_9>,train,2,1,1.384615
3,[CC(C)(C)OC(=O)NCC(=O)CCC(=O)OCCCC(=O)OCc1ccccc1],[CC(C)(C)OC(=O)NCC(=O)CCC(=O)OCCCC(=O)O],<RX_6>,train,1,1,1.318182
4,"[Fc1cc2c(Cl)ncnc2cn1, NC1CCCCCC1]",[Fc1cc2c(NC3CCCCCC3)ncnc2cn1],<RX_1>,train,2,1,1.052632


In [6]:
augmenter = Levenshtein_augment(source_augmentation=2, randomization_tries=1000)

def augment(reactants, products):
    reactants = sorted(reactants, key=lambda x: len(x), reverse=True)
    products  = sorted(products, key=lambda x: len(x), reverse=True)
    
    new_reactants, new_products, all_score = [], [], []
    for i in range(1, len(reactants)+1):
        reactant, product = '.'.join(reactants[:i]), '.'.join(products[:i])
    
        pairs = augmenter.levenshtein_pairing(reactant, product)
        augmentations = augmenter.sample_pairs(pairs)
    
        for new_reactant, new_product, score in augmentations:
            new_reactants.append(new_reactant)
            new_products.append(new_product)
            all_score.append(score)
    
    return new_reactants, new_products, all_score

In [11]:
# random_idx = np.random.randint(0, len(df))
# reactants, products = df['reactants_mol'].iloc[random_idx], df['products_mol'].iloc[random_idx]

new_df = df[df['reactants_mol'].apply(len) >= 3]

tok_reactants = []
tok_products  = []

# for i, (reactants, products) in enumerate(tqdm(zip(new_df['reactants_mol'], new_df['products_mol']), total=len(new_df))):
for i, (reactants, products) in enumerate(tqdm(zip(df['reactants_mol'], df['products_mol']), total=len(df))):
    new_reactants, new_products, score = augment(reactants, products)
    # print(f'{".".join(reactants)} -> {".".join(products)}')
    for reactant, product, sc in zip(new_reactants, new_products, score):
        # print(f'{sc:.2f}: {reactant} -> {product}')
        tok_reactants.append(atomwise_tokenizer(reactant))
        tok_products.append(atomwise_tokenizer(product))
    # print()
    if i >= 3:
        break

  0%|          | 0/50037 [00:00<?, ?it/s]

  0%|          | 3/50037 [00:03<17:37:03,  1.27s/it]


In [12]:
rand_idx = 6
print(f"{' '.join(tok_reactants[rand_idx])} -> {' '.join(tok_products[rand_idx])}")

O 1 C ( C ) ( C ) C ( C ) ( C ) O B 1 B 1 O C ( C ) ( C ) C ( C ) ( C ) O 1 -> O 1 C ( C ) ( C ) C ( C ) ( C ) O B 1 c 1 c c c c ( N c 2 n c c c ( C ( F ) ( F ) F ) n 2 ) c 1


In [2]:
with open('data/zinc/new_vocab.txt') as f:
    zinc_vocab = [x.strip() for x in f.readlines()]
print(len(zinc_vocab), zinc_vocab)

with open('data/rooted/vocab.txt') as f:
    rooted_vocab = [x.strip() for x in f.readlines()]
print(len(rooted_vocab), rooted_vocab)

145 ['C', 'S', '(', '=', 'O', ')', '[C@H]', '1', '.', 'F', 'c', 'N', '2', 'n', '3', '4', 'Cl', '[nH]', 's', 'B', 'Br', '[N+]', '[O-]', '[C@@H]', '-', '5', 'o', '/', '[Li]', '[N-]', '#', '[C@@]', '[Si]', 'I', 'P', '[Mg+]', '[P+]', '[S-]', '[Se]', '[C@]', '\\', '[Sn]', '[NH4+]', '[SiH2]', '[NH3+]', '[K]', '[SiH]', '[Zn+]', '6', '[C-]', '[Cu]', '[n+]', '[S@@]', '[PH]', '[se]', '[BH3-]', '[SH]', '[SnH]', '[S@]', '[BH-]', '[S+]', '[PH2]', '7', '[OH-]', '[NH2+]', '[s+]', '[PH4]', '[Pt]', '[Cl-]', '[Zn]', '[n-]', '[Mg]', '[NH+]', '[Br-]', '[NH-]', '[B-]', '[Fe]', '[Pd]', '[Cl+3]', 'p', '[Pb]', '[SiH3]', '[I+]', '8', '9', '[N@+]', '[N@@+]', '[C]', '[N]', '[P@]', '[CH2-]', '[CH]', '[S@@+]', '[CH-]', '[S@@H]', '[O]', '[CH2]', '[P@@]', '[cH-]', '[S@+]', '[P@@H]', '[c-]', '[P@H]', '[F+]', '[N@@H+]', '[SH2]', '[11CH3]', '[P@@+]', '[o+]', '[S]', '[B@-]', '[SH3]', '[18F]', '[B@@H-]', '[125I]', '[124I]', '[P@+]', '[123I]', '[CH+]', '[BH2-]', '[18OH]', '[B@H-]', '%10', '[C+]', '[IH2]', '[O+]', '[Sn+2]'

In [3]:
special_tokens = rooted_vocab[-11:]

In [4]:
set1 = set(zinc_vocab[:-11])
set2 = set(rooted_vocab[:-11])
print(len(set1), len(set2), len(set1 & set2))

134 72 72


In [5]:
final_vocab = set1.union(set2)
final_vocab = list(final_vocab) + special_tokens

In [6]:
print(len(final_vocab), final_vocab)

145 ['[N@@+]', '[S+]', '[125I]', '%10', '[SiH3]', '[18OH]', 'P', '[Fe]', '[B@@H-]', '[Cl-]', '[Cl+3]', '[BH3-]', '[Pb]', '[Se]', '[SH]', '[Mg]', '[SiH]', '[S@@]', 'c', 'F', 's', 'N', '[CH2-]', '/', '[3H]', '[C+]', '[O+]', '[P@H]', '[Br-]', 'O', '[C]', '[CH2]', '[Si]', '[B@@-]', '[NH+]', '[123I]', 'C', '[Pt]', '7', '[CH]', '[o+]', '[B@H-]', '[O-]', '[O]', '[BH2-]', '[I+]', '[N+]', '[Sn+2]', 'B', '[F+]', '[nH]', '-', '2', '[N-]', '9', '3', '4', 'n', '[Zn+]', '\\', '[P+]', '[C-]', '[S@@H]', '[P@]', '[NH4+]', 'Cl', '.', '[PH2]', 'S', '[s+]', 'p', '[124I]', '[C@@H]', '[SH2]', '[SH3]', '[SiH2]', '[11CH3]', '[pH]', '[Br+]', '[N@+]', '5', '[Li]', '[Cu]', '[S@@+]', '[K]', '[Sn]', '[SnH2]', '[NH2+]', '=', '[se]', '[SnH]', '[n+]', '[B@-]', '[S@]', '1', '[P@+]', '[c-]', '[S]', '[CH+]', '[Mg+]', '[C@@]', '[N@@H+]', '[cH-]', '[PH]', '6', '[B-]', '[OH-]', '[18F]', ')', '[PH4]', '[CH-]', '[BH-]', '[S-]', '[N]', 'I', '(', '8', '[S@+]', '[P@@]', '[Sn+3]', 'o', '[Zn]', '[C@]', '[P@@H]', '[NH-]', 'Br', '[

In [7]:
with open('data/final_vocab.txt', 'w') as f:
    for item in final_vocab:
        f.write("%s\n" % item)

In [None]:
def editDistDP(str1, str2, m, n):
    # Create a table to store results of subproblems
    dp = [[0 for x in range(n + 1)] for x in range(m + 1)]
 
    # Fill d[][] in bottom up manner
    for i in range(m + 1):
        for j in range(n + 1):
 
            # If first string is empty, only option is to
            # insert all characters of second string
            if i == 0:
                dp[i][j] = j    # Min. operations = j
 
            # If second string is empty, only option is to
            # remove all characters of second string
            elif j == 0:
                dp[i][j] = i    # Min. operations = i
 
            # If last characters are same, ignore last char
            # and recur for remaining string
            elif str1[i-1] == str2[j-1]:
                dp[i][j] = dp[i-1][j-1]
 
            # If last character are different, consider all
            # possibilities and find minimum
            else:
                dp[i][j] = 1 + min(dp[i][j-1],        # Insert
                                   dp[i-1][j],        # Remove
                                   dp[i-1][j-1])    # Replace
    # ans = dp[m][n]
    return dp

In [None]:
r_smiles = []
p_smiles = []

for split in ['train', 'val', 'test']:
    with open(f'data/rooted/{split}/src-{split}.txt') as f:
        src = f.read().splitlines()
    with open(f'data/rooted/{split}/tgt-{split}.txt') as f:
        tgt = f.read().splitlines()

    for s in tqdm(src, desc=f'{split}'):
        p_smiles.append(''.join(s.split()))
    for t in tqdm(tgt, desc=f'{split}'):
        r_smiles.append(''.join(t.split()))
    # for t in tqdm(tgt, desc=f'{split}'):
    #     for s_ind in ''.join(t.split()).split('.'):
    #         smiles.append(s_ind)

In [None]:
print(len(r_smiles), r_smiles[3234:3238])
print(len(p_smiles), p_smiles[3234:3238])

In [None]:
dp = editDistDP(r_smiles[3234], p_smiles[3234], len(r_smiles[3234]), len(p_smiles[3234]))
dp[-1][-1]

In [None]:
sns.set()
plt.figure(figsize=(10, 10))
sns.heatmap(dp, xticklabels=list(p_smiles[3234]), yticklabels=list(r_smiles[3234]), cbar=False)
plt.tick_params(axis='both', which='major', labelsize=10, labelbottom = False, bottom=False, top = False, labeltop=True)
plt.show()

In [None]:
for vocab_size in ['', '100', '250', '750', '2000']:
    vocab = set()

    for split in ['train', 'val', 'test']:
        src_name = f'data/rooted/{split}/src-{split}_{vocab_size}.txt' if vocab_size != '' else f'data/rooted/{split}/src-{split}.txt'
        tgt_name = f'data/rooted/{split}/tgt-{split}_{vocab_size}.txt' if vocab_size != '' else f'data/rooted/{split}/tgt-{split}.txt'
        with open(src_name) as f:
            src = f.read().splitlines()
        with open(tgt_name) as f:
            tgt = f.read().splitlines()

        for s in tqdm(src, desc=f'{vocab_size}-{split}'):
            vocab.update(s.split())
        for t in tqdm(tgt, desc=f'{vocab_size}-{split}'):
            vocab.update(t.split())

    extra = ['<unk>', '<sos>', '<eos>', '<mask>', '<sum_pred>', '<sum_react>', '<0>', '<1>', '<2>', '<3>', '<pad>']
    vocab = list(vocab) + extra

    with open(f'data/rooted/vocab{vocab_size}.txt', 'w') as f:
        f.write('\n'.join(vocab))
    
    break

In [None]:
print(len(vocab))
print(vocab[-20:])

In [None]:
spe_vob = codecs.open('data/vocab_pairs/SPE_vocab_pairs_2000.txt')
spe = SPE_Tokenizer(spe_vob)

In [None]:
smi = ''.join('COC(=O) [C@H](C CCCN )N C(=O)N c1cc(OC )cc(C (C)(C)C)c1 O'.split())
smi

In [None]:
spe.tokenize(smi)

In [None]:
torch.cuda.empty_cache()
gc.collect()

In [None]:
torch.set_float32_matmul_precision('medium')

In [None]:
class Zinc(Dataset):
    def __init__(self, data_dir: str='/scratch/arihanth.srikar', split: str='train', to_gen: int=-1):
        extra = ''
        
        # dataset files
        # df = pd.read_pickle(f'{data_dir}/x001{extra}.pickle')
        df = pd.read_csv(f'{data_dir}/x001.csv')
        df = df[df['set'] == split].copy()
        
        # read entire dataset and convert to list
        self.smiles = df['smiles'].tolist()
        
        # clear memory
        del df
        
        # load specified number of samples
        self.to_gen = to_gen if to_gen > 0 else len(self.smiles)
        
        # token encoder and decoder
        with open(f'{data_dir}/vocab{extra}.txt', 'r') as f:
            self.token_decoder = f.read().splitlines()
        self.token_encoder = {k: v for v, k in enumerate(self.token_decoder)}

        self.vocab_size = len(self.token_decoder)
        self.pad_token_id = self.token_encoder['<pad>']

    def __len__(self):
        return self.to_gen

    def __getitem__(self, idx):
        
        # pick random indices if not utilizing entire dataset
        if self.to_gen != len(self.smiles):
            idx = torch.randint(0, len(self.smiles), (1,)).item()
        
        # treat the smiles as products
        p = self.smiles[idx]
        p = [self.token_encoder[tok] for tok in atomwise_tokenizer(p)]
        
        # append end of products token
        p = [self.token_encoder['<sop>']] + p + [self.token_encoder['<eop>']]
        mask = [1] * len(p)
        
        return torch.tensor(p), torch.tensor(mask)


    def collate_fn(self, batch):
        smiles, mask = zip(*batch)
        smiles = torch.nn.utils.rnn.pad_sequence(smiles, batch_first=True, padding_value=self.token_encoder['<pad>'])
        mask = (smiles != self.token_encoder['<pad>']).bool()
        return smiles, mask

In [None]:
train_dataset = Zinc(split='train', to_gen=100*384)

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=128, collate_fn=train_dataset.collate_fn, shuffle=True, num_workers=8, pin_memory=True, prefetch_factor=4)

In [None]:
train_dataset.vocab_size

In [None]:
smiles, mask = next(iter(train_dataloader))

In [None]:
encoder_transformer = TransformerWrapper(
    num_tokens = train_dataset.vocab_size,
    max_seq_len = 512,
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8,
        rel_pos_bias = True
    )
)

In [None]:
decoder = TransformerWrapper(
    num_tokens = train_dataset.vocab_size,
    max_seq_len = 512,
    attn_layers = Decoder(
        dim = 512,
        depth = 6,
        heads = 8,
        rel_pos_bias = True,
        cross_attend = True
    )
)

In [None]:
encoder = MLM(
    encoder_transformer,
    mask_token_id = train_dataset.token_encoder['<mask>'],          # the token id reserved for masking
    pad_token_id = train_dataset.token_encoder['<pad>'],           # the token id for padding
    mask_prob = 0.15,           # masking probability for masked language modeling
    replace_prob = 0.90,        # ~10% probability that token will not be masked, but included in loss, as detailed in the epaper
    mask_ignore_token_ids = []  # other tokens to exclude from masking, include the [cls] and [sep] here
)

In [None]:
decoder = AutoregressiveWrapper(
    decoder,
    pad_value = train_dataset.token_encoder['<pad>'],
    ignore_index=train_dataset.token_encoder['<pad>'],
)

In [None]:
smiles, mask = next(iter(train_dataloader))
smiles.shape, mask.shape

In [None]:
encoder = encoder.cuda()
smiles, mask = smiles.cuda(), mask.cuda()

In [None]:
with torch.no_grad():
    logits, enc, loss = encoder(smiles, mask=mask, return_logits_and_embeddings=True)

In [None]:
logits.shape, enc.shape, loss.item()

In [None]:
decoder = decoder.cuda()

In [None]:
with torch.no_grad():
    decoder_logits, decoder_loss = decoder(smiles, context=enc, context_mask=mask)

In [None]:
decoder_logits.shape, decoder_loss.item()

In [None]:
device = 'cuda'
optimizer = AdamW(list(encoder.parameters())+list(decoder.parameters()), lr=1e-4)

encoder.to(device)
decoder.to(device)
for epoch in range(10):
    avg_encoder_loss, avg_decoder_loss = 0, 0
    with tqdm(train_dataloader) as pbar:
        pbar.set_description(f'Epoch {epoch+1}')
        for i, (smiles, mask) in enumerate(pbar):
            smiles, mask = smiles.to(device), mask.to(device)
            
            encoder_logits, enc, encoder_loss = encoder(smiles, mask=mask, return_logits_and_embeddings=True)
            decoder_logits, decoder_loss = decoder(smiles, context=enc, context_mask=mask)

            encoder_loss.backward()
            # decoder_loss.backward()
            
            optimizer.step()
            optimizer.zero_grad()
            
            avg_encoder_loss += encoder_loss.item()
            avg_decoder_loss += decoder_loss.item()
            
            pbar.set_postfix({
                'encoder_loss': encoder_loss.item(),
                'decoder_loss': decoder_loss.item(),
                'avg_encoder_loss': avg_encoder_loss/(i+1),
                'avg_decoder_loss': avg_decoder_loss/(i+1)
                })