-
Notifications
You must be signed in to change notification settings - Fork 0
[DEVX-828] Added image summarization in multimodal pipeline #31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
eb8d6e9
2979c62
af46344
d52f5a6
fe0d75e
f7dd88a
09c977c
c0cbfb4
9203227
38ccc91
e3a9f13
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -51,6 +51,9 @@ coverage.xml | |
.hypothesis/ | ||
.pytest_cache/ | ||
|
||
testing/* | ||
!testing/test.ipynb | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import base64 | ||
import random | ||
from typing import List | ||
|
||
try: | ||
from unstructured.documents.elements import CompositeElement, ElementMetadata, Image | ||
except ImportError: | ||
raise ImportError( | ||
"Could not import unstructured package. " | ||
"Please install it with `pip install 'unstructured[pdf] @ git+https://github.com/clarifai/unstructured.git@support_clarifai_model'`." | ||
) | ||
|
||
from clarifai.client.input import Inputs | ||
from clarifai.client.model import Model | ||
|
||
from .basetransform import BaseTransform | ||
|
||
SUMMARY_PROMPT = """You are an assistant tasked with summarizing images for retrieval. \ | ||
These summaries will be embedded and used to retrieve the raw image. \ | ||
Give a concise summary of the image that is well optimized for retrieval.""" | ||
|
||
|
||
class ImageSummarizer(BaseTransform): | ||
""" Summarizes image elements. """ | ||
|
||
def __init__(self, | ||
model_url: str = "https://clarifai.com/qwen/qwen-VL/models/qwen-VL-Chat", | ||
pat: str = None, | ||
prompt: str = SUMMARY_PROMPT): | ||
"""Initializes an ImageSummarizer object. | ||
|
||
Args: | ||
pat (str): Clarifai PAT. | ||
model_url (str): Model URL to use for summarization. | ||
prompt (str): Prompt to use for summarization. | ||
""" | ||
self.pat = pat | ||
self.model_url = model_url | ||
self.model = Model(url=model_url, pat=pat) | ||
self.summary_prompt = prompt | ||
|
||
def __call__(self, elements: List) -> List: | ||
"""Applies the transformation. | ||
|
||
Args: | ||
elements (List[str]): List of all elements. | ||
|
||
Returns: | ||
List of transformed elements along with added summarized elements. | ||
|
||
""" | ||
img_elements = [] | ||
for _, element in enumerate(elements): | ||
element.metadata.update(ElementMetadata.from_dict({'is_original': True})) | ||
if isinstance(element, Image): | ||
element.metadata.update( | ||
ElementMetadata.from_dict({ | ||
'input_id': f'{random.randint(1000000, 99999999)}' | ||
})) | ||
img_elements.append(element) | ||
new_elements = self._summarize_image(img_elements) | ||
elements.extend(new_elements) | ||
return elements | ||
|
||
def _summarize_image(self, image_elements: List[Image]) -> List[CompositeElement]: | ||
"""Summarizes an image element. | ||
|
||
Args: | ||
image_elements (List[Image]): Image elements to summarize. | ||
|
||
Returns: | ||
Summarized image elements list. | ||
|
||
""" | ||
img_inputs = [] | ||
for element in image_elements: | ||
if not isinstance(element, Image): | ||
continue | ||
new_input_id = "summarize_" + element.metadata.input_id | ||
input_proto = Inputs.get_multimodal_input( | ||
input_id=new_input_id, | ||
image_bytes=base64.b64decode(element.metadata.image_base64), | ||
raw_text=self.summary_prompt) | ||
img_inputs.append(input_proto) | ||
resp = self.model.predict(img_inputs) | ||
sanjaychelliah marked this conversation as resolved.
Show resolved
Hide resolved
sanjaychelliah marked this conversation as resolved.
Show resolved
Hide resolved
|
||
del img_inputs | ||
|
||
new_elements = [] | ||
for i, output in enumerate(resp.outputs): | ||
summary = "" | ||
if image_elements[i].text: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe image elements will not have text, so why this check here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I observed that some image elements had text too... it can be seen in the output of 9th cell in this notebook |
||
summary = image_elements[i].text | ||
summary = summary + " \n " + output.data.text.raw | ||
eid = image_elements[i].metadata.input_id | ||
meta_dict = {'source_input_id': eid, 'is_original': False} | ||
comp_element = CompositeElement( | ||
text=summary, | ||
metadata=ElementMetadata.from_dict(meta_dict), | ||
element_id="summarized_" + eid) | ||
new_elements.append(comp_element) | ||
|
||
return new_elements |
Large diffs are not rendered by default.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QQ: why are we adding this new field ID and will it be used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is to identify the corresponding summary that was generated for the images in the PDF!