In [None]:
# %pip install -qU pypdf
# %pip install -qU langchain-unstructured
# %pip install -qU "unstructured[pdf]"
# %pip install -qU matplotlib PyMuPDF pillow

## Extracting texts with PyPDFLoader

In [None]:
from langchain_community.document_loaders import PyPDFLoader

file_path = "data/scholar_corpus/0KCrVeiC5zEJ.pdf"

loader = PyPDFLoader(file_path)
pages = []
async for page in loader.alazy_load():
    pages.append(page)

In [None]:
print(f"{pages[0].metadata}\n")
print(pages[0].page_content)

## Extracting texts, images & tables with Unstructured API

In [None]:
import getpass
import os

if "UNSTRUCTURED_API_KEY" not in os.environ:
    os.environ["UNSTRUCTURED_API_KEY"] = getpass.getpass("Unstructured API Key:")

In [None]:
from langchain_unstructured import UnstructuredLoader

file_path = "data/scholar_corpus/0KCrVeiC5zEJ.pdf"

loader = UnstructuredLoader(
    file_path=file_path,
    strategy="hi_res",
    partition_via_api=True,
    coordinates=True,
)
docs = []
for doc in loader.lazy_load():
    docs.append(doc)

In [None]:
print(len(docs))

In [None]:
from pprint import pprint
pprint(docs[0])

In [None]:
first_page_docs = [doc for doc in docs if (doc.metadata.get("page_number") == 1)]

for doc in first_page_docs:
    print(doc.page_content)
    print("")

In [None]:
##fzz：建立index，这样下面的dducument数据结构才能找到coordinates里面的字段
index_dict = {}
for i in first_page_docs:
    element_id = i.metadata['element_id']
    index_dict[element_id] = i

In [None]:
print(index_dict['c1ca1e8812f489e94f6cf80276f8bff6'])

In [None]:
###fzz: 尝试输出page_content理应相连的部分
def judge_end(st):##true就是这段话完整，False表示这段话不完整
    if(st[len(st)-1] not in ".?!"):##论文的段落结尾应该都是'.' 好像没见过别的
        return False
    else:
        ##判断是不是没有不匹配的括号，有说明不完整(为了防止上面的if看到的是引用的句号)
        cnt_left_bracket = 0
        cnt_right_bracket = 0
        for i in st:
            if(i == "("):
                cnt_left_bracket += 1
            if(i == ")"):
                cnt_right_bracket += 1
        if(cnt_left_bracket != cnt_right_bracket):
            return False
    return True

new_first_page_docs = []
tmp_doc = None
for i in range(0, len(first_page_docs)):
    if(len(first_page_docs[i].page_content.split(" ")) < 15):
        new_first_page_docs.append(first_page_docs[i])
        continue
    if(tmp_doc == None):
        if(judge_end(first_page_docs[i].page_content)):
            new_first_page_docs.append(first_page_docs[i])
        else:
            tmp_doc = first_page_docs[i]
    else:
        tmp_doc.page_content += first_page_docs[i].page_content
        # print(tmp_doc.metadata['element_id'])
        # print(index_dict[tmp_doc.metadata['element_id']])
        if('element_id_ls' not in index_dict[tmp_doc.metadata['element_id']].metadata.keys()):
            index_dict[tmp_doc.metadata['element_id']].metadata['element_id_ls'] = [tmp_doc.metadata['element_id'], first_page_docs[i].metadata['element_id']]
        else:
            index_dict[tmp_doc.metadata['element_id']].metadata['element_id_ls'].append(first_page_docs[i].metadata['element_id'])
        if(judge_end(tmp_doc.page_content)):
            new_first_page_docs.append(tmp_doc)
            tmp_doc = None


In [None]:
for doc in new_first_page_docs:
    print(doc.page_content)
    print("")

In [None]:
print(first_page_docs[0].id)
pprint(first_page_docs[0].metadata)
print(first_page_docs[0].type)
print(type(first_page_docs[0]))

In [None]:
for prop in vars(first_page_docs[0]):
    if not prop.startswith('__'):
        print(prop)
# first_page_docs[0].page_content

In [None]:
import fitz
import matplotlib.patches as patches
import matplotlib.pyplot as plt
from PIL import Image


def plot_pdf_with_boxes(pdf_page, segments):
    pix = pdf_page.get_pixmap()
    pil_image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)

    fig, ax = plt.subplots(1, figsize=(10, 10))
    ax.imshow(pil_image)
    categories = set()
    category_to_color = {
        "Title": "orchid",
        "Image": "forestgreen",
        "Table": "tomato",
    }
    for segment in segments:
        points = segment["coordinates"]["points"]
        layout_width = segment["coordinates"]["layout_width"]
        layout_height = segment["coordinates"]["layout_height"]
        scaled_points = [
            (x * pix.width / layout_width, y * pix.height / layout_height)
            for x, y in points
        ]
        box_color = category_to_color.get(segment["category"], "deepskyblue")
        categories.add(segment["category"])
        rect = patches.Polygon(
            scaled_points, linewidth=1, edgecolor=box_color, facecolor="none"
        )
        ax.add_patch(rect)

    # Make legend
    legend_handles = [patches.Patch(color="deepskyblue", label="Text")]
    for category in ["Title", "Image", "Table"]:
        if category in categories:
            legend_handles.append(
                patches.Patch(color=category_to_color[category], label=category)
            )
    ax.axis("off")
    ax.legend(handles=legend_handles, loc="upper right")
    plt.tight_layout()
    plt.show()


def render_page(doc_list: list, page_number: int, print_text=True) -> None:
    pdf_page = fitz.open(file_path).load_page(page_number - 1)
    page_docs = [
        doc for doc in doc_list if doc.metadata.get("page_number") == page_number
    ]
    segments = [doc.metadata for doc in page_docs]
    plot_pdf_with_boxes(pdf_page, segments)
    if print_text:
        for doc in page_docs:
            print(f"{doc.page_content}\n")

In [None]:
render_page(docs,1)

## Transform into Embeddings

In [None]:
from langchain_community.vectorstores.utils import filter_complex_metadata
# Filter out coordinates metadata which is not supported in vector store
filtered_docs = filter_complex_metadata(docs)
print(filtered_docs)

In [None]:
from langchain.vectorstores import Chroma
from langchain_community.embeddings import OpenAIEmbeddings

chroma_store = Chroma.from_documents(
    documents=filtered_docs,
    embedding=OpenAIEmbeddings(),
    persist_directory="scholar_embeddings"
)

# No longer userful as docs are automatically persisted.
# https://python.langchain.com/api_reference/community/vectorstores/langchain_community.vectorstores.chroma.Chroma.html#langchain_community.vectorstores.chroma.Chroma.persist
# chroma_store.persist()

## Ask questions

In [None]:
chroma_store.search('Contrastive Decoding', search_type="similarity")

In [None]:
chroma_store.search('Does this paper mentions contrastive decoding?', search_type="similarity")

In [None]:
from langchain.chains import VectorDBQA
from langchain_community.chat_models import ChatOllama

model_local = ChatOllama(model="qwen:7b")
persist_directory = "scholar_embeddings"

# TODO: Need refactor, Deprecated class
# https://python.langchain.com/api_reference/core/vectorstores/langchain_core.vectorstores.in_memory.InMemoryVectorStore.html#langchain_core.vectorstores.in_memory.InMemoryVectorStore
qa = VectorDBQA.from_chain_type(llm=model_local, chain_type="stuff", vectorstore=chroma_store)

In [None]:
query = "Does this paper mentions contrastive decoding?"
result = qa.run(query)

print(result)