In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# トークナイザーとモデルの準備
tokenizer = AutoTokenizer.from_pretrained(
    "mistralai/Mixtral-8x7B-Instruct-v0.1",
)
model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mixtral-8x7B-Instruct-v0.1",
    torch_dtype=torch.bfloat16,
    #load_in_4bit=True,
    device_map="auto",
    trust_remote_code=False,
)

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 19/19 [00:16<00:00,  1.13it/s]


In [2]:

from transformers import pipeline
pipe = pipeline("text-generation", model=model,tokenizer=tokenizer,
            )
pipe("Q: hello! how are you? A: ")

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


[{'generated_text': "Q: hello! how are you? A:  I'm doing well, thank you"}]

In [3]:

pipe("Q: こんにちは  A: ")

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


[{'generated_text': 'Q: こんにちは  A: こんにちは\n\nQ'}]

In [4]:

#予測周りのutility funcs
import re
import torch
import gc
from IPython.display import clear_output
from trl import AutoModelForCausalLMWithValueHead
def gen_text_stop_word(prompt,model,tokenizer,
                       device="cuda:0",
                       stop_words=["#Problem","#Reason","# Problem"],
                       double_stop_words=["#Prediction"],
                       stream=False,
                       #stream=True,
                       max_tokens=400,
                       ):
    gc.collect()
    torch.cuda.empty_cache()



    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    # 生成されたテキストを格納する変数
    generated_text = ""

    # トークンを一つずつ生成
    for i in range(max_tokens):
        # 次のトークンを予測
        outputs = model(input_ids)
        if type(model) is AutoModelForCausalLMWithValueHead:
            #AutoModelForCausalLMWithValueHeadの場合
            logits = outputs[0]
            next_token_logits = logits[:, -1, :]
        else:
            next_token_logits = outputs.logits[:, -1, :]

        next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)

        # 生成されたトークンを現在の入力に追加
        input_ids = torch.cat([input_ids, next_token], dim=-1)

        # 生成されたテキストを更新
        generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)[len(prompt):]

        if stream:
            if i%1==0:
                clear_output()
            print(generated_text)

        # ストップワードのチェック
        if any(stop_word in generated_text for stop_word in stop_words):
            break

        # 2回以上出現したらstopするwordのcheck 
        stop_flag=False
        for check_word in double_stop_words:
            count=generated_text.count(check_word)
            if count>=2:
                stop_flag=True
                break
        if stop_flag:
            break

    return generated_text


In [5]:
prompt="What's your hobby?"
gen_text_stop_word(prompt,model,tokenizer,stream=True)



I like to play the piano.

Do you like to play the piano?

Yes, I do. I've been playing the piano for 10 years.

What kind of music do you like to play?




KeyboardInterrupt: 

In [6]:

prompt="10+20+5+60="
gen_text_stop_word(prompt,model,tokenizer,stream=True)

95

The Brainliest Answer!



KeyboardInterrupt: 

In [7]:
import pandas as pd
import random
csv_path="dataset/BradleyMeltingPointDataset_clean.csv"
df=pd.read_csv(csv_path)
chemical_records=df.to_dict(orient='records')
random.shuffle(chemical_records)
chemical_records[:1]

[{'mpC': 131.0,
  'name': '5-amino-1,3-diphenyl-1h-pyrazole',
  'smiles': 'c1ccc(cc1)c2cc(n(n2)c3ccccc3)N',
  'csid': 173078,
  'link': 'http://www.alfa.com/en/GP100W.pgm?DSSTK=B20464',
  'source': 'Alfa Aesar'}]

In [27]:
system_prompt="""
Provide the quantitative Reason and Prediction so that a scientist, who does not know the melting point, can predict the value.

#Commands
- You must quantitatively consider how the melting point shifts, focusing on each functional groups.
- Actual value and Prediction must match each other.
- If Actual value and Prediction differ each other, rethink Reason.
- If Prediction does not contain numbers for each functional group effect, rethink Reason

#Example reason
#Name: Chloroform
-Dichloromethane has a melting point of -97.
-Chloro group: +33 (larger molecular weight)
#Prediction: -64

"""

def gen_prompt(chemical_record,reason="",prediction=""):
    name=chemical_record["name"]
    smiles=chemical_record["smiles"]
    value=chemical_record["mpC"]
    prompt=f"""
Data
#Name: {name}
#SMILES: {smiles} 
#Actual value: {value}
#Reason: {reason}
"""
    return prompt

In [28]:
prompt=system_prompt+(gen_prompt(chemical_records[0]))
print(prompt)


Provide the quantitative Reason and Prediction so that a scientist, who does not know the melting point, can predict the value.

#Commands
- You must quantitatively consider how the melting point shifts, focusing on each functional groups.
- Actual value and Prediction must match each other.
- If Actual value and Prediction differ each other, rethink Reason.
- If Prediction does not contain numbers for each functional group effect, rethink Reason

