In [1]:
from langchain.chains import LLMChain
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI

from finllmqa.api.core import LLM_API_URL

In [2]:
llm = ChatOpenAI(base_url=LLM_API_URL, api_key='null')

In [3]:
import pandas as pd
data = pd.read_excel('question.xlsx')

training_data = pd.concat([data[data['分类'] == '股票投资'].sample(10),
                           data[data['分类'] == '财经百科'].sample(10)])
test_data = data.drop(training_data.index)

In [4]:
def parse_classification(text):
    return '股票投资' if '股票投资' in text else ('财经百科' if '财经百科' in text else '')

In [5]:
take = 5

## Benchmark

In [6]:
base_prompt_template = """
    你是一名金融文本分类专家，请对下列问题进行分类，只能从[股票投资,财经百科,其他]选择一种合适的分类
    只能回答[股票投资,财经百科,其他]中的一种

    问题:{question}
    分类结果:
"""

base_prompt = PromptTemplate.from_template(base_prompt_template)
llm_chain = LLMChain(prompt=base_prompt, llm=llm)


In [34]:
benchmark_score = {'股票投资精确度': 0, '股票投资召回率': 0, '财经百科精确度': 0, '财经百科召回率': 0}
for i in range(take):
    benchmark_correct = {'股票投资':0, '财经百科':0}
    benchmark_count = {'股票投资':0, '财经百科':0}
    for i, row in test_data.iterrows():
        response = llm_chain.invoke(dict(question=row['问题']))['text']
        classification = parse_classification(response)
        if classification:
            benchmark_count[classification] += 1
        if classification == row['分类']:
            benchmark_correct[classification] += 1
    for key in benchmark_correct:
        benchmark_score[key+'精确度'] += benchmark_correct[key] / benchmark_count[key]
        benchmark_score[key+'召回率'] += benchmark_correct[key] * 2 / len(test_data)

for key in benchmark_score.keys():
    benchmark_score[key] /= take

2024-04-17 23:57:47,112 - INFO - HTTP Request: POST http://gemini2.sufe.edu.cn:27282/v1/chat/completions "HTTP/1.1 200 OK"
2024-04-17 23:57:47,785 - INFO - HTTP Request: POST http://gemini2.sufe.edu.cn:27282/v1/chat/completions "HTTP/1.1 200 OK"
2024-04-17 23:57:48,374 - INFO - HTTP Request: POST http://gemini2.sufe.edu.cn:27282/v1/chat/completions "HTTP/1.1 200 OK"
2024-04-17 23:57:48,850 - INFO - HTTP Request: POST http://gemini2.sufe.edu.cn:27282/v1/chat/completions "HTTP/1.1 200 OK"
2024-04-17 23:57:49,347 - INFO - HTTP Request: POST http://gemini2.sufe.edu.cn:27282/v1/chat/completions "HTTP/1.1 200 OK"
2024-04-17 23:57:50,219 - INFO - HTTP Request: POST http://gemini2.sufe.edu.cn:27282/v1/chat/completions "HTTP/1.1 200 OK"
2024-04-17 23:57:50,809 - INFO - HTTP Request: POST http://gemini2.sufe.edu.cn:27282/v1/chat/completions "HTTP/1.1 200 OK"
2024-04-17 23:57:51,453 - INFO - HTTP Request: POST http://gemini2.sufe.edu.cn:27282/v1/chat/completions "HTTP/1.1 200 OK"
2024-04-17 23:57

In [35]:
benchmark_score

{'股票投资精确度': 0.7554411764705883,
 '股票投资召回率': 0.6461538461538461,
 '财经百科精确度': 0.6350523872263002,
 '财经百科召回率': 0.6666666666666667}

## Few Shot

In [36]:
fewshot_template = """
    你是一名金融文本分类专家，请对以下的金融问题进行分类，只能从[股票投资,财经百科,其他]选择一种合适的分类, 
    只能回答[股票投资,财经百科,其他]中的一种

    举例:
    {example}

    问题:{question}
    分类结果:
"""

fewshot_prompt = PromptTemplate.from_template(fewshot_template)
fewshot_chain = LLMChain(prompt=fewshot_prompt, llm=llm)

