In [15]:
#再帰的に構造ー物性相関データセットから理由を生成させる
#バッチ処理

#ライブラリの自動インポート
%reload_ext autoreload
%autoreload 2

In [16]:
import openai
from tqdm import tqdm
import pandas as pd
import os
import glob
import json
import copy

In [17]:

openai.api_key =os.environ["OPENAI_API_KEY"]

In [18]:
#laod dataset as dict
csv_path="dataset/BradleyMeltingPointDataset_clean.csv"
df=pd.read_csv(csv_path)
chemical_records=df.to_dict(orient='records')
chemical_records[:1]

[{'mpC': 87.0,
  'name': '((5-((Diphenylphosphino)methyl)-2,2-dimethyl-1,3-dioxolan-4-yl)methyl)(diphenyl)phosphine',
  'smiles': 'P(CC1OC(OC1CP(c1ccccc1)c1ccccc1)(C)C)(c1ccccc1)c1ccccc1',
  'csid': 109291,
  'link': 'http://dx.doi.org/10.1021/ci0500132',
  'source': 'Karthikeyan M.; Glen R.C.; Bender A. General melting point prediction based on a diverse compound dataset and artificial neural networks. J. Chem. Inf. Model.; 2005; 45(3); 581-4277'}]

In [19]:
model="gpt-4-1106-preview"

In [20]:
system_prompt="""
Provide the quantitative Reason and Prediction so that a scientist, who does not know the melting point, can predict the value.

#Commands
- You must quantitatively consider how the melting point shifts, focusing on each functional groups.
- Actual value and Prediction must match each other.
- If Actual value and Prediction differ each other, rethink Reason.
- If Prediction does not contain numbers for each functional group effect, rethink Reason

#Example reason
- Target compound: Toluene
- Basic unit, benzene has a boiling point of 80.
- Methyl group: +30 (due to larger molecular weight)
- Prediction: 110

"""

In [21]:
def gen_prompt(chemical_record,reason="",prediction=""):
    name=chemical_record["name"]
    smiles=chemical_record["smiles"]
    value=chemical_record["mpC"]
    prompt=f"""
#Data
-Name: {name}
-SMILES: {smiles} 
-Actual value: {value}
-Reason: {reason}
-Prediction: {prediction}

#Output (JSON keys)
- Reason, Prediction
"""
    return prompt


In [22]:
import json

#ask gpt
def json_generate(prompt,model="gpt-3.5-turbo-1106"):
    response = openai.chat.completions.create(
    model=model,
    messages=[
        {
            "role": "system",
            "content": system_prompt,
        },
        {
            "role": "user",
            "content": f"""{prompt}"""
        }  
    ],
    response_format={ "type": "json_object" }
    )

    return (json.loads(response.choices[0].message.content))


#parse prediction
def prediction_string_to_number(prompt,model="gpt-3.5-turbo-1106"):
    response = openai.chat.completions.create(
    model=model,
    messages=[
        {
            "role": "system",
            "content": """Extract integer from prediction. Use average if multiple numbers are included.
            Examples:
            In: 70.2 - 75.2 degrees Celsius
            Out: 73
            In: 75.2 degrees Celsius
            Out: 73
            In: For 1-naphthalenecarboxaldehyde, starting with the base value for naphthalene with a melting point of 80\u00b0C and subtracting the estimated aldehyde effect of approximately -47 to -50\u00b0C, the predicted melting point would be in the range of 30-33\u00b0C.
            Out: 32
            """,
        },
        {
            "role": "user",
            "content": f"""{prompt}
"#Output (JSON keys)
- Prediction"""
        }  
    ],
    response_format={ "type": "json_object" }
    )

    return (json.loads(response.choices[0].message.content))

In [23]:
#t=prediction_string_to_number("Considering a starting point of 80\u00b0C for naphthalene and accounting for the influence of the aldehyde functional group, which can reduce the melting point by 47 to 50\u00b0C, the estimated melting point for 1-naphthalenecarboxaldehyde is around 30 to 33\u00b0C, closely aligning with the actual value of 33.5\u00b0C.")
#t

In [27]:

save_base_path="dataset/231225AutoReasoning/"


#load finished records
gen_records={}
gen_json_path_list=glob.glob(save_base_path+"*.json")
for gen_json_path in tqdm(gen_json_path_list):
    with open(gen_json_path) as f:
        gen_hist=json.load(f)
    gen_records[gen_hist[0]["name"]]=gen_hist

  0%|          | 0/9 [00:00<?, ?it/s]

100%|██████████| 9/9 [00:00<00:00, 10660.47it/s]


In [30]:
import re

def remove_non_alphabet_characters(s):
    # Using regex to remove all non-alphabet characters
    return re.sub('[^a-zA-Z]', '', s)

In [31]:
n_recursion=2
n_random_repeat=3
error_threshold=10

In [32]:
#batch 
for chemical_record in tqdm(chemical_records):

    #load record
    gen_record=copy.deepcopy(chemical_record)

    #skip if already generated
    if gen_record["name"] in gen_records:
        print(f"Skip because already generated: {gen_record['name']}")
        continue

    record_history=[]

    fin_flag=False
    #make suggestion with random seed
    for j in range(n_random_repeat):
        if fin_flag:
            break

        gen_record["Reason"]=""
        gen_record["Prediction"]=""
        if j==0:
            record_history.append(copy.deepcopy(gen_record))

        #improve reasoing
        for i in range(n_recursion):
            r=json_generate(
                gen_prompt(gen_record,
                        reason=gen_record["Reason"],
                        prediction=gen_record["Prediction"]
                ),
                model=model,
            )
            #parse prediction string to number
            gen_record.update(r)
            try:
                gen_record["Prediction(integer)"]=float(prediction_string_to_number(gen_record["Prediction"])["Prediction"])
            except:
                gen_record["Prediction(integer)"]=99999
            record_history.append(copy.deepcopy(gen_record))
            
            #finish reasoning if prediction is close to actual value
            if abs(gen_record["Prediction(integer)"]-gen_record["mpC"])<=error_threshold:
                fin_flag=True
                print(f"Finished because good reasoning was achieved: {gen_record['name']}")
                break

    #save
    save_name=remove_non_alphabet_characters(gen_record["name"])
    save_path=save_base_path+f"{save_name}.json"
    with open(save_path, 'w') as f:
        json.dump(record_history, f, indent=4)

    gen_records[gen_record["name"]]=record_history

  0%|          | 0/24889 [00:00<?, ?it/s]

In [None]:
#値段の概算
t=gen_prompt(gen_record,
               reason=gen_record["Reason"],
               prediction=gen_record["Prediction"]
    )
user_len=len(t.split(" "))
system_len=len(system_prompt.split(" "))

input_cost=0.01/1000*(user_len+system_len)

gen_len=len(gen_record["Reason"].split(" "))+len(gen_record["Prediction"].split(" "))
output_cost=0.03/1000*gen_len

n_trials=2
cost=n_trials*(input_cost+output_cost)
print(f"Input tokens: {user_len+system_len}")
print(f"Output tokens: {gen_len}")
print(f"Cost: {cost} USD")
print(f"Cost: {cost*150} JP")

Input tokens: 298
Output tokens: 205
Cost: 0.01826 USD
Cost: 2.739 JP
