# Мультимодальный RAG на описаниях изображений

In [None]:
#from google.colab import drive
#drive.mount('/content/drive')
#root_path = "/content/drive/MyDrive/Diploma-mag"

In [None]:
# pip install -q -U langchain langchain-gigachat open-clip-torch

In [None]:
from langchain_gigachat import Gigachat
from langchain_gigachat.embeddings import GigaChatEmbeddings

def init_gigachat():
    return Gigachat(credentials="ключ_авторизации", model="GigaChat-Max", verify_ssl_certs=False, temperature=1e-15, timeout=100)


def init_gigachat_embeddings():
    return GigaChatEmbeddings(credentials="ключ_авторизации", scope-"GIGACHAT_API_PER", verify_ssl_certs=False)

In [None]:
from langchain.storage import InMemoryStore

llm = init_gigachat()
embeddings = init_gigachat_embeddings()

id_key = "doc_id"
doc_ids = []

docstore_dir = "./data/multimodal_rag_with_summaries/doc_store"
vectorstore_dir = "./data/multimodal_rag_with_summaries/vectorstore"

docstore = InMemoryStore()

In [None]:
from langchain.retrievers import MultiVectorRetriever
from langchain_chroma import Chroma
from chromadb.config import Settings


text_vectorstore = Chroma(
    persist_directory=vectorstore_dir,
    embedding_function=embeddings,
    collection_name="mm_rag_text_gigaembeddings",
    client_settings=Settings(anonymized_telemetry=False)
)

retriever = MultiVectorRetriever(
        vectorstore=text_vectorstore,
        docstore=docstore,
        id_key=id_key
)

In [None]:
import json
with open("./extracted_data/extracted_texts.json", "r") as f:
    documents = json.load(f)

with open("./extracted_data/image_summary.json", "r") as f:
    summaries = json.load(f)

In [None]:
import uuid

documents_content = []
documents_page = []

for d in documents:
    documents_content.append(d["text"])
    documents_page.append(d["metadata"]["page_number"])


doc_ids = [str(uuid.uuid4()) for _ in documents_content]

In [None]:
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter

text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)

prepared_text = []
for i, document_content in enumerate(documents_content):
    chunks = text_splitter.split_text(document_content)
    prepared_text += [Document(
        page_content=chunk,
        metadata={
            "page_number": documents_page[i],
            "doc_id": doc_ids[i]
        })
    for j, chunk in enumerate(chunks)]

all_chunks = [text.page_content for text in prepared_text]

text_vectorstore.add_documents(prepared_text)
retriever.docstore.mset(list(zip(doc_ids, documents_content)))

In [None]:
summaries_content = []
summaries_page = []
summaries_source = []

for s in summaries:
    summaries_content.append(s["image_summary"])
    summaries_page.append(s["page_number"])
    summaries_source.append(s["source"])

summaries_ids = [str(uuid.uuid4()) for _ in summaries_content]

text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)

prepared_text = []
for i, summary_content in enumerate(summaries_content):
    chunks = text_splitter.split_text(summary_content)
    prepared_text += [Document(
        page_content=chunk,
        metadata={
            "page_number": summaries_content[i],
            "doc_id": summaries_ids[i],
            "source": summaries_source[i]
        })
    for j, chunk in enumerate(chunks)]

retriever.vectorstore.add_documents(prepared_text)

# Пайплайн GigaChat

In [None]:
from langchain_core.messages import HumanMessage
from langchain.prompts import ChatPromptTemplate
from prompts import QA_PROMPT_SYSTEM, QA_PROMPT_USER

def run_pipeline_gigachat(question, text_vectorstore, img_vectorstore, llm):
    response = retriever.invoke(question)
    images = []
    texts = []
    
    for r in response:
        if r.metadata.get("source", None) is not None:
            images.append(r)
        else:
            texts.append(r)


    context = "\n\n".join([t.page_content for t in texts])
    
    file = None
    if len(images) > 0:
        img_path = images[0].metadata["source"]
        file = llm.upload_file(open(img_path, "rb"))

    text_content = QA_PROMPT_USER.format(context=context, question=question)

    if file is not None:
        prompt = ChatPromptTemplate.from_messages(
            [
                ("system", QA_PROMPT_SYSTEM),
                HumanMessage(content=text_content, additional_kwargs={"attachments": [file.id_]})
            ]
        )
    else:
        prompt = ChatPromptTemplate.from_messages(
            [
                ("system", QA_PROMPT_SYSTEM),
                HumanMessage(content=text_content)
            ]
        )
    chain = prompt | llm

    return  chain.invoke({}).content

In [None]:
question = "Как выделить прямоугольную область на изображении в Adobe Photoshop?"
run_pipeline_gigachat(question, text_vectorstore, img_vectorstore, llm)

# Пайплайн LLaVa

In [None]:
#!pip install -q -U transformers bitsandbytes accelerate

In [None]:
from transformers import BitsAndBytesConfig, LlavaNextProcessor, LlavaNextForConditionalGeneration
from PIL import Image
import io
import pandas as pd
from typing import Tuple

quantization_config = BitsAndBytesConfig(
    load_in_8bit=True,
    low_cpu_mem_usage=True,
    use_flash_attention_2=True
)


def get_qa_prompt(model_id:str, system_prompt:str, question: str, context: str, image: Image=None) -> str:
    return  f"[INST]{'<image>' if image else ' '}\n{system_prompt}\n{context}\n\nQuestion:\n{question}\n\n[/INST]"