example_data = training_data[7:12]
example_str = ''
for i, row in example_data.iterrows():
    example_str+=f"问题: {row['问题']}\n 分类结果: {row['分类']}\n "


In [37]:
fewshot_score = {'股票投资精确度': 0, '股票投资召回率': 0, '财经百科精确度': 0, '财经百科召回率': 0}
for i in range(take):
    fewshot_correct = {'股票投资':0, '财经百科':0}
    fewshot_count = {'股票投资':0, '财经百科':0}
    for i, row in test_data.iterrows():
        response = fewshot_chain.invoke(dict(question=row['问题'], example=example_str))['text']
        classification = parse_classification(response)
        if classification:
            fewshot_count[classification] += 1
        if classification == row['分类']:
            fewshot_correct[classification] += 1
    for key in fewshot_correct:
        fewshot_score[key+'精确度'] += fewshot_correct[key] / fewshot_count[key]
        fewshot_score[key+'召回率'] += fewshot_correct[key] * 2 / len(test_data)

for key in fewshot_score.keys():
    fewshot_score[key] /= take

2024-04-18 00:02:04,204 - INFO - HTTP Request: POST http://gemini2.sufe.edu.cn:27282/v1/chat/completions "HTTP/1.1 200 OK"
2024-04-18 00:02:04,638 - INFO - HTTP Request: POST http://gemini2.sufe.edu.cn:27282/v1/chat/completions "HTTP/1.1 200 OK"
2024-04-18 00:02:05,267 - INFO - HTTP Request: POST http://gemini2.sufe.edu.cn:27282/v1/chat/completions "HTTP/1.1 200 OK"
2024-04-18 00:02:05,748 - INFO - HTTP Request: POST http://gemini2.sufe.edu.cn:27282/v1/chat/completions "HTTP/1.1 200 OK"
2024-04-18 00:02:06,238 - INFO - HTTP Request: POST http://gemini2.sufe.edu.cn:27282/v1/chat/completions "HTTP/1.1 200 OK"
2024-04-18 00:02:06,751 - INFO - HTTP Request: POST http://gemini2.sufe.edu.cn:27282/v1/chat/completions "HTTP/1.1 200 OK"
2024-04-18 00:02:07,255 - INFO - HTTP Request: POST http://gemini2.sufe.edu.cn:27282/v1/chat/completions "HTTP/1.1 200 OK"
2024-04-18 00:02:07,750 - INFO - HTTP Request: POST http://gemini2.sufe.edu.cn:27282/v1/chat/completions "HTTP/1.1 200 OK"
2024-04-18 00:02

In [39]:
fewshot_score

{'股票投资精确度': 0.8768465280849181,
 '股票投资召回率': 0.8717948717948719,
 '财经百科精确度': 0.8430242272347535,
 '财经百科召回率': 0.8102564102564103}

## Quetion Type Summary + Few Shot

In [6]:
question_classify_template = """
    你是一名金融文本分类专家，请对以下的金融问题进行分类，类别只能是[股票投资,财经百科,其他]中的一种, 类别特征如下：
    {question_type_description}

    举例:
    {example}

    问题:{question}
    分类结果:
"""

question_summary_template = """
    你是一名财经领域的专家，以下是财经领域关于{question_type}的问题，请你总结这类问题的特征。

    问题：
    {question_list_str}
    特征:
"""

In [7]:
example_data = training_data[7:12]
example_str = ''
for i, row in example_data.iterrows():
    example_str+=f"问题: {row['问题']}\n 分类结果: {row['分类']}\n "

In [8]:
question_classify_prompt = PromptTemplate.from_template(question_classify_template)
question_summary_prompt = PromptTemplate.from_template(question_summary_template)
classify_chain = LLMChain(prompt=question_classify_prompt, llm=llm)
summary_chain = LLMChain(prompt=question_summary_prompt, llm=llm)

In [9]:
stock_investment_question_pool = training_data[training_data['分类'] == '股票投资']['问题'].to_list()
financial_knowledge_pool = training_data[training_data['分类'] == '财经百科']['问题'].to_list()

