In [None]:
import os
import time
import re
import ast
import numpy as np
import pandas as pd
import joblib
import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors
from rdkit.Chem import rdMolDescriptors
from openai import OpenAI
from dotenv import load_dotenv

from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

In [None]:
load_dotenv()
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')  # Replace API key
if OPENAI_API_KEY is None:
    raise ValueError("API key not found in .env file")
else:
    print("Found API key!")

## Set params

In [None]:
# Select LLM version
# LLM_version = "GPT4omini"
LLM_version = "GPT4_1"
# LLM_version = "o4mini"
# LLM_version = "GPT4o"
# LLM_version = "GPT5"

if LLM_version == "GPT4omini":
    LLM_version_run = "gpt-4o-mini-2024-07-18"
    print(f"{LLM_version} is using")
elif LLM_version == "GPT4o":
    LLM_version_run =  "gpt-4o-2024-11-20"
    print(f"{LLM_version} is using")
elif LLM_version == "o4mini":
    LLM_version_run = "o4-mini-2025-04-16"
    print(f"{LLM_version} is using")
elif LLM_version == "GPT4_1":
    LLM_version_run =  "gpt-4.1-2025-04-14"
    print(f"{LLM_version} is using")
elif LLM_version == "GPT5":
    LLM_version_run =  "gpt-5-2025-08-07"
    print(f"{LLM_version} is using")
else:
    print("Language Model Chosen Error!")

# Set number of molecules return per LLM call
num_mol_per_message = 10

# Set total number of LLM call
num_api_call = 10

# Set Similarity threshold
tani_thr_RefExp_max = 0.85
tani_thr_RefExp_min = 0.25
tani_thr_RefGen_max = 0.85
tani_thr_RefGen_min = 0.25

# Set LogP threshold
logP_threshold_min = 3

# Set target metal and other metal
metal1 = "Am(III)"  # Target metal
metal2 = "Eu(III)"  # Other metal

# Select and Load the trained ML model
model_name = "resources/xgb_model0.1-1.sav"
loaded_model = joblib.load(model_name)

# design_focus = f"binding ligands with functional groups to target {metal1} metal"
design_focus = "binding ligands must be structurally similar to bis-triazinyl bipyridines (BTBPs)"
eval_onlyGood = True

In [None]:
from utils import load_resources

all_exp_smiles_list, all_solvent_df, all_metals_df,\
    original_smiles_eval_df, ml_conditions_list, all_ml_feature_series = load_resources(eval_onlyGood)

## Separation conditions

In [None]:
# Define conditions here

all_extractant_conc_list = [0.05, 0.1, 0.2]  # ligand concentration (M)
all_vol_a_list = [1]  # volume ratio of solvent A (from 0-1)
all_vol_b_list = [0]  # volume ratio of solvent B
all_solvent_a_list2D = all_solvent_df.loc[['n-dodecane', '1-octanol', 'kerosene']].values.tolist() # organic solvent A descriptors
all_solvent_b_list2D = [[0] * 5] # organic solvent B descriptors
all_acid_dipole_list = [2.17] # aicd dipole moment, HNO3
all_acid_conc_list = [0.5, 1, 2, 3]  # acid concentration (M)
all_T_list = [298] # temperature (K)
all_metal1_conc_list = np.log10([0.00000001, 0.05])  # metal concentration (mM)
all_metals_list2D = all_metals_df.loc[[metal1]].values.tolist() # metal descriptors
all_metal2_conc_list = np.log10([0.00000001, 0.05])

# Generate all combinations
combinations = [
    [one_extractant_conc]
    + [one_vol_a]
    + [one_vol_b]
    + one_solvent_a
    + one_solvent_b
    + [one_acid_dipole]
    + [one_acid_conc]
    + [one_T]
    + [one_metal_conc]
    + one_metal
    for one_extractant_conc in all_extractant_conc_list
    for one_vol_a in all_vol_a_list
    for one_vol_b in all_vol_b_list
    for one_solvent_a in all_solvent_a_list2D
    for one_solvent_b in all_solvent_b_list2D
    for one_acid_dipole in all_acid_dipole_list
    for one_acid_conc in all_acid_conc_list
    for one_T in all_T_list
    for one_metal_conc in all_metal1_conc_list
    for one_metal in all_metals_list2D
]

