In [9]:
import gradio as gr
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
import re
import requests
from bs4 import BeautifulSoup
from fake_useragent import UserAgent
import random
from time import sleep
import chardet

from langchain_openai import embeddings
from langchain.vectorstores import FAISS
from langchain.document_loaders import TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.text_splitter import CharacterTextSplitter

from sklearn.metrics.pairwise import cosine_similarity
import difflib

def generate_keywords(input):
    llm = ChatOpenAI(base_url="http://localhost:1234/v1", api_key="lm-studio", max_tokens=1280, temperature=0.7, top_p=0.9)
    prompt = ChatPromptTemplate.from_messages([("user", 
    '''
    请你扮演一名专业的行业分析师，你应该利用你的行业知识和其他相关因素提供专业的分析。

    请你给出以下行业以下每个主题各3至5个重要关键词，关键词应包括行业术语、概念或趋势，保证关键词的准确、精炼性和与主题的相关度，逻辑连贯，内容低重复度、有深度。禁止使用"："和"（）"，避免使用品牌名称或专用词，关键词要易搜索和理解。
    行业：{input}

    主题：
    1. 行业定义
    2. 行业分类
    3. 行业特征
    4. 发展历程
    5. 产业链分析
    6. 市场规模
    7. 政策分析
    8. 竞争格局

    使用以下格式：
    1. <主题1>
    - <关键词1>
    - <关键词2>
    - <关键词3>
    - <关键词4>
    - <关键词5>

    2. <主题2>
    - <关键词1>
    ..
    ..
    8. ..
    完成上述任务后，请停止生成任何额外内容。

    样例：
    1. 行业定义
    - 水域资源管理
    - 养殖技术进步
    - 水产品质量监控
    - 海洋生态环境保护
    - 农水联动模式

    2. 行业分类
    - 鱼类养殖
    - 虾蟹养殖
    - 牛蛙业
    - 海参养殖
    - 浮游生物养殖

    3. 行业特征
    - 规模化生产
    - 生物多样性
    - 投入成本
    - 市场周期性
    - 环保压力

    4. 发展历程
    - 人工饲料
    - 智能水产养殖
    - 全球化市场拓展
    - 产业链整合
    - 技术驱动革新

    5. 产业链分析
    - 种苗供应
    - 饲料加工
    - 养殖基地建设
    - 销售和分销

    6. 市场规模
    - 全球水产养殖产量
    - 主要消费国
    - 国内市场规模
    - 年度报告与预测

    7. 政策分析
    - 环保政策影响
    - 信贷和补贴支持
    - 农业补贴调整

    8. 竞争格局
    - 主要企业竞争态势
    - 新进入者威胁
    - 品牌差异化策略
    - 外资并购与合作
    '''
    )])

    str_output_parser = StrOutputParser()
    chain = prompt | llm | str_output_parser

    accumulated_text = ""
    for message in chain.stream({"input": input}):
        accumulated_text += message
        yield accumulated_text
#     return '''1. 行业定义
# - 水域资源管理  
# - 养殖技术进步
# - 水产品质量监控
# - 海洋生态环境保护
# - 农水联动模式
# '''

def process_keywords(input):
    lines = input.strip().split("\n")
    industry_keywords = []
    current_topic = ""
    topics = {}
    keywords = []
    for line in lines:
        if re.match(r".*?\d+\.\s.*", line):  # 1. 行业定义
            current_topic = line.split(". ")[1].strip()
            topics[current_topic] = []
        elif re.match(r".*?-\s.*", line):        # - 行业关键词
            keyword = line.split("- ")[1].strip()
            topics[current_topic].append(keyword)
            industry_keywords.append((current_topic, keyword))
            keywords.append(keyword)

    accumulated_text = ""
    for topic, keyword in industry_keywords:
        accumulated_text += f"{keyword}\n"

    return accumulated_text, keywords_to_option(input)