def format_output(raw_output, processor: LlavaNextProcessor, prompt: str) -> str:
    out = processor.decode(raw_output[0], skip_special_tokens=True)
    out_prompt = prompt.replace("<image>", " ").strip()
    formatted_output = out.replace(out_prompt, "").strip()
    return formatted_output


def get_prompt(task: str, model_id: str, system_prompt: str, text: str, image: Image, question: str) -> str:
    prompt = get_qa_prompt(model_id, system_prompt, question, text, image)
    return prompt


def llava_call(prompt: str, model: LlavaNextForConditionalGeneration, processor: LlavaNextProcessor, device: str, image: Image=None) -> str:
    inputs = processor(prompt, image, return_tensors="pt").to(device)
    raw_output = model.generate(**inputs, max_new_tokens=300)
    formatted_output = format_output(raw_output, processor, prompt)
    return formatted_output


def load_llava_model(model_id: str) -> Tuple[LlavaNextForConditionalGeneration, LlavaNextProcessor]:
    processor = LlavaNextProcessor.from_pretrained(model_id)
    model = LlavaNextForConditionalGeneration.from_pretrained(model_id, quantization_config=quantization_config, device_map="auto", cache_dir=root_path + "/models")
    #model = LlavaNextForConditionalGeneration.from_pretrained(model_id, device_map="auto")
    return model, processor

In [None]:
from prompts import QA_PROMPT_SYSTEM, QA_PROMPT_USER

def run_pileline_llava(question, text_vectorstore, img_vectorstore, model, processor, device="cuda"):
    response = retriever.invoke(question)
    images = []
    texts = []
    
    for r in response:
        if r.metadata.get("source", None) is not None:
            images.append(r)
        else:
            texts.append(r)

    context = "\n\n".join([t.page_content for t in text_content])

    if len(images) > 0:
      img_path = images[0].metadata["source"]
      image = Image.open(img_path)
      img_prompt = get_qa_prompt("llava-hf/llava-v1.6-mistral-7b-hf", QA_PROMPT_SYSTEM, question, context, image)
      return llava_call(img_prompt, model, processor, device, image)

    no_img_prompt = get_qa_prompt("llava-hf/llava-v1.6-mistral-7b-hf", QA_PROMPT_SYSTEM, question, context)
    return llava_call(no_img_prompt, model, processor, device)

In [None]:
question = "Как выделить прямоугольную область на изображении в Adobe Photoshop?"
model, processor = load_llava_model("llava-hf/llava-v1.6-mistral-7b-hf")
model = model.eval()

response = run_pileline_llava(question, text_vectorstore, img_vectorstore, model, processor)

# Оценка ответа (DeepSeek-R1)

In [None]:
from langchain_gigachat import Gigachat
def init_deepseek():
    return Gigachat(credentials="ваш_ключ_авторизации", model="DeepSeek-R1", verify_ssl_certs=False, temperature=1e-15, timeout=100)

In [None]:
from langchain.prompts import ChatPromptTemplate
from prompts import ANSWER_CORRECTNESS_SYSTEM, ANSWER_CORRECTNESS_USER

answer_correctness = ChatPromptTemplate.from_messages(
    [
        ("system", ANSWER_CORRECTNESS_SYSTEM),
        ("human", ANSWER_CORRECTNESS_USER)
    ]
)
llm = init_deepseek()

answer_correctness_chain = answer_correctness | llm
response = answer_correctness_chain.invoke("question": question, "reference_answer": reference_answer, "generated_answer": generated_answer)

In [None]:
from langchain.prompts import ChatPromptTemplate
from prompts import ANSWER_RELEVANCE_SYSTEM, ANSWER_RELEVANCE_USER

answer_relevance = ChatPromptTemplate.from_messages(
    [
        ("system", ANSWER_RELEVANCE_SYSTEM),
        ("human", ANSWER_RELEVANCE_USER)
    ]
)
llm = init_deepseek()

answer_relevance_chain = answer_relevance | llm
generated_answer = """Чтобы выделить прямоугольную область в Photoshop, используйте инструмент "Прямоугольная область" (Rectangular Marquee Tool) на панели инструментов"""
response = answer_relevance_chain.invoke("query": question, "text": generated_answer)

In [None]:
from langchain.prompts import ChatPromptTemplate
from prompts import CONTEXT_RELEVANCE_TEXT_SYSTEM, CONTEXT_RELEVANCE_TEXT_USER

context_relevance_text = ChatPromptTemplate.from_messages(
    [
        ("system", CONTEXT_RELEVANCE_TEXT_SYSTEM),
        ("human", CONTEXT_RELEVANCE_TEXT_USER)
    ]
)
llm = init_deepseek()


context = """Чтобы выделить прямоугольную область в Photoshop, используйте инструмент "Прямоугольная область" (Rectangular Marquee Tool) на панели инструментов"""
context_relevance_text_chain = context_relevance_text | llm
response = context_relevance_text_chain.invoke("query": question, "context": context)

In [None]:
from langchain.prompts import ChatPromptTemplate
from prompts import CONTEXT_RELEVANCE_IMAGE_SYSTEM, CONTEXT_RELEVANCE_IMAGE_USER

context_relevance_image = ChatPromptTemplate.from_messages(
    [
        ("system", CONTEXT_RELEVANCE_IMAGE_SYSTEM),
        ("human", CONTEXT_RELEVANCE_IMAGE_USER)
    ]
)

llm = init_deepseek()

context_relevance_image_chain = context_relevance_image | llm
response = context_relevance_image_chain.invoke("query": question)