Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
[project]
name = "uipath-langchain"
version = "0.0.144"
version = "0.0.145"
description = "UiPath Langchain"
readme = { file = "README.md", content-type = "text/markdown" }
requires-python = ">=3.10"
dependencies = [
"uipath>=2.1.103, <2.2.0",
"uipath>=2.1.110, <2.2.0",
"langgraph>=0.5.0, <0.7.0",
"langchain-core>=0.3.34",
"langgraph-checkpoint-sqlite>=2.0.3",
Expand Down Expand Up @@ -111,4 +111,4 @@ asyncio_mode = "auto"
name = "testpypi"
url = "https://test.pypi.org/simple/"
publish-url = "https://test.pypi.org/legacy/"
explicit = true
explicit = true
19 changes: 19 additions & 0 deletions src/uipath_langchain/_cli/_runtime/_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
import os
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, AsyncIterator, Optional, Sequence
from uuid import uuid4

from langchain_core.runnables.config import RunnableConfig
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
from langgraph.errors import EmptyInputError, GraphRecursionError, InvalidUpdateError
from langgraph.graph.state import CompiledStateGraph, StateGraph
from langgraph.types import Interrupt, StateSnapshot
from typing_extensions import override
from uipath._cli._runtime._contracts import (
UiPathBaseRuntime,
UiPathBreakpointResult,
Expand All @@ -17,12 +19,14 @@
UiPathRuntimeResult,
UiPathRuntimeStatus,
)
from uipath._cli.models.runtime_schema import Entrypoint
from uipath._events._events import (
UiPathAgentMessageEvent,
UiPathAgentStateEvent,
UiPathRuntimeEvent,
)

from .._utils._schema import generate_schema_from_graph
from ._context import LangGraphRuntimeContext
from ._exception import LangGraphErrorCode, LangGraphRuntimeError
from ._graph_resolver import AsyncResolver, LangGraphJsonResolver
Expand Down Expand Up @@ -481,6 +485,21 @@ def __init__(
self.resolver = LangGraphJsonResolver(entrypoint=entrypoint)
super().__init__(context, self.resolver)

@override
async def get_entrypoint(self) -> Entrypoint:
"""Get entrypoint for this LangGraph runtime."""
graph = await self.resolver()
compiled_graph = graph.compile()
schema = generate_schema_from_graph(compiled_graph)

return Entrypoint(
file_path=self.context.entrypoint, # type: ignore[call-arg]
unique_id=str(uuid4()),
type="agent",
input=schema["input"],
output=schema["output"],
)

async def cleanup(self) -> None:
"""Cleanup runtime resources including resolver."""
await super().cleanup()
Expand Down
85 changes: 85 additions & 0 deletions src/uipath_langchain/_cli/_utils/_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from typing import Any, Dict

from langgraph.graph.state import CompiledStateGraph


def resolve_refs(schema, root=None):
"""Recursively resolves $ref references in a JSON schema."""
if root is None:
root = schema # Store the root schema to resolve $refs

if isinstance(schema, dict):
if "$ref" in schema:
ref_path = schema["$ref"].lstrip("#/").split("/")
ref_schema = root
for part in ref_path:
ref_schema = ref_schema.get(part, {})
return resolve_refs(ref_schema, root)

return {k: resolve_refs(v, root) for k, v in schema.items()}

elif isinstance(schema, list):
return [resolve_refs(item, root) for item in schema]

return schema


def process_nullable_types(
schema: Dict[str, Any] | list[Any] | Any,
) -> Dict[str, Any] | list[Any]:
"""Process the schema to handle nullable types by removing anyOf with null and keeping the base type."""
if isinstance(schema, dict):
if "anyOf" in schema and len(schema["anyOf"]) == 2:
types = [t.get("type") for t in schema["anyOf"]]
if "null" in types:
non_null_type = next(
t for t in schema["anyOf"] if t.get("type") != "null"
)
return non_null_type

return {k: process_nullable_types(v) for k, v in schema.items()}
elif isinstance(schema, list):
return [process_nullable_types(item) for item in schema]
return schema


def generate_schema_from_graph(
graph: CompiledStateGraph[Any, Any, Any],
) -> Dict[str, Any]:
"""Extract input/output schema from a LangGraph graph"""
schema = {
"input": {"type": "object", "properties": {}, "required": []},
"output": {"type": "object", "properties": {}, "required": []},
}

if hasattr(graph, "input_schema"):
if hasattr(graph.input_schema, "model_json_schema"):
input_schema = graph.input_schema.model_json_schema()
unpacked_ref_def_properties = resolve_refs(input_schema)

# Process the schema to handle nullable types
processed_properties = process_nullable_types(
unpacked_ref_def_properties.get("properties", {})
)

schema["input"]["properties"] = processed_properties
schema["input"]["required"] = unpacked_ref_def_properties.get(
"required", []
)

if hasattr(graph, "output_schema"):
if hasattr(graph.output_schema, "model_json_schema"):
output_schema = graph.output_schema.model_json_schema()
unpacked_ref_def_properties = resolve_refs(output_schema)

# Process the schema to handle nullable types
processed_properties = process_nullable_types(
unpacked_ref_def_properties.get("properties", {})
)

schema["output"]["properties"] = processed_properties
schema["output"]["required"] = unpacked_ref_def_properties.get(
"required", []
)

return schema
12 changes: 3 additions & 9 deletions src/uipath_langchain/_cli/cli_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
get_current_span,
)
from uipath._cli._evals._console_progress_reporter import ConsoleProgressReporter
from uipath._cli._evals._evaluate import evaluate
from uipath._cli._evals._progress_reporter import StudioWebProgressReporter
from uipath._cli._evals._runtime import UiPathEvalContext, UiPathEvalRuntime
from uipath._cli._evals._runtime import UiPathEvalContext
from uipath._cli._runtime._contracts import (
UiPathRuntimeFactory,
)
Expand Down Expand Up @@ -82,14 +83,7 @@ def generate_runtime(ctx: LangGraphRuntimeContext) -> LangGraphScriptRuntime:

runtime_factory.add_instrumentor(LangChainInstrumentor, get_current_span)

async def execute():
async with UiPathEvalRuntime.from_eval_context(
factory=runtime_factory, context=eval_context, event_bus=event_bus
) as eval_runtime:
await eval_runtime.execute()
await event_bus.wait_for_all()

asyncio.run(execute())
asyncio.run(evaluate(runtime_factory, eval_context, event_bus))
return MiddlewareResult(should_continue=False)

except Exception as e:
Expand Down
86 changes: 3 additions & 83 deletions src/uipath_langchain/_cli/cli_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
import uuid
from collections.abc import Generator
from enum import Enum
from typing import Any, Callable, Dict, overload
from typing import Any, Callable, overload

import click
from langgraph.graph.state import CompiledStateGraph
from uipath._cli._utils._console import ConsoleLogger
from uipath._cli._utils._parse_ast import generate_bindings_json # type: ignore
from uipath._cli.middlewares import MiddlewareResult

from uipath_langchain._cli._utils._schema import generate_schema_from_graph

from ._utils._graph import LangGraphConfig

console = ConsoleLogger()
Expand All @@ -27,88 +29,6 @@ class FileOperationStatus(str, Enum):
SKIPPED = "skipped"


def resolve_refs(schema, root=None):
"""Recursively resolves $ref references in a JSON schema."""
if root is None:
root = schema # Store the root schema to resolve $refs

if isinstance(schema, dict):
if "$ref" in schema:
ref_path = schema["$ref"].lstrip("#/").split("/")
ref_schema = root
for part in ref_path:
ref_schema = ref_schema.get(part, {})
return resolve_refs(ref_schema, root)

return {k: resolve_refs(v, root) for k, v in schema.items()}

elif isinstance(schema, list):
return [resolve_refs(item, root) for item in schema]

return schema


def process_nullable_types(
schema: Dict[str, Any] | list[Any] | Any,
) -> Dict[str, Any] | list[Any]:
"""Process the schema to handle nullable types by removing anyOf with null and keeping the base type."""
if isinstance(schema, dict):
if "anyOf" in schema and len(schema["anyOf"]) == 2:
types = [t.get("type") for t in schema["anyOf"]]
if "null" in types:
non_null_type = next(
t for t in schema["anyOf"] if t.get("type") != "null"
)
return non_null_type

return {k: process_nullable_types(v) for k, v in schema.items()}
elif isinstance(schema, list):
return [process_nullable_types(item) for item in schema]
return schema


def generate_schema_from_graph(
graph: CompiledStateGraph[Any, Any, Any],
) -> Dict[str, Any]:
"""Extract input/output schema from a LangGraph graph"""
schema = {
"input": {"type": "object", "properties": {}, "required": []},
"output": {"type": "object", "properties": {}, "required": []},
}

if hasattr(graph, "input_schema"):
if hasattr(graph.input_schema, "model_json_schema"):
input_schema = graph.input_schema.model_json_schema()
unpacked_ref_def_properties = resolve_refs(input_schema)

# Process the schema to handle nullable types
processed_properties = process_nullable_types(
unpacked_ref_def_properties.get("properties", {})
)

schema["input"]["properties"] = processed_properties
schema["input"]["required"] = unpacked_ref_def_properties.get(
"required", []
)

if hasattr(graph, "output_schema"):
if hasattr(graph.output_schema, "model_json_schema"):
output_schema = graph.output_schema.model_json_schema()
unpacked_ref_def_properties = resolve_refs(output_schema)

# Process the schema to handle nullable types
processed_properties = process_nullable_types(
unpacked_ref_def_properties.get("properties", {})
)

schema["output"]["properties"] = processed_properties
schema["output"]["required"] = unpacked_ref_def_properties.get(
"required", []
)

return schema


def generate_agent_md_file(
target_directory: str,
file_name: str,
Expand Down
21 changes: 21 additions & 0 deletions src/uipath_langchain/runtime_factories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Runtime factory for LangGraph projects."""

from uipath._cli._runtime._contracts import UiPathRuntimeFactory

from ._cli._runtime._context import LangGraphRuntimeContext
from ._cli._runtime._runtime import LangGraphScriptRuntime


class LangGraphRuntimeFactory(
UiPathRuntimeFactory[LangGraphScriptRuntime, LangGraphRuntimeContext]
):
"""Factory for LangGraph runtimes."""

def __init__(self):
super().__init__(
LangGraphScriptRuntime,
LangGraphRuntimeContext,
context_generator=lambda **kwargs: LangGraphRuntimeContext.with_defaults(
**kwargs
),
)
8 changes: 4 additions & 4 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.