## Start workflow

In [None]:
from utils import callLLM_generate_smiles, similarStructureCheck, generate_ml_input, classify_D, extract_metal_label

ml_conditions_list_good = ml_conditions_list.copy()
new_smiles_conditions_df = pd.DataFrame(columns=ml_conditions_list_good + ["New_extractants_SMILES", "Pred_class"]) 
gen_smiles_eval_df = pd.DataFrame(columns=['SMILES', "Target_metal", "Other_metal", 'Source', 'Similarity_to_Experimental', 'Similarity_to_Generated', 'LogP'])
all_smiles_eval_df = original_smiles_eval_df.copy()

for i in range(0, num_api_call):
    print("Times of LLM call: ", i)
    message_new_smiles_list = callLLM_generate_smiles(
        all_smiles_eval_df,
        OPENAI_API_KEY,
        num_mol_per_message,
        metal1,
        metal2,
        tani_thr_RefExp_min,
        tani_thr_RefExp_max,
        tani_thr_RefGen_max,
        logP_threshold_min,
        design_focus,
        LLM_version_run,
    )
    # Init the eval list for one API call
    target_metal_eval = []
    other_metal_eval = []
    sim_RefExp_eval = []
    sim_RefGen_eval = []
    logP_eval = []
    gen_smiles_one_eval_list2D = []
    
    for index_new_smiles, one_new_smiles in enumerate(message_new_smiles_list):
        # Check if the new SMILES is valid
        try:
            m = rdkit.Chem.MolFromSmiles(one_new_smiles)
            duplicate = False
            if m:
                fp = rdkit.Chem.AllChem.GetMorganFingerprintAsBitVect(m, 4, nBits=2048)
                # Check the SMILES in one message is duplicate
                for j in range(index_new_smiles):
                    m2 = rdkit.Chem.MolFromSmiles(message_new_smiles_list[j])
                    if m2:
                        fp2 = rdkit.Chem.AllChem.GetMorganFingerprintAsBitVect(m2, 4, nBits=2048)
                        similarity = rdkit.DataStructs.TanimotoSimilarity(fp, fp2)
                        if similarity >= 1:
                            sim_RefExp_eval.append('Duplicate SMILES')
                            sim_RefGen_eval.append('Duplicate SMILES')
                            logP_eval.append('Duplicate SMILES')
                            target_metal_eval.append('Duplicate SMILES')
                            other_metal_eval.append('Duplicate SMILES')
                            duplicate = True
                            break
            else:
                sim_RefExp_eval.append('Invalid SMILES')
                sim_RefGen_eval.append('Invalid SMILES')
                logP_eval.append('Invalid SMILES')
                target_metal_eval.append('Invalid SMILES')
                other_metal_eval.append('Invalid SMILES')
                continue
        except:
            print("Found invalid SMILES in workflow...")
            sim_RefExp_eval.append('Invalid SMILES')
            sim_RefGen_eval.append('Invalid SMILES')
            logP_eval.append('Invalid SMILES')
            target_metal_eval.append('Invalid SMILES')
            other_metal_eval.append('Invalid SMILES')
            continue
        
        if duplicate:
            print("Found duplicate SMILES in workflow...")
            continue
        
        # Calculate max tanimoto score to experimental SMILES
        max_similarity_score, _ = similarStructureCheck(one_new_smiles, all_exp_smiles_list)
        print("Similarity calculation to Experimental is done.")
        if max_similarity_score < tani_thr_RefExp_min:
            sim_RefExp_eval.append("LOW")
        elif max_similarity_score > tani_thr_RefExp_max:
            sim_RefExp_eval.append("HIGH")
        else:
            sim_RefExp_eval.append("MEDIUM")
            
        # Calculate max tanimoto score to generated SMILES
        if not gen_smiles_eval_df['SMILES'].empty:
            max_similarity_score, _ = similarStructureCheck(one_new_smiles, gen_smiles_eval_df["SMILES"])
            print("Similarity calculation to Generated is done.")
            if max_similarity_score < tani_thr_RefGen_min:
                sim_RefGen_eval.append("LOW")
            elif max_similarity_score > tani_thr_RefGen_max:
                sim_RefGen_eval.append("HIGH")
            else:
                sim_RefGen_eval.append("MEDIUM")
        else:
            sim_RefGen_eval.append("")
                
        # Calculate log P
        logP = rdMolDescriptors.CalcCrippenDescriptors(m)[0]
        print("Log P calculation is done.")
        if logP > 3:
            logP_eval.append('ORGANIC')
        elif 2 <= logP <= 3:
            logP_eval.append('UNSELECTIVE')
        else:
            logP_eval.append('AQUEOUS')

        print("Start conditions search...")
        OUA_tuple_list = []
        for one_condition_list in combinations:
            # Get the ML input and make a prediction for metal_1
            ml_input = generate_ml_input(one_new_smiles, one_condition_list, ml_conditions_list, all_ml_feature_series)
            ml_input = np.array(ml_input)
            ml_input = ml_input.reshape(1, 1860)
            predicted_class = loaded_model.predict(ml_input)
            
            # Get the ML desp for metal 2
            one_condition_list_metalchange = one_condition_list.copy()
            start_index = len(one_condition_list_metalchange) - len(all_metals_df.loc[[metal2]].values.flatten().tolist())
            one_condition_list_metalchange[start_index:] = all_metals_df.loc[[metal2]].values.flatten().tolist()
            
            # Loop the metal 2 conc
            for one_metal2_conc in all_metal2_conc_list:
                # Get the metal 2 conc
                one_condition_list_metalchange[start_index-1] = one_metal2_conc
                # Get the ML input and make a prediction for 2nd metal
                ml_input2 = generate_ml_input(one_new_smiles, one_condition_list_metalchange, ml_conditions_list, all_ml_feature_series)
                ml_input2 = np.array(ml_input2)
                ml_input2 = ml_input2.reshape(1, 1860)
                predicted_class2 = loaded_model.predict(ml_input2)
                
                if predicted_class[0] == 2 and predicted_class2[0] == 0:
                    OUA_tuple_list.append(("ORGANIC", "AQUEOUS"))
                    print("Good D Detected!") 
                    new_row_list = one_condition_list + [one_new_smiles] + [predicted_class[0]]
                    new_row_list2 = one_condition_list_metalchange + [one_new_smiles] + [predicted_class2[0]]
                    new_row_df = pd.DataFrame([new_row_list], columns=new_smiles_conditions_df.columns)
                    new_row_df2 = pd.DataFrame([new_row_list2], columns=new_smiles_conditions_df.columns)
                    new_smiles_conditions_df = pd.concat([new_smiles_conditions_df, new_row_df, new_row_df2], ignore_index=True,)
                else:
                    OUA_tuple_list.append((classify_D(predicted_class[0]), classify_D(predicted_class2[0])))

        metal1_label, metal2_label = extract_metal_label(OUA_tuple_list)
        target_metal_eval.append(metal1_label)
        other_metal_eval.append(metal2_label)

    gen_smiles_one_eval_list2D.append(message_new_smiles_list)
    gen_smiles_one_eval_list2D.append(target_metal_eval)
    gen_smiles_one_eval_list2D.append(other_metal_eval)
    gen_smiles_one_eval_list2D.append(['LLM generated']*num_mol_per_message)
    gen_smiles_one_eval_list2D.append(sim_RefExp_eval)
    gen_smiles_one_eval_list2D.append(sim_RefGen_eval)
    gen_smiles_one_eval_list2D.append(logP_eval)
    gen_smiles_one_eval_list2D = np.array(gen_smiles_one_eval_list2D).T.tolist()

    # Add LLM generated SMILES to all_smiles_eval_df for another API call
    gpt_smiles_one_eval_df = pd.DataFrame(gen_smiles_one_eval_list2D, columns=['SMILES', "Target_metal", "Other_metal", 'Source', 'Similarity_to_Experimental', 'Similarity_to_Generated', 'LogP'])
    all_smiles_eval_df = pd.concat([all_smiles_eval_df, gpt_smiles_one_eval_df], ignore_index=True)
    gen_smiles_eval_df = pd.concat([gen_smiles_eval_df, gpt_smiles_one_eval_df], ignore_index=True)
    time.sleep(3)

