In [2]:
# ====================================
# openai tokens
# ====================================

import os
import glob
import codecs
import pickle
import re
import textwrap
from collections import namedtuple

import openai
import faiss
from langchain.embeddings import OpenAIEmbeddings
from langchain.text_splitter import CharacterTextSplitter, MarkdownTextSplitter
from langchain.vectorstores import FAISS
from pymongo import MongoClient
from kipp.utils import setup_logger

from sys import path

path.append("/opt/configs/ramjet")
import prd

# ----------------------------------------------
# Azure
# ----------------------------------------------
# os.environ['OPENAI_API_TYPE'] = "azure"
# os.environ['OPENAI_API_VERSION'] = prd.OPENAI_AZURE_VERSION
# os.environ['OPENAI_API_BASE'] = prd.OPENAI_AZURE_API
# os.environ['OPENAI_API_KEY'] = prd.OPENAI_AZURE_TOKEN

# openai.api_type = os.environ['OPENAI_API_TYPE']
# openai.api_version = os.environ['OPENAI_API_VERSION']
# openai.api_base = os.environ['OPENAI_API_BASE']
# openai.api_key = os.environ['OPENAI_API_KEY']

# azure_embeddings_deploymentid = prd.OPENAI_AZURE_DEPLOYMENTS["embeddings"].deployment_id
# azure_gpt_deploymentid = prd.OPENAI_AZURE_DEPLOYMENTS["chat"].deployment_id
# ----------------------------------------------

# ----------------------------------------------
# OpenAI
# ----------------------------------------------
os.environ["OPENAI_API_KEY"] = prd.OPENAI_TOKEN

openai.api_key = os.environ['OPENAI_API_KEY']
# ----------------------------------------------

Index = namedtuple("index", ["store", "scaned_files"])


def pretty_print(text: str) -> str:
    text = text.strip()
    return textwrap.fill(text, width=60, subsequent_indent="    ")


# =============================
# 定义文件路径
# =============================
name = "security"
logger = setup_logger(name)

index_dirpath = "/home/laisky/data/langchain/index-azure"
pdf_dirpath = f"/home/laisky/data/langchain/pdf/{name}"

for path in [index_dirpath, pdf_dirpath]:
    try:
        os.mkdir(path)
    except FileExistsError:
        pass

In [3]:
# ==============================================================
# prepare pdf documents docs.index & docs.store
#
# https://python.langchain.com/en/latest/modules/indexes/document_loaders/examples/pdf.html#retain-elements
#
# 通用的函数定义
# ==============================================================

from urllib.parse import quote

from langchain.document_loaders import PyPDFLoader
from langchain.vectorstores import FAISS
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.text_splitter import CharacterTextSplitter, MarkdownTextSplitter
from ramjet.tasks.gptchat.embedding.embeddings import reset_eof_of_pdf

text_splitter = CharacterTextSplitter(chunk_size=500, separator="\n")
markdown_splitter = MarkdownTextSplitter(chunk_size=500, chunk_overlap=50)

N_BACTCH_FILES = 5


def is_file_scaned(index: Index, fpath):
    return os.path.split(fpath)[1] in index.scaned_files


def embedding_pdfs(index: Index, fpaths, url, replace_by_url):
    i = 0
    docs = []
    metadatas = []
    for fpath in fpaths:
        fname = os.path.split(fpath)[1]
        if is_file_scaned(index, fname):
            continue

        try:
            reset_eof_of_pdf(fpath)
            loader = PyPDFLoader(fpath)
            for page, data in enumerate(loader.load_and_split()):
                splits = text_splitter.split_text(data.page_content)
                docs.extend(splits)
                for ichunk, _ in enumerate(splits):
                    fnameurl = quote(fpath.removeprefix(replace_by_url), safe="")
                    furl = url + fnameurl
                    metadatas.append({"source": f"{furl}#page={page+1}"})
        except Exception:
            logger.exception(f"skip file {fpath}")
            continue

        index.scaned_files.add(fname)
        print(f"scaned {fpath}")
        i += 1
        if i > N_BACTCH_FILES:
            break

    if i != 0:
        # fix stupid compatability issue in langchain faiss
        if not getattr(index.store, "_normalize_L2", None):
            index.store._normalize_L2 = False
            
        index.store.add_texts(docs, metadatas=metadatas)

    return i


