# Импортируем необходимые методы

In [None]:
import os
import re
import pickle
from langchain.text_splitter import SentenceTransformersTokenTextSplitter
from pypdf import PdfReader
from typing import List
import matplotlib.pyplot as plt
import google.generativeai as genai
from langchain_milvus.retrievers import MilvusCollectionHybridSearchRetriever
from langchain_milvus.utils.sparse import BM25SparseEmbedding
from pymilvus import (
    Collection,
    CollectionSchema,
    DataType,
    FieldSchema,
    WeightedRanker,
    connections,
    utility
)

from langchain.embeddings.base import Embeddings
from sentence_transformers import SentenceTransformer

CONNECTION_URI = "http://localhost:19530"
connections.connect(uri=CONNECTION_URI)

In [10]:
def load_pdf_file(file_path: str) -> str:
    """
    Loads text from a PDF file and returns it as a single string.

    Parameters:
    - file_path (str): Path to the PDF file.

    Returns:
    - str: The extracted text from the PDF, with line breaks replaced by spaces.
    """
    reader = PdfReader(file_path)
    text = ""
    for page in reader.pages:
        text += page.extract_text()

    return text.replace('\n', ' ')


def split_text_to_token(text: str) -> List[str]:
    """
    Splits a given text into smaller chunks based on sentences and token limits.

    Parameters:
    - text (str): The full text to be split.

    Returns:
    - List[str]: A list of text chunks, each within the token limit.
    """
    split_text = re.split('\. ', text)
    token_splitter = SentenceTransformersTokenTextSplitter(chunk_overlap=0, tokens_per_chunk=256)
    token_split_texts = []
    for text in [i for i in split_text if i != ""]:
        token_split_texts += token_splitter.split_text(text)

    return token_split_texts

class SentenceTransformerEmbeddings(Embeddings):
    def __init__(self, model_name="all-MiniLM-L6-v2"):
        self.model = SentenceTransformer(model_name)

    def embed_documents(self, documents):
        return self.model.encode(documents)

    def embed_query(self, query):
        return self.model.encode([query])[0]
    
def get_hybrid_function(
        text_field: str, 
        config_path: str,
        k: int=5, 
        metric: str='IP'
    ) -> MilvusCollectionHybridSearchRetriever:

    with open(config_path, "rb") as f:
        config = pickle.load(f)
        
    sparse_search_params = {"metric_type": metric}
    dense_search_params = {"metric_type": metric, "params": {}}
    collection = Collection(config['colection_name'])
    collection.load()
    retriever = MilvusCollectionHybridSearchRetriever(
        collection = collection,
        rerank=WeightedRanker(0.5, 0.5),
        anns_fields=[config['dense_field'], config['sparse_field']],
        field_embeddings=[config['dense_embedding_func'], config['sparse_embedding_func']],
        field_search_params=[dense_search_params, sparse_search_params],
        top_k=k,
        text_field=text_field,
        nprobe=20
    )

    return retriever


def save_params_for_retriever(
    collection: Collection, 
    dense_field: str,
    sparse_field: str,
    dense_embedding_func: SentenceTransformerEmbeddings,
    sparse_embedding_func: BM25SparseEmbedding
    ) -> None:
    config = {
        'colection_name': collection.name,
        'dense_field': dense_field,
        'sparse_field': sparse_field,
        'dense_embedding_func': dense_embedding_func,
        'sparse_embedding_func': sparse_embedding_func
    }

    with open("config_params.pkl", "wb") as f:
        pickle.dump(config, f)

  split_text = re.split('\. ', text)


# Create MILVUS Database

In [None]:
# Удаляем предыдущие коллекции

collections = utility.list_collections()

# Удаляем каждую коллекцию из списка
for collection_name in collections:
    collection = Collection(name=collection_name)
    collection.drop()
    print(f"Коллекция {collection_name} удалена.")

print("Все коллекции удалены.")

Коллекция Pushkin удалена.
Все коллекции удалены.


In [5]:
pk_field = "doc_id"
dense_field = "dense_vector"
sparse_field = "sparse_vector"
text_field = "text"
fields = [
    FieldSchema(
        name=pk_field,
        dtype=DataType.VARCHAR,
        is_primary=True,
        auto_id=True,
        max_length=100,
    ),
    FieldSchema(name=dense_field, dtype=DataType.FLOAT_VECTOR, dim=384),
    FieldSchema(name=sparse_field, dtype=DataType.SPARSE_FLOAT_VECTOR),
    FieldSchema(name=text_field, dtype=DataType.VARCHAR, max_length=65_535),
]

In [6]:
schema = CollectionSchema(fields=fields, enable_dynamic_field=False)
collection = Collection(
    name="Moscow", schema=schema, consistency_level="Strong"
)

