In [2]:
# ====================================
# openai tokens
# ====================================

import os
import glob
import codecs
import pickle
import re
import textwrap
from collections import namedtuple

import faiss
from langchain.embeddings import OpenAIEmbeddings
from langchain.text_splitter import CharacterTextSplitter, MarkdownTextSplitter
from langchain.vectorstores import FAISS
from pymongo import MongoClient

from sys import path

path.append("/opt/configs/ramjet")
import prd

os.environ["OPENAI_API_KEY"] = prd.OPENAI_TOKEN_ME


Index = namedtuple("index", ["store", "scaned_files"])


def pretty_print(text: str) -> str:
    text = text.strip()
    return textwrap.fill(text, width=60, subsequent_indent="    ")



In [22]:
# ==============================================================
# prepare pdf documents docs.index & docs.store
#
# https://python.langchain.com/en/latest/modules/indexes/document_loaders/examples/pdf.html#retain-elements
#
# 通用的函数定义
# ==============================================================

from urllib.parse import quote

from langchain.document_loaders import PyPDFLoader

# from langchain.document_loaders import UnstructuredPDFLoader
from langchain.vectorstores import FAISS
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.text_splitter import CharacterTextSplitter, MarkdownTextSplitter

text_splitter = CharacterTextSplitter(chunk_size=500, separator="\n")
markdown_splitter = MarkdownTextSplitter(chunk_size=500, chunk_overlap=50)

N_BACTCH_FILES = 5


def is_file_scaned(index: Index, fpath):
    return os.path.split(fpath)[1] in index.scaned_files


def embedding_pdfs(index: Index, fpaths, url, replace_by_url):
    i = 0
    docs = []
    metadatas = []
    for fpath in fpaths:
        fname = os.path.split(fpath)[1]
        if is_file_scaned(index, fname):
            continue

        loader = PyPDFLoader(fpath)
        for page, data in enumerate(loader.load_and_split()):
            splits = text_splitter.split_text(data.page_content)
            docs.extend(splits)
            for ichunk, _ in enumerate(splits):
                fnameurl = quote(fpath.removeprefix(replace_by_url), safe="")
                furl = url + fnameurl
                metadatas.append({"source": f"{furl}#page={page+1}"})

        index.scaned_files.add(fname)
        print(f"scaned {fpath}")
        i += 1
        if i > N_BACTCH_FILES:
            break

    if i != 0:
        index.store.add_texts(docs, metadatas=metadatas)

    return i


def embedding_markdowns(index: Index, fpaths, url, replace_by_url):
    i = 0
    docs = []
    metadatas = []
    for fpath in fpaths:
        fname = os.path.split(fpath)[1]
        if is_file_scaned(index, fpath):
            continue

        with codecs.open(fpath, "rb", "utf8") as fp:
            docus = markdown_splitter.create_documents([fp.read()])
            for ichunk, docu in enumerate(docus):
                docs.append(docu.page_content)
                title = quote(docu.page_content.strip().split("\n", maxsplit=1)[0])
                if url:
                    fnameurl = quote(fpath.removeprefix(replace_by_url), safe="")
                    furl = url + fnameurl
                    metadatas.append({"source": f"{furl}#{title}"})
                else:
                    metadatas.append({"source": f"{fname}#{title}"})
                    
        index.scaned_files.add(fname)
        print(f"scaned {fpath}")
        i += 1
        if i > N_BACTCH_FILES:
            break

    if i != 0:
        index.store.add_texts(docs, metadatas=metadatas)

    return i


def load_store(dirpath, name) -> Index:
    """
    Args:
        dirpath: dirpath to store index files
        name: project/file name
    """
    index = faiss.read_index(f"{os.path.join(dirpath, name)}.index")
    with open(f"{os.path.join(dirpath, name)}.store", "rb") as f:
        store = pickle.load(f)
    store.index = index

    with open(f"{os.path.join(dirpath, name)}.scanedfile", "rb") as f:
        scaned_files = pickle.load(f)

    return Index(
        store=store,
        scaned_files=scaned_files,
    )


def new_store() -> Index:
    store = FAISS.from_texts(["world"], OpenAIEmbeddings(), metadatas=[{"source": "hello"}])
    return Index(
        store=store,
        scaned_files=set([]),
    )


def save_store(index: Index, dirpath, name):
    store_index = index.store.index
    fpath_prefix = os.path.join(dirpath, name)
    print(f"save store to {fpath_prefix}")
    faiss.write_index(store_index, f"{fpath_prefix}.index")
    index.store.index = None
    with open(f"{fpath_prefix}.store", "wb") as f:
        pickle.dump(index.store, f)
    index.store.index = store_index

    with open(f"{fpath_prefix}.scanedfile", "wb") as f:
        pickle.dump(index.scaned_files, f)


In [None]:
# incremental scan pdfs