print("New Extractants Search End!")

## Extract metal and solvent name

In [None]:
matching_columns_solventA = []
for one_sol_col_name in all_solvent_df.columns[0:5]:
    one_sol_col_name += "_A"
    matching_columns_solventA.append(one_sol_col_name)

all_solvent_A_df = all_solvent_df.copy()
all_solvent_A_df = all_solvent_A_df.reset_index().rename(columns={'Solvent': 'Solvent_A_name'})
all_solvent_A_df = all_solvent_A_df.rename(columns={'Molar_mass(g/mol)': 'Molar_mass(g/mol)_A',
                                                    'Log_P': 'Log_P_A',
                                                    'Boiling_point(K)': 'Boiling_point(K)_A',
                                                    'Melting_point(K)': 'Melting_point(K)_A',
                                                    'Density(g/mL)': 'Density(g/mL)_A',})

merged_new_smiles_conditions_df = new_smiles_conditions_df.merge(
    all_solvent_A_df,
    on=matching_columns_solventA,
    how='left'
)
print(merged_new_smiles_conditions_df.shape)

In [None]:
all_metals_df = all_metals_df.reset_index().rename(columns={'Metal': 'Metal_name'})
matching_columns_metal = ['Atomic_number','Ionic_radius_nm','Oxidation_state']