In [7]:
dense_index = {"index_type": "FLAT", "metric_type": "IP"}
collection.create_index("dense_vector", dense_index)
sparse_index = {"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "IP"}
collection.create_index("sparse_vector", sparse_index)
collection.flush()

In [39]:
pdf_text = load_pdf_file(file_path='dataset/skazka_o_rubake_i_rubke.pdf')
chunked_text = split_text_to_token(text=pdf_text)

In [10]:
dense_embedding_func = SentenceTransformerEmbeddings()
sparse_embedding_func = BM25SparseEmbedding(corpus=chunked_text)

In [11]:
entities = []
for text in chunked_text:
    entity = {
        dense_field: dense_embedding_func.embed_documents([text])[0],
        sparse_field: sparse_embedding_func.embed_documents([text])[0],
        text_field: text,
    }
    entities.append(entity)
collection.insert(entities)
collection.load()

In [12]:
save_params_for_retriever(
    collection=collection, 
    dense_field=dense_field, 
    sparse_field=sparse_field,
    dense_embedding_func=dense_embedding_func,
    sparse_embedding_func=sparse_embedding_func
)

# Используем MILVUS для поиска

In [None]:
name_of_collection = 'Pushkin'

In [None]:
retriever = get_hybrid_function(
    text_field='text', 
    config_name=f'config_params_{name_of_collection}.pkl'
)

In [12]:
def make_rag_prompt(query:str, relevant_passage:str) -> str:
    escaped = relevant_passage.replace("'", "").replace('"', "").replace("\n", " ")
    prompt = f"""
    Human: You are a helpful and informative bot that answers questions using text from the passage below. 
    Be sure to answer in a full sentence, exhaustively, including all relevant background information.
    However, you are speaking to a non-technical audience, so be sure to break down complex concepts and
    keep your tone friendly and accommodating.
    If the passage is not relevant to the answer, you can ignore it.
    Context is given between <context>.
    The question is given between </question>.
    Thank you!

    <context>
    {escaped}
    </context>

    <question>
    {query}
    </question>

    Assistant:
    """

    return prompt

In [13]:
def get_gpt_model() -> genai.GenerativeModel:
    gemini_api_key = os.getenv("GEMINI_API_KEY")
    if not gemini_api_key:
        raise ValueError("Gemini API Key not provided. Please provide GEMINI_API_KEY as an environment variable")
    genai.configure(api_key=gemini_api_key)
    model = genai.GenerativeModel('gemini-pro')

    return model

In [None]:
def generate_answer_promt(query:str, model: genai.GenerativeModel, collection:MilvusCollectionHybridSearchRetriever) -> str:
    similar_docs = collection.invoke(query)
    relevant_text = ". ".join([doc.page_content.capitalize() for doc in similar_docs])

    prompt = make_rag_prompt(query=query, relevant_passage=relevant_text)

    answer = model.generate_content(prompt)
    
    return answer.text

In [None]:
model = get_gpt_model()
query="What are the components of pain signals?"
answer = generate_answer_promt(query=query, model=model, collection=retriever)
answer

'Pain signals are made up of two components: first pain and second pain. First pain is rapidly transmitted and has high spatial resolution, meaning it can be precisely localized. Second pain is much slower, poorly localized, and poorly tolerated.'

In [18]:
# collection.drop()

In [86]:
# embeddings = db.get(include=['embeddings'])['embeddings']
# umap_transform = umap.UMAP(random_state=0, transform_seed=0).fit(embeddings)

In [87]:
# def project_embeddings(embeddings, umap_transform):
#     umap_embeddings = np.empty((len(embeddings),2))
#     for i, embedding in enumerate(tqdm(embeddings)): 
#         umap_embeddings[i] = umap_transform.transform([embedding])
#     return umap_embeddings   

# projected_dataset_embeddings = project_embeddings(embeddings, umap_transform)

In [88]:
# query_embedding = embedding_function([query])[0]
# retrieved_embeddings = relevant_text['embeddings'][0]

# projected_query_embedding = project_embeddings([query_embedding], umap_transform)
# projected_retrieved_embeddings = project_embeddings(retrieved_embeddings, umap_transform)

In [89]:
# plt.figure()
# plt.scatter(projected_dataset_embeddings[:, 0], projected_dataset_embeddings[:, 1], s=10, color='gray')
# plt.scatter(projected_query_embedding[:, 0], projected_query_embedding[:, 1], s=50, marker='X', color='r')
# plt.scatter(projected_retrieved_embeddings[:, 0], projected_retrieved_embeddings[:, 1], s=100, facecolors='none', edgecolors='g')

# plt.gca().set_aspect('equal', 'datalim')
# plt.title(f'{query}')
# plt.show()
# # plt.axis('off')

In [34]:
# augmented_query_embedding = embedding_function([joint_query])
# original_query_embedding = embedding_function([query])
# retrieved_embeddings = relevant_text['embeddings'][0]

In [35]:
# projected_original_query_embedding = project_embeddings(original_query_embedding, umap_transform)
# projected_augmented_query_embedding = project_embeddings(augmented_query_embedding, umap_transform)
# projected_retrieved_embeddings = project_embeddings(retrieved_embeddings, umap_transform)

In [36]:
# plt.figure()
# plt.scatter(projected_dataset_embeddings[:, 0], projected_dataset_embeddings[:, 1], s=10, color='gray')
# plt.scatter(projected_retrieved_embeddings[:, 0], projected_retrieved_embeddings[:, 1], s=100, facecolors='none', edgecolors='g')
# plt.scatter(projected_original_query_embedding[:, 0], projected_original_query_embedding[:, 1], s=50, marker='X', color='r')
# plt.scatter(projected_augmented_query_embedding[:, 0], projected_augmented_query_embedding[:, 1], s=50, marker='X', color='blue')

# plt.gca().set_aspect('equal', 'datalim')
# plt.title(f'{query}')
# plt.show()