In [None]:
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility

HOST = "milvus"
PORT = "19530"
COLLECTION_NAME = 'TextSearchCollection'
INDEX_PARAMS = {
    'metric_type': 'COSINE',
    'index_type': 'IVF_FLAT',
    'params': {'nlist': 4},
}
DIMENSTION = 1536 # OpenAI Standard 
 
connections.connect(host=HOST, port=PORT)

In [None]:
def create_milvus_collection(collection_name, dim):
    if utility.has_collection(collection_name):
        utility.drop_collection(collection_name)
    
    fields = [
        FieldSchema(name='id', dtype=DataType.INT64, descrition='id', is_primary=True, auto_id=True),
        FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, descrition='embedding vectors', dim=DIMENSTION),
        FieldSchema(name='content', dtype=DataType.VARCHAR, max_length=4096)
    ]
    schema = CollectionSchema(fields=fields)
    collection = Collection(name=collection_name, schema=schema)

    collection.create_index(field_name="embedding", index_params=INDEX_PARAMS)
    return collection

collection = create_milvus_collection(COLLECTION_NAME, DIMENSTION)
collection.load() # load 完之後才可以搜尋

In [None]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema.document import Document

from langchain_openai import OpenAIEmbeddings 

def handle_text(text: str):
    CHUNK_SIZE = 400
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_SIZE//5)

    return [Document(page_content=x) for x in text_splitter.split_text(text)]

def parse_content(text):
    documents = []
    documents = handle_text(text)
    embedding_model = OpenAIEmbeddings()
    return [
        {
            'embedding': embedding_model.embed_documents([d.page_content])[0],
            'content': d.page_content,
        } for d in documents
    ]

with open('./articles/天劍一夢.txt', 'r') as file:
    text = file.read()
    # print(text)
    entities = parse_content(text)
    print(entities[0])
    collection.insert(entities)

In [None]:
from langchain_openai import OpenAIEmbeddings
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from langchain.prompts import PromptTemplate

# ===== Step 1: Embeding ===== #
query = '天劍門歷史？'
query_embedding = [OpenAIEmbeddings().embed_query(query)]


# ===== Step 2: Search knowledge ===== #
search_result = collection.search(
    data=query_embedding,
    anns_field = 'embedding',
    limit=5, 
    param={
        'metric_type': 'COSINE', 
        'params': {},
    },
    output_fields=['id', 'content'],
)

for hits in search_result:
    for hit in hits:
        print(hit.get('id'), hit.get('content'))


# ====== Step 3: provide context to LLM ===== #
context = '\n'.join(hit.get('content') for hit in hits for hits in search_result)
parser = StrOutputParser()
prompt = PromptTemplate(
    template='請根據以下資料回答使用者問題，只能依據資料回答，不得捏造. 問題:{query}. 資料:{context}',
    input_variables=['query', 'context'],
    
)
llm = ChatOpenAI(model='gpt-3.5-turbo')

chain = prompt | llm | parser
response = chain.invoke({'query': query, 'context': context})
print(response)