def keywords_to_option(input):
    lines = input.strip().split("\n")
    industry_keywords = []
    current_topic = ""
    topics = {}
    keywords = []
    for line in lines:
        if re.match(r".*?\d+\.\s.*", line):  # 1. 行业定义
            current_topic = line.split(". ")[1].strip()
            topics[current_topic] = []
        elif re.match(r".*?-\s.*", line):        # - 行业关键词
            keyword = line.split("- ")[1].strip()
            topics[current_topic].append(keyword)
            industry_keywords.append((current_topic, keyword))
            keywords.append(keyword)

    # 打印行业关键词组合
    # for topic, keyword in industry_keywords:
    #     print(f"{input} {keyword}")
    # for keyword in keywords:
    #     print(f"{keyword}")
    # accumulated_text = ""
    # for topic, keyword in industry_keywords:
    #     accumulated_text += f"{keyword}\n"

    return gr.update(choices=keywords, value=None, interactive=True)


def clear_all():
    return "", "", "", gr.update(choices=[], value=None, interactive=False)


def handle_selection(industry, keywords_to_search, pages_needed):
    # 模拟处理函数，可能抛出异常
    if industry == "" or keywords_to_search == "":
        return "请输入行业和关键词"
    ua = UserAgent()
    headers = {
        'accept': '*/*',
        'accept-encoding': 'gzip, deflate, br, zstd',
        'accept-language': 'zh-CN,zh;q=0.9,en;q=0.8,en-GB;q=0.7,en-US;q=0.6',
        'connection': 'keep-alive',
        #'cookie': 'PSTM=1710688297; BD_UPN=12314753; BIDUPSID=F67F8916CCCEBCD73956847F5D7978CC; BDUSS=jd6TEVuNlN6akQxVmo1dFpUcjkySEQwYmhqa2dVeXBZWHotQ3RvZGNqRUtWU0ptSVFBQUFBJCQAAAAAAAAAAAEAAABTgqghYWxleDIwMDM3MzEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAArI-mUKyPplV; BDUSS_BFESS=jd6TEVuNlN6akQxVmo1dFpUcjkySEQwYmhqa2dVeXBZWHotQ3RvZGNqRUtWU0ptSVFBQUFBJCQAAAAAAAAAAAEAAABTgqghYWxleDIwMDM3MzEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAArI-mUKyPplV; BAIDUID=FD799B193B0AEF60B5A1A9C9C6EDB6A9:SL=0:NR=10:FG=1; newlogin=1; H_WISE_SIDS_BFESS=60360_60453_60467_60492_60498_60552_60564; H_PS_PSSID=60360_60564; H_WISE_SIDS=60360_60564; BDORZ=B490B5EBF6F3CD402E515D22BCDA1598; ab_sr=1.0.1_MDA0Njk3MzdmMWQ4ZjkyZTBlN2E4YjY0NmVjNTAzZWY5YjdmMmIzYjhmMTRhYTAxNjQ3MmE1NmQ4ODg3YTlhNDViODZlNDkyODJhMWE5NWZiZThiMWQxZmUwZGEwNWVlYzM1MzUwMzgwZTdlMTQ5N2FmY2Q0MzFiZjM3NTZjYjQ0Mzc4NTczZjcxNzZkMTIyNWI1OTU5OWVlZGUyODA1YQ==; sug=3; sugstore=0; ORIGIN=0; bdime=0; BA_HECTOR=2lal810k2g0ga42g2501850g845es61jb0q6o1u; delPer=0; BD_CK_SAM=1; PSINO=3; ZFY=juQ5OsocA2UxNHl:BsYgz1ZK8IU76I:AkzN6Zq:A5D2stw:C; BAIDUID_BFESS=FD799B193B0AEF60B5A1A9C9C6EDB6A9:SL=0:NR=10:FG=1; COOKIE_SESSION=242939_0_7_7_4_21_1_2_5_7_1_6_168352_0_0_0_1722497720_0_1722825515%7C9%231532_27_1722233001%7C9; baikeVisitId=bfb245aa-2485-4ca3-bcad-f00e7e4cd8c3; H_PS_645EC=b986QSaUEyz7RSOXAdapMdMhqzQGeF5ck1HMI0tjW4XAHp7jLLVZfnW7E7k; BDSVRTM=173',
        'User-Agent': ua.edge
    }
    all_links = []
    keyword_combination = f"{industry} {keywords_to_search}"
    for pages in range(pages_needed):
        page = pages * 10
        response = requests.get(f'https://www.baidu.com/s?wd={keyword_combination}&pn={page}', headers=headers)

        html = response.text
        soup = BeautifulSoup(html, 'html.parser')

        content_left = soup.find('div', attrs={'id': 'content_left'})
        all_divs = content_left.find_all('div', attrs={'mu': True})
        links = [div.get('mu') for div in all_divs]
        all_links.extend(links)
    # results = "\n".join(all_links)
    results = "\n".join([f"{i+1}. {link}" for i, link in enumerate(all_links)])
    return results