#Example reason
#Name: Chloroform
-Dichloromethane has a melting point of -97.
-Chloro group: +33 (larger molecular weight)
#Prediction: -64


Data
#Name: 5-amino-1,3-diphenyl-1h-pyrazole
#SMILES: c1ccc(cc1)c2cc(n(n2)c3ccccc3)N 
#Actual value: 131.0
#Reason: 



In [29]:

r=gen_text_stop_word(prompt,model,tokenizer,stream=True)
r

-Phenyl group: +50 (larger molecular weight)
-Pyrazole group: +10 (hydrogen bonding)
#Prediction: 161.0

















































KeyboardInterrupt: 

In [36]:
#データセットの読み込み
import pandas as pd
import random
df=pd.read_csv("dataset/231225AutoReasoning/240104best_reason_record.csv")
dataset=df.to_dict(orient="records")
random.seed(0)
random.shuffle(dataset)


system_prompt="You are a professional chemist. Predict the melting point of the following compound."
def gen_compound_text(chemical_record,
    reason="",prediction=""):
    name=chemical_record["name"]
    smiles=chemical_record["smiles"]
    prompt=f"""
#Problem
##Name: {name}
##SMILES: {smiles}"""
    if reason !="" and prediction!="":
        prompt+=f"""
##Reason: {reason}
##Prediction: {prediction}
"""
    else:
        #test mode
        prompt+="""
##Reason: 
"""
    return prompt

i=0
prompt=system_prompt+gen_compound_text(dataset[i],reason=dataset[i]["Reason"],prediction=dataset[i]["Prediction(integer)"])
prompt+=gen_compound_text(dataset[i+1])
print(prompt)


You are a professional chemist. Predict the melting point of the following compound.
#Problem
##Name: (1,2,2,3-tetramethylcyclopentyl)methyl 4-aminobenzoate
##SMILES: O=C(OCC1(C)CCC(C)C1(C)C)c1ccc(N)cc1
##Reason: To predict the melting point of (1,2,2,3-tetramethylcyclopentyl)methyl 4-aminobenzoate, we consider the effects of various functional groups and structural features:
- Basic unit, cyclopentane has a typical melting point around -94°C.
- Methyl groups: Four methyl groups attached to the cyclopentane ring increase molecular weight and steric hindrance; each can contribute an estimated +20°C due to increased van der Waals interactions.
- Benzene ring attached via ester linkage: The aromatic ring contributes to an increase in molecular weight and rigidity; estimated contribution is +80°C.
- Ester group: Ester functionality typically raises the melting point due to polar interactions and possible hydrogen bonding if protic solvents are present or intramolecularly; estimated contrib

In [37]:

r=gen_text_stop_word(prompt,model,tokenizer,stream=True)
r

- Basic unit, quinoxaline, has a typical melting point around 80°C.
- Chlorine atoms: Two chlorine atoms attached to the quinoxaline ring increase molecular weight and steric hindrance; each can contribute an estimated +20°C due to increased van der Waals interactions.
- Nitrogen atoms in the quinoxaline ring: Nitrogen atoms can engage in hydrogen bonding if protic solvents are present or intramolecularly; estimated contribution is +10°C.
Combining these effects in a qualitative manner leads to the predicted melting point.
##Prediction: 110.0

#Problem


'- Basic unit, quinoxaline, has a typical melting point around 80°C.\n- Chlorine atoms: Two chlorine atoms attached to the quinoxaline ring increase molecular weight and steric hindrance; each can contribute an estimated +20°C due to increased van der Waals interactions.\n- Nitrogen atoms in the quinoxaline ring: Nitrogen atoms can engage in hydrogen bonding if protic solvents are present or intramolecularly; estimated contribution is +10°C.\nCombining these effects in a qualitative manner leads to the predicted melting point.\n##Prediction: 110.0\n\n#Problem'

In [38]:
dataset[i+1]

{'name': '2,3-Dichloroquinoxaline',
 'smiles': 'Clc1nc2ccccc2nc1Cl',
 'csid': 15796,
 'link': 'http://msds.chem.ox.ac.uk/',
 'source': 'academic website',
 'Reason': 'The basic unit for comparison could be pyrazine, which has a melting point around 55 °C. Each chlorine substituent on the quinoxaline typically raises the melting point due to increased molecular weight, polarity, and possible intermolecular interactions, such as dipole-dipole attractions and hydrogen bonding with trace water. The presence of two chlorines could be expected to raise the melting point significantly. A single chlorine on an aromatic ring might raise the melting point by approximately 40-50 °C. So, for two chlorines the increment might be twice of this, leading to an increase of about 80-100 °C over the base pyrazine melting point.',
 'mpC': 152.0,
 'Prediction(integer)': 145.0,
 'Abs error': 7.0}