In [None]:
!pip install ratelimit gseapy rdkit pubchempy

In [None]:
from agent import TCMAgent
from tools.llm_api import *
from tools.json_tool import *
import time
import json
from typing import Dict, List, Optional
from tqdm import tqdm
from tools.pubmed_api import PubMedFetcher
import os
import pandas as pd

def fetch_all_pages(
    fetcher,
    query: str,
    max_pages: int = 5,
    results_per_page: int = 10,
    sleep_time: int = 3,
    output_file: Optional[str] = None
) -> Dict:
    """
    获取多页PubMed搜索结果并整合到一个字典中
    
    Args:
        fetcher: PubMedFetcher实例
        query: 搜索查询字符串
        max_pages: 最大获取页数
        results_per_page: 每页结果数0
        sleep_time: 页面间暂停时间(秒)
        output_file: 可选的输出JSON文件路径
    
    Returns:
        包含所有文章和元数据的字典
    """
    all_results = {
        "papers": [],
        "metadata": {}
    }
    
    try:
        # 获取第一页以获取总结果数
        first_page = fetcher.search(
            query=query,
            max_results=results_per_page,
            start=0
        )
        
        total_results = first_page["metadata"]["total_results"]
        total_pages = min(max_pages, (total_results + results_per_page - 1) // results_per_page)
        
        print(f"找到 {total_results} 篇文章，将获取 {total_pages} 页")
        
        # 添加第一页结果
        all_results["papers"].extend(first_page["papers"])
        all_results["metadata"] = {
            "total_results": total_results,
            "pages_retrieved": total_pages,
            "results_per_page": results_per_page,
            "query": query,
            "total_papers_retrieved": len(first_page["papers"])
        }
        
        # 获取剩余页面
        if total_pages > 1:
            with tqdm(range(1, total_pages), desc="获取页面") as pbar:
                for page in pbar:
                    time.sleep(sleep_time)  # 在请求之间暂停
                    
                    start_index = page * results_per_page
                    try:
                        result = fetcher.search(
                            query=query,
                            max_results=results_per_page,
                            start=start_index
                        )
                        
                        all_results["papers"].extend(result["papers"])
                        all_results["metadata"]["total_papers_retrieved"] = len(all_results["papers"])
                        
                        pbar.set_postfix({
                            "已获取文章": len(all_results["papers"]),
                            "当前页文章数": len(result["papers"])
                        })
                        
                    except Exception as e:
                        print(f"\n获取第 {page + 1} 页时出错: {str(e)}")
                        continue
        
        # 如果指定了输出文件，保存结果
        if output_file:
            with open(output_file, 'w', encoding='utf-8') as f:
                json.dump(all_results, f, ensure_ascii=False, indent=2)
            print(f"\n结果已保存至: {output_file}")
        
        return all_results
        
    except Exception as e:
        print(f"获取过程中出错: {str(e)}")
        return all_results

def format_results(paper):    
    temp_df = pd.DataFrame()
    temp_df['title'] = [paper['title']]
    temp_df['pmid'] = [paper['pmid']]
    temp_df['journal'] = [paper['journal']['title']]
    authors = ""
    for author in paper['authors']:
        authors += f"{author['fore_name']} {author['last_name']}, "
        if author['affiliations']:
            temp_df['affiliations'] = [author['affiliations'][0]]
        else:
            temp_df['affiliations'] = [None]
    temp_df['authors'] = [authors]
    if paper['abstract']['structured']:
        for section, text in paper['abstract']['sections'].items():
            temp_df[section] = [text]
    else:
        temp_df['abstract'] = [paper['abstract']['complete']]
    
    if paper['keywords']:
        temp_df['keywords'] = [", ".join(paper['keywords'])]
    else:
        temp_df['keywords'] = [None]
    
    for url_type, url in paper['urls'].items():
        if url:
            temp_df[url_type] = [url]
    if paper['metadata']:
        temp_df['metadata'] = [paper['metadata']]
    else:
        temp_df['metadata'] = [None]
    
    temp_df['is_open_access'] = [paper['metadata']['is_open_access']]

    temp_df['fetch_time'] = [paper['metadata']['fetch_time']]
    return temp_df

# 组装LLM聊天记录信息，用于提取关键词
def combine_context():
    llm_ans = agent.conversations[-1]["content"]
    human_question = agent.conversations[-4]["content"]
    content = f"""
    human_question：
    {human_question}
    ---
    llm answer:
    {llm_ans}
    """
    return content

def get_keywords(content, model="deepseek-chat"):
    PROMPT = """
# 你的任务
基于我下面提供的聊天记录，提取与主题有关的的关键词

# 关键词要求
- 关键词必须是英文
- 关键词控制在4个以内
- 涉及到药物的必须要有相关药物的关键词，比如：中药复方的名称、疾病、关键成分、关键靶点
- 你的输出必须是JSON，key必须是"keywords"

# 示例
用户问题：冠心宁注射液（丹参、川芎）治疗心力衰竭（Heart Failure, HF）的系统药理学分析，包括关键成分、关键靶点和KEGG、WikiPathways通路，系统解读分析药理机制结果
你的输出：
```json
{{
"keywords": ["Guanxinning Injection", "Heart Failure", "Tanshinone Compounds (e.g., Tanshinone IIA)"]
}}
```

# 内容
{content}

现在，请输出关键词JSON：

"""

    prompt = PROMPT.format(content=content)
    ans = ""
    for char in get_llm_answer(prompt, model, temperature=0.95):
        print(char, end="", flush=True)
        ans += char
    print()
    keywords = get_json(ans)["keywords"]
    return keywords

# 获取搜索结果文本
def get_search_results(keywords):
    search_results_all = []
    for i, query in enumerate(keywords):
        print(f"正在获取第{i+1}个关键词的检索结果：{query}")
        # 使用示例
        fetcher = PubMedFetcher(api_key=os.getenv("PUBMED_API_KEY"))

        results = fetch_all_pages(
            fetcher,
            query=query,
            max_pages=1,
            results_per_page=5,
            sleep_time=3,
            output_file="pubmed_results.json"
        )

        search_results_all.extend(results["papers"])
    # search_results_all获取每篇论文的标题及链接
    searched_results = ""
    for paper in search_results_all:
        title = paper["title"]
        pubmed_url = paper["urls"]["pubmed"]
        searched_results += f"title：{title}\npubmed_url：{pubmed_url}\n===\n"
    return searched_results 


question = "伊马替尼(Imatinib)与BCR-ABL激酶的结合活性是多少？"
question = "冠心宁注射液（丹参、川芎）治疗心力衰竭（Heart Failure, HF）的系统药理学分析，包括关键成分、关键靶点和KEGG、WikiPathways通路，系统解读分析药理机制结果"

main_model = "deepseek-chat"
tool_model = "deepseek-chat"
flash_model = "deepseek-chat"

agent = TCMAgent(main_model, tool_model, flash_model)

intention_tools = agent.get_conversation_intention_tools(question)
print("intention_tools:", intention_tools)

for char in agent.work_flow(question, intention_tools):
    print(char, end="", flush=True)



In [None]:

content = combine_context()
keywords = get_keywords(content)

searched_results = get_search_results(keywords)
print(searched_results)

# 基于问题、AI回答、搜索结果文本，筛选出与问题相关的论文，展示论文标题及链接，markdown格式
def get_related_papers(content, searched_results):
    SELECT_PAPER_PROMPT = """
    # 你的任务
    基于我下面提供的聊天记录，筛选出与问题相关的论文，以markdown格式展示论文标题及链接，

    # 聊天记录
    {content}  

    # 搜索结果文本
    {searched_results}

    现在，请判断搜索结果中，哪些论文与问题相关，以markdown格式展示论文标题及链接，示例格式：[论文标题](论文链接)
    你的输出格式：
    # 相关论文
    1. [论文标题](论文链接)
    2. [论文标题](论文链接)
    3. [论文标题](论文链接)
    ...
    """
    prompt = SELECT_PAPER_PROMPT.format(content=content, searched_results=searched_results)
    ans = ""
    for char in get_llm_answer(prompt, model="deepseek-chat", temperature=0.95):
        print(char, end="", flush=True)
        ans += char
    print()
    return ans



In [None]:
# 基于问题、AI回答、搜索结果文本，筛选出与问题相关的论文，展示论文标题及链接，markdown格式
def get_related_papers_yield(content, searched_results):
    SELECT_PAPER_PROMPT = """
# 你的任务
基于我下面提供的聊天记录，筛选出与问题相关的论文，以markdown格式展示论文标题及链接，

# 聊天记录
{content}  

# 搜索结果文本
{searched_results}

现在，请判断搜索结果中，哪些论文与问题相关，以markdown格式展示论文标题及链接，示例格式：[论文标题](论文链接)
你的输出格式：
# 相关论文
1. [论文标题](论文链接)
2. [论文标题](论文链接)
3. [论文标题](论文链接)
...
"""
    prompt = SELECT_PAPER_PROMPT.format(content=content, searched_results=searched_results)
    ans = ""
    for char in get_llm_answer(prompt, model="deepseek-chat", temperature=0.95):
        ans += char
        yield char 
    yield "\n"

In [15]:
# 读取data/conversation.json
with open("data/conversation.json", "r", encoding="utf-8") as f:
    conversation = json.load(f)
content = conversation[-1]["content"]
# 保存content到data/content.txt
with open("data/content.txt", "w", encoding="utf-8") as f:
    f.write(content)


In [None]:
from agents.pathway_enrich import *
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

gpmap = GSEAPyDatabaseMap()
user_input = "请评估复元活血汤（柴胡半两、瓜蒌根、当归各、红花、甘草、大黄、桃仁）在纤维化中的作用，以及在kegg和GO、Reactome的通路富集情况"
databases = gpmap.get_databases(user_input)
print("提取的数据库名称:", databases)

# 示例基因列表
genes = [
    "TP53", "BRCA1", "EGFR", "MYC", "AKT1", 
    "VEGFA", "PTEN", "KRAS", "CDKN2A", "IL6", 
    "TNF", "MAPK1", "STAT3", "JUN", "FOS", 
    "HIF1A", "NFKB1", "PIK3CA", "RB1", "CCND1"
]

# 定义要查找的疾病名称列表
disease_names = ['Sclerosis', 'Cancer', 'Heart']

agent2 = GeneEnrichmentAnalyzer()
# 执行分析
summary, enrichment_results = agent2.get_enrichment_summary(genes, gene_sets=databases, disease_names=disease_names)
enrichment_results


In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import datetime
import numpy as np

def plot_terms_by_pvalue(data, term_col='Term', pvalue_col='Adjusted P-value', 
                         figsize=(10, 12), color='skyblue', 
                         title='Terms Sorted by Adjusted P-value',
                         save_path=None, dpi=300, format='png',
                         log_transform=False):
    """
    Create a horizontal bar chart of Terms sorted by Adjusted P-value.
    
    Parameters:
    -----------
    data : pandas.DataFrame or path to file
        DataFrame containing the Term and Adjusted P-value columns
    term_col : str, default='Term'
        Name of the column containing the terms
    pvalue_col : str, default='Adjusted P-value'
        Name of the column containing the adjusted p-values
    figsize : tuple, default=(10, 12)
        Size of the figure (width, height)
    color : str, default='skyblue'
        Color of the bars
    title : str, default='Terms Sorted by Adjusted P-value'
        Title of the plot
    save_path : str, default=None
        Path to save the figure. If None, the figure is not saved
    dpi : int, default=300
        Resolution of the saved figure
    format : str, default='png'
        Format of the saved figure ('png', 'jpg', 'pdf', 'svg', etc.)
    log_transform : bool, default=False
        Whether to apply -log10 transformation to p-values
        
    Returns:
    --------
    fig, ax : matplotlib figure and axis objects
    """
    # If data is a string, assume it's a file path and load it
    if isinstance(data, str):
        if data.endswith('.csv'):
            df = pd.read_csv(data)
        elif data.endswith(('.xls', '.xlsx')):
            df = pd.read_excel(data)
        else:
            raise ValueError("File format not supported. Please provide a CSV or Excel file.")
    else:
        # Create a deep copy to avoid modifying the original data
        df = data.copy(deep=True)
    
    # Apply -log10 transformation if requested
    if log_transform:
        df[pvalue_col] = [-np.log10(x) for x in df[pvalue_col]]
        x_label = '-log10(Adjusted P-value)'
    else:
        x_label = 'Adjusted P-value'
    
    # Sort by adjusted p-value in ascending order
    df_sorted = df.sort_values(by=pvalue_col)
    
    # Create the figure
    fig, ax = plt.subplots(figsize=figsize)
    
    # Create horizontal bar plot
    sns.barplot(x=pvalue_col, y=term_col, data=df_sorted, color=color, ax=ax)
    
    # Add labels and title
    ax.set_xlabel(x_label)
    ax.set_ylabel('Term')
    ax.set_title(title)
    
    # Add p-value annotations to the end of each bar
    for i, p in enumerate(df_sorted[pvalue_col]):
        ax.text(p + p*0.01, i, f'{p:.2e}', va='center')
    
    # Adjust layout
    plt.tight_layout()
    
    # Save the figure if a save path is provided
    if save_path is not None:
        # Create directory if it doesn't exist
        save_dir = os.path.dirname(save_path)
        if save_dir and not os.path.exists(save_dir):
            os.makedirs(save_dir)
            
        # Add file extension if not in the save_path
        if not save_path.lower().endswith(f'.{format.lower()}'):
            save_path = f"{save_path}.{format}"
            
        # Save the figure
        fig.savefig(save_path, dpi=dpi, format=format, bbox_inches='tight')
        print(f"Figure saved to: {save_path}")
    
    return fig, ax


def draw_pvalue_analysis(summary):
    save_info = ""
    base_dir = "files"
    os.makedirs(base_dir, exist_ok=True)
    current_time = datetime.datetime.now().strftime("%Y%m%d%H%M%S")

    for index, data in summary.items():
        # 使用索引作为文件名的一部分，确保唯一性
        save_path = os.path.join(base_dir, f"pvalue_analysis_{index}_{current_time}")
        
        # 使用log_transform参数而不是直接修改数据
        fig, ax = plot_terms_by_pvalue(
            data, 
            save_path=save_path,
            log_transform=True,  # 在函数内部应用转换
            title=f"{index} - Terms Sorted by -log10(Adjusted P-value)"
        )
        plt.close(fig)  # 关闭图形以释放内存
        
        save_info += f"{index} 的富集结果图已保存到 {save_path}.png\n"
    return save_info

save_info = draw_pvalue_analysis(summary)



In [None]:
import json 

herb_names = ["柴胡", "天花粉", "当归", "红花", "甘草", "大黄", "桃仁"]

# 读取herb_targets.json