# 基于LLM预测化学反应产率

In [2]:
!pip install spark_ai_python

Looking in indexes: http://mirrors.aliyun.com/pypi/simple/


In [3]:
import pandas as pd
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from tqdm.auto import tqdm
import logging
import time
import concurrent.futures
import os
import re

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from sparkai.llm.llm import ChatSparkLLM, ChunkPrintHandler
from sparkai.core.messages import ChatMessage

#星火认知大模型Spark3.5 Max的URL值，其他版本大模型URL值请前往文档（https://www.xfyun.cn/doc/spark/Web.html）查看
SPARKAI_URL = 'wss://spark-api.xf-yun.com/v3.5/chat'
#星火认知大模型调用秘钥信息，请前往讯飞开放平台控制台（https://console.xfyun.cn/services/bm35）查看
SPARKAI_APP_ID="xxxx"
SPARKAI_API_SECRET="xxxx"
SPARKAI_API_KEY="xxxx"
#星火认知大模型Spark3.5 Max的domain值，其他版本大模型domain值请前往文档（https://www.xfyun.cn/doc/spark/Web.html）查看
SPARKAI_DOMAIN = 'generalv3.5'

def get_completions(text):
    messages = [ChatMessage(
        role="user",
        content=text
    )]
    spark = ChatSparkLLM(
        spark_api_url=SPARKAI_URL,
        spark_app_id=SPARKAI_APP_ID,
        spark_api_key=SPARKAI_API_KEY,
        spark_api_secret=SPARKAI_API_SECRET,
        spark_llm_domain=SPARKAI_DOMAIN,
        streaming=False,
    )
    handler = ChunkPrintHandler()
    a = spark.generate([messages], callbacks=[handler])
    return a.generations[0][0].text

# 测试模型配置是否正确
text = "你是谁"
get_completions(text)

'您好，我是讯飞星火认知大模型，由科大讯飞构建的认知智能大模型。\n我设计用于与人类进行自然语言交流，旨在解答各种问题和提供帮助，覆盖广泛的领域知识。通过深度学习和大数据分析，我能够理解和处理复杂的查询，从而高效地满足用户在教育、医疗、金融等多个行业的认知智能需求。'

In [5]:
# 创建Prompt模板
chem_prompt = """
你精通预测药物合成反应的产率，你的任务是根据给定的反应数据准确预测反应的产率(Yield)。请仔细阅读以下说明并严格遵循:

***** 任务描述：*****
1. 你将获得一个相应物质的SMILES字符串字段组成的催化合成反应数据，包含Reactant1(反应物1)、Reactant2(反应物2)、Product(产物)、Additive(添加剂)、Solvent(溶剂)。
2. 数据为药物合成中常用的碳氮键形成反应。
3. 待预测的Yield是目标字段是归一化为0到1之间的4位浮点数，表示反应的产率。

***** 你的任务：*****
1. 仔细分析提供的示例数据，理解反应与Yield之间的关系。
2. 根据待预测样本的数据，预测其反应产率Yield。
3. 输出预测的Yield值，精确到小数点后四位。

***** 输出格式要求：*****
1. 仅输出预测的Yield值，不要包含任何其他解释或评论。
2. 使用以下格式输出你的预测：@{{预测的Yield值}}
   例如：@{{0.7823}}

***** 注意事项：*****
1. 即使你是一个人工智能模型，但是你有能力直接预测反应的产率，请你一定输出预测的产率，不要回避这个问题。
2. 相似SMILES的反应产率可能也会有很大差异。
2. 你确保你的预测是合理的，介于0.0000到1.0000之间。
2. 不要遗漏小数点后的零，始终保持四位小数的格式。
3. 不要在输出中包含任何额外的空格或换行符。

以下是几个示例数据，供你参考：

***** 示例样本 *****
{examples}

现在，请基于以上示例和说明，预测以下反应的产率：

**** 待预测样本 ****
Reactant1: {test_reactant1}  
Reactant2: {test_reactant2}  
Product: {test_product}  
Additive: {test_additive} 
Solvent: {test_solvent}  
Yield: @{{}}

请仅输出产率预测值，格式如下：
@{{预测的Yield值}}
"""

