## Configuration

In [1]:
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from importlib import resources as resources
from typing import Callable

import torch


@dataclass
class ModelCfg:
    model: str = "microsoft/Phi-3.5-vision-instruct"
    device_map: str = 'cpu'#'cuda'
    torch_dtype: torch.dtype | str = torch.float32
    model_kwargs: dict = field(default_factory=lambda: {'trust_remote_code':True, '_attn_implementation': 'flash_attention_2'})
    #CHANGE _ATTN_IMPLENTATION TO 'eager' IF 'flash_attention_2' NOT SUPPORTED
    processor_kwargs: dict = field(default_factory=lambda: {'trust_remote_code':True, 'num_crops':16})


@dataclass
class RetrievalCfg:
    n_pages: int = 20
    n_passages: int = 7
    chunk_size: int = 300
    chunk_overlap: int = 150
    separators: list = field(default_factory=lambda: ["\n\n", "\n", ".", ","])
    length_function: Callable = len
    query_separator: str = 'Keywords: '


@dataclass
class EmbeddingModelCfg:
    embedding_model_name: str = "Snowflake/snowflake-arctic-embed-l"
    embedding_model_kwargs: dict = field(default_factory=lambda: {'device': 'cpu',
                                                                  'model_kwargs': {'torch_dtype': torch.float32},
                                                                  })
    encode_kwargs: dict = field(default_factory=lambda: {'normalize_embeddings': True})


@dataclass
class RetrievedPassage:
    passage: str
    page: str
    url: str


class TemplateCfg(ABC):
    template: str
    generation_kwargs: dict

    @abstractmethod
    def format_template(self, *args, **kwargs):
        raise NotImplementedError


def load_template(file: str) -> str:
    with resources.files("templates").joinpath(file).open('r') as f:
        instruction_template = f.read()
    return instruction_template


@dataclass
class KeywordGenerationCfg(TemplateCfg):
    template: str = load_template("keyword_generation.txt")
    prompt_addition: str = 'Keywords: '
    generation_kwargs: dict = field(default_factory=lambda: {'max_new_tokens': 64})

    def format_template(self, question: str) -> str:
        return self.template.format(question=question)


@dataclass
class QuestionAnsweringCfg(TemplateCfg):
    qa_template: str = load_template('question_answering.txt')
    passage_template: str = load_template('passage_quote.txt')
    generation_kwargs: dict = field(default_factory=lambda: {'max_new_tokens': 256})
    model_answer_split: str = 'Assistant: '

    def format_template(self, question: str, retrieval_results: list[RetrievedPassage]) -> str:
        retrieved = ''
        for result in retrieval_results:
            retrieved += self.passage_template.format(page=result.page, passage=result.passage)
        return self.qa_template.format(passages=retrieved, question=question)

## Wiki Retrieval

In [2]:
import warnings

import wikipedia
from wikipedia import WikipediaPage
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_text_splitters import RecursiveCharacterTextSplitter