In [10]:
question_type_description = {}
stock_investment_question_list_str = '\n'.join(stock_investment_question_pool)
stock_investment_summary_response=summary_chain.invoke(dict(question_type='股票投资', question_list_str=stock_investment_question_list_str))['text']
question_type_description['股票投资'] = stock_investment_summary_response

financial_knowledge_question_list_str = '\n'.join(financial_knowledge_pool)
financial_knowledge_summary_response=summary_chain.invoke(dict(question_type='财经百科', question_list_str=financial_knowledge_question_list_str))['text']
question_type_description['财经百科'] = financial_knowledge_summary_response

2024-04-22 13:46:16,889 - httpx - INFO - HTTP Request: POST http://gemini2.sufe.edu.cn:27282/v1/chat/completions "HTTP/1.1 200 OK"
2024-04-22 13:46:46,213 - httpx - INFO - HTTP Request: POST http://gemini2.sufe.edu.cn:27282/v1/chat/completions "HTTP/1.1 200 OK"


In [45]:
summary_fewshot_score = {'股票投资精确度': 0, '股票投资召回率': 0, '财经百科精确度': 0, '财经百科召回率': 0}
for i in range(take):
    summary_fewshot_correct = {'股票投资':0, '财经百科':0}
    summary_fewshot_count= {'股票投资':0, '财经百科':0}
    for i, row in test_data.iterrows():
        classify_response = classify_chain.invoke(dict(question_type_description=str(question_type_description),
                                                example=example_str,
                                                question=row['问题']))['text']
        classification = parse_classification(classify_response)
        if classification:
            summary_fewshot_count[classification] += 1
        if classification == row['分类']:
            summary_fewshot_correct[classification] += 1
    for key in summary_fewshot_correct:
        summary_fewshot_score[key+'精确度'] += summary_fewshot_correct[key] / summary_fewshot_count[key]
        summary_fewshot_score[key+'召回率'] += summary_fewshot_correct[key] * 2 / len(test_data)

for key in summary_fewshot_score.keys():
    summary_fewshot_score[key] /= take

2024-04-18 00:06:06,388 - INFO - HTTP Request: POST http://gemini2.sufe.edu.cn:27282/v1/chat/completions "HTTP/1.1 200 OK"
2024-04-18 00:06:06,892 - INFO - HTTP Request: POST http://gemini2.sufe.edu.cn:27282/v1/chat/completions "HTTP/1.1 200 OK"
2024-04-18 00:06:07,384 - INFO - HTTP Request: POST http://gemini2.sufe.edu.cn:27282/v1/chat/completions "HTTP/1.1 200 OK"
2024-04-18 00:06:07,877 - INFO - HTTP Request: POST http://gemini2.sufe.edu.cn:27282/v1/chat/completions "HTTP/1.1 200 OK"
2024-04-18 00:06:08,394 - INFO - HTTP Request: POST http://gemini2.sufe.edu.cn:27282/v1/chat/completions "HTTP/1.1 200 OK"
2024-04-18 00:06:08,902 - INFO - HTTP Request: POST http://gemini2.sufe.edu.cn:27282/v1/chat/completions "HTTP/1.1 200 OK"
2024-04-18 00:06:09,381 - INFO - HTTP Request: POST http://gemini2.sufe.edu.cn:27282/v1/chat/completions "HTTP/1.1 200 OK"
2024-04-18 00:06:09,880 - INFO - HTTP Request: POST http://gemini2.sufe.edu.cn:27282/v1/chat/completions "HTTP/1.1 200 OK"
2024-04-18 00:06

In [48]:
summary_fewshot_count

{'股票投资': 21, '财经百科': 17}

In [46]:
summary_fewshot_score

{'股票投资精确度': 0.9241991341991342,
 '股票投资召回率': 0.9948717948717949,
 '财经百科精确度': 0.9652089783281734,
 '财经百科召回率': 0.8512820512820515}

## Question Type Summary + Few Shot + Training

