Skip to content

Commit

Permalink
[DEVX-419]: Adding ID support for labels (#338)
Browse files Browse the repository at this point in the history
* ID_support_for_labels

* clarifai_extra_requires
  • Loading branch information
sanjaychelliah committed Apr 25, 2024
1 parent 008be90 commit 88fabfa
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 13 deletions.
33 changes: 26 additions & 7 deletions clarifai/client/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def _get_proto(input_id: str,
text_pb: Text = None,
geo_info: List = None,
labels: List = None,
label_ids: List = None,
metadata: Struct = None) -> Input:
"""Create input proto for image data type.
Args:
Expand All @@ -82,22 +83,35 @@ def _get_proto(input_id: str,
audio_pb (Audio): The audio proto to be used for the input.
text_pb (Text): The text proto to be used for the input.
geo_info (list): A list of longitude and latitude for the geo point.
labels (list): A list of labels for the input.
labels (list): A list of label names for the input.
label_ids (list): A list of label ids for the input.
metadata (Struct): A Struct of metadata for the input.
Returns:
Input: An Input object for the specified input ID.
"""
assert geo_info is None or isinstance(
geo_info, list), "geo_info must be a list of longitude and latitude"
assert labels is None or isinstance(labels, list), "labels must be a list of strings"
assert label_ids is None or isinstance(label_ids, list), "label_ids must be a list of strings"
assert metadata is None or isinstance(metadata, Struct), "metadata must be a Struct"
geo_pb = resources_pb2.Geo(geo_point=resources_pb2.GeoPoint(
longitude=geo_info[0], latitude=geo_info[1])) if geo_info else None
concepts=[
if labels:
if not label_ids:
concepts=[
resources_pb2.Concept(
id=f"id-{''.join(_label.split(' '))}", name=_label, value=1.)\
for _label in labels
]if labels else None
]
else:
assert len(labels) == len(label_ids), "labels and label_ids must be of the same length"
concepts=[
resources_pb2.Concept(
id=label_id, name=_label, value=1.)\
for label_id, _label in zip(label_ids, labels)
]
else:
concepts = None

if dataset_id:
return resources_pb2.Input(
Expand Down Expand Up @@ -467,13 +481,14 @@ def get_text_inputs_from_folder(folder_path: str, dataset_id: str = None,
return input_protos

@staticmethod
def get_bbox_proto(input_id: str, label: str, bbox: List) -> Annotation:
def get_bbox_proto(input_id: str, label: str, bbox: List, label_id: str = None) -> Annotation:
"""Create an annotation proto for each bounding box, label input pair.
Args:
input_id (str): The input ID for the annotation to create.
label (str): annotation label
label (str): annotation label name
bbox (List): a list of a single bbox's coordinates. # bbox ordering: [xmin, ymin, xmax, ymax]
label_id (str): annotation label ID
Returns:
An annotation object for the specified input ID.
Expand All @@ -500,19 +515,22 @@ def get_bbox_proto(input_id: str, label: str, bbox: List) -> Annotation:
data=resources_pb2.Data(concepts=[
resources_pb2.Concept(
id=f"id-{''.join(label.split(' '))}", name=label, value=1.)
if not label_id else resources_pb2.Concept(id=label_id, name=label, value=1.)
]))
]))

return input_annot_proto

@staticmethod
def get_mask_proto(input_id: str, label: str, polygons: List[List[float]]) -> Annotation:
def get_mask_proto(input_id: str, label: str, polygons: List[List[float]],
label_id: str = None) -> Annotation:
"""Create an annotation proto for each polygon box, label input pair.
Args:
input_id (str): The input ID for the annotation to create.
label (str): annotation label
label (str): annotation label name
polygons (List): Polygon x,y points iterable
label_id (str): annotation label ID
Returns:
An annotation object for the specified input ID.
Expand All @@ -537,6 +555,7 @@ def get_mask_proto(input_id: str, label: str, polygons: List[List[float]]) -> An
data=resources_pb2.Data(concepts=[
resources_pb2.Concept(
id=f"id-{''.join(label.split(' '))}", name=label, value=1.)
if not label_id else resources_pb2.Concept(id=label_id, name=label, value=1.)
]))
]))

Expand Down
4 changes: 4 additions & 0 deletions clarifai/datasets/upload/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class TextFeatures:
labels: List[Union[str, int]] # List[str or int] to cater for multi-class tasks
id: Optional[int] = None # text_id
metadata: Optional[dict] = None
label_ids: Optional[List[str]] = None


@dataclass
Expand All @@ -21,6 +22,7 @@ class VisualClassificationFeatures:
id: Optional[int] = None # image_id
metadata: Optional[dict] = None
image_bytes: Optional[bytes] = None
label_ids: Optional[List[str]] = None


@dataclass
Expand All @@ -33,6 +35,7 @@ class VisualDetectionFeatures:
id: Optional[int] = None # image_id
metadata: Optional[dict] = None
image_bytes: Optional[bytes] = None
label_ids: Optional[List[str]] = None