def retry_function(fn, max_retries, delay, industry, keywords_to_search, pages_needed):
    for attempt in range(max_retries):
        try:
            result = fn(industry, keywords_to_search, pages_needed)
            return result
        except Exception as e:
            if attempt < max_retries - 1:
                random_delay = random.uniform(delay-0.5, delay+0.5)
                sleep(random_delay)  # 延迟重试
            else:
                return f"Error: {str(e)}"

def selection_to_links(industry, keywords_to_search, pages_needed):
    return retry_function(handle_selection, max_retries=5, delay=1.5, industry=industry, keywords_to_search=keywords_to_search, pages_needed=pages_needed)


# webpage_to_text functions
def remove_ads_by_tag(soup):
    # 常见的广告类名或ID
    tag_keywords = ['next', 'post_top_tie', 'jubao', 'search', 'comment_area', 'share', 'nav', 'ad', 'recommend', 'tool', 'advertisement', 'ads', 'sponsored', 'promo', 'banner', 'adsense', 'aside', 'footer', 'header', 'side-bar', 'column', 'sidebar', 'list', 'sideColumn', 'side']

    pattern = re.compile('|'.join(tag_keywords), re.IGNORECASE)

    tags = soup.find_all(class_=pattern)
    for tag in tags:
        tag.decompose()

    tags = soup.find_all(id=pattern)
    for tag in tags:
        tag.decompose()

    # 删除头尾侧边栏
    for iframe in soup.find_all('iframe'):
        iframe.decompose()
    for aside in soup.find_all('aside'):
        aside.decompose()
    for header in soup.find_all('header'):
        header.decompose()
    for footer in soup.find_all('footer'):
        footer.decompose()

    return soup

def remove_ads_by_text(soup):
    ad_keywords = ['优惠券', '阅读原文', '扫一扫', '限时抢购', '免费试用', '立即注册', '超值折扣', '注册有礼', '免费领取', '立即购买', '关注该公众号', '微信扫码', '分享至', '下载(.*?)客户端', '返回(.*?)首页', '阅读下一篇', '特别声明：以上内容', 'Notice: The content above', '打开(.*?)体验更佳', '热搜', '打开(.*?)新闻', '查看精彩图片']
    for keyword in ad_keywords:
        for ad in soup.find_all(string=re.compile(keyword)):
            parent = ad.find_parent()
            if parent:
                parent.decompose()

    return soup

def clean_html(html_content):
    # 使用BeautifulSoup解析HTML
    soup = BeautifulSoup(html_content, 'lxml')
    soup = remove_ads_by_tag(soup)
    soup = remove_ads_by_text(soup)
    
    # 去除脚本和样式
    for script in soup(['script', 'style']):
        script.decompose()

    # 插入换行符
    for tag in soup.find_all(['p', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'li', 'tr', 'div', 'br', 'hr']):
        tag.insert_after('\n')

    # 提取文本
    text = soup.get_text()

    # 正则表达式清理多余空格空行
    text = re.sub(r'\n+', '\n', text)
    text = re.sub(r'\s{2,}', ' ', text)

    return text