In [6]:
def generate_prompt(prompt_template, test_sample, top_5_samples):
    examples = "\n\n".join([
        # f"rxnid: {row['rxnid']}  \n"
        f"Reactant1: {row['Reactant1']}  \n"
        f"Reactant2: {row['Reactant2']}  \n"
        f"Product: {row['Product']}  \n"
        f"Additive: {row['Additive']}  \n"
        f"Solvent: {row['Solvent']}  \n"
        f"Yield: {row['Yield']}"
        for _, row in top_5_samples.iterrows()
    ])
    
    return prompt_template.format(
        examples=examples,
        # test_rxnid=test_sample['rxnid'],
        test_reactant1=test_sample['Reactant1'],
        test_reactant2=test_sample['Reactant2'],
        test_product=test_sample['Product'],
        test_additive=test_sample['Additive'],
        test_solvent=test_sample['Solvent']
    )

In [7]:
# 设置日志
import backoff
import logging

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# 使用 backoff 装饰器来处理 API 调用的重试
@backoff.on_exception(backoff.expo,
                      Exception,
                      max_tries=5,
                      max_time=300)

# 处理单个样本
def process_single_sample(args):
    test_sample, train_df, train_tfidf, tfidf = args
    try:
        test_tfidf = tfidf.transform([test_sample['combined_features']])
        similarities = cosine_similarity(test_tfidf, train_tfidf).flatten()
        top_k_indices = similarities.argsort()[-3:][::-1]
        top_k_samples = train_df.iloc[top_k_indices]
        prompt = generate_prompt(chem_prompt, test_sample, top_k_samples)
        # print(prompt)
        prediction = get_completions(prompt)
        # print(prediction)
        yield_value = extract_yield(prediction)
        # print(yield_value)
        if yield_value is None:
            raise ValueError(f"无法从预测结果中提取有效的产率值")
        return test_sample['rxnid'], yield_value, None
    except Exception as e:
        return test_sample['rxnid'], None, str(e)

def extract_yield(prediction):
    yield_match = re.search(r'@{(.+?)}', prediction)
    if yield_match:
        yield_value = yield_match.group(1)
        try:
            float_yield = float(yield_value)
            if 0 <= float_yield <= 1:
                return f"{float_yield:.4f}"
            else:
                logger.warning(f"提取的产率值 {float_yield} 不在有效范围内")
        except ValueError:
            logger.warning(f"无法将提取的值 '{yield_value}' 转换为浮点数")
    else:
        logger.warning(f"无法从预测结果中提取产率值。") # 完整响应：{prediction}
    return None

# 并行处理样本
def process_samples_parallel(test_df, train_df, train_tfidf, tfidf, max_workers=None, batch_size=100):
    results = {}
    error_indices = []
    total_samples = len(test_df)

    logger.info(f"开始并行处理 {total_samples} 个测试样本")

    # 如果没有指定max_workers，API最大支持2的并行
    if max_workers is None:
        max_workers = 2

    # 将数据分成批次
    batches = [test_df[i:i+batch_size] for i in range(0, total_samples, batch_size)]

    with tqdm(total=total_samples, desc="处理测试样本", unit="sample") as pbar:
        for batch in batches:
            with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
                futures = [executor.submit(process_single_sample, (row, train_df, train_tfidf, tfidf)) 
                           for _, row in batch.iterrows()]
                
                for future in concurrent.futures.as_completed(futures):
                    rxnid, yield_value, error = future.result()
                    if error:
                        logger.error(f"处理样本 {rxnid} 时出错: {error}")
                        error_indices.append(rxnid)
                    else:
                        results[rxnid] = yield_value
                    pbar.update(1)
                    
                    # 更新预计完成时间
                    elapsed_time = pbar.format_dict['elapsed']
                    rate = pbar.format_dict['rate']
                    if rate and rate > 0:
                        remaining_time = (total_samples - pbar.n) / rate
                        eta = time.strftime("%H:%M:%S", time.gmtime(remaining_time))
                        pbar.set_postfix({'ETA': eta}, refresh=True)

    return results, error_indices

In [8]:
logger.info("开始读取数据...")
train_df = pd.read_csv('round1_train_data.csv')

# 计算训练集的产率平均值
train_yield_mean = train_df['Yield'].mean()
logger.info(f"训练集产率平均值: {train_yield_mean:.4f}")

train_df['Yield'] = train_df['Yield'].apply(lambda x: f"{float(x):.4f}")
test_df = pd.read_csv('round1_test_data.csv')
logger.info("数据读取完成")


# 显示训练集的前几行
print("训练集数据预览：")
print(train_df.head())

# 显示测试集的前几行
print("\n测试集数据预览：")
print(test_df.head())

