In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader

In [2]:
import codecs
from SmilesPE.pretokenizer import atomwise_tokenizer
from SmilesPE.tokenizer import *

In [3]:
from Levy.levenshteinaugment.levenshtein import Levenshtein_augment
from rdkit import Chem

#Supress warnings from RDKit
from rdkit import rdBase
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

In [22]:
import Chemformer.molbart.util as util
from Chemformer.molbart.data.datasets import ZincSlice
from Chemformer.molbart.data.datamodules import MoleculeDataModule
from pathlib import Path

In [10]:
class Zinc(ZincSlice):
    def __init__(self, data_path):
        path = Path(data_path)

        # If path is a directory then read every subfile
        if path.is_dir():
            df = self._read_dir_df(path)
        else:
            df = pd.read_csv(path)

        super().__init__(df)

    def _read_dir_df(self, path):
        dfs = []
        for f_name in tqdm(path.iterdir()):
            df = pd.read_csv(f_name)
            dfs.append(df)
            break
        zinc_df = pd.concat(dfs, ignore_index=True, copy=False)
        return zinc_df

In [11]:
print("Building tokeniser...")
tokeniser = util.load_tokeniser('Chemformer/my_vocab.txt', 272)
print("Finished tokeniser.")

Building tokeniser...
Finished tokeniser.


In [12]:
zinc_dataset = Zinc('/scratch/arihanth.srikar/zinc/')

0it [00:12, ?it/s]


In [25]:
print("Building data module...")
dm = MoleculeDataModule(
        zinc_dataset,
        tokeniser,
        8,
        512,
        task='aug',
        val_idxs=zinc_dataset.val_idxs,
        test_idxs=zinc_dataset.test_idxs,
        train_token_batch_size=None,
        num_buckets=24,
        unified_model=False,
    )
dm.setup()

Building data module...
Using a batch size of 8.
Using molecule data module with augmentations.


In [26]:
test_loader = dm.test_dataloader()

In [28]:
sample = next(iter(test_loader))

In [30]:
sample.keys()

dict_keys(['encoder_input', 'encoder_pad_mask', 'decoder_input', 'decoder_pad_mask', 'target', 'target_mask', 'target_smiles'])

In [36]:
sample['encoder_input'].shape, sample['encoder_pad_mask'].shape

(torch.Size([57, 8]), torch.Size([57, 8]))

In [38]:
sample['decoder_input'].shape, sample['decoder_pad_mask'].shape

(torch.Size([52, 8]), torch.Size([52, 8]))

In [41]:
sample['target'].shape, sample['target_mask'].shape

(torch.Size([52, 8]), torch.Size([52, 8]))

In [45]:
(sample['decoder_input'][1:, :] == sample['target'][:-1, :]).all()

tensor(True)

In [42]:
sample['target_smiles']

['Cn1cc(C[C@@H]2CCC[C@H]2NC(=O)NC[C@@]2(C)CCCNC2)cn1',
 'Cn1cc(C(=O)NCc2ncc(-c3ccccc3)[nH]2)c(-c2cccnc2)n1',
 'C[C@]1(Cc2cc(F)c(F)c(F)c2)CCCN1C(=O)NCc1ccc(C#N)c(F)c1',
 'CN(CC(=O)N(C)C1CCCCC1)C[C@H](O)C[C@@]1(O)CCOC1',
 'CC(C)(C)c1coc([C@H]2CCCN(Cc3ccnc(C4CCC4)n3)C2)n1',
 'O=C(NCCNS(=O)(=O)c1cnn(CC2CC2)c1)c1ccccc1Cl',
 'COc1cccc(OC(F)(F)F)c1CNC(=O)N[C@H](CCN)C(F)(F)F',
 'Cc1ccc(C[C@@H](C)C(=O)NCC[C@@H]2CC(C)(C)CO2)cc1']

In [None]:
df = pd.read_pickle('data/uspto50/uspto_50.pickle')
df.head()

In [None]:
from rdkit.Chem import rdFMCS

df_new = {
    'reactants_mol': [],
    'products_mol': [],
    'reaction_type': [],
    'set': [],
    'importance': [],
}

for row in tqdm(df.itertuples(), total=len(df)):
    product = row.products_mol
    reactants = row.reactants_mol
    reactants = Chem.MolToSmiles(reactants)
    reactants = reactants.split('.')
    reactants = [Chem.MolFromSmiles(reactant) for reactant in reactants]

    # find overlap between reactant and product
    sorted_reactants = []
    for reactant in reactants:
        overlap = rdFMCS.FindMCS([reactant, product])
        sorted_reactants.append((reactant, overlap.numAtoms, overlap.numBonds))
    sorted_reactants = sorted(sorted_reactants, key=lambda x: x[1], reverse=True)

    for i, (reactant, _, _) in enumerate(sorted_reactants):
        df_new['reactants_mol'].append(reactant)
        df_new['products_mol'].append(product)
        df_new['reaction_type'].append(row.reaction_type)
        df_new['set'].append(row.set)
        df_new['importance'].append(i)

In [None]:
sike_df = pd.DataFrame(df_new)

In [None]:
sike_df

In [None]:
type(sike_df['reactants_mol'][0])

In [None]:
sike_df.to_pickle('data/uspto50/uspto_50_sike.pickle')

In [None]:
sike_df = pd.read_pickle('data/uspto50/uspto_50_sike.pickle')
sike_df