def webpage_to_text(links_found, min_slider, max_slider, num):
    all_links = [line.split('. ', 1)[1] for line in links_found.splitlines()]
    accumulated_text = ""

    minnum = int(min_slider)
    maxnum = int(max_slider)

    # 使用正则表达式解析 num
    num_pattern = r'^(\d+)(?:-(\d+))?$'  # 匹配单个数字或范围
    match = re.match(num_pattern, num)
    if match:
        start_num = int(match.group(1))
        end_num = int(match.group(2)) if match.group(2) else start_num
        minnum = start_num if start_num > 0 else minnum
        maxnum = end_num


    if maxnum == 0 or maxnum < minnum:
        maxnum = minnum

    if minnum == 0 or all_links == []:
        accumulated_text += "请选择链接"
        return accumulated_text
    
    ua = UserAgent()
    headers = {
        'accept': '*/*',
        'accept-encoding': 'gzip, deflate, br, zstd',
        'accept-language': 'zh-CN,zh;q=0.9,en;q=0.8,en-GB;q=0.7,en-US;q=0.6',
        'connection': 'keep-alive',
        'User-Agent': ua.edge
    }
    count = int(0)
    for link in all_links:
        count += 1
        if count < minnum or count > maxnum:
            continue
        accumulated_text += f"第{count}个网页：\n{link}\n"
        if re.findall(r'zhihu', link):
            accumulated_text += "Error: Zhihu\n\n"
            continue
        if re.findall(r'weibo', link):
            accumulated_text += "Error: Sina Visitor System\n\n"
            continue
        try:
            # 获取html的文本内容
            response = requests.get(link, headers=headers)

            response.raise_for_status()  # 检查请求是否成功

            # 检测编码
            encoding_result = chardet.detect(response.content)
            encoding = encoding_result['encoding']
            response.encoding = encoding
            if re.findall(r'weixin', link):
                response.encoding = 'utf-8'
                
            # 获取网页标题
            soup = BeautifulSoup(response.text, 'html.parser')
            if soup.title and soup.title.string:
                title = soup.title.string
                if title == "百度安全验证":
                    accumulated_text += "Error: Baidu Security Verification\n\n"
                    continue
                if re.findall(r'百度文库', title):
                    accumulated_text += "Error: Baidu Wenku\n\n"
                    continue
                # accumulated_text += f"{title}\n" #不显示标题，会重复

            text = clean_html(response.text)
            accumulated_text += f"{text}\n\n"

        except requests.exceptions.RequestException as e:
            accumulated_text += f"Failed to retrieve {link}: {e}\n\n"
    
    return accumulated_text, minnum, maxnum


def parse_pages(pages_str, minnum, maxnum):
    # 替换分隔符
    pages_str = re.sub(r'[;\/.，。、]', ',', pages_str)
    # 拆分字符串
    parts = pages_str.replace(' ','').split(',')
    
    result = []
    for part in parts:
        # 单个数字
        if '-' not in part:
            num = int(part)
            if minnum <= num <= maxnum:
                result.append(num)
        else: # 数字范围
            start, end = map(int, part.split('-'))
            result.extend([num for num in range(start, end + 1) if minnum <= num <= maxnum])
    
    return result

def calculate_final_score(relevance, accuracy, completeness, timeliness, authority, readability):
    # 定义每个维度的权重
    weights = {
        'relevance': 0.25,
        'accuracy': 0.20,
        'completeness': 0.15,
        'timeliness': 0.15,
        'authority': 0.15,
        'readability': 0.10
    }
    
    # 计算加权总分
    final_score = (
        relevance * weights['relevance'] +
        accuracy * weights['accuracy'] +
        completeness * weights['completeness'] +
        timeliness * weights['timeliness'] +
        authority * weights['authority'] +
        readability * weights['readability']
    )
    
    return final_score

def calculate_by_vector(text1, text2, min_length=20, split_char=r'[|\n]'):
    # 初始化嵌入模型
    embedding = embeddings.OpenAIEmbeddings(check_embedding_ctx_length=False, base_url="http://localhost:1234/v1", api_key="lm-studio")

    # 将文本转换为向量
    vector1 = embedding.embed_query(text1)

    # 分段
    segments = re.split(split_char, text2)

    # 设定最小长度限制，过滤掉短句
    filtered_segments = [segment.strip() for segment in segments if len(segment.strip()) >= min_length]

    # 计算每个段落与主题的相似度
    similarities = []
    for segment in filtered_segments:
        vector2 = embedding.embed_query(segment)
        similarity = cosine_similarity([vector1], [vector2])
        similarities.append((segment, similarity[0][0]))

    # 找到相似度最高的段落
    if similarities:
        best_segment, best_similarity = max(similarities, key=lambda x: x[1])
    else:
        best_segment, best_similarity = '', 0.00
    return best_similarity, best_segment

