In [5]:
import os
import re
import numpy as np
import textwrap
import torch
from accelerate import init_empty_weights, Accelerator
from transformers import AutoModelForCausalLM, AutoTokenizer
from custom_modeling_opt import CustomOPTForCausalLM
from matplotlib import pyplot as plt

In [6]:
seed_value=42
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [7]:
import rdkit.Chem as Chem
import sys
from rdkit.Chem import RDConfig, MACCSkeys, QED
from rdkit.Chem.rdMolDescriptors import CalcTPSA, CalcCrippenDescriptors
from rdkit.Chem import Descriptors
sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
import sascorer

In [8]:
device = "cuda:1"

In [9]:
tokenizer = AutoTokenizer.from_pretrained("/auto/home/menuab/code/ChemLacticaTestSuite/src/tokenizer/ChemLacticaTokenizer_50066/")
print('tokenizer size: ', len(tokenizer))

tokenizer size:  50066


In [10]:
checkpoint_path = "/auto/home/menuab/code/checkpoints/f2c6ebb289994595a478f513/125m_126k_f2c6/"
checkpoint_path = "/auto/home/menuab/code/checkpoints/f3fbd012918247a388efa732/125m_126k_f3fb/"
checkpoint_path = "/auto/home/menuab/code/checkpoints/26d322857a184fcbafda5d4a/125m_118k_26d3/"
checkpoint_path

'/auto/home/menuab/code/checkpoints/26d322857a184fcbafda5d4a/125m_118k_26d3/'

In [11]:
model = CustomOPTForCausalLM.from_pretrained(
            checkpoint_path,
            use_flash_attn=True,
            torch_dtype=torch.bfloat16
            )
model.eval()
model.to(device)
print(f'model loaded with embedding size of : {model.model.decoder.embed_tokens.num_embeddings}')
assert(model.model.decoder.embed_tokens.num_embeddings == len(tokenizer))

model loaded with embedding size of : 50066


In [50]:
def calculate_tanimoto_distance(smiles1, smiles2):
    # Convert SMILES strings to RDKit molecules
    mol1 = Chem.MolFromSmiles(smiles1)
    mol2 = Chem.MolFromSmiles(smiles2)
    
    # Check if the molecules were successfully created
    if mol1 is None or mol2 is None:
        raise ValueError("Invalid SMILES representation")

    # Generate MACCS keys for the molecules
    keys1 = MACCSkeys.GenMACCSKeys(mol1)
    keys2 = MACCSkeys.GenMACCSKeys(mol2)

    # Calculate the Tanimoto similarity
    common_bits = sum(bit1 & bit2 for bit1, bit2 in zip(keys1, keys2))
    total_bits = sum(bit1 | bit2 for bit1, bit2 in zip(keys1, keys2))

    tanimoto_distance = 1.0 - (common_bits / total_bits)  # Tanimoto distance ranges from 0 to 1

    return tanimoto_distance

In [12]:
mols = ['COc1ccccc1NS(=O)(=O)c1ccc(NC(C)=O)cc1',
'CC(=O)Oc1cc2oc(=O)cc(C)c2cc1OC(C)=O',
'COc1ccc(Oc2coc3cc(OC(C)=O)ccc3c2=O)cc1',
'COc1ccc(C(=O)OCC(=O)Nc2ccc(C)cc2)cc1OC',
'COc1ccc(C(=O)OCC(=O)Nc2ccc(F)cc2)cc1OC',
'COc1ccc(C(=O)OCC(=O)Nc2ccc(Cl)cc2)cc1OC',
'COc1ccc(OC)c(-c2oc3ccccc3c(=O)c2O)c1',
'O=S(=O)(NCc1ccc(F)cc1F)c1ccc(F)cc1F',
'CC1(O)CCC2C3CCC4=C(CCC(=O)C4)C3CCC21C',
'NS(=O)(=O)c1cc(C(=O)Nc2ccccc2Cl)ccc1F']

In [57]:
prompt = f'[SIMILAR]{mols[0]} 0.9[/SIMILAR][START_SMILES]'

# prompt = f'[SIMILAR]CC(=O)NC1=CC=C(C=C1)S(=O)(=O)NC2=CC=CC=C2OC 0.9[/SIMILAR][START_SMILES]'
prompt = tokenizer(prompt, return_tensors="pt").to(device).input_ids

In [58]:
out = model.generate(prompt, do_sample=False, max_new_tokens=300, eos_token_id=20, return_dict_in_generate=True, output_scores=True)

In [56]:
tokenizer.batch_decode(out.sequences)

['[SIMILAR]CC(=O)NC1=CC=C(C=C1)S(=O)(=O)NC2=CC=CC=C2OC 0.9[/SIMILAR][START_SMILES]CC(=O)NC1=CC=C(C=C1)S(=O)(=O)NC2=CC=CC=C2OC[END_SMILES]']

In [59]:
tokenizer.batch_decode(out.sequences)

['[SIMILAR]COc1ccccc1NS(=O)(=O)c1ccc(NC(C)=O)cc1 0.9[/SIMILAR][START_SMILES]COC1=NS(=O)(=O)N[C@@H]1NC(=O)C=C[END_SMILES]']

In [61]:
calculate_tanimoto_distance('CC(=O)NC1=CC=C(C=C1)S(=O)(=O)NC2=CC=CC=C2OC','COC1=NS(=O)(=O)N[C@@H]1NC(=O)C=C')

0.4305555555555556

In [30]:
mol = Chem.MolFromSmiles('CC(=O)NC1=CC=C(C=C1)S(=O)(=O)NC2=CC=CC=C2OC')
mol = Chem.MolToSmiles(mol, isomericSmiles=False, canonical=True)

In [31]:
mol

'COc1ccccc1NS(=O)(=O)c1ccc(NC(C)=O)cc1'