def embedding_markdowns(index: Index, fpaths, url, replace_by_url):
    i = 0
    docs = []
    metadatas = []
    for fpath in fpaths:
        fname = os.path.split(fpath)[1]
        if is_file_scaned(index, fpath):
            continue

        with codecs.open(fpath, "rb", "utf8") as fp:
            docus = markdown_splitter.create_documents([fp.read()])
            for ichunk, docu in enumerate(docus):
                docs.append(docu.page_content)
                title = quote(docu.page_content.strip().split("\n", maxsplit=1)[0])
                if url:
                    fnameurl = quote(fpath.removeprefix(replace_by_url), safe="")
                    furl = url + fnameurl
                    metadatas.append({"source": f"{furl}#{title}"})
                else:
                    metadatas.append({"source": f"{fname}#{title}"})
                    
        index.scaned_files.add(fname)
        print(f"scaned {fpath}")
        i += 1
        if i > N_BACTCH_FILES:
            break

    if i != 0:
        index.store.add_texts(docs, metadatas=metadatas)

    return i


def load_store(dirpath, name) -> Index:
    """
    Args:
        dirpath: dirpath to store index files
        name: project/file name
    """
    if os.environ.get("OPENAI_API_TYPE") == "azure":
        azure_embeddings_deploymentid = prd.OPENAI_AZURE_DEPLOYMENTS[
            "embeddings"
        ].deployment_id
        # azure_gpt_deploymentid = prd.OPENAI_AZURE_DEPLOYMENTS["chat"].deployment_id

        embedding_model = OpenAIEmbeddings(
            client=None,
            model="text-embedding-ada-002",
            deployment=azure_embeddings_deploymentid,
        )
    else:
        embedding_model = OpenAIEmbeddings(
            client=None,
            model="text-embedding-ada-002",
        )
    
    index = faiss.read_index(f"{os.path.join(dirpath, name)}.index")
    with open(f"{os.path.join(dirpath, name)}.store", "rb") as f:
        store = pickle.load(f)
    store.index = index

    with open(f"{os.path.join(dirpath, name)}.scanedfile", "rb") as f:
        scaned_files = pickle.load(f)
        
    # compatable with azure/openai embeddings
    store.embedding_function = embedding_model.embed_query

    return Index(
        store=store,
        scaned_files=scaned_files,
    )


def new_store() -> Index:
    if os.environ.get("OPENAI_API_TYPE") == "azure":
        embedding_model = OpenAIEmbeddings(
            client=None,
            model="text-embedding-ada-002",
            deployment=azure_embeddings_deploymentid,
        )
    else:
        embedding_model = OpenAIEmbeddings(
            client=None,
            model="text-embedding-ada-002",
        )
        
    store = FAISS.from_texts(["world"], embedding_model, metadatas=[{"source": "hello"}])
    return Index(
        store=store,
        scaned_files=set([]),
    )


def save_store(index: Index, dirpath, name):
    store_index = index.store.index
    fpath_prefix = os.path.join(dirpath, name)
    print(f"save store to {fpath_prefix}")
    faiss.write_index(store_index, f"{fpath_prefix}.index")
    index.store.index = None
    with open(f"{fpath_prefix}.store", "wb") as f:
        pickle.dump(index.store, f)
    index.store.index = store_index

    with open(f"{fpath_prefix}.scanedfile", "wb") as f:
        pickle.dump(index.scaned_files, f)


In [4]:
# incremental scan pdfs
# /home/laisky/data/langchain/pdf/security

def gen_pdfs():
    yield from glob.glob(f"{pdf_dirpath}/**/*.pdf", recursive=True)

def run_scan_pdfs():
#     index = new_store()
#     save_store(
#         index=index, 
#         dirpath=index_dirpath, 
#         name=name,
#     )
    
    total = 0
    while True:
        index = load_store(
            dirpath=index_dirpath,
            name=name,
        )
        n = embedding_pdfs(
            index=index,
            fpaths=gen_pdfs(),
            url=f"https://s3.laisky.com/public/papers/{name}/",
            replace_by_url=f"/home/laisky/data/langchain/pdf/{name}/",
        )
        total += n
        save_store(
            index=index, 
            dirpath=index_dirpath, 
            name=name,
        )
        
#         return
        print(f"scanned {total} files")
        if n == 0:
            return
        
run_scan_pdfs()

