In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib as plt
import regex as re
import pickle as pkl
from transvae.data import vae_data_gen 
from collections import Counter


def tokenizer(mol,mol_encoding='selfies'):
    if mol_encoding=='smiles':
        "Tokenizes SMILES string"
        pattern =  "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|_|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
        regezz = re.compile(pattern)
        tokens = [token for token in regezz.findall(mol)]
    elif mol_encoding=='selfies':
        tokens=[block+"]" for block in mol.split("]")][:-1]
    else:
        raise NameError('expected mol encoding to be smiles or selfies')
    assert mol == ''.join(tokens), ("{} could not be joined".format(mol))
    return tokens



In [None]:
#Print 20 most common motifs for each of lengths 4-20 for SMILES in MOSES dataset

moses_smiles = pd.read_csv("data/moses_smiles_train.csv")["smiles"]

counters = [Counter() for k in range(21)]
j = 0
for smile in moses_smiles:
    tokens = tokenizer(smile, mol_encoding="smiles")
    for l in range(4, min(len(tokens), 21)): #length of motif
        motifs = []
        for i in range(0, len(tokens)-l+1): #start index of motif
            motifs.append(''.join(tokens[i: i + l]))
        counters[l].update(motifs)
        del(motifs)
    j +=1
    if j % 50000 == 0: print(j)
    
for i in range(4, 21):
    print(counters[i].most_common(20))

In [None]:
#Print 20 most common motifs for each of lengths 4-20 for SELFIES in MOSES dataset

moses_selfies = pd.read_csv("data/moses_train.csv")["selfies"]

counters = [Counter() for k in range(21)]
j = 0
for selfie in moses_selfies:
    tokens = tokenizer(selfie, mol_encoding="selfies")
    for l in range(4, min(len(tokens), 21)): #length of motif
        motifs = []
        for i in range(0, len(tokens)-l+1): #start index of motif
            motifs.append(''.join(tokens[i: i + l]))
        counters[l].update(motifs)
        del(motifs)
    j +=1
    if j % 50000 == 0: print(j)
    
for i in range(4, 21):
    print(counters[i].most_common(20))