In [None]:
sike_df['IFT'] = sike_df['importance'].apply(lambda x: f'<IFT_{x+1}>')
# sike_df.to_pickle('data/uspto50/uspto_50_sike.pickle')

In [None]:
sike_df.iloc[0]['reactants_mol']

In [None]:
df = pd.read_pickle('data/uspto50/processed.pickle')
df.head()

In [None]:
augmenter = Levenshtein_augment(source_augmentation=1, randomization_tries=1000)

In [None]:
def augment(reactants, products):
    # reactants = sorted(reactants, key=lambda x: len(x), reverse=True)
    # products  = sorted(products, key=lambda x: len(x), reverse=True)
    
    new_reactants, new_products, all_score = [], [], []
    reactant, product = '.'.join(reactants), '.'.join(products)

    pairs = augmenter.levenshtein_pairing(reactant, product)
    augmentations = augmenter.sample_pairs(pairs)

    for new_reactant, new_product, score in augmentations:
        new_reactants.append(new_reactant)
        new_products.append(new_product)
        all_score.append(score)
    
    return new_reactants, new_products, all_score

In [None]:
with open('Chemformer/my_vocab.txt') as f:
    char2idx = f.read().split('\n')
char2idx = {c: i for i, c in enumerate(char2idx)}
idx2char = {i: c for i, c in enumerate(char2idx)}

In [None]:
for i, (reactants, products) in enumerate(tqdm(zip(df['reactants_mol'], df['products_mol']), total=len(df))):
    new_products, new_reactants, score = augment(products, reactants)
    print(f'{".".join(reactants)} -> {".".join(products)}')
    for reactant, product, sc in zip(new_reactants, new_products, score):
        print(f'{sc:.2f}: {reactant} -> {product}')
        reactant = [char2idx['^']] + [char2idx[char] for char in atomwise_tokenizer(reactant)] + [char2idx['&']]
        product  = [char2idx['^']] + [char2idx[char] for char in atomwise_tokenizer(product)] + [char2idx['&']]
    break

In [None]:
%%timeit
augment(df['products_mol'].iloc[0], df['reactants_mol'].iloc[0])

In [None]:
class LevySMILES(Dataset):
    def __init__(self, split: str='val') -> None:

        self.split = 'valid' if split == 'val' else split
        
        self.df = pd.read_pickle('data/uspto50/processed.pickle')
        self.df = self.df[self.df['set'] == self.split]

        self.augmenter = Levenshtein_augment(source_augmentation=1, randomization_tries=1000)

        with open('Chemformer/my_vocab.txt') as f:
            char2idx = f.read().split('\n')
        self.char2idx = {c: i for i, c in enumerate(char2idx)}
        self.idx2char = {i: c for i, c in enumerate(char2idx)}
        
        self.start_token = self.char2idx['^']
        self.end_token   = self.char2idx['&']
        self.pad_token   = self.char2idx['<PAD>']

    def augment(self, reactants, products):
        new_reactants, new_products, all_score = [], [], []
        reactant, product = '.'.join(reactants), '.'.join(products)

        pairs = self.augmenter.levenshtein_pairing(reactant, product)
        augmentations = self.augmenter.sample_pairs(pairs)

        for new_reactant, new_product, score in augmentations:
            new_reactants.append(new_reactant)
            new_products.append(new_product)
            all_score.append(score)
        
        return new_reactants, new_products, all_score

    def __len__(self) -> int:
        return len(self.df)
    
    def __getitem__(self, idx: int) -> dict:
        # get reactants and products
        reactants, products = self.df['reactants_mol'].iloc[idx], self.df['products_mol'].iloc[idx]
        
        # augment and find best pair
        new_products, new_reactants, score = self.augment(products, reactants)
        new_reactants, new_products = new_reactants[0], new_products[0]
        
        # tokenize
        new_reactants = [self.start_token] + [self.char2idx[char] for char in atomwise_tokenizer(new_reactants)] + [self.end_token]
        new_products  = [self.start_token] + [self.char2idx[char] for char in atomwise_tokenizer(new_products)]  + [self.end_token]

        # convert to tensor
        new_reactants = torch.tensor(new_reactants).long()
        new_products  = torch.tensor(new_products).long()
        
        return {'encoder_output': new_reactants, 'encoder_input': new_products}

    def collate_fn(self, batch):
        # extract batch elements
        encoder_input = [x['encoder_input'] for x in batch]
        encoder_output = [x['encoder_output'] for x in batch]
        
        # pad to maximum length
        encoder_input = torch.nn.utils.rnn.pad_sequence(encoder_input, batch_first=True, padding_value=self.pad_token)
        encoder_output = torch.nn.utils.rnn.pad_sequence(encoder_output, batch_first=True, padding_value=self.pad_token)
        
        return {'encoder_input': encoder_input, 'encoder_output': encoder_output}

In [None]:
val_dataset = LevySMILES(split='val')
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=16, collate_fn=val_dataset.collate_fn)

In [None]:
for batch in tqdm(val_loader):
    pass

In [None]:
char2idx

In [None]:
df = pd.read_pickle('data/uspto50/uspto_50_sike.pickle')
df

In [None]:
single_df = df[df['importance'] == 0]
single_df = single_df.reset_index(drop=True)
single_df.to_pickle('data/uspto50/uspto_50_sike_single.pickle')

In [None]:
single_df = pd.read_pickle('data/uspto50/uspto_50_sike_single.pickle')
single_df