In [54]:
import os
import pandas as pd
import numpy as np
from pdfminer.converter import TextConverter
from pdfminer.pdfdocument import PDFTextExtractionNotAllowed
from pdfminer.pdfinterp import PDFResourceManager, PDFPageInterpreter
from pdfminer.pdfpage import PDFPage
import io
import jieba

from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Weaviate
import weaviate
from weaviate.embedded import EmbeddedOptions

from tqdm import tqdm

from langchain.text_splitter import CharacterTextSplitter
from langchain.text_splitter import TokenTextSplitter

from FlagEmbedding import FlagModel
from sklearn.metrics.pairwise import cosine_similarity

In [55]:
def parsePDF(PDF_path):
    flag = False
    if 'AZ' in PDF_path:
        flag = True
    resource_manager = PDFResourceManager()
    fake_file_handle = io.StringIO()
    converter = TextConverter(resource_manager,fake_file_handle)
    page_interpreter = PDFPageInterpreter(resource_manager,converter)
    with open(PDF_path,'rb') as fh:
        for n_page,page in enumerate(PDFPage.get_pages(fh,caching=True,check_extractable=False)):
            if flag:
                if n_page < 2:
                    continue
            page_interpreter.process_page(page)
        text = fake_file_handle.getvalue()
    converter.close()
    fake_file_handle.close()
    if text:
        return text

In [56]:
def drop_content_from_text(text,drop_content):
    for content in drop_content:
        text = text.replace(content,'')
    return text

In [57]:
document_root = r'data/A榜/A_document'
docs = os.listdir(document_root)

knowledge = []

drop_content = ['\x0c',
                '本文档为2024CCFBDCI比赛用语料的一部分。部分文档使用大语言模型改写生成，内容可能与现实情况不符，可能不具备现实意义，仅允许在本次比赛中使用。']

for doc in tqdm(docs):
    # if doc != 'AZ01.pdf':
    #     continue
    doc_path = os.path.join(document_root,doc)
    text = drop_content_from_text(parsePDF(doc_path),drop_content)
    knowledge.append(text)

knowledge = ''.join(knowledge)

 92%|█████████▏| 110/120 [00:43<00:06,  1.59it/s]The PDF <_io.BufferedReader name='data/A榜/A_document\\AZ01.pdf'> contains a metadata field indicating that it should not allow text extraction. Ignoring this field and proceeding. Use the check_extractable if you want to raise an error in this case
 92%|█████████▎| 111/120 [00:44<00:06,  1.49it/s]The PDF <_io.BufferedReader name='data/A榜/A_document\\AZ02.pdf'> contains a metadata field indicating that it should not allow text extraction. Ignoring this field and proceeding. Use the check_extractable if you want to raise an error in this case
 93%|█████████▎| 112/120 [00:45<00:07,  1.06it/s]The PDF <_io.BufferedReader name='data/A榜/A_document\\AZ03.pdf'> contains a metadata field indicating that it should not allow text extraction. Ignoring this field and proceeding. Use the check_extractable if you want to raise an error in this case
 94%|█████████▍| 113/120 [00:47<00:07,  1.00s/it]The PDF <_io.BufferedReader name='data/A榜/A_document\\AZ0

In [58]:
# text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=50)
# chunks = text_splitter.split_documents(knowledge)

token_splitter = TokenTextSplitter(chunk_size=500, chunk_overlap=50)
chunks = token_splitter.split_text(knowledge)

In [59]:
class KnowledgeDataBase:
    def __init__(self,chunks) -> None:
        self.model = FlagModel('BAAI/bge-large-zh-v1.5',
                  query_instruction_for_retrieval="Represent this sentence for searching relevant passages:",
                  use_fp16=True)
        self.chunks = [drop_content_from_text(chunk,'�') for chunk in chunks]
        self.db = self.create_db(chunks)

    def create_db(self,chunks):
        print('create db..')
        db = []
        for chunk in tqdm(chunks):
            db.append(self.model.encode(chunk))
        return db

    def search(self,query,n=5):
        query_embedding = self.model.encode(query)
        similarity_scores = [(doc_id, cosine_similarity(query_embedding.reshape(1,-1), doc.reshape(1,-1))[0][0]) 
                            for doc_id, doc in enumerate(self.db)]
        top_n_results = sorted(similarity_scores, key=lambda x: x[1], reverse=True)[:n]
        return top_n_results
    

In [60]:
kdb = KnowledgeDataBase(chunks)

create db..


100%|██████████| 3449/3449 [33:08<00:00,  1.73it/s]


In [61]:
test_path = r'data/A榜/A_question.csv'
df_test = pd.read_csv(test_path)

# df_test['answer'] = df_test['question'].apply(lambda x:''.join([kdb.chunks[id] for id,score in kdb.search(x,n=5)]))
# df_test['embedding'] = df_test['answer'].apply(lambda x:kdb.model.encode(x))
# df_test.to_csv('result.csv')

In [62]:
df_test['id'] = df_test['question'].apply(lambda x:[id for id,score in kdb.search(x,n=1)][0])

In [63]:
df_test['id']

0     1115
1     1623
2     1504
3     1506
4     1571
      ... 
95     817
96     821
97     685
98     694
99     699
Name: id, Length: 100, dtype: int64

In [64]:
df_test['answer'] = df_test['id'].apply(lambda x:kdb.chunks[x])
df_test['embedding'] = df_test['id'].apply(lambda x:str(list(kdb.db[x]))[1:-2])

In [65]:
df_test.to_csv('result.csv',index=False)