In [19]:
question_classify_template = """
    你是一名金融文本分类专家，请对以下的金融问题进行分类，类别只能是[股票投资,财经百科,其他]中的一种, 类别特征如下：
    {question_type_description}

    举例:
    {example}

    问题:{question}
    分类结果:
"""

question_summary_template = """
    你是一名财经领域的专家，以下是财经领域关于{question_type}的问题，请你总结这类问题的特征。

    问题：
    {question_list_str}
    特征:
"""

In [20]:
question_classify_prompt = PromptTemplate.from_template(question_classify_template)
question_summary_prompt = PromptTemplate.from_template(question_summary_template)
classify_chain = LLMChain(prompt=question_classify_prompt, llm=llm)
summary_chain = LLMChain(prompt=question_summary_prompt, llm=llm)

In [21]:
example_data = training_data[7:12]
example_str = ''
for i, row in example_data.iterrows():
    example_str+=f"问题: {row['问题']}\n 分类结果: {row['分类']}\n "

In [22]:
summary_fewshot_train_correct = {'股票投资':0, '财经百科':0}
summary_fewshot_train_count = {'股票投资':0, '财经百科':0}
for i in range(take):
    stock_investment_question_pool = training_data[training_data['分类'] == '股票投资']['问题'].to_list()
    financial_knowledge_pool = training_data[training_data['分类'] == '财经百科']['问题'].to_list()
    for i, row in test_data.iterrows():
        if i // 5 == 0:
            question_type_description = {}
            stock_investment_question_list_str = '\n'.join(stock_investment_question_pool)
            stock_investment_summary_response=summary_chain.invoke(dict(question_type='股票投资', question_list_str=stock_investment_question_list_str))['text']
            question_type_description['股票投资'] = stock_investment_summary_response

            financial_knowledge_question_list_str = '\n'.join(financial_knowledge_pool)
            financial_knowledge_summary_response=summary_chain.invoke(dict(question_type='财经百科', question_list_str=financial_knowledge_question_list_str))['text']
            question_type_description['财经百科'] = financial_knowledge_summary_response

        classify_response = classify_chain.invoke(dict(question_type_description=str(question_type_description),
                                                    example=example_str,
                                                    question=row['问题']))['text']
        classification = parse_classification(classify_response)
        if classification:
            summary_fewshot_train_count[classification] += 1
        if classification == '股票投资':
            stock_investment_question_pool.append(row['问题'])
        else:
            financial_knowledge_pool.append(row['问题'])

        if classification == row['分类']:
            summary_fewshot_train_correct[classification] += 1

for key in fewshot_correct.keys():
    summary_fewshot_train_correct[key] /= take


2024-04-17 22:56:18,900 - INFO - HTTP Request: POST http://gemini2.sufe.edu.cn:27282/v1/chat/completions "HTTP/1.1 200 OK"
2024-04-17 22:56:21,865 - INFO - HTTP Request: POST http://gemini2.sufe.edu.cn:27282/v1/chat/completions "HTTP/1.1 200 OK"
2024-04-17 22:56:22,399 - INFO - HTTP Request: POST http://gemini2.sufe.edu.cn:27282/v1/chat/completions "HTTP/1.1 200 OK"
2024-04-17 22:56:29,967 - INFO - HTTP Request: POST http://gemini2.sufe.edu.cn:27282/v1/chat/completions "HTTP/1.1 200 OK"
2024-04-17 22:56:32,809 - INFO - HTTP Request: POST http://gemini2.sufe.edu.cn:27282/v1/chat/completions "HTTP/1.1 200 OK"
2024-04-17 22:56:33,296 - INFO - HTTP Request: POST http://gemini2.sufe.edu.cn:27282/v1/chat/completions "HTTP/1.1 200 OK"
2024-04-17 22:56:39,913 - INFO - HTTP Request: POST http://gemini2.sufe.edu.cn:27282/v1/chat/completions "HTTP/1.1 200 OK"
2024-04-17 22:56:42,651 - INFO - HTTP Request: POST http://gemini2.sufe.edu.cn:27282/v1/chat/completions "HTTP/1.1 200 OK"
2024-04-17 22:56

In [23]:
summary_fewshot_train_correct

{'股票投资': 19.8, '财经百科': 17.2}