Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ coverage.xml
.hypothesis/
.pytest_cache/

testing/*
!testing/test.ipynb

# Translations
*.mo
*.pot
Expand Down
8 changes: 6 additions & 2 deletions clarifai_datautils/multimodal/pipeline/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __getitem__(self, index: int):
meta.pop('coordinates', None)
meta.pop('detection_class_prob', None)
image_data = meta.pop('image_base64', None)
id = meta.get('input_id', None)
Comment on lines 29 to +30

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: why are we adding this new field ID and will it be used?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is to identify the corresponding summary that was generated for the images in the PDF!

if image_data is not None:
# Ensure image_data is already bytes before encoding
image_data = base64.b64decode(image_data)
Expand All @@ -39,7 +40,7 @@ def __getitem__(self, index: int):
meta['type'] = 'table'

return MultiModalFeatures(
text=text, image_bytes=image_data, labels=[self.pipeline_name], metadata=meta)
text=text, image_bytes=image_data, labels=[self.pipeline_name], metadata=meta, id=id)

def __len__(self):
return len(self.elements)
Expand All @@ -61,10 +62,13 @@ def task(self):
return DATASET_UPLOAD_TASKS.TEXT_CLASSIFICATION #TODO: Better dataset name in SDK

def __getitem__(self, index: int):
id = self.elements[index].to_dict().get('element_id', None)
id = id[:48] if id is not None else None
return TextFeatures(
text=self.elements[index].text,
labels=self.pipeline_name,
metadata=self.elements[index].metadata.to_dict())
metadata=self.elements[index].metadata.to_dict(),
id=id)

def __len__(self):
return len(self.elements)
102 changes: 102 additions & 0 deletions clarifai_datautils/multimodal/pipeline/summarizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import base64
import random
from typing import List

try:
from unstructured.documents.elements import CompositeElement, ElementMetadata, Image
except ImportError:
raise ImportError(
"Could not import unstructured package. "
"Please install it with `pip install 'unstructured[pdf] @ git+https://github.com/clarifai/unstructured.git@support_clarifai_model'`."
)

from clarifai.client.input import Inputs
from clarifai.client.model import Model

from .basetransform import BaseTransform

SUMMARY_PROMPT = """You are an assistant tasked with summarizing images for retrieval. \
These summaries will be embedded and used to retrieve the raw image. \
Give a concise summary of the image that is well optimized for retrieval."""


class ImageSummarizer(BaseTransform):
""" Summarizes image elements. """

def __init__(self,
model_url: str = "https://clarifai.com/qwen/qwen-VL/models/qwen-VL-Chat",
pat: str = None,
prompt: str = SUMMARY_PROMPT):
"""Initializes an ImageSummarizer object.

Args:
pat (str): Clarifai PAT.
model_url (str): Model URL to use for summarization.
prompt (str): Prompt to use for summarization.
"""
self.pat = pat
self.model_url = model_url
self.model = Model(url=model_url, pat=pat)
self.summary_prompt = prompt

def __call__(self, elements: List) -> List:
"""Applies the transformation.

Args:
elements (List[str]): List of all elements.

Returns:
List of transformed elements along with added summarized elements.

"""
img_elements = []
for _, element in enumerate(elements):
element.metadata.update(ElementMetadata.from_dict({'is_original': True}))
if isinstance(element, Image):
element.metadata.update(
ElementMetadata.from_dict({
'input_id': f'{random.randint(1000000, 99999999)}'
}))
img_elements.append(element)
new_elements = self._summarize_image(img_elements)
elements.extend(new_elements)
return elements

def _summarize_image(self, image_elements: List[Image]) -> List[CompositeElement]:
"""Summarizes an image element.

Args:
image_elements (List[Image]): Image elements to summarize.

Returns:
Summarized image elements list.

"""
img_inputs = []
for element in image_elements:
if not isinstance(element, Image):
continue
new_input_id = "summarize_" + element.metadata.input_id
input_proto = Inputs.get_multimodal_input(
input_id=new_input_id,
image_bytes=base64.b64decode(element.metadata.image_base64),
raw_text=self.summary_prompt)
img_inputs.append(input_proto)
resp = self.model.predict(img_inputs)
del img_inputs

new_elements = []
for i, output in enumerate(resp.outputs):
summary = ""
if image_elements[i].text:
Copy link
Contributor

@sanjaychelliah sanjaychelliah Nov 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe image elements will not have text, so why this check here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I observed that some image elements had text too... it can be seen in the output of 9th cell in this notebook

summary = image_elements[i].text
summary = summary + " \n " + output.data.text.raw
eid = image_elements[i].metadata.input_id
meta_dict = {'source_input_id': eid, 'is_original': False}
comp_element = CompositeElement(
text=summary,
metadata=ElementMetadata.from_dict(meta_dict),
element_id="summarized_" + eid)
new_elements.append(comp_element)

return new_elements
425 changes: 425 additions & 0 deletions testing/test.ipynb

Large diffs are not rendered by default.

32 changes: 31 additions & 1 deletion tests/pipelines/test_multimodal_pipelines.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os.path as osp

import pytest

PDF_FILE_PATH = osp.abspath(
Expand Down Expand Up @@ -66,3 +65,34 @@ def test_pipeline_run_loader(self,):
assert elements.__class__.__name__ == 'MultiModalLoader'
assert len(elements) == 14
assert elements.elements[0].metadata.to_dict()['filename'] == 'Multimodal_sample_file.pdf'

def test_pipeline_summarize(self,):
"""Tests for pipeline run with summarizer"""
import os

from clarifai_datautils.multimodal import Pipeline
from clarifai_datautils.multimodal.pipeline.cleaners import Clean_extra_whitespace
from clarifai_datautils.multimodal.pipeline.PDF import PDFPartitionMultimodal
from clarifai_datautils.multimodal.pipeline.summarizer import ImageSummarizer

pipeline = Pipeline(
name='pipeline-1',
transformations=[
PDFPartitionMultimodal(chunking_strategy="by_title", max_characters=1024),
Clean_extra_whitespace(),
ImageSummarizer(pat=os.environ.get("CLARIFAI_PAT"))
])
elements = pipeline.run(files=PDF_FILE_PATH, loader=False)

assert len(elements) == 17
assert isinstance(elements, list)
assert elements[0].metadata.to_dict()['filename'] == 'Multimodal_sample_file.pdf'
assert elements[0].metadata.to_dict()['page_number'] == 1
assert elements[6].__class__.__name__ == 'Table'
assert elements[-3].__class__.__name__ == 'Image'
assert elements[-3].metadata.is_original is True
assert elements[-3].metadata.input_id is not None
id = elements[-3].metadata.input_id
assert elements[-1].__class__.__name__ == 'CompositeElement'
assert elements[-1].metadata.is_original is False
assert elements[-1].metadata.source_input_id == id
Loading