In [5]:
# %%
#再帰的に構造ー物性相関データセットから理由を生成させるバッチ処理

#ライブラリの自動インポート

# %%
import openai
from tqdm import tqdm
import pandas as pd
import os
import glob
import json
import copy
import random

import re

# %%

#環境構築からapi keyを読み込み
openai.api_key =os.environ["OPENAI_API_KEY"]

# %%
#laod dataset as dict
csv_path="dataset/BradleyMeltingPointDataset_clean.csv"
df=pd.read_csv(csv_path)
chemical_records=df.to_dict(orient='records')
random.shuffle(chemical_records)
chemical_records[:1]

# %%
model="gpt-4-1106-preview"

# %%
system_prompt="""
Provide the quantitative Reason and Prediction so that a scientist, who does not know the melting point, can predict the value.

#Commands
- You must quantitatively consider how the melting point shifts, focusing on each functional groups.
- Actual value and Prediction must match each other.
- If Actual value and Prediction differ each other, rethink Reason.
- If Prediction does not contain numbers for each functional group effect, rethink Reason

#Example reason
- Target compound: Toluene
- Basic unit, benzene has a boiling point of 80.
- Methyl group: +30 (due to larger molecular weight)
- Prediction: 110

"""

# %%
def gen_prompt(chemical_record,reason="",prediction=""):
    name=chemical_record["name"]
    smiles=chemical_record["smiles"]
    value=chemical_record["mpC"]
    prompt=f"""
#Data
-Name: {name}
-SMILES: {smiles} 
-Actual value: {value}
-Reason: {reason}
-Prediction: {prediction}

#Output (JSON keys)
- Reason, Prediction
"""
    return prompt


# %%
import json

#ask gpt
def json_generate(prompt,model="gpt-3.5-turbo-1106"):
    response = openai.chat.completions.create(
    model=model,
    messages=[
        {
            "role": "system",
            "content": system_prompt,
        },
        {
            "role": "user",
            "content": f"""{prompt}"""
        }  
    ],

logprobs=True,
    response_format={ "type": "json_object" }
    )

    return (json.loads(response.choices[0].message.content)),response


#parse prediction
def prediction_string_to_number(prompt,model="gpt-3.5-turbo-1106"):
    response = openai.chat.completions.create(
    model=model,
    messages=[
        {
            "role": "system",
            "content": """Extract integer from prediction. Use average if multiple numbers are included.
            Examples:
            In: 70.2 - 75.2 degrees Celsius
            Out: 73
            In: 75.2 degrees Celsius
            Out: 73
            In: For 1-naphthalenecarboxaldehyde, starting with the base value for naphthalene with a melting point of 80\u00b0C and subtracting the estimated aldehyde effect of approximately -47 to -50\u00b0C, the predicted melting point would be in the range of 30-33\u00b0C.
            Out: 32
            """,
        },
        {
            "role": "user",
            "content": f"""{prompt}
"#Output (JSON keys)
- Prediction"""
        }  
    ],
logprobs=True,
    response_format={ "type": "json_object" }
    )

    return (json.loads(response.choices[0].message.content))

# %%

save_base_path="dataset/231225AutoReasoning/"


#load finished records
gen_records={}
gen_json_path_list=glob.glob(save_base_path+"*.json")
for gen_json_path in tqdm(gen_json_path_list):
    with open(gen_json_path) as f:
        gen_hist=json.load(f)
    gen_records[gen_hist[0]["name"]]=gen_hist

# %%

def remove_non_alphabet_characters(s):
    # Using regex to remove all non-alphabet characters
    return re.sub('[^a-zA-Z]', '', s)

# %%
n_recursion=2
n_random_repeat=3
error_threshold=10

# %%
import time

#batch 
for chemical_record in tqdm(chemical_records):

    #load record
    gen_record=copy.deepcopy(chemical_record)

    #skip if already generated
    if gen_record["name"] in gen_records:
        print(f"Skip because already generated: {gen_record['name']}")
        continue

    record_history=[]

    fin_flag=False
    #make suggestion with random seed
    for j in range(n_random_repeat):
        if fin_flag:
            break

        gen_record["Reason"]=""
        gen_record["Prediction"]=""
        if j==0:
            record_history.append(copy.deepcopy(gen_record))

        #improve reasoing
        for i in range(n_recursion):
            try:
                r,all_log=json_generate(
                    gen_prompt(gen_record,
                            reason=gen_record["Reason"],
                            prediction=gen_record["Prediction"]
                    ),
                    model=model,
                )
                #time.sleep(30)
            except:
                pass

            break

        break

    break


100%|██████████| 3295/3295 [00:00<00:00, 41970.30it/s]
  0%|          | 0/24889 [00:25<?, ?it/s]


In [28]:
def parse_log_probs(all_log):
    log_probs=[]
    for token_pack in (all_log.choices[0].logprobs.content):
        token=token_pack.token
        logprob=token_pack.logprob
        log_probs.append((token,logprob))

    return log_probs

In [31]:
probs=parse_log_probs(all_log)
r["logprob"]=probs
r

{'Reason': "The compound 1,3-Diphenylpropane-1,3-dione consists of a propane-1,3-dione structure with two phenyl groups attached to the alpha-carbon atoms. To predict its melting point, we need to consider the effects of its functional groups and structural features:\n- Basic unit, propane, would generally have a very low melting point below -100°C due to its small size and lack of strong dipole interactions.\n- Carbonyl groups: Each carbonyl (C=O) group significantly increases the melting point due to the possibility of hydrogen bonding and strong dipolar interactions. If one carbonyl group increases the melting point by roughly +100°C, two such groups would contribute approximately +200°C.\n- Phenyl groups: Aromatic rings usually increase the melting point due to π-π stacking interactions, which provide extra stability to the crystalline lattice. Each phenyl group might contribute around +20°C to the melting point.\n- The alkane chain between the carbonyl groups would normally slight

In [26]:
with open("t.json","w") as f:
    json.dump(log_probs,f   )

In [11]:
import json
import joblib
joblib.dump(all_log,"t.bin")

['t.bin']

In [12]:
t=joblib.load("t.bin")

In [13]:
t

ChatCompletion(id='chatcmpl-8doaiyuHStjsDeMwgwj3g4ZJDkujg', choices=[Choice(finish_reason='stop', index=0, logprobs=ChoiceLogprobs(content=[ChatCompletionTokenLogprob(token='{\n', bytes=[123, 10], logprob=-3.0232935e-05, top_logprobs=[]), ChatCompletionTokenLogprob(token=' ', bytes=[32], logprob=-0.088981785, top_logprobs=[]), ChatCompletionTokenLogprob(token=' "', bytes=[32, 34], logprob=-1.9361265e-07, top_logprobs=[]), ChatCompletionTokenLogprob(token='Reason', bytes=[82, 101, 97, 115, 111, 110], logprob=0.0, top_logprobs=[]), ChatCompletionTokenLogprob(token='":', bytes=[34, 58], logprob=-0.00036061046, top_logprobs=[]), ChatCompletionTokenLogprob(token=' "', bytes=[32, 34], logprob=-0.0018493895, top_logprobs=[]), ChatCompletionTokenLogprob(token='The', bytes=[84, 104, 101], logprob=-0.2472705, top_logprobs=[]), ChatCompletionTokenLogprob(token=' compound', bytes=[32, 99, 111, 109, 112, 111, 117, 110, 100], logprob=-1.3816159, top_logprobs=[]), ChatCompletionTokenLogprob(token=' '