Skip to content

Commit

Permalink
perf: improve log from pipeline monitors (#1905)
Browse files Browse the repository at this point in the history
> there are some code coverage notifications, @frascuchon

Let's forget them this time. Thanks for the warning.
  • Loading branch information
frascuchon committed Nov 17, 2022
1 parent 4621ecf commit 8aec980
Show file tree
Hide file tree
Showing 10 changed files with 335 additions and 90 deletions.
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ dependencies = [
"wrapt ~= 1.13.0",
# weaksupervision
"numpy",
"tqdm >= 4.27.0"
"tqdm >= 4.27.0",
# monitor background consumers
"backoff",
"monotonic"

]
dynamic = ["version"]
Expand Down
17 changes: 13 additions & 4 deletions src/argilla/monitoring/_flair.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import Any, Dict, List, Optional, Tuple, Union

from argilla import TokenClassificationRecord
from argilla.client.models import BulkResponse
from argilla.client.api import Api
from argilla.monitoring.base import BaseMonitor
from argilla.monitoring.types import MissingType

Expand All @@ -32,7 +32,8 @@

class FlairMonitor(BaseMonitor):
def _prepare_log_data(
self, data: List[Tuple[Sentence, Dict[str, Any]]]
self,
data: List[Tuple[Sentence, Dict[str, Any]]],
) -> Dict[str, Any]:
return dict(
records=[
Expand Down Expand Up @@ -74,15 +75,23 @@ def predict(self, sentences: Union[List[Sentence], Sentence], *args, **kwargs):
if self.is_record_accepted()
]
if filtered_data:
self._log_future = self.log_async(filtered_data)
self._log_future = self.send_records(filtered_data)

return result


def flair_monitor(
pl: SequenceTagger,
api: Api,
dataset: str,
sample_rate: float,
log_interval: float,
) -> Optional[SequenceTagger]:

return FlairMonitor(pl, dataset=dataset, sample_rate=sample_rate)
return FlairMonitor(
pl,
api=api,
dataset=dataset,
sample_rate=sample_rate,
log_interval=log_interval,
)
25 changes: 20 additions & 5 deletions src/argilla/monitoring/_spacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Any, Dict, Optional, Tuple

from argilla import TokenClassificationRecord
from argilla.client.api import Api
from argilla.monitoring.base import BaseMonitor
from argilla.monitoring.types import MissingType

Expand All @@ -33,7 +34,9 @@ class SpacyNERMonitor(BaseMonitor):

@staticmethod
def doc2token_classification(
doc: Doc, agent: str, metadata: Optional[Dict[str, Any]]
doc: Doc,
agent: str,
metadata: Optional[Dict[str, Any]],
) -> TokenClassificationRecord:
"""
Converts a spaCy `Doc` into a token classification record
Expand Down Expand Up @@ -89,17 +92,29 @@ def pipe(self, *args, **kwargs):
log_info.append((doc, metadata))
yield r

self.log_async(log_info)
self.send_records(log_info)

def __call__(self, *args, **kwargs):
metadata = kwargs.pop("metadata", None)
doc = self.__wrapped__(*args, **kwargs)
try:
if self.is_record_accepted():
self.log_async([(doc, metadata)])
self.send_records([(doc, metadata)])
finally:
return doc


def ner_monitor(nlp: Language, dataset: str, sample_rate: float) -> Language:
return SpacyNERMonitor(nlp, dataset=dataset, sample_rate=sample_rate)
def ner_monitor(
nlp: Language,
api: Api,
dataset: str,
sample_rate: float,
log_interval: float,
) -> Language:
return SpacyNERMonitor(
nlp,
api=api,
dataset=dataset,
sample_rate=sample_rate,
log_interval=log_interval,
)
27 changes: 22 additions & 5 deletions src/argilla/monitoring/_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pydantic import BaseModel

from argilla import TextClassificationRecord
from argilla.client.api import Api
from argilla.monitoring.base import BaseMonitor
from argilla.monitoring.types import MissingType

Expand Down Expand Up @@ -138,7 +139,7 @@ def __call__(
if self.is_record_accepted()
]
if filtered_data:
self.log_async(filtered_data, multi_label=multi_label)
self.send_records(filtered_data, multi_label=multi_label)

finally:
return batch_predictions
Expand Down Expand Up @@ -170,17 +171,33 @@ def __call__(self, inputs, *args, **kwargs):
if self.is_record_accepted()
]
if filtered_data:
self.log_async(filtered_data)
self.send_records(filtered_data)

finally:
return batch_predictions


def huggingface_monitor(
pl: Pipeline, dataset: str, sample_rate: float
pl: Pipeline,
api: Api,
dataset: str,
sample_rate: float,
log_interval: float,
) -> Optional[Pipeline]:
if isinstance(pl, TextClassificationPipeline):
return TextClassificationMonitor(pl, dataset=dataset, sample_rate=sample_rate)
return TextClassificationMonitor(
pl,
api=api,
dataset=dataset,
sample_rate=sample_rate,
log_interval=log_interval,
)
if isinstance(pl, ZeroShotClassificationPipeline):
return ZeroShotMonitor(pl, dataset=dataset, sample_rate=sample_rate)
return ZeroShotMonitor(
pl,
api=api,
dataset=dataset,
sample_rate=sample_rate,
log_interval=log_interval,
)
return None
187 changes: 167 additions & 20 deletions src/argilla/monitoring/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,144 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import atexit
import logging
import random
import threading
from typing import Any, Dict, Optional
from queue import Empty, Queue
from typing import Any, Dict, Iterable, List, Optional

import backoff
import monotonic
import wrapt

import argilla
from argilla.client.api import Api
from argilla.client.models import Record
from argilla.client.sdk.commons.errors import ArApiResponseError


class ModelNotSupportedError(Exception):
pass


class DatasetRecordsConsumer(threading.Thread):
"""Consumes the records from the dataset queue."""

log = logging.getLogger("argilla.monitoring")

def __init__(
self,
name: str,
api: Api,
tags: Optional[dict] = None,
metadata: Optional[dict] = None,
buffer_size: int = 10000,
upload_size=256,
upload_interval=1.0,
retries=10,
timeout=15,
on_error=None,
):
"""Create a consumer thread."""
threading.Thread.__init__(self)
self.daemon = True
self.upload_size = upload_size
self.upload_interval = upload_interval
self.api = api
self.on_error = on_error
self.queue = Queue(maxsize=buffer_size)
self.dataset = name
self.tags = tags
self.metadata = metadata

self.running = True
self.retries = retries
self.timeout = timeout

def run(self):
"""Runs the consumer."""
while self.running:
self.log_next_batch()

def pause(self):
"""Pause the consumer."""
self.running = False

def log_next_batch(self):
"""Upload the next batch of items, return whether successful."""
success = False
batch = self._next_batch()
if len(batch) == 0:
return False

try:
self._log_records(batch)
success = True
except Exception as e:
self.log.error("error logging data: %s", e)
success = False
if self.on_error:
self.on_error(e, batch)
finally:
# mark items as acknowledged from queue
for _ in batch:
self.queue.task_done()
return success

def _next_batch(self) -> List[Record]:
queue = self.queue
records = []

start_time = monotonic.monotonic()
while len(records) < self.upload_size:
elapsed = monotonic.monotonic() - start_time
if elapsed >= self.upload_interval:
break
try:
item = queue.get(
block=True,
timeout=self.upload_interval - elapsed,
)
records.append(item)
except Empty:
break

return records

def _log_records(self, batch: List[Record]):
def fatal_exception(exc):
if isinstance(exc, ArApiResponseError):
return (400 <= exc.HTTP_STATUS < 500) and exc.HTTP_STATUS != 429
else:
return False

@backoff.on_exception(
backoff.expo,
Exception,
max_tries=self.retries + 1,
giveup=fatal_exception,
)
def _inner_log_records():
self.api.log(
name=self.dataset,
records=batch,
tags=self.tags,
metadata=self.metadata,
background=True,
verbose=False,
)

_inner_log_records()

def send(self, records: Iterable[Record]):
"""Send records to the consumer"""
for record in records:
self.queue.put(
item=record,
block=False,
)


class BaseMonitor(wrapt.ObjectProxy):
"""
A base monitor class for easy task model monitoring
Expand All @@ -41,11 +165,13 @@ class BaseMonitor(wrapt.ObjectProxy):
def __init__(
self,
*args,
api: Api,
dataset: str,
sample_rate: float,
sample_rate: float = 1.0,
log_interval: float = 1.0,
agent: Optional[str] = None,
tags: Dict[str, str] = None,
**kwargs
**kwargs,
):
super().__init__(*args, **kwargs)

Expand All @@ -58,6 +184,11 @@ def __init__(
self.sample_rate = sample_rate
self.agent = agent
self.tags = tags
self._api = api
self._log_interval = log_interval
self._consumers: Dict[str, DatasetRecordsConsumer] = {}

atexit.register(self.shutdown)

@property
def __model__(self):
Expand All @@ -71,18 +202,34 @@ def is_record_accepted(self) -> bool:
def _prepare_log_data(self, *args, **kwargs) -> Dict[str, Any]:
raise NotImplementedError()

def log_async(self, *args, **kwargs):
log_args = self._prepare_log_data(*args, **kwargs)
log_args.pop("verbose", None)
log_args.pop("background", None)
return argilla.log(**log_args, verbose=False, background=True)

def _start_event_loop_if_needed(self):
"""Recreate loop/thread if needed"""
if self._event_loop is None:
self._event_loop = asyncio.new_event_loop()
if self._event_loop_thread is None or not self._event_loop_thread.is_alive():
self._thread = threading.Thread(
target=self._event_loop.run_forever, daemon=True
)
self._thread.start()
def shutdown(self):
"""Stop consumers"""
for consumer in self._consumers.values():
try:
consumer.pause()
consumer.join()
except RuntimeError:
pass

def send_records(self, *args, **kwargs):
data = self._prepare_log_data(*args, **kwargs)

consumer = self._get_consumer_by_dataset(dataset=data["name"])
consumer.tags = data.get("tags", {})
consumer.metadata = data.get("metadata", {})
consumer.send(data["records"])

def _get_consumer_by_dataset(self, dataset: str):
if dataset not in self._consumers:
print(f"NOT FOUND {dataset}")
self._consumers[dataset] = self._create_consumer(dataset)
return self._consumers[dataset]

def _create_consumer(self, name: str):
consumer = DatasetRecordsConsumer(
name=name,
api=self._api,
upload_interval=self._log_interval,
)
consumer.start()
return consumer

0 comments on commit 8aec980

Please sign in to comment.