def batch_analyze_webpage(analyze_all, webpage_text, keywords_processed, webpage_to_analyze, industry, keywords_to_search, minnum, maxnum):
    if webpage_to_analyze:
        pages_to_analyze = parse_pages(webpage_to_analyze, minnum, maxnum)
    if analyze_all == True: # 优先选择全部
        pages_to_analyze = range(minnum, maxnum + 1)
    
    analyzed_webpage = str("")
    for i in pages_to_analyze:
        analyzed_webpage += f"正在分析第{i}个网页：\n"
        yield analyzed_webpage
        for part in analyze_webpage(webpage_text, keywords_processed, i, industry, keywords_to_search):
            yield analyzed_webpage + part
        # analyzed_webpage += analyze_webpage(webpage_text, keywords_processed, i, industry, keywords_to_search)
        analyzed_webpage += part

def analyze_webpage(webpage_text, keywords_processed, webpage_to_analyze, industry, keywords_to_search):
    accumulated_text = ""
    llm = ChatOpenAI(base_url="http://localhost:1234/v1", api_key="lm-studio", max_tokens=512, temperature=0.6, top_p=0.8, stop_sequences=["---END---"])
    n = webpage_to_analyze  # 要提取的页面编号
    weblink = re.search(rf'第{n}个网页：\n(https?:\/\/[^\s]+)', webpage_text)
    pattern = rf'第{n}个网页：\n(?:https?:\/\/[^\n]*\n)?([\s\S]*?)(?=\n第\d+个网页：|\Z)'
    match = re.search(pattern, webpage_text)
    if match:
        content = match.group(1).strip()  # 提取并去除首尾空格
    
        prompt0 = ChatPromptTemplate.from_messages([("user", 
'''
请判断以下网页爬虫获取的、与{input}行业以及{keywords_to_search}相关的文章页面是否存在缺失（只有标题、没有正文、字数过少、少于100字等）或报错（Error、Failed、错误、验证码等）。请忽略页面中的可能出现的广告、推荐、导航栏等无关信息。
文章：
{text}

若存在缺失或报错，则直接输出"Error---END---"；文章完整，则直接输出"Complete---END---"。
''')])
        chain0 = prompt0 | llm | StrOutputParser()
        reply0 = chain0.invoke({"input": industry, "text": content, "keywords_to_search": keywords_to_search})

        if re.search(r'Error', reply0, re.IGNORECASE):
            reply0 = chain0.invoke({"input": industry, "text": content, "keywords_to_search": keywords_to_search})
            if re.search(r'Error', reply0, re.IGNORECASE):
                iscomplete = False
            else:
                iscomplete = True
            iscomplete = False
        else:
            iscomplete = True
        if not iscomplete:
            accumulated_text += f"网址：{weblink[1]}\n\n文章内容缺失。\n\n综合评分：0/100\n向量评分：0/100（最高余弦相似度）\n----------\n\n"
            # return "文章内容缺失。\n相关性评分：0/100\n----------\n\n" #, gr.update(choices=["Error"], value=None, interactive=False)
            yield accumulated_text
            return


        prompt = ChatPromptTemplate.from_messages([("user", 
'''
#START#-#END#为输出格式的范围，[]内为需要填写的部分。

第一步：请阅读以下与{input}行业及{keywords_to_search}的文章，并根据文章内容写一个自然段的简介，字数不超过300字。请忽略文章中的广告、推荐和导航栏等无关信息。

文章：
{text}

输出格式：#START#
标题：[此处填写标题]
简介：
[此处撰写简介（字数不超过300字）]
#END#

第二步：根据文章内容，提取与{input}行业以及{keywords_to_search}相关的**至多8个**关键词。

可参考的关键词：（
{keywords}）

输出格式：#START#
关键词：[关键词1]；[关键词2]；...；[关键词n（关键词最多8个）]
#END#

第三步：请对文章与{input}行业及{keywords_to_search}的相关性等各维度进行评分，范围是0到100，现在是2024年。

评分标准：尽量不给90-100，优秀文章给70-89，普通文章给50-69，质量较差给30-49

请按以下格式输出：#START#
相关性：[n]/100
准确性：[n]/100
完整性：[n]/100
时效性：[n]/100
权威性：[n]/100
可读性：[n]/100
---END---
#END#
---END---
''')])
        prompt2 = ChatPromptTemplate.from_messages([("user", 
'''
严格按照以下格式输出：
标题：
简介：（字数不超过300字）
关键词：可选关键词a；可选关键词x；...；可选关键词n
评分：n/100

完成上述任务后，直接停止生成，不要理由、分析、综述。
{prev}
''')])
        
        chain = prompt | llm | StrOutputParser()
        # reply = chain.invoke({"input": industry, "text": content, "keywords": keywords_processed, "keywords_to_search": keywords_to_search})
        chain2 = {"prev": chain} | prompt2 | llm | StrOutputParser() # 停用
        # reply = chain2.invoke({"input": industry, "text": content, "keywords": keywords_processed, "keywords_to_search": keywords_to_search})
        # reply = f"网址：{weblink[1]}\n\n" + reply + "\n----------\n\n"

        accumulated_text += f"网址：{weblink[1]}\n"
        yield accumulated_text
        for message in chain.stream({"input": industry, "text": content, "keywords": keywords_processed, "keywords_to_search": keywords_to_search}):
            accumulated_text += message
            accumulated_text = accumulated_text.replace("---END---", "").replace("#START#", "").replace("#END#", "")
            yield accumulated_text

        # 匹配分数
        pattern = r"(?P<dimension>[\u4e00-\u9fa5]+)：(?P<score>\d+)/100"
        matches = re.findall(pattern, accumulated_text)
        scores = {dimension: int(score) for dimension, score in matches}
        if scores.get("综合评分", 60) == 0:
            final_score = 0
        else:
            final_score = calculate_final_score(scores.get("相关性", 0), scores.get("准确性", 0), scores.get("完整性", 0), scores.get("时效性", 0), scores.get("权威性", 0), scores.get("可读性", 0))
            accumulated_text += f"综合评分：{round(final_score, 2)}/100\n"

            cosine_similarity = calculate_by_vector(industry + ' ' + keywords_to_search, content, 20, r'[|\n]')[0]
            accumulated_text += f"向量评分：{cosine_similarity*100:.0f}/100（最高余弦相似度）\n"
            accumulated_text += "----------\n\n"
        accumulated_text = re.sub(r'\n{3,}', '\n\n', accumulated_text)
        yield accumulated_text
        # return reply #, match_keywords(reply)
    else:
        accumulated_text += f"未找到第{n}个页面的内容"
        yield accumulated_text #, gr.update(choices=["Error"], value=None, interactive=False)

