diff --git a/assets/multi_modal_files/nutrition_data.json b/assets/multi_modal_files/nutrition_data.json index 5c57cbb..0a1d11c 100644 --- a/assets/multi_modal_files/nutrition_data.json +++ b/assets/multi_modal_files/nutrition_data.json @@ -28,5 +28,5 @@ "image_url": "https://photos-us.bazaarvoice.com/photo/2/cGhvdG86Ymx1ZWRpYW1vbmRncm93ZXJz/b20cd90a-4e8b-5cfb-aabb-88df46c19be7", "title": "Chocolate Almond Milk nutrition facts", "content": "Chocolate Almond Milk nutrition facts per cup" - } + } ] diff --git a/config/empty_image_ret_pipeline.yaml b/config/empty_image_ret_pipeline.yaml new file mode 100644 index 0000000..9872fb4 --- /dev/null +++ b/config/empty_image_ret_pipeline.yaml @@ -0,0 +1,58 @@ +components: + prompt_builder: + init_parameters: + required_variables: null + template: "{% for document in documents %}\nImage: <|image_\nThis image is relevant to the question. \n{% endfor %}" + variables: null + type: fastrag.prompt_builders.multi_modal_prompt_builder.MultiModalPromptBuilder + retriever: + init_parameters: + document_store: + init_parameters: + bm25_algorithm: BM25L + bm25_parameters: {} + bm25_tokenization_regex: (?u)\b\w\w+\b + embedding_similarity_function: dot_product + index: fd8404e2-143c-4b33-9842-34aa23026912 + type: haystack.document_stores.in_memory.document_store.InMemoryDocumentStore + filter_policy: replace + filters: null + return_embedding: false + scale_score: false + top_k: 1 + type: haystack.components.retrievers.in_memory.embedding_retriever.InMemoryEmbeddingRetriever + text_embedder: + init_parameters: + batch_size: 32 + config_kwargs: null + device: + device: cpu + type: single + embedding_separator: ' + + ' + meta_fields_to_embed: [] + model: openai/clip-vit-large-patch14 + model_kwargs: null + normalize_embeddings: false + precision: float32 + prefix: '' + progress_bar: true + suffix: '' + token: + env_vars: + - HF_API_TOKEN + - HF_TOKEN + strict: false + type: env_var + tokenizer_kwargs: null + truncate_dim: null + trust_remote_code: false + type: fastrag.embedders.image_embedders.SentenceTransformersImageTextEmbedder +connections: +- receiver: retriever.query_embedding + sender: text_embedder.embedding +- receiver: prompt_builder.documents + sender: retriever.documents +max_runs_per_component: 100 +metadata: {} diff --git a/config/empty_ret_llama_32.yaml b/config/empty_ret_llama_32.yaml new file mode 100644 index 0000000..efadcc8 --- /dev/null +++ b/config/empty_ret_llama_32.yaml @@ -0,0 +1,41 @@ +components: + embedder: + init_parameters: + batch_size: 32 + device: + device: cpu + type: single + model: sentence-transformers/all-MiniLM-L6-v2 + normalize_embeddings: false + prefix: '' + progress_bar: true + suffix: '' + token: + env_vars: + - HF_API_TOKEN + strict: false + type: env_var + type: haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder + prompt_builder: + init_parameters: + template: "```{% for document in documents %}\nImage: <|image|>\nThis image shows: {{ document.content }} \n{% endfor %}```\nThought: " + type: fastrag.prompt_builders.multi_modal_prompt_builder.MultiModalPromptBuilder + retriever: + init_parameters: + document_store: + init_parameters: + bm25_algorithm: BM25L + bm25_parameters: {} + bm25_tokenization_regex: (?u)\b\w\w+\b + embedding_similarity_function: dot_product + type: haystack.document_stores.in_memory.document_store.InMemoryDocumentStore + filters: null + return_embedding: false + scale_score: false + top_k: 1 + type: haystack.components.retrievers.in_memory.embedding_retriever.InMemoryEmbeddingRetriever +connections: +- receiver: retriever.query_embedding + sender: embedder.embedding +- receiver: prompt_builder.documents + sender: retriever.documents diff --git a/config/vis_agent_phi_35.yaml b/config/vis_agent_phi_35.yaml new file mode 100644 index 0000000..2381655 --- /dev/null +++ b/config/vis_agent_phi_35.yaml @@ -0,0 +1,24 @@ +chat_model: + generator_kwargs: + model: microsoft/Phi-3.5-vision-instruct + task: "image-to-text" + generation_kwargs: + max_new_tokens: 300 + do_sample: false + huggingface_pipeline_kwargs: + torch_dtype: torch.bfloat16 + trust_remote_code: true + _attn_implementation: eager + generator_class: fastrag.generators.llava.Phi35VisionHFGenerator +tools: + - type: doc_with_image + params: + name: "docRetriever" + description: 'useful for when you need to retrieve images and text to answer questions. Use the following format: {{ "input": [your tool input here ] }}.' + pipeline_or_yaml_file: "config/empty_retrieval_pipeline.yaml" +system_tools: + - type: doc_with_image_index_from_provider + params: + name: "docIndex" + pipeline_or_yaml_file: "config/empty_index_pipeline.yaml" + tool_provider_name: docRetriever diff --git a/config/vis_agent_phi_35_image.yaml b/config/vis_agent_phi_35_image.yaml new file mode 100644 index 0000000..8299085 --- /dev/null +++ b/config/vis_agent_phi_35_image.yaml @@ -0,0 +1,24 @@ +chat_model: + generator_kwargs: + model: microsoft/Phi-3.5-vision-instruct + task: "image-to-text" + generation_kwargs: + max_new_tokens: 300 + do_sample: false + huggingface_pipeline_kwargs: + torch_dtype: torch.bfloat16 + trust_remote_code: true + _attn_implementation: eager + generator_class: fastrag.generators.llava.Phi35VisionHFGenerator +tools: + - type: doc_with_image + params: + name: imageRetriever + description: 'useful for when you need to retrieve images to answer questions. Use the following format: {{ "input": [your tool input here ] }}.' + pipeline_or_yaml_file: "config/empty_image_ret_pipeline.yaml" +system_tools: + - type: image + params: + name: "docIndex" + model_name_or_path: "openai/clip-vit-large-patch14" + tool_provider_name: imageRetriever diff --git a/config/visual_chat_agent.yaml b/config/visual_chat_agent.yaml index 2381655..e765bb7 100644 --- a/config/visual_chat_agent.yaml +++ b/config/visual_chat_agent.yaml @@ -1,6 +1,6 @@ chat_model: generator_kwargs: - model: microsoft/Phi-3.5-vision-instruct + model: meta-llama/Llama-3.2-11B-Vision-Instruct task: "image-to-text" generation_kwargs: max_new_tokens: 300 @@ -8,14 +8,13 @@ chat_model: huggingface_pipeline_kwargs: torch_dtype: torch.bfloat16 trust_remote_code: true - _attn_implementation: eager - generator_class: fastrag.generators.llava.Phi35VisionHFGenerator + generator_class: fastrag.generators.llava.LlamaHFGenerator tools: - type: doc_with_image params: name: "docRetriever" description: 'useful for when you need to retrieve images and text to answer questions. Use the following format: {{ "input": [your tool input here ] }}.' - pipeline_or_yaml_file: "config/empty_retrieval_pipeline.yaml" + pipeline_or_yaml_file: "config/empty_ret_llama_32.yaml" system_tools: - type: doc_with_image_index_from_provider params: diff --git a/fastrag/agents/base.py b/fastrag/agents/base.py index 28e2ce8..81caba0 100644 --- a/fastrag/agents/base.py +++ b/fastrag/agents/base.py @@ -421,3 +421,7 @@ def create_agent_step(self, max_steps: Optional[int] = None) -> AgentStep: return AgentStep( max_steps=max_steps or self.max_steps, final_answer_pattern=self.final_answer_pattern ) + + def clear(self): + self.memory.clear() + self.tm.clear_tool_history() diff --git a/fastrag/agents/create_agent.py b/fastrag/agents/create_agent.py index 67ce071..a8102a1 100644 --- a/fastrag/agents/create_agent.py +++ b/fastrag/agents/create_agent.py @@ -82,7 +82,7 @@ def get_generator(chat_model_config): if not tokenizer.pad_token: tokenizer.pad_token = tokenizer.eos_token - stop_word_list = ["Observation:", "<|eot_id|>", "<|end|>"] + stop_word_list = ["Tool Output", "Observation:", "<|eot_id|>", "<|end|>"] sw = StopWordsByTextCriteria(tokenizer=tokenizer, stop_words=stop_word_list, device="cpu") if "generation_kwargs" not in chat_model_config["generator_kwargs"]: diff --git a/fastrag/agents/tools/image_utils.py b/fastrag/agents/tools/image_utils.py new file mode 100644 index 0000000..d2c7727 --- /dev/null +++ b/fastrag/agents/tools/image_utils.py @@ -0,0 +1,21 @@ +import base64 +from io import BytesIO + +import requests + + +def get_base64_from_url(image_url): + if image_url.startswith("http"): + response = requests.get( + image_url, + headers={ + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/109.0.0.0 Safari/537.36", + }, + ) + buffered = BytesIO(response.content) + img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") + else: + with open(image_url, "rb") as image_file: + img_str = base64.b64encode(image_file.read()) + + return img_str diff --git a/fastrag/agents/tools/tools.py b/fastrag/agents/tools/tools.py index a0ca2ba..3243a9d 100644 --- a/fastrag/agents/tools/tools.py +++ b/fastrag/agents/tools/tools.py @@ -6,7 +6,9 @@ from tqdm import tqdm from fastrag.agents.base import Tool +from fastrag.agents.tools.image_utils import get_base64_from_url from fastrag.agents.utils import Color, load_text +from fastrag.embedders.image_embedders import SentenceTransformersImageEmbedder COMPONENT_WITH_STORE = "retriever" @@ -190,10 +192,55 @@ def __init__( def example_to_doc(self, ex): return Document( - content=ex["content"], meta={"title": ex["title"], "image_url": ex["image_url"]} + content=ex["content"], + meta={ + "title": ex["title"], + "image_url": ex["image_url"], + "image_base64": get_base64_from_url(ex["image_url"]), + }, ) +class ImageHaystackIndexTool(Tool): + def __init__( + self, + name: str, + description: str = "", + logging_color: Color = Color.YELLOW, + model_name_or_path=None, + tool_provider_map: Dict[str, Tool] = None, + tool_provider_name: str = None, + ): + # get the store from the correct tool with the document store to use + document_store = tool_provider_map[tool_provider_name].get_store() + super().__init__( + name=name, + description=description, + logging_color=logging_color, + ) + self.doc_embedder = SentenceTransformersImageEmbedder(model=model_name_or_path) + self.doc_embedder.warm_up() + + self.document_store = document_store + + def run(self, tool_input: Union[str, List[dict]], params: Optional[dict] = None) -> str: + if isinstance(tool_input, str): + tool_input = json.loads(tool_input) + + elif isinstance(tool_input, dict) and "docs" in tool_input: + tool_input = tool_input["docs"] + + docs = [ + Document( + content=element["content"], + meta=dict(image_base64=get_base64_from_url(element["content"])), + ) + for element in tool_input + ] + docs_with_embeddings = self.doc_embedder.run(docs) + self.document_store.write_documents(docs_with_embeddings["documents"]) + + class DocWithImageFromProvidersHaystackIndexTool(DocWithImageHaystackIndexTool): def __init__( self, @@ -220,4 +267,5 @@ def __init__( "doc_with_image_index": DocWithImageHaystackIndexTool, "doc_with_image_index_from_provider": DocWithImageFromProvidersHaystackIndexTool, "doc": HaystackQueryTool, + "image": ImageHaystackIndexTool, } diff --git a/fastrag/embedders/image_embedders.py b/fastrag/embedders/image_embedders.py new file mode 100644 index 0000000..eb67a81 --- /dev/null +++ b/fastrag/embedders/image_embedders.py @@ -0,0 +1,68 @@ +import base64 +from io import BytesIO +from typing import List + +import torch +from haystack import Document, component +from haystack.components.embedders import ( + SentenceTransformersDocumentEmbedder, + SentenceTransformersTextEmbedder, +) +from PIL import Image +from transformers import CLIPModel, CLIPProcessor + + +def base64_to_image(base_64_input): + bytes_io = BytesIO(base64.b64decode(base_64_input)) + return Image.open(bytes_io) + + +class BaseSentenceTransformersImageEmbedder(SentenceTransformersDocumentEmbedder): + def warm_up(self): + """ + Initializes the component. + """ + + # Load CLIP model + self.embeding_backend = CLIPModel.from_pretrained(self.model) + self.processor = CLIPProcessor.from_pretrained(self.model) + + +@component +class SentenceTransformersImageEmbedder( + BaseSentenceTransformersImageEmbedder, SentenceTransformersDocumentEmbedder +): + @component.output_types(documents=List[Document]) + def run(self, documents: List[Document]): + images = [base64_to_image(doc.meta["image_base64"]) for doc in documents] + inputs = self.processor(images=images, return_tensors="pt", padding=True) + with torch.no_grad(): + embeddings = self.embeding_backend.get_image_features(**inputs) + + for doc, emb in zip(documents, embeddings): + doc.embedding = emb + + return {"documents": documents} + + +@component +class SentenceTransformersImageTextEmbedder( + BaseSentenceTransformersImageEmbedder, SentenceTransformersTextEmbedder +): + @component.output_types(embedding=List[float]) + def run(self, text: str): + """ + Embed a single string. + + :param text: + Text to embed. + + :returns: + A dictionary with the following keys: + - `embedding`: The embedding of the input text. + """ + inputs = self.processor(text=text, return_tensors="pt", padding=True) + with torch.no_grad(): + text_outputs = self.embeding_backend.get_text_features(**inputs) + + return {"embedding": text_outputs.tolist()[0]} diff --git a/fastrag/generators/image_caption_generator.py b/fastrag/generators/image_caption_generator.py new file mode 100644 index 0000000..cbafc41 --- /dev/null +++ b/fastrag/generators/image_caption_generator.py @@ -0,0 +1,73 @@ +import torch +from PIL import Image +from transformers import AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor +from transformers.image_utils import load_image + + +def get_default_device(): + if torch.cuda.is_available(): + return "cuda" + if torch.xpu.is_available(): + return "xpu" + return "cpu" + + +class ImageCaptionGenerator: + def __init__(self, model_name_or_path="microsoft/caption-large"): + print("Loading Image Caption model ...") + self.model_name_or_path = model_name_or_path + self.device = get_default_device() + self.processor = AutoProcessor.from_pretrained(model_name_or_path) + + def caption(self, path): + raise NotImplementedError() + + +class GitImageCaptionGenerator(ImageCaptionGenerator): + def __init__(self, model_name_or_path="microsoft/git-large-r-textcaps"): + super().__init__(model_name_or_path) + self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path).to(self.device) + + def caption(self, path): + image = Image.open(path) + + pixel_values = self.processor(images=image, return_tensors="pt").pixel_values + generated_ids = self.model.generate( + pixel_values=pixel_values.to(self.model.device), max_length=30, do_sample=False + ) + generated_caption = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + return generated_caption + + +class VLLMImageCaptionGenerator(ImageCaptionGenerator): + def __init__(self, model_name_or_path="HuggingFaceTB/SmolVLM-Instruct"): + super().__init__(model_name_or_path) + + self.model = AutoModelForVision2Seq.from_pretrained( + model_name_or_path, + torch_dtype=torch.bfloat16, + _attn_implementation="flash_attention_2" if self.device == "cuda" else "eager", + ) + + def caption(self, path): + # Load images + image1 = load_image(path) + + # Create input messages + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "Can you describe the image?"}, + ], + }, + ] + + # Prepare inputs + prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True) + inputs = self.processor(text=prompt, images=[image1], return_tensors="pt") + inputs = inputs.to(self.model.device) + # Generate outputs + generated_ids = self.model.generate(**inputs, max_new_tokens=20, do_sample=False) + return self.processor.batch_decode(generated_ids[:, inputs["input_ids"].shape[1] :])[0] diff --git a/fastrag/generators/llava.py b/fastrag/generators/llava.py index 6070114..476b860 100644 --- a/fastrag/generators/llava.py +++ b/fastrag/generators/llava.py @@ -113,7 +113,13 @@ def __init__( del self.generation_kwargs["stopping_criteria"] self.processor = AutoProcessor.from_pretrained(model) - self.image_token = "<image>" + image_tokens = [ + v + for v in self.processor.tokenizer.added_tokens_decoder.values() + if "image" in v.content + ] + assert len(image_tokens) > 0, "No image token found in the tokenizer" + self.image_token = image_tokens[0].content @component.output_types(replies=List[str]) def run( @@ -208,6 +214,92 @@ def get_user_text(self, chat_snippet): return user_text +class LlamaHFGenerator(LlavaHFGenerator): + @component.output_types(replies=List[str]) + def run( + self, + prompt: str, + images: List[str] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + Run the text generation model on the given prompt. + + :param prompt: + A string representing the prompt. + :param images: + A list of base64 strings representing the input images. + :param generation_kwargs: + Additional keyword arguments for text generation. + + :returns: + A dictionary containing the generated replies. + - replies: A list of strings representing the generated replies. + - raw_images: A list of the raw PIL images. + """ + if self.pipeline is None: + raise RuntimeError( + "The generation model has not been loaded. Please call warm_up() before running." + ) + + if not prompt: + return {"replies": []} + + # merge generation kwargs from init method with those from run method + updated_generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} + + raw_images = None + if images: + raw_images = [base64_to_image(img) for img in images] + + present_image_token_count = prompt.count(self.image_token) + image_token_count_diff = len(images) - present_image_token_count + + # check if we need to add additional image tokens + if image_token_count_diff > 0: + image_token_full_str = " ".join( + [self.image_token for _ in range(image_token_count_diff)] + ) + prompt = f"Current Images: {image_token_full_str}\n" + prompt + + print(f"USING {len(raw_images)=}!!!!!!!!") + inputs = self.processor( + images=raw_images, text=prompt, add_special_tokens=False, return_tensors="pt" + ) + else: + inputs = self.processor( + images=None, text=prompt, add_special_tokens=False, return_tensors="pt" + ) + + updated_generation_kwargs["max_length"] = updated_generation_kwargs.get("max_length", 32000) + print(f"{prompt=}") + stop_strings = self.stopping_criteria_list[0].stop_words_text + output = self.pipeline.model.generate( + **inputs, + stop_strings=stop_strings, + tokenizer=self.processor.tokenizer, + **updated_generation_kwargs, + ) + + replies = self.processor.batch_decode( + output[:, inputs["input_ids"].shape[1] :], skip_special_tokens=True + ) + + gen_stop_words = self.get_stop_words_from_kwargs() + + stop_words = None + stop_words = gen_stop_words if gen_stop_words else self.stop_words + if stop_words: + # the output of the pipeline includes the stop word + replies = [ + reply.replace(stop_word, "").rstrip() + for reply in replies + for stop_word in stop_words + ] + + return {"replies": replies, "raw_images": raw_images} + + class Phi35VisionHFGenerator(HuggingFaceLocalGenerator): """ Generator based on a Phi35 Hugging Face model loaded. diff --git a/fastrag/prompt_builders/multi_modal_prompt_builder.py b/fastrag/prompt_builders/multi_modal_prompt_builder.py index 3f17528..6ccada2 100644 --- a/fastrag/prompt_builders/multi_modal_prompt_builder.py +++ b/fastrag/prompt_builders/multi_modal_prompt_builder.py @@ -31,22 +31,10 @@ def run(self, **kwargs): - `prompt`: The updated prompt text after rendering the prompt template. """ prompt_dict = {"prompt": self.template.render(kwargs)} - prompt_dict["images"] = [ - self.get_base64_from_url(doc.meta["image_url"]) for doc in kwargs["documents"] - ] - return prompt_dict + prompt_dict["images"] = [doc.meta["image_base64"] for doc in kwargs["documents"]] - def get_base64_from_url(self, image_url): - response = requests.get( - image_url, - headers={ - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/109.0.0.0 Safari/537.36", - }, - ) - buffered = BytesIO(response.content) - img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") - return img_str + return prompt_dict component._component(MultiModalPromptBuilder) diff --git a/fastrag/ui/chainlit_multi_modal_agent.py b/fastrag/ui/chainlit_multi_modal_agent.py index 9c1942e..4390781 100644 --- a/fastrag/ui/chainlit_multi_modal_agent.py +++ b/fastrag/ui/chainlit_multi_modal_agent.py @@ -45,12 +45,12 @@ @cl.on_chat_end def chat_end(): # clear memory - agent.memory.clear() + agent.clear() def add_images_to_message(additional_params): image_elements = [] - last_image = [additional_params["images"][-1]] + last_image = additional_params["images"] for image_base64_index, image_base64 in enumerate(last_image): image_uuid = str(uuid.uuid4()) bytes_io = BytesIO(base64.b64decode(image_base64)) @@ -75,6 +75,16 @@ def parse_element(element, params): if "json" in element.mime: params["docs"] = json.load(open(element.path, "r")) + if any([image_suffix in element.mime for image_suffix in ["png", "jpeg", "jpg"]]): + if "docs" not in params: + params["docs"] = [] + + image_element = { + "content": element.path, + } + + params["docs"].append(image_element) + # params for the agent params = {} @@ -95,8 +105,12 @@ def parse_element(element, params): answer = agent_result["answers"][0].answer # display retrieved image, if exists - additional_params = agent.memory.get_additional_params() - if "images" in additional_params and len(additional_params["images"]) > 0: + additional_params = agent.memory.list[-1].get("additional_params", None) + if ( + additional_params + and "images" in additional_params + and len(additional_params["images"]) > 0 + ): image_elements = add_images_to_message(additional_params) _ = await cl.Message( diff --git a/setup.cfg b/setup.cfg index c0790e1..a852d67 100644 --- a/setup.cfg +++ b/setup.cfg @@ -19,6 +19,7 @@ install_requires = chainlit>=1.0.506 sentence-transformers>=2.3.0 events + pydantic==2.10.1 [options.extras_require] dev = diff --git a/setup.py b/setup.py index 9636edc..30f8e29 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ def get_version(rel_path): long_description_content_type="text/markdown", url="https://github.com/IntelLabs/fastRAG", license="Apache-2.0", - python_requires=">=3.8, <3.12", + python_requires=">=3.8", classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License",