merged_new_smiles_conditions_df = merged_new_smiles_conditions_df.merge(
    all_metals_df[['Metal_name'] + matching_columns_metal],
    on=matching_columns_metal,
    how='left'
)
print(merged_new_smiles_conditions_df.shape)
merged_new_smiles_conditions_df.head(2)

## Save results

In [None]:
current_time = time.strftime("%Y%H%M%m%d")
folderName_results = "results_AmEu"
folderName2_results = f"results{current_time}_{LLM_version}_SMILES{num_api_call}x{num_mol_per_message}"
folderPath_results = os.path.join(folderName_results, folderName2_results)
os.makedirs(folderPath_results, exist_ok=True)

In [None]:
filePath_genSMILESeval = os.path.join(folderPath_results, "genSMILESeval.xlsx")
gen_smiles_eval_df = gen_smiles_eval_df.rename(columns={'SMILES':'Gen_SMILES'})
gen_smiles_eval_df.to_excel(filePath_genSMILESeval, index=False)

filePath_genSMILESconditions = os.path.join(folderPath_results, "genSMILESconditions.xlsx")
merged_new_smiles_conditions_df.to_excel(filePath_genSMILESconditions, index=False)

filePath_params = os.path.join(folderPath_results, "params.txt")
with open(filePath_params, 'w') as f:
    f.write(f"Number_of_LLM_Call: {num_api_call}\n")
    f.write(f"Number_of_Molecules_Per_LLM_Call: {num_mol_per_message}\n")
    f.write(f"Threshold_LogP: {logP_threshold_min}\n")
    f.write(f"Target_Metal: {metal1}\n")
    f.write(f"Other_Metal: {metal2}\n")
    f.write(f"Supervised_ML: {model_name}\n")
    f.write(f"Design_Focus: {design_focus}\n")
    f.write(f"Eval_onlyGood: {eval_onlyGood}\n")