# Multi-modal RAG in 60 Lines of Code

## 1. Create a multi-modal vector database

In [None]:
from typing import List, Union

import torch
from PIL import Image
from tqdm import tqdm
from transformers import AutoModel, AutoProcessor, AutoTokenizer, CLIPTextModelWithProjection, CLIPVisionModelWithProjection
from langchain.text_splitter import CharacterTextSplitter


class MultiModalVectorStore:

    def __init__(self):
        # image store
        self.img_processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
        self.vision_model = CLIPVisionModelWithProjection.from_pretrained(
            "openai/clip-vit-base-patch32", torch_dtype=torch.float16
        ).cuda()
        self.img_store = []

        # text store
        self.splitter = CharacterTextSplitter(separator="\n\n", chunk_size=1000, chunk_overlap=200)
        self.tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
        self.text_model = CLIPTextModelWithProjection.from_pretrained(
            "openai/clip-vit-base-patch32", torch_dtype=torch.float16
        ).cuda()
        self.text_store = []
        
    def add_image(self, image: Image.Image) -> List[float]:
        inputs = self.img_processor(images=image, return_tensors="pt").to("cuda")
        outputs = self.vision_model(**inputs)
        emb = outputs.image_embeds
        self.img_store.append((image, emb))
    
    def add_text(self, text: str) -> List[float]:
        chunks = self.splitter.split_text(text)

        for chunk in chunks:
            inputs = self.tokenizer(chunk, padding=True, return_tensors="pt").to("cuda")
            outputs = self.text_model(**inputs)
            emb = outputs[0]
            self.text_store.append((chunk, emb))

    def retrieve(self, query: str, top_k_text_chunks: int) -> List[str]:
        pass

## 2. Create a vision language model

In [None]:
class VLM:

    def __init__(self):
        # Load the model in half-precision
        self.model = AutoModel.from_pretrained('openbmb/MiniCPM-V-2', torch_dtype=torch.bfloat16).cuda()
        self.tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2')

    def chat(self, query: str, txt_context: List[str], img: Image.Image):
        context_prompt = "\n".join([f" - {chunk}" for chunk in txt_context])
        instruction_prompt = f"""You are a helpful chatbot.
        Use only the following pieces of context to answer the question. Don't make up any new information:
        {context_prompt}

        Question: {query}
        """
        msgs = [{'role': 'user', 'content': instruction_prompt}]
        res, _, _ = self.model.chat(img, msgs, context=None, tokenizer=self.tokenizer, sampling=True, temperature=0.7
        )

        return res

## 3. Create a RAG

In [None]:
class MMRag:

    def __init__(self):
        self.store = MultiModalVectorStore()
        self.chat_model = VLM()

    def add(self, data: Union[str, Image.Image]) -> None:
        if isinstance(data, str):
            self.store.add_text(data)
        else:
            self.store.add_image(data)
    
    def query(self, query: str, top_n_text: int = 3, top_n_image: int = 3):
        text_chunks, img = self.store.retrieve(query, top_n_text)
        response = self.chat_model.chat(query, text_chunks, img)
        return response

## Example

We can finally test our multi-modal RAG system. Let's use a wikipedia article as our data source.