# 显示训练集的基本信息
print("\n训练集基本信息：")
print(train_df.info())

# 显示测试集的基本信息
print("\n测试集基本信息：")
print(test_df.info())

2024-07-24 15:02:22,815 - INFO - 开始读取数据...
2024-07-24 15:02:22,879 - INFO - 训练集产率平均值: 0.6299
2024-07-24 15:02:22,901 - INFO - 数据读取完成


训练集数据预览：
    rxnid               Reactant1        Reactant2  \
0  train1  c1ccc2c(c1)Nc1ccccc1O2      Brc1ccccc1I   
1  train2  c1ccc2c(c1)Nc1ccccc1O2      Brc1ccccc1I   
2  train3  c1ccc2c(c1)Nc1ccccc1O2      Brc1ccccc1I   
3  train4                C1COCCN1  Fc1cnc(Cl)nc1Cl   
4  train5                C1COCCN1  Fc1cnc(Cl)nc1Cl   

                          Product  \
0  Brc1ccccc1N1c2ccccc2Oc2ccccc21   
1  Brc1ccccc1N1c2ccccc2Oc2ccccc21   
2  Brc1ccccc1N1c2ccccc2Oc2ccccc21   
3           Fc1cnc(Cl)nc1N1CCOCC1   
4           Fc1cnc(Cl)nc1N1CCOCC1   

                                            Additive         Solvent   Yield  