def gen_pdfs():
    yield from glob.glob("/home/laisky/data/langchain/pdf/security/**/*.pdf", recursive=True)

def run_scan_pdfs():
#     index = new_store()
    total = 0
    while True:
        index = load_store(
            dirpath="/home/laisky/data/langchain/index",
            name="security",
        )
        n = embedding_pdfs(
            index=index,
            fpaths=gen_pdfs(),
            url="https://s3.laisky.com/public/papers/security/",
            replace_by_url="/home/laisky/data/langchain/pdf/security/",
        )
        total += n
        save_store(
            index=index, 
            dirpath="/home/laisky/data/langchain/index", 
            name="security",
        )
        
        print(f"scanned {total} files")
        if n == 0:
            return
        
run_scan_pdfs()

scaned /home/laisky/data/langchain/pdf/security/TEE/The trusted computer system evaluation criteria.pdf
scaned /home/laisky/data/langchain/pdf/security/TEE/TPM/TPM 2.0 - Part4_SuppRoutines_pub.pdf
scaned /home/laisky/data/langchain/pdf/security/TEE/TPM/TPM 2.0 - Part2_Structures_pub.pdf
scaned /home/laisky/data/langchain/pdf/security/TEE/TPM/TCG PC Client Platform TPM Profile Specification for TPM 2.0.pdf
scaned /home/laisky/data/langchain/pdf/security/TEE/TPM/TPM 2.0 Keys for Device Identity and Attestation.pdf
scaned /home/laisky/data/langchain/pdf/security/TEE/TPM/TPM 2.0 - Part4_SuppRoutines_code_pub.pdf


In [None]:
# incremental scan markdowns

def gen_markdowns():
    yield "/home/laisky/data/langchain/basebit/doc/content/terms.md"
    yield from glob.glob("/home/laisky/data/langchain/basebit/doc/content/research/**/*.md", recursive=True)
    

def run_scan_markdowns():
#         index = new_store()
    while True:
        index = load_store(
            dirpath="/home/laisky/data/langchain/index",
            name="security",
        )
        files = gen_markdowns()
        n = embedding_markdowns(
            index=index,
            fpaths=files,
            url="https://s3.laisky.com/public/papers/security/",
            replace_by_url="/home/laisky/data/langchain/pdf/security/",
        )
        save_store(
            index=index,
            dirpath="/home/laisky/data/langchain/index/", 
            name="security",
        )
        
        print(f"{n=}")
        if n == 0:
            return
        
        
run_scan_markdowns()

In [36]:
# ====================================
# 生成用于问答的 query chain
# ====================================

from langchain.prompts.chat import (
    ChatPromptTemplate,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
)

system_template="""Use the following pieces of context to answer the users question.
Take note of the sources and include them in the answer in the format: "SOURCES: source1 source2", use "SOURCES" in capital letters regardless of the number of sources.
If you don't know the answer, just say that "I don't know", don't try to make up an answer.
----------------
{summaries}"""

messages = [
    SystemMessagePromptTemplate.from_template(system_template),
    HumanMessagePromptTemplate.from_template("{question}")
]
prompt = ChatPromptTemplate.from_messages(messages)

from langchain.chat_models import ChatOpenAI
from langchain.chains import ChatVectorDBChain

chain_type_kwargs = {"prompt": prompt}
llm = ChatOpenAI(
    model_name="gpt-3.5-turbo", 
    temperature=0, 
    max_tokens=2000)  
chain = VectorDBQAWithSourcesChain.from_chain_type(
    llm=llm,
    vectorstore=index.store,
    return_source_documents=True,
    chain_type_kwargs=chain_type_kwargs,
    reduce_k_below_max_tokens=True,
)

In [37]:
# ====================================
# ask pdf embeddings
# ====================================
question = "TPM 的 PCR 是啥"
result = chain(
    {
        "question": question,
    },
    return_only_outputs=True,
)

print(f"🤔️: {question}\n")
print(f"🤖️: {pretty_print(result['answer'])}\n")
print(f"📖: {result['sources']}")

🤔️: TPM 的 PCR 是啥

🤖️: TPM 的 PCR 是指 Trusted Platform Module 的 Platform
    Configuration Registers，是一种用于记录计算机系统状态的寄存器。TPM 初始化时会将所有
    PCR 的值设置为 0 或 1，之后记录系统启动过程中的关键软件和配置信息的哈希值。PCR
    可以用于验证系统的完整性和安全性，以及实现安全策略。

📖: https://s3.laisky.com/public/papers/security/TEE%2FTPM%2FTPM%202.0%20-%20Part1-%20Architecture.pdf#page=21, https://s3.laisky.com/public/papers/security/TEE%2FTPM%2FA%20Practical%20Guide%20to%20TPM%202.0.pdf#page=156, https://s3.laisky.com/public/papers/security/TEE%2FTPM%2FTPM%202.0%20-%20Part1-%20Architecture.pdf#page=133