[2023-07-24 03:06:36,189 - ERROR - /tmp/ipykernel_3350828/996360124.py:47 - security] - skip file /home/laisky/data/langchain/pdf/security/RFC2986_Certification Request Syntax Specification.pdf
Traceback (most recent call last):
  File "/tmp/ipykernel_3350828/996360124.py", line 39, in embedding_pdfs
    for page, data in enumerate(loader.load_and_split()):
  File "/home/laisky/.pyenv/versions/3.9.7/lib/python3.9/site-packages/langchain/document_loaders/base.py", line 43, in load_and_split
    docs = self.load()
  File "/home/laisky/.pyenv/versions/3.9.7/lib/python3.9/site-packages/langchain/document_loaders/pdf.py", line 118, in load
    return list(self.lazy_load())
  File "/home/laisky/.pyenv/versions/3.9.7/lib/python3.9/site-packages/langchain/document_loaders/pdf.py", line 125, in lazy_load
    yield from self.parser.parse(blob)
  File "/home/laisky/.pyenv/versions/3.9.7/lib/python3.9/site-packages/langchain/document_loaders/base.py", line 95, in parse
    return list(self.lazy_pa

Created a chunk of size 803, which is longer than the specified 500


scaned /home/laisky/data/langchain/pdf/security/TEE/Honeycomb- Secure and Efficient GPU Executions via Static Validation.pdf
scaned /home/laisky/data/langchain/pdf/security/TEE/TIO/Security Protocol and Data Model (SPDM) Specification.pdf
scaned /home/laisky/data/langchain/pdf/security/TEE/TIO/Secured Messages using SPDM Specification.pdf
scaned /home/laisky/data/langchain/pdf/security/TEE/TIO/Integrity and Data Encryption (IDE) ECN Deep Dive.pdf
save store to /home/laisky/data/langchain/index-azure/security
scanned 6 files
[2023-07-24 03:18:45,225 - ERROR - /tmp/ipykernel_3350828/996360124.py:47 - security] - skip file /home/laisky/data/langchain/pdf/security/RFC2986_Certification Request Syntax Specification.pdf
Traceback (most recent call last):
  File "/tmp/ipykernel_3350828/996360124.py", line 39, in embedding_pdfs
    for page, data in enumerate(loader.load_and_split()):
  File "/home/laisky/.pyenv/versions/3.9.7/lib/python3.9/site-packages/langchain/document_loaders/base.py", li

In [5]:
# ====================================
# 生成用于问答的 query chain
# ====================================

from langchain.chains import VectorDBQAWithSourcesChain, RetrievalQAWithSourcesChain
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI
from langchain.prompts.chat import (
    ChatPromptTemplate,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
)
from langchain.chains.question_answering import load_qa_chain
from langchain.chains import LLMChain

system_template="""Use the following pieces of context to answer the users question.
Take note of the sources and include them in the answer in the format: "SOURCES: source1 source2", use "SOURCES" in capital letters regardless of the number of sources.
If you don't know the answer, just say that "I don't know", don't try to make up an answer.
----------------
{summaries}"""
messages = [
    SystemMessagePromptTemplate.from_template(system_template),
    HumanMessagePromptTemplate.from_template("{question}")
]
prompt = ChatPromptTemplate.from_messages(messages)

index = load_store(
    dirpath=index_dirpath,
    name=name,
)

llm = ChatOpenAI(
    client=None,
    model="gpt-3.5-turbo", 
    temperature=0, 
    max_tokens=2000,
    streaming=False,
)  

# llm = AzureChatOpenAI(
#     deployment_name=azure_gpt_deploymentid,
#     model_name="gpt-3.5-turbo",
#     max_tokens=2000,
# )

chain = load_qa_chain(llm, chain_type="stuff")
# chain = LLMChain(llm=llm, prompt=prompt)

In [6]:
# ====================================
# ask pdf embeddings
# ====================================
query = "list tpm's features"


related_docs = index.store.similarity_search(
    query=query,
    k=5,
)

response = chain.run(
#     context=';'.join([d.page_content for d in related_docs]), 
    input_documents=related_docs,
    question=query,
)

print(f"🤔️: {query}\n")
print(f"📖: {response}\n")

🤔️: list tpm's features

📖: Some of the features of TPM (Trusted Platform Module) include:

1. Validation of acceptable integrity metrics: By isolating processes, the set of acceptable platform configurations can be reduced to one operating system and one application only.

2. TPM Components: The TPM can be implemented as an IC (Integrated Circuit) or in software. Most commercially available implementations are hardware-based.

3. Enhanced functionality: TPM 2.0 has been enhanced from TPM 1.2 to support more platforms. This includes adding encryption algorithms, enhancing availability for applications, enhancing authentication features, simplifying TPM management, and adding features that enhance the security of platform services.

4. Monotonic counters: TPM provides secure mechanisms, such as monotonic counters, to prevent replay attacks.

5. Time-stamping: TPM provides the ability to measure time intervals, although absolute measurement of time is not possible.

6. Audit trail manage

In [86]:
# ====================================
# manually qa based on embedidngs step by step
# ====================================
from typing import List
import re

from langchain.schema import (
    AIMessage,
    HumanMessage,
    SystemMessage
)
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain


system_template="""Use the following pieces of context to answer the users question.
If you don't know the answer, or you think more information is needed to provide a better answer, 
just say in this strict format: "I need more informations about: [list keywords that will be used to search more informations]" to ask more informations, 
don't try to make up an answer.
----------------
context: {summaries}"""
messages = [
    SystemMessagePromptTemplate.from_template(system_template),
    HumanMessagePromptTemplate.from_template("{question}")
]
prompt = ChatPromptTemplate.from_messages(messages)


def query_for_more_info(query: str) -> str:
    related_docs = index.store.similarity_search(
        query=query,
        k=5,
    )

    return "; ".join([d.page_content for d in related_docs]) 


chain = LLMChain(llm=llm, prompt=prompt)


query = "how to measure host os in vm by vtpm"


n = 0
last_sub_query = ""
regexp = re.compile(r'I need more information about "([^"]+)"')
while n<3: 
    n += 1
    resp = chain.run({
        "summaries": query_for_more_info(query),
        "question": query,
    })
    matched = regexp.findall(resp)
    if len(matched) == 0:
        break
        
    sub_query = matched[0]
    if sub_query == last_sub_query:
        break
    last_sub_query = sub_query
    
    print(f"require more informations about: {sub_query}")
    query += f"; {query_for_more_info(sub_query)}"
    
print(resp)

To measure the host OS in a virtual machine (VM) using a virtual Trusted Platform Module (vTPM), you can follow these steps:

1. Ensure that your VM has a vTPM enabled. This can usually be done through the VM settings or configuration options provided by your virtualization software.

2. Install the necessary software and drivers for the vTPM in both the host OS and the guest OS. This may involve installing specific TPM device drivers or modules.

3. Once the vTPM is successfully emulated in the VM, you can access the TPM device in the guest OS. In the provided context, it mentions that a device named "tpm0" will be created under "/sys/class/misc/" in the guest OS.

4. Use the appropriate TPM management tools or APIs in the guest OS to measure the host OS. These tools will typically provide functions to read and write to the Platform Configuration Registers (PCR) in the TPM.

5. In the host OS, maintain a log that indicates what has been measured into each PCR register. This log can be

In [70]:
cnt = 'I need more information about "Amber GPU-CC" to provide an accurate answer. Could you please provide more context or clarify your question?'

regexp = re.compile(r'I need more information about "([^\)]+)"')

regexp.findall(cnt)

['Amber GPU-CC']

In [56]:
# ====================================
# use vectore store in functions(agents)
# ====================================

from langchain.agents import initialize_agent, Tool
from langchain.agents import AgentType
from langchain.tools import BaseTool
from langchain.llms import OpenAI
from langchain import LLMMathChain, SerpAPIWrapper


related_docs = index.store.similarity_search(
    query=query,
    k=5,
)


def query_for_agent(query: str) -> str:
    related_docs = index.store.similarity_search(
        query=query,
        k=5,
    )

    return "\n".join([d.page_content for d in related_docs])

def context_for_agent(query: str) -> str:
    related_docs = index.store.similarity_search(
        query=query,
        k=5,
    )
        
    response = chain.run(
        input_documents=related_docs,
        question=query,
    )
    
    return response
    
    
tools = [
    Tool(
        name="Search",
        func=query_for_agent,
        description="useful for when you need to answer questions, this function takes a string as input and returns a string. This function is capable of vectorizing the input string and searching for similar information in a vector database. Your AI can call this function to retrieve the data it needs based on its requirements.",
    ),
]

query = "what is tee-io"

agent = initialize_agent(
#     tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
    tools, llm, agent=AgentType.OPENAI_FUNCTIONS, verbose=True,
)

agent.run(
    input_documents=related_docs,
    question=query,
)



ValueError: Missing some input keys: {'input'}

In [51]:
agent.run??