class WikipediaRetriever:
    def __init__(self,
                 embedding_cfg: EmbeddingModelCfg = EmbeddingModelCfg(),
                 ):
        self.embeddings_cfg = embedding_cfg
        self.embeddings = HuggingFaceEmbeddings(model_name=embedding_cfg.embedding_model_name,
                                                model_kwargs=embedding_cfg.embedding_model_kwargs,
                                                encode_kwargs=embedding_cfg.encode_kwargs)

    @staticmethod
    def find_pages(keywords: str, n_pages: int) -> list[WikipediaPage]:
        titles = wikipedia.search(keywords, results=n_pages)
        pages = []
        for t in titles:
            try:
                pages.append(wikipedia.page(t, auto_suggest=False))
            except Exception as e:
                warnings.warn(str(e))
        return pages

    def select_passages(self,
                        pages: list[WikipediaPage],
                        question: str,
                        keywords: str,
                        retrieval_cfg: RetrievalCfg) -> list[RetrievedPassage]:
        splitter = RecursiveCharacterTextSplitter(chunk_size=retrieval_cfg.chunk_size,
                                                  chunk_overlap=retrieval_cfg.chunk_overlap,
                                                  length_function=retrieval_cfg.length_function,
                                                  )
        page_splits = splitter.create_documents(texts=[p.content for p in pages],
                                                metadatas=[{'page_no': i} for i in range(len(pages))])
        db = FAISS.from_documents(page_splits, self.embeddings)
        docs = db.similarity_search(question + retrieval_cfg.query_separator + keywords,
                                    k=retrieval_cfg.n_passages)
        return [RetrievedPassage(passage=docs[i].page_content,
                                 page=pages[docs[i].metadata['page_no']].title,
                                 url=pages[docs[i].metadata['page_no']].url)
                for i in range(retrieval_cfg.n_passages)]

    def __call__(self,
                 question: str,
                 keywords: str,
                 retrieval_cfg: RetrievalCfg) -> list[RetrievedPassage]:
        pages = self.find_pages(keywords, n_pages=retrieval_cfg.n_pages)
        if len(pages) < 1:
            warnings.warn(f"Unable to retrieve any page with the keywords {keywords}")
            selected_passages = []
        else:
            selected_passages = self.select_passages(pages=pages,
                                                     question=question,
                                                     keywords=keywords,
                                                     retrieval_cfg=retrieval_cfg)
        return selected_passages

## Pipeline

In [3]:
from transformers import AutoProcessor, AutoModelForCausalLM
from transformers.image_utils import load_image
from PIL import Image





In [4]:



def format_messages(prompt: str) -> list:
    messages = [{
        "role": "user",
        "content": "<|image_1|>\n" + prompt,
    }
    ]
    return messages


class VisualQuestionRAG:
    """Pipeline for Visual Question Answering with Retrieval Augmented Generation from Wikipedia.
       Attributes:
             cfg: the model configuration.
             model: the loaded Vision Language Model.
             processor: the huggingface transformers' processor for the VLM.
             retriever: an instance of the Wikipedia retriever to be used for retrieval.
    """

    def __init__(self,
                 retriever: WikipediaRetriever,
                 cfg: ModelCfg = ModelCfg(),
                 ):
        self.cfg = cfg
        self.model = AutoModelForCausalLM.from_pretrained(cfg.model,
                                                          torch_dtype=cfg.torch_dtype,
                                                          device_map=cfg.device_map,
                                                          **cfg.model_kwargs)
        self.processor = AutoProcessor.from_pretrained(cfg.model, **cfg.processor_kwargs)
        self.retriever = retriever

    def generate(self,
                 prompt: str,
                 img: Image.Image,
                 **kwargs) -> str:
        messages = format_messages(prompt)
        text = self.processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = self.processor(text, [img], return_tensors="pt").to(self.model.device)
        generated_tokens = self.model.generate(**inputs,
                                               eos_token_id=self.processor.tokenizer.eos_token_id,
                                               **kwargs)[:, inputs['input_ids'].shape[1]:]
        generated_text = self.processor.batch_decode(generated_tokens,
                                                     skip_special_tokens=True,
                                                     clean_up_tokenization_spaces=False)[0]

        return generated_text

    def generate_keywords(self,
                          query: str,
                          img: str | Image.Image,
                          template: KeywordGenerationCfg = KeywordGenerationCfg(),
                          ) -> str:
        prompt = template.format_template(query)
        generated_text = self.generate(prompt=prompt, img=img, **template.generation_kwargs)
        return generated_text

    def rag_question_answering(self,
                               question: str,
                               retrieval_results: list[RetrievedPassage],
                               img: str | Image.Image,
                               template: QuestionAnsweringCfg = QuestionAnsweringCfg(),
                               ) -> str:
        prompt = template.format_template(question=question, retrieval_results=retrieval_results)
        generated_text = self.generate(prompt=prompt,
                                       img=img,
                                       **template.generation_kwargs)
        return generated_text

    def __call__(self,
                 question: str,
                 img_or_img_path: str | Image.Image,
                 keyword_generation_cfg: KeywordGenerationCfg = KeywordGenerationCfg(),
                 retrieval_cfg: RetrievalCfg = RetrievalCfg(),
                 answer_generation_cfg: QuestionAnsweringCfg = QuestionAnsweringCfg(),
                 ) -> tuple[str, str, list[RetrievedPassage]]:
        """
        Args:
            question: The user's question.
            img_or_img_path: If a string, the image will be loaded from the path or url. Otherwise, a PIL Image object
                             is accepted.
            keyword_generation_cfg: an instance of the KeywordGenerationCfg class. It specifies the configuration for
                                    keyword generation.
            retrieval_cfg: an instance of the RetrievalCfg class. It specifies the configuration for passage retrieval.
            answer_generation_cfg: an instance of the QuestionAnsweringCfg class. It specifies the configuration for the
                                   generation of the retrieval augmented answer.


        Returns:
           A tuple containing the answer to the user's question, the keywords used for retrieval, and the selected
           passages.

        """
        img = load_image(img_or_img_path)
        keywords = self.generate_keywords(query=question, img=img, template=keyword_generation_cfg)
        selected_passages = self.retriever(question=question, keywords=keywords, retrieval_cfg=retrieval_cfg)
        answer = self.rag_question_answering(question=question,
                                             img=img,
                                             retrieval_results=selected_passages,
                                             template=answer_generation_cfg)
        return answer, keywords, selected_passages

