## Build Qdrant Vector Stores

In [None]:
from llama_index.legacy.vector_stores import QdrantVectorStore
from custom_vectore_store import MultiModalQdrantVectorStore
from custom_embeddings import custom_sparse_doc_vectors, custom_sparse_query_vectors

from functools import partial

from qdrant_client import QdrantClient
from qdrant_client.http import models as qd_models

try:
    client = QdrantClient(path="qdrant_db")
    print("Connected to Qdrant")
except:
    pass
    print("Failed to connect to Qdrant")



import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM

SPLADE_QUERY_PATH = "./embedding_models/efficient-splade-VI-BT-large-query"
splade_q_tokenizer = AutoTokenizer.from_pretrained(SPLADE_QUERY_PATH)
splade_q_model = AutoModelForMaskedLM.from_pretrained(SPLADE_QUERY_PATH)

SPLADE_DOC_PATH = "./embedding_models/efficient-splade-VI-BT-large-doc"
splade_d_tokenizer = AutoTokenizer.from_pretrained(SPLADE_DOC_PATH)
splade_d_model = AutoModelForMaskedLM.from_pretrained(SPLADE_DOC_PATH)

custom_sparse_doc_fn = partial(custom_sparse_doc_vectors, splade_d_tokenizer, splade_d_model, 512)
custom_sparse_query_fn = partial(custom_sparse_query_vectors, splade_q_tokenizer, splade_q_model, 512)


In [None]:
text_store = QdrantVectorStore(
    client=client,
    collection_name="text_collection",
    enable_hybrid=True,
    sparse_query_fn=custom_sparse_query_fn,
    sparse_doc_fn=custom_sparse_doc_fn,
    stores_text=True,
)

image_store = MultiModalQdrantVectorStore(
    client=client,
    collection_name="image_collection",
    enable_hybrid=True,
    sparse_query_fn=custom_sparse_query_fn,
    sparse_doc_fn=custom_sparse_doc_fn,
    stores_text=False,
)

In [None]:
from llama_index.legacy.embeddings import HuggingFaceEmbedding
from custom_embeddings import CustomizedCLIPEmbedding

BGE_PATH = "./embedding_models/bge-small-en-v1.5"
CLIP_PATH = "./embedding_models/clip-vit-base-patch32"
bge_embedding = HuggingFaceEmbedding(model_name=BGE_PATH, device="cpu", pooling="mean")
clip_embedding = CustomizedCLIPEmbedding(model_name=CLIP_PATH, device="cpu")


## Customized Multi-modal Retriever with Reranker

In [None]:
from llama_index.core.postprocessor import SentenceTransformerRerank

bge_reranker = SentenceTransformerRerank(
    model="./embedding_models/bge-reranker-base",
    top_n=3,
    device="cpu",
    keep_retrieval_score=False,
    )


In [None]:
from mm_retriever import MultiModalQdrantRetriever

mm_retriever = MultiModalQdrantRetriever(
    text_vector_store = text_store,
    image_vector_store = image_store, 
    text_embed_model = bge_embedding, 
    mm_embed_model = clip_embedding,
    reranker = bge_reranker,
    text_similarity_top_k = 5,
    text_sparse_top_k = 5,
    text_rerank_top_n = 3,
    image_similarity_top_k = 5,
    image_sparse_top_k = 5,
    image_rerank_top_n = 1,
    sparse_query_fn = custom_sparse_query_fn,
)

In [None]:
from llama_index.legacy.schema import QueryBundle
query_bundle=QueryBundle(query_str="How does Llama 2 perform compared to other open-source models?")

# text_query_result = mm_retriever.retrieve_text_nodes(query_bundle=query_bundle, query_mode="hybrid")
# reranked_text_nodes = mm_retriever.rerank_text_nodes(query_bundle, text_query_result)
# image_query_result = mm_retriever.retrieve_image_nodes(query_bundle=query_bundle, query_mode="hybrid")
# reranked_image_nodes = mm_retriever.rerank_image_nodes(query_bundle, image_query_result)

## Load Quantized LLaVA-1.6 with llama-cpp framework

In [None]:
from llama_cpp.llama_chat_format import Llava15ChatHandler

llava_chat_handler = Llava15ChatHandler(
    clip_model_path = "LLMs/llava-1.6-mistral-7b-gguf/mmproj-model-f16.gguf",
    verbose = False
)


## Load LLaVA with the original llama-cpp python bindings 

# from llama_cpp import Llama

# llava_1_6 = Llama(
#     model_path="LLMs/llava-1.6-mistral-7b-gguf/llava-v1.6-mistral-7b.Q4_K_M.gguf",
#     chat_format="llava-1-5",
#     chat_handler=llava_chat_handler, # Optional chat handler to use when calling create_chat_completion.
#     n_ctx=2048, # (context window size) Text context, 0 = from model
#     logits_all=True, # Return logits for all tokens, not just the last token. Must be True for completion to return logprobs.
#     offload_kqv=True, # Offload K, Q, V to GPU.
#     n_gpu_layers=40,  # Number of layers to offload to GPU (-ngl). If -1, all layers are offloaded.
#     last_n_tokens_size=64, # maximum number of tokens to keep in the last_n_tokens deque.
#     verbose=True,

#     ## LoRA Params
#     # lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
#     # lora_scale: float = 1.0,
#     # lora_path: Path to a LoRA file to apply to the model.

#     ## Tokenizer Override
#     # tokenizer: Optional[BaseLlamaTokenizer] = None,
# )

In [None]:
## Load LLaVA with customized llama-index integration
from llava_llamacpp import Llava_LlamaCPP

model_kwargs = {
    "chat_format":"llava-1-5",
    "chat_handler":llava_chat_handler, 
    "logits_all":True,
    "offload_kqv":True,
    "n_gpu_layers":40,
    "last_n_tokens_size":64,
    
    ## LoRA Params
    # lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
    # lora_scale: float = 1.0,
    # lora_path: Path to a LoRA file to apply to the model.

    ## Tokenizer Override
    # tokenizer: Optional[BaseLlamaTokenizer] = None,
}

llava_1_6 = Llava_LlamaCPP(
    model_path="LLMs/llava-1.6-mistral-7b-gguf/llava-v1.6-mistral-7b.Q3_K_M.gguf",
    temperature=0.5,
    max_new_tokens=1024,
    context_window=4096,
    verbose=True,
    model_kwargs = model_kwargs,
)

## Build Query Engine

In [None]:
from mm_query_engine import CustomMultiModalQueryEngine

query_engine = CustomMultiModalQueryEngine(
    retriever = mm_retriever,
    multi_modal_llm = llava_1_6,
)

In [None]:
# retrieval_results = query_engine.retrieve(query_bundle=query_bundle, text_query_mode="hybrid", image_query_mode="default")
# response = query_engine.synthesize(query_bundle, retrieval_results)

In [None]:
response = query_engine.query(query_bundle)