0  CC(C)(C)[O-].CC(C)(C)[PH+](C(C)(C)C)C(C)(C)C.F...       Cc1ccccc1  0.7800  
1  C1COCCOCCOCCOCCOCCO1.O=C([O-])[O-].[Cu+].[I-]....    Clc1ccccc1Cl  0.9000  
2  CC(=O)[O-].CC(=O)[O-].CC(C)(C)[O-].CC(C)(C)[PH...  CC1(C)C=CC=CC1  0.8464  
3                                    CCN(C(C)C)C(C)C             CCO  0.9500  
4                                    CCN(C(C)C)C

In [9]:
logger.info("开始特征提取...")

# 创建TF-IDF向量化器
tfidf = TfidfVectorizer()

# 将所有文本特征组合成一个字符串
def combine_features(row):
    return ' '.join([str(row['Reactant1']), str(row['Reactant2']), str(row['Product']), str(row['Additive']), str(row['Solvent'])])

train_df['combined_features'] = train_df.apply(combine_features, axis=1)
test_df['combined_features'] = test_df.apply(combine_features, axis=1)
train_tfidf = tfidf.fit_transform(train_df['combined_features'])

logger.info("特征提取完成")

2024-07-24 15:02:22,942 - INFO - 开始特征提取...
2024-07-24 15:02:23,683 - INFO - 特征提取完成


In [11]:
logger.info("开始并行处理测试样本...")
results, error_indices = process_samples_parallel(test_df, train_df, train_tfidf, tfidf)

2024-07-24 15:02:31,518 - INFO - 开始并行处理测试样本...
2024-07-24 15:02:31,519 - INFO - 开始并行处理 20 个测试样本
处理测试样本:   0%|          | 0/20 [00:00<?, ?sample/s]2024-07-24 15:02:31,629 - INFO - Websocket connected
2024-07-24 15:02:31,643 - INFO - Websocket connected
处理测试样本:   5%|▌         | 1/20 [00:02<00:42,  2.23s/sample, ETA=00:00:42]2024-07-24 15:02:33,848 - INFO - Websocket connected
处理测试样本:  10%|█         | 2/20 [00:02<00:19,  1.11s/sample, ETA=00:00:19]2024-07-24 15:02:34,159 - INFO - Websocket connected
处理测试样本:  15%|█▌        | 3/20 [00:04<00:25,  1.52s/sample, ETA=00:00:25]2024-07-24 15:02:36,258 - INFO - Websocket connected
处理测试样本:  20%|██        | 4/20 [00:05<00:20,  1.28s/sample, ETA=00:00:20]2024-07-24 15:02:37,088 - INFO - Websocket connected
处理测试样本:  25%|██▌       | 5/20 [00:06<00:20,  1.35s/sample, ETA=00:00:20]2024-07-24 15:02:38,558 - INFO - Websocket connected
处理测试样本:  30%|███       | 6/20 [00:07<00:17,  1.22s/sample, ETA=00:00:17]2024-07-24 15:02:39,523 - INFO - Websocket connecte

In [12]:
# 利用相似样本的均值填补大模型无法预测的空缺值

default_yield = []

for index, test_sample in tqdm(test_df.iterrows(), total=len(test_df), desc="处理测试样本"):
        # 对测试样本进行TF-IDF转换
        test_tfidf = tfidf.transform([test_sample['combined_features']])
        
        # 计算与训练集的相似度
        similarities = cosine_similarity(test_tfidf, train_tfidf).flatten()
        
        # 获取最相似的5个样本的索引
        top_k_indices = similarities.argsort()[-3:][::-1]
    
        # 获取最相似的5个样本的产率
        top_k_yields = train_df.iloc[top_k_indices]['Yield'].astype(float).values
        
        # 获取相似度权重
        weights = similarities[top_k_indices]
        # 对权重进行归一化
        weights = weights / weights.sum()
        
        # 计算这五个samples的产率加权平均
        weighted_yield = np.dot(top_k_yields, weights) 
        
        default_yield.append(weighted_yield)

处理测试样本: 100%|██████████| 2616/2616 [00:23<00:00, 110.40it/s]


In [13]:
default_yield[:5]

[0.792216186596149,
 0.7648363827925556,
 0.779355390246026,
 0.5899540884916881,
 0.5184287119948298]

In [14]:
# 处理出错的样本
if error_indices:
    logger.info(f"有 {len(error_indices)} 个样本处理出错，正在重新处理...")
    error_df = test_df[test_df['rxnid'].isin(error_indices)]
    retry_results, retry_error_indices = process_samples_parallel(error_df, train_df, train_tfidf, tfidf)
    results.update(retry_results)

    if retry_error_indices:
        logger.warning(f"仍有 {len(retry_error_indices)} 个样本处理失败")
        # retry_error_idx = [int(item[4:]) - 1 for item in retry_error_indices]
        for rxnid in retry_error_indices:
            results[rxnid] = default_yield[int(rxnid[4:]) - 1]  # 对于始终无法处理的样本，设置一个默认值

In [15]:
logger.info("开始写入结果...")
with open('submit001.txt', 'w') as f:
    f.write('rxnid,Yield\n')
    for rxnid in test_df['rxnid']:
        yield_value = float(results.get(rxnid, default_yield[int(rxnid[4:]) - 1])) # type: ignore
        f.write(f"{rxnid},{yield_value:.4f}\n")
logger.info("结果已保存到submit001.txt文件中")

2024-07-24 15:03:36,504 - INFO - 开始写入结果...
2024-07-24 15:03:36,515 - INFO - 结果已保存到submit001.txt文件中


In [None]:
# 试一试：直接使用相似K个样本的均值作为预测结果

# with open('submit0.txt', 'w') as f:
#     f.write('rxnid,Yield\n')
#     total_samples = len(test_df)
#     for index, test_sample in tqdm(test_df.iterrows(), total=len(test_df), desc="处理测试样本"):
#         # 对测试样本进行TF-IDF转换
#         test_tfidf = tfidf.transform([test_sample['combined_features']])
        
#         # 计算与训练集的相似度
#         similarities = cosine_similarity(test_tfidf, train_tfidf).flatten()
        
#         # 获取最相似的5个样本的索引
#         top_5_indices = similarities.argsort()[-5:][::-1]
#         # print(top_5_indices)
#         # 获取最相似的5个样本的产率
#         top_5_yields = train_df.iloc[top_5_indices]['Yield'].astype(float).values[0]
#         # print(top_5_yields)
#         # # 获取相似度权重
#         weights = similarities[top_5_indices]
#         # # 对权重进行归一化
#         weights = weights / weights.sum()
        
#         # # 计算这五个samples的产率加权平均
#         weighted_yield = np.dot(top_5_yields, weights) 

#         f.write(f"test{index+1},{weighted_yield:.4f}\n")

处理测试样本:   0%|          | 6/2616 [00:00<00:45, 57.24it/s]

处理测试样本:   3%|▎         | 71/2616 [00:01<00:44, 56.90it/s]

[3467 3468 3476 3477]


处理测试样本:   4%|▍         | 99/2616 [00:01<00:38, 65.35it/s]

[820 822 823]


处理测试样本:   5%|▌         | 131/2616 [00:02<00:33, 73.10it/s]

[1572 8131 8132]


处理测试样本:  11%|█         | 283/2616 [00:04<00:33, 68.96it/s]

[ 574 8236 8565]
[5652 5653 5657 5658 5659 5660]


处理测试样本:  17%|█▋        | 451/2616 [00:06<00:29, 73.92it/s]

[  796 21313 21314]


处理测试样本:  19%|█▉        | 491/2616 [00:07<00:28, 74.67it/s]

[  216 10202 20999 21003]
[ 2297 13639 13640]


处理测试样本:  22%|██▏       | 564/2616 [00:07<00:26, 76.02it/s]

[2767 2768 2772 2773 2774 2775 2777]


处理测试样本:  24%|██▍       | 637/2616 [00:08<00:26, 75.78it/s]

[4676 4677 4678 5275 5276 5277 5278]


处理测试样本:  26%|██▋       | 688/2616 [00:09<00:24, 77.66it/s]

[11059 11768 11769]
[ 3307 16091 16099]


处理测试样本:  27%|██▋       | 712/2616 [00:09<00:25, 74.01it/s]

[18266 18268 18269 18270 18271 18272 18273 18274 18275 18276 18277 18278
 18279]


处理测试样本:  35%|███▌      | 926/2616 [00:12<00:21, 77.98it/s]

[19427 19428 19444 19446]


处理测试样本:  40%|████      | 1050/2616 [00:14<00:19, 78.58it/s]

[13359 13360 13361 13414 13415 13416]


处理测试样本:  42%|████▏     | 1091/2616 [00:14<00:19, 77.69it/s]

[16393 16395 16577]


处理测试样本:  43%|████▎     | 1115/2616 [00:15<00:20, 73.90it/s]

[14230 14232 14796]


处理测试样本:  52%|█████▏    | 1360/2616 [00:18<00:17, 72.51it/s]

[  216  8109 10202 20999 21000 21003]
[ 3641  3642  3852  3854  3855 16381 17063 17072 17073 17082 17083 17084]


处理测试样本:  53%|█████▎    | 1384/2616 [00:19<00:17, 71.44it/s]

[19991 19995 19996 20010 20011]


处理测试样本:  54%|█████▎    | 1400/2616 [00:19<00:16, 72.23it/s]

[19426 19429 19443 19445]


处理测试样本:  58%|█████▊    | 1528/2616 [00:21<00:14, 73.49it/s]

[ 9223  9224 17780 20113 20114 20115]


处理测试样本:  66%|██████▌   | 1730/2616 [00:23<00:11, 75.43it/s]

[13537 14949 14951 16431]


处理测试样本:  71%|███████▏  | 1865/2616 [00:25<00:10, 74.12it/s]

[11708 11709 11710]
[ 2110 13190 13192]


处理测试样本:  73%|███████▎  | 1897/2616 [00:26<00:09, 72.68it/s]

[18266 18268 18269 18270 18271 18272 18273 18274 18275 18276 18277 18278
 18279]


处理测试样本:  77%|███████▋  | 2003/2616 [00:27<00:07, 77.28it/s]

[8180 9119 9120 9121]


处理测试样本:  79%|███████▉  | 2075/2616 [00:28<00:07, 75.32it/s]

[8368 8369 8370 8371]


处理测试样本:  82%|████████▏ | 2140/2616 [00:29<00:06, 76.14it/s]

[7362 7363 7364]
[ 2177  2289 12985 13383 13384 13386]


处理测试样本:  85%|████████▍ | 2212/2616 [00:30<00:05, 74.01it/s]

[13537 13538 14948 14949 14951 14960 14962]
[ 5982  5983 11664]


处理测试样本:  86%|████████▌ | 2252/2616 [00:30<00:04, 75.55it/s]

[5967 5969 8737]


处理测试样本:  89%|████████▉ | 2341/2616 [00:32<00:03, 73.06it/s]

[23528 23529 23530]


处理测试样本:  90%|█████████ | 2357/2616 [00:32<00:03, 71.05it/s]

[ 2110 13190 13192]
[22596 22598 22600 22602 22605 22606 22608]


处理测试样本:  94%|█████████▍| 2455/2616 [00:33<00:02, 67.78it/s]

[22603 22604 22607]


处理测试样本:  95%|█████████▌| 2486/2616 [00:34<00:01, 71.01it/s]

[  216  8109 10202 20999 21000 21003]
[ 6703  6704 10089 21717]


处理测试样本:  98%|█████████▊| 2558/2616 [00:35<00:00, 73.27it/s]

[21837 22277 22278]


处理测试样本: 100%|██████████| 2616/2616 [00:35<00:00, 72.72it/s]