def match_keywords(text):
    # 正则表达式只匹配“关键词：”这一行后面的内容
    pattern = r'^关键词：([^\n]+)'

    # 查找匹配项
    match = re.search(pattern, text, re.MULTILINE)

    if match:
        keywords = match.group(1).strip().replace('.', '').replace('。', '')  # 提取并去除首尾空格
        # print(keywords)
        # keyword_list = keywords.split('，')  # 将关键词按逗号分隔并存入数组
        # if not keyword_list:
        #     keyword_list = keywords.split(', ')
        # if not keyword_list:
        #     keyword_list = keywords.split(',')
        # if not keyword_list:
        #     keyword_list = keywords.split('、')
        # print(keyword_list)
        delimiters = '[，,、；;]'  # 包含逗号、中文逗号和顿号
        # 使用正则表达式分割字符串
        keyword_list = re.split(delimiters, keywords)
        # 移除列表中的空字符串（如果有的话）
        keyword_list = [keyword.strip() for keyword in keyword_list if keyword.strip()]

        return gr.update(choices=keyword_list, value=True, interactive=False)
    else:
        return gr.update(choices=["Error"], value=None, interactive=False)

def sort_results(text):
    pattern = r"正在分析第(\d+)个网页：\n网址：\s*(.*?)向量评分：(?P<score>\d+(\.\d{1,2})?)/100"

    # 使用 findall 查找所有匹配项
    matches = re.findall(pattern, text, re.DOTALL)

    # 提取网页编号、内容和对应的综合评分
    results = []
    for match in matches:
        webpage_number = match[0]
        webpage_content = match[1].strip()
        score = float(match[2])
        results.append((webpage_number, webpage_content, score))

    # 按照综合评分排序，分数高的排在前面
    sorted_results = sorted(results, key=lambda x: x[2], reverse=True)

    # 输出排序后的结果
    final_results = ""
    for result in sorted_results:
        final_results += f"第{result[0]}个网页：\n{result[1]}\n\n向量评分：{result[2]}/100（最高余弦相似度）\n----------\n\n"
    return final_results