In [5]:
cfg = ModelCfg(model="microsoft/Phi-3.5-vision-instruct",
               model_kwargs={'trust_remote_code':True, '_attn_implementation': 'eager'},
               processor_kwargs={'trust_remote_code':True, 'num_crops':16},
               device_map="auto")

In [6]:
retriever = WikipediaRetriever()
model = VisualQuestionRAG(retriever, cfg=cfg)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some parameters are on the meta device because they were offloaded to the disk and cpu.


## Generation

In [7]:
!nvidia-smi

Wed Sep  4 22:24:09 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 551.76                 Driver Version: 551.76         CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                     TCC/WDDM  | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 3070 ...  WDDM  |   00000000:01:00.0 Off |                  N/A |
| N/A   47C    P8             12W /   30W |    6105MiB /   8192MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [8]:


question = "How much does this species weight?"
img = "https://images.unsplash.com/photo-1589656966895-2f33e7653819?utm_medium=medium&w=700&q=50&auto=format"

In [9]:
answer, keywords, selected_passages = model(question=question, img_or_img_path=img)

The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` model input instead.
  attn_output = torch.nn.functional.scaled_dot_product_attention(
You are not running the flash-attention implementation, expect numerical differences.


### Model's Answer

In [10]:
answer

'The weight of this species, the polar bear, ranges from 300-800 kg (660-1,760 lb) for males and 150-300 kg (330-660 lb) for females.'

In [11]:
keywords

'polar bear, weight'

In [12]:

for p in selected_passages:
  print(f"""From page {p.page} ({p.url}): \n "{p.passage}".""")

From page Polar bear (https://en.wikipedia.org/wiki/Polar_bear): 
 "Males are generally 200–250 cm (6.6–8.2 ft) long with a weight of 300–800 kg (660–1,760 lb). Females are smaller at 180–200 cm (5.9–6.6 ft) with a weight of 150–300 kg (330–660 lb). Sexual dimorphism in the species is particularly high compared with most other mammals. Male polar bears also have".
From page Polar bear (https://en.wikipedia.org/wiki/Polar_bear): 
 "== Notes ==


== References ==


== Bibliography ==


== External links ==
Polar Bears International website
ARKive—images and movies of the polar bear (Ursus maritimus)".
From page Polar bear (https://en.wikipedia.org/wiki/Polar_bear): 
 "weight of 150–300 kg (330–660 lb). Sexual dimorphism in the species is particularly high compared with most other mammals. Male polar bears also have proportionally larger heads than females. The weight of polar bears fluctuates during the year, as they can bulk up on fat and increase their mass by".
From page List of ursid

## Compare with the baseline models

In [13]:

from transformers.image_utils import load_image


model.generate(question, load_image(img), max_new_tokens=128)

'Polar bears can weigh between 900 to 1,600 pounds (408 to 727 kilograms).'