In [1]:
import re

from rdkit import Chem
from tqdm import tqdm
import selfies as sf

In [2]:
# [Rand] / [Canon] / _
prompt = "_"

In [3]:
path = f"OPT_1.2B_ep_1_half_rand_end_sf_848M_2.00E-04_hf_gradacc_32_gen_10000_{prompt}.csv"

In [4]:
with open(path, "r") as f:
    data = f.read().splitlines()          

In [5]:
data = data[:10000]

In [6]:
unique_data = list(set(data))

In [7]:
len(unique_data)

10000

In [8]:
if prompt == "_":
    # the tag was set in the end
    stats = {
        "invalid_count": 0,
        "canon_count_[Canon]": 0, 
        "canon_count_[Rand]": 0, 
        "rand_count_[Canon]": 0,
        "rand_count_[Rand]": 0
    }

    for i in tqdm(unique_data):
        try:
            match = re.search(r"(.*)\[([^\]]+)\]$", i)
            selfies, end = match.group(1), match.group(2)
            end = "[" + end + "]"
    
            smiles = sf.decoder(selfies)
            canon_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles), canonical=True)
            canon_selfies = sf.encoder(canon_smiles)
            
            if selfies == canon_selfies:
                if end == "[Canon]":
                    stats["canon_count_[Canon]"] += 1
                else:
                    stats["canon_count_[Rand]"] += 1
            else:     
                if end == "[Canon]":
                    stats["rand_count_[Canon]"] += 1
                else:
                    stats["rand_count_[Rand]"] += 1
                    
        except Exception as error:
            print(error)
            stats["invalid_count"] += 1
else:
    # the prompt is with tags [Canon]/[Rand]
    stats = {
        "invalid_count": 0,
        "canon_count": 0, 
        "non_canon_count": 0
    }

    for i in tqdm(unique_data):
        try:
            selfies = i.split(prompt)[-1]
            smiles = sf.decoder(selfies)
            canon_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles), canonical=True)
            canon_selfies = sf.encoder(canon_smiles)
            
            if selfies == canon_selfies:
                stats["canon_count"] += 1
            else:     
                stats["non_canon_count"] += 1
        except Exception as error:
            print(error)
            stats["invalid_count"] += 1

    ## Old algorithm
    # for i in tqdm(unique_data):
    #     try:
    #         selfies = i.split(prompt)[-1]
    #         smiles = sf.decoder(selfies)
    #         canon_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles), canonical=True)
    #         # canon_selfies = sf.encoder(canon_smiles)
            
    #         if smiles == canon_smiles:
    #             stats["canon_count"] += 1
    #         else:     
    #             stats["non_canon_count"] += 1
    #     except Exception as error:
    #         print(error)
    #         stats["invalid_count"] += 1


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:03<00:00, 3048.30it/s]


In [9]:
prompt, stats

('_',
 {'invalid_count': 0,
  'canon_count_[Canon]': 4777,
  'canon_count_[Rand]': 21,
  'rand_count_[Canon]': 118,
  'rand_count_[Rand]': 5084})