Skip to content

Commit

Permalink
Wait for all futures (langchain-ai#6554)
Browse files Browse the repository at this point in the history
- Expose method to wait for all futures
- Wait for submissions in the run_on_dataset functions to ensure runs
are fully submitted before cleaning up
  • Loading branch information
vowelparrot authored and aerrober committed Jul 24, 2023
1 parent 02d5ad6 commit 9ba7454
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 15 deletions.
63 changes: 50 additions & 13 deletions langchain/callbacks/tracers/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

import logging
import os
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import Future, ThreadPoolExecutor, wait
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Set, Union
from uuid import UUID

from langchainplus_sdk import LangChainPlusClient
Expand All @@ -21,6 +21,7 @@

logger = logging.getLogger(__name__)
_LOGGED = set()
_TRACERS: List[LangChainTracer] = []


def log_error_once(method: str, exception: Exception) -> None:
Expand All @@ -32,6 +33,12 @@ def log_error_once(method: str, exception: Exception) -> None:
logger.error(exception)


def wait_for_all_tracers() -> None:
global _TRACERS
for tracer in _TRACERS:
tracer.wait_for_futures()


class LangChainTracer(BaseTracer):
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""

Expand All @@ -52,6 +59,9 @@ def __init__(
# set max_workers to 1 to process tasks in order
self.executor = ThreadPoolExecutor(max_workers=1)
self.client = client or LangChainPlusClient()
self._futures: Set[Future] = set()
global _TRACERS
_TRACERS.append(self)

def on_chat_model_start(
self,
Expand Down Expand Up @@ -93,7 +103,7 @@ def _persist_run_single(self, run: Run) -> None:
extra["runtime"] = get_runtime_environment()
run_dict["extra"] = extra
try:
run = self.client.create_run(**run_dict, session_name=self.session_name)
self.client.create_run(**run_dict, session_name=self.session_name)
except Exception as e:
# Errors are swallowed by the thread executor so we need to log them here
log_error_once("post", e)
Expand All @@ -110,40 +120,67 @@ def _update_run_single(self, run: Run) -> None:

def _on_llm_start(self, run: Run) -> None:
"""Persist an LLM run."""
self.executor.submit(self._persist_run_single, run.copy(deep=True))
self._futures.add(
self.executor.submit(self._persist_run_single, run.copy(deep=True))
)

def _on_chat_model_start(self, run: Run) -> None:
"""Persist an LLM run."""
self.executor.submit(self._persist_run_single, run.copy(deep=True))
self._futures.add(
self.executor.submit(self._persist_run_single, run.copy(deep=True))
)

def _on_llm_end(self, run: Run) -> None:
"""Process the LLM Run."""
self.executor.submit(self._update_run_single, run.copy(deep=True))
self._futures.add(
self.executor.submit(self._update_run_single, run.copy(deep=True))
)

def _on_llm_error(self, run: Run) -> None:
"""Process the LLM Run upon error."""
self.executor.submit(self._update_run_single, run.copy(deep=True))
self._futures.add(
self.executor.submit(self._update_run_single, run.copy(deep=True))
)

def _on_chain_start(self, run: Run) -> None:
"""Process the Chain Run upon start."""
self.executor.submit(self._persist_run_single, run.copy(deep=True))
self._futures.add(
self.executor.submit(self._persist_run_single, run.copy(deep=True))
)

def _on_chain_end(self, run: Run) -> None:
"""Process the Chain Run."""
self.executor.submit(self._update_run_single, run.copy(deep=True))
self._futures.add(
self.executor.submit(self._update_run_single, run.copy(deep=True))
)

def _on_chain_error(self, run: Run) -> None:
"""Process the Chain Run upon error."""
self.executor.submit(self._update_run_single, run.copy(deep=True))
self._futures.add(
self.executor.submit(self._update_run_single, run.copy(deep=True))
)

def _on_tool_start(self, run: Run) -> None:
"""Process the Tool Run upon start."""
self.executor.submit(self._persist_run_single, run.copy(deep=True))
self._futures.add(
self.executor.submit(self._persist_run_single, run.copy(deep=True))
)

def _on_tool_end(self, run: Run) -> None:
"""Process the Tool Run."""
self.executor.submit(self._update_run_single, run.copy(deep=True))
self._futures.add(
self.executor.submit(self._update_run_single, run.copy(deep=True))
)

def _on_tool_error(self, run: Run) -> None:
"""Process the Tool Run upon error."""
self.executor.submit(self._update_run_single, run.copy(deep=True))
self._futures.add(
self.executor.submit(self._update_run_single, run.copy(deep=True))
)

def wait_for_futures(self) -> None:
"""Wait for the given futures to complete."""
futures = list(self._futures)
wait(futures)
for future in futures:
self._futures.remove(future)
14 changes: 12 additions & 2 deletions langchain/client/runner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,17 @@ async def run_coroutine_with_semaphore(
tracer_queue.put_nowait(tracer)
return result

return await asyncio.gather(
results = await asyncio.gather(
*(run_coroutine_with_semaphore(function) for function in async_funcs)
)
while tracer_queue:
try:
tracer = tracer_queue.get_nowait()
except asyncio.QueueEmpty:
break
if tracer:
tracer.wait_for_futures()
return results


async def _tracer_initializer(session_name: Optional[str]) -> Optional[LangChainTracer]:
Expand Down Expand Up @@ -411,7 +419,9 @@ def run_on_examples(
)
if verbose:
print(f"{i+1} processed", flush=True, end="\r")
results[str(example.id)] = result
results[str(example.id)] = result
if tracer:
tracer.wait_for_futures()
return results


Expand Down

0 comments on commit 9ba7454

Please sign in to comment.