@dataclass
Expand All @@ -45,3 +48,4 @@ class VisualSegmentationFeatures:
id: Optional[int] = None # image_id
metadata: Optional[dict] = None
image_bytes: Optional[bytes] = None
label_ids: Optional[List[str]] = None
27 changes: 25 additions & 2 deletions clarifai/datasets/upload/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def process_data_item(id):
image_path = data_item.image_path
labels = data_item.labels if isinstance(data_item.labels,
list) else [data_item.labels] # clarifai concept
label_ids = data_item.label_ids
input_id = f"{self.dataset_id}-{uuid.uuid4().hex[:8]}" if data_item.id is None else f"{self.dataset_id}-{str(data_item.id)}"
geo_info = data_item.geo_info
if data_item.metadata is not None:
Expand All @@ -49,6 +50,7 @@ def process_data_item(id):
image_bytes=data_item.image_bytes,
dataset_id=self.dataset_id,
labels=labels,
label_ids=label_ids,
geo_info=geo_info,
metadata=metadata))
else:
Expand All @@ -58,6 +60,7 @@ def process_data_item(id):
image_file=image_path,
dataset_id=self.dataset_id,
labels=labels,
label_ids=label_ids,
geo_info=geo_info,
metadata=metadata))

Expand Down Expand Up @@ -91,6 +94,12 @@ def process_data_item(id):
metadata = Struct()
image = data_item.image_path
labels = data_item.labels # list:[l1,...,ln]
if data_item.label_ids is not None:
assert len(labels) == len(
data_item.label_ids), "Length of labels and label_ids must be equal"
label_ids = data_item.label_ids
else:
label_ids = None
bboxes = data_item.bboxes # [[xmin,ymin,xmax,ymax],...,[xmin,ymin,xmax,ymax]]
input_id = f"{self.dataset_id}-{uuid.uuid4().hex[:8]}" if data_item.id is None else f"{self.dataset_id}-{str(data_item.id)}"
if data_item.metadata is not None:
Expand Down Expand Up @@ -120,7 +129,11 @@ def process_data_item(id):
# one id could have more than one bbox and label
for i in range(len(bboxes)):
annotation_protos.append(
Inputs.get_bbox_proto(input_id=input_id, label=labels[i], bbox=bboxes[i]))
Inputs.get_bbox_proto(
input_id=input_id,
label=labels[i],
bbox=bboxes[i],
label_id=label_ids[i] if label_ids else None))

with ThreadPoolExecutor(max_workers=4) as executor:
futures = [executor.submit(process_data_item, id) for id in batch_input_ids]
Expand Down Expand Up @@ -152,6 +165,12 @@ def process_data_item(id):
metadata = Struct()
image = data_item.image_path
labels = data_item.labels
if data_item.label_ids is not None:
assert len(labels) == len(
data_item.label_ids), "Length of labels and label_ids must be equal"
label_ids = data_item.label_ids
else:
label_ids = None
_polygons = data_item.polygons # list of polygons: [[[x,y],...,[x,y]],...]
input_id = f"{self.dataset_id}-{uuid.uuid4().hex[:8]}" if data_item.id is None else f"{self.dataset_id}-{str(data_item.id)}"
if data_item.metadata is not None:
Expand Down Expand Up @@ -183,7 +202,11 @@ def process_data_item(id):
for i, _polygon in enumerate(_polygons):
try:
annotation_protos.append(
Inputs.get_mask_proto(input_id=input_id, label=labels[i], polygons=_polygon))
Inputs.get_mask_proto(
input_id=input_id,
label=labels[i],
polygons=_polygon,
label_id=label_ids[i] if label_ids else None))
except IndexError:
continue

Expand Down
9 changes: 7 additions & 2 deletions clarifai/datasets/upload/loaders/coco_captions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@

import os

from pycocotools.coco import COCO

from clarifai.datasets.upload.base import ClarifaiDataLoader

from ..features import VisualClassificationFeatures

#pycocotools is a dependency for this loader
try:
from pycocotools.coco import COCO
except ImportError:
raise ImportError("Could not import pycocotools package. "
"Please do `pip install 'clarifai[all]'` to import pycocotools.")


class COCOCaptionsDataLoader(ClarifaiDataLoader):
"""COCO Image Captioning Dataset."""
Expand Down
9 changes: 7 additions & 2 deletions clarifai/datasets/upload/loaders/coco_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@

import os

from pycocotools.coco import COCO

from ..base import ClarifaiDataLoader

from ..features import VisualDetectionFeatures

#pycocotools is a dependency for this loader
try:
from pycocotools.coco import COCO
except ImportError:
raise ImportError("Could not import pycocotools package. "
"Please do `pip install 'clarifai[all]'` to import pycocotools.")


class COCODetectionDataLoader(ClarifaiDataLoader):

Expand Down
2 changes: 2 additions & 0 deletions clarifai/datasets/upload/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def process_data_item(id):
text = data_item.text
labels = data_item.labels if isinstance(data_item.labels,
list) else [data_item.labels] # clarifai concept
label_ids = data_item.label_ids
input_id = f"{self.dataset_id}-{id}" if data_item.id is None else f"{self.dataset_id}-{str(data_item.id)}"
if data_item.metadata is not None:
metadata.update(data_item.metadata)
Expand All @@ -43,6 +44,7 @@ def process_data_item(id):
raw_text=text,
dataset_id=self.dataset_id,
labels=labels,
label_ids=label_ids,
metadata=metadata))

with ThreadPoolExecutor(max_workers=4) as executor:
Expand Down

0 comments on commit 88fabfa

Please sign in to comment.