In [1]:
import re
import json

import pandas as pd
from rdkit import Chem
from tqdm import tqdm
import selfies as sf

In [2]:
def check_validness(smiles: str) -> bool:
    try:
        return Chem.MolFromSmiles(smiles) is not None
    except:
        return False


def check_canonical(smiles: str) -> bool:
    try:
        return get_canonical_form(smiles) == smiles
    except:
        return False
    

def get_canonical_form(smiles: str):
    try:
        return Chem.MolToSmiles(Chem.MolFromSmiles(smiles), canonical=True)
    except:
        print(f"Error while getting the canonical form for {smiles}.")
        return ""   

In [3]:
path = f"OPT_1.2B_ep_1_half_rand_end_sf_848M_2.00E-04_hf_gradacc_32_iter_10000_from_valid.csv"

In [4]:
df = pd.read_json(path)
df = df[:10000]
df

Unnamed: 0,seq,output
0,[Rand][C][Branch1][C][O][Branch1][Ring1][C][O]...,[C][Branch1][C][O][Branch1][Ring1][C][O][C][=B...
1,[Rand][C][=Branch1][C][=O][N][C][Branch1][C][C...,[C][=Branch1][C][=O][N][C][Branch1][C][C][C][R...
2,[Rand][C][Branch1][C][N][C][Branch1][C][O][C][...,[C][Branch1][C][N][C][Branch1][C][O][C][C][Rin...
3,[Rand][C][C][=C][C][Branch1][C][N][C][Ring1][=...,[C][C][=C][C][Branch1][C][N][C][Ring1][=Branch...
4,[Rand][N][#C][C][=C][Branch1][C][N][C][=C][C][...,[N][#C][C][=C][Branch1][C][N][C][=C][C][Branch...
...,...,...
9995,[Canon][C][O][C][=C][C][O][C][C][Ring1][Branch...,[C][O][C][=C][C][O][C][C][Ring1][Branch1][Bran...
9996,[Canon][C][=C][C][C][Branch1][Branch1][C][O][R...,[C][=C][C][C][Branch1][Branch1][C][O][Ring1][B...
9997,[Canon][C][C][N][C][C][C][N][N][=C][Ring1][#Br...,[C][C][N][C][C][C][N][N][=C][Ring1][#Branch1][...
9998,[Canon][C][C][C][C][C][N][C][C][Branch1][Branc...,[C][C][C][C][C][N][C][C][Branch1][Branch2][C][...


In [5]:
df["tag"] = df["seq"].apply(lambda a: a[:a.index("]")+1])

In [6]:
def return_diff_seq(first_seq, second_seq):
    return second_seq[len(first_seq):]

In [7]:
df["seq"] = list(map(return_diff_seq, df["tag"], df["seq"]))

In [8]:
df

Unnamed: 0,seq,output,tag
0,[C][Branch1][C][O][Branch1][Ring1][C][O][C][=B...,[C][Branch1][C][O][Branch1][Ring1][C][O][C][=B...,[Rand]
1,[C][=Branch1][C][=O][N][C][Branch1][C][C][C][R...,[C][=Branch1][C][=O][N][C][Branch1][C][C][C][R...,[Rand]
2,[C][Branch1][C][N][C][Branch1][C][O][C][C][Rin...,[C][Branch1][C][N][C][Branch1][C][O][C][C][Rin...,[Rand]
3,[C][C][=C][C][Branch1][C][N][C][Ring1][=Branch...,[C][C][=C][C][Branch1][C][N][C][Ring1][=Branch...,[Rand]
4,[N][#C][C][=C][Branch1][C][N][C][=C][C][Branch...,[N][#C][C][=C][Branch1][C][N][C][=C][C][Branch...,[Rand]
...,...,...,...
9995,[C][O][C][=C][C][O][C][C][Ring1][Branch1][Bran...,[C][O][C][=C][C][O][C][C][Ring1][Branch1][Bran...,[Canon]
9996,[C][=C][C][C][Branch1][Branch1][C][O][Ring1][B...,[C][=C][C][C][Branch1][Branch1][C][O][Ring1][B...,[Canon]
9997,[C][C][N][C][C][C][N][N][=C][Ring1][#Branch1][...,[C][C][N][C][C][C][N][N][=C][Ring1][#Branch1][...,[Canon]
9998,[C][C][C][C][C][N][C][C][Branch1][Branch2][C][...,[C][C][C][C][C][N][C][C][Branch1][Branch2][C][...,[Canon]


In [9]:
df["diff"] = list(map(return_diff_seq, df["seq"], df["output"]))

In [10]:
df

Unnamed: 0,seq,output,tag,diff
0,[C][Branch1][C][O][Branch1][Ring1][C][O][C][=B...,[C][Branch1][C][O][Branch1][Ring1][C][O][C][=B...,[Rand],[Rand]
1,[C][=Branch1][C][=O][N][C][Branch1][C][C][C][R...,[C][=Branch1][C][=O][N][C][Branch1][C][C][C][R...,[Rand],[Rand]
2,[C][Branch1][C][N][C][Branch1][C][O][C][C][Rin...,[C][Branch1][C][N][C][Branch1][C][O][C][C][Rin...,[Rand],[Rand]
3,[C][C][=C][C][Branch1][C][N][C][Ring1][=Branch...,[C][C][=C][C][Branch1][C][N][C][Ring1][=Branch...,[Rand],[Rand]
4,[N][#C][C][=C][Branch1][C][N][C][=C][C][Branch...,[N][#C][C][=C][Branch1][C][N][C][=C][C][Branch...,[Rand],[Rand]
...,...,...,...,...
9995,[C][O][C][=C][C][O][C][C][Ring1][Branch1][Bran...,[C][O][C][=C][C][O][C][C][Ring1][Branch1][Bran...,[Canon],[Canon]
9996,[C][=C][C][C][Branch1][Branch1][C][O][Ring1][B...,[C][=C][C][C][Branch1][Branch1][C][O][Ring1][B...,[Canon],[Canon]
9997,[C][C][N][C][C][C][N][N][=C][Ring1][#Branch1][...,[C][C][N][C][C][C][N][N][=C][Ring1][#Branch1][...,[Canon],[Canon]
9998,[C][C][C][C][C][N][C][C][Branch1][Branch2][C][...,[C][C][C][C][C][N][C][C][Branch1][Branch2][C][...,[Canon],[Canon]


In [11]:
matched_df = df[df["diff"] == df["tag"]]
matched_df

Unnamed: 0,seq,output,tag,diff
0,[C][Branch1][C][O][Branch1][Ring1][C][O][C][=B...,[C][Branch1][C][O][Branch1][Ring1][C][O][C][=B...,[Rand],[Rand]
1,[C][=Branch1][C][=O][N][C][Branch1][C][C][C][R...,[C][=Branch1][C][=O][N][C][Branch1][C][C][C][R...,[Rand],[Rand]
2,[C][Branch1][C][N][C][Branch1][C][O][C][C][Rin...,[C][Branch1][C][N][C][Branch1][C][O][C][C][Rin...,[Rand],[Rand]
3,[C][C][=C][C][Branch1][C][N][C][Ring1][=Branch...,[C][C][=C][C][Branch1][C][N][C][Ring1][=Branch...,[Rand],[Rand]
4,[N][#C][C][=C][Branch1][C][N][C][=C][C][Branch...,[N][#C][C][=C][Branch1][C][N][C][=C][C][Branch...,[Rand],[Rand]
...,...,...,...,...
9995,[C][O][C][=C][C][O][C][C][Ring1][Branch1][Bran...,[C][O][C][=C][C][O][C][C][Ring1][Branch1][Bran...,[Canon],[Canon]
9996,[C][=C][C][C][Branch1][Branch1][C][O][Ring1][B...,[C][=C][C][C][Branch1][Branch1][C][O][Ring1][B...,[Canon],[Canon]
9997,[C][C][N][C][C][C][N][N][=C][Ring1][#Branch1][...,[C][C][N][C][C][C][N][N][=C][Ring1][#Branch1][...,[Canon],[Canon]
9998,[C][C][C][C][C][N][C][C][Branch1][Branch2][C][...,[C][C][C][C][C][N][C][C][Branch1][Branch2][C][...,[Canon],[Canon]


In [12]:
matched_rand = len(matched_df[matched_df["tag"]=="[Rand]"])
matched_rand

4208

In [13]:
matched_canon = len(matched_df[matched_df["tag"]=="[Canon]"])
matched_canon

4321

In [14]:
mismatched_df = df[df["diff"] != df["tag"]]
mismatched_df

Unnamed: 0,seq,output,tag,diff
8,[C][=Branch1][N][=C][N][=C][N][N][=C][Branch1]...,[C][=Branch1][N][=C][N][=C][N][N][=C][Branch1]...,[Rand],[C][Rand]
20,[C][C][Branch1][C][O][C][Branch1][Ring1][C][O]...,[C][C][Branch1][C][O][C][Branch1][Ring1][C][O]...,[Rand],[Ring1][=N][Rand]
21,[C][Branch1][N][C][=C][C][C][N][C][Ring1][#Bra...,[C][Branch1][N][C][=C][C][C][N][C][Ring1][#Bra...,[Rand],[Ring1][Branch1][Rand]
23,[C][N][C][Ring1][Ring1][C][C][Branch1][=Branch...,[C][N][C][Ring1][Ring1][C][C][Branch1][=Branch...,[Rand],[C][Rand]
31,[N][C][=C][C][=Branch1][C][=O][N][C][C][N][Bra...,[N][C][=C][C][=Branch1][C][=O][N][C][C][N][Bra...,[Rand],[Ring1][N][Rand]
...,...,...,...,...
9952,[C][#C][C][C][C][C][Branch1][C][N][C][C][C][Ri...,[C][#C][C][C][C][C][Branch1][C][N][C][C][C][Ri...,[Canon],[C][Canon]
9956,[C][C][C][C][O][C][C][C][=Branch1][C][=N][N][B...,[C][C][C][C][O][C][C][C][=Branch1][C][=N][N][B...,[Canon],[C][Canon]
9970,[C][C][C][C][C][Ring1][Ring2][C][=Branch1][C][...,[C][C][C][C][C][Ring1][Ring2][C][=Branch1][C][...,[Canon],[=O][Canon]
9973,[C][=C][C][C][C][N][C][Ring1][Branch1][C][Ring...,[C][=C][C][C][C][N][C][Ring1][Branch1][C][Ring...,[Canon],[Ring1][O][Canon]


In [15]:
def contains_fn(tag, diff_seq):
    return tag in diff_seq

In [16]:
mismatched_df[list(map(contains_fn, mismatched_df["tag"], mismatched_df["diff"]))]

Unnamed: 0,seq,output,tag,diff
8,[C][=Branch1][N][=C][N][=C][N][N][=C][Branch1]...,[C][=Branch1][N][=C][N][=C][N][N][=C][Branch1]...,[Rand],[C][Rand]
20,[C][C][Branch1][C][O][C][Branch1][Ring1][C][O]...,[C][C][Branch1][C][O][C][Branch1][Ring1][C][O]...,[Rand],[Ring1][=N][Rand]
21,[C][Branch1][N][C][=C][C][C][N][C][Ring1][#Bra...,[C][Branch1][N][C][=C][C][C][N][C][Ring1][#Bra...,[Rand],[Ring1][Branch1][Rand]
23,[C][N][C][Ring1][Ring1][C][C][Branch1][=Branch...,[C][N][C][Ring1][Ring1][C][C][Branch1][=Branch...,[Rand],[C][Rand]
31,[N][C][=C][C][=Branch1][C][=O][N][C][C][N][Bra...,[N][C][=C][C][=Branch1][C][=O][N][C][C][N][Bra...,[Rand],[Ring1][N][Rand]
...,...,...,...,...
9952,[C][#C][C][C][C][C][Branch1][C][N][C][C][C][Ri...,[C][#C][C][C][C][C][Branch1][C][N][C][C][C][Ri...,[Canon],[C][Canon]
9956,[C][C][C][C][O][C][C][C][=Branch1][C][=N][N][B...,[C][C][C][C][O][C][C][C][=Branch1][C][=N][N][B...,[Canon],[C][Canon]
9970,[C][C][C][C][C][Ring1][Ring2][C][=Branch1][C][...,[C][C][C][C][C][Ring1][Ring2][C][=Branch1][C][...,[Canon],[=O][Canon]
9973,[C][=C][C][C][C][N][C][Ring1][Branch1][C][Ring...,[C][=C][C][C][C][N][C][Ring1][Branch1][C][Ring...,[Canon],[Ring1][O][Canon]


In [17]:
# Rows that didn't contain the ground truth tag
mismatched_df[[not el for el in list(map(contains_fn, mismatched_df["tag"], mismatched_df["diff"]))]]

Unnamed: 0,seq,output,tag,diff
154,[C][C][C][O][C][C][C][C][N][Branch1][Ring2][C]...,[C][C][C][O][C][C][C][C][N][Branch1][Ring2][C]...,[Rand],[Ring1][=N][Ring1][N][Canon]
297,[C][C][C][C][C][Branch1][Branch2][C][=N][C][=N...,[C][C][C][C][C][Branch1][Branch2][C][=N][C][=N...,[Rand],[Canon]
364,[C][N][C][Branch1][N][C][=N][O][C][Branch1][Ri...,[C][N][C][Branch1][N][C][=N][O][C][Branch1][Ri...,[Rand],[C][Ring1][O][Canon]
454,[C][=C][O][C][C][=C][Branch1][C][O][N][Branch1...,[C][=C][O][C][C][=C][Branch1][C][O][N][Branch1...,[Rand],[Canon]
765,[C][O][C][C][C][C][Branch1][Ring2][O][C][=O][C...,[C][O][C][C][C][C][Branch1][Ring2][O][C][=O][C...,[Rand],[Ring1][#Branch2][C][Canon]
...,...,...,...,...
9488,[C][=C][C][=C][C][=Branch1][Branch2][=N][N][=C...,[C][=C][C][=C][C][=Branch1][Branch2][=N][N][=C...,[Canon],[Rand]
9534,[C][C][=Branch1][Ring1][=N][O][C][C][Branch1][...,[C][C][=Branch1][Ring1][=N][O][C][C][Branch1][...,[Canon],[Ring1][=N][Rand]
9559,[C][=C][C][=Branch1][C][=O][NH1][N][Branch1][=...,[C][=C][C][=Branch1][C][=O][NH1][N][Branch1][=...,[Canon],[Rand]
9585,[C][C][C][C][N][Ring1][Ring1][C][C][Ring1][=Br...,[C][C][C][C][N][Ring1][Ring1][C][C][Ring1][=Br...,[Canon],[Rand]
