In [None]:
#load tools
from tools import GetDOIsTool, GetDOIsCleanTool, GetOriginaltextTool, \
    GetExtracttextTool, StreamModeTool, CleanDBTool, FileRagTool, QueryTool,\
    DataextTool, DatacountTool, FeaextTool, ModelSelTool, ShapplotTool

tools = [
    GetDOIsTool(description="Quickly get the doi of the literature."),
    GetDOIsCleanTool(description="Quickly clean the doi of the literature."),
    GetOriginaltextTool(description="Get the original text of the literature."),
    GetExtracttextTool(description="Get the extracted text of the literature."),
    StreamModeTool(description="create the database."),
    CleanDBTool(description="Clean the database."),
    FileRagTool(description="File quiz tool. analysis file"),
    QueryTool(description="File quiz tool. Use when you need to quiz based on uploaded files, no need to upload files again for subsequent queries."),
    DataextTool(description="Extract the jsonl to csv."),
    DatacountTool(description="Count the data."),
    FeaextTool(description="Extract the feature."),
    ModelSelTool(description="Select the model."),
    ShapplotTool(description="Plot the shap.")
]

# print(f"construct {len(tools)} tools:")
# for tool in tools:
#     print(f"- {tool.name}: {tool.description}")
#     if tool.args_schema:
#         print(f" params: {tool.args_schema.model_json_schema()}")

In [None]:
#define the agent
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import create_react_agent
import warnings
warnings.filterwarnings("ignore")

memory = MemorySaver()

model = ChatOpenAI(
    openai_api_key="sk-bcc39a2fd1ec4aaf81bb672f446d1b2d",
    base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
    model= "qwen-plus"
)
with open("prompts/calling_prompts_gradio.txt", "r") as f:
    prompts = f.read()

In [None]:
# 请帮我查找在2020年到2025年发表的有关二氧化钛纳米复合材料相关文献并返回doi号
# 请帮我查找在2019年电解水相关的文献并返回doi号
# 请帮我查找2025年发表的有关甲烷催化重整的论文，并返回论文的doi
# 请帮我查找在2024年发表的有关mof合成相关文献并返回doi号
# 请帮我查找在2021年到2022年发表的有关高效锂电池相关文献并返回doi号

#请帮我查找在2000年到2025年发表的有关ZrO2纳米粒子液相合成相关文献并返回doi号，改变“液相合成”关键词进行5次查询
#请帮我查找在2000年到2025年发表的有关CeO2纳米粒子液相合成相关文献并返回doi号，改变“液相合成”关键词进行5次查询

In [None]:
# please extract the text from original text

In [None]:
agent_executor = create_react_agent(model, tools, prompt=prompts, checkpointer=memory)

config = {"configurable": {"thread_id": "abc123"}}

while True:
    user_input = input("please enter your problem (type 'exit' to end the conversation): ")
    if user_input.strip().lower() == "exit":
        break

    for chunk in agent_executor.stream(
        {"messages": [HumanMessage(content = user_input)]}, config
    ):
        # if "agent" in chunk.keys():
            
        #     if chunk['agent']['messages'][0].content:
        #        
        #  print(f"Agent: {chunk['agent']['messages'][0].content}")
        #         print("----")
        print(chunk)


In [None]:
from data_analysis import FeatureExtractor
fea = FeatureExtractor()
fea.exe()

In [None]:
from model_sel import ModelSelector
modelsel = ModelSelector()
modelsel.exe()

In [None]:
import shap, joblib, os
import pandas as pd
import matplotlib.pyplot as plt
shap.initjs()

In [None]:
model_res_dir=r"outputs/saved_models"
input_dir=r"outputs"

df = pd.read_csv(os.path.join(model_res_dir, "model_res.csv"))
model_name = df[df['best_r2'] == max(df['best_r2'])]['model'].values[0]
X = pd.read_csv(os.path.join(input_dir, "X.csv"), index_col= 0)
model = joblib.load(f"{model_res_dir}/{model_name}_best_model.joblib")
explainer = shap.KernelExplainer(model.predict, shap.sample(X, 5))
shap_values = explainer.shap_values(X)
excluded_features = [col for col in X.columns if 'name' in col.lower() or 'unit' in col.lower()]
display_features = [col for col in X.columns if col not in excluded_features]
display_feature_indices = [X.columns.get_loc(col) for col in display_features]
X_display = X[display_features]
shap_values_display = shap_values[:, display_feature_indices]
shap.summary_plot(shap_values_display, X_display, show=False, max_display=8)

fig = plt.gcf()
ax = plt.gca()
current_xlabel = ax.get_xlabel()
ax.set_xlabel(current_xlabel, color='black', fontdict={'weight': 'bold', 'size': 12})

for label in ax.get_xticklabels():
    label.set_fontweight('bold')
    label.set_color('black')
for label in ax.get_yticklabels(): 
    label.set_fontweight('bold')
    label.set_color('black')
if len(fig.axes) > 1 and fig.axes[-1] is not ax:
    cbar_ax = fig.axes[-1]
    if hasattr(cbar_ax, 'yaxis') and cbar_ax.get_ylabel():
         cbar_ax.yaxis.label.set_fontweight('bold')
         cbar_ax.yaxis.label.set_color('black')
    for label in cbar_ax.get_xticklabels() + cbar_ax.get_yticklabels():
        label.set_fontweight('bold')
        label.set_color('black')

save_path = os.path.join(input_dir, "shap.png")
plt.savefig(save_path, format='png', dpi=300, bbox_inches='tight')
plt.show()
plt.close(fig)