# 问答函数
def chatbot_response(message, chat_history, webpage_text):
    # 添加用户的消息到历史记录
    chat_history.append(("用户", message))
    accumulated_text = ""
    yield chat_history, accumulated_text

    llm = ChatOpenAI(base_url="http://localhost:1234/v1", api_key="lm-studio", stop_sequences=["### Response", "<|endoftext|>", "###", "---"])
    embedding = embeddings.OpenAIEmbeddings(check_embedding_ctx_length=False, base_url="http://localhost:1234/v1", api_key="lm-studio")

    text_splitter = RecursiveCharacterTextSplitter(chunk_size=600, chunk_overlap=100)
    text_splitter2 = CharacterTextSplitter(separator="\n\n")
    texts = text_splitter.split_text(webpage_text)
    final_texts = []
    for text in texts:
        final_texts.extend(text_splitter2.split_text(text))
    # 向量化
    vectorstore = FAISS.from_texts(final_texts, embedding)

    template ="""{context}
1. 仅根据以上给出的信息回答问题，不要使用自身的知识。
2. 指明回答参考的文本块号。
问题: {question}

回答示例（问题：小家电的种类）：
回答：按照小家电的使用功能，可以将其分为四类：是厨房小家电产品、家居小家电产品、个人生活小家电产品、个人使用数码产品
来源：第2个文本块、第5个文本块、第11个文本块
"""
    prompt = ChatPromptTemplate.from_template(template)

    query = message
    docs = vectorstore.similarity_search(query)
    print(len(docs),"个doc")
    def format_docs(docs):
        return "\n\n".join(f"第{index}个文本块：\n"+doc.page_content for index, doc in enumerate(docs, start=1))

    formated_docs = format_docs(docs)

    retrieval_chain = (
        # {"context": retriever | format_docs, "question": RunnablePassthrough()}
        prompt
        | llm
        | StrOutputParser()
    )

    answer = ""
    for message in retrieval_chain.stream({"context": formated_docs, "question": query}):
        accumulated_text += message
        answer += message
        yield chat_history, accumulated_text
    
    accumulated_text += "\n"

    pattern = r'第(\d+)个文本块'
    matches = re.findall(pattern, answer)
    unique_list = list(set(matches)) # 去重
    print(unique_list,"个unique_list")
    for match in unique_list:
        # accumulated_text += (f'第{match}个文本块：\n'+docs[int(match)-1].page_content+"\n") # 文本块具体内容
        accumulated_text += (f'第{match}个文本块')
        # 查找网页
        article = webpage_text
        target_sentence = docs[int(match)-1].page_content

        article = article.replace('\n', '').strip()
        target_sentence = target_sentence.replace('\n', '').strip()
        article = re.sub(r'\s+', ' ', article)
        target_sentence = re.sub(r'\s+', ' ', target_sentence)
        
        header_pattern = r'第\d+个网页'
        paragraph_pattern = re.compile(rf'({header_pattern})(.*?)(?={header_pattern}|$)', re.DOTALL)
        paragraphs = paragraph_pattern.findall(article)
        for header, content in paragraphs:
            if target_sentence in content:
                accumulated_text += f"来源于：{header}\n"
                break
        else:
            accumulated_text += "来源于：未知\n"

        yield chat_history, accumulated_text
    
    # 将模型的回复添加到历史记录
    chat_history.append(("回答", accumulated_text))
    
    yield chat_history, ""


with gr.Blocks() as demo:
    title = gr.Markdown("## 行业信息获取器")
    description = gr.Markdown("输入行业，生成相关关键词。")

    with gr.Row():
        industry = gr.Textbox(label="输入行业名", lines=1, max_lines=1)
        keywords_generated = gr.Textbox(label="行业关键词", lines=8, max_lines=8, show_copy_button=True)
        keywords_processed = gr.Textbox(label="处理后的关键词", lines=8, max_lines=8, show_copy_button=True)
        

    with gr.Row():
    # 定义第一个按钮的行为
        generate_button = gr.Button("生成关键词")
        generate_button.click(fn=generate_keywords, inputs=industry, outputs=keywords_generated)
        process_button = gr.Button("处理关键词")
        clear_button = gr.Button("清除关键词")
        

    with gr.Row():
        with gr.Column():
            keywords_to_search = gr.Radio(label="搜索关键词", choices=[])
            
        with gr.Column():
            pages_needed = gr.Slider(label="搜索页数", minimum=1, maximum=20, step=1, value=1)
            links_found = gr.Textbox(label="获取到的链接", lines=10, max_lines=10)
            submit_button = gr.Button("获取链接")

        # 定义第三个按钮的行为
        # option_button = gr.Button("获取关键词")
        # option_button.click(fn=, inputs=keywords_generated, outputs=)
        process_button.click(fn=process_keywords, inputs=keywords_generated, outputs=[keywords_processed, keywords_to_search])
        submit_button.click(fn=selection_to_links, inputs=[industry, keywords_to_search, pages_needed], outputs=links_found)
        clear_button.click(fn=clear_all, inputs=[], outputs=[industry, keywords_generated, keywords_processed, keywords_to_search])

    with gr.Row():
        with gr.Column():
            min_slider = gr.Slider(label="获取页面最小值", minimum=0, maximum=20, step=1, value=1)
            max_slider = gr.Slider(label="获取页面最大值", minimum=0, maximum=20, step=1, value=0)
            no_slider = gr.Textbox(label="手动输入（例：1-10，例：25）", lines=1, max_lines=1)
            with gr.Row():
                webpage_to_analyze = gr.Textbox(label="输入需要分析的页面号", lines=1, max_lines=1, value=1)
                analyze_all = gr.Checkbox(label="选择全部", value=False)
            clear_slider = gr.ClearButton(components=[max_slider, no_slider], value="重新输入", )
            
            
        with gr.Column():
            webpage_text = gr.Textbox(label="网页内容", lines=16, max_lines=16, show_copy_button=True)
            
            
    with gr.Row():
        webpage_button = gr.Button("获取网页内容")
        webpage_analyze = gr.Button("分析网页内容")

    with gr.Row():
        # webpage_to_analyze = gr.Textbox(label="输入需要分析的页面号", lines=1, max_lines=1, value=1)
        # analyzed_keywords = gr.CheckboxGroup(label="分析关键词", choices=[])
        with gr.Column():
            analyzed_text = gr.Textbox(label="分析结果", lines=20, max_lines=20, show_copy_button=True)
            sort_button = gr.Button("排序")

        sorted_results = gr.Textbox(label="排序结果", lines=20, max_lines=20, show_copy_button=True)

    with gr.Row():
        chatbot = gr.Chatbot(label="问答机器人")
        with gr.Column():
            msg = gr.Textbox(label="网页问答", placeholder="输入你的问题")
            clear = gr.Button("清除问答记录")

        msg.submit(chatbot_response, [msg, chatbot, webpage_text], [chatbot, msg])
        clear.click(lambda: None, None, chatbot, queue=False)

    minnum = gr.Number(label="最小值", visible=False, interactive=False)
    maxnum = gr.Number(label="最大值", visible=False, interactive=False)
    webpage_button.click(fn=webpage_to_text, inputs=[links_found, min_slider, max_slider, no_slider], outputs=[webpage_text, minnum, maxnum])
    # webpage_analyze.click(fn=analyze_webpage, inputs=[webpage_text, keywords_processed, webpage_to_analyze, industry, keywords_to_search], outputs=[analyzed_text, analyzed_keywords])
    webpage_analyze.click(fn=batch_analyze_webpage, inputs=[analyze_all, webpage_text, keywords_processed, webpage_to_analyze, industry, keywords_to_search, minnum, maxnum], outputs=analyzed_text)
    sort_button.click(fn=sort_results, inputs=[analyzed_text], outputs=sorted_results)

demo.launch()


Running on local URL:  http://127.0.0.1:7868

To create a public link, set `share=True` in `launch()`.




4 个doc
